import cv2
import mediapipe as mp
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from tkinter import Tk, Label
from PIL import Image, ImageTk
import pyautogui

from mouse_class import Mouse
from keyboard_class import Keyboard
from specialkeys_class import Specialkeys
from hand_detection import normalise_landmarks, landmarks_from_results
from tools import load_model, set_camera_window

# hide mediapype warning :UserWarning: SymbolDatabase.GetPrototype() is deprecated. Please use message_factory.GetMessageClass() instead. SymbolDatabase.GetPrototype() will be removed soon.
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

MOUSE_ACTIVE = True
FREEZE_CHANGE_MODEL = False

def main():
    global MOUSE_ACTIVE
    global FREEZE_CHANGE_MODEL
    #define Mouse
    mouse = Mouse()
    keyboard = Keyboard()
    specialkeys = Specialkeys()

    # load MOUSE model
    model_mouse = load_model(device = "mouse")
    model_keyboard = load_model(device = "keyboard")
    model_specialkeys = load_model(device = "specialkeys")
    
    # create hand detection object
    mp_hands = mp.solutions.hands
    mp_drawing = mp.solutions.drawing_utils
    
    # open video
    cap = cv2.VideoCapture(0)
    
    # if cannot open video give warning
    if not cap.isOpened():
        print("Warning: Cannot reach camera")
        return
    
    # set up Tkinter window
    root, video_label = set_camera_window()
    
    # mediapipe hand object
    with mp_hands.Hands(max_num_hands=2, model_complexity=1,
                        min_detection_confidence=0.9, min_tracking_confidence=0.9) as hands:
        
        def update_frame():
            global MOUSE_ACTIVE
            global FREEZE_CHANGE_MODEL
            ret, frame = cap.read()
            if not ret:
                print("Warning: Cannot read camera input")
                root.destroy()
                return
            
            # flip frame and process it
            frame = cv2.flip(frame, 1)
            frameRGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            # Hand detection
            results = hands.process(frameRGB)
            
            right_landmark_list = []
            left_landmark_list = []
            command = None

            if results.multi_hand_landmarks:
                # two hands are detected, so we split left and right 
                left_hand_landmarks, right_hand_landmarks = landmarks_from_results(results) 

                # if right hand detected, process
                if right_hand_landmarks is not None:
                    # Draw landmarks on frame
                    mp_drawing.draw_landmarks(
                        frameRGB, right_hand_landmarks, mp_hands.HAND_CONNECTIONS, 
                        mp_drawing.DrawingSpec(color=(250, 0, 0), thickness=2, circle_radius=4),
                        mp_drawing.DrawingSpec(color=(0, 250, 0), thickness=2, circle_radius=2)
                    )
                    
                    # get landmark list with indices described in https://github.com/google-ai-edge/mediapipe/blob/master/mediapipe/python/solutions/hands.py
                    for lm in right_hand_landmarks.landmark:
                        right_landmark_list.append((lm.x, lm.y))
                    
                    # normalise landmarks for more powerful training
                    normalised_right_landmark_list = normalise_landmarks(right_landmark_list)
                    
                    # apply model
                    if MOUSE_ACTIVE:
                        pred = model_mouse.predict(np.asarray(normalised_right_landmark_list).reshape(1, -1))
                        command = pred[0]
                        mouse.add_prediction(command)
                        
                        if command == "move cursor" or command == "drag":
                            mouse.get_hand_size(right_landmark_list[12], right_landmark_list[0])
                            mouse.get_hand_pos(right_landmark_list[9])
                        elif command == "change the model":
                            if not FREEZE_CHANGE_MODEL:
                                MOUSE_ACTIVE = False
                                FREEZE_CHANGE_MODEL = True
                        else:
                            FREEZE_CHANGE_MODEL = False

                    else:
                        pred = model_keyboard.predict(np.asarray(normalised_right_landmark_list).reshape(1, -1))
                        command = pred[0]
                        keyboard.add_prediction(command)
                        if command == "change the model":
                            if not FREEZE_CHANGE_MODEL:
                                MOUSE_ACTIVE = True
                                FREEZE_CHANGE_MODEL = True
                        else:
                            FREEZE_CHANGE_MODEL = False

                    cv2.putText(
                        img=frameRGB, 
                        text=f"{pred[0]}, MOUSE: {MOUSE_ACTIVE}",
                        org=(30, 30), fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1, color=(255, 0, 0), thickness=1
                    )


                if left_hand_landmarks is not None:
                    # Draw landmarks on frame
                    mp_drawing.draw_landmarks(
                        frameRGB, left_hand_landmarks, mp_hands.HAND_CONNECTIONS, 
                        mp_drawing.DrawingSpec(color=(0, 250, 0), thickness=2, circle_radius=4),
                        mp_drawing.DrawingSpec(color=(0, 120, 120), thickness=2, circle_radius=2)
                    )
                    
                    # get landmark list with indices described in https://github.com/google-ai-edge/mediapipe/blob/master/mediapipe/python/solutions/hands.py
                    for lm in left_hand_landmarks.landmark:
                        left_landmark_list.append((lm.x, lm.y))
                
                    # normalise landmarks for more powerful training
                    normalised_left_landmark_list = normalise_landmarks(left_landmark_list)
                
                    # apply model
                    pred = model_specialkeys.predict(np.asarray(normalised_left_landmark_list).reshape(1, -1))
                    command = pred[0]
                    cv2.putText(
                        img=frameRGB, text=pred[0], org=(30, 60), 
                        fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1, color=(0, 255, 0), thickness=1
                    )
                    
                    specialkeys.add_prediction(command)
                else:
                    pyautogui.keyUp('shift')
                    pyautogui.keyUp('ctrl')
                    pyautogui.keyUp('alt')

            # Convert frame to Tkinter-compatible format and display
            frameRGB_resized = cv2.resize(frameRGB, (root.winfo_width(), root.winfo_height()))
            img = ImageTk.PhotoImage(Image.fromarray(frameRGB_resized))
            video_label.config(image=img)
            video_label.image = img

            # Refresh frame
            root.after(10, update_frame)

        # Start updating frames
        update_frame()

        # Quit the program properly
        root.protocol("WM_DELETE_WINDOW", lambda: (cap.release(), root.destroy()))
        root.mainloop()

    cap.release()
    print("Program closed")

if __name__ == '__main__':
    main()