#!/usr/bin/env python3

# The best way to start this server is from the parent folder using:
#   python3 -m stage.lighting_server --quiet

cmd_desc = "Communicate real-time lighting commands received via OSC to the DMX hardware."

import argparse
import time
import math
import logging
import threading

import numpy as np

# import the pythonosc package
from pythonosc import dispatcher
from pythonosc import osc_server

# import the DMXUSBPro interface
from . import dmxusbpro

# OSC port for lighting commands
from .config import lighting_UDP_port, lighting_host_IP, fixtures

# common logging functions
from .logconfig import add_logging_args, open_log_file, configure_logging

# initialize logging for this module
log = logging.getLogger('lighting')

#================================================================
class FixtureSpline:
    """Representation of a single DMX lighting fixture animated using cubic
    Bezier splines.  For dimmer packs this is typically a single channel
    representing a single monochrome lamp, for RGB or RGBA fixtures this
    operates three or four channels.  This stores DMX channel assignments for
    convenience but uses an external object to write the current universe.
    """
    def __init__(self, channels=1, DMX_base=0):

        # save the DMX address of the first channel
        self.DMX_base = DMX_base

        # Initialize a default spline buffer.  The invariant is that this will
        # always retain at least one segment.
        self.knots = np.zeros((4, channels))
        self.u     = 1.0
        self.rate  = 1.0   # execution rate in segments per sec

        # Current interpolated value.
        self.value = np.zeros((channels))

    def set_tempo(self, spm):
        """Set the execution rate of the spline interpolator in units of
        segments/minute.  The initial default value is 60 segs/min (1 per
        second)."""
        self.rate = max(min(spm/60.0, 5), 0.0001)

    def channels(self):
        """Return the number of channels required by the fixture.  E.g., each
        dimmer pack output is one channel, an RGB fixture has three, an RGBW
        fixture has four."""
        return self.knots.shape[1]

    def update(self, dt, dmx):
        """Advance the spline position given the elapsed interval dt (in
        seconds), recalculate the spline interpolation, and send data to the
        given dmx universe object.
        """
        if dt > 0.0:
            self.u += self.rate * dt
            #  Determine whether to roll over to the next segment or not.  This will
            #  always leave one segment remaining.
            segs = (self.knots.shape[0]-1) // 3
            if segs > 1:
                if self.u >= 1.0:
                    # u has advanced to next segment, drop the first segment by dropping the first three knots
                    self.knots = self.knots[3:]
                    self.u -= 1.0

        # always limit u to the first active segment
        self.u = min(max(self.u, 0.0), 1.0)

        # update interpolation
        self._recalculate()

        # update the DMX universe
        num_channels = self.knots.shape[1]
        dmx.universe[self.DMX_base:self.DMX_base+num_channels] = np.minimum(np.maximum(self.value, 0), 255).astype(np.uint8)

    def _recalculate(self):
        """Apply the De Casteljau algorithm to calculate the position of
        the spline at the current path point."""

        # unroll the De Casteljau iteration and save intermediate points
        num_channels = self.knots.shape[1]
        beta = np.ndarray((3, num_channels))

        # note: could use more numpy optimizations
        u = self.u
        for k in range(3):
            beta[k] = self.knots[k] * (1 - u) + self.knots[k+1] * u

        for k in range(2):
            beta[k] =  beta[k] * (1 - u) +  beta[k+1] * u;

        self.value[:] = beta[0] * (1 - u) + beta[1] * u

    def set_target(self, value):
        """Reset the path generator to a single spline segment from the current
        value to the given value.  The path defaults to 'ease-out', meaning the
        rate of change is discontinuous at the start and slows into the target

        :param value: iterable with one DMX value per channel"""
        self.knots[0,:] = self.value   # start at the current levels
        self.knots[1,:] = value        # derivative discontinuity
        self.knots[2,:] = value        # ease at the target
        self.knots[3,:] = value        # target endpoint

        # reset the path position to start the new segment
        self.u = 0.0;

    def set_value(self, value):
        """Reset the path generator immediately to the given value with no
        interpolation.

        :param value: iterable with one DMX value per channel"""

        self.knots[0,:] = value
        self.knots[1,:] = value
        self.knots[2,:] = value
        self.knots[3,:] = value

        # reset the path position to start the new segment
        self.u = 1.0;

        # reset current value
        self.value[:] = value

    def add_spline(self, knots):
        """Append a sequence of Bezier spline segments to the lighting sequence
        generator, provided as a numpy matrix with three rows per segment and
        one column per channel.  The values are treated as DMX command values,
        and may stray outside 0-255 even though the interpolated output will be
        clamped."""

        self.knots = np.vstack((self.knots, knots))

        # Note: the path position and current value are left unchanged.  If the
        # interpolator is at the end of the current segment, it will now begin
        # to advance into these new segments.
        return

