#!/usr/bin/env python3
cmd_desc = "Plot a single-channel cubic Bezier spline using matplotlib."

import argparse
import numpy as np
import matplotlib.pyplot as plt

################################################################
def Bezier_spline_segment_coefficients(N):
    """Return a (N,4) coefficient fitting matrix which can be used to calculate the
    interpolated Bezier spline values for N points along a single spline
    segment.  In the 1-D case, when the resulting matrix B premultiplies a (4,)
    or (4,1) vector of knot points P, the result is a (N,) or (N,1) matrix of Y
    values: Y = np.dot(B, P).  Higher dimensions can be computed in the same
    way: for D dimensions, B premultiplies a (4,D) matrix of D-dimensional knot
    point result is a (N,D) matrix of interpolated vectors.
    """

    # Bezier cubic spline basis definition. These are coefficients that
    # generate a set of polynomials given a set of control points:
    #
    #   f(t) =  [ t^3  t^2   t   1 ] * basis * [p0 p1 p2 p3]'
    #
    Bezier_basis = np.array(((-1, 3, -3, 1), (3, -6, 3, 0), (-3, 3, 0, 0), (1, 0, 0, 0)))

    # generate the fixed u matrix of polynomial terms
    u = np.linspace(0.0, 1.0, num=N)
    u2 = u*u;
    u3 = u2 * u;
    ones = np.ones((u.shape[0]))
    umat = np.stack((u3, u2, u, ones)).T

    # multiply to find the fitting matrix
    coeff = np.dot(umat, Bezier_basis)
    return coeff

################################################################
if __name__ == "__main__":
    # Initialize the command parser.
    parser = argparse.ArgumentParser(description = cmd_desc)
    parser.add_argument('knots', nargs='*', help = "Spline knot values (minimum of four).")

    # Parse the command line, returning a Namespace.
    args = parser.parse_args()

    if len(args.knots) > 3:
        knots = np.array([float(arg) for arg in args.knots])

    else:
        print("Using demonstration spline.")
        # supply a demo Bezier spline
        knots = np.array((0, 0, 1, 1,   1, 0, 0,   0.5, 1.0, 0.8,  0.8, -0.2, 0.0,   -0.2, 0.25, 0.5,   0.75, 1.0, 1.0))

    print("Using", knots)
    numknots = knots.shape[0]
    segments = (numknots-1)//3

    # set up interpolation
    N = 51
    coeff = Bezier_spline_segment_coefficients(N)

    # empty array to hold the result, will be calculated segment by segment
    y = np.ndarray((1+segments*(N-1)))
    for seg in range(segments):
        yidx = seg*(N-1)
        pidx = seg*3
        y[yidx:yidx+N] = np.dot(coeff, knots[pidx:pidx+4])

    # calculate all u values for plotting
    u = np.linspace(0.0, segments, num = 1+(segments*(N-1)))

    # calculate the X positions of the knot points for the scatter plot, including excess knots
    sx = np.linspace(0.0, segments, num=numknots)

    # plot the curve
    fig, ax = plt.subplots()
    ax.plot(u, y)
    ax.set_xlabel("u")
    ax.set_ylabel("y")

    # plot the knots
    ax.scatter(sx, knots)

    # display in a window and wait until closed
    plt.show()
