# tilt_table.py
#
# Sample Webots controller file for driving the
# two-DOF tilt table device.
#
# No copyright, 2022, Garth Zeglin.  This file is
# explicitly placed in the public domain.

# References:
#   https://scikit-learn.org/stable/modules/svm.html
#   https://www.cyberbotics.com/doc/reference/led?tab-language=python
#   https://www.cyberbotics.com/doc/reference/camera?tab-language=python

print("loading tilt_table.py...")

# Import the Webots simulator API.
from controller import Robot
from controller import Mouse
from controller import Keyboard

# Import standard Python libraries.
import math, random

# Import scikit libraries.
from sklearn import svm
import numpy as np

print("tilt_table.py waking up...")
print("Use the mouse to control the table tilt, then press A to toggle automatic mode.")

# Define the time step in milliseconds between
# controller updates.
EVENT_LOOP_DT = 20

# Request a proxy object representing the robot to
# control.
robot = Robot()

# Fetch handle for the 'base' joint motor.
motor1 = robot.getDevice('j1')
motor2 = robot.getDevice('j2')

# Enable computer keyboard input for user control.
keyboard = Keyboard()
keyboard.enable(EVENT_LOOP_DT)
last_key = None

# Request mouse updates for user input.
mouse = Mouse()
mouse.enable(EVENT_LOOP_DT)

# Enable the overhead camera.
camera = robot.getDevice('camera')
camera.enable(EVENT_LOOP_DT)
camera.recognitionEnable(EVENT_LOOP_DT)

# Connect to the spotlight.
spotlight = robot.getDevice('spotlight')

################################################################
# Global buffer to record machine states for policy regression.
# The system observes tuples [ball_u, ball_v, j1, j2]
# The policy regression maps [ball_u, ball_v] -> [d_j1, d_j2]
samples = []
J1_policy = None
J2_policy = None
automatic = False

def make_policy(history):
    # return a pair of models J1 and J2 which map the current state to the next actuator command
    J1 = svm.SVR()
    J2 = svm.SVR()
    ndata = np.array(history)
    states = ndata[:,0:2]  # slice representing current state
    Y1 = ndata[:,2]        # slice representing J1 policy output
    Y2 = ndata[:,3]        # slice representing J2 policy output
    J1.fit(states, Y1)     # fit the J1 model
    J2.fit(states, Y2)     # fit the J2 model
    return J1, J2

################################################################
# Run an event loop until the simulation quits,
# indicated by the step function returning -1.
while robot.step(EVENT_LOOP_DT) != -1:

    # Read simulator clock time in seconds (discretized by EVENT_LOOP_DT milliseconds).
    t = robot.getTime()

    # Read the camera ball tracker state.
    objects = camera.getRecognitionObjects()
    if objects is None or len(objects) == 0:
        print("Warning: no ball detected.")
        ball = None
    else:
        # read the [u, v] integer pixel location of ball
        # [0,0] is upper left, [64,64] is lower right
        ball = objects[0].getPositionOnImage()

    # Drive the light based on ball position.
    if ball is not None:
        if abs(ball[0]-32) < 10 and abs(ball[1]-32) < 10:
            spotlight.set(0xffffff) # 0xRRGGBB color integer
        else:
            spotlight.set(0x000000)

    # Read any computer keyboard keypresses.  Returns -1 or an integer keycode while a key is held down.
    # This is debounced to detect only changes.
    key = keyboard.getKey()
    if key != last_key:
        last_key = key        
        if key != -1:
            # convert the integer key number to a lowercase single-character string
            letter = chr(key).lower()
            # Set mode based on keypresses.
            if letter == 'a':
                if automatic is False:
                    print("Entering automatic mode with %d input samples." % (len(samples)))
                    automatic = True
                    J1_policy, J2_policy = make_policy(samples)
                else:
                    print("Exiting automatic mode.")
                    automatic = False
                    samples = []

    if automatic:
        # in automatic mode, use the policy to select the next action
        if ball is not None:
            # scale the samples to unity range; SVM is sensitive to scaling
            state = [0.01*ball[0], 0.01*ball[1]]
            angle1 = 0.1 * J1_policy.predict([state])[0]
            angle2 = 0.1 * J2_policy.predict([state])[0]
            angle1 = min(max(angle1, -0.1), 0.1)
            angle2 = min(max(angle2, -0.1), 0.1)
            motor1.setPosition(angle1)
            motor2.setPosition(angle2)
            # print("Policy:", state, angle1, angle2)

    else:
        # Read the mouse state to control the robot.
        # The (u,v) coordinates are (0,0) in the upper left and (1,1) in the lower right.
        ms = mouse.getState()
        if (not math.isnan(ms.u)) and (not math.isnan(ms.v)):
            angle1 = -0.2 * (ms.u - 0.5)
            angle2 =  0.2 * (ms.v - 0.5)
            motor1.setPosition(angle1)
            motor2.setPosition(angle2)

            # record the policy trajectory
            if ball is not None:
                # scale the samples to unity range; SVM is sensitive to scaling
                samp = [0.01*ball[0], 0.01*ball[1], 10*angle1, 10*angle2]
                samples.append(samp)
                # print("Sample: ", samp)
