#!/usr/bin/env python3

usage_guide = """\
This script generates Python code for a decision tree classifier using the
scikit-learn machine learning library.

The default mode creates an explicit tree using constants instead of using
tables.  This is intended for pedagogical purposes.

It defaults to float inputs but can also generate integer point code.  The
scikit-learn decision tree model uses floating point, so the integer model
entails a loss of accuracy.  This can be mitigated by scaling the training data
to use the full fixed point range of 16-bit integers.

References:
  https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html
  https://en.wikipedia.org/wiki/Decision_tree_learning
"""

################################################################
# Standard Python libraries.
import sys, argparse, logging

# Set up debugging output.
# logging.getLogger().setLevel(logging.DEBUG)

# Extension libraries.
import numpy as np
import sklearn.tree

# Optional extension libraries.
try:
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    _matplotlib_available = True

except ModuleNotFoundError:
    logging.warning("Matplotlib not found, plot file generation not available.")
    _matplotlib_available = False

################################################################
# Recursively generate predicate code for a given binary tree node.
def node_code_gen(stream, clf, node=0, level=0, integer=False):
    # properly indent the output code
    prefix    = "  "*(level+1)

    # look up the node properties in each numerical array defining the tree
    left  = clf.tree_.children_left[node]
    right = clf.tree_.children_right[node]

    if left != right:
        threshold = clf.tree_.threshold[node]
        feature   = clf.tree_.feature[node]
        if integer:
            threshold = int(threshold)
        stream.write(f"{prefix}if input[{feature}] <= {threshold}:\n")
        node_code_gen(stream, clf, left, level+1, integer)
        stream.write(f"{prefix}else:\n")
        node_code_gen(stream, clf, right, level+1, integer)

    else:
        value = clf.tree_.value[node]
        cls = np.argmax(value)
        stream.write(f"{prefix}return {cls}\n")

################################################################        
# Write a Python file with a classification function.
def emit_classifier(filename, clf, name='classify', dims=2, integer=False):
    with open(filename,"w") as output:
        output.write(f"# Decision tree classifier generated using classifier_gen.py\n")
        output.write(f"# Input vector has {dims} elements.\n")
        output.write(f"def {name}(input):\n")
        node_code_gen(output, clf, 0, 0, integer)

        # also emit the underlying tables to use with an evaluation function
        output.write("\n\n# tree data\n")
        output.write("children_left  = %s\n" % (np.array2string(clf.tree_.children_left,  separator=',')))
        output.write("children_right = %s\n" % (np.array2string(clf.tree_.children_right, separator=',')))
        output.write("feature        = %s\n" % (np.array2string(clf.tree_.feature,        separator=',')))
        output.write("threshold      = %s\n" % (np.array2string(clf.tree_.threshold,      separator=',')))

################################################################
# Recursively plot a 2-D binary decision tree as colored rectangles segmenting the plane.
def plot_2d_binary_segmentation(ax, clf, node, bounds):
    left      = clf.tree_.children_left[node]
    right     = clf.tree_.children_right[node]
    if left != right:
        threshold = clf.tree_.threshold[node]
        feature   = clf.tree_.feature[node]
        leftbounds = bounds.copy()
        leftbounds[feature][1] = threshold
        plot_2d_binary_segmentation(ax, clf, left, leftbounds)
        rightbounds = bounds.copy()
        rightbounds[feature][0] = threshold
        plot_2d_binary_segmentation(ax, clf, right, rightbounds)
    else:
        cls = np.argmax(clf.tree_.value[node])
        rectangle = patches.Rectangle((bounds[0][0], bounds[1][0]), # lower left corner
                                      bounds[0][1] - bounds[0][0],  # width
                                      bounds[1][1] - bounds[1][0],  # height
                                      edgecolor = 'black',
                                      facecolor = "C%d" % cls,
                                      alpha = 0.5)

        ax.add_patch(rectangle)
        
################################################################
# Write an image file with a plot of 2-dimensional training data overlaid
# over the planar segmentation defined by a binary decision tree.
def plot_2d_tree(clf, samples, labels, filename):
    fig, ax = plt.subplots(nrows=1)
    fig.set_dpi(160)
    fig.set_size_inches((8,6))

    xmin = np.min(samples[:,0])
    xmax = np.max(samples[:,0])
    ymin = np.min(samples[:,1])
    ymax = np.max(samples[:,1])
    bounds = np.array(((xmin,xmax),(ymin, ymax)))
    plot_2d_binary_segmentation(ax, clf, 0, bounds)

    unique_labels = np.unique(labels)
    # print("Found %d labels." % len(unique_labels))
    
    for label in unique_labels:
        flags = (labels==label)
        samps = np.compress(flags, samples, axis=0)
        ax.plot(samps[:,0], samps[:,1], 'o')
        
    ax.set_title(f"Decision Tree Classifier")
    ax.set_xlabel("X Direction")
    ax.set_ylabel("Y Direction")
    fig.savefig(filename)

################################################################
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="""Generate code for a decision tree classifier.""",
                                     formatter_class=argparse.RawDescriptionHelpFormatter,
                                     epilog=usage_guide)
    parser.add_argument('--training', type=str, default='data/training.csv', help='Path of training data input file.')
    parser.add_argument('--space', action='store_true', help = 'Assume input is whitespace-delimited (default is comma delimited).')
    parser.add_argument('--leaf', type=int, default=10, help = 'Minimum number of samples in each leaf node (default is 10).')
    parser.add_argument('--name', type=str, default='classify', help = 'Name of Python filter function.')
    parser.add_argument('--integer', action='store_true', help = 'Enable integer code (default is floating point).')
    parser.add_argument('--out', type=str, default="../classifier.py", help='Path of Python code output file.')
    parser.add_argument('--plot', type=str, default="../plots/classifier.png", help='Path of plot image output file.')
    args = parser.parse_args()

    # Each training data line is a sample prefixed with an integer label, separated by whitespace.
    if args.space:
        training = np.loadtxt(args.training)        
    else:
        training = np.loadtxt(args.training, delimiter=',')

    num_samples = len(training)
    dims = len(training[0]) - 1
    labels =  training[:,0].astype(int)
    samples = training[:,1:]

    # Assume the data is somewhat noisy, so require each leaf node to represent some minimum
    # number of samples specified by 'min_samples_leaf' to reduce overfitting.
    clf = sklearn.tree.DecisionTreeClassifier(min_samples_leaf=args.leaf).fit(samples, labels)

    # Generate the code file.
    emit_classifier(args.out, clf, args.name, dims, args.integer is True)

    # Generate the plot image file.
    if _matplotlib_available and args.plot is not None and dims == 2:
        plot_2d_tree(clf, samples, labels, args.plot)
