# This code example is created for educational purpose
# by Thorsten Thormaehlen (contact: www.thormae.de).
# It is distributed without any warranty.

# Requires PyOpenGL, PyGLM, numpy, PIL, which can be install with:
# pip install PyOpenGL PyOpenGL_accelerate PyGLM numpy Pillow

from OpenGL.GL import *
from OpenGL.GLUT import *
import numpy as np
import sys
import ctypes
import math
import glm
from PIL import Image

class Renderer:
    def __init__(self):
        self.t = 0.0
        self.pyramidVertNo = 0
        self.texID = 0
        self.progID = 0
        self.vertID = 0
        self.fragID = 0
        self.vertexLoc = -1
        self.colorLoc = -1
        self.texCoordLoc = -1
        self.normalLoc = -1
        self.vaoID = 0
        self.bufID = 0
        self.projectionLoc = -1
        self.modelviewLoc = -1
        self.projection = glm.mat4()
        self.modelview = glm.mat4()

    # public member functions
    def init(self):
        glEnable(GL_DEPTH_TEST)

        self._setup_shaders()

        # create a Vertex Array Objects (VAO)
        self.vaoID = glGenVertexArrays(1)
        
        # generate a Vertex Buffer Object (VBO)
        self.bufID = glGenBuffers(1)

        # binding the Pyramid VAO
        glBindVertexArray(self.vaoID)

        pyramid_vertex_data = np.array([
            0.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.5, 1.0, 0.0000,-0.9701, 0.2425,
            -0.5,-0.5, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0000,-0.9701, 0.2425,
            0.5,-0.5, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0000,-0.9701, 0.2425,
            0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 1.0, 0.5, 1.0, 0.9701, 0.0000, 0.2425,
            0.5,-0.5, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.9701, 0.0000, 0.2425,
            0.5, 0.5, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.9701, 0.0000, 0.2425,
            0.0, 0.0, 2.0, 0.0, 0.0, 1.0, 1.0, 0.5, 1.0, 0.0000, 0.9701, 0.2425,
            0.5, 0.5, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0000, 0.9701, 0.2425,
            -0.5, 0.5, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0000, 0.9701, 0.2425,
            0.0, 0.0, 2.0, 1.0, 1.0, 0.0, 1.0, 0.5, 1.0,-0.9701, 0.0000, 0.2425,
            -0.5, 0.5, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0,-0.9701, 0.0000, 0.2425,
            -0.5,-0.5, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0,-0.9701, 0.0000, 0.2425
        ], dtype=np.float32)
        self.pyramidVertNo = 12

        glBindBuffer(GL_ARRAY_BUFFER, self.bufID)
        glBufferData(GL_ARRAY_BUFFER, pyramid_vertex_data.nbytes, pyramid_vertex_data, GL_STATIC_DRAW)

        stride = 12 * np.float32().itemsize # stride in bytes
        offset = 0 # offset in bytes

        # position
        if self.vertexLoc != -1:
            glVertexAttribPointer(self.vertexLoc, 3, GL_FLOAT, GL_FALSE, stride, ctypes.c_void_p(offset))
            glEnableVertexAttribArray(self.vertexLoc)

        # color
        if self.colorLoc != -1:
            offset = 3 * np.float32().itemsize
            glVertexAttribPointer(self.colorLoc, 4, GL_FLOAT, GL_FALSE, stride, ctypes.c_void_p(offset))
            glEnableVertexAttribArray(self.colorLoc)
        
        # texCoord
        if self.texCoordLoc != -1:
            offset = (3 + 4) * np.float32().itemsize
            glVertexAttribPointer(self.texCoordLoc, 2, GL_FLOAT, GL_FALSE, stride, ctypes.c_void_p(offset))
            glEnableVertexAttribArray(self.texCoordLoc)

        # normal
        if self.normalLoc != -1:
            offset = (3 + 4 + 2) * 4
            glVertexAttribPointer(self.normalLoc, 3, GL_FLOAT, GL_FALSE, stride, ctypes.c_void_p(offset))
            glEnableVertexAttribArray(self.normalLoc)

        self.texID = self._load_texture("checkerboard.png")

    def resize(self, w, h):
        glViewport(0, 0, w, h)
        # glm.perspective replaces gluPerspective
        self.projection = glm.perspective(math.radians(30.0), w / h, 1.0, 10.0)
    
    def display(self):
        glClearColor(0.0, 0.0, 0.0, 0.0)
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)

        # update the modelview matrix (camera orbits)
        rad = math.radians(self.t)
        eye_vec = glm.vec3(5.0 * math.cos(rad), 5.0 * math.sin(rad), 5.0)
        target_vec = glm.vec3(0.0, 0.0, 0.5)
        up_vec = glm.vec3(0.0, 0.0, 1.0)
         # glm.lookAt replaces gluLookAt
        self.modelview = glm.lookAt(eye_vec, target_vec, up_vec)
        # self.mat4_print(self.modelview)

        glUseProgram(self.progID)

        # activate texture unit 0
        glActiveTexture(GL_TEXTURE0)
        # bind texture
        glBindTexture(GL_TEXTURE_2D, self.texID)
        # inform the shader to use texture unit 0
        glUniform1i(self.texLoc, 0)

        # load the current projection and modelview matrix into the
        # corresponding UNIFORM variables of the shader
        glUniformMatrix4fv(self.projectionLoc, 1, GL_FALSE, glm.value_ptr(self.projection))
        glUniformMatrix4fv(self.modelviewLoc, 1, GL_FALSE, glm.value_ptr(self.modelview))

        # bind Pyramid VAO
        glBindVertexArray(self.vaoID)
        # render data
        glDrawArrays(GL_TRIANGLES, 0, self.pyramidVertNo)

    def dispose(self):
        glDeleteVertexArrays(1, [self.vaoID])
        glDeleteBuffers(1, [self.bufID])
        glDeleteTextures([self.texID])
        glDeleteProgram(self.progID)
        glDeleteShader(self.vertID)
        glDeleteShader(self.fragID)

    # private member functions
    def _setup_shaders(self):

        # create shaders
        self.vertID = glCreateShader(GL_VERTEX_SHADER)
        self.fragID = glCreateShader(GL_FRAGMENT_SHADER)

        # provide shader code as multiline string
        vs_source = """
#version 140
 
in vec3 inputPosition;
in vec4 inputColor;
in vec2 inputTexCoord;
in vec3 inputNormal;

uniform mat4 projection;
uniform mat4 modelview;

out vec3 forFragColor;
out vec2 forFragTexCoord;

void main(){
    forFragColor = inputColor.rgb;
    forFragTexCoord = inputTexCoord;
    gl_Position =  projection * modelview * vec4(inputPosition, 1.0);
}
"""
        
        fs_source = """
#version 140
 
in vec3 forFragColor;
in vec2 forFragTexCoord;
out vec4 outputColor;

uniform sampler2D myTexture;

void main() {
    vec3 textureColor = vec3( texture(myTexture, forFragTexCoord) );
    // Note: The original C++ code seems to be multiplying the vertex color 
    // by the texture color.
    outputColor = vec4(forFragColor * textureColor ,1.0); 
}
"""

        # specify shader source
        glShaderSource(self.vertID, vs_source)
        glShaderSource(self.fragID, fs_source)

        # compile shaders
        glCompileShader(self.vertID)
        glCompileShader(self.fragID)

        # check for errors
        self._print_shader_info_log(self.vertID)
        self._print_shader_info_log(self.fragID)

        # create program and attach shaders
        self.progID = glCreateProgram()
        glAttachShader(self.progID, self.vertID)
        glAttachShader(self.progID, self.fragID)

        # "outColor" is a user-provided OUT variable
        # of the fragment shader.
        # Its output is bound to the first color buffer
        # in the framebuffer
        glBindFragDataLocation(self.progID, 0, "outputColor")

        # link the program
        glLinkProgram(self.progID)
        # output error messages
        self._print_program_info_log(self.progID)

        # retrieve the location of the IN variables of the vertex shader.
        self.vertexLoc = glGetAttribLocation(self.progID, "inputPosition")
        self.colorLoc = glGetAttribLocation(self.progID, "inputColor")
        self.texCoordLoc = glGetAttribLocation(self.progID, "inputTexCoord")
        self.normalLoc = glGetAttribLocation(self.progID, "inputNormal")

        # retrieve the location of the UNIFORM variables of the vertex shader.
        self.projectionLoc = glGetUniformLocation(self.progID, "projection")
        self.modelviewLoc = glGetUniformLocation(self.progID, "modelview")
        self.texLoc = glGetUniformLocation(self.progID, "myTexture")
    
    def _print_shader_info_log(self, shader_obj):
        log = glGetShaderInfoLog(shader_obj)
        if log:
            print(f"Shader Log: {log.decode()}")

    def _print_program_info_log(self, prog_obj):
        log = glGetProgramInfoLog(prog_obj)
        if log:
            print(f"Program Log: {log.decode()}")

    def _load_shader_src(self, filename):
        try:
            with open(filename, 'r') as f:
                return f.read()
        except IOError:
            print(f"Unable to open file {filename}")
            sys.exit(1)

    def _load_texture(self, filename):
        try:
            img = Image.open(filename).transpose(Image.FLIP_TOP_BOTTOM)
            if img.mode != 'RGB':
                img = img.convert('RGB')
                
            img_data = np.array(img.getdata(), np.uint8)
            width, height = img.size
        except FileNotFoundError:
            print(f"Error: Texture file '{filename}' not found.")
            return 0
        except Exception as e:
            print(f"Error loading texture: {e}")
            return 0

        # request textureID
        tex_id = glGenTextures(1)
        
        # bind texture
        glBindTexture(GL_TEXTURE_2D, tex_id)

        # define how to filter the texture
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR)

        # specify the 2D texture map
        glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB, width, height, 0, GL_RGB, GL_UNSIGNED_BYTE, img_data)
        
        # return unique texture identifier
        return tex_id

# --- Global Renderer Instance ---
renderer = Renderer()

# --- GLUT Callbacks ---
def glut_resize(w, h):
    renderer.resize(w, h)

def glut_display():
    renderer.display()
    glutSwapBuffers()
    
def glut_close():
    renderer.dispose()

def timer(v):
    offset = 1.0
    renderer.t += offset
    glut_display()
    glutTimerFunc(20, timer, v + 1)

def main():
    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_DEPTH | GLUT_DOUBLE | GLUT_RGBA)
    glutInitWindowPosition(100, 100)
    glutInitWindowSize(320, 320)
    
    glutCreateWindow(b"Accessing textures in the shader")

    # Register Callbacks
    glutDisplayFunc(glut_display)
    glutReshapeFunc(glut_resize)
    glutCloseFunc(glut_close)
    
    # Initialize Renderer
    renderer.init()

    glutTimerFunc(20, timer, 0)

    # Enter the main event loop
    glutMainLoop()

if __name__ == "__main__":
    main()