# array.py
#
# Sample Webots controller file for driving the array
# of underactuated two-joint devices.
#
# No copyright, 2020-2023, Garth Zeglin.  This file is
# explicitly placed in the public domain.

print("loading array.py...")

import math
import random
import numpy as np

# Import the Webots simulator API.
from controller import Supervisor

# Define the time step in milliseconds between controller updates.
EVENT_LOOP_DT = 20

################################################################
# The sample controller is defined as an class which is then instanced as a
# singleton control object.  This is conventional Python practice and also
# is closer to practices involving the physical hardware.

class ArrayRobot(Supervisor):
    def __init__(self):

        # Initialize the superclass which implements the Robot API.
        super().__init__()

        robot_name = self.getName()
        print("%s: controller connected." % (robot_name))

        # Read back the array dimensions.  This is possible because the Robot is
        # configured with Supervisor permissions.
        num_x = self.getSelf().getField('num_x').value
        num_y = self.getSelf().getField('num_y').value
        print("Array has %d x %d units." % (num_x, num_y))

        # Fetch handles for the motors into a 2D array.  The naming scheme is defined in
        # the proto file for the array.
        j = []
        for x in range(1,num_x+1):
            row = []
            for y in range(1,num_y+1):
                name = "motor%d%d" % (x, y)
                # print("Fetching motor", name)
                row.append(self.getDevice(name))
            j.append(row)
        self.joints = j

        # Configure the motors for velocity control by setting
        # the position targets to infinity.
        for row in self.joints:
            for motor in row:
                motor.setPosition(float('inf'))
                motor.setVelocity(0.0)

        # Create a numpy array with a velocity for each motor.
        self.velocity = np.zeros((num_x, num_y))

        # Create some additional constant index arrays useful for generative functions.
        # Each of these has the same shape as the motor array.
        # x_mesh varies only the grid 'X' direction, y_mesh varies only along the grid 'Y' direction.
        self.y_mesh, self.x_mesh = np.meshgrid(np.linspace(0, num_y-1, num_y), np.linspace(0, num_x-1, num_x))

        # Initialize the performance state.
        self.state = 'init'
        self.new_state = True
        self.state_timer = 0
        self.current_time = 0.0

        return

    #----------------------------------------------------------------
    # internal method to set the velocity of each motor to the current velocity grid value
    def _update_motor(self):
        for vel_row, motor_row in zip(self.velocity, self.joints):
            for vel, motor in zip(vel_row, motor_row):
                motor.setVelocity(vel)

    #----------------------------------------------------------------
    # set every motor to zero velocity
    def stop_all(self):
        self.velocity[:,:] = 0.0
        self._update_motor()

    # set every motor to a specific velocity
    def unison(self, velocity):
        self.velocity[:,:] = velocity
        self._update_motor()

    # mirror the left half velocity onto the right half
    def lateral_mirror(self):
        width = self.velocity.shape[1]
        halfwidth = width // 2
        self.velocity[:,halfwidth:width] = -self.velocity[:, (halfwidth-1)::-1]
        self._update_motor()

    # set every motor to a reasonable random velocity (positive and negative)
    def randomize_velocity(self):
        self.velocity = 3 - 6*np.random.rand(*self.velocity.shape)
        self._update_motor()

    # set every motor to wave function varying along the X direction.  Note: in the sample model, the
    # array has been rotated so this is vertical
    def wave_x(self):
        period = 3.0
        phase_velocity = 2*math.pi / period
        phase = self.x_mesh
        self.velocity = 3 * np.sin((phase_velocity * self.current_time) + 0.5*phase)
        self._update_motor()

    # set every motor to wave function varying along the Y direction.  Note: in the sample model, this
    # is the horizontal direction across the room.
    def wave_y(self):
        period = 3.0
        phase_velocity = 2*math.pi / period
        phase = self.y_mesh
        self.velocity = 1.5 + 1.5 * np.sin((phase_velocity * self.current_time) + 0.3*phase)
        self._update_motor()

    #----------------------------------------------------------------
    # Implement state machine for the overall performance state, updated
    # at the main event rate.
    def poll_performance_state(self):

        # implement a basic timer used by many states
        timer_expired = False
        self.state_timer -= 0.001*EVENT_LOOP_DT
        if self.state_timer < 0:
            timer_expired = True

        # implement entry flag logic used by many state
        new_state = self.new_state
        self.new_state = False

        # -----------------------------------

        # Evaluate the side-effects and transition rules for each state
        if self.state == 'init':
            if new_state:
                self.state_timer = 1.0    # initial pause
                self.stop_all()

            elif timer_expired:
                self._transition('spinny')

        elif self.state == 'spinny':
            # apply side effects on entry
            if new_state:
                print("Entering", self.state)
                self.state_timer = 3.0
                self.unison(-6.0)
                self.lateral_mirror()

            # apply transition rules
            elif timer_expired:
                self._transition('horizontal_wave')

        elif self.state == 'horizontal_wave':
            # apply side effects on entry
            if new_state:
                print("Entering", self.state)
                self.state_timer = 6.0

            # apply transition rules
            elif timer_expired:
                self._transition('vertical_wave')

            # apply continuous side-effects
            else:
                self.wave_y()

        elif self.state == 'vertical_wave':
            # apply side effects on entry
            if new_state:
                print("Entering", self.state)
                self.state_timer = 6.0

            # apply transition rules
            elif timer_expired:
                self._transition('tremble')

            # apply continuous side-effects
            else:
                self.wave_x()

        elif self.state == 'tremble':
            # apply side effects on entry
            if new_state:
                print("Entering", self.state)
                self.state_timer = 3.0

            # apply transition rules
            elif timer_expired:
                self._transition('rest')

            # apply continuous side-effects
            else:
                self.randomize_velocity()

        elif self.state == 'rest':
            if new_state:
                print("Entering", self.state)
                self.state_timer = 6.0
                self.stop_all()

            elif timer_expired:
                self._transition('spinny')

        else:
            print("Invalid state, resetting.")
            self.state = 'init'

    # internal methods for state machine transitions
    def _transition(self, new_state):
        self.new_state = True
        self.state = new_state
        self.state_timer = 0

    #----------------------------------------------------------------
    def run(self):
        # Run loop to execute a periodic script until the simulation quits.
        # If the controller returns -1, the simulator is quitting.

        while self.step(EVENT_LOOP_DT) != -1:
            # Read simulator clock time.
            self.current_time = self.getTime()

            # Poll the performance state machine.
            self.poll_performance_state()


################################################################
# If running directly from Webots as a script, the following will begin execution.
# This will have no effect when this file is imported as a module.

if __name__ == "__main__":
    robot = ArrayRobot()
    robot.run()
