#!/usr/bin/env python3
# Obtain and process images from a Blackfly GigE machine vision camera using the FLIR Spinnaker PySpin API.

cmd_desc = "Stream FLIR camera images using OSC UDP messages."

import os
import sys
import logging
import argparse
import datetime
import threading

# OpenCV
import cv2 as cv

# NumPy
import numpy as np

# python-osc
from pythonosc import osc_message_builder
from pythonosc import udp_client
from pythonosc import dispatcher
from pythonosc import osc_server

# theater space configuration
from .config import camera_roi, vision_UDP_port, client_UDP_port

# camera interface
from .camera import SpinnakerCamera
from .vision import VideoRegion

# common logging functions
from .logconfig import add_logging_args, open_log_file, configure_logging

# initialize logging for this module
log = logging.getLogger('server')

#================================================================
def elapsed(time0, time1, interval):
    return (time1-time0).total_seconds() >= interval

#================================================================
class ClientManager:
    def __init__(self, args):

        self.args = args

        # empty list of clients to start
        self.clients = {}

        # Create the networking input and output.
        self.init_networking(args)

        return
    #---------------------------------------------------------------
    def _get_or_add_client(self, ip, port):
        client = self.clients.get(ip)

        if client is None:
            # Add a new client.
            # Create an OSC socket to communicate back to the client.
            socket = udp_client.SimpleUDPClient(ip, port)
            log.info("Opened socket to send to client at %s:%d", ip, port)

            # create a new client record
            client = {'timestamp': datetime.datetime.now(),
                      'socket': socket,
                      'motion': False,
                      'images': False,
                      }
            self.clients[ip] = client

        return client

    #---------------------------------------------------------------
    def init_networking(self, args):
        # Initialize the OSC message dispatch system.
        self.dispatch = dispatcher.Dispatcher()
        self.dispatch.map("/request/motion", self.motion_request, needs_reply_address=True)
        self.dispatch.map("/request/images", self.images_request, needs_reply_address=True)
        self.dispatch.set_default_handler(self.unknown_message, needs_reply_address=True)

        # Start and run the server.
        self.server = osc_server.OSCUDPServer((args.recv, vision_UDP_port), self.dispatch)
        self.server_thread = threading.Thread(target=self.server.serve_forever)
        self.server_thread.daemon = True
        self.server_thread.start()
        log.info("started OSC server on port %s:%d", args.recv, vision_UDP_port)

    def motion_request(self, client_address, msgaddr, *args):
        # client_address is a tuple (ip address, udp port number)
        log.info("received motion data request from %s: %s %s", client_address, msgaddr, args)
        client_ip = client_address[0]

        # fetch the client record if it exists or make a new one
        client = self._get_or_add_client(client_ip, client_UDP_port)
        client['motion'] = True
        client['timestamp'] = datetime.datetime.now() # update the timestamp
        return

    def images_request(self, client_address, msgaddr, *args):
        log.info("received images data request from %s: %s %s", client_address, msgaddr, args)
        client_ip = client_address[0]

        # fetch the client record if it exists or make a new one
        client = self._get_or_add_client(client_ip, client_UDP_port)
        client['images'] = True
        client['timestamp'] = datetime.datetime.now() # update the timestamp
        return

    def unknown_message(self, client_address, msgaddr, *args):
        """Default handler for unrecognized OSC messages."""
        log.warning("Received unmapped OSC message from %s: %s: %s", client_address, msgaddr, args)

    #---------------------------------------------------------------
    def drop_inactive_clients(self):
        # Purge any inactive clients from the list.  Should be called periodically.
        purge = []
        now = datetime.datetime.now()
        for ip, record in self.clients.items():
            timestamp = record['timestamp']
            if (now - timestamp).total_seconds() > 10:
                log.info("connection to %s has timed out", ip)
                purge.append(ip)

        for key in purge:
            del self.clients[key]