#================================================================
class EventLoop:
    def __init__(self, args):

        # configure event timing
        self.dt_ns = 40*1000*1000  # 25 Hz in nanoseconds per cycle

        # Keep usage counts for logging.
        self.spline_msg_count = 0
        self.fixture_msg_count = 0
        self.tempo_msg_count = 0

        # Create the networking listener
        self.init_networking(args)

        # Open the DMX controller serial port
        log.info("Opening DMX serial port.")
        self.dmx = dmxusbpro.DMXUSBPro(port=args.dmx, verbose=args.verbose, debug=args.debug)
        self.dmx.open_serial_port()

        # Create fade interpolators. The lock is used to synchronize access between
        # the networking thread and the main event loop thread.
        self.array_lock = threading.Lock()

        # Use the stage.config.fixtures array to configure the spline interpolators.
        self.splines = []
        self.fixture_map = {}
        for i, f in enumerate(fixtures):
            log.debug("Found fixture %d: %s", i, f)
            name     = f.get('name')
            base     = f.get('dmx')
            channels = f.get('channels')
            interpolator = FixtureSpline(channels=channels, DMX_base=base)
            self.splines.append(interpolator)
            self.fixture_map[name] = interpolator

        self.num_channels = len(self.splines)

        return
    #---------------------------------------------------------------
    def set_targets(self, values):
        """Set all spline interpolation targets."""
        with self.array_lock:
            for i, value in enumerate(values):
                self.splines[i].set_target([value])

    def set_values(self, values):
        """Immediately set all spline interpolor values."""
        with self.array_lock:
            for i, value in enumerate(values):
                self.splines[i].set_value([value])

    #---------------------------------------------------------------
    def close(self):
        log.info("Closing serial ports.")
        self.dmx.close_serial_port()

    #---------------------------------------------------------------
    def init_networking(self, args):
        # Initialize the OSC message dispatch system.
        self.dispatch = dispatcher.Dispatcher()
        self.dispatch.map("/fixture",   self.lights_fixture)
        self.dispatch.map("/spline",    self.lights_spline)
        self.dispatch.map("/tempo",     self.lights_tempo)

        self.dispatch.set_default_handler(self.unknown_message)

        # Start and run the server.
        self.server = osc_server.OSCUDPServer((args.ip, lighting_UDP_port), self.dispatch)
        self.server_thread = threading.Thread(target=self.server.serve_forever)
        self.server_thread.daemon = True
        self.server_thread.start()
        log.info("started OSC server on port %s:%d", args.ip, lighting_UDP_port)

    def unknown_message(self, msgaddr, *args):
        """Default handler for unrecognized OSC messages."""
        log.warning("Received unmapped OSC message %s: %s", msgaddr, args)

    def lights_fixture(self, msgaddr, *args):
        """Process /fixture messages received via OSC over UDP.  The first value
        is treated as a fixture index or name, followed by one or more DMX
        values specifying the target level or color.  The fixture starts a
        smooth transition to the specified value.
        """

        try:
            values = np.array(args[1:])
            values = np.minimum(np.maximum(values, 0), 255)

            # first try the fixture argument as an identifier
            fixture = args[0]
            interpolator = self.fixture_map.get(fixture)
            if interpolator is None:
                # else try the fixture argument as a numerical index
                fixture = int(fixture)
                interpolator = self.splines[fixture]

            with self.array_lock:
                interpolator.set_target(values)

            self.fixture_msg_count += 1

        except:
            log.warning("OSC message failed for /fixture: %s", args)

    def lights_spline(self, msgaddr, *args):
        """Process /spline messages received via OSC over UDP.  The first
        message value specifies a fixture number, followed by a row-major array
        representing a matrix of Bezier spline knots.  The matrix has one column
        per fixture channel and three rows per spline segment, and specifies DMX
        output values.  The implicit first knot is the current fixture state.
        """
        try:
            # first try the fixture argument as an identifier
            fixture = args[0]
            interpolator = self.fixture_map.get(fixture)
            if interpolator is None:
                # else try the fixture argument as a numerical index
                fixture = int(fixture)
                interpolator = self.splines[fixture]

            # calculate the inferred matrix properties, ignoring excess values
            channels     = interpolator.channels()

            num_values   = len(args[1:])
            num_rows     = num_values // channels
            num_segments = num_rows // 3
            num_used     = 3*num_segments*channels

            # convert the usable args into a knot matrix
            knots        = np.array(args[1:1+num_used]).reshape(3*num_segments, channels)

            # and add to the fixture lighting trajectory
            with self.array_lock:
                interpolator.add_spline(knots)

            self.spline_msg_count += 1

        except:
            log.warning("error processing OSC spline message: %s", args)

    def lights_tempo(self, msgaddr, *args):
        """Process /tempo messages received via OSC over UDP.  The first message
        value specifies a fixture, the second a spline execution rate in
        segments per minute.
        """
        try:
            # first try the fixture argument as an identifier
            fixture = args[0]
            interpolator = self.fixture_map.get(fixture)
            if interpolator is None:
                # else try the fixture argument as a numerical index
                fixture = int(fixture)
                interpolator = self.splines[fixture]

            tempo = float(args[1])
            interpolator.set_tempo(tempo)

            self.tempo_msg_count += 1

        except:
            log.warning("error processing OSC tempo message: %s", args)

    #---------------------------------------------------------------
    # Event loop.
    def run(self):
        """Run the lighting server event loop to periodically update the fade
        filters and retransmit the DMX universe to the hardware.  OSC UDP
        network messages are received by a background thread and applied to the
        target and intensity arrays to be processed by this loop."""

        start_t = time.monotonic_ns()
        next_cycle_t = start_t + self.dt_ns
        event_cycles = 0

        while True:
            # wait for the next cycle timepoint, keeping the long
            # term rate stable even if the short term timing jitters
            now_ns = time.monotonic_ns()
            delay = max(next_cycle_t - now_ns, 0)
            if (delay > 0):
                time.sleep(delay * 1e-9)
            next_cycle_t += self.dt_ns
            now_ns = time.monotonic_ns()
            now_seconds = (now_ns - start_t)*1e-9

            # apply first-order smoothing to the lighting commands

            with self.array_lock:
                for spline in self.splines:
                    spline.update(self.dt_ns * 1e-9, self.dmx)

            # and send to the hardware
            self.dmx.send_universe()

            # the following is not necessary given the output-only usage of the
            # Enttec device, and also appears to be triggering a FTDI driver bug on
            # recent versions of macOS:
            # self.dmx.flush_serial_input()

            # periodically log status
            if (event_cycles % (25*60*5)) == 0:
                with self.array_lock:
                    log.debug("event cycle %d: fixture msgs: %d, spline msgs: %d, tempo msgs: %d",
                              event_cycles,
                              self.fixture_msg_count,
                              self.spline_msg_count,
                              self.tempo_msg_count)
            event_cycles += 1

#================================================================
# The following section is run when this is loaded as a script.
if __name__ == "__main__":

    # set up logging
    open_log_file('logs/lighting-server.log')
    log.info("Starting lighting-server.py")
    np.set_printoptions(linewidth=np.inf)

    # Initialize the command parser.
    parser = argparse.ArgumentParser(description = cmd_desc)
    parser.add_argument( '--ip', default=lighting_host_IP,  help="Network interface address (default: %(default)s).")
    parser.add_argument( '--dmx', default='/dev/ttyUSB0', help='DMX serial port device (default: %(default)s).')
    add_logging_args(parser)

    # Parse the command line, returning a Namespace.
    args = parser.parse_args()

    # Modify logging settings as per common arguments.
    configure_logging(args)

    # Create the event loop.
    event_loop = EventLoop(args)

    # Begin the performance.  This may be safely interrupted by the user pressing Control-C.
    try:
        event_loop.run()

    except KeyboardInterrupt:
        log.info("User interrupted operation.")
        print("User interrupt, shutting down.")
        event_loop.close()

    log.info("Exiting lighting-server.py")
