// This code example is created for educational purpose
// by Thorsten Thormaehlen (contact: www.thormae.de).
// It is distributed without any warranty.

import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;

import javax.imageio.ImageIO;

import com.jogamp.opengl.GL3;
import com.jogamp.opengl.GLAutoDrawable;
import com.jogamp.opengl.GLCapabilities;
import com.jogamp.opengl.GLEventListener;
import com.jogamp.opengl.GLProfile;
import com.jogamp.opengl.awt.GLCanvas;

import javax.swing.JFrame;

import com.jogamp.common.nio.Buffers;
import com.jogamp.math.Matrix4f;
import com.jogamp.math.Vec3f;
import com.jogamp.opengl.util.FPSAnimator;

class Renderer {

    public float t;

    private enum VAOs {Pyramid, numVAOs}

    ;

    private enum VBOs {PyramidAll, numVBOs}

    ;
    private int[] vaoID = new int[VAOs.numVAOs.ordinal()];
    private int[] bufID = new int[VBOs.numVBOs.ordinal()];
    private int pyramidVertNo = 0;
    private int texID = 0;
    private int progID = 0;
    private int vertID = 0;
    private int fragID = 0;
    private int vertexLoc = 0;
    private int colorLoc = 0;
    private int texCoordLoc = 0;
    private int normalLoc = 0;
    private int projectionLoc = 0;
    private int modelviewLoc = 0;
    private int texLoc = 0;
    private Matrix4f projection = new Matrix4f();
    private Matrix4f modelview = new Matrix4f();

    public void init(GLAutoDrawable d) {
        GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
        gl.glEnable(GL3.GL_DEPTH_TEST);

        setupShaders(d);

        // create a Vertex Array Objects (VAO)
        gl.glGenVertexArrays(VAOs.numVAOs.ordinal(), vaoID, 0);

        // generate a Vertex Buffer Object (VBO)
        gl.glGenBuffers(VBOs.numVBOs.ordinal(), bufID, 0);

        // bind the pyramid VAO
        gl.glBindVertexArray(vaoID[VAOs.Pyramid.ordinal()]);

        float pyramidVertexData[] = {
                0.0f, 0.0f, 2.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.5f, 1.0f, 0.0000f, -0.9701f, 0.2425f,
                -0.5f, -0.5f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0000f, -0.9701f, 0.2425f,
                0.5f, -0.5f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0000f, -0.9701f, 0.2425f,
                0.0f, 0.0f, 2.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.5f, 1.0f, 0.9701f, 0.0000f, 0.2425f,
                0.5f, -0.5f, 0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.9701f, 0.0000f, 0.2425f,
                0.5f, 0.5f, 0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.9701f, 0.0000f, 0.2425f,
                0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.5f, 1.0f, 0.0000f, 0.9701f, 0.2425f,
                0.5f, 0.5f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0000f, 0.9701f, 0.2425f,
                -0.5f, 0.5f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0000f, 0.9701f, 0.2425f,
                0.0f, 0.0f, 2.0f, 1.0f, 1.0f, 0.0f, 1.0f, 0.5f, 1.0f, -0.9701f, 0.0000f, 0.2425f,
                -0.5f, 0.5f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, -0.9701f, 0.0000f, 0.2425f,
                -0.5f, -0.5f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, -0.9701f, 0.0000f, 0.2425f
        };

        pyramidVertNo = 12;

        int floatItems = pyramidVertNo * (3 + 4 + 2 + 3);
        FloatBuffer pyramidVertexFB = Buffers.newDirectFloatBuffer(floatItems);
        pyramidVertexFB.put(pyramidVertexData);
        pyramidVertexFB.flip();

        gl.glBindBuffer(GL3.GL_ARRAY_BUFFER, bufID[VBOs.PyramidAll.ordinal()]);
        gl.glBufferData(GL3.GL_ARRAY_BUFFER, pyramidVertexFB.capacity() * Buffers.SIZEOF_FLOAT,
                pyramidVertexFB, GL3.GL_STATIC_DRAW);

        int stride = (3 + 4 + 2 + 3) * Buffers.SIZEOF_FLOAT;
        int offset = 0;

        // position
        if (vertexLoc != -1) {
            gl.glVertexAttribPointer(vertexLoc, 3, GL3.GL_FLOAT, false, stride, offset);
            gl.glEnableVertexAttribArray(vertexLoc);
        }

        // color
        if (colorLoc != -1) {
            offset = 3 * Buffers.SIZEOF_FLOAT;
            gl.glVertexAttribPointer(colorLoc, 4, GL3.GL_FLOAT, false, stride, offset);
            gl.glEnableVertexAttribArray(colorLoc);
        }

        // texCoord
        if (texCoordLoc != -1) {
            offset = (3 + 4) * Buffers.SIZEOF_FLOAT;
            gl.glVertexAttribPointer(texCoordLoc, 2, GL3.GL_FLOAT, false, stride, offset);
            gl.glEnableVertexAttribArray(texCoordLoc);
        }

        // normal
        if (normalLoc != -1) {
            offset = (3 + 4 + 2) * Buffers.SIZEOF_FLOAT;
            gl.glVertexAttribPointer(normalLoc, 3, GL3.GL_FLOAT, false, stride, offset);
            gl.glEnableVertexAttribArray(normalLoc);
        }

        texID = loadTexture(d, "checkerboard.png");

    }