#================================================================
class EventLoop:
    def __init__(self, camera, args):

        self.camera = camera
        self.args = args
        self.last_capture = None

        # create a manager to keep track of client requests
        self.manager = ClientManager(args)

        # create a tracker for each active image region specified in the configuration
        self.trackers = [VideoRegion(roi) for roi in camera_roi]

        self.packet_count = 0
        self.next_tile = -1
        self.tile_width = 144
        self.tile_height = 54
        self.num_tile_cols = self.camera.capture_width // self.tile_width
        self.num_tile_rows = self.camera.capture_height // self.tile_height
        return

    #---------------------------------------------------------------
    def write_capture_file(self, image):
        now = datetime.datetime.now()
        folder = os.path.expanduser("logs")
        filename = now.strftime("%Y-%m-%d-%H-%M-%S") + "-camera.png"
        path = os.path.join(folder, filename)
        cv.imwrite(path, image)
        log.debug("Wrote region image to %s", path)

    #---------------------------------------------------------------
    def send_subsampled_image(self, image):
        # Send a complete image at low resolution.  With 2x2
        # decimation and a height window applied, the system is
        # returning (270, 720) images.  Dividing by 5 yields a (54,
        # 144) image with 7776 bytes which still fits within an OSC
        # UDP packet.
        subsampled = image[::5, ::5]
        height, width, depth = subsampled.shape
        data = subsampled.tobytes()
        msg = osc_message_builder.OscMessageBuilder(address='/image')
        msg.add_arg(height) # number of rows of pixel data
        msg.add_arg(width)  # number of columns of pixel data
        msg.add_arg(self.camera.capture_height) # full image pixel rows
        msg.add_arg(self.camera.capture_width)  # full image pixel columns
        msg.add_arg(data, arg_type='b') # blob
        packet = msg.build()
        self.client.send(packet)
        self.packet_count += 1
        if self.packet_count % 900 == 0:
            log.info("Sent %d packets", self.packet_count)

    #---------------------------------------------------------------
    def send_image_tile(self, socket, image, tile_row, tile_col):
        # Send a section of an image at full resolution.
        offset_cols = tile_col * self.tile_width
        offset_rows = tile_row * self.tile_height
        tile = image[offset_rows:offset_rows+self.tile_height, offset_cols:offset_cols+self.tile_width]
        height, width, depth = tile.shape
        data = tile.tobytes()
        msg = osc_message_builder.OscMessageBuilder(address='/tile')
        msg.add_arg(height) # tile pixel rows
        msg.add_arg(width)  # tile pixel columns
        msg.add_arg(offset_rows) # tile pixel row position in full image
        msg.add_arg(offset_cols) # tile pixel column position in full image
        msg.add_arg(self.camera.capture_height) # full image pixel rows
        msg.add_arg(self.camera.capture_width)  # full image pixel columns
        msg.add_arg(data, arg_type='b') # blob of 8-bit mono pixel data
        packet = msg.build()
        socket.send(packet)
        self.packet_count += 1
        if self.packet_count % 900 == 0:
            log.info("Sent %d packets", self.packet_count)

    #---------------------------------------------------------------
    def send_debug_region(self, image, row, col):
        # Write a debugging image into the remote debug buffer at the specified location.
        height, width, depth = image.shape
        data = image.tobytes()
        msg = osc_message_builder.OscMessageBuilder(address='/debug')
        msg.add_arg(height) # image pixel rows
        msg.add_arg(width)  # image pixel columns
        msg.add_arg(row) # image pixel row position in full image
        msg.add_arg(col) # image pixel column position in full image
        msg.add_arg(self.camera.capture_height) # full image pixel rows
        msg.add_arg(self.camera.capture_width)  # full image pixel columns
        msg.add_arg(data, arg_type='b') # blob of 8-bit mono pixel data
        packet = msg.build()
        self.client.send(packet)
        self.packet_count += 1
        if self.packet_count % 900 == 0:
            log.info("Sent %d packets", self.packet_count)

    #---------------------------------------------------------------
    def run(self):
        while True:
            image = self.camera.acquire_image()

            # ignore empty images
            if image is None:
                continue

            # update the motion detection for each active region
            now = datetime.datetime.now()
            for i, tracker in enumerate(self.trackers):
                changed = tracker.update(image)
                if changed:
                    log.info("tracker %d reports motion state %s", i, tracker.motion_detected)

                    # send motion data to active clients`
                    for record in self.manager.clients.values():
                        if record['motion'] is True:
                            record['socket'].send_message("/motion", [i, tracker.motion_detected])

                    # capture the event image (with a rate limit)
                    if tracker.capture_timestamp is None or elapsed(tracker.capture_timestamp, now, 30):
                        tracker.write_capture_file()

            # cycle through the image tiles to send
            self.next_tile = (self.next_tile+1) % (self.num_tile_cols*self.num_tile_rows)
            next_row = self.next_tile // self.num_tile_cols
            next_col = self.next_tile % self.num_tile_cols

            # send image data to active clients`
            for ip, record in self.manager.clients.items():
                if record['images'] is True:
                    self.send_image_tile(record['socket'], image, next_row, next_col)

            # check whether any connections have timed out
            self.manager.drop_inactive_clients()


#================================================================
if __name__ == '__main__':

    # Initialize the command parser.
    parser = argparse.ArgumentParser(description = cmd_desc)
    parser.add_argument("--recv", default="0.0.0.0",    help="IP address of the OSC receiver (default: %(default)s).")
    add_logging_args(parser)

    # Parse the command line, returning a Namespace.
    args = parser.parse_args()

    # set up logging
    open_log_file('logs/vision-server.log')
    log.info("Starting vision_server.py")

    # Modify logging settings as per common arguments.
    configure_logging(args)

    # Open the camera.
    camera = SpinnakerCamera(args)
    camera.start_continuous()

    if not camera.continuous_mode:
        log.error("Camera not acquiring, quitting.")

    else:
        # Create the event loop.
        event_loop = EventLoop(camera, args)

        try:
            event_loop.run()

        except KeyboardInterrupt:
            print ("User quit.")

    log.info("Exiting vision server.")
