# gestures.py

# Raspberry Pi Pico - Accelerometer Processing Demo

# Read an analog accelerometer input with the ADC and detect patterns of motion.
# Drive the onboard LED and print messages to indicate state.
# Assumes a three-axis analog accelerometer is connected to the three ADC inputs.

import board
import time
import analogio
import digitalio

# Import filters from the 'signals' directory provided with the course.  These
# files should be copied to the top-level directory of the CIRCUITPY filesystem
# on the Pico.

import linear
import biquad

#---------------------------------------------------------------

# Coefficients for Low-Pass Butterworth IIR digital filter, generated using
# filter_gen.py.  Sampling rate: 100.0 Hz, frequency: 20.0 Hz.  Filter is order
# 4, implemented as second-order sections (biquads). Reference:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.butter.html
low_pass_100_20 = [
    [[ 1.00000000, -0.32897568,  0.06458765],   # A coeff for section 0
     [ 0.04658291,  0.09316581,  0.04658291]],  # B coeff for section 0
    [[ 1.00000000, -0.45311952,  0.46632557],   # A coeff for section 1
     [ 1.00000000,  2.00000000,  1.00000000]],  # B coeff for section 1
]

#---------------------------------------------------------------
class Accelerometer:
    def __init__(self):
        """Interface to a three-axis analog accelerometer including input calibration
        and filtering."""

        self.x_in = analogio.AnalogIn(board.A0)
        self.y_in = analogio.AnalogIn(board.A1)
        self.z_in = analogio.AnalogIn(board.A2)
        self.sampling_interval = 10000000           # period of 100 Hz in nanoseconds
        self.sampling_timer = 0
        self.raw  = [0, 0, 0]           # most recent raw value
        self.grav = [0.0, 0.0, 0.0]     # current calibrated and smoothed value

        # create a set of filters to smooth the input signal
        self.filters = [biquad.BiquadFilter(low_pass_100_20) for i in range(3)]
        
    def poll(self, elapsed):
        """Polling function to be called as frequently as possible from the event loop
        to read and process new samples.

        :param elapsed: nanoseconds elapsed since the last cycle.
        """

        self.sampling_timer -= elapsed
        if self.sampling_timer < 0:
            self.sampling_timer += self.sampling_interval

            # read the ADC values as synchronously as possible
            self.raw = [self.x_in.value, self.y_in.value, self.z_in.value]

            # apply linear calibration to find the unit gravity vector direction
            self.grav = [linear.map(self.raw[0], 26240, 39120, -1.0, 1.0),
                         linear.map(self.raw[1], 26288, 39360, -1.0, 1.0),
                         linear.map(self.raw[2], 26800, 40128, -1.0, 1.0)]

            # pipe each calibrated value through the filter
            self.grav = [filt.update(sample) for filt, sample in zip(self.filters, self.grav)]
            
#---------------------------------------------------------------
class Blinker:
    def __init__(self):
        """Interface to the onboard LED featuring variable rate blinking."""
        # Set up built-in green LED for output.
        self.led = digitalio.DigitalInOut(board.LED)  # GP25
        self.led.direction = digitalio.Direction.OUTPUT
        self.update_timer = 0
        self.set_rate(1.0)
            
    def poll(self, elapsed):
        """Polling function to be called as frequently as possible from the event loop
        with the nanoseconds elapsed since the last cycle."""
        self.update_timer -= elapsed
        if self.update_timer < 0:
            self.update_timer += self.update_interval
            self.led.value = not self.led.value

    def set_rate(self, Hz):
        self.update_interval = int(500000000 / Hz) # blink half-period in nanoseconds
        
#---------------------------------------------------------------
class Logic:
    def __init__(self, accel, blinker):
        """Application state machine."""
        self.accel   = accel
        self.blinker = blinker
        self.update_interval = 500000000           # period of 2 Hz in nanoseconds
        self.update_timer    = 0

        # initialize the state machine
        self.state = 'init'
        self.state_changed = False
        self.state_timer = 0

    def poll(self, elapsed):
        """Polling function to be called as frequently as possible from the event loop
        with the nanoseconds elapsed since the last cycle."""
        self.update_timer -= elapsed
        if self.update_timer < 0:
            self.update_timer += self.update_interval
            self.tick()            # evaluate the state machine

    def transition(self, new_state):
        """Set the state machine to enter a new state on the next tick."""
        self.state         = new_state
        self.state_changed = True
        self.state_timer   = 0
        print("Entering state:", new_state)

    def tick(self):
        """Evaluate the state machine rules."""

        # advance elapsed time in seconds        
        self.state_timer += 1e-9 * self.update_interval  
        
        # set up a flag to process transitions within a state clause
        entering_state = self.state_changed
        self.state_changed = False
        
        # select the state clause to evaluate
        if self.state == 'init':
            self.transition('idle')

        elif self.state == 'idle':
            if entering_state:
                self.blinker.set_rate(0.5)

            if self.accel.grav[1] < -0.5:
                self.transition('right')

        elif self.state == 'right':
            if entering_state:
                self.blinker.set_rate(1.0)

            if self.state_timer > 2.0:
                self.transition('idle')

            if self.accel.grav[1] > 0.5:
                self.transition('left')

        elif self.state == 'left':
            if entering_state:
                self.blinker.set_rate(2.0)

            if self.state_timer > 2.0:
                self.transition('idle')

            if self.accel.grav[0] < -0.5:
                self.transition('win')
                
        elif self.state == 'win':
            if entering_state:
                self.blinker.set_rate(4.0)

            if self.state_timer > 2.0:
                self.transition('idle')
                
#---------------------------------------------------------------
class Status:
    def __init__(self, accel, blinker, logic):
        """Debugging information reporter."""
        self.accel    = accel
        self.blinker  = blinker
        self.logic    = logic
        # self.update_interval = 1000000000           # period of 1 Hz in nanoseconds
        self.update_interval = 100000000           # period of 10 Hz in nanoseconds        
        self.update_timer = 0

    def poll(self, elapsed):
        """Polling function to be called as frequently as possible from the event loop
        with the nanoseconds elapsed since the last cycle."""
        self.update_timer -= elapsed
        if self.update_timer < 0:
            self.update_timer += self.update_interval
            print(tuple(self.accel.grav))
            # print(tuple(self.accel.raw))
                    
#---------------------------------------------------------------
# Initialize global objects.
accel   = Accelerometer()
blinker = Blinker()
logic   = Logic(accel, blinker)
status  = Status(accel, blinker, logic)

#---------------------------------------------------------------
# Main event loop to run each non-preemptive thread.

last_clock = time.monotonic_ns()

while True:
    # read the current nanosecond clock
    now = time.monotonic_ns()
    elapsed = now - last_clock
    last_clock = now

    # poll each thread
    accel.poll(elapsed)
    logic.poll(elapsed)
    blinker.poll(elapsed)
    status.poll(elapsed)
