# zyy.py
#
# Sample Webots controller file for driving a two-link arm with three driven
# joints.  This example provides inverse kinematics for performing
# position-controlled trajectories.

# No copyright, 2020-2024, Garth Zeglin.  This file is
# explicitly placed in the public domain.

print("zyy.py waking up.")

# Import the Webots simulator API.
from controller import Robot

# Import the standard Python math library.
import math, random, time

# Import the bezier module from the same directory.  Note: this requires numpy.
import bezier

# Import the third-party numpy library for matrix calculations.
# Note: this can be installed using 'pip3 install numpy' or 'pip3 install scipy'.
import numpy as np

# Define the time step in milliseconds between controller updates.
EVENT_LOOP_DT = 20

################################################################
class ZYY(Robot):
    def __init__(self):

        super(ZYY, self).__init__()
        self.robot_name = self.getName()
        print("%s: controller connected." % (self.robot_name))

        # Attempt to randomize the random library sequence.
        random.seed(time.time())

        # Initialize geometric constants.  These should match
        # the current geometry of the robot.
        self.base_z  = 0.29  # joint2 axis height: 0.12 fixed base height + 0.17 j2 z offset
        self.link1   = 0.5   # link1 length, i.e. distance between j2 and j3
        self.link2   = 0.5   # link2 length, i.e. distance between j3 and end

        # Fetch handles for the joint motors.
        self.motor1 = self.getDevice('motor1')
        self.motor2 = self.getDevice('motor2')
        self.motor3 = self.getDevice('motor3')

        # Adjust the motor controller properties.
        self.motor1.setAvailableTorque(20.0)
        self.motor2.setAvailableTorque(15.0)
        self.motor3.setAvailableTorque(10.0)

        # Adjust the low-level controller gains.
        print("%s: setting PID gains." % (self.robot_name))
        self.motor1.setControlPID(100.0, 0.0, 25.0)
        self.motor2.setControlPID( 50.0, 0.0, 15.0)
        self.motor3.setControlPID( 50.0, 0.0, 15.0)

        # Fetch handles for the joint sensors.
        self.joint1 = self.getDevice('joint1')
        self.joint2 = self.getDevice('joint2')
        self.joint3 = self.getDevice('joint3')

        # Specify the sampling rate for the joint sensors.
        self.joint1.enable(EVENT_LOOP_DT)
        self.joint2.enable(EVENT_LOOP_DT)
        self.joint3.enable(EVENT_LOOP_DT)

        # Connect to the end sensor.
        self.end_sensor = self.getDevice("endRangeSensor")
        self.end_sensor.enable(EVENT_LOOP_DT) # set sampling period in milliseconds
        self.end_sensor_interval = 1000
        self.end_sensor_timer = 1000

        # Initialize behavior state machines.
        self.state_timer = 2*EVENT_LOOP_DT
        self.state_index = 0
        self._init_spline()
        return

    #================================================================
    def endpoint_forward_kinematics(self, q):
        """Compute the forward kinematics for the end point.  Returns the
        body-coordinate XYZ Cartesian position of the end point for a given joint
        angle vector.
        :param q: three-element list with [q1, q2, q3] joint angles in radians
        :return:  three-element list with endpoint [x,y,z] location

        """
        j1 = q[0]
        j2 = q[1]
        j3 = q[2]

        return [(self.link1*math.sin(j2) + self.link2*math.sin(j2 + j3))*math.cos(j1),
                (self.link1*math.sin(j2) + self.link2*math.sin(j2 + j3))*math.sin(j1),
                self.base_z + self.link1*math.cos(j2) + self.link2*math.cos(j2 + j3)]

    #================================================================
    def endpoint_inverse_kinematics(self, target):
        """Compute the joint angles for a target end position.  The target is a
        XYZ Cartesian position vector in body coordinates, and the result vector
        is a joint angles as list [j1, j2, j3].  If the target is out of reach,
        returns the closest pose.  With j1 between -pi and pi, and j2 and j3
        limited to positive rotations, the solution is always unique.
        """

        x = target[0]
        y = target[1]
        z = target[2]

        # the joint1 base Z rotation depends only upon x and y
        j1 = math.atan2(y, x)

        # distance within the XY plane from the origin to the endpoint projection
        xy_radial = math.sqrt(x*x + y*y)

        # find the Z offset from the J2 horizontal plane to the end point
        end_z = z - self.base_z

        # angle between the J2 horizonta plane and the endpoint
        theta = math.atan2(end_z, xy_radial)

        # radial distance from the J2 axis to the endpoint
        radius = math.sqrt(x*x + y*y + end_z*end_z)

        # use the law of cosines to compute the elbow angle solely as a function of the
        # link lengths and the radial distance from j2 to end
        #   R**2 = l1**2 + l2**2 - 2*l1*l2*cos(pi - elbow)
        acosarg = (radius*radius - self.link1**2 - self.link2**2) / (-2 * self.link1 * self.link2)
        if acosarg < -1.0:  elbow_supplement = math.pi
        elif acosarg > 1.0: elbow_supplement = 0.0
        else:               elbow_supplement = math.acos(acosarg)

        print("theta:", theta, "radius:", radius, "acosarg:", acosarg)

        # use the law of sines to find the angle at the bottom vertex of the triangle defined by the links
        #  radius / sin(elbow_supplement)  = l2 / sin(alpha)
        if radius > 0.0:
            alpha = math.asin(self.link2 * math.sin(elbow_supplement) / radius)
        else:
            alpha = 0.0

        # calculate the joint angles
        return [j1, math.pi/2 - theta - alpha, math.pi - elbow_supplement]


    #================================================================
    # motion primitives

    def go_joints(self, target):
        """Issue a position command to move to the given endpoint position expressed in joint angles."""

        self.motor1.setPosition(target[0])
        self.motor2.setPosition(target[1])
        self.motor3.setPosition(target[2])
        # print("%s: moving to (%f, %f, %f)" % (self.robot_name, target[0], target[1], target[2]));


    #================================================================
    # Polling function to process sensor input at different timescales.
    def poll_sensors(self):
        self.end_sensor_timer -= EVENT_LOOP_DT
        if self.end_sensor_timer < 0:
            self.end_sensor_timer += self.end_sensor_interval

            # read the distance sensor
            distance = self.end_sensor.getValue()

            if distance < 0.9:
                # print("%s: range sensor detected obstacle at %f." % (self.robot_name, distance))
                pass

    #================================================================

    # Define a joint-space movement sequence as a Bezier cubic spline trajectory
    # specified in degrees.  This should have three rows per spline segment.

    _path = np.array([               # the first waypoint is implicitly zero
                      [  0,  0,  0],
                      [  5, 30, 30],
                      [  5, 30, 30], # waypoint
                      [  5, 30, 30],
                      [  5, 30, 30],
                      [  5, 30, 90], # waypoint
                      [  5, 30, 30],
                      [  5, 30, 30],
                      [  5, 30, 30], # waypoint
                      [  5, 30, 30],
                      [  0,  0,  0],
                      [  0,  0,  0], # waypoint
                      [  0,  0,  0],
                      [ 45, 30, 30],
                      [ 45, 30, 30], # waypoint
                      [ 45, 30, 30],
                      [ 45, 30, 30],
                      [ 45, 30, 90], # waypoint
                      [ 45, 30, 30],
                      [ 45, 30, 30],
                      [ 45, 30, 30], # waypoint
                      [ 45, 30, 30],
                      [  0,  0,  0],
                      [  0,  0,  0], # waypoint (should be zero for continuity)
                      ])

    def _init_spline(self):
        self.interpolator = bezier.PathSpline(axes=3)
        self.interpolator.set_tempo(30.0)
        self.interpolator.add_spline(self._path)
        self._last_segment_count = self.interpolator.segments()

    def poll_spline_path(self):
        """Update function to loop a cubic spline trajectory."""
        degrees = self.interpolator.update(dt=EVENT_LOOP_DT*0.001)
        radians = [math.radians(theta) for theta in degrees]

        # mirror the base rotation of the right robot for symmetric motion
        if 'right' in self.robot_name:
            radians[0] *= -1

        self.go_joints(radians)
        # print("u: ", self.interpolator.u, "knots:", self.interpolator.knots.shape[0])

        segment_count = self.interpolator.segments()
        if segment_count != self._last_segment_count:
            self._last_segment_count = segment_count
            print("Starting next spline segment.")

        if self.interpolator.is_done():
            print("Restarting the spline path.")
            self.interpolator.add_spline(self._path)

    #================================================================
    # Define joint-space movement sequences.  For convenience the joint angles
    # are specified in degrees, then converted to radians for the controllers.
    _right_poses = [[0, 0, 0],
                    [-45, 0, 60],
                    [45, 60, 60],
                    [0, 60, 0],
                    ]

    _left_poses = [[0, 0, 0],
                   [45, 0, 60],
                   [-45, 60, 60],
                   [0, 60, 0],
                   ]

    #================================================================
    def poll_sequence_activity(self):
        """State machine update function to walk through a series of poses at regular intervals."""

        # Update the state timer
        self.state_timer -= EVENT_LOOP_DT

        # If the timer has elapsed, reset the timer and update the outputs.
        if self.state_timer < 0:
            self.state_timer += 2000

            # Look up the next pose.
            if 'left' in self.robot_name:
                next_pose = self._left_poses[self.state_index]
                self.state_index = (self.state_index + 1) % len(self._left_poses)
            else:
                next_pose = self._right_poses[self.state_index]
                self.state_index = (self.state_index + 1) % len(self._right_poses)

            # Convert the pose to radians and issue to the motor controllers.
            angles = [math.radians(next_pose[0]), math.radians(next_pose[1]), math.radians(next_pose[2])]
            self.go_joints(angles)


   #================================================================
    def run(self):
        # Run loop to execute a periodic script until the simulation quits.
        # If the controller returns -1, the simulator is quitting.
        while self.step(EVENT_LOOP_DT) != -1:
            # Read simulator clock time.
            self.sim_time = self.getTime()

            # Read sensor values.
            self.poll_sensors()

            # Update the activity state machine.
            # self.poll_sequence_activity()

            # Update the spline trajectory generator.
            self.poll_spline_path()


################################################################
# Start the script.
robot = ZYY()
robot.run()
