1# SPDX-License-Identifier: Apache-2.0
2# -----------------------------------------------------------------------------
3# Copyright 2019-2022 Arm Limited
4#
5# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6# use this file except in compliance with the License. You may obtain a copy
7# of the License at:
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14# License for the specific language governing permissions and limitations
15# under the License.
16# -----------------------------------------------------------------------------
17"""
18This module contains code for loading image metadata from a file path on disk.
19
20The directory path is structured:
21
22    TestSetName/TestFormat/FileName
23
24... and the file name is structured:
25
26    colorProfile-colorFormat-name[-flags].extension
27"""
28
29from collections.abc import Iterable
30import os
31import re
32import subprocess as sp
33
34from PIL import Image as PILImage
35
36import testlib.misc as misc
37
38
39CONVERT_BINARY = ["convert"]
40
41
42class ImageException(Exception):
43    """
44    Exception thrown for bad image specification.
45    """
46
47
48class TestImage():
49    """
50    Objects of this type contain metadata for a single test image on disk.
51
52    Attributes:
53        filePath: The path of the file on disk.
54        outFilePath: The path of the output file on disk.
55        testSet: The name of the test set.
56        testFormat: The test format group.
57        testFile: The test file name.
58        colorProfile: The image compression color profile.
59        colorFormat: The image color format.
60        name: The image human name.
61        is3D: True if the image is 3D, else False.
62        isAlphaScaled: True if the image wants alpha scaling, else False.
63        TEST_EXTS: Expected test image extensions.
64        PROFILES: Tuple of valid color profile values.
65        FORMATS: Tuple of valid color format values.
66        FLAGS: Map of valid flags (key) and their meaning (value).
67    """
68    TEST_EXTS = (".jpg", ".png", ".tga", ".dds", ".hdr", ".ktx")
69
70    PROFILES = ("ldr", "ldrs", "hdr")
71
72    FORMATS = ("l", "la", "xy", "rgb", "rgba")
73
74    FLAGS = {
75        # Flags for image compression control
76        "3": "3D image",
77        "m": "Mask image",
78        "a": "Alpha scaled image"
79    }
80
81    def __init__(self, filePath):
82        """
83        Create a new image definition, based on a structured file path.
84
85        Args:
86            filePath (str): The path of the image on disk.
87
88        Raises:
89            ImageException: The image couldn't be found or is unstructured.
90        """
91        self.filePath = os.path.abspath(filePath)
92        if not os.path.exists(self.filePath):
93            raise ImageException("Image doesn't exist (%s)" % filePath)
94
95        # Decode the path
96        scriptDir = os.path.dirname(__file__)
97        rootInDir = os.path.join(scriptDir, "..", "Images")
98        partialPath = os.path.relpath(self.filePath, rootInDir)
99        parts = misc.path_splitall(partialPath)
100        if len(parts) != 3:
101            raise ImageException("Image path not path triplet (%s)" % parts)
102        self.testSet = parts[0]
103        self.testFormat = parts[1]
104        self.testFile = parts[2]
105
106        # Decode the file name
107        self.decode_file_name(self.testFile)
108
109        # Output file path (store base without extension)
110        rootOutDir = os.path.join(scriptDir, "..", "..", "TestOutput")
111        outFilePath = os.path.join(rootOutDir, partialPath)
112        outFilePath = os.path.abspath(outFilePath)
113        outFilePath = os.path.splitext(outFilePath)[0]
114        self.outFilePath = outFilePath
115
116    def decode_file_name(self, fileName):
117        """
118        Utility function to decode metadata from an encoded file name.
119
120        Args:
121            fileName (str): The file name to tokenize.
122
123        Raises:
124            ImageException: The image file path is badly structured.
125        """
126        # Strip off the extension
127        rootName = os.path.splitext(fileName)[0]
128
129        parts = rootName.split("-")
130
131        # Decode the mandatory fields
132        if len(parts) >= 3:
133            self.colorProfile = parts[0]
134            if self.colorProfile not in self.PROFILES:
135                raise ImageException("Unknown color profile (%s)" % parts[0])
136
137            self.colorFormat = parts[1]
138            if self.colorFormat not in self.FORMATS:
139                raise ImageException("Unknown color format (%s)" % parts[1])
140
141            # Consistency check between directory and file names
142            reencode = "%s-%s" % (self.colorProfile, self.colorFormat)
143            compare = self.testFormat.lower()
144            if reencode != compare:
145                dat = (self.testFormat, reencode)
146                raise ImageException("Mismatched test and image (%s:%s)" % dat)
147
148            self.name = parts[2]
149
150        # Set default values for the optional fields
151        self.is3D = False
152        self.isAlphaScaled = False
153
154        # Decode the flags field if present
155        if len(parts) >= 4:
156            flags = parts[3]
157            seenFlags = set()
158            for flag in flags:
159                if flag in seenFlags:
160                    raise ImageException("Duplicate flag (%s)" % flag)
161                if flag not in self.FLAGS:
162                    raise ImageException("Unknown flag (%s)" % flag)
163                seenFlags.add(flag)
164
165            self.is3D = "3" in seenFlags
166            self.isAlphaScaled = "a" in seenFlags
167
168    def get_size(self):
169        """
170        Get the dimensions of this test image, if format is known.
171
172        Known cases today where the format is not known:
173
174        * 3D .dds files.
175        * Any .ktx, .hdr, .exr, or .astc file.
176
177        Returns:
178            tuple(int, int): The dimensions of a 2D image, or ``None`` if PIL
179            could not open the file.
180        """
181        try:
182            img = PILImage.open(self.filePath)
183        except IOError:
184            # HDR files
185            return None
186        except NotImplementedError:
187            # DDS files
188            return None
189
190        return (img.size[0], img.size[1])
191
192
193class Image():
194    """
195    Wrapper around an image on the file system.
196    """
197
198    # TODO: We don't support KTX yet, as ImageMagick doesn't.
199    SUPPORTED_LDR = ["bmp", "jpg", "png", "tga"]
200    SUPPORTED_HDR = ["exr", "hdr"]
201
202    @classmethod
203    def is_format_supported(cls, fileFormat, profile=None):
204        """
205        Test if a given file format is supported by the library.
206
207        Args:
208            fileFormat (str): The file extension (excluding the ".").
209            profile (str or None): The profile (ldr or hdr) of the image.
210
211        Returns:
212            bool: `True` if the image is supported, `False` otherwise.
213        """
214        assert profile in [None, "ldr", "hdr"]
215
216        if profile == "ldr":
217            return fileFormat in cls.SUPPORTED_LDR
218
219        if profile == "hdr":
220            return fileFormat in cls.SUPPORTED_HDR
221
222        return fileFormat in cls.SUPPORTED_LDR or \
223            fileFormat in cls.SUPPORTED_HDR
224
225    def __init__(self, filePath):
226        """
227        Construct a new Image.
228
229        Args:
230            filePath (str): The path to the image on disk.
231        """
232        self.filePath = filePath
233        self.proxyPath = None
234
235    def get_colors(self, coords):
236        """
237        Get the image colors at the given coordinate.
238
239        Args:
240            coords (tuple or list): A single coordinate, or a list of
241                coordinates to sample.
242
243        Returns:
244            tuple: A single sample color (if `coords` was a coordinate).
245            list: A list of sample colors (if `coords` was a list).
246
247            Colors are returned as float values between 0.0 and 1.0 for LDR,
248            and float values which may exceed 1.0 for HDR.
249        """
250        colors = []
251
252        # We accept both a list of positions and a single position;
253        # canonicalize here so the main processing only handles lists
254        isList = len(coords) != 0 and isinstance(coords[0], Iterable)
255
256        if not isList:
257            coords = [coords]
258
259        for (x, y) in coords:
260            command = list(CONVERT_BINARY)
261            command += [self.filePath]
262
263            # Ensure convert factors in format origin if needed
264            command += ["-auto-orient"]
265
266            command += [
267                "-format", "%%[pixel:p{%u,%u}]" % (x, y),
268                "info:"
269            ]
270
271            if os.name == 'nt':
272                command.insert(0, "magick")
273
274            result = sp.run(command, stdout=sp.PIPE, stderr=sp.PIPE,
275                            check=True, universal_newlines=True)
276
277            rawcolor = result.stdout.strip()
278
279            # Decode ImageMagick's annoying named color outputs. Note that this
280            # only handles "known" cases triggered by our test images, we don't
281            # support the entire ImageMagick named color table.
282            if rawcolor == "black":
283                colors.append([0.0, 0.0, 0.0, 1.0])
284            elif rawcolor == "white":
285                colors.append([1.0, 1.0, 1.0, 1.0])
286            elif rawcolor == "red":
287                colors.append([1.0, 0.0, 0.0, 1.0])
288            elif rawcolor == "blue":
289                colors.append([0.0, 0.0, 1.0, 1.0])
290
291            # Decode ImageMagick's format tuples
292            elif rawcolor.startswith("srgba"):
293                rawcolor = rawcolor[6:]
294                rawcolor = rawcolor[:-1]
295                channels = rawcolor.split(",")
296                for i, channel in enumerate(channels):
297                    if (i < 3) and channel.endswith("%"):
298                        channels[i] = float(channel[:-1]) / 100.0
299                    elif (i < 3) and not channel.endswith("%"):
300                        channels[i] = float(channel) / 255.0
301                    else:
302                        channels[i] = float(channel)
303                colors.append(channels)
304            elif rawcolor.startswith("srgb"):
305                rawcolor = rawcolor[5:]
306                rawcolor = rawcolor[:-1]
307                channels = rawcolor.split(",")
308                for i, channel in enumerate(channels):
309                    if (i < 3) and channel.endswith("%"):
310                        channels[i] = float(channel[:-1]) / 100.0
311                    if (i < 3) and not channel.endswith("%"):
312                        channels[i] = float(channel) / 255.0
313                channels.append(1.0)
314                colors.append(channels)
315            elif rawcolor.startswith("rgba"):
316                rawcolor = rawcolor[5:]
317                rawcolor = rawcolor[:-1]
318                channels = rawcolor.split(",")
319                for i, channel in enumerate(channels):
320                    if (i < 3) and channel.endswith("%"):
321                        channels[i] = float(channel[:-1]) / 100.0
322                    elif (i < 3) and not channel.endswith("%"):
323                        channels[i] = float(channel) / 255.0
324                    else:
325                        channels[i] = float(channel)
326                colors.append(channels)
327            elif rawcolor.startswith("rgb"):
328                rawcolor = rawcolor[4:]
329                rawcolor = rawcolor[:-1]
330                channels = rawcolor.split(",")
331                for i, channel in enumerate(channels):
332                    if (i < 3) and channel.endswith("%"):
333                        channels[i] = float(channel[:-1]) / 100.0
334                    if (i < 3) and not channel.endswith("%"):
335                        channels[i] = float(channel) / 255.0
336                channels.append(1.0)
337                colors.append(channels)
338            else:
339                print(x, y)
340                print(rawcolor)
341                assert False
342
343        # ImageMagick decodes DDS files as BGRA not RGBA; manually correct
344        if self.filePath.endswith("dds"):
345            for color in colors:
346                tmp = color[0]
347                color[0] = color[2]
348                color[2] = tmp
349
350        # ImageMagick decodes EXR files with premult alpha; manually correct
351        if self.filePath.endswith("exr"):
352            for color in colors:
353                color[0] /= color[3]
354                color[1] /= color[3]
355                color[2] /= color[3]
356
357        # Undo list canonicalization if we were given a single scalar coord
358        if not isList:
359            return colors[0]
360
361        return colors
362