# This code example is created for educational purpose
# by Thorsten Thormaehlen (contact: www.thormae.de).
# It is distributed without any warranty.

# Requires PyOpenGL, PyGLM, and numpy, which can be install with:
# pip install PyOpenGL PyOpenGL_accelerate PyGLM numpy 

from OpenGL.GL import *
from OpenGL.GLUT import *
import numpy as np
import sys
import ctypes
import math
import glm
import os

class Renderer:
    def __init__(self):
        self.t = 0.0
        self.modeVal = 1
        self.sceneVertNo = 0
        self.progID = 0
        self.vertID = 0
        self.fragID = 0
        self.vertexLoc = -1
        self.texCoordLoc = -1
        self.normalLoc = -1
        self.vaoID = 0
        self.bufID = 0
        self.projectionLoc = -1
        self.modelviewLoc = -1
        self.normalMatrixLoc = -1
        self.modeLoc = -1
        self.projection = glm.mat4()
        self.modelview = glm.mat4()

    # public member functions
    def init(self):
        glEnable(GL_DEPTH_TEST)

        self._setup_shaders()

        # create a Vertex Array Objects (VAO)
        self.vaoID = glGenVertexArrays(1)
        
        # generate a Vertex Buffer Object (VBO)
        self.bufID = glGenBuffers(1)

        # binding the Scene VAO
        glBindVertexArray(self.vaoID)

        per_vertex_floats = (3 + 2 + 3)
        vertex_data = self._load_vertex_data("teapot.vbo", per_vertex_floats)
        self.sceneVertNo = int(len(vertex_data) / per_vertex_floats)

        # Convert list of floats to numpy array
        np_vertex_data = np.array(vertex_data, dtype=np.float32)

        glBindBuffer(GL_ARRAY_BUFFER, self.bufID)
        glBufferData(GL_ARRAY_BUFFER, np_vertex_data.nbytes, np_vertex_data, GL_STATIC_DRAW)

        stride = per_vertex_floats * np.float32().itemsize # 8 * 4 bytes
        offset = 0 # offset in bytes

        # position
        if self.vertexLoc != -1:
            glVertexAttribPointer(self.vertexLoc, 3, GL_FLOAT, GL_FALSE, stride, ctypes.c_void_p(offset))
            glEnableVertexAttribArray(self.vertexLoc)

        # texCoord
        if self.texCoordLoc != -1:
            offset += 3 * np.float32().itemsize
            glVertexAttribPointer(self.texCoordLoc, 2, GL_FLOAT, GL_FALSE, stride, ctypes.c_void_p(offset))
            glEnableVertexAttribArray(self.texCoordLoc)

        # normal
        if self.normalLoc != -1:
            offset += 2 * np.float32().itemsize
            glVertexAttribPointer(self.normalLoc, 3, GL_FLOAT, GL_FALSE, stride, ctypes.c_void_p(offset))
            glEnableVertexAttribArray(self.normalLoc)


    def resize(self, w, h):
        glViewport(0, 0, w, h)
        # glm.perspective replaces gluPerspective
        self.projection = glm.perspective(math.radians(30.0), w / h, 0.5, 4.0)
    
    def display(self):
        glClearColor(0.0, 0.0, 0.0, 0.0)
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)

        # The camera orbits in the z = 1.5 plane
        # and looks at the origin
        rad = math.radians(self.t)
        eye_x = 1.5 * math.cos(rad)
        eye_y = 1.5 * math.sin(rad)
        eye_z = 1.5
        
        eye_vec = glm.vec3(eye_x, eye_y, eye_z)
        target_vec = glm.vec3(0.0, 0.0, 0.0)
        up_vec = glm.vec3(0.0, 0.0, 1.0)
        # glm.lookAt replaces gluLookAt
        self.modelview = glm.lookAt(eye_vec, target_vec, up_vec)
        
        # rotational part of the modelview transformation
        normal_matrix = glm.transpose(glm.inverse(self.modelview))

        glUseProgram(self.progID)

        # load the current projection and modelview matrix into the
        # corresponding UNIFORM variables of the shader
        glUniformMatrix4fv(self.projectionLoc, 1, GL_FALSE, glm.value_ptr(self.projection))
        glUniformMatrix4fv(self.modelviewLoc, 1, GL_FALSE, glm.value_ptr(self.modelview))
        glUniformMatrix4fv(self.normalMatrixLoc, 1, GL_FALSE, glm.value_ptr(normal_matrix))
        glUniform1i(self.modeLoc, self.modeVal)

        # bind Scene VAO
        glBindVertexArray(self.vaoID)
        # render data
        glDrawArrays(GL_TRIANGLES, 0, self.sceneVertNo)

    def dispose(self):
        glDeleteVertexArrays(1, [self.vaoID])
        glDeleteBuffers(1, [self.bufID])
        glDeleteProgram(self.progID)
        glDeleteShader(self.vertID)
        glDeleteShader(self.fragID)

    # private member functions
    def _setup_shaders(self):
        # create shaders
        self.vertID = glCreateShader(GL_VERTEX_SHADER)
        self.fragID = glCreateShader(GL_FRAGMENT_SHADER)

        # load shader source from file
        vs_source = self._load_shader_src("pass.vert")
        fs_source = self._load_shader_src("pass.frag")

        # specify shader source
        glShaderSource(self.vertID, vs_source)
        glShaderSource(self.fragID, fs_source)

        # compile shaders
        glCompileShader(self.vertID)
        glCompileShader(self.fragID)

        # check for errors
        self._print_shader_info_log(self.vertID)
        self._print_shader_info_log(self.fragID)

        # create program and attach shaders
        self.progID = glCreateProgram()
        glAttachShader(self.progID, self.vertID)
        glAttachShader(self.progID, self.fragID)

        # "outColor" is a user-provided OUT variable
        # of the fragment shader.
        # Its output is bound to the first color buffer
        # in the framebuffer
        glBindFragDataLocation(self.progID, 0, "outputColor")

        # link the program
        glLinkProgram(self.progID)
        # output error messages
        self._print_program_info_log(self.progID)

        # retrieve the location of the IN variables of the vertex shader
        self.vertexLoc = glGetAttribLocation(self.progID, "inputPosition")
        self.texCoordLoc = glGetAttribLocation(self.progID, "inputTexCoord")
        self.normalLoc = glGetAttribLocation(self.progID, "inputNormal")

        # retrieve the location of the UNIFORM variables of the vertex shader.
        self.projectionLoc = glGetUniformLocation(self.progID, "projection")
        self.modelviewLoc = glGetUniformLocation(self.progID, "modelview")
        self.normalMatrixLoc = glGetUniformLocation(self.progID, "normalMat")
        self.modeLoc = glGetUniformLocation(self.progID, "mode")


    def _print_shader_info_log(self, shader_obj):
        log = glGetShaderInfoLog(shader_obj)
        if log:
            print(f"Shader Log: {log.decode()}") 

    def _print_program_info_log(self, prog_obj):
        log = glGetProgramInfoLog(prog_obj)
        if log:
            print(f"Program Log: {log.decode()}")

    def _load_shader_src(self, filename):
        try:
            with open(filename, 'r') as f:
                return f.read()
        except IOError:
            print(f"Unable to open file {filename}")
            sys.exit(1)

    def _load_vertex_data(self, filename, per_vertex_floats):
        data = []
        try:
            with open(filename, 'r') as input_file:
                try:
                    num_floats = int(input_file.readline().strip())
                except ValueError:
                    print(f"Error reading number of floats from vbo data file {filename}")
                    return None
                
                for line in input_file:
                    for s in line.split():
                        try:
                            data.append(float(s))
                        except ValueError:
                            continue 
                            
                if len(data) != num_floats or num_floats % per_vertex_floats != 0:
                    return None

        except IOError:
            print(f"Can not find vbo data file {filename}")
            return None
            
        return data