    public void resize(GLAutoDrawable d, int w, int h) {
        GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
        gl.glViewport(0, 0, w, h);

        // setToPerspective replaces gluPerspective
        projection.setToPerspective((float) Math.toRadians(30.0f), (float) w / (float) h, 1.0f, 10.0f);
    }

    public void display(GLAutoDrawable d) {
        GL3 gl = d.getGL().getGL3();  // get the OpenGL >= 3 graphics context

        gl.glClearColor(0.0f, 0.0f, 0.0f, 0.0f);
        gl.glClear(GL3.GL_COLOR_BUFFER_BIT | GL3.GL_DEPTH_BUFFER_BIT);

        // camera orbits in the z=5 plane
        // and looks at the origin
        // mat4LookAt replaces gluLookAt
        double rad = Math.PI / 180.0f * t;
        Vec3f eye = new Vec3f(5.0f * (float) Math.cos(rad), 5.0f * (float) Math.sin(rad), 5.0f);
        Vec3f center = new Vec3f(0.0f, 0.0f, 0.5f);
        Vec3f up = new Vec3f(0.0f, 0.0f, 1.0f);
        Matrix4f tempMatrix = new Matrix4f();
        // setToLookAt replaces gluLookAt
        modelview.setToLookAt(eye, center, up, tempMatrix);

        gl.glUseProgram(progID);

        // load the current projection and modelview matrix into the
        // corresponding UNIFORM variables of the shader
        gl.glUniformMatrix4fv(projectionLoc, 1, false, projection.get(new float[16]), 0);
        gl.glUniformMatrix4fv(modelviewLoc, 1, false, modelview.get(new float[16]), 0);


        // activate texture unit 0
        gl.glActiveTexture(GL3.GL_TEXTURE0);
        // bind texture
        gl.glBindTexture(GL3.GL_TEXTURE_2D, texID);
        // inform the shader to use texture unit 0
        gl.glUniform1i(texLoc, 0);

        // bind pyramid VAO
        gl.glBindVertexArray(vaoID[VAOs.Pyramid.ordinal()]);
        // render data
        gl.glDrawArrays(GL3.GL_TRIANGLES, 0, pyramidVertNo);
    }

