1a8e1175bSopenharmony_ci#!/usr/bin/env python3
2a8e1175bSopenharmony_ci"""Test the program psa_constant_names.
3a8e1175bSopenharmony_ciGather constant names from header files and test cases. Compile a C program
4a8e1175bSopenharmony_cito print out their numerical values, feed these numerical values to
5a8e1175bSopenharmony_cipsa_constant_names, and check that the output is the original name.
6a8e1175bSopenharmony_ciReturn 0 if all test cases pass, 1 if the output was not always as expected,
7a8e1175bSopenharmony_cior 1 (with a Python backtrace) if there was an operational error.
8a8e1175bSopenharmony_ci"""
9a8e1175bSopenharmony_ci
10a8e1175bSopenharmony_ci# Copyright The Mbed TLS Contributors
11a8e1175bSopenharmony_ci# SPDX-License-Identifier: Apache-2.0
12a8e1175bSopenharmony_ci#
13a8e1175bSopenharmony_ci# Licensed under the Apache License, Version 2.0 (the "License"); you may
14a8e1175bSopenharmony_ci# not use this file except in compliance with the License.
15a8e1175bSopenharmony_ci# You may obtain a copy of the License at
16a8e1175bSopenharmony_ci#
17a8e1175bSopenharmony_ci# http://www.apache.org/licenses/LICENSE-2.0
18a8e1175bSopenharmony_ci#
19a8e1175bSopenharmony_ci# Unless required by applicable law or agreed to in writing, software
20a8e1175bSopenharmony_ci# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
21a8e1175bSopenharmony_ci# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22a8e1175bSopenharmony_ci# See the License for the specific language governing permissions and
23a8e1175bSopenharmony_ci# limitations under the License.
24a8e1175bSopenharmony_ci
25a8e1175bSopenharmony_ciimport argparse
26a8e1175bSopenharmony_cifrom collections import namedtuple
27a8e1175bSopenharmony_ciimport os
28a8e1175bSopenharmony_ciimport re
29a8e1175bSopenharmony_ciimport subprocess
30a8e1175bSopenharmony_ciimport sys
31a8e1175bSopenharmony_cifrom typing import Iterable, List, Optional, Tuple
32a8e1175bSopenharmony_ci
33a8e1175bSopenharmony_ciimport scripts_path # pylint: disable=unused-import
34a8e1175bSopenharmony_cifrom mbedtls_dev import c_build_helper
35a8e1175bSopenharmony_cifrom mbedtls_dev.macro_collector import InputsForTest, PSAMacroEnumerator
36a8e1175bSopenharmony_cifrom mbedtls_dev import typing_util
37a8e1175bSopenharmony_ci
38a8e1175bSopenharmony_cidef gather_inputs(headers: Iterable[str],
39a8e1175bSopenharmony_ci                  test_suites: Iterable[str],
40a8e1175bSopenharmony_ci                  inputs_class=InputsForTest) -> PSAMacroEnumerator:
41a8e1175bSopenharmony_ci    """Read the list of inputs to test psa_constant_names with."""
42a8e1175bSopenharmony_ci    inputs = inputs_class()
43a8e1175bSopenharmony_ci    for header in headers:
44a8e1175bSopenharmony_ci        inputs.parse_header(header)
45a8e1175bSopenharmony_ci    for test_cases in test_suites:
46a8e1175bSopenharmony_ci        inputs.parse_test_cases(test_cases)
47a8e1175bSopenharmony_ci    inputs.add_numerical_values()
48a8e1175bSopenharmony_ci    inputs.gather_arguments()
49a8e1175bSopenharmony_ci    return inputs
50a8e1175bSopenharmony_ci
51a8e1175bSopenharmony_cidef run_c(type_word: str,
52a8e1175bSopenharmony_ci          expressions: Iterable[str],
53a8e1175bSopenharmony_ci          include_path: Optional[str] = None,
54a8e1175bSopenharmony_ci          keep_c: bool = False) -> List[str]:
55a8e1175bSopenharmony_ci    """Generate and run a program to print out numerical values of C expressions."""
56a8e1175bSopenharmony_ci    if type_word == 'status':
57a8e1175bSopenharmony_ci        cast_to = 'long'
58a8e1175bSopenharmony_ci        printf_format = '%ld'
59a8e1175bSopenharmony_ci    else:
60a8e1175bSopenharmony_ci        cast_to = 'unsigned long'
61a8e1175bSopenharmony_ci        printf_format = '0x%08lx'
62a8e1175bSopenharmony_ci    return c_build_helper.get_c_expression_values(
63a8e1175bSopenharmony_ci        cast_to, printf_format,
64a8e1175bSopenharmony_ci        expressions,
65a8e1175bSopenharmony_ci        caller='test_psa_constant_names.py for {} values'.format(type_word),
66a8e1175bSopenharmony_ci        file_label=type_word,
67a8e1175bSopenharmony_ci        header='#include <psa/crypto.h>',
68a8e1175bSopenharmony_ci        include_path=include_path,
69a8e1175bSopenharmony_ci        keep_c=keep_c
70a8e1175bSopenharmony_ci    )
71a8e1175bSopenharmony_ci
72a8e1175bSopenharmony_ciNORMALIZE_STRIP_RE = re.compile(r'\s+')
73a8e1175bSopenharmony_cidef normalize(expr: str) -> str:
74a8e1175bSopenharmony_ci    """Normalize the C expression so as not to care about trivial differences.
75a8e1175bSopenharmony_ci
76a8e1175bSopenharmony_ci    Currently "trivial differences" means whitespace.
77a8e1175bSopenharmony_ci    """
78a8e1175bSopenharmony_ci    return re.sub(NORMALIZE_STRIP_RE, '', expr)
79a8e1175bSopenharmony_ci
80a8e1175bSopenharmony_ciALG_TRUNCATED_TO_SELF_RE = \
81a8e1175bSopenharmony_ci    re.compile(r'PSA_ALG_AEAD_WITH_SHORTENED_TAG\('
82a8e1175bSopenharmony_ci               r'PSA_ALG_(?:CCM|CHACHA20_POLY1305|GCM)'
83a8e1175bSopenharmony_ci               r', *16\)\Z')
84a8e1175bSopenharmony_ci
85a8e1175bSopenharmony_cidef is_simplifiable(expr: str) -> bool:
86a8e1175bSopenharmony_ci    """Determine whether an expression is simplifiable.
87a8e1175bSopenharmony_ci
88a8e1175bSopenharmony_ci    Simplifiable expressions can't be output in their input form, since
89a8e1175bSopenharmony_ci    the output will be the simple form. Therefore they must be excluded
90a8e1175bSopenharmony_ci    from testing.
91a8e1175bSopenharmony_ci    """
92a8e1175bSopenharmony_ci    if ALG_TRUNCATED_TO_SELF_RE.match(expr):
93a8e1175bSopenharmony_ci        return True
94a8e1175bSopenharmony_ci    return False
95a8e1175bSopenharmony_ci
96a8e1175bSopenharmony_cidef collect_values(inputs: InputsForTest,
97a8e1175bSopenharmony_ci                   type_word: str,
98a8e1175bSopenharmony_ci                   include_path: Optional[str] = None,
99a8e1175bSopenharmony_ci                   keep_c: bool = False) -> Tuple[List[str], List[str]]:
100a8e1175bSopenharmony_ci    """Generate expressions using known macro names and calculate their values.
101a8e1175bSopenharmony_ci
102a8e1175bSopenharmony_ci    Return a list of pairs of (expr, value) where expr is an expression and
103a8e1175bSopenharmony_ci    value is a string representation of its integer value.
104a8e1175bSopenharmony_ci    """
105a8e1175bSopenharmony_ci    names = inputs.get_names(type_word)
106a8e1175bSopenharmony_ci    expressions = sorted(expr
107a8e1175bSopenharmony_ci                         for expr in inputs.generate_expressions(names)
108a8e1175bSopenharmony_ci                         if not is_simplifiable(expr))
109a8e1175bSopenharmony_ci    values = run_c(type_word, expressions,
110a8e1175bSopenharmony_ci                   include_path=include_path, keep_c=keep_c)
111a8e1175bSopenharmony_ci    return expressions, values
112a8e1175bSopenharmony_ci
113a8e1175bSopenharmony_ciclass Tests:
114a8e1175bSopenharmony_ci    """An object representing tests and their results."""
115a8e1175bSopenharmony_ci
116a8e1175bSopenharmony_ci    Error = namedtuple('Error',
117a8e1175bSopenharmony_ci                       ['type', 'expression', 'value', 'output'])
118a8e1175bSopenharmony_ci
119a8e1175bSopenharmony_ci    def __init__(self, options) -> None:
120a8e1175bSopenharmony_ci        self.options = options
121a8e1175bSopenharmony_ci        self.count = 0
122a8e1175bSopenharmony_ci        self.errors = [] #type: List[Tests.Error]
123a8e1175bSopenharmony_ci
124a8e1175bSopenharmony_ci    def run_one(self, inputs: InputsForTest, type_word: str) -> None:
125a8e1175bSopenharmony_ci        """Test psa_constant_names for the specified type.
126a8e1175bSopenharmony_ci
127a8e1175bSopenharmony_ci        Run the program on the names for this type.
128a8e1175bSopenharmony_ci        Use the inputs to figure out what arguments to pass to macros that
129a8e1175bSopenharmony_ci        take arguments.
130a8e1175bSopenharmony_ci        """
131a8e1175bSopenharmony_ci        expressions, values = collect_values(inputs, type_word,
132a8e1175bSopenharmony_ci                                             include_path=self.options.include,
133a8e1175bSopenharmony_ci                                             keep_c=self.options.keep_c)
134a8e1175bSopenharmony_ci        output_bytes = subprocess.check_output([self.options.program,
135a8e1175bSopenharmony_ci                                                type_word] + values)
136a8e1175bSopenharmony_ci        output = output_bytes.decode('ascii')
137a8e1175bSopenharmony_ci        outputs = output.strip().split('\n')
138a8e1175bSopenharmony_ci        self.count += len(expressions)
139a8e1175bSopenharmony_ci        for expr, value, output in zip(expressions, values, outputs):
140a8e1175bSopenharmony_ci            if self.options.show:
141a8e1175bSopenharmony_ci                sys.stdout.write('{} {}\t{}\n'.format(type_word, value, output))
142a8e1175bSopenharmony_ci            if normalize(expr) != normalize(output):
143a8e1175bSopenharmony_ci                self.errors.append(self.Error(type=type_word,
144a8e1175bSopenharmony_ci                                              expression=expr,
145a8e1175bSopenharmony_ci                                              value=value,
146a8e1175bSopenharmony_ci                                              output=output))
147a8e1175bSopenharmony_ci
148a8e1175bSopenharmony_ci    def run_all(self, inputs: InputsForTest) -> None:
149a8e1175bSopenharmony_ci        """Run psa_constant_names on all the gathered inputs."""
150a8e1175bSopenharmony_ci        for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
151a8e1175bSopenharmony_ci                          'key_type', 'key_usage']:
152a8e1175bSopenharmony_ci            self.run_one(inputs, type_word)
153a8e1175bSopenharmony_ci
154a8e1175bSopenharmony_ci    def report(self, out: typing_util.Writable) -> None:
155a8e1175bSopenharmony_ci        """Describe each case where the output is not as expected.
156a8e1175bSopenharmony_ci
157a8e1175bSopenharmony_ci        Write the errors to ``out``.
158a8e1175bSopenharmony_ci        Also write a total.
159a8e1175bSopenharmony_ci        """
160a8e1175bSopenharmony_ci        for error in self.errors:
161a8e1175bSopenharmony_ci            out.write('For {} "{}", got "{}" (value: {})\n'
162a8e1175bSopenharmony_ci                      .format(error.type, error.expression,
163a8e1175bSopenharmony_ci                              error.output, error.value))
164a8e1175bSopenharmony_ci        out.write('{} test cases'.format(self.count))
165a8e1175bSopenharmony_ci        if self.errors:
166a8e1175bSopenharmony_ci            out.write(', {} FAIL\n'.format(len(self.errors)))
167a8e1175bSopenharmony_ci        else:
168a8e1175bSopenharmony_ci            out.write(' PASS\n')
169a8e1175bSopenharmony_ci
170a8e1175bSopenharmony_ciHEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h']
171a8e1175bSopenharmony_ciTEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data']
172a8e1175bSopenharmony_ci
173a8e1175bSopenharmony_cidef main():
174a8e1175bSopenharmony_ci    parser = argparse.ArgumentParser(description=globals()['__doc__'])
175a8e1175bSopenharmony_ci    parser.add_argument('--include', '-I',
176a8e1175bSopenharmony_ci                        action='append', default=['include'],
177a8e1175bSopenharmony_ci                        help='Directory for header files')
178a8e1175bSopenharmony_ci    parser.add_argument('--keep-c',
179a8e1175bSopenharmony_ci                        action='store_true', dest='keep_c', default=False,
180a8e1175bSopenharmony_ci                        help='Keep the intermediate C file')
181a8e1175bSopenharmony_ci    parser.add_argument('--no-keep-c',
182a8e1175bSopenharmony_ci                        action='store_false', dest='keep_c',
183a8e1175bSopenharmony_ci                        help='Don\'t keep the intermediate C file (default)')
184a8e1175bSopenharmony_ci    parser.add_argument('--program',
185a8e1175bSopenharmony_ci                        default='programs/psa/psa_constant_names',
186a8e1175bSopenharmony_ci                        help='Program to test')
187a8e1175bSopenharmony_ci    parser.add_argument('--show',
188a8e1175bSopenharmony_ci                        action='store_true',
189a8e1175bSopenharmony_ci                        help='Show tested values on stdout')
190a8e1175bSopenharmony_ci    parser.add_argument('--no-show',
191a8e1175bSopenharmony_ci                        action='store_false', dest='show',
192a8e1175bSopenharmony_ci                        help='Don\'t show tested values (default)')
193a8e1175bSopenharmony_ci    options = parser.parse_args()
194a8e1175bSopenharmony_ci    headers = [os.path.join(options.include[0], h) for h in HEADERS]
195a8e1175bSopenharmony_ci    inputs = gather_inputs(headers, TEST_SUITES)
196a8e1175bSopenharmony_ci    tests = Tests(options)
197a8e1175bSopenharmony_ci    tests.run_all(inputs)
198a8e1175bSopenharmony_ci    tests.report(sys.stdout)
199a8e1175bSopenharmony_ci    if tests.errors:
200a8e1175bSopenharmony_ci        sys.exit(1)
201a8e1175bSopenharmony_ci
202a8e1175bSopenharmony_ciif __name__ == '__main__':
203a8e1175bSopenharmony_ci    main()
204