Image Analysis in Python with SciKit

Full Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python3

#================================================================
# Import standard Python modules.
import argparse, csv

# Import the numpy and OpenCV modules.
import numpy as np
import cv2 as cv

# Documentation on scikit-learn:
#   http://scikit-learn.org/stable/modules/clustering.html#clustering
#   pip:      scikit-learn
#   MacPorts: py39-scikit-learn

import sklearn.preprocessing
import sklearn.cluster

# The following example was referenced while developing this script:
# https://docs.opencv.org/4.5.3/de/d25/imgproc_color_conversions.html#color_convert_rgb_lab

#================================================================
# Main script follows.
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description = """Select a color palette automatically from a set of images.""")
    parser.add_argument( '-v', '--verbose', action='store_true', help='Enable more detailed output.' )
    parser.add_argument( '--colors', type=int, default=32, help='Number of colors in palette (default: %(default)s.')
    parser.add_argument( '--samples', type=int, default=10000, help='Pixels to sample from each image (default: %(default)s.')
    parser.add_argument( '--preview', type=str, default='palette.png',
                         help='Path of output color table preview image (default: %(default)s.')
    parser.add_argument( '--output', type=str, default='palette.csv',
                         help='Path of output color table (default: %(default)s.')
    parser.add_argument( 'inputs', nargs="+", help = 'Images to process.')

    args = parser.parse_args()

    # Read a list of input images.
    sources = []
    for path in args.inputs:
        if args.verbose:
            print(f"Reading {path}.")
        img = cv.imread(path)
        rows, cols, planes = img.shape
        if args.verbose:
            print(f"Image size: {img.shape}.")
        if planes != 3:
            print("The source image needs to be three-channel RGB, {planes} channels were found.")
        else:
            sources.append(img)

    # Sample an equal-sized subset of pixels from each flattened image.
    # See https://scikit-learn.org/stable/modules/generated/sklearn.utils.shuffle.html
    samples = [sklearn.utils.shuffle(src.reshape((-1,3)), n_samples=args.samples) for src in sources]

    # Assemble the Python list into a single ndarray.
    samples = np.stack(samples, axis=0)
    
    if args.verbose:
        print(f"Sample array size: {samples.shape} type: {samples.dtype}")
    
    # Convert all samples to CIELAB color space.
    #   https://en.wikipedia.org/wiki/CIELAB_color_space
    #   https://docs.opencv.org/4.5.3/de/d25/imgproc_color_conversions.html#color_convert_rgb_lab
    
    lab = cv.cvtColor(samples, cv.COLOR_BGR2Lab)

    # Reshape into a linear pixel array.
    pix = lab.reshape((-1, 3))
    if args.verbose:
        print(f"CIELAB pixel array shape: {pix.shape} type: {pix.dtype}")

    # Apply clustering to find representative color points.
    kmeans = sklearn.cluster.KMeans(n_clusters=args.colors).fit(pix)

    if args.verbose:
        print("Raw cluster centers:\n", kmeans.cluster_centers_)
    
    # Convert the cluster vectors back to RGB colors as a 2D image.
    centers = kmeans.cluster_centers_.reshape((1, args.colors, 3)).astype(np.uint8)
    palette = cv.cvtColor(centers, cv.COLOR_Lab2BGR)

    # Upscale for the preview output.
    preview = cv.resize(palette, None, fx=32, fy=32, interpolation=cv.INTER_NEAREST)
    cv.imwrite(args.preview, preview)
    
    # Write a CSV file with integer RGB values.
    with open(args.output, "w") as stream:
        writer = csv.writer(stream)
        for color in palette[0,:]:
            writer.writerow((color[2], color[1], color[0]))