1#!/usr/bin/env python3
2# SPDX-License-Identifier: Apache-2.0
3# -----------------------------------------------------------------------------
4# Copyright 2020-2022 Arm Limited
5#
6# Licensed under the Apache License, Version 2.0 (the "License"); you may not
7# use this file except in compliance with the License. You may obtain a copy
8# of the License at:
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15# License for the specific language governing permissions and limitations
16# under the License.
17# -----------------------------------------------------------------------------
18"""
19The ``astc_test_result_plot.py`` script consolidates all current sets of
20reference results into a single graphical plot.
21"""
22
23import re
24import os
25import sys
26
27import numpy as np
28import matplotlib.pyplot as plt
29
30import testlib.resultset as trs
31from collections import defaultdict as ddict
32
33
34def find_reference_results():
35    """
36    Scrape the Test/Images directory for result CSV files and return an
37    mapping of the result sets.
38
39    Returns:
40        Returns a three deep tree of dictionaries, with the final dict
41        pointing at a `ResultSet` object. The hierarchy is:
42
43            imageSet => quality => encoder => result
44    """
45    scriptDir = os.path.dirname(__file__)
46    imageDir = os.path.join(scriptDir, "Images")
47
48    # Pattern for extracting useful data from the CSV file name
49    filePat = re.compile(r"astc_reference-(.+)_(.+)_results\.csv")
50
51    # Build a three level dictionary we can write into
52    results = ddict(lambda: ddict(lambda: ddict()))
53
54    # Final all CSVs, load them and store them in the dict tree
55    for root, dirs, files in os.walk(imageDir):
56        for name in files:
57            match = filePat.match(name)
58            if match:
59                fullPath = os.path.join(root, name)
60
61                encoder = match.group(1)
62                quality = match.group(2)
63                imageSet = os.path.basename(root)
64
65                if imageSet not in ["Kodak", "Khronos", "HDRIHaven", "KodakSim"]:
66                    continue
67
68                testRef = trs.ResultSet(imageSet)
69                testRef.load_from_file(fullPath)
70
71                results[imageSet][quality]["ref-%s" % encoder] = testRef
72
73    return results
74
75
76def get_series(results, tgtEncoder, tgtQuality, resFilter=lambda x: True):
77    psnrData = []
78    mtsData = []
79    marker = []
80    records = []
81
82    for imageSet, iResults in results.items():
83
84        for quality, qResults in iResults.items():
85            if quality != tgtQuality:
86                continue
87
88            for encoder, eResults in qResults.items():
89                if encoder != tgtEncoder:
90                    continue
91
92                for record in eResults.records:
93                    if resFilter(record):
94                        records.append(record)
95                        psnrData.append(record.psnr)
96                        mtsData.append(record.cRate)
97
98                        if "ldr-xy" in record.name:
99                            marker.append('$N$')
100                        elif "ldr-l" in record.name:
101                            marker.append('$G$')
102                        elif "ldr" in record.name:
103                            marker.append('$L$')
104                        elif "hdr" in record.name:
105                            marker.append('$H$')
106                        else:
107                            marker.append('$?$')
108
109
110    return mtsData, psnrData, marker, records
111
112
113def get_series_rel(results, refEncoder, refQuality, tgtEncoder, tgtQuality, resFilter=lambda x: True):
114
115    mts1, psnr1, marker1, rec1 = get_series(results, tgtEncoder, tgtQuality, resFilter)
116
117    if refEncoder is None:
118        refEncoder = tgtEncoder
119
120    if refQuality is None:
121        refQuality = tgtQuality
122
123    mts2, psnr2, marker2, rec2 = get_series(results, refEncoder, refQuality, resFilter)
124
125    mtsm  = [x/mts2[i] for i, x in enumerate(mts1)]
126    psnrm = [x - psnr2[i] for i, x in enumerate(psnr1)]
127
128    return mtsm, psnrm, marker1, rec1
129
130
131def get_human_eq_name(encoder, quality):
132    parts = encoder.split("-")
133    if len(parts) == 2:
134        return "astcenc %s -%s" % (parts[1], quality)
135    else:
136        return "astcenc-%s %s -%s" % (parts[2], parts[1], quality)
137
138
139def get_human_e_name(encoder):
140    parts = encoder.split("-")
141    if len(parts) == 2:
142        return "astcenc %s" % parts[1]
143    else:
144        return "astcenc-%s %s" % (parts[2], parts[1])
145
146
147def get_human_q_name(quality):
148    return "-%s" % quality
149
150
151def plot(results, chartRows, chartCols, blockSizes,
152         relative, pivotEncoder, pivotQuality, fileName, limits):
153
154    fig, axs = plt.subplots(nrows=len(chartRows), ncols=len(chartCols),
155                            sharex=True, sharey=True, figsize=(15, 8.43))
156
157    for a in fig.axes:
158        a.tick_params(
159            axis="x", which="both",
160            bottom=True, top=False, labelbottom=True)
161
162        a.tick_params(
163            axis="y", which="both",
164            left=True, right=False, labelleft=True)
165
166    for i, row in enumerate(chartRows):
167        for j, col in enumerate(chartCols):
168            if row == "fastest" and (("1.7" in col) or ("2.0" in col)):
169                if len(chartCols) == 1:
170                    fig.delaxes(axs[i])
171                else:
172                    fig.delaxes(axs[i][j])
173                continue
174
175            if len(chartRows) == 1 and len(chartCols) == 1:
176                ax = axs
177            elif len(chartCols) == 1:
178                ax = axs[i]
179            else:
180                ax = axs[i, j]
181
182            title = get_human_eq_name(col, row)
183
184            if not relative:
185                ax.set_title(title, y=0.97, backgroundcolor="white")
186                ax.set_xlabel('Coding performance (MTex/s)')
187                ax.set_ylabel('PSNR (dB)')
188            else:
189                if pivotEncoder and pivotQuality:
190                    tag = get_human_eq_name(pivotEncoder, pivotQuality)
191                elif pivotEncoder:
192                    tag = get_human_e_name(pivotEncoder)
193                else:
194                    assert(pivotQuality)
195                    tag = get_human_q_name(pivotQuality)
196
197                ax.set_title("%s vs. %s" % (title, tag), y=0.97, backgroundcolor="white")
198                ax.set_xlabel('Performance scaling')
199                ax.set_ylabel('PSNR delta (dB)')
200
201            for k, series in enumerate(blockSizes):
202                fn = lambda x: x.blkSz == series
203
204                if not relative:
205                    x, y, m, r = get_series(results, col, row, fn)
206                else:
207                    x, y, m, r = get_series_rel(results, pivotEncoder, pivotQuality,
208                                                col, row, fn)
209
210                color = None
211                label = "%s blocks" % series
212                for xp, yp, mp in zip(x, y, m):
213                    ax.scatter([xp],[yp], s=16, marker=mp,
214                               color="C%u" % k, label=label)
215                    label = None
216
217            if i == 0 and j == 0:
218                ax.legend(loc="lower right")
219
220    for i, row in enumerate(chartRows):
221        for j, col in enumerate(chartCols):
222
223            if len(chartRows) == 1 and len(chartCols) == 1:
224                ax = axs
225            elif len(chartCols) == 1:
226                ax = axs[i]
227            else:
228                ax = axs[i, j]
229
230            ax.grid(ls=':')
231
232            if limits and limits[0]:
233                ax.set_xlim(left=limits[0][0], right=limits[0][1])
234            if limits and limits[1]:
235                ax.set_ylim(bottom=limits[1][0], top=limits[1][1])
236
237    fig.tight_layout()
238    fig.savefig(fileName)
239
240
241def main():
242    """
243    The main function.
244
245    Returns:
246        int: The process return code.
247    """
248    absXMin = 0
249    absXMax = 80
250    absXLimits = (absXMin, absXMax)
251
252    relXMin = 0.8
253    relXMax = None
254    relXLimits = (relXMin, relXMax)
255
256    last1x = "1.7"
257    last2x = "2.5"
258    last3x = "3.7"
259    prev4x = "4.3"
260    last4x = "4.4"
261    lastMain = "main"
262
263    charts = [
264        # --------------------------------------------------------
265        # Latest in stable series charts
266        [
267            # Relative scores
268            ["thorough", "medium", "fast"],
269            [f"ref-{last2x}-avx2", f"ref-{last3x}-avx2", f"ref-{last4x}-avx2"],
270            ["4x4", "6x6", "8x8"],
271            True,
272            f"ref-{last1x}",
273            None,
274            "results-relative-stable-series.png",
275            (None, None)
276        ], [
277            # Absolute scores
278            ["thorough", "medium", "fast"],
279            [f"ref-{last1x}", f"ref-{last2x}-avx2", f"ref-{last3x}-avx2", f"ref-{last4x}-avx2"],
280            ["4x4", "6x6", "8x8"],
281            False,
282            None,
283            None,
284            "results-absolute-stable-series.png",
285            (absXLimits, None)
286        ],
287        # --------------------------------------------------------
288        # Latest 2.x vs 1.x release charts
289        [
290            # Relative scores
291            ["thorough", "medium", "fast"],
292            [f"ref-{last2x}-avx2"],
293            ["4x4", "6x6", "8x8"],
294            True,
295            f"ref-{last1x}",
296            None,
297            "results-relative-2.x-vs-1.x.png",
298            (None, None)
299        ],
300        # --------------------------------------------------------
301        # Latest 3.x vs 1.x release charts
302        [
303            # Relative scores
304            ["thorough", "medium", "fast"],
305            [f"ref-{last3x}-avx2"],
306            ["4x4", "6x6", "8x8"],
307            True,
308            f"ref-{last1x}",
309            None,
310            "results-relative-3.x-vs-1.x.png",
311            (None, None)
312        ],
313        # --------------------------------------------------------
314        # Latest 4.x vs 1.x release charts
315        [
316            # Relative scores
317            ["thorough", "medium", "fast"],
318            [f"ref-{last4x}-avx2"],
319            ["4x4", "6x6", "8x8"],
320            True,
321            f"ref-{last1x}",
322            None,
323            "results-relative-4.x-vs-1.x.png",
324            (None, None)
325        ],
326        # --------------------------------------------------------
327        # Latest 3.x vs 2.x release charts
328        [
329            # Relative scores
330            ["thorough", "medium", "fast", "fastest"],
331            [f"ref-{last3x}-avx2"],
332            ["4x4", "6x6", "8x8"],
333            True,
334            f"ref-{last2x}-avx2",
335            None,
336            "results-relative-3.x-vs-2.x.png",
337            (None, None)
338        ],
339        # --------------------------------------------------------
340        # Latest 4.x vs 3.x release charts
341        [
342            # Relative scores
343            ["thorough", "medium", "fast", "fastest"],
344            [f"ref-{last4x}-avx2"],
345            ["4x4", "6x6", "8x8"],
346            True,
347            f"ref-{last3x}-avx2",
348            None,
349            "results-relative-4.x-vs-3.x.png",
350            (relXLimits, None),
351        ], [
352            # Relative ISAs of latest
353            ["thorough", "medium", "fast", "fastest"],
354            [f"ref-{last4x}-sse4.1", f"ref-{last4x}-avx2"],
355            ["4x4", "6x6", "8x8"],
356            True,
357            f"ref-{last4x}-sse2",
358            None,
359            "results-relative-4.x-isa.png",
360            (None, None)
361        ], [
362            # Relative quality of latest
363            ["medium", "fast", "fastest"],
364            [f"ref-{last4x}-avx2"],
365            ["4x4", "6x6", "8x8"],
366            True,
367            None,
368            "thorough",
369            "results-relative-4.x-quality.png",
370            (None, None)
371        ],
372        # --------------------------------------------------------
373        # Latest 4.x vs previous 4.x release charts
374        [
375            # Relative scores
376            ["thorough", "medium", "fast", "fastest"],
377            [f"ref-{last4x}-avx2"],
378            ["4x4", "6x6", "8x8"],
379            True,
380            f"ref-{prev4x}-avx2",
381            None,
382            "results-relative-4.x-vs-4.x.png",
383            (relXLimits, None)
384        ],
385        # --------------------------------------------------------
386        # Latest 4.x vs previous 4.x release charts
387        [
388            # Relative scores
389            ["thorough", "medium", "fast", "fastest"],
390            [f"ref-{lastMain}-avx2"],
391            ["4x4", "6x6", "8x8"],
392            True,
393            f"ref-{last4x}-avx2",
394            None,
395            "results-relative-main-vs-4.x.png",
396            (relXLimits, None)
397        ]
398    ]
399
400    results = find_reference_results()
401
402    # Force select is triggered by adding a trailing entry to the argument list
403    # of the charts that you want rendered; designed for debugging use cases
404    maxIndex = 0
405    expectedLength = 8
406    for chart in charts:
407        maxIndex = max(maxIndex, len(chart))
408
409    for chart in charts:
410        # If force select is enabled then only keep the forced ones
411        if len(chart) != maxIndex:
412            print("Skipping %s" % chart[6])
413            continue
414        else:
415            print("Generating %s" % chart[6])
416
417        # If force select is enabled then strip the dummy force option
418        if maxIndex != expectedLength:
419            chart = chart[:expectedLength]
420
421        plot(results, *chart)
422
423    return 0
424
425
426if __name__ == "__main__":
427    sys.exit(main())
428