# bezier.py
#
# Support for Bezier cubic spline interpolation and path generation.

# Import the third-party numpy library for matrix calculations.
# Note: this can be installed using 'pip3 install numpy' or 'pip3 install scipy'.
import numpy as np

#================================================================
class PathSpline:
    """Representation of a multi-axis multi-segment Bezier cubic spline
    trajectory.  This may be useful for interpolating a path through joint space
    or Cartesian space.
    """

    def __init__(self, axes=3):
        # Initialize a default spline buffer.  The invariant is that this will
        # always retain at least one segment.  It should always have (1+3*segs)
        # rows where segs is the number of path segments.
        self.knots = np.zeros((4, axes))

        # Set the default execution rate in spline segments per second.
        self.rate  = 1.0

        # Initialize the current position at the end of the path.  The position
        # advances by one unit per segment.
        self.u     = 1.0

        # Initialize the current interpolated value.
        self.value = np.zeros((axes))

    def set_tempo(self, spm):
        """Set the execution rate of the spline interpolator in units of
        segments/minute.  The initial default value is 60 segs/min (1 per
        second)."""
        self.rate = max(min(spm/60.0, 5), 0.0001)

    def axes(self):
        """Return the number of axes."""
        return self.knots.shape[1]

    def segments(self):
        """Return the number of spline segments in the path."""
        return (self.knots.shape[0]-1)//3

    def update(self, dt):
        """Advance the spline position given the elapsed interval dt (in
        seconds) and recalculate the spline interpolation.  Returns the current
        interpolated value."""
        if dt > 0.0:
            self.u += self.rate * dt
            #  Determine whether to roll over to the next segment or not.  This will
            #  always leave one segment remaining.
            segs = (self.knots.shape[0]-1) // 3
            if segs > 1:
                if self.u >= 1.0:
                    # u has advanced to next segment, drop the first segment by dropping the first three knots
                    self.knots = self.knots[3:]
                    self.u -= 1.0

        # always limit u to the first active segment
        self.u = min(max(self.u, 0.0), 1.0)

        # update interpolation
        self._recalculate()
        return self.value

    def is_done(self):
        """Returns True if the evaluation has reached the end of the last spline segment."""
        return (self.u >= 1.0) and (self.knots.shape[0] == 4)

    def _recalculate(self):
        """Apply the De Casteljau algorithm to calculate the position of
        the spline at the current path point."""

        # unroll the De Casteljau iteration and save intermediate points
        num_axes = self.knots.shape[1]
        beta = np.ndarray((3, num_axes))

        # note: could use more numpy optimizations
        u = self.u
        for k in range(3):
            beta[k] = self.knots[k] * (1 - u) + self.knots[k+1] * u

        for k in range(2):
            beta[k] =  beta[k] * (1 - u) +  beta[k+1] * u;

        self.value[:] = beta[0] * (1 - u) + beta[1] * u

    def set_target(self, value):
        """Reset the path generator to a single spline segment from the current
        value to the given value.  The path defaults to 'ease-out', meaning the
        rate of change is discontinuous at the start and slows into the target

        :param value: iterable with one value per axis"""
        self.knots[0,:] = self.value   # start at the current levels
        self.knots[1,:] = value        # derivative discontinuity
        self.knots[2,:] = value        # ease at the target
        self.knots[3,:] = value        # target endpoint

        # reset the path position to start the new segment
        self.u = 0.0;

    def set_value(self, value):
        """Reset the path generator immediately to the given value with no
        interpolation.

        :param value: iterable with one value per axis"""

        self.knots[0,:] = value
        self.knots[1,:] = value
        self.knots[2,:] = value
        self.knots[3,:] = value

        # reset the path position to start the new segment
        self.u = 1.0;

        # reset current value
        self.value[:] = value

    def add_spline(self, knots):
        """Append a sequence of Bezier spline segments to the path
        generator, provided as a numpy matrix with three rows per segment and
        one column per axis."""

        self.knots = np.vstack((self.knots, knots))

        # Note: the path position and current value are left unchanged.  If the
        # interpolator is at the end of the current segment, it will now begin
        # to advance into these new segments.
        return
