Jogamp
67c49578349468bd9c4a1f04bba54ac0ab225935
[jocl-demos.git] / src / com / jogamp / opencl / demos / fractal / MultiDeviceFractal.java
1 package com.jogamp.opencl.demos.fractal;
2
3 import com.jogamp.opencl.CLBuffer;
4 import com.jogamp.opencl.CLCommandQueue;
5 import com.jogamp.opencl.CLDevice;
6 import com.jogamp.opencl.CLEvent;
7 import com.jogamp.opencl.CLEventList;
8 import com.jogamp.opencl.CLException;
9 import com.jogamp.opencl.gl.CLGLBuffer;
10 import com.jogamp.opencl.gl.CLGLContext;
11 import com.jogamp.opencl.CLKernel;
12 import com.jogamp.opencl.CLPlatform;
13 import com.jogamp.opencl.CLProgram;
14 import com.jogamp.opencl.CLProgram.CompilerOptions;
15 import com.jogamp.opencl.util.CLProgramConfiguration;
16 import com.jogamp.opengl.util.awt.TextRenderer;
17 import java.awt.Color;
18 import java.awt.Dimension;
19 import java.awt.Font;
20 import java.awt.Frame;
21 import java.awt.Point;
22 import java.awt.Window;
23 import java.awt.event.KeyAdapter;
24 import java.awt.event.KeyEvent;
25 import java.awt.event.MouseAdapter;
26 import java.awt.event.MouseEvent;
27 import java.awt.event.MouseWheelEvent;
28 import java.awt.event.WindowAdapter;
29 import java.awt.event.WindowEvent;
30 import java.io.IOException;
31 import java.nio.IntBuffer;
32 import java.util.logging.Level;
33 import java.util.logging.Logger;
34 import javax.media.opengl.DebugGL2;
35 import javax.media.opengl.GL;
36 import javax.media.opengl.GL2;
37 import javax.media.opengl.GLAutoDrawable;
38 import javax.media.opengl.GLCapabilities;
39 import javax.media.opengl.GLContext;
40 import javax.media.opengl.GLEventListener;
41 import javax.media.opengl.GLProfile;
42 import javax.media.opengl.awt.GLCanvas;
43 import javax.swing.SwingUtilities;
44
45 import static com.jogamp.common.nio.Buffers.*;
46 import static javax.media.opengl.GL2.*;
47 import static com.jogamp.opencl.CLMemory.Mem.*;
48 import static com.jogamp.opencl.CLDevice.Type.*;
49 import static com.jogamp.opencl.CLEvent.ProfilingCommand.*;
50 import static com.jogamp.opencl.CLCommandQueue.Mode.*;
51 import static java.lang.Math.*;
52
53 /**
54  * Computes the Mandelbrot set with OpenCL using multiple GPUs and renders the result with OpenGL.
55  * A shared PBO is used as storage for the fractal image.<br/>
56  * http://en.wikipedia.org/wiki/Mandelbrot_set
57  * <p>
58  * controls:<br/>
59  * keys 1-9 control parallelism level<br/>
60  * space enables/disables slice seperator<br/>
61  * 'd' toggles between 32/64bit floatingpoint precision<br/>
62  * mouse/mousewheel to drag and zoom<br/>
63  * </p>
64  * @author Michael Bien
65  */
66 public class MultiDeviceFractal implements GLEventListener {
67
68     // max number of used GPUs
69     private static final int MAX_PARRALLELISM_LEVEL = 8;
70
71     // max per pixel iterations to compute the fractal
72     private static final int MAX_ITERATIONS         = 500;
73
74     private GLCanvas canvas;
75
76     private CLGLContext clContext;
77     private CLCommandQueue[] queues;
78     private CLKernel[] kernels;
79     private CLProgram[] programs;
80     private CLEventList probes;
81     private CLGLBuffer<?>[] pboBuffers;
82     private CLBuffer<IntBuffer>[] colorMap;
83
84     private int width  = 0;
85     private int height = 0;
86
87     private double minX = -2f;
88     private double minY = -1.2f;
89     private double maxX  = 0.6f;
90     private double maxY  = 1.3f;
91
92     private int slices;
93
94     private boolean drawSeperator;
95     private boolean doublePrecision;
96     private boolean buffersInitialized;
97     private boolean rebuild;
98
99     private final TextRenderer textRenderer;
100
101     public MultiDeviceFractal(int width, int height) {
102
103         this.width = width;
104         this.height = height;
105
106         canvas = new GLCanvas(new GLCapabilities(GLProfile.get(GLProfile.GL2)));
107         canvas.addGLEventListener(this);
108         initSceneInteraction();
109
110         Frame frame = new Frame("JOCL Multi Device Mandelbrot Set");
111         frame.addWindowListener(new WindowAdapter() {
112             @Override
113             public void windowClosing(WindowEvent e) {
114                 MultiDeviceFractal.this.release(e.getWindow());
115             }
116         }); 
117         canvas.setPreferredSize(new Dimension(width, height));
118         frame.add(canvas);
119         frame.pack();
120
121         frame.setVisible(true);
122
123         textRenderer = new TextRenderer(frame.getFont().deriveFont(Font.BOLD, 14), true, true, null, false);
124     }
125
126     @Override
127     public void init(GLAutoDrawable drawable) {
128
129         if(clContext == null) {
130             // enable GL error checking using the composable pipeline
131             drawable.setGL(new DebugGL2(drawable.getGL().getGL2()));
132
133             drawable.getGL().glFinish();
134             initCL(drawable.getContext());
135
136             GL2 gl = drawable.getGL().getGL2();
137
138             gl.setSwapInterval(0);
139             gl.glDisable(GL_DEPTH_TEST);
140             gl.glClearColor(0.0f, 0.0f, 0.0f, 1.0f);
141
142             initView(gl, drawable.getWidth(), drawable.getHeight());
143
144             initPBO(gl);
145             drawable.getGL().glFinish();
146
147             setKernelConstants();
148         }
149     }
150
151     private void initCL(GLContext glCtx){
152         try {
153             CLPlatform platform = CLPlatform.getDefault();
154             // SLI on NV platform wasn't very fast (or did not work at all -> CL_INVALID_OPERATION)
155             if(platform.getICDSuffix().equals("NV")) {
156                 clContext = CLGLContext.create(glCtx, platform.getMaxFlopsDevice(GPU));
157             }else{
158                 clContext = CLGLContext.create(glCtx, platform, ALL);
159             }
160             CLDevice[] devices = clContext.getDevices();
161
162             slices = min(devices.length, MAX_PARRALLELISM_LEVEL);
163
164             // create command queues for every GPU, setup colormap and init kernels
165             queues = new CLCommandQueue[slices];
166             kernels = new CLKernel[slices];
167             probes = new CLEventList(slices);
168             colorMap = new CLBuffer[slices];
169
170             for (int i = 0; i < slices; i++) {
171
172                 colorMap[i] = clContext.createIntBuffer(32*2, READ_ONLY);
173                 initColorMap(colorMap[i].getBuffer(), 32, Color.BLUE, Color.GREEN, Color.RED);
174
175                 // create command queue and upload color map buffer on each used device
176                 queues[i] = devices[i].createCommandQueue(PROFILING_MODE).putWriteBuffer(colorMap[i], true); // blocking upload
177
178             }
179
180             // check if we have 64bit FP support on all devices
181             // if yes we can use only one program for all devices + one kernel per device.
182             // if not we will have to create (at least) one program for 32 and one for 64bit devices.
183             // since there are different vendor extensions for double FP we use one program per device.
184             // (OpenCL spec is not very clear about this usecases)
185             boolean all64bit = true;
186             for (CLDevice device : devices) {
187                 if(!isDoubleFPAvailable(device)) {
188                     all64bit = false;
189                     break;
190                 }
191             }
192
193             // load program(s)
194             if(all64bit) {
195                 programs = new CLProgram[] {
196                     clContext.createProgram(getClass().getResourceAsStream("Mandelbrot.cl"))
197                 };
198             }else{
199                 programs = new CLProgram[slices];
200                 for (int i = 0; i < slices; i++) {
201                     programs[i] = clContext.createProgram(getClass().getResourceAsStream("Mandelbrot.cl"));
202                 }
203             }
204
205             buildProgram();
206
207         } catch (IOException ex) {
208             Logger.getLogger(getClass().getName()).log(Level.SEVERE, "can not find 'Mandelbrot.cl' in classpath.", ex);
209             if(clContext != null) {
210                 clContext.release();
211             }
212         } catch (CLException ex) {
213             Logger.getLogger(getClass().getName()).log(Level.SEVERE, "something went wrong, hopefully nobody got hurt", ex);
214             if(clContext != null) {
215                 clContext.release();
216             }
217         }
218
219     }
220
221     private void initColorMap(IntBuffer colorMap, int stepSize, Color... colors) {
222         
223         for (int n = 0; n < colors.length - 1; n++) {
224
225             Color color = colors[n];
226             int r0 = color.getRed();
227             int g0 = color.getGreen();
228             int b0 = color.getBlue();
229
230             color = colors[n + 1];
231             int r1 = color.getRed();
232             int g1 = color.getGreen();
233             int b1 = color.getBlue();
234
235             int deltaR = r1 - r0;
236             int deltaG = g1 - g0;
237             int deltaB = b1 - b0;
238
239             for (int step = 0; step < stepSize; step++) {
240                 float alpha = (float) step / (stepSize - 1);
241                 int r = (int) (r0 + alpha * deltaR);
242                 int g = (int) (g0 + alpha * deltaG);
243                 int b = (int) (b0 + alpha * deltaB);
244                 colorMap.put((r << 16) | (g << 8) | (b << 0));
245             }
246         }
247         colorMap.rewind();
248
249     }
250
251     private void initView(GL2 gl, int width, int height) {
252
253         gl.glViewport(0, 0, width, height);
254
255         gl.glMatrixMode(GL_MODELVIEW);
256         gl.glLoadIdentity();
257
258         gl.glMatrixMode(GL_PROJECTION);
259         gl.glLoadIdentity();
260         gl.glOrtho(0.0, width, 0.0, height, 0.0, 1.0);
261     }
262
263     @SuppressWarnings("unchecked")
264     private void initPBO(GL gl) {
265
266         if(pboBuffers != null) {
267             int[] oldPbos = new int[pboBuffers.length];
268             for (int i = 0; i < pboBuffers.length; i++) {
269                 CLGLBuffer<?> buffer = pboBuffers[i];
270                 oldPbos[i] = buffer.GLID;
271                 buffer.release();
272             }
273             gl.glDeleteBuffers(oldPbos.length, oldPbos, 0);
274         }
275
276         pboBuffers = new CLGLBuffer[slices];
277
278         int[] pbo = new int[slices];
279         gl.glGenBuffers(slices, pbo, 0);
280
281         // setup one empty PBO per slice
282         for (int i = 0; i < slices; i++) {
283
284             final int size = width*height * SIZEOF_INT / slices ;
285             gl.glBindBuffer(GL_PIXEL_UNPACK_BUFFER, pbo[i]);
286             gl.glBufferData(GL_PIXEL_UNPACK_BUFFER, size, null, GL_STREAM_DRAW);
287             gl.glBindBuffer(GL_PIXEL_UNPACK_BUFFER, 0);
288
289             pboBuffers[i] = clContext.createFromGLBuffer(pbo[i], size, WRITE_ONLY);
290         }
291
292         buffersInitialized = true;
293     }
294
295     private void buildProgram() {
296
297         /*
298          * workaround: The driver keeps using the old binaries for some reason.
299          * to solve this we simple create a new program and release the old.
300          * however rebuilding programs should be possible -> remove when drivers are fixed.
301          * (again: the spec is not very clear about this kind of usages)
302          */
303         if(programs[0] != null && rebuild) {
304             for(int i = 0; i < programs.length; i++) {
305                 String source = programs[i].getSource();
306                 programs[i].release();
307                 programs[i] = clContext.createProgram(source);
308             }
309         }
310
311         // disable 64bit floating point math if not available
312         for(int i = 0; i < programs.length; i++) {
313             CLDevice device = queues[i].getDevice();
314
315             CLProgramConfiguration configure = programs[i].prepare();
316             if(doublePrecision && isDoubleFPAvailable(device)) {
317                 //cl_khr_fp64
318                 configure.withDefine("DOUBLE_FP");
319
320                 //amd's verson of double precision floating point math
321                 if(!device.isDoubleFPAvailable() && device.isExtensionAvailable("cl_amd_fp64")) {
322                     configure.withDefine("AMD_FP");
323                 }
324             }
325             if(programs.length > 1) {
326                 configure.forDevice(device);
327             }
328             System.out.println(configure);
329             configure.withOption(CompilerOptions.FAST_RELAXED_MATH).build();
330          }
331
332         rebuild = false;
333
334         for (int i = 0; i < kernels.length; i++) {
335             // init kernel with constants
336             kernels[i] = programs[min(i, programs.length)].createCLKernel("mandelbrot");
337         }
338
339     }
340
341     // init kernels with constants
342     private void setKernelConstants() {
343         for (int i = 0; i < slices; i++) {
344             kernels[i].setForce32BitArgs(!doublePrecision || !isDoubleFPAvailable(queues[i].getDevice()))
345                       .setArg(6, pboBuffers[i])
346                       .setArg(7, colorMap[i])
347                       .setArg(8, colorMap[i].getBuffer().capacity())
348                       .setArg(9, MAX_ITERATIONS);
349         }
350     }
351
352     // rendering cycle
353     @Override
354     public void display(GLAutoDrawable drawable) {
355         GL gl = drawable.getGL();
356
357         // make sure GL does not use our objects before we start computeing
358         gl.glFinish();
359         if(!buffersInitialized) {
360             initPBO(gl);
361             setKernelConstants();
362         }
363         if(rebuild) {
364             buildProgram();
365             setKernelConstants();
366         }
367         compute();
368
369         render(gl.getGL2());
370     }
371
372     // OpenCL
373     private void compute() {
374
375         int sliceWidth = (int)(width / (float)slices);
376         double rangeX  = (maxX - minX) / slices;
377         double rangeY  = (maxY - minY);
378
379         // release all old events, you can't reuse events in OpenCL
380         probes.release();
381
382         // start computation
383         for (int i = 0; i < slices; i++) {
384
385             kernels[i].putArg(     sliceWidth).putArg(height)
386                       .putArg(minX + rangeX*i).putArg(  minY)
387                       .putArg(       rangeX  ).putArg(rangeY)
388                       .rewind();
389
390             // aquire GL objects, and enqueue a kernel with a probe from the list
391             queues[i].putAcquireGLObject(pboBuffers[i])
392                      .put2DRangeKernel(kernels[i], 0, 0, sliceWidth, height, 0, 0, probes)
393                      .putReleaseGLObject(pboBuffers[i]);
394
395         }
396
397         // block until done (important: finish before doing further gl work)
398         for (int i = 0; i < slices; i++) {
399             queues[i].finish();
400         }
401
402     }
403
404     // OpenGL
405     private void render(GL2 gl) {
406
407         gl.glClear(GL_COLOR_BUFFER_BIT);
408
409         //draw slices
410         int sliceWidth = width / slices;
411
412         for (int i = 0; i < slices; i++) {
413
414             int seperatorOffset = drawSeperator?i:0;
415
416             gl.glBindBuffer(GL_PIXEL_UNPACK_BUFFER, pboBuffers[i].GLID);
417             gl.glRasterPos2i(sliceWidth*i + seperatorOffset, 0);
418
419             gl.glDrawPixels(sliceWidth, height, GL_BGRA, GL_UNSIGNED_BYTE, 0);
420
421         }
422         gl.glBindBuffer(GL_PIXEL_UNPACK_BUFFER, 0);
423
424         //draw info text
425         textRenderer.beginRendering(width, height, false);
426
427             textRenderer.draw("device/time/precision", 10, height-15);
428
429             for (int i = 0; i < slices; i++) {
430                 CLDevice device = queues[i].getDevice();
431                 boolean doubleFP = doublePrecision && isDoubleFPAvailable(device);
432                 CLEvent event = probes.getEvent(i);
433                 long start = event.getProfilingInfo(START);
434                 long end = event.getProfilingInfo(END);
435                 textRenderer.draw(device.getType().toString()+i +" "
436                                + (int)((end-start)/1000000.0f)+"ms @"
437                                + (doubleFP?"64bit":"32bit"), 10, height-(20+16*(slices-i)));
438             }
439
440         textRenderer.endRendering();
441     }
442
443     @Override
444     public void reshape(GLAutoDrawable drawable, int x, int y, int width, int height) {
445
446         if(this.width == width && this.height == height)
447             return;
448
449         this.width = width;
450         this.height = height;
451
452         initPBO(drawable.getGL());
453         setKernelConstants();
454
455         initView(drawable.getGL().getGL2(), width, height);
456         
457     }
458
459     private void initSceneInteraction() {
460
461         MouseAdapter mouseAdapter = new MouseAdapter() {
462
463             Point lastpos = new Point();
464
465             @Override
466             public void mouseDragged(MouseEvent e) {
467                 
468                 double offsetX = (lastpos.x - e.getX()) * (maxX - minX) / width;
469                 double offsetY = (lastpos.y - e.getY()) * (maxY - minY) / height;
470
471                 minX += offsetX;
472                 minY -= offsetY;
473
474                 maxX += offsetX;
475                 maxY -= offsetY;
476
477                 lastpos = e.getPoint();
478
479                 canvas.display();
480
481             }
482
483             @Override
484             public void mouseMoved(MouseEvent e) {
485                 lastpos = e.getPoint();
486             }
487             
488             @Override
489             public void mouseWheelMoved(MouseWheelEvent e) {
490                 float rotation = e.getWheelRotation() / 25.0f;
491
492                 double deltaX = rotation * (maxX - minX);
493                 double deltaY = rotation * (maxY - minY);
494
495                 // offset for "zoom to cursor"
496                 double offsetX = (e.getX() / (float)width - 0.5f) * deltaX * 2;
497                 double offsetY = (e.getY() / (float)height- 0.5f) * deltaY * 2;
498
499                 minX += deltaX+offsetX;
500                 minY += deltaY-offsetY;
501
502                 maxX +=-deltaX+offsetX;
503                 maxY +=-deltaY-offsetY;
504
505                 canvas.display();
506             }
507         };
508
509         KeyAdapter keyAdapter = new KeyAdapter() {
510
511             @Override
512             public void keyPressed(KeyEvent e) {
513                 if(e.getKeyCode() == KeyEvent.VK_SPACE) {
514                     drawSeperator = !drawSeperator;
515                 }else if(e.getKeyChar() > '0' && e.getKeyChar() < '9') {
516                     int number = e.getKeyChar()-'0';
517                     slices = min(number, min(queues.length, MAX_PARRALLELISM_LEVEL));
518                     buffersInitialized = false;
519                 }else if(e.getKeyCode() == KeyEvent.VK_D) {
520                     doublePrecision = !doublePrecision;
521                     rebuild = true;
522                 }
523                 canvas.display();
524             }
525
526         };
527
528         canvas.addMouseMotionListener(mouseAdapter);
529         canvas.addMouseWheelListener(mouseAdapter);
530         canvas.addKeyListener(keyAdapter);
531     }
532
533
534     private boolean isDoubleFPAvailable(CLDevice device) {
535         return device.isDoubleFPAvailable() || device.isExtensionAvailable("cl_amd_fp64");
536     }
537
538     private void release(Window win) {
539         if(clContext != null) {
540             // releases all resources
541             clContext.release();
542         }
543         win.dispose();
544     }
545
546     @Override
547     public void dispose(GLAutoDrawable drawable) {
548     }
549
550     public static void main(String args[]) {
551         
552         //false for webstart compatibility
553         GLProfile.initSingleton(false);
554         
555         SwingUtilities.invokeLater(new Runnable() {
556             @Override public void run() {
557                 new MultiDeviceFractal(512, 512);
558             }
559         });
560     }
561
562 }
http://JogAmp.org git info: FAQ, tutorial and man pages.