#!/usr/bin/env python3
#================================================================
# Import standard Python modules.
import argparse, csv
# Import the numpy and OpenCV modules.
import numpy as np
import cv2 as cv
# Documentation on scikit-learn:
# http://scikit-learn.org/stable/modules/clustering.html#clustering
# pip: scikit-learn
# MacPorts: py39-scikit-learn
import sklearn.preprocessing
import sklearn.cluster
# The following example was referenced while developing this script:
# https://docs.opencv.org/4.5.3/de/d25/imgproc_color_conversions.html#color_convert_rgb_lab
#================================================================
# Main script follows.
if __name__ == "__main__":
parser = argparse.ArgumentParser(description = """Select a color palette automatically from a set of images.""")
parser.add_argument( '-v', '--verbose', action='store_true', help='Enable more detailed output.' )
parser.add_argument( '--colors', type=int, default=32, help='Number of colors in palette (default: %(default)s.')
parser.add_argument( '--samples', type=int, default=10000, help='Pixels to sample from each image (default: %(default)s.')
parser.add_argument( '--preview', type=str, default='palette.png',
help='Path of output color table preview image (default: %(default)s.')
parser.add_argument( '--output', type=str, default='palette.csv',
help='Path of output color table (default: %(default)s.')
parser.add_argument( 'inputs', nargs="+", help = 'Images to process.')
args = parser.parse_args()
# Read a list of input images.
sources = []
for path in args.inputs:
if args.verbose:
print(f"Reading {path}.")
img = cv.imread(path)
rows, cols, planes = img.shape
if args.verbose:
print(f"Image size: {img.shape}.")
if planes != 3:
print("The source image needs to be three-channel RGB, {planes} channels were found.")
else:
sources.append(img)
# Sample an equal-sized subset of pixels from each flattened image.
# See https://scikit-learn.org/stable/modules/generated/sklearn.utils.shuffle.html
samples = [sklearn.utils.shuffle(src.reshape((-1,3)), n_samples=args.samples) for src in sources]
# Assemble the Python list into a single ndarray.
samples = np.stack(samples, axis=0)
if args.verbose:
print(f"Sample array size: {samples.shape} type: {samples.dtype}")
# Convert all samples to CIELAB color space.
# https://en.wikipedia.org/wiki/CIELAB_color_space
# https://docs.opencv.org/4.5.3/de/d25/imgproc_color_conversions.html#color_convert_rgb_lab
lab = cv.cvtColor(samples, cv.COLOR_BGR2Lab)
# Reshape into a linear pixel array.
pix = lab.reshape((-1, 3))
if args.verbose:
print(f"CIELAB pixel array shape: {pix.shape} type: {pix.dtype}")
# Apply clustering to find representative color points.
kmeans = sklearn.cluster.KMeans(n_clusters=args.colors).fit(pix)
if args.verbose:
print("Raw cluster centers:\n", kmeans.cluster_centers_)
# Convert the cluster vectors back to RGB colors as a 2D image.
centers = kmeans.cluster_centers_.reshape((1, args.colors, 3)).astype(np.uint8)
palette = cv.cvtColor(centers, cv.COLOR_Lab2BGR)
# Upscale for the preview output.
preview = cv.resize(palette, None, fx=32, fy=32, interpolation=cv.INTER_NEAREST)
cv.imwrite(args.preview, preview)
# Write a CSV file with integer RGB values.
with open(args.output, "w") as stream:
writer = csv.writer(stream)
for color in palette[0,:]:
writer.writerow((color[2], color[1], color[0]))