# mediapipe_examples/face_detection.py
# adapted from https://colab.research.google.com/github/googlesamples/mediapipe/blob/main/examples/face_detector/python/face_detector.ipynb
# model file: https://storage.googleapis.com/mediapipe-models/face_detector/blaze_face_short_range/float16/1/blaze_face_short_range.tflite

import math
import argparse

import numpy as np
import cv2 as cv

import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description = "Demo MediaPipe face detection on a static image.")
    parser.add_argument('--verbose', action='store_true', help="enable even more detailed console output" )
    parser.add_argument('-i', '--input', default="face.png", type=str, help="input image file name (default: %(default)s)")
    parser.add_argument('out', default="detected.png", type=str, nargs='?', help="output image file name (default: %(default)s)")
    args = parser.parse_args()

    # Create an FaceDetector object.
    base_options = python.BaseOptions(model_asset_path='blaze_face_short_range.tflite')
    options = vision.FaceDetectorOptions(base_options=base_options)
    detector = vision.FaceDetector.create_from_options(options)

    # Load the input image.
    image = mp.Image.create_from_file(args.input)

    # Detect faces in the input image.
    detection_result = detector.detect(image)

    if args.verbose:
        print(f"detection_result: {detection_result}")

    # Process the detection result into an annotated image.
    annotated = np.copy(image.numpy_view())
    annotated = cv.cvtColor(annotated, cv.COLOR_RGB2BGR)
    height, width, channels = annotated.shape

    for detection in detection_result.detections:
        for pt in detection.keypoints:
            if args.verbose:
                print("Processing keypoint:", pt)
            keypoint_px = math.floor(pt.x * width), math.floor(pt.y * height)
            color, thickness, radius = (0, 255, 0), 2, 2
            cv.circle(annotated, keypoint_px, thickness, color, radius)

    cv.imwrite(args.out, annotated)
