# arm_camera.py
#
# Sample Webots controller file for driving a two-link arm with three driven
# joints and an end-mounted camera.

# No copyright, 2020-2024, Garth Zeglin.  This file is
# explicitly placed in the public domain.

print("zyy_camera.py waking up.")

# Import the Webots simulator API.
from controller import Robot, Display

# Import the standard Python math library.
import math, random, time

# Import the third-party numpy library for matrix calculations.
# Note: this can be installed using 'pip3 install numpy' or 'pip3 install scipy'.
import numpy as np

# Import the third-party OpenCV library for image operations.
# Note: this can be installed using 'pip3 install opencv-python'
import cv2 as cv

# Import the bezier module from the same directory.  Note: this requires numpy.
import bezier

# Import the vision module from the same directory.  Note: this requires OpenCV.
import vision

# Define the time step in milliseconds between controller updates.
EVENT_LOOP_DT = 50

################################################################
class ZYY(Robot):
    def __init__(self):

        super(ZYY, self).__init__()
        self.robot_name = self.getName()
        print("%s: controller connected." % (self.robot_name))

        # Attempt to randomize the random library sequence.
        random.seed(time.time())

        # Initialize geometric constants.  These should match
        # the current geometry of the robot.
        self.base_z  = 0.29  # joint2 axis height: 0.12 fixed base height + 0.17 j2 z offset
        self.link1   = 0.5   # link1 length, i.e. distance between j2 and j3
        self.link2   = 0.5   # link2 length, i.e. distance between j3 and end

        # Fetch handles for the joint motors.
        self.motor1 = self.getDevice('motor1')
        self.motor2 = self.getDevice('motor2')
        self.motor3 = self.getDevice('motor3')

        # Adjust the motor controller properties.
        self.motor1.setAvailableTorque(20.0)
        self.motor2.setAvailableTorque(15.0)
        self.motor3.setAvailableTorque(10.0)

        # Adjust the low-level controller gains.
        print("%s: setting PID gains." % (self.robot_name))
        self.motor1.setControlPID(100.0, 0.0, 25.0)
        self.motor2.setControlPID( 50.0, 0.0, 25.0)
        self.motor3.setControlPID( 50.0, 0.0, 25.0)

        # Fetch handles for the joint sensors.
        self.joint1 = self.getDevice('joint1')
        self.joint2 = self.getDevice('joint2')
        self.joint3 = self.getDevice('joint3')

        # Save the most recent target position.
        self.joint_target = np.zeros(3)

        # Specify the sampling rate for the joint sensors.
        self.joint1.enable(EVENT_LOOP_DT)
        self.joint2.enable(EVENT_LOOP_DT)
        self.joint3.enable(EVENT_LOOP_DT)

        # Enable the overhead camera and set the sampling frame rate.
        self.camera = self.getDevice('camera')
        self.camera_frame_dt = EVENT_LOOP_DT # units are milliseconds
        self.camera_frame_timer = 0
        self.camera.enable(self.camera_frame_dt)
        self.frame_count = 0
        self.vision_target = None

        # Connect to the display object to show debugging data within the Webots GUI.
        # The display width and height should match the camera.
        self.display = self.getDevice('display')

        # Initialize behavior state machines.
        self.state_timer = 2*EVENT_LOOP_DT
        self.state_index = 0
        self._init_spline()
        return

    #================================================================
    def endpoint_forward_kinematics(self, q):
        """Compute the forward kinematics for the end point.  Returns the
        body-coordinate XYZ Cartesian position of the end point for a given joint
        angle vector.
        :param q: three-element list with [q1, q2, q3] joint angles in radians
        :return:  three-element list with endpoint [x,y,z] location

        """
        j1 = q[0]
        j2 = q[1]
        j3 = q[2]

        return [(self.link1*math.sin(j2) + self.link2*math.sin(j2 + j3))*math.cos(j1),
                (self.link1*math.sin(j2) + self.link2*math.sin(j2 + j3))*math.sin(j1),
                self.base_z + self.link1*math.cos(j2) + self.link2*math.cos(j2 + j3)]

    #================================================================
    def endpoint_inverse_kinematics(self, target):
        """Compute the joint angles for a target end position.  The target is a
        XYZ Cartesian position vector in body coordinates, and the result vector
        is a joint angles as list [j1, j2, j3].  If the target is out of reach,
        returns the closest pose.  With j1 between -pi and pi, and j2 and j3
        limited to positive rotations, the solution is always unique.
        """

        x = target[0]
        y = target[1]
        z = target[2]

        # the joint1 base Z rotation depends only upon x and y
        j1 = math.atan2(y, x)

        # distance within the XY plane from the origin to the endpoint projection
        xy_radial = math.sqrt(x*x + y*y)

        # find the Z offset from the J2 horizontal plane to the end point
        end_z = z - self.base_z

        # angle between the J2 horizonta plane and the endpoint
        theta = math.atan2(end_z, xy_radial)

        # radial distance from the J2 axis to the endpoint
        radius = math.sqrt(x*x + y*y + end_z*end_z)

        # use the law of cosines to compute the elbow angle solely as a function of the
        # link lengths and the radial distance from j2 to end
        #   R**2 = l1**2 + l2**2 - 2*l1*l2*cos(pi - elbow)
        acosarg = (radius*radius - self.link1**2 - self.link2**2) / (-2 * self.link1 * self.link2)
        if acosarg < -1.0:  elbow_supplement = math.pi
        elif acosarg > 1.0: elbow_supplement = 0.0
        else:               elbow_supplement = math.acos(acosarg)

        print("theta:", theta, "radius:", radius, "acosarg:", acosarg)

        # use the law of sines to find the angle at the bottom vertex of the triangle defined by the links
        #  radius / sin(elbow_supplement)  = l2 / sin(alpha)
        if radius > 0.0:
            alpha = math.asin(self.link2 * math.sin(elbow_supplement) / radius)
        else:
            alpha = 0.0

        # calculate the joint angles
        return [j1, math.pi/2 - theta - alpha, math.pi - elbow_supplement]


    #================================================================
    # motion primitives

    def go_joints(self, target):
        """Issue a position command to move to the given endpoint position expressed in joint angles as radians."""

        self.motor1.setPosition(target[0])
        self.motor2.setPosition(target[1])
        self.motor3.setPosition(target[2])
        self.joint_target = target
        # print("%s: moving to (%f, %f, %f)" % (self.robot_name, target[0], target[1], target[2]));


    #================================================================
    # Polling function to process camera input.  The API doesn't seem to include
    # any method for detecting frame capture, so this just samples at the
    # expected frame rate.  It is possible it might retrieve duplicate frames or
    # experience time aliasing.

    def poll_camera(self):
        self.camera_frame_timer -= EVENT_LOOP_DT
        if self.camera_frame_timer < 0:
            self.camera_frame_timer += self.camera_frame_dt
            # read the camera
            image = self.camera.getImage()
            if image is not None:
                # print("read %d byte image from camera" % (len(image)))
                self.frame_count += 1

                # convert the image data from a raw bytestring into a 3D matrix
                # the pixel format is BGRA, which is the default for OpenCV
                width  = self.camera.getWidth()
                height = self.camera.getHeight()
                frame  = np.frombuffer(image, dtype=np.uint8).reshape(height, width, 4)

                # temporary: write image to file for offline testing
                # cv.imwrite("capture-images/%04d.png" % (self.frame_count), frame)

                # convert to grayscale and normalize levels
                frame = cv.cvtColor(frame, cv.COLOR_BGR2GRAY)

                # extract features from the image
                render = self._find_round_image_targets(frame)

                # display a debugging image in the Webots display
                imageRef = self.display.imageNew(render.tobytes(), format=Display.RGB, width=width, height=height)
                self.display.imagePaste(imageRef, 0, 0, False)

    #================================================================
    # Analyze a grayscale image for circular targets.
    # Updates the internal state.  Returns and annotated debugging image.
    def _find_round_image_targets(self, grayscale):
        circles = vision.find_circles(grayscale)

        # generate controls from the feature data
        if circles is None:
            if self.vision_target is not None:
                print("target not visible, tracking dropped")
                self.vision_target = None
        else:
            # find the offset of the first circle center from the image center in pixels,
            # ignoring the radius value
            height, width = grayscale.shape
            halfsize = np.array((width/2, height/2))
            offset = circles[0][0:2] - halfsize

            # normalize to unit scale, i.e. range of each dimension is (-1,1)
            self.vision_target = offset / halfsize

            # invert the Y coordinate so that positive values are 'up'
            self.vision_target[1] *= -1.0
            # print("target: ", self.vision_target)

        # produce a debugging image
        return vision.annotate_circles(grayscale, circles)

    #================================================================

    # Define joint-space movement sequences as a Bezier cubic spline
    # trajectories specified in degrees.  Each sequence should have three rows
    # per spline segment.

    # This initial motion brings the arm from the vertical home position to a
    # 'work' pose with the camera pointing forward.

    _home_to_work = np.array(
        [
            [  0,  0,  0],
            [  0, 10, 60],
            [  0, 10, 60], # waypoint
            ])

    # This begins and ends at the 'work' pose, and is looped by poll_spline_path().
    _path = np.array([
                      [  0,  10, 60],
                      [  15, 10, 60],
                      [  15, 10, 60], # waypoint

                      [  15, 10, 60],
                      [ -15, 10, 60],
                      [ -15, 10, 60], # waypoint

                      [  15, 10, 60],
                      [   0, 10, 60],
                      [   0, 10, 60], # waypoint

                      [   0, 10, 60],
                      [   0, 10, 60],
                      [   0, 10, 60], # waypoint
        ])


    def _init_spline(self):
        """Create the Bezier spline interpolator and add the initial path from the home position."""
        self.interpolator = bezier.PathSpline(axes=3)
        self.interpolator.set_tempo(45.0)
        self.interpolator.add_spline(self._home_to_work)
        self._last_segment_count = self.interpolator.segments()

    def poll_spline_path(self):
        """Update function to loop a cubic spline trajectory."""
        degrees = self.interpolator.update(dt=EVENT_LOOP_DT*0.001)
        radians = [math.radians(theta) for theta in degrees]

        # mirror the base rotation of the right robot for symmetric motion
        if 'right' in self.robot_name:
            radians[0] *= -1

        self.go_joints(radians)
        # print("u: ", self.interpolator.u, "knots:", self.interpolator.knots.shape[0])

        segment_count = self.interpolator.segments()
        if segment_count != self._last_segment_count:
            self._last_segment_count = segment_count
            print("Starting next spline segment.")

        if self.interpolator.is_done():
            print("Restarting the spline path.")
            self.interpolator.add_spline(self._path)

    #================================================================
    # Define joint-space movement sequences.  For convenience the joint angles
    # are specified in degrees, then converted to radians for the controllers.
    _right_poses = [[0, 0, 0],
                    [-45, 0, 60],
                    [45, 60, 60],
                    [0, 60, 0],
                    ]

    _left_poses = [[0, 0, 0],
                   [45, 0, 60],
                   [-45, 60, 60],
                   [0, 60, 0],
                   ]

    #================================================================
    def poll_sequence_activity(self):
        """State machine update function to walk through a series of poses at regular intervals."""

        # Update the state timer
        self.state_timer -= EVENT_LOOP_DT

        # If the timer has elapsed, reset the timer and update the outputs.
        if self.state_timer < 0:
            self.state_timer += 2000

            # Look up the next pose.
            if 'left' in self.robot_name:
                next_pose = self._left_poses[self.state_index]
                self.state_index = (self.state_index + 1) % len(self._left_poses)
            else:
                next_pose = self._right_poses[self.state_index]
                self.state_index = (self.state_index + 1) % len(self._right_poses)

            # Convert the pose to radians and issue to the motor controllers.
            angles = [math.radians(next_pose[0]), math.radians(next_pose[1]), math.radians(next_pose[2])]
            self.go_joints(angles)

    #================================================================
    def apply_target_tracking(self):
        # self.vision_target is a 2D vector representing the normalized target
        # displacement from the image center.  The following approximates the
        # robot yaw and pitch error in radians using the half-angle of the field
        # of view.  The negative sign results from the robot kinematics, e.g. a
        # positive yaw moves the camera to the left, so a positive image plane
        # error needs a negative yaw to track toward the right.

        field_of_view = 0.75 # horizontal FOV of Webots camera in radians
        yaw_error, pitch_error = -0.5 * field_of_view * self.vision_target

        # print("yaw and pitch error:", yaw_error, pitch_error)

        # Calculate a joint velocity to reduce angular error.  For now, this
        # uses the base Z axis to control yaw and the elbow Y axis to control
        # pitch; the shoulder Y axis is unchanged.  The feedback gain scales
        # angular error to angular velocity.
        joint_velocity = np.array((0.5 * yaw_error, 0.0, 0.5 * pitch_error))
        print("target tracking joint velocity:", joint_velocity)

        # Update the position targets for this cycle for the given velocity.
        dt = 0.001 * EVENT_LOOP_DT
        new_joint_target = self.joint_target + dt * joint_velocity

        self.go_joints(new_joint_target)
        return


   #================================================================
    def run(self):
        # Run loop to execute a periodic script until the simulation quits.
        # If the controller returns -1, the simulator is quitting.
        while self.step(EVENT_LOOP_DT) != -1:
            # Read simulator clock time.
            self.sim_time = self.getTime()

            # Process available camera data.
            self.poll_camera()

            # If a target is visible, try tracking it
            if self.vision_target is not None:
                self.apply_target_tracking()

            else:
                # Update the spline trajectory generator.
                self.poll_spline_path()

################################################################
# Start the script.
robot = ZYY()
robot.run()