    public void setupShaders(GLAutoDrawable d) {
        GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context

        vertID = gl.glCreateShader(GL3.GL_VERTEX_SHADER);
        fragID = gl.glCreateShader(GL3.GL_FRAGMENT_SHADER);

		// provide shader code as multiline string
        String[] vs = new String[]{
                """
#version 140

in vec3 inputPosition;
in vec4 inputColor;
in vec2 inputTexCoord;
in vec3 inputNormal;

uniform mat4 projection;
uniform mat4 modelview;

out vec3 forFragColor;
out vec2 forFragTexCoord;

void main(){
    forFragColor = inputColor.rgb;
    forFragTexCoord = inputTexCoord;
    gl_Position =  projection * modelview * vec4(inputPosition, 1.0);
}
"""
        };

        String[] fs = new String[]{
                """
#version 140

in vec3 forFragColor;
in vec2 forFragTexCoord;
out vec4 outputColor;

uniform sampler2D myTexture;

void main() {
    vec3 textureColor = vec3( texture(myTexture, forFragTexCoord) );
    outputColor = vec4(forFragColor * textureColor ,1.0);\s
}
"""
        };

        gl.glShaderSource(vertID, 1, vs, null, 0);
        gl.glShaderSource(fragID, 1, fs, null, 0);

        // compile the shader
        gl.glCompileShader(vertID);
        gl.glCompileShader(fragID);

        // check for errors
        printShaderInfoLog(d, vertID);
        printShaderInfoLog(d, fragID);

        // create program and attach shaders
        progID = gl.glCreateProgram();
        gl.glAttachShader(progID, vertID);
        gl.glAttachShader(progID, fragID);

        // "outColor" is a user-provided OUT variable
        // of the fragment shader.
        // Its output is bound to the first color buffer
        // in the framebuffer
        gl.glBindFragDataLocation(progID, 0, "outputColor");

        // link the program
        gl.glLinkProgram(progID);
        // output error messages
        printProgramInfoLog(d, progID);

        // "inputPosition" and "inputColor" are user-provided
        // IN variables of the vertex shader.
        // Their locations are stored to be used later with
        // glEnableVertexAttribArray()
        vertexLoc = gl.glGetAttribLocation(progID, "inputPosition");
        colorLoc = gl.glGetAttribLocation(progID, "inputColor");
        texCoordLoc = gl.glGetAttribLocation(progID, "inputTexCoord");
        normalLoc = gl.glGetAttribLocation(progID, "inputNormal");

        // "projection" and "modelview" are user-provided
        // UNIFORM variables of the vertex shader.
        // Their locations are stored to be used later
        projectionLoc = gl.glGetUniformLocation(progID, "projection");
        modelviewLoc = gl.glGetUniformLocation(progID, "modelview");
        texLoc = gl.glGetUniformLocation(progID, "myTexture");
    }

