"""Generate forward kinematic functions for a ZYY three-DOF two-link arm.

Copyright (c) 2013-2024 Garth Zeglin. All rights reserved. Licensed under the
terms of the BSD 3-clause license as included in LICENSE.

N.B. These have not been checked in detail.
"""

import os, sys, time, argparse

import sympy as sym

# import a utility class for symbolic manipulation of homogeneous transforms
import transform3d

#================================================================
class ZYYKinematicsGenerator(object):
    """Generator to derive kinematic functions for the ZYY arm using sympy.

    The arm joint angles are specified in radians.
    """
    def __init__(self, verbose=False):

        self.verbose = verbose

        # Generate lists of the needed symbols.
        self._joint_symbols = sym.symbols('j1 j2 j3')
        self._parameter_symbols = sym.symbols('base_z link1 link2')
        self._point_symbols = sym.symbols('x y z')

        # Generate kinematics functions using sympy.
        self._generate_endpoint_frame_transform()
        self._generate_endpoint_inverse_kinematics()

        return

    #================================================================
    def _generate_endpoint_frame_transform(self):
        """Generate a function to compute the homogeneous transform representing the
        default endpoint frame parameterized by kinematic state.
        """
        if self.verbose: print("Generating endpoint frame transforms...")
        j1, j2, j3 = self._joint_symbols
        base, link1, link2 = self._parameter_symbols

        m = transform3d.Transform3D()

        # apply base Z rotation between fixed base frame and the rotating base frame
        m.rotate_z(j1)

        # apply Z offset between origin at ground height and J2 axis
        m.translate(0, 0, base)

        # apply Y rotation for joint2
        m.rotate_y(j2)

        # apply the Z offset between joint1 and joint2
        m.translate(0, 0, link1)

        # apply Y rotation for joint3
        m.rotate_y(j3)

        # apply the Z offset between joint3 and the end
        m.translate(0, 0, link2)

        # find the homogeneous transform representing the end frame
        end_transform = sym.trigsimp(m.ctm)

        # find an expression for the origin of the end frame
        end_vector = end_transform[0:3, 3]

        print("end vector as expression:", end_vector)
        print("\nend vector location as code:", sym.pycode(end_vector))
        return

    #================================================================
    def _generate_endpoint_inverse_kinematics(self):
        """Generate a function to compute the joint angle vector for a given
        endpoint location.
        """
        j1, j2, j3 = self._joint_symbols
        base, link1, link2 = self._parameter_symbols
        x, y, z    = self._point_symbols

        # the joint1 base Z rotation depends only upon x and y
        j1_angle = sym.atan2(y, x)

        # distance in the XY plane from the origin to the endpoint projection
        radial = sym.sqrt(x*x + y*y)

        # angle between the XY plane and the endpoint
        theta = sym.atan2(z, radial)

        # distance from the origin to the endpoint
        distance = sym.sqrt(x*x + y*y + z*z)

        # use the law of cosines to compute the elbow angle
        #   R**2 = l1**2 + l2**2 - 2*l1*l2*cos(pi - elbow)
        acosarg = (distance*distance - link1**2 - link2**2) / (-2 * link1*link2)
        elbow_supplement = sym.acos(acosarg)

        # use the law of sines to find the angle at the bottom vertex of the triangle defined by the links
        #  distance / sin(elbow_supplement)  = l2 / sin(alpha)
        alpha = sym.asin(link2 * sym.sin(elbow_supplement) / distance)

        joint_angles = sym.trigsimp(sym.Matrix([theta - alpha, sym.pi - elbow_supplement]))
        print("inverse kinematics as expression:", joint_angles)

        return


################################################################
if __name__ == "__main__":
    # process command line arguments
    parser = argparse.ArgumentParser(description = """Kinematic function code generators.""")
    parser.add_argument( '-v', '--verbose', action='store_true', help='Enable more detailed output.' )
    args = parser.parse_args()

    if args.verbose:
        print("Beginning derivation of kinematics functions...")
        start = time.time()

    kine = ZYYKinematicsGenerator(verbose=args.verbose)

    if args.verbose:
        end = time.time()
        print("Completed in %f seconds." % (end - start))
