#!/usr/bin/env python3

usage_guide = """\
This script generates C/C++ code for a decision tree classifier using the
scikit-learn machine learning library.

The code is targeted toward the Arduino API, so it creates an explicit tree
using constants instead of using tables.  It defaults to integer inputs but can
also generate floating point code.  The scikit-learn decision tree model uses
floating point, so the integer model entails a loss of accuracy as the price for
a faster, smaller model.  This can be mitigated by scaling the training data to
use the full fixed point range of 16-bit integers.

References:
  https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html
  https://en.wikipedia.org/wiki/Decision_tree_learning
"""

################################################################
# Standard Python libraries.
import sys, argparse, logging

# Set up debugging output.
# logging.getLogger().setLevel(logging.DEBUG)

# Extension libraries.
import numpy as np
import sklearn.tree

# Optional extension libraries.
try:
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    _matplotlib_available = True

except ModuleNotFoundError:
    logging.warning("Matplotlib not found, plot file generation not available.")
    _matplotlib_available = False

################################################################
# Recursively generate predicate code for a given binary tree node.
def node_code_gen(stream, clf, node=0, level=0, integer=False):
    # properly indent the output code
    prefix    = "  "*(level+1)

    # look up the node properties in each numerical array defining the tree
    left  = clf.tree_.children_left[node]
    right = clf.tree_.children_right[node]

    if left != right:
        threshold = clf.tree_.threshold[node]
        feature   = clf.tree_.feature[node]
        if integer:
            threshold = int(threshold)
        stream.write(f"{prefix}if (input[{feature}] <= {threshold}) {{\n")
        node_code_gen(stream, clf, left, level+1, integer)
        stream.write(f"{prefix}}} else {{\n")
        node_code_gen(stream, clf, right, level+1, integer)
        stream.write(f"{prefix}}}\n")

    else:
        value = clf.tree_.value[node]
        cls = np.argmax(value)
        stream.write(f"{prefix}return {cls};\n")

################################################################        
# Write an C/C++ file with a classification function.
def emit_classifier(filename, clf, name='classify', dims=2, integer=False):
    with open(filename,"w") as output:
        output.write(f"// Decision tree classifier generated using classify_gen.py\n")
        itype = 'int' if integer else False
        output.write(f"int {name}({itype} input[{dims}])\n{{\n")
        node_code_gen(output, clf, 0, 0, integer)
        output.write(f"}}\n")

################################################################
# Recursively plot a 2-D binary decision tree as colored rectangles segmenting the plane.
def plot_2d_binary_segmentation(ax, clf, node, bounds):
    left      = clf.tree_.children_left[node]
    right     = clf.tree_.children_right[node]
    if left != right:
        threshold = clf.tree_.threshold[node]
        feature   = clf.tree_.feature[node]
        leftbounds = bounds.copy()
        leftbounds[feature][1] = threshold
        plot_2d_binary_segmentation(ax, clf, left, leftbounds)
        rightbounds = bounds.copy()
        rightbounds[feature][0] = threshold
        plot_2d_binary_segmentation(ax, clf, right, rightbounds)
    else:
        cls = np.argmax(clf.tree_.value[node])
        rectangle = patches.Rectangle((bounds[0][0], bounds[1][0]), # lower left corner
                                      bounds[0][1] - bounds[0][0],  # width
                                      bounds[1][1] - bounds[1][0],  # height
                                      edgecolor = 'black',
                                      facecolor = "C%d" % cls,
                                      alpha = 0.5)

        ax.add_patch(rectangle)
        
################################################################
# Write an image file with a plot of 2-dimensional training data overlaid
# over the planar segmentation defined by a binary decision tree.
def plot_2d_tree(clf, samples, filename):
    fig, ax = plt.subplots(nrows=1)
    fig.set_dpi(160)
    fig.set_size_inches((8,6))

    xmin = np.min(samples[:,0])
    xmax = np.max(samples[:,0])
    ymin = np.min(samples[:,1])
    ymax = np.max(samples[:,1])
    bounds = np.array(((xmin,xmax),(ymin, ymax)))
    plot_2d_binary_segmentation(ax, clf, 0, bounds)

    ax.plot(samples[:,0], samples[:,1], 'o')
    ax.set_title(f"Decision Tree Classifier")
    ax.set_xlabel("Position (cm)")
    ax.set_ylabel("Velocity (cm/sec)")
    fig.savefig(filename)

################################################################
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="""Generate C code for a decision tree classifier.""",
                                     formatter_class=argparse.RawDescriptionHelpFormatter,
                                     epilog=usage_guide)
    parser.add_argument('--training', type=str, default='data/training.csv', help='Path of training data input file.')
    parser.add_argument('--space', action='store_true', help = 'Assume input is whitespace-delimited (default is comma delimited).')
    parser.add_argument('--leaf', type=int, default=10, help = 'Minimum number of samples in each leaf node (default is 10).')
    parser.add_argument('--name', type=str, default='classify', help = 'Name of C filter function.')
    parser.add_argument('--float', action='store_true', help = 'Enable floating point code (default is integer).')
    parser.add_argument('--out', type=str, default="classify.ino", help='Path of C code output file.')
    parser.add_argument('--plot', type=str, default="plots/classify.png", help='Path of plot image output file.')
    args = parser.parse_args()

    # Each training data line is a sample prefixed with an integer label, separated by whitespace.
    if args.space:
        training = np.loadtxt(args.training)        
    else:
        training = np.loadtxt(args.training, delimiter=',')

    num_samples = len(training)
    dims = len(training[0]) - 1
    labels =  training[:,0].astype(int)
    samples = training[:,1:]

    # Assume the data is somewhat noisy, so require each leaf node to represent some minimum
    # number of samples specified by 'min_samples_leaf' to reduce overfitting.
    clf = sklearn.tree.DecisionTreeClassifier(min_samples_leaf=args.leaf).fit(samples, labels)

    # Generate the C code file.
    emit_classifier(args.out, clf, args.name, dims, args.float is not True)

    # Generate the plot image file.
    if _matplotlib_available and args.plot is not None and dims == 2:
        plot_2d_tree(clf, samples, args.plot)
