Skip to content
Snippets Groups Projects
Select Git revision
  • aabf10f034d1dd903a99cf4dcc42e062a43c465b
  • main default protected
2 results

main.py

Blame
  • main.py 4.46 KiB
    import cv2
    import random
    import mediapipe as mp
    import pickle
    import numpy as np
    from sklearn.ensemble import RandomForestClassifier
    import time
    import os
    from tkinter import Tk, Label
    from PIL import Image, ImageTk
    
    from mouse_class import Mouse
    from hand_detection import normalise_landmarks
    
    def main():
        #define Mouse
        mouse = Mouse()
    
        # load model
        current_dir = os.path.dirname(__file__)
        model_path = os.path.abspath(os.path.join(current_dir, os.pardir, 'trained_models', 'trained_Moni_data.p'))
        model_dict = pickle.load(open(model_path, 'rb'))
        model = model_dict['model']
        
        # create hand detection object
        mp_hands = mp.solutions.hands
        mp_drawing = mp.solutions.drawing_utils
        
        # open video
        cap = cv2.VideoCapture(0)
        
        # if cannot open video give warning
        if not cap.isOpened():
            print("Warning: Cannot reach camera")
            return
        
        # set up Tkinter window
        root = Tk()
        root.title("Hand Tracking - Always on Top")
        root.attributes("-topmost", True)
        video_label = Label(root)
        video_label.pack()
    
        # adjust window geometry
        # Get the screen width and height
        screen_width = root.winfo_screenwidth()
        screen_height = root.winfo_screenheight()
        
        # Define window size and position (e.g., 320x240 window at bottom-right corner)
        window_width = 160
        window_height = 120
        x_position = screen_width - window_width - 10  # 10px margin from the right
        y_position = screen_height - window_height - 70  # 50px margin from the bottom
    
        # Set window geometry
        root.geometry(f"{window_width}x{window_height}+{x_position}+{y_position}")
        # mediapipe hand object
        with mp_hands.Hands(max_num_hands=1, model_complexity=1,
                            min_detection_confidence=0.9, min_tracking_confidence=0.9) as hands:
            
            def update_frame():
                ret, frame = cap.read()
                if not ret:
                    print("Warning: Cannot read camera input")
                    root.destroy()
                    return
                
                # flip frame and process it
                frame = cv2.flip(frame, 1)
                frameRGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
                # Hand detection
                results = hands.process(frameRGB)
                
                landmark_list = []
                mouse_command = None
                if results.multi_hand_landmarks:
                    # one hand is detected, because max_num_hands=1
                    hand_landmarks = results.multi_hand_landmarks[0]
    
                    # Draw landmarks on frame
                    mp_drawing.draw_landmarks(
                        frameRGB, hand_landmarks, mp_hands.HAND_CONNECTIONS, 
                        mp_drawing.DrawingSpec(color=(250, 0, 0), thickness=2, circle_radius=4),
                        mp_drawing.DrawingSpec(color=(0, 250, 0), thickness=2, circle_radius=2)
                    )
                    
                    # get landmark list with indices described in https://github.com/google-ai-edge/mediapipe/blob/master/mediapipe/python/solutions/hands.py
                    for lm in hand_landmarks.landmark:
                        landmark_list.append((lm.x, lm.y))
                    
                    # normalise landmarks for more powerful training
                    normalised_landmark_list = normalise_landmarks(landmark_list)
                    
                    # apply model
                    pred = model.predict(np.asarray(normalised_landmark_list).reshape(1, -1))
                    mouse_command = pred[0]
                    cv2.putText(
                        img=frameRGB, text=pred[0], org=(30, 30), 
                        fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1, color=(255, 0, 0), thickness=1
                    )
    
                    mouse.add_prediction(mouse_command)
                    if mouse_command == "move cursor" or "grab":
                        mouse.get_hand_pos(landmark_list[8])
                
                # Convert frame to Tkinter-compatible format and display
                frameRGB_resized = cv2.resize(frameRGB, (root.winfo_width(), root.winfo_height()))
                img = ImageTk.PhotoImage(Image.fromarray(frameRGB_resized))
                video_label.config(image=img)
                video_label.image = img
    
                # Refresh frame
                root.after(10, update_frame)
    
            # Start updating frames
            update_frame()
    
            # Quit the program properly
            root.protocol("WM_DELETE_WINDOW", lambda: (cap.release(), root.destroy()))
            root.mainloop()
    
        cap.release()
        print("Program closed")
    
    if __name__ == '__main__':
        main()