# --- Global Renderer Instance ---
renderer = Renderer()

# --- GLUT Callbacks ---
def glut_resize(w, h):
    renderer.resize(w, h)

def glut_display():
    renderer.display()
    glutSwapBuffers()

def glut_close():
    renderer.dispose()

def timer(v):
    offset = 1.0
    renderer.t += offset
    glut_display()
    glutTimerFunc(20, timer, v + 1)

def glut_keyboard(key, x, y):
    redraw = False
    mode_str = ""
    char_key = key.decode('utf-8')
    
    if char_key == '1':
        renderer.modeVal = 1
        redraw = True
        mode_str = "Global Normals"
    elif char_key == '2':
        renderer.modeVal = 2
        redraw = True
        mode_str = "Local Normals"
    elif char_key == '3':
        renderer.modeVal = 3
        redraw = True
        mode_str = "Global Vertex Positions"
    elif char_key == '4':
        renderer.modeVal = 4
        redraw = True
        mode_str = "Local Vertex Positions"
    elif char_key == '5':
        renderer.modeVal = 5
        redraw = True
        mode_str = "Texture Coordinates"
        
    if redraw:
        glut_display()
        print(mode_str)
        glutSetWindowTitle(mode_str.encode('utf-8'))

def main():
    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_DEPTH | GLUT_DOUBLE | GLUT_RGBA)
    glutInitWindowPosition(100, 100)
    glutInitWindowSize(320, 320)
    
    glutCreateWindow(b"Transforming normals")

    # Register Callbacks
    glutDisplayFunc(glut_display)
    glutReshapeFunc(glut_resize)
    glutKeyboardFunc(glut_keyboard)
    glutCloseFunc(glut_close)
    
    # Initialize Renderer
    renderer.init()

    glutTimerFunc(20, timer, 0)

    # Enter the main event loop
    glutMainLoop()

if __name__ == "__main__":
    main()