#!/usr/bin/env python3

"""\
midi_to_mqtt.py : sample code in Python to translate MIDI events into MQTT messages

No copyright, 2021, Garth Zeglin.  This file is explicitly placed in the public domain.

"""
#----------------------------------------------------------------
# Configuration

# The following variables must be customized to set up the network and MIDI
# port settings.

# IDeATe MQTT server name.
mqtt_hostname = "mqtt.ideate.cmu.edu"

# IDeATe MQTT server port, specific to each course.
# Please see https://mqtt.ideate.cmu.edu for details.
mqtt_portnum  = 8884   # 16-223

# Username and password, provided by instructor for each course.
mqtt_username = ''
mqtt_password = ''

# MQTT publication topic.  This is usually the students Andrew ID.
mqtt_topic = 'orchestra'
# mqtt_topic = ''

# MQTT receive subscription.  This is usually the partner Andrew ID.
mqtt_subscription = 'unspecified'

# MIDI port name on macOS for our preferred Akai drum pad.
midi_portname = 'MPD218 Port A'

#----------------------------------------------------------------
# Standard Python modules.
import sys, time, signal, platform

#----------------------------------------------------------------
if mqtt_username == '' or mqtt_password == '' or mqtt_topic == '' or midi_portname == '': 
    print("""\
This script must be customized before it can be used.  Please edit the file with
a Python or text editor and set the variables appropriately in the Configuration
section at the top of the file.
""")
    sys.exit(0)
    
#----------------------------------------------------------------

# Import the MQTT client library.
# documentation: https://www.eclipse.org/paho/clients/python/docs/
import paho.mqtt.client as mqtt

# For documentation on python-rtmidi: https://pypi.org/project/python-rtmidi/
import rtmidi

