#!/usr/bin/env python3

################################################################
# Standard Python libraries.
import sys, argparse, logging

# Set up debugging output.
# logging.getLogger().setLevel(logging.DEBUG)

# Extension libraries.
import numpy as np
import scipy.signal
import sklearn.tree
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Model data.
# label_names = ['close', 'far', 'receding', 'approaching']

# Each line is a labeled sample, separated by whitespace, with the label as the first element.
training = np.loadtxt('training.dat')
labels =  training[:,0].astype(int)
samples = training[:,1:]

print(f"samples: {samples}")
print(f"labels: {labels}")

clf = sklearn.tree.DecisionTreeClassifier().fit(samples, labels)

print(f"thresholds: {clf.tree_.threshold}")
print(f"features:   {clf.tree_.feature}")
print(f"left:       {clf.tree_.children_left}")
print(f"right:      {clf.tree_.children_right}")
print(f"values:     {clf.tree_.value}")


# samples = [[0, 20], [60.0, 0.0], [40, -10], [45, 0], [50, 0], [55, 0], [100, -40]]
# classes = clf.predict(samples)
# names = [label_names[i] for i in classes]
# print(f"Classes of {samples}\n is {classes}\nnamed {names}")

# https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html#sphx-glr-auto-examples-tree-plot-unveil-tree-structure-py

def walk_tree(clf, node=0, level=0, integer=False):
    # look up the node properties in each numerical array
    threshold = clf.tree_.threshold[node]
    feature   = clf.tree_.feature[node]
    left      = clf.tree_.children_left[node]
    right     = clf.tree_.children_right[node]
    prefix    = "  "*(level+1)
    
    if left != right:
        if integer:
            threshold = int(threshold)
        
        print(f"{prefix}if (input[{feature}] <= {threshold}) {{ // node {node}")
        walk_tree(clf, left, level+1, integer)
        print(f"{prefix}}} else {{")        
        walk_tree(clf, right, level+1, integer)
        print(f"{prefix}}}")

    else:
        value = clf.tree_.value[node]
        cls = np.argmax(value)
        print(f"{prefix}return {cls}; // node {node}")

def emit_classifier(clf, name='classify', dims=2, integer=False):
    itype = 'int' if integer else False
    print(f"int {name}({itype} input[{dims}])\n{{")
    walk_tree(clf, 0, 0, integer)
    print(f"}} // end of {name}()")

emit_classifier(clf)
emit_classifier(clf, 'iclassify', 2, True)

# sklearn.tree.plot_tree(clf)
# plt.show()

################################################################
def plot_2d_binary_segmentation(ax, clf, node, bounds):
    print(f"Entering node {node}, bounds are {bounds}")
    left      = clf.tree_.children_left[node]
    right     = clf.tree_.children_right[node]
    if left != right:
        threshold = clf.tree_.threshold[node]
        feature   = clf.tree_.feature[node]
        leftbounds = bounds.copy()
        leftbounds[feature][1] = threshold
        plot_2d_binary_segmentation(ax, clf, left, leftbounds)
        rightbounds = bounds.copy()
        rightbounds[feature][0] = threshold
        plot_2d_binary_segmentation(ax, clf, right, rightbounds)
    else:
        cls = np.argmax(clf.tree_.value[node])
        rectangle = patches.Rectangle((bounds[0][0], bounds[1][0]), # lower left corner
                                      bounds[0][1] - bounds[0][0],  # width
                                      bounds[1][1] - bounds[1][0],  # height
                                      edgecolor = 'black',
                                      facecolor = "C%d" % cls,
                                      alpha = 0.5)
        
        ax.add_patch(rectangle)
            
def plot_2d_tree(clf, samples, filename="classify.png"):
    fig, ax = plt.subplots(nrows=1)
    fig.set_dpi(160)
    fig.set_size_inches((8,6))

    xmin = np.min(samples[:,0])
    xmax = np.max(samples[:,0])
    ymin = np.min(samples[:,1])
    ymax = np.max(samples[:,1])
    bounds = np.array(((xmin,xmax),(ymin, ymax)))
    plot_2d_binary_segmentation(ax, clf, 0, bounds)

    ax.plot(samples[:,0], samples[:,1], 'o')
    ax.set_title(f"Classification of Range Data Samples")
    ax.set_xlabel("Position (cm)")
    ax.set_ylabel("Velocity (cm/sec)")
    fig.savefig(filename)    
    

################################################################    
plot_2d_tree(clf, samples)
# print(samples)
