1#!/usr/bin/env python3
2
3#================================================================
4# Import standard Python modules.
5import argparse, csv
6
7# Import the numpy and OpenCV modules.
8import numpy as np
9import cv2 as cv
10
11# Documentation on scikit-learn:
12# http://scikit-learn.org/stable/modules/clustering.html#clustering
13# pip: scikit-learn
14# MacPorts: py39-scikit-learn
15
16import sklearn.preprocessing
17import sklearn.cluster
18
19# The following example was referenced while developing this script:
20# https://docs.opencv.org/4.5.3/de/d25/imgproc_color_conversions.html#color_convert_rgb_lab
21
22#================================================================
23# Main script follows.
24if __name__ == "__main__":
25 parser = argparse.ArgumentParser(description = """Select a color palette automatically from a set of images.""")
26 parser.add_argument( '-v', '--verbose', action='store_true', help='Enable more detailed output.' )
27 parser.add_argument( '--colors', type=int, default=32, help='Number of colors in palette (default: %(default)s.')
28 parser.add_argument( '--samples', type=int, default=10000, help='Pixels to sample from each image (default: %(default)s.')
29 parser.add_argument( '--preview', type=str, default='palette.png',
30 help='Path of output color table preview image (default: %(default)s.')
31 parser.add_argument( '--output', type=str, default='palette.csv',
32 help='Path of output color table (default: %(default)s.')
33 parser.add_argument( 'inputs', nargs="+", help = 'Images to process.')
34
35 args = parser.parse_args()
36
37 # Read a list of input images.
38 sources = []
39 for path in args.inputs:
40 if args.verbose:
41 print(f"Reading {path}.")
42 img = cv.imread(path)
43 rows, cols, planes = img.shape
44 if args.verbose:
45 print(f"Image size: {img.shape}.")
46 if planes != 3:
47 print("The source image needs to be three-channel RGB, {planes} channels were found.")
48 else:
49 sources.append(img)
50
51 # Sample an equal-sized subset of pixels from each flattened image.
52 # See https://scikit-learn.org/stable/modules/generated/sklearn.utils.shuffle.html
53 samples = [sklearn.utils.shuffle(src.reshape((-1,3)), n_samples=args.samples) for src in sources]
54
55 # Assemble the Python list into a single ndarray.
56 samples = np.stack(samples, axis=0)
57
58 if args.verbose:
59 print(f"Sample array size: {samples.shape} type: {samples.dtype}")
60
61 # Convert all samples to CIELAB color space.
62 # https://en.wikipedia.org/wiki/CIELAB_color_space
63 # https://docs.opencv.org/4.5.3/de/d25/imgproc_color_conversions.html#color_convert_rgb_lab
64
65 lab = cv.cvtColor(samples, cv.COLOR_BGR2Lab)
66
67 # Reshape into a linear pixel array.
68 pix = lab.reshape((-1, 3))
69 if args.verbose:
70 print(f"CIELAB pixel array shape: {pix.shape} type: {pix.dtype}")
71
72 # Apply clustering to find representative color points.
73 kmeans = sklearn.cluster.KMeans(n_clusters=args.colors).fit(pix)
74
75 if args.verbose:
76 print("Raw cluster centers:\n", kmeans.cluster_centers_)
77
78 # Convert the cluster vectors back to RGB colors as a 2D image.
79 centers = kmeans.cluster_centers_.reshape((1, args.colors, 3)).astype(np.uint8)
80 palette = cv.cvtColor(centers, cv.COLOR_Lab2BGR)
81
82 # Upscale for the preview output.
83 preview = cv.resize(palette, None, fx=32, fy=32, interpolation=cv.INTER_NEAREST)
84 cv.imwrite(args.preview, preview)
85
86 # Write a CSV file with integer RGB values.
87 with open(args.output, "w") as stream:
88 writer = csv.writer(stream)
89 for color in palette[0,:]:
90 writer.writerow((color[2], color[1], color[0]))
91
92
93