Skip to content
Snippets Groups Projects
main.py 4.07 KiB
Newer Older
  • Learn to ignore specific revisions
  • Vajay Mónika's avatar
    Vajay Mónika committed
    import cv2
    import random
    import mediapipe as mp
    import pickle
    import numpy as np
    from sklearn.ensemble import RandomForestClassifier
    import time
    import os
    
    from mouse_class import Mouse
    from hand_detection import normalise_landmarks
      
    
    ## main: open video and do hand detection
    def main():
        #define Mouse
        mouse = Mouse()
    
        # load model
        current_dir = os.path.dirname(__file__)
        model_path = os.path.abspath(os.path.join(current_dir, os.pardir, 'trained_models', 'trained_Moni_data.p'))
        model_dict = pickle.load(open(model_path, 'rb'))
        model = model_dict['model']
        
        # create hand detection object
        mp_hands = mp.solutions.hands
        mp_drawing = mp.solutions.drawing_utils
        
        # open video
        cap = cv2.VideoCapture(0)
        
        # if cannot open video give warning
        if not cap.isOpened():
            print("Warning: cannot reach camera")
        else:
            print("Program is running, push 'q' to quit.")
            
        # mediapipe hand object
        with mp_hands.Hands( max_num_hands=1, model_complexity=1,
                            min_detection_confidence=0.9, min_tracking_confidence=0.9) as hands:
            
            # read frames from webcamera
            while cap.isOpened():        
                ret, frame = cap.read()
                
                if not ret:
                    print("Warning: cannot read camera input")
                    break
                    
                # flip frame to appear as a mirror
                frame = cv2.flip(frame, 1)
                frameRGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                
                ## hand detection
                results = hands.process(frameRGB)
                
                landmark_list = []
                mouse_command = None
                if results.multi_hand_landmarks:
                    # multi_hand_landmarks can store two hands, if max_num_hands=2, in which case we have to iterate through the hands with
                    # for num, hand in enumerate(results.multi_hand_landmarks): 
                    
                    # one hand is detected, because max_num_hands=1
                    hand_landmarks = results.multi_hand_landmarks[0]  
    
                    # draw landmarks on frame
                    mp_drawing.draw_landmarks(frameRGB, hand_landmarks, mp_hands.HAND_CONNECTIONS, 
                                                mp_drawing.DrawingSpec(color=(250, 0, 0), thickness=2, circle_radius=4),
                                                mp_drawing.DrawingSpec(color=(0, 250, 0), thickness=2, circle_radius=2),
                                                )
                    
                    # get landmark list with indices described in https://github.com/google-ai-edge/mediapipe/blob/master/mediapipe/python/solutions/hands.py
                    for lm in hand_landmarks.landmark:
                        landmark_list.append((lm.x, lm.y))
                
                    # normalise landmarks for mor powerful training
                    normalised_landmark_list = normalise_landmarks(landmark_list)
                
                    # apply model
                    pred = model.predict(np.asarray(normalised_landmark_list).reshape(1, -1))
                    mouse_command = pred[0]
                    cv2.putText(img = frameRGB, text = pred[0], org = (30,30), 
                        fontFace = cv2.FONT_HERSHEY_DUPLEX, fontScale = 1, color = (255, 0, 0), thickness = 1)
    
                    mouse.add_prediction(mouse_command)
                    if mouse_command == "move cursor" or "grab":
                        mouse.get_hand_pos(landmark_list[8])
                # transform back RGB and show frame with annotation
                frame_annotated = cv2.cvtColor(frameRGB, cv2.COLOR_RGB2BGR)
                cv2.imshow('Hand tracking', frame_annotated)
                
                # or show original frame without annotation
                # cv2.imshow('Hand tracking', frame)
                
                # Check for key presses
                key = cv2.waitKey(1) & 0xFF
                
                if key == ord('n'):
                    label = ""
                elif key == ord('q'):
                    print("Quit camera")
                    break
    
        cap.release()
        cv2.destroyAllWindows()
        
        print("Program closed")
    
    if __name__ == '__main__':
        main()