#!/usr/bin/env python3
"""camera-viewer.py : request UDP image stream from vision server and display using OpenCV
"""
import argparse
import time
import logging
import numpy as np
import datetime

# import threading

# https://docs.opencv.org/modules/highgui/doc/highgui.html
# https://docs.opencv.org/modules/highgui/doc/reading_and_writing_images_and_video.html
# https://docs.opencv.org/modules/video/doc/motion_analysis_and_object_tracking.html
import cv2 as cv

from pythonosc import udp_client
from pythonosc import dispatcher
from pythonosc import osc_server

from stage.config import theater_IP, vision_UDP_port, client_UDP_port, camera_roi

# initialize logging for this module
log = logging.getLogger('viewer')

#================================================================
def elapsed(time0, time1, interval):
    return (time1-time0).total_seconds() >= interval

#================================================================
class VideoViewer:
    def __init__(self, args):

        # Initialize the OSC message dispatch system.
        self.dispatch = dispatcher.Dispatcher()
        self.dispatch.map("/image",  self.image_message)
        self.dispatch.map("/tile",   self.tile_message)
        self.dispatch.map("/debug",  self.debug_message)
        self.dispatch.map("/motion", self.motion_message)
        self.dispatch.set_default_handler(self.unknown_message)

        # Start and run the server.
        self.server = osc_server.OSCUDPServer((args.recv, client_UDP_port), self.dispatch)
        # self.server_thread = threading.Thread(target=self.server.serve_forever)
        # self.server_thread.daemon = True
        # self.server_thread.start()
        log.info("started OSC server on port %s:%d", args.recv, client_UDP_port)

        # Create a socket to communicate to the vision system.
        self.vision = udp_client.SimpleUDPClient(args.ip, vision_UDP_port)
        log.info("Opened vision client socket to send to %s:%d", args.ip, vision_UDP_port)

        # Initialize annotations.

        # Region of interest rectangles, specified as pairs of points in (x,y) (col,row) order.
        self.roi = camera_roi

        # Open the viewer window.
        cv.namedWindow("viewer")

        # Set up a default frame buffer
        self.frame_height = 270
        self.frame_width  = 720
        self.frame = np.zeros((self.frame_height, self.frame_width), dtype=np.uint8)

        # Open the debug window.
        # cv.namedWindow("debug")
        self.debug = np.zeros((self.frame_height, self.frame_width), dtype=np.uint8)

        self.update_window()

        # request data streams
        self.vision.send_message("/request/images", [])
        self.vision.send_message("/request/motion", [])
        self.last_timestamp = datetime.datetime.now()

        return

    #---------------------------------------------------------------
    def update_window(self):
        # update the window
        cv.imshow("viewer", self.frame)
        # cv.imshow("debug", self.debug)

        # Calling waitKey allows the window system event loop to run and update
        # the display.  The waitKey argument is in milliseconds.
        if cv.waitKey(33) & 0xFF == ord('q'):
            log.press("user pressed q")

    #---------------------------------------------------------------
    def set_frame_size(self, height, width):
        # Set the frame buffer size if necessary
        if self.frame_height != height or self.frame_width != width:
            log.info("Resizing frame buffer to %d, %d.", height, width)
            self.frame_height = height
            self.frame_width  = width
            self.frame = np.zeros((self.frame_height, self.frame_width), dtype=np.uint8)
            self.debug = np.zeros((self.frame_height, self.frame_width), dtype=np.uint8)

    #---------------------------------------------------------------
    def unknown_message(self, msgaddr, *args):
        """Default handler for unrecognized OSC messages."""
        log.warning("Received unmapped OSC message %s: %s", msgaddr, args)

    #---------------------------------------------------------------
    def redraw_annotations(self):
        for region in self.roi:
            cv.rectangle(self.frame, region[0], region[1], (255))
            cv.rectangle(self.debug, region[0], region[1], (255))

    #---------------------------------------------------------------
    def image_message(self, msgaddr, *args):
        """Process /image messages received via OSC over UDP."""

        height      = args[0] # int
        width       = args[1] # int
        full_height = args[2] # int
        full_width  = args[3] # int
        pixels      = args[4] # blob

        print("received /image message with %d, %d, %d, %d, %d" % (height, width, full_height, full_width, len(pixels)))

        # extract the message bytestring as an image
        new_image = np.frombuffer(pixels, dtype=np.uint8).reshape(height, width)

        # resize the frame buffer if needed
        self.set_frame_size(full_height, full_width)

        # update the frame buffer data, typically resizing the subsampled image
        if height != full_height or width != full_width:
            self.frame[:][:] = cv.resize(new_image, (full_width, full_height), interpolation=cv.INTER_NEAREST)
        else:
            self.frame = new_image

        # redisplay
        self.redraw_annotations()
        self.update_window()
        return

    #---------------------------------------------------------------
    def tile_message(self, msgaddr, *args):
        """Process /tile messages received via OSC over UDP."""

        height      = args[0] # int
        width       = args[1] # int
        offset_rows = args[2] # int
        offset_cols = args[3] # int
        full_height = args[4] # int
        full_width  = args[5] # int
        pixels      = args[6] # blob

        # print("received /tile message with %d, %d, %d, %d, %d, %d, %d" % (height, width, offset_rows, offset_cols, full_height, full_width, len(pixels)))

        # extract the message bytestring as an image
        new_image = np.frombuffer(pixels, dtype=np.uint8).reshape(height, width)

        # resize the frame buffer if needed
        self.set_frame_size(full_height, full_width)

        # update a portion of the frame buffer data
        self.frame[offset_rows:offset_rows+height,offset_cols:offset_cols+width] = new_image

        # redisplay
        self.redraw_annotations()
        self.update_window()

        # occasionally renew request for image stream
        now = datetime.datetime.now()
        if elapsed(self.last_timestamp, now, 5):
            self.vision.send_message("/request/images", [])
            self.last_timestamp = now

        return

    #---------------------------------------------------------------
    def debug_message(self, msgaddr, *args):
        """Process /debug messages received via OSC over UDP."""

        height      = args[0] # int
        width       = args[1] # int
        offset_rows = args[2] # int
        offset_cols = args[3] # int
        full_height = args[4] # int
        full_width  = args[5] # int
        pixels      = args[6] # blob

        print("received /debug message with %d, %d, %d, %d, %d, %d, %d" %
              (height, width, offset_rows, offset_cols, full_height, full_width, len(pixels)))

        # extract the message bytestring as an image
        new_image = np.frombuffer(pixels, dtype=np.uint8).reshape(height, width)

        # resize the frame buffer if needed
        self.set_frame_size(full_height, full_width)

        # update a portion of the frame buffer data
        self.debug[offset_rows:offset_rows+height,offset_cols:offset_cols+width] = new_image

        # redisplay
        self.redraw_annotations()
        self.update_window()
        return

    #---------------------------------------------------------------
    def motion_message(self, msgaddr, *args):
        """Process /motion messages received via OSC over UDP."""
        tracker = args[0]
        value = args[1]
        print("received /motion message with %s %s" % (tracker, value))

#================================================================
if __name__ == "__main__":

    # enable console log stream
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.DEBUG)
    console_handler.setFormatter(logging.Formatter('%(levelname)s:%(name)s: %(message)s'))
    logging.getLogger().addHandler(console_handler)

    parser = argparse.ArgumentParser( description = "Receive and display images received via OSC UDP.")
    parser.add_argument( '-q','--quiet', action='store_true', help="Run without viewer window or console output.")
    parser.add_argument( '--debug', action='store_true', help='Enable debugging output to log file.' )
    parser.add_argument( '--verbose', action='store_true', help='Enable even more detailed logging output.' )
    parser.add_argument("--ip", default=theater_IP,  help="IP address of the OSC destination (default: %(default)s).")
    parser.add_argument("--recv", default="0.0.0.0",    help="IP address of the OSC receiver (default: %(default)s).")
    args = parser.parse_args()

    viewer = VideoViewer(args)


    try:
        log.info("Entering server loop.")
        viewer.server.serve_forever()

    except KeyboardInterrupt:
        print ("User quit.")
