# This code example is created for educational purpose
# by Thorsten Thormaehlen (contact: www.thormae.de).
# It is distributed without any warranty.

# Requires PyOpenGL, PyGLM, and numpy, which can be install with:
# pip install PyOpenGL PyOpenGL_accelerate PyGLM numpy 

from OpenGL.GL import *
from OpenGL.GLUT import *
import numpy as np
import sys
import ctypes
import math
import glm

class Renderer:
    def __init__(self):
        self.t = 0.0
        self.triangleVertNo = 0
        self.progID = 0
        self.vertID = 0
        self.fragID = 0
        self.vertexLoc = -1
        self.colorLoc = -1
        self.vaoID = 0
        self.bufID = 0
        self.projectionLoc = -1
        self.modelviewLoc = -1
        self.projection = glm.mat4()
        self.modelview = glm.mat4()

    # public member functions
    def init(self):
        glEnable(GL_DEPTH_TEST)

        self._setup_shaders()

        # create a Vertex Array Objects (VAO)
        self.vaoID = glGenVertexArrays(1)
        
        # generate a Vertex Buffer Object (VBO)
        self.bufID = glGenBuffers(1)

        # binding the Triangle VAO
        glBindVertexArray(self.vaoID)

        triangle_vertex_data = np.array([
             0.0,  0.5, 0.0, 1.0, 0.0, 0.0, 1.0,  
            -0.5, -0.5, 0.0, 0.0, 1.0, 0.0, 1.0,  
             0.5, -0.5, 0.0, 0.0, 0.0, 1.0, 1.0, 
        ], dtype=np.float32)
        self.triangleVertNo = 3

        glBindBuffer(GL_ARRAY_BUFFER, self.bufID)
        glBufferData(GL_ARRAY_BUFFER, triangle_vertex_data.nbytes, triangle_vertex_data, GL_STATIC_DRAW)

        stride = 7 * np.float32().itemsize # stride in bytes
        offset = 0 # offset in bytes

        # position
        if self.vertexLoc != -1:
            glVertexAttribPointer(self.vertexLoc, 3, GL_FLOAT, GL_FALSE, stride, ctypes.c_void_p(offset))
            glEnableVertexAttribArray(self.vertexLoc)

        # color
        if self.colorLoc != -1:
            offset = 3 * np.float32().itemsize
            glVertexAttribPointer(self.colorLoc, 4, GL_FLOAT, GL_FALSE, stride, ctypes.c_void_p(offset))
            glEnableVertexAttribArray(self.colorLoc)

    def resize(self, w, h):
        glViewport(0, 0, w, h)
        # glm.perspective replaces gluPerspective
        self.projection = glm.perspective(math.radians(45.0), w / h, 0.5, 4.0)
    
    def display(self):
        glClearColor(0.0, 0.0, 0.0, 0.0)
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)

        # update the modelview matrix (camera orbits)
        rad = math.radians(self.t)
        eye_vec = glm.vec3(2.0 * math.cos(rad), 2.0, 2.0 * math.sin(rad))
        target_vec = glm.vec3(0.0, 0.0, 0.0)
        up_vec = glm.vec3(0.0, 1.0, 0.0)
         # glm.lookAt replaces gluLookAt
        self.modelview = glm.lookAt(eye_vec, target_vec, up_vec)

        # self.mat4_print(self.modelview)

        glUseProgram(self.progID)

        # load the current projection and modelview matrix into the
        # corresponding UNIFORM variables of the shader
        glUniformMatrix4fv(self.projectionLoc, 1, GL_FALSE, glm.value_ptr(self.projection))
        glUniformMatrix4fv(self.modelviewLoc, 1, GL_FALSE, glm.value_ptr(self.modelview))

        # bind Triangle VAO
        glBindVertexArray(self.vaoID)
        # render data
        glDrawArrays(GL_TRIANGLES, 0, self.triangleVertNo)

    def dispose(self):
        glDeleteVertexArrays(1, [self.vaoID])
        glDeleteBuffers(1, [self.bufID])
        glDeleteProgram(self.progID)
        glDeleteShader(self.vertID)
        glDeleteShader(self.fragID)

    # private member functions
    def _setup_shaders(self):

        # create shaders
        self.vertID = glCreateShader(GL_VERTEX_SHADER)
        self.fragID = glCreateShader(GL_FRAGMENT_SHADER)

        # load shader source from file
        vs_source = self._load_shader_src("uniform.vert")
        fs_source = self._load_shader_src("uniform.frag")

        # specify shader source
        glShaderSource(self.vertID, vs_source)
        glShaderSource(self.fragID, fs_source)

        # compile shaders
        glCompileShader(self.vertID)
        glCompileShader(self.fragID)

        # check for errors
        self._print_shader_info_log(self.vertID)
        self._print_shader_info_log(self.fragID)

        # create program and attach shaders
        self.progID = glCreateProgram()
        glAttachShader(self.progID, self.vertID)
        glAttachShader(self.progID, self.fragID)

        # "outColor" is a user-provided OUT variable
        # of the fragment shader.
        # Its output is bound to the first color buffer
        # in the framebuffer
        glBindFragDataLocation(self.progID, 0, "outputColor")

        # link the program
        glLinkProgram(self.progID)
        # output error messages
        self._print_program_info_log(self.progID)

        # "inputPosition" and "inputColor" are user-provided
        # IN variables of the vertex shader.
        # Their locations are stored to be used later with
        # glEnableVertexAttribArray()
        self.vertexLoc = glGetAttribLocation(self.progID, "inputPosition")
        self.colorLoc = glGetAttribLocation(self.progID, "inputColor")

        # "projection" and "modelview" are user-provided
        # UNIFORM variables of the vertex shader.
        # Their locations are stored for later use.
        self.projectionLoc = glGetUniformLocation(self.progID, "projection")
        self.modelviewLoc = glGetUniformLocation(self.progID, "modelview")

    def _print_shader_info_log(self, shader_obj):
        log = glGetShaderInfoLog(shader_obj)
        if log:
            print(f"Shader Log: {log.decode()}")

    def _print_program_info_log(self, prog_obj):
        log = glGetProgramInfoLog(prog_obj)
        if log:
            print(f"Program Log: {log.decode()}")

    def _load_shader_src(self, filename):
        try:
            with open(filename, 'r') as f:
                return f.read()
        except IOError:
            print(f"Unable to open file {filename}")
            sys.exit(1)

# --- Global Renderer Instance ---
renderer = Renderer()

# --- GLUT Callbacks ---
def glut_resize(w, h):
    renderer.resize(w, h)

def glut_display():
    renderer.display()
    glutSwapBuffers()
    
def glut_close():
    renderer.dispose()

def timer(v):
    offset = 1.0
    renderer.t += offset
    glut_display()
    glutTimerFunc(20, timer, v + 1)

def main():
    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_DEPTH | GLUT_DOUBLE | GLUT_RGBA)
    glutInitWindowPosition(100, 100)
    glutInitWindowSize(320, 320)
    
    glutCreateWindow(b"Shader-based gluLookAt simulation")

    # Register Callbacks
    glutDisplayFunc(glut_display)
    glutReshapeFunc(glut_resize)
    glutCloseFunc(glut_close)
    
    # Initialize Renderer
    renderer.init()

    glutTimerFunc(20, timer, 0)

    # Enter the main event loop
    glutMainLoop()

if __name__ == "__main__":
    main()