1# SPDX-License-Identifier: Apache-2.0
2# -----------------------------------------------------------------------------
3# Copyright 2019-2023 Arm Limited
4#
5# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6# use this file except in compliance with the License. You may obtain a copy
7# of the License at:
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14# License for the specific language governing permissions and limitations
15# under the License.
16# -----------------------------------------------------------------------------
17"""
18These classes provide an abstraction around the astcenc command line tool,
19allowing the rest of the image test suite to ignore changes in the command line
20interface.
21"""
22
23import os
24import re
25import subprocess as sp
26import sys
27
28
29class EncoderBase():
30    """
31    This class is a Python wrapper for the `astcenc` binary, providing an
32    abstract means to set command line options and parse key results.
33
34    This is an abstract base class providing some generic helper functionality
35    used by concrete instantiations of subclasses.
36
37    Attributes:
38        binary: The encoder binary path.
39        variant: The encoder SIMD variant being tested.
40        name: The encoder name to use in reports.
41        VERSION: The encoder version or branch.
42        SWITCHES: Dict of switch replacements for different color formats.
43        OUTPUTS: Dict of output file extensions for different color formats.
44    """
45
46    VERSION = None
47    SWITCHES = None
48    OUTPUTS = None
49
50    def __init__(self, name, variant, binary):
51        """
52        Create a new encoder instance.
53
54        Args:
55            name (str): The name of the encoder.
56            variant (str): The SIMD variant of the encoder.
57            binary (str): The path to the binary on the file system.
58        """
59        self.name = name
60        self.variant = variant
61        self.binary = binary
62
63    def build_cli(self, image, blockSize="6x6", preset="-thorough",
64                  keepOutput=True, threads=None):
65        """
66        Build the command line needed for the given test.
67
68        Args:
69            image (TestImage): The test image to compress.
70            blockSize (str): The block size to use.
71            preset (str): The quality-performance preset to use.
72            keepOutput (bool): Should the test preserve output images? This is
73                only a hint and discarding output may be ignored if the encoder
74                version used can't do it natively.
75            threads (int or None): The thread count to use.
76
77        Returns:
78            list(str): A list of command line arguments.
79        """
80        # pylint: disable=unused-argument,no-self-use,redundant-returns-doc
81        assert False, "Missing subclass implementation"
82
83    def execute(self, command):
84        """
85        Run a subprocess with the specified command.
86
87        Args:
88            command (list(str)): The list of command line arguments.
89
90        Returns:
91            list(str): The output log (stdout) split into lines.
92        """
93        # pylint: disable=no-self-use
94        try:
95            result = sp.run(command, stdout=sp.PIPE, stderr=sp.PIPE,
96                            check=True, universal_newlines=True)
97        except (OSError, sp.CalledProcessError):
98            print("ERROR: Test run failed")
99            print("  + %s" % " ".join(command))
100            qcommand = ["\"%s\"" % x for x in command]
101            print("  + %s" % ", ".join(qcommand))
102            sys.exit(1)
103
104        return result.stdout.splitlines()
105
106    def parse_output(self, image, output):
107        """
108        Parse the log output for PSNR and performance metrics.
109
110        Args:
111            image (TestImage): The test image to compress.
112            output (list(str)): The output log from the compression process.
113
114        Returns:
115            tuple(float, float, float): PSNR in dB, TotalTime in seconds, and
116            CodingTime in seconds.
117        """
118        # Regex pattern for image quality
119        patternPSNR = re.compile(self.get_psnr_pattern(image))
120        patternTTime = re.compile(self.get_total_time_pattern())
121        patternCTime = re.compile(self.get_coding_time_pattern())
122        patternCRate = re.compile(self.get_coding_rate_pattern())
123
124        # Extract results from the log
125        runPSNR = None
126        runTTime = None
127        runCTime = None
128        runCRate = None
129
130        for line in output:
131            match = patternPSNR.match(line)
132            if match:
133                runPSNR = float(match.group(1))
134
135            match = patternTTime.match(line)
136            if match:
137                runTTime = float(match.group(1))
138
139            match = patternCTime.match(line)
140            if match:
141                runCTime = float(match.group(1))
142
143            match = patternCRate.match(line)
144            if match:
145                runCRate = float(match.group(1))
146
147        stdout = "\n".join(output)
148        assert runPSNR is not None, "No coding PSNR found %s" % stdout
149        assert runTTime is not None, "No total time found %s" % stdout
150        assert runCTime is not None, "No coding time found %s" % stdout
151        assert runCRate is not None, "No coding rate found %s" % stdout
152        return (runPSNR, runTTime, runCTime, runCRate)
153
154    def get_psnr_pattern(self, image):
155        """
156        Get the regex pattern to match the image quality metric.
157
158        Note, while this function is called PSNR for some images we may choose
159        to match another metric (e.g. mPSNR for HDR images).
160
161        Args:
162            image (TestImage): The test image we are compressing.
163
164        Returns:
165            str: The string for a regex pattern.
166        """
167        # pylint: disable=unused-argument,no-self-use,redundant-returns-doc
168        assert False, "Missing subclass implementation"
169
170    def get_total_time_pattern(self):
171        """
172        Get the regex pattern to match the total compression time.
173
174        Returns:
175            str: The string for a regex pattern.
176        """
177        # pylint: disable=unused-argument,no-self-use,redundant-returns-doc
178        assert False, "Missing subclass implementation"
179
180    def get_coding_time_pattern(self):
181        """
182        Get the regex pattern to match the coding compression time.
183
184        Returns:
185            str: The string for a regex pattern.
186        """
187        # pylint: disable=unused-argument,no-self-use,redundant-returns-doc
188        assert False, "Missing subclass implementation"
189
190    def run_test(self, image, blockSize, preset, testRuns, keepOutput=True,
191                 threads=None):
192        """
193        Run the test N times.
194
195        Args:
196            image (TestImage): The test image to compress.
197            blockSize (str): The block size to use.
198            preset (str): The quality-performance preset to use.
199            testRuns (int): The number of test runs.
200            keepOutput (bool): Should the test preserve output images? This is
201                only a hint and discarding output may be ignored if the encoder
202                version used can't do it natively.
203            threads (int or None): The thread count to use.
204
205        Returns:
206            tuple(float, float, float, float): Returns the best results from
207            the N test runs, as PSNR (dB), total time (seconds), coding time
208            (seconds), and coding rate (M pixels/s).
209        """
210        # pylint: disable=assignment-from-no-return
211        command = self.build_cli(image, blockSize, preset, keepOutput, threads)
212
213        # Execute test runs
214        bestPSNR = 0
215        bestTTime = sys.float_info.max
216        bestCTime = sys.float_info.max
217        bestCRate = 0
218
219        for _ in range(0, testRuns):
220            output = self.execute(command)
221            result = self.parse_output(image, output)
222
223            # Keep the best results (highest PSNR, lowest times, highest rate)
224            bestPSNR = max(bestPSNR, result[0])
225            bestTTime = min(bestTTime, result[1])
226            bestCTime = min(bestCTime, result[2])
227            bestCRate = max(bestCRate, result[3])
228
229        return (bestPSNR, bestTTime, bestCTime, bestCRate)
230
231
232class Encoder2x(EncoderBase):
233    """
234    This class wraps the latest `astcenc` 2.x series binaries from main branch.
235    """
236    VERSION = "main"
237
238    SWITCHES = {
239        "ldr": "-tl",
240        "ldrs": "-ts",
241        "hdr": "-th",
242        "hdra": "-tH"
243    }
244
245    OUTPUTS = {
246        "ldr": ".png",
247        "ldrs": ".png",
248        "hdr": ".exr",
249        "hdra": ".exr"
250    }
251
252    def __init__(self, variant, binary=None):
253        name = "astcenc-%s-%s" % (variant, self.VERSION)
254
255        if binary is None:
256            if variant != "universal":
257                binary = f"./bin/astcenc-{variant}"
258            else:
259                binary = "./bin/astcenc"
260
261            if os.name == 'nt':
262                binary = f"{binary}.exe"
263
264        super().__init__(name, variant, binary)
265
266    def build_cli(self, image, blockSize="6x6", preset="-thorough",
267                  keepOutput=True, threads=None):
268        opmode = self.SWITCHES[image.colorProfile]
269        srcPath = image.filePath
270
271        if keepOutput:
272            dstPath = image.outFilePath + self.OUTPUTS[image.colorProfile]
273            dstDir = os.path.dirname(dstPath)
274            dstFile = os.path.basename(dstPath)
275            dstPath = os.path.join(dstDir, self.name, preset[1:], blockSize, dstFile)
276
277            dstDir = os.path.dirname(dstPath)
278            os.makedirs(dstDir, exist_ok=True)
279        elif sys.platform == "win32":
280            dstPath = "nul"
281        else:
282            dstPath = "/dev/null"
283
284        command = [
285            self.binary, opmode, srcPath, dstPath,
286            blockSize, preset, "-silent"
287        ]
288
289        if image.colorFormat == "xy":
290            command.append("-normal")
291
292        if image.isAlphaScaled:
293            command.append("-a")
294            command.append("1")
295
296        if threads is not None:
297            command.append("-j")
298            command.append("%u" % threads)
299
300        return command
301
302    def get_psnr_pattern(self, image):
303        if image.colorProfile != "hdr":
304            if image.colorFormat != "rgba":
305                patternPSNR = r"\s*PSNR \(LDR-RGB\):\s*([0-9.]*) dB"
306            else:
307                patternPSNR = r"\s*PSNR \(LDR-RGBA\):\s*([0-9.]*) dB"
308        else:
309            patternPSNR = r"\s*mPSNR \(RGB\)(?: \[.*?\] )?:\s*([0-9.]*) dB.*"
310        return patternPSNR
311
312    def get_total_time_pattern(self):
313        return r"\s*Total time:\s*([0-9.]*) s"
314
315    def get_coding_time_pattern(self):
316        return r"\s*Coding time:\s*([0-9.]*) s"
317
318    def get_coding_rate_pattern(self):
319        return r"\s*Coding rate:\s*([0-9.]*) MT/s"
320
321
322class Encoder2xRel(Encoder2x):
323    """
324    This class wraps a released 2.x series binary.
325    """
326    def __init__(self, version, variant):
327
328        self.VERSION = version
329
330        if variant != "universal":
331            binary = f"./Binaries/{version}/astcenc-{variant}"
332        else:
333            binary = f"./Binaries/{version}/astcenc"
334
335        if os.name == 'nt':
336            binary = f"{binary}.exe"
337
338        super().__init__(variant, binary)
339
340
341class Encoder1_7(EncoderBase):
342    """
343    This class wraps the 1.7 series binaries.
344    """
345    VERSION = "1.7"
346
347    SWITCHES = {
348        "ldr": "-tl",
349        "ldrs": "-ts",
350        "hdr": "-t"
351    }
352
353    OUTPUTS = {
354        "ldr": ".tga",
355        "ldrs": ".tga",
356        "hdr": ".htga"
357    }
358
359    def __init__(self):
360        name = "astcenc-%s" % self.VERSION
361        if os.name == 'nt':
362            binary = "./Binaries/1.7/astcenc.exe"
363        else:
364            binary = "./Binaries/1.7/astcenc"
365
366        super().__init__(name, None, binary)
367
368    def build_cli(self, image, blockSize="6x6", preset="-thorough",
369                  keepOutput=True, threads=None):
370
371        if preset == "-fastest":
372            preset = "-fast"
373
374        opmode = self.SWITCHES[image.colorProfile]
375        srcPath = image.filePath
376
377        dstPath = image.outFilePath + self.OUTPUTS[image.colorProfile]
378        dstDir = os.path.dirname(dstPath)
379        dstFile = os.path.basename(dstPath)
380        dstPath = os.path.join(dstDir, self.name, preset[1:], blockSize, dstFile)
381
382        dstDir = os.path.dirname(dstPath)
383        os.makedirs(dstDir, exist_ok=True)
384
385        command = [
386            self.binary, opmode, srcPath, dstPath,
387            blockSize, preset, "-silentmode", "-time", "-showpsnr"
388        ]
389
390        if image.colorFormat == "xy":
391            command.append("-normal_psnr")
392
393        if image.colorProfile == "hdr":
394            command.append("-hdr")
395
396        if image.isAlphaScaled:
397            command.append("-alphablend")
398
399        if threads is not None:
400            command.append("-j")
401            command.append("%u" % threads)
402
403        return command
404
405    def get_psnr_pattern(self, image):
406        if image.colorProfile != "hdr":
407            if image.colorFormat != "rgba":
408                patternPSNR = r"PSNR \(LDR-RGB\):\s*([0-9.]*) dB"
409            else:
410                patternPSNR = r"PSNR \(LDR-RGBA\):\s*([0-9.]*) dB"
411        else:
412            patternPSNR = r"mPSNR \(RGB\)(?: \[.*?\] )?:\s*([0-9.]*) dB.*"
413        return patternPSNR
414
415    def get_total_time_pattern(self):
416        # Pattern match on a new pattern for a 2.1 compatible variant
417        # return r"Elapsed time:\s*([0-9.]*) seconds.*"
418        return r"\s*Total time:\s*([0-9.]*) s"
419
420    def get_coding_time_pattern(self):
421        # Pattern match on a new pattern for a 2.1 compatible variant
422        # return r".* coding time: \s*([0-9.]*) seconds"
423        return r"\s*Coding time:\s*([0-9.]*) s"
424
425    def get_coding_rate_pattern(self):
426        # Pattern match on a new pattern for a 2.1 compatible variant
427        return r"\s*Coding rate:\s*([0-9.]*) MT/s"
428