// This code example is created for educational purpose
// by Thorsten Thormaehlen (contact: www.thormae.de).
// It is distributed without any warranty.

import java.awt.event.KeyAdapter;
import java.awt.event.KeyEvent;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;

import com.jogamp.opengl.GL3;
import com.jogamp.opengl.GLAutoDrawable;
import com.jogamp.opengl.GLCapabilities;
import com.jogamp.opengl.GLEventListener;
import com.jogamp.opengl.GLProfile;
import com.jogamp.opengl.awt.GLCanvas;

import javax.swing.JFrame;

import com.jogamp.common.nio.Buffers;
import com.jogamp.math.Matrix4f;
import com.jogamp.math.Vec3f;
import com.jogamp.opengl.util.FPSAnimator;

class Renderer {

  public float t = 0.0f;
  public int modeVal = 1;

  private enum VAOs {Scene, numVAOs};
  private enum VBOs {SceneAll, numVBOs};
  private int[] vaoID = new int[VAOs.numVAOs.ordinal()];
  private int[] bufID = new int[VBOs.numVBOs.ordinal()];
  private int sceneVertNo = 0;
  private int progID = 0;
  private int vertID = 0;
  private int fragID = 0;
  private int vertexLoc = 0;
  private int texCoordLoc = 0;
  private int normalLoc = 0;
  private int projectionLoc = 0;
  private int modelviewLoc = 0;
  private int normalMatrixLoc = 0;
  private int modeLoc = 0;
  private Matrix4f projection = new Matrix4f();
  private Matrix4f modelview = new Matrix4f();

  public void init(GLAutoDrawable d) {
    GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
    gl.glEnable(GL3.GL_DEPTH_TEST);

    setupShaders(d);

    // create a Vertex Array Objects (VAO)
    gl.glGenVertexArrays(VAOs.numVAOs.ordinal(), vaoID, 0);

    // generate a Vertex Buffer Object (VBO)
    gl.glGenBuffers(VBOs.numVBOs.ordinal(), bufID, 0);

    // binding the Triangle VAO
    gl.glBindVertexArray(vaoID[VAOs.Scene.ordinal()]);

    int perVertexFloats = (3 + 2 + 3);
    float data[] = loadVertexData("./teapot.vbo", perVertexFloats);

    sceneVertNo = data.length / perVertexFloats;


    FloatBuffer sceneVertexFB = Buffers.newDirectFloatBuffer(data.length);
    sceneVertexFB.put(data);
    sceneVertexFB.flip();

    gl.glBindBuffer(GL3.GL_ARRAY_BUFFER, bufID[VBOs.SceneAll.ordinal()]);
    gl.glBufferData(GL3.GL_ARRAY_BUFFER, sceneVertexFB.capacity() * Buffers.SIZEOF_FLOAT,
            sceneVertexFB, GL3.GL_STATIC_DRAW);

    int stride = (3 + 2 + 3) * Buffers.SIZEOF_FLOAT;
    int offset = 0;

    // position
    if (vertexLoc != -1) {
      gl.glVertexAttribPointer(vertexLoc, 3, GL3.GL_FLOAT, false, stride, offset);
      gl.glEnableVertexAttribArray(vertexLoc);
    }

    // texCoord
    if (texCoordLoc != -1) {
      offset = 3 * Buffers.SIZEOF_FLOAT;
      gl.glVertexAttribPointer(texCoordLoc, 2, GL3.GL_FLOAT, false, stride, offset);
      gl.glEnableVertexAttribArray(texCoordLoc);
    }

    // normal
    if (normalLoc != -1) {
      offset = (3 + 2) * Buffers.SIZEOF_FLOAT;
      gl.glVertexAttribPointer(normalLoc, 3, GL3.GL_FLOAT, false, stride, offset);
      gl.glEnableVertexAttribArray(normalLoc);
    }
  }


  public void resize(GLAutoDrawable d, int w, int h) {
    GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
    gl.glViewport(0, 0, w, h);

    // setToPerspective replaces gluPerspective
    projection.setToPerspective((float) Math.toRadians(45.0f), (float) w / (float) h, 0.5f, 4.0f);

  }

  public void display(GLAutoDrawable d) {
    GL3 gl = d.getGL().getGL3();  // get the OpenGL >= 3 graphics context

    gl.glClearColor(0.0f, 0.0f, 0.0f, 0.0f);
    gl.glClear(GL3.GL_COLOR_BUFFER_BIT | GL3.GL_DEPTH_BUFFER_BIT);

    // camera orbits in the z=1.5 plane
    // and looks at the origin
    // mat4LookAt replaces gluLookAt
    double rad = Math.PI / 180.0f * t;
    Vec3f eye = new Vec3f(1.5f * (float) Math.cos(rad), 1.5f * (float) Math.sin(rad), 1.5f);
    Vec3f center = new Vec3f(0.0f, 0.0f, 0.0f);
    Vec3f up = new Vec3f(0.0f, 0.0f, 1.0f);
    Matrix4f tempMatrix = new Matrix4f();
    // setToLookAt replaces gluLookAt
    modelview.setToLookAt(eye, center, up, tempMatrix);

    // rotational part of the modelview transformation
    Matrix4f normalmatrix = new Matrix4f(modelview);
    normalmatrix.invert();
    normalmatrix.transpose();

    gl.glUseProgram(progID);

    // load the current projection and modelview matrix into the
    // corresponding UNIFORM variables of the shader
    gl.glUniformMatrix4fv(projectionLoc, 1, false, projection.get(new float[16]), 0);
    gl.glUniformMatrix4fv(modelviewLoc, 1, false, modelview.get(new float[16]), 0);
    gl.glUniformMatrix4fv(normalMatrixLoc, 1, false, normalmatrix.get(new float[16]), 0);
    gl.glUniform1i(modeLoc, modeVal);


    // bind scene VAO
    gl.glBindVertexArray(vaoID[VAOs.Scene.ordinal()]);

    // render data
    gl.glDrawArrays(GL3.GL_TRIANGLES, 0, sceneVertNo);
  }

  public void dispose(GLAutoDrawable d) {
    GL3 gl = d.getGL().getGL3();  // get the OpenGL >= 3 graphics context
    gl.glDeleteVertexArrays(VAOs.numVAOs.ordinal(), vaoID, 0);
    gl.glDeleteBuffers(VBOs.numVBOs.ordinal(), bufID, 0);
    gl.glDeleteProgram(progID);
    gl.glDeleteShader(vertID);
    gl.glDeleteShader(fragID);
  }

  private void setupShaders(GLAutoDrawable d) {
    GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context

    vertID = gl.glCreateShader(GL3.GL_VERTEX_SHADER);
    fragID = gl.glCreateShader(GL3.GL_FRAGMENT_SHADER);

    String[] vs = loadShaderSrc("./pass.vert");
    String[] fs = loadShaderSrc("./pass.frag");

    gl.glShaderSource(vertID, 1, vs, null, 0);
    gl.glShaderSource(fragID, 1, fs, null, 0);

    // compile the shader
    gl.glCompileShader(vertID);
    gl.glCompileShader(fragID);

    // check for errors
    printShaderInfoLog(d, vertID);
    printShaderInfoLog(d, fragID);

    // create program and attach shaders
    progID = gl.glCreateProgram();
    gl.glAttachShader(progID, vertID);
    gl.glAttachShader(progID, fragID);

    // "outColor" is a user-provided OUT variable
    // of the fragment shader.
    // Its output is bound to the first color buffer
    // in the framebuffer
    gl.glBindFragDataLocation(progID, 0, "outputColor");

    // link the program
    gl.glLinkProgram(progID);
    // output error messages
    printProgramInfoLog(d, progID);

    // retrieve the location of the IN variables of the vertex shader
    vertexLoc = gl.glGetAttribLocation(progID, "inputPosition");
    texCoordLoc = gl.glGetAttribLocation(progID, "inputTexCoord");
    normalLoc = gl.glGetAttribLocation(progID, "inputNormal");

    // retrieve the location of the UNIFORM variables of the vertex shader
    projectionLoc = gl.glGetUniformLocation(progID, "projection");
    modelviewLoc = gl.glGetUniformLocation(progID, "modelview");
    normalMatrixLoc = gl.glGetUniformLocation(progID, "normalMat");
    modeLoc = gl.glGetUniformLocation(progID, "mode");
  }

  private String[] loadShaderSrc(String name) {
    StringBuilder sb = new StringBuilder();
    try {
      InputStream is = getClass().getResourceAsStream(name);
      BufferedReader br = new BufferedReader(new InputStreamReader(is));
      String line;
      while ((line = br.readLine()) != null) {
        sb.append(line);
        sb.append('\n');
      }
      is.close();
    } catch (Exception e) {
      e.printStackTrace();
    }
    return new String[]{sb.toString()};
  }

  private void printShaderInfoLog(GLAutoDrawable d, int obj) {
    GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
    IntBuffer infoLogLengthBuf = IntBuffer.allocate(1);
    int infoLogLength;
    gl.glGetShaderiv(obj, GL3.GL_INFO_LOG_LENGTH, infoLogLengthBuf);
    infoLogLength = infoLogLengthBuf.get(0);
    if (infoLogLength > 0) {
      ByteBuffer byteBuffer = ByteBuffer.allocate(infoLogLength);
      gl.glGetShaderInfoLog(obj, infoLogLength, infoLogLengthBuf, byteBuffer);
      for (byte b : byteBuffer.array()) {
        System.err.print((char) b);
      }
    }
  }

  private void printProgramInfoLog(GLAutoDrawable d, int obj) {
    GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
    IntBuffer infoLogLengthBuf = IntBuffer.allocate(1);
    int infoLogLength;
    gl.glGetProgramiv(obj, GL3.GL_INFO_LOG_LENGTH, infoLogLengthBuf);
    infoLogLength = infoLogLengthBuf.get(0);
    if (infoLogLength > 0) {
      ByteBuffer byteBuffer = ByteBuffer.allocate(infoLogLength);
      gl.glGetProgramInfoLog(obj, infoLogLength, infoLogLengthBuf, byteBuffer);
      for (byte b : byteBuffer.array()) {
        System.err.print((char) b);
      }
    }
  }

  private float[] loadVertexData(String filename, int perVertexFloats) {

    float[] floatArray = new float[0];
    // read vertex data from file
    int vertSize = 0;
    try {
      InputStream is = getClass().getResourceAsStream(filename);
      BufferedReader br = new BufferedReader(new InputStreamReader(is));
      String line = br.readLine();
      if (line != null) {
        vertSize = Integer.parseInt(line);
        floatArray = new float[vertSize];
      }
      int i = 0;
      while ((line = br.readLine()) != null && i < floatArray.length) {
        floatArray[i] = Float.parseFloat(line);
        i++;
      }
      if (i != vertSize || (vertSize % perVertexFloats) != 0) {
        floatArray = new float[0];
      }
      br.close();
    } catch (FileNotFoundException e) {
      System.out.println("Can not find vbo data file " + filename);
    } catch (IOException e) {
      e.printStackTrace();
    }
    return floatArray;
  }
}

class MyGui extends JFrame implements GLEventListener {

  private Renderer renderer;

  public void createGUI() {
    setTitle("Transforming normals");
    setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

    GLProfile glp = GLProfile.getDefault();
    GLCapabilities caps = new GLCapabilities(glp);
    GLCanvas canvas = new GLCanvas(caps);
    setSize(320, 320);

    getContentPane().add(canvas);
    final FPSAnimator ani = new FPSAnimator(canvas, 60, true);
    canvas.addGLEventListener(this);
    setVisible(true);
    renderer = new Renderer();

    canvas.addKeyListener(new KeyAdapter() {
      public void keyPressed(KeyEvent event) {
        boolean redraw = false;
        String modeStr = "";

        switch (event.getKeyCode()) {
          case '1':
            renderer.modeVal = 1;
            redraw = true;
            modeStr = "Global Normals";
            break;
          case '2':
            renderer.modeVal = 2;
            redraw = true;
            modeStr = "Local Normals";
            break;
          case '3':
            renderer.modeVal = 3;
            redraw = true;
            modeStr = "Global Vertex Positions";
            break;
          case '4':
            renderer.modeVal = 4;
            redraw = true;
            modeStr = "Local Vertex Positions";
            break;
          case '5':
            renderer.modeVal = 5;
            redraw = true;
            modeStr = "Texture Coordinates";
            break;
        }
        if (redraw) {
          setTitle(modeStr);
        }
      }
    });


    ani.start();
  }

  @Override
  public void init(GLAutoDrawable d) {
    renderer.init(d);
  }

  @Override
  public void reshape(GLAutoDrawable d, int x, int y, int width, int height) {
    renderer.resize(d, width, height);
  }

  @Override
  public void display(GLAutoDrawable d) {
    float offset = 1.0f;
    renderer.t += offset;
    renderer.display(d);
  }

  @Override
  public void dispose(GLAutoDrawable d) {
    renderer.dispose(d);
  }
}

public class ShaderNormalTrans {
  public static void main(String[] args) {
    javax.swing.SwingUtilities.invokeLater(new Runnable() {
      public void run() {
        MyGui myGUI = new MyGui();
        myGUI.createGUI();
      }
    });
  }
}