#----------------------------------------------------------------
class MidiMqtt(object):
    """Class to manage a MIDI input connection and process MIDI events into MQTT messages."""
    
    def __init__(self, portname, client):

        # save the client handle
        self.client = client

        print(f"Opening MIDI input to look for {portname}.")
        
        # Initialize the MIDI input system and read the currently available ports.
        self.midi_in = rtmidi.MidiIn()
        
        for idx, name in enumerate(self.midi_in.get_ports()):
            if portname in name:
                print("Found preferred MIDI input device %d: %s" % (idx, name))
                self.midi_in.open_port(idx)
                self.midi_in.set_callback(self.midi_received)
                break
            else:
                print("Ignoring unselected MIDI device: ", name)

        if not self.midi_in.is_port_open():
            if platform.system() == 'Windows':
                print("Virtual MIDI inputs are not currently supported on Windows, see python-rtmidi documentation.")
            else:
                print("Creating virtual MIDI input.")
                self.midi_in.open_virtual_port(portname)

    def decode_mpd218_key(self, key):
        """Interpret a MPD218 pad event key value as a row, column, and bank position.
        Row 0 is the front/bottom row (Pads 1-4), row 3 is the back/top row (Pads 13-16).
        Column 0 is the left, column 3 is the right.
        Bank 0 is the A bank, bank 1 is the B bank, bank 2 is the C bank.

        :param key: an integer MIDI note value
        :return: (row, column, bank)
        """
        # Decode the key into coordinates on the 4x4 pad grid.
        bank = (key - 36) // 16
        pos = (key - 36) % 16
        row = pos // 4
        col = pos % 4
        return row, col, bank

    def decode_mpd218_cc(self, cc):
        """Interpret a MPD218 knob control change event as a knob index and bank position.
        The MPD218 uses a non-contiguous set of channel indices so this normalizes the result.
        The knob index ranges from 1 to 6 matching the knob labels.
        Bank 0 is the A bank, bank 1 is the B bank, bank 2 is the C bank.

        :param cc: an integer MIDI control channel identifier
        :return: (knob, bank)
        """
        if cc < 16:
            knob = {3:1, 9:2, 12:3, 13:4, 14:5, 15:6}.get(cc)
            bank = 0
        else:
            knob = 1 + ((cc - 16) % 6)
            bank = 1 + ((cc - 16) // 6)
        return knob, bank

    def decode_message(self, message):
        """Decode a MIDI message expressed as a list of integers and perform callbacks
        for recognized message types.

        :param message: list of integers containing a single MIDI message
        """
        if len(message) > 0:
            status = message[0] & 0xf0
            channel = (message[0] & 0x0f) + 1

            if len(message) == 2:
                if status == 0xd0: # == 0xdx, channel pressure, any channel
                    return self.channel_pressure(channel, message[1])

            elif len(message) == 3:
                if status == 0x90: # == 0x9x, note on, any channel
                    return self.note_on(channel, message[1], message[2])

                elif status == 0x80: # == 0x8x, note off, any channel
                    return self.note_off(channel, message[1], message[2])

                elif status == 0xb0: # == 0xbx, control change, any channel
                    return self.control_change(channel, message[1], message[2])

                elif status == 0xa0: # == 0xax, polyphonic key pressure, any channel
                    return self.polyphonic_key_pressure(channel, message[1], message[2])
    
    def midi_received(self, data, unused):
        msg, delta_time = data
        self.decode_message(msg)

    def note_off(self, channel, key, velocity):
        """Function to receive messages starting with 0x80 through 0x8F.

        :param channel: integer from 1 to 16
        :param key: integer from 0 to 127
        :param velocity: integer from 0 to 127
        """
        pass

    def note_on(self, channel, key, velocity):
        """Function to receive messages starting with 0x90 through 0x9F.

        :param channel: integer from 1 to 16
        :param key: integer from 0 to 127
        :param velocity: integer from 0 to 127
        """
        topic = mqtt_topic + '/noteOn'
        payload = "%d %d" % (key, velocity)
        self.client.publish(topic, payload)

        # apply some logic to route control inputs to other topics
        row, column, bank = self.decode_mpd218_key(key)
        topic = mqtt_topic + '/instrument%d' % (column+1)
        payload = "%d %d" % (row, velocity)
        self.client.publish(topic, payload)
        
    def polyphonic_key_pressure(self, channel, key, value):
        """Function to receive messages starting with 0xA0 through 0xAF.

        :param channel: integer from 1 to 16
        :param key: integer from 0 to 127
        :param value: integer from 0 to 127
        """
        pass

    def control_change(self, channel, control, value):
        """Function to receive messages starting with 0xB0 through 0xBF.

        :param channel: integer from 1 to 16
        :param control: integer from 0 to 127; some have special meanings
        :param value: integer from 0 to 127
        """
        topic = mqtt_topic + '/controlChange'
        payload = "%d %d" % (control, value)
        self.client.publish(topic, payload)

        # apply some logic to route control inputs to other topics
        knob, bank = self.decode_mpd218_cc(control)
        topic = mqtt_topic + '/project%d' % (knob)
        payload = "%d" % (value)
        self.client.publish(topic, payload)


    def channel_pressure(self, channel, value):
        """Function to receive messages starting with 0xD0 through 0xDF.

        :param channel: integer from 1 to 16
        :param value: integer from 0 to 127
        """
        pass

#----------------------------------------------------------------
# Global script variables.

midi_port = None
client = None
    
#----------------------------------------------------------------
# Attach a handler to the keyboard interrupt (control-C).
def _sigint_handler(signal, frame):
    print("Keyboard interrupt caught, closing down...")
    if client is not None:
        client.loop_stop()
    sys.exit(0)
signal.signal(signal.SIGINT, _sigint_handler)        

#----------------------------------------------------------------
# MQTT networking functions.

#----------------------------------------------------------------
# The callback for when the broker responds to our connection request.
def on_connect(client, userdata, flags, rc):
    print(f"MQTT connected with flags: {flags}, result code: {rc}")

    # Subscribing in on_connect() means that if we lose the connection and
    # reconnect then subscriptions will be renewed.  The hash mark is a
    # multi-level wildcard, so this will subscribe to all subtopics of 16223
    client.subscribe(mqtt_subscription)
    return

#----------------------------------------------------------------
# The callback for when a message has been received on a topic to which this
# client is subscribed.  The message variable is a MQTTMessage that describes
# all of the message parameters.

# Some useful MQTTMessage fields: topic, payload, qos, retain, mid, properties.
#   The payload is a binary string (bytes).
#   qos is an integer quality of service indicator (0,1, or 2)
#   mid is an integer message ID.
def on_message(client, userdata, msg):
    print(f"message received: topic: {msg.topic} payload: {msg.payload}")
    return

#----------------------------------------------------------------
# Launch the MQTT network client
client = mqtt.Client()
client.enable_logger()
client.on_connect = on_connect
client.on_message = on_message
client.tls_set()
client.username_pw_set(mqtt_username, mqtt_password)

# Start a background thread to connect to the MQTT network.
client.connect_async(mqtt_hostname, mqtt_portnum)
client.loop_start()

#----------------------------------------------------------------
# Connect to the MIDI port.  MIDI events will be received on a separate thread.
midi_port = MidiMqtt(midi_portname, client)

# Nothing to do on the main thread.
while(True):
    time.sleep(60)
