#!/usr/bin/env python3

# The best way to start this server is from the parent folder using:
#   python3 -m stage.motion_server --quiet

cmd_desc = "Communicate real-time motion commands received via OSC to the motion hardware."

import argparse
import time
import math
import logging
import queue
import threading

import numpy as np

# import the pythonosc package
from pythonosc import udp_client
from pythonosc import dispatcher
from pythonosc import osc_server

# import the StepperSpline interface
from . import steppers

# OSC port for motor commands
from .config import motor_UDP_port, motor_host_IP

# common logging functions
from .logconfig import add_logging_args, open_log_file, configure_logging

# initialize logging for this module
log = logging.getLogger('motion')

#================================================================
class EventLoop:
    def __init__(self, args):

        # configure event timing
        self.dt_ns = 500*1000*1000  # 2 Hz in nanoseconds per cycle

        # save flags
        self.verbose = args.verbose

        # Keep usage counts for logging.
        self.spline_msg_count = 0
        self.tempo_msg_count = 0

        # the number of microsteps per revolution is set by driver jumpers
        self.steps_per_rev = 800

        # Open the Arduino serial port
        log.info("Opening Arduino serial port %s.", args.arduino)
        self.steppers = steppers.StepperSplineClient(port=args.arduino, verbose=args.verbose, debug=args.debug)

        log.info("Waiting for Arduino wakeup.")
        self.steppers.wait_for_wakeup()

        # Create the networking listener
        self.init_networking(args)

        return

    def close(self):
        # Issue a command to turn off the drivers, then shut down the connections.
        log.info("Closing serial ports.")
        self.steppers.motor_enable(False)
        self.steppers.close()

    #---------------------------------------------------------------
    def init_networking(self, args):
        # Initialize the OSC message dispatch system.
        self.dispatch = dispatcher.Dispatcher()
        self.dispatch.map("/tempo",     self.motors_tempo)
        self.dispatch.map("/spline",    self.motors_spline)

        self.dispatch.set_default_handler(self.unknown_message)

        # Start and run the server.
        unit_port = motor_UDP_port + args.unit
        self.server = osc_server.OSCUDPServer((args.ip, unit_port), self.dispatch)
        self.server_thread = threading.Thread(target=self.server.serve_forever)
        self.server_thread.daemon = True
        self.server_thread.start()
        log.info("started OSC server on port %s:%d", args.ip, unit_port)

    def unknown_message(self, msgaddr, *args):
        """Default handler for unrecognized OSC messages."""
        log.warning("Received unmapped OSC message %s: %s", msgaddr, args)

    def motors_tempo(self, msgaddr, *args):
        """Process /tempo messages received via OSC over UDP.  The first
        message value specifies a spline execution rate in segments per minute."""
        try:
            channel  = args[0]
            if isinstance(channel, str):
                if channel in 'xyza':
                    mask = channel
                else:
                    log.warning("OSC tempo invalid channel index: %d", channel)
                    return
            else:
                mask = 'xyza'[int(channel)]

            tempo = float(args[1])
            self.steppers.send_tempo(tempo, mask)
            self.tempo_msg_count += 1

        except:
            log.warning("error processing OSC tempo message: %s", args)

    def motors_spline(self, msgaddr, *args):
        """Process /spline messages received via OSC over UDP.  The first
        message values specifies a motor driver channel, followed by a sequence
        of Bezier spline knots, three per spline segment, specifying a relative
        displacement in degrees.
        """
        try:
            channel  = args[0]
            if isinstance(channel, str):
                if channel in 'xyza':
                    mask = channel
                else:
                    log.warning("OSC spline invalid channel index: %d", channel)
                    return
            else:
                mask = 'xyza'[int(channel)]

            # convert remainder of arguments to numpy array
            angles   = np.array([float(val) for val in args[1:]])

            # calculate the number of full segments; extra knots are ignored
            segments = angles.shape[0] // 3

            # trim the angle list if needed and rescale to microsteps
            knots    = angles[0:3*segments] * (self.steps_per_rev/360.0)

            # convert to a matrix (one row per knot, one column for the one channel)
            knots = knots.reshape(3*segments, 1)

            # immediately send to the Arduino
            self.steppers.send_knots(knots, mask)
            self.spline_msg_count += 1

        except:
            log.warning("error processing OSC spline message: %s", args)


    #---------------------------------------------------------------
    # Event loop.  Most of the data flow occurs on the OSC callbacks as they
    # directly send messages to the Arduino.  The event loop
    # only processes Arduino output and periodically updates the log.
    def run(self):
        start_t = time.monotonic_ns()
        next_cycle_t = start_t + self.dt_ns
        event_cycles = 0

        while True:
            # wait for the next cycle timepoint, keeping the long
            # term rate stable even if the short term timing jitters
            now_ns = time.monotonic_ns()
            delay = max(next_cycle_t - now_ns, 0)
            if (delay > 0):
                time.sleep(delay * 1e-9)
            next_cycle_t += self.dt_ns
            now_ns = time.monotonic_ns()
            now_seconds = (now_ns - start_t)*1e-9

            # Process status messages from the Arduino.
            self.steppers.poll_status()

            # periodically log status
            if (event_cycles % (2*60*5)) == 0:
                log.debug("event cycle %d: spline msgs: %d, tempo msgs: %d",
                          event_cycles,
                              self.spline_msg_count,
                              self.tempo_msg_count)
            event_cycles += 1


#================================================================
# The following section is run when this is loaded as a script.
if __name__ == "__main__":

    # Initialize the command parser.
    parser = argparse.ArgumentParser(description = cmd_desc)
    parser.add_argument( '--ip', default=motor_host_IP,  help="Network interface address (default: %(default)s).")
    parser.add_argument( '--arduino', default='/dev/ttyACM0', help='Arduino serial port device (default is %(default)s.)')
    parser.add_argument( '--unit', default=0, type=int, help='Unit number (default is %(default)s.)')
    add_logging_args(parser)

    # Parse the command line, returning a Namespace.
    args = parser.parse_args()

    # set up logging
    open_log_file('logs/motion-server-%d.log' % (args.unit))
    log.info("Starting motion-server.py as unit %d", args.unit)

    # Modify logging settings as per common arguments.
    configure_logging(args)

    # Create the event loop.
    event_loop = EventLoop(args)

    # Begin the performance.  This may be safely interrupted by the user pressing Control-C.
    try:
        event_loop.run()

    except KeyboardInterrupt:
        log.info("User interrupted operation.")
        print("User interrupt, shutting down.")
        event_loop.close()

    except Exception as e:
        log.error("unable to continue: %s: %s", type(e), e)
        print("unable to continue: ", type(e), e)
