#SOURCES: 
#https://www.youtube.com/watch?v=FygLqV15TxQ
#https://www.youtube.com/watch?v=EgjwKM3KzGU

import cv2
import mediapipe as mp
import time
import numpy as np

# class creation
class handDetector():
    def __init__(self, mode=False, maxHands=2, detectionCon=0.5,modelComplexity=1,trackCon=0.5):
        self.mode = mode
        self.maxHands = maxHands
        self.detectionCon = detectionCon
        self.modelComplex = modelComplexity
        self.trackCon = trackCon
        self.mpHands = mp.solutions.hands
        self.hands = self.mpHands.Hands(self.mode, self.maxHands,self.modelComplex,
                                        self.detectionCon, self.trackCon)
        self.mpDraw = mp.solutions.drawing_utils # it gives small dots onhands total 20 landmark points

    def findHands(self,img,draw=True):
        # Send rgb image to hands
        imgRGB = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        imgRGB.flags.writeable = False  #makes program run faster
        self.results = self.hands.process(imgRGB) # process the frame 
        imgRGB.flags.writeable = True  #makes program run faster
        if self.results.multi_hand_landmarks:
            for handLMS in self.results.multi_hand_landmarks:
                if draw:
                    #Draw dots and connect them
                    self.mpDraw.draw_landmarks(img,handLMS,self.mpHands.HAND_CONNECTIONS,
                                    self.mpDraw.DrawingSpec(color = (164,230,32),thickness = 2, circle_radius = 2),
                                    self.mpDraw.DrawingSpec(color =  (115,115,102),thickness = 2,circle_radius = 2)
                                )  
        return img

    def getPositions(self,img):
        """gets the node value and postion of all nodes

        Args:
            img (_type_): image

        Returns:
            list of lists: returns a list of nodes which contain lists that have id,x, y values
        """
        lmList = []
        if self.results.multi_hand_landmarks:   #if we have values from hand detection
            for myHand in self.results.multi_hand_landmarks:
                for id, lm in enumerate(myHand.landmark):
                    h, w, c = img.shape #gets dimension of screen
                    cx, cy, cz = int(lm.x * w), int(lm.y * h), float(lm.z)   #translates landmarks with position of screen
                    #print(id, cx, cy)
                    lmList.append([id, cx, cy])
        return lmList
        
    def drawLine(self,img, x1, y1, x2, y2, draw = True):
        """draws a line bewtween two points

        Args:
            img (_type_): _description_
            x1 (_type_): _description_
            y1 (_type_): _description_
            x2 (_type_): _description_
            y2 (_type_): _description_
            draw (bool, optional): _description_. Defaults to True.
        """
        if draw:
            cv2.line(img, (x1,y1),(x2,y2), (51,153,255), 5)
    
    def get_length(self,img,node1 ,node2, draw = False):
        """takes the compressed list, finds the x and y values of each node and finds the length between the two

        Args:
            img (_type_): image
            list (_type_): compressed list (focus on node 0,4,8,12,16,20)
            node1 (_type_): which node are we interested in in the compressed lsit
            node2 (_type_): which node are we interested in in the compressed lsit
            draw (bool, optional): draws and prints information. Defaults to False.

        Returns:
            int: returns length between both nodes
        """
        x1,x2,y1,y2 = 9999999,9999999,9999999,9999999   #starting value of positions
        if self.results.multi_hand_landmarks:
            h, w, c = img.shape #gets dimension of screen
            for myHand in self.results.multi_hand_landmarks:
                for id, lm in enumerate(myHand.landmark):
                    if id ==node1:
                        x1, y1 = int(lm.x * w), int(lm.y * h)   #translates landmarks with position of screen
                    elif id ==node2:
                        x2, y2 = int(lm.x * w), int(lm.y * h)   #translates landmarks with position of screen
        length = int(np.sqrt((x2-x1)**2+(y2-y1)**2))    #length between two points
        if draw:    #all of this is for visuals
            print(F"X1: {x1} , X2: {x2} , Y1: {y1} , Y2: {y2}")
            print(f"length: {length}")
            self.drawLine(img,x1,y1,x2,y2)
            self.drawPosition(img,True,node1)
            self.drawPosition(img,True,node2)
        return length
                           
    def drawPosition(self,img,draw = True,node =0):
        """draws a circle around a specific node

        Args:
            img (_type_): screen image
            node (int, optional): specifiy which node we want to draw. Defaults to 0.
        """
        if self.results.multi_hand_landmarks:
            for myHand in self.results.multi_hand_landmarks:
                for id, lm in enumerate(myHand.landmark):
                    h, w, c = img.shape #gets dimension of screen
                    cx, cy = int(lm.x * w), int(lm.y * h)   #translates landmarks with position of screen
                    #print(id, cx, cy)
                    if id ==node and draw == True:
                        cv2.circle(img, (cx, cy), 15, (51,153,255), cv2.FILLED)

    def showFPS(self,img,draw = True,pTime =0):
        """Shows fps on screen

        Args:
            img (_type_): gets the image on screen
            draw (bool, optional): if we want to draw the time we set True if not false. Defaults to True.
            pTime (int, optional): previous time value. Defaults to 0.

        Returns:
            int: current time
        """
        if draw:
            cTime = time.time() #gets the current time
            fps = 1/(cTime-pTime)   #conversion from time to fps
            pTime = cTime   #this is previous time for fps calculation
            cv2.putText(img,f'FPS: {int(fps)}',(5,470),cv2.FONT_HERSHEY_PLAIN,  #this is how we add text to screen
                    2,(164,230,32),2)

        return pTime

def main():
    #Frame rates
    pTime = 0
    cTime = 0
    cap = cv2.VideoCapture(1)   #0 for built in camera, 1 for external camera
    detector = handDetector()   #detects hand

    while True:
        success,img = cap.read()
        img = detector.findHands(img)
        lmList = detector.getPositions(img)
        pTime = detector.showFPS(img,pTime = pTime)
        length_index = detector.get_length(img, 0,4,True)
        cv2.imshow("Video",img)
        if cv2.waitKey(1) == ord('q'):
            break
    cap.release()
    cv2.destroyAllWindows()

if __name__ == "__main__":
    main()