// This code example is created for educational purpose
// by Thorsten Thormaehlen (contact: www.thormae.de).
// It is distributed without any warranty.

import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;

import javax.imageio.ImageIO;
import com.jogamp.opengl.GL3;
import com.jogamp.opengl.GLAutoDrawable;
import com.jogamp.opengl.GLCapabilities;
import com.jogamp.opengl.GLEventListener;
import com.jogamp.opengl.GLProfile;
import com.jogamp.opengl.awt.GLCanvas;
import javax.swing.JFrame;

import com.jogamp.common.nio.Buffers;
import com.jogamp.opengl.util.FPSAnimator;

class Renderer {
 
  public float t;
  
  private enum VAOs {Pyramid, numVAOs};
  private enum VBOs {PyramidAll, numVBOs};
  private int[] vaoID  = new int[VAOs.numVAOs.ordinal()];
  private int[] bufID = new int[VBOs.numVBOs.ordinal()];
  private int pyramidVertNo = 0;
  private int texID = 0;
  private int progID = 0;
  private int vertID = 0;
  private int fragID = 0;
  private int vertexLoc = 0;
  private int colorLoc = 0;
  private int texCoordLoc = 0;
  private int normalLoc = 0;
  private int projectionLoc = 0;
  private int modelviewLoc = 0;
  private int texLoc = 0;
  private float[] projection = new float[16];
  private float[] modelview = new float[16];
  
  public void init(GLAutoDrawable d) {
    GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
    gl.glEnable(GL3.GL_DEPTH_TEST);
    
    setupShaders(d);
    
    // create a Vertex Array Objects (VAO)
    gl.glGenVertexArrays(VAOs.numVAOs.ordinal(), vaoID, 0);

    // generate a Vertex Buffer Object (VBO)
    gl.glGenBuffers(VBOs.numVBOs.ordinal(), bufID, 0);

    // bind the pyramid VAO
    gl.glBindVertexArray(vaoID[VAOs.Pyramid.ordinal()]);

    float pyramidVertexData[] = {
        0.0f, 0.0f, 2.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.5f, 1.0f, 0.0000f,-0.9701f, 0.2425f,
       -0.5f,-0.5f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0000f,-0.9701f, 0.2425f,
        0.5f,-0.5f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0000f,-0.9701f, 0.2425f,
        0.0f, 0.0f, 2.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.5f, 1.0f, 0.9701f, 0.0000f, 0.2425f,
        0.5f,-0.5f, 0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.9701f, 0.0000f, 0.2425f,
        0.5f, 0.5f, 0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.9701f, 0.0000f, 0.2425f,
        0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.5f, 1.0f, 0.0000f, 0.9701f, 0.2425f,
        0.5f, 0.5f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0000f, 0.9701f, 0.2425f,
       -0.5f, 0.5f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0000f, 0.9701f, 0.2425f,
        0.0f, 0.0f, 2.0f, 1.0f, 1.0f, 0.0f, 1.0f, 0.5f, 1.0f,-0.9701f, 0.0000f, 0.2425f,
       -0.5f, 0.5f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f,-0.9701f, 0.0000f, 0.2425f,
       -0.5f,-0.5f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f,-0.9701f, 0.0000f, 0.2425f
     };

     pyramidVertNo = 12;

    int floatItems = pyramidVertNo * (3+4+2+3);
    FloatBuffer pyramidVertexFB = Buffers.newDirectFloatBuffer(floatItems);
    pyramidVertexFB.put(pyramidVertexData);
    pyramidVertexFB.flip();
    
    gl.glBindBuffer(GL3.GL_ARRAY_BUFFER, bufID[VBOs.PyramidAll.ordinal()]);
    gl.glBufferData(GL3.GL_ARRAY_BUFFER, pyramidVertexFB.capacity()*Buffers.SIZEOF_FLOAT, 
                    pyramidVertexFB, GL3.GL_STATIC_DRAW);
    
    int stride = (3+4+2+3)*Buffers.SIZEOF_FLOAT;
    int offset = 0;

    // position
	if(vertexLoc != -1) {
	  gl.glVertexAttribPointer(vertexLoc, 3, GL3.GL_FLOAT, false, stride, offset);
	  gl.glEnableVertexAttribArray(vertexLoc);
    }
	
    // color
	if(colorLoc != -1) {
      offset = 0 + 3*Buffers.SIZEOF_FLOAT;
      gl.glVertexAttribPointer(colorLoc, 4, GL3.GL_FLOAT, false, stride, offset);
      gl.glEnableVertexAttribArray(colorLoc);
    }
	
    // texCoord
	if(texCoordLoc != -1) {
      offset = 0 + (3+4)*Buffers.SIZEOF_FLOAT;
      gl.glVertexAttribPointer(texCoordLoc, 2, GL3.GL_FLOAT, false, stride, offset);
      gl.glEnableVertexAttribArray(texCoordLoc);
    }
	
    // normal
	if(normalLoc != -1) {
      offset = 0 + (3+4+2)*Buffers.SIZEOF_FLOAT;
      gl.glVertexAttribPointer(normalLoc, 3, GL3.GL_FLOAT, false, stride, offset);
      gl.glEnableVertexAttribArray(normalLoc);  
    }
	
    texID = loadTexture(d, "checkerboard.png");
    
  }
 
  public void resize(GLAutoDrawable d, int w, int h) {
    GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
    gl.glViewport(0, 0, w, h);
    
    // this function replaces gluPerspective
    mat4Perspective(projection, 30.0f, (float)w/(float)h, 1.0f, 10.0f);
    //mat4Print(projection);
  }

  public void display(GLAutoDrawable d) {
    GL3 gl = d.getGL().getGL3();  // get the OpenGL >= 3 graphics context
    
    gl.glClearColor(0.0f, 0.0f, 0.0f, 0.0f);
    gl.glClear(GL3.GL_COLOR_BUFFER_BIT | GL3.GL_DEPTH_BUFFER_BIT);

    // camera orbits in the z=5 plane
    // and looks at the origin
    // mat4LookAt replaces gluLookAt
    double rad = Math.PI / 180.0f * t;
    mat4LookAt(modelview,
               5.0f*(float)Math.cos(rad), 5.0f*(float)Math.sin(rad), 5.0f, // eye
               0.0f, 0.0f, 0.5f, // look at
               0.0f, 0.0f, 1.0f); // up
    
    gl.glUseProgram(progID); 
    
    // load the current projection and modelview matrix into the
    // corresponding UNIFORM variables of the shader
    gl.glUniformMatrix4fv(projectionLoc, 1, false, projection, 0);
    gl.glUniformMatrix4fv(modelviewLoc, 1, false, modelview, 0);
   
    
    // activate texture unit 0
    gl.glActiveTexture(GL3.GL_TEXTURE0);
    // bind texture
    gl.glBindTexture(GL3.GL_TEXTURE_2D, texID);
    // inform the shader to use texture unit 0
    gl.glUniform1i(texLoc, 0);
    
    
    // bind pyramid VAO
    gl.glBindVertexArray(vaoID[VAOs.Pyramid.ordinal()]);
    // render data
    gl.glDrawArrays(GL3.GL_TRIANGLES, 0, pyramidVertNo);

    gl.glFlush();
  }
  
  public void  setupShaders(GLAutoDrawable d) {
    GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
    
    vertID = gl.glCreateShader(GL3.GL_VERTEX_SHADER);
    fragID = gl.glCreateShader(GL3.GL_FRAGMENT_SHADER);
    
    String[] vs = loadShaderSrc("./texture.vert");
    String[] fs = loadShaderSrc("./texture.frag");
    
    gl.glShaderSource(vertID, 1, vs, null, 0);
    gl.glShaderSource(fragID, 1, fs, null, 0);

    // compile the shader
    gl.glCompileShader(vertID);
    gl.glCompileShader(fragID);

    // check for errors
    printShaderInfoLog(d, vertID);
    printShaderInfoLog(d, fragID);

    // create program and attach shaders
    progID = gl.glCreateProgram();
    gl.glAttachShader(progID, vertID);
    gl.glAttachShader(progID, fragID);

    // "outColor" is a user-provided OUT variable
    // of the fragment shader.
    // Its output is bound to the first color buffer
    // in the framebuffer
    gl.glBindFragDataLocation(progID, 0, "outputColor");

    // link the program
    gl.glLinkProgram(progID);
    // output error messages
    printProgramInfoLog(d, progID);

    // "inputPosition" and "inputColor" are user-provided
    // IN variables of the vertex shader.
    // Their locations are stored to be used later with
    // glEnableVertexAttribArray()
    vertexLoc = gl.glGetAttribLocation(progID,"inputPosition");
    colorLoc = gl.glGetAttribLocation(progID, "inputColor");
    texCoordLoc = gl.glGetAttribLocation(progID,"inputTexCoord");
    normalLoc = gl.glGetAttribLocation(progID, "inputNormal");
    
    // "projection" and "modelview" are user-provided
    // UNIFORM variables of the vertex shader.
    // Their locations are stored to be used later
    projectionLoc = gl.glGetUniformLocation(progID, "projection");
    modelviewLoc = gl.glGetUniformLocation(progID, "modelview");
    texLoc = gl.glGetUniformLocation(progID, "myTexture");
  }
  
  private String[] loadShaderSrc(String name){
    StringBuilder sb = new StringBuilder();
    try{
       BufferedReader br = new BufferedReader(new FileReader(name));
       String line;
       while ((line = br.readLine())!=null){
          sb.append(line);
          sb.append('\n');
       }
       br.close();
    }
    catch (Exception e){
       e.printStackTrace();
    }
    return new String[]{sb.toString()};
  }
 
  private void printShaderInfoLog(GLAutoDrawable d, int obj) {
    GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
    IntBuffer infoLogLengthBuf = IntBuffer.allocate(1);
    int infoLogLength;
    gl.glGetShaderiv(obj, GL3.GL_INFO_LOG_LENGTH, infoLogLengthBuf);
    infoLogLength = infoLogLengthBuf.get(0);
    if (infoLogLength > 0) {
      ByteBuffer byteBuffer = ByteBuffer.allocate(infoLogLength);
      gl.glGetShaderInfoLog(obj, infoLogLength, infoLogLengthBuf, byteBuffer);
      for (byte b:byteBuffer.array()){
        System.err.print((char)b);
      }
    }
  }
  
  
  private void printProgramInfoLog(GLAutoDrawable d, int obj) {
    GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
    IntBuffer infoLogLengthBuf = IntBuffer.allocate(1);
    int infoLogLength;
    gl.glGetProgramiv(obj, GL3.GL_INFO_LOG_LENGTH, infoLogLengthBuf);
    infoLogLength = infoLogLengthBuf.get(0);
    if (infoLogLength > 0) {
      ByteBuffer byteBuffer = ByteBuffer.allocate(infoLogLength);
      gl.glGetProgramInfoLog(obj, infoLogLength, infoLogLengthBuf, byteBuffer);
      for (byte b:byteBuffer.array()){
        System.err.print((char)b);
      }
    }
  }

  // returns a valid textureID on success, otherwise 0
  private int loadTexture(GLAutoDrawable d, String filename) {
    GL3 gl = d.getGL().getGL3(); // get the OpenGL 2 graphics context

    int width;
    int height;
    int level = 0;
    int border = 0;

    try {
      // open file
      FileInputStream fileInputStream = new FileInputStream(new File(filename));

      // read image
      BufferedImage bufferedImage = ImageIO.read(fileInputStream);
      fileInputStream.close();

      width = bufferedImage.getWidth();
      height = bufferedImage.getHeight();

      // convert image to ByteBuffer
      int[] pixelIntData = new int[width * height];
      bufferedImage.getRGB(0, 0, width, height, pixelIntData, 0, width);
      ByteBuffer buffer = ByteBuffer.allocateDirect(pixelIntData.length * 4);
      buffer.order(ByteOrder.nativeOrder());
      // Unpack the data, each integer into 4 bytes of the ByteBuffer.
      // Also we need to vertically flip the image because the image origin
      // in OpenGL is the lower-left corner.
      for (int y = 0; y < height; y++) {
        int k = (height - 1 - y) * width;
        for (int x = 0; x < width; x++) {
          buffer.put((byte) (pixelIntData[k] >>> 16));
          buffer.put((byte) (pixelIntData[k] >>> 8));
          buffer.put((byte) (pixelIntData[k]));
          buffer.put((byte) (pixelIntData[k] >>> 24));
          k++;
        }
      }
      buffer.rewind();

      // data is aligned in byte order
      gl.glPixelStorei(GL3.GL_UNPACK_ALIGNMENT, 1);

      // request textureID
      final int[] textureID = new int[1];
      gl.glGenTextures(1, textureID, 0);

      // bind texture
      gl.glBindTexture(GL3.GL_TEXTURE_2D, textureID[0]);

      // define how to filter the texture
      gl.glTexParameteri(GL3.GL_TEXTURE_2D, GL3.GL_TEXTURE_MAG_FILTER,
          GL3.GL_LINEAR);
      gl.glTexParameteri(GL3.GL_TEXTURE_2D, GL3.GL_TEXTURE_MIN_FILTER,
          GL3.GL_LINEAR);

      // specify the 2D texture map
      gl.glTexImage2D(GL3.GL_TEXTURE_2D, level, GL3.GL_RGB, width, height,
          border, GL3.GL_RGBA, GL3.GL_UNSIGNED_BYTE, buffer);

      return textureID[0];
    } catch (FileNotFoundException e) {
      System.out.println("Can not find texture data file " + filename);
    } catch (IOException e) {
      e.printStackTrace();
    }
    return 0;
  }
  
  
  // the following functions are some matrix and vector helpers
  // they work for this demo but in general it is recommended
  // to use more advanced matrix libraries
  private float vec3Dot(float[] a, float[] b) {
    return a[0]*b[0] + a[1]*b[1] + a[2]*b[2];
  }

  private void vec3Cross(float[] a, float[] b, float[] res) {
    res[0] = a[1] * b[2]  -  b[1] * a[2];
    res[1] = a[2] * b[0]  -  b[2] * a[0];
    res[2] = a[0] * b[1]  -  b[0] * a[1];
  }

  private void vec3Normalize(float[] a) {
    float mag = (float) Math.sqrt(a[0] * a[0]  +  a[1] * a[1]  +  a[2] * a[2]);
    a[0] /= mag; a[1] /= mag; a[2] /= mag;
  } 
  
  private void mat4Identity(float[] a) {
    for (int i = 0; i < 16; ++i) a[i] = 0.0f;
    for (int i = 0; i < 4; ++i) a[i + i * 4] = 1.0f;
  }

  private void mat4Multiply(float[] a, float[] b, float[] res) {
    for (int i = 0; i < 4; ++i) {
      for (int j = 0; j < 4; ++j) {
        res[j*4 + i] = 0.0f;
        for (int k = 0; k < 4; ++k) {
          res[j*4 + i] += a[k*4 + i] * b[j*4 + k];
        }
      }
    }
  }
  
  private void mat4Perspective(float[] a, float fov, float aspect, float zNear, float zFar) {
    float f = 1.0f / (float) (Math.tan (fov/2.0f * (Math.PI / 180.0f)));
    mat4Identity(a);
    a[0] = f / aspect;
    a[1 * 4 + 1] = f;
    a[2 * 4 + 2] = (zFar + zNear)  / (zNear - zFar);
    a[3 * 4 + 2] = (2.0f * zFar * zNear) / (zNear - zFar);
    a[2 * 4 + 3] = -1.0f;
    a[3 * 4 + 3] = 0.0f;
  }
  
  private void mat4LookAt(float[] viewMatrix,
      float eyeX, float eyeY, float eyeZ,
      float centerX, float centerY, float centerZ,
      float upX, float upY, float upZ) {

    float dir[] = new float[3];
    float right[] = new float[3];
    float up[] = new float[3];
    float eye[] = new float[3];

    up[0]=upX; up[1]=upY; up[2]=upZ;
    eye[0]=eyeX; eye[1]=eyeY; eye[2]=eyeZ;

    dir[0]=centerX-eyeX; dir[1]=centerY-eyeY; dir[2]=centerZ-eyeZ;
    vec3Normalize(dir);
    vec3Cross(dir,up,right);
    vec3Normalize(right);
    vec3Cross(right,dir,up);
    vec3Normalize(up);
    // first row
    viewMatrix[0]  = right[0];
    viewMatrix[4]  = right[1];
    viewMatrix[8]  = right[2];
    viewMatrix[12] = -vec3Dot(right, eye);
    // second row
    viewMatrix[1]  = up[0];
    viewMatrix[5]  = up[1];
    viewMatrix[9]  = up[2];
    viewMatrix[13] = -vec3Dot(up, eye);
    // third row
    viewMatrix[2]  = -dir[0];
    viewMatrix[6]  = -dir[1];
    viewMatrix[10] = -dir[2];
    viewMatrix[14] =  vec3Dot(dir, eye);
    // forth row
    viewMatrix[3]  = 0.0f;
    viewMatrix[7]  = 0.0f;
    viewMatrix[11] = 0.0f;
    viewMatrix[15] = 1.0f;
  }
  
  private void mat4Print(float[] a) {
    // opengl uses column major order
    for (int i = 0; i < 4; ++i) {
      for (int j = 0; j < 4; ++j) {
        System.out.print( a[j * 4 + i] + " ");
      }
      System.out.println(" ");
    }
  }

  
}

class MyGui extends JFrame implements GLEventListener {

  private Renderer renderer;

  public void createGUI() {
    setTitle("Accessing textures in the shader");
    setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

    GLProfile glp = GLProfile.getDefault();
    GLCapabilities caps = new GLCapabilities(glp);
    GLCanvas canvas = new GLCanvas(caps);
    setSize(320, 320);

    getContentPane().add(canvas);
    final FPSAnimator ani = new FPSAnimator(canvas, 60, true);
    canvas.addGLEventListener(this);
    setVisible(true);
    renderer = new Renderer();

    ani.start();
  }

  @Override
  public void init(GLAutoDrawable d) {
    renderer.init(d);
  }

  @Override
  public void reshape(GLAutoDrawable d, int x, int y, int width, int height) {
    renderer.resize(d, width, height);
  }

  @Override
  public void display(GLAutoDrawable d) {
    float offset = 1.0f;
    renderer.t += offset;
    renderer.display(d);
  }

  @Override
  public void dispose(GLAutoDrawable d) {
  }
}

public class ShaderTexture {
  public static void main(String[] args) {
    javax.swing.SwingUtilities.invokeLater(new Runnable() {
      public void run() {
        MyGui myGUI = new MyGui();
        myGUI.createGUI();
      }
    });
  }
}