    private String[] loadShaderSrc(String name) {
        StringBuilder sb = new StringBuilder();
        try {
            BufferedReader br = new BufferedReader(new FileReader(name));
            String line;
            while ((line = br.readLine()) != null) {
                sb.append(line);
                sb.append('\n');
            }
            br.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
        return new String[]{sb.toString()};
    }

    private void printShaderInfoLog(GLAutoDrawable d, int obj) {
        GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
        IntBuffer infoLogLengthBuf = IntBuffer.allocate(1);
        int infoLogLength;
        gl.glGetShaderiv(obj, GL3.GL_INFO_LOG_LENGTH, infoLogLengthBuf);
        infoLogLength = infoLogLengthBuf.get(0);
        if (infoLogLength > 0) {
            ByteBuffer byteBuffer = ByteBuffer.allocate(infoLogLength);
            gl.glGetShaderInfoLog(obj, infoLogLength, infoLogLengthBuf, byteBuffer);
            for (byte b : byteBuffer.array()) {
                System.err.print((char) b);
            }
        }
    }


    private void printProgramInfoLog(GLAutoDrawable d, int obj) {
        GL3 gl = d.getGL().getGL3(); // get the OpenGL 3 graphics context
        IntBuffer infoLogLengthBuf = IntBuffer.allocate(1);
        int infoLogLength;
        gl.glGetProgramiv(obj, GL3.GL_INFO_LOG_LENGTH, infoLogLengthBuf);
        infoLogLength = infoLogLengthBuf.get(0);
        if (infoLogLength > 0) {
            ByteBuffer byteBuffer = ByteBuffer.allocate(infoLogLength);
            gl.glGetProgramInfoLog(obj, infoLogLength, infoLogLengthBuf, byteBuffer);
            for (byte b : byteBuffer.array()) {
                System.err.print((char) b);
            }
        }
    }

    // returns a valid textureID on success, otherwise 0
    private int loadTexture(GLAutoDrawable d, String filename) {
        GL3 gl = d.getGL().getGL3(); // get the OpenGL 2 graphics context

        int width;
        int height;
        int level = 0;
        int border = 0;

        try {
            // open file
            FileInputStream fileInputStream = new FileInputStream(new File(filename));

            // read image
            BufferedImage bufferedImage = ImageIO.read(fileInputStream);
            fileInputStream.close();

            width = bufferedImage.getWidth();
            height = bufferedImage.getHeight();

            // convert image to ByteBuffer
            int[] pixelIntData = new int[width * height];
            bufferedImage.getRGB(0, 0, width, height, pixelIntData, 0, width);
            ByteBuffer buffer = ByteBuffer.allocateDirect(pixelIntData.length * 4);
            buffer.order(ByteOrder.nativeOrder());
            // Unpack the data, each integer into 4 bytes of the ByteBuffer.
            // Also we need to vertically flip the image because the image origin
            // in OpenGL is the lower-left corner.
            for (int y = 0; y < height; y++) {
                int k = (height - 1 - y) * width;
                for (int x = 0; x < width; x++) {
                    buffer.put((byte) (pixelIntData[k] >>> 16));
                    buffer.put((byte) (pixelIntData[k] >>> 8));
                    buffer.put((byte) (pixelIntData[k]));
                    buffer.put((byte) (pixelIntData[k] >>> 24));
                    k++;
                }
            }
            buffer.rewind();

            // data is aligned in byte order
            gl.glPixelStorei(GL3.GL_UNPACK_ALIGNMENT, 1);

            // request textureID
            final int[] textureID = new int[1];
            gl.glGenTextures(1, textureID, 0);

            // bind texture
            gl.glBindTexture(GL3.GL_TEXTURE_2D, textureID[0]);

            // define how to filter the texture
            gl.glTexParameteri(GL3.GL_TEXTURE_2D, GL3.GL_TEXTURE_MAG_FILTER,
                    GL3.GL_LINEAR);
            gl.glTexParameteri(GL3.GL_TEXTURE_2D, GL3.GL_TEXTURE_MIN_FILTER,
                    GL3.GL_LINEAR);

            // specify the 2D texture map
            gl.glTexImage2D(GL3.GL_TEXTURE_2D, level, GL3.GL_RGB, width, height,
                    border, GL3.GL_RGBA, GL3.GL_UNSIGNED_BYTE, buffer);

            return textureID[0];
        } catch (FileNotFoundException e) {
            System.out.println("Can not find texture data file " + filename);
        } catch (IOException e) {
            e.printStackTrace();
        }
        return 0;
    }
}

class MyGui extends JFrame implements GLEventListener {

    private Renderer renderer;

    public void createGUI() {
        setTitle("Accessing textures in the shader");
        setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

        GLProfile glp = GLProfile.getDefault();
        GLCapabilities caps = new GLCapabilities(glp);
        GLCanvas canvas = new GLCanvas(caps);
        setSize(320, 320);

        getContentPane().add(canvas);
        final FPSAnimator ani = new FPSAnimator(canvas, 60, true);
        canvas.addGLEventListener(this);
        setVisible(true);
        renderer = new Renderer();

        ani.start();
    }

    @Override
    public void init(GLAutoDrawable d) {
        renderer.init(d);
    }

    @Override
    public void reshape(GLAutoDrawable d, int x, int y, int width, int height) {
        renderer.resize(d, width, height);
    }

    @Override
    public void display(GLAutoDrawable d) {
        float offset = 1.0f;
        renderer.t += offset;
        renderer.display(d);
    }

    @Override
    public void dispose(GLAutoDrawable d) {
    }
}

public class ShaderTexture {
    public static void main(String[] args) {
        javax.swing.SwingUtilities.invokeLater(new Runnable() {
            public void run() {
                MyGui myGUI = new MyGui();
                myGUI.createGUI();
            }
        });
    }
}
