17db96d56Sopenharmony_ci"""Test the secrets module.
27db96d56Sopenharmony_ci
37db96d56Sopenharmony_ciAs most of the functions in secrets are thin wrappers around functions
47db96d56Sopenharmony_cidefined elsewhere, we don't need to test them exhaustively.
57db96d56Sopenharmony_ci"""
67db96d56Sopenharmony_ci
77db96d56Sopenharmony_ci
87db96d56Sopenharmony_ciimport secrets
97db96d56Sopenharmony_ciimport unittest
107db96d56Sopenharmony_ciimport string
117db96d56Sopenharmony_ci
127db96d56Sopenharmony_ci
137db96d56Sopenharmony_ci# === Unit tests ===
147db96d56Sopenharmony_ci
157db96d56Sopenharmony_ciclass Compare_Digest_Tests(unittest.TestCase):
167db96d56Sopenharmony_ci    """Test secrets.compare_digest function."""
177db96d56Sopenharmony_ci
187db96d56Sopenharmony_ci    def test_equal(self):
197db96d56Sopenharmony_ci        # Test compare_digest functionality with equal (byte/text) strings.
207db96d56Sopenharmony_ci        for s in ("a", "bcd", "xyz123"):
217db96d56Sopenharmony_ci            a = s*100
227db96d56Sopenharmony_ci            b = s*100
237db96d56Sopenharmony_ci            self.assertTrue(secrets.compare_digest(a, b))
247db96d56Sopenharmony_ci            self.assertTrue(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8')))
257db96d56Sopenharmony_ci
267db96d56Sopenharmony_ci    def test_unequal(self):
277db96d56Sopenharmony_ci        # Test compare_digest functionality with unequal (byte/text) strings.
287db96d56Sopenharmony_ci        self.assertFalse(secrets.compare_digest("abc", "abcd"))
297db96d56Sopenharmony_ci        self.assertFalse(secrets.compare_digest(b"abc", b"abcd"))
307db96d56Sopenharmony_ci        for s in ("x", "mn", "a1b2c3"):
317db96d56Sopenharmony_ci            a = s*100 + "q"
327db96d56Sopenharmony_ci            b = s*100 + "k"
337db96d56Sopenharmony_ci            self.assertFalse(secrets.compare_digest(a, b))
347db96d56Sopenharmony_ci            self.assertFalse(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8')))
357db96d56Sopenharmony_ci
367db96d56Sopenharmony_ci    def test_bad_types(self):
377db96d56Sopenharmony_ci        # Test that compare_digest raises with mixed types.
387db96d56Sopenharmony_ci        a = 'abcde'
397db96d56Sopenharmony_ci        b = a.encode('utf-8')
407db96d56Sopenharmony_ci        assert isinstance(a, str)
417db96d56Sopenharmony_ci        assert isinstance(b, bytes)
427db96d56Sopenharmony_ci        self.assertRaises(TypeError, secrets.compare_digest, a, b)
437db96d56Sopenharmony_ci        self.assertRaises(TypeError, secrets.compare_digest, b, a)
447db96d56Sopenharmony_ci
457db96d56Sopenharmony_ci    def test_bool(self):
467db96d56Sopenharmony_ci        # Test that compare_digest returns a bool.
477db96d56Sopenharmony_ci        self.assertIsInstance(secrets.compare_digest("abc", "abc"), bool)
487db96d56Sopenharmony_ci        self.assertIsInstance(secrets.compare_digest("abc", "xyz"), bool)
497db96d56Sopenharmony_ci
507db96d56Sopenharmony_ci
517db96d56Sopenharmony_ciclass Random_Tests(unittest.TestCase):
527db96d56Sopenharmony_ci    """Test wrappers around SystemRandom methods."""
537db96d56Sopenharmony_ci
547db96d56Sopenharmony_ci    def test_randbits(self):
557db96d56Sopenharmony_ci        # Test randbits.
567db96d56Sopenharmony_ci        errmsg = "randbits(%d) returned %d"
577db96d56Sopenharmony_ci        for numbits in (3, 12, 30):
587db96d56Sopenharmony_ci            for i in range(6):
597db96d56Sopenharmony_ci                n = secrets.randbits(numbits)
607db96d56Sopenharmony_ci                self.assertTrue(0 <= n < 2**numbits, errmsg % (numbits, n))
617db96d56Sopenharmony_ci
627db96d56Sopenharmony_ci    def test_choice(self):
637db96d56Sopenharmony_ci        # Test choice.
647db96d56Sopenharmony_ci        items = [1, 2, 4, 8, 16, 32, 64]
657db96d56Sopenharmony_ci        for i in range(10):
667db96d56Sopenharmony_ci            self.assertTrue(secrets.choice(items) in items)
677db96d56Sopenharmony_ci
687db96d56Sopenharmony_ci    def test_randbelow(self):
697db96d56Sopenharmony_ci        # Test randbelow.
707db96d56Sopenharmony_ci        for i in range(2, 10):
717db96d56Sopenharmony_ci            self.assertIn(secrets.randbelow(i), range(i))
727db96d56Sopenharmony_ci        self.assertRaises(ValueError, secrets.randbelow, 0)
737db96d56Sopenharmony_ci        self.assertRaises(ValueError, secrets.randbelow, -1)
747db96d56Sopenharmony_ci
757db96d56Sopenharmony_ci
767db96d56Sopenharmony_ciclass Token_Tests(unittest.TestCase):
777db96d56Sopenharmony_ci    """Test token functions."""
787db96d56Sopenharmony_ci
797db96d56Sopenharmony_ci    def test_token_defaults(self):
807db96d56Sopenharmony_ci        # Test that token_* functions handle default size correctly.
817db96d56Sopenharmony_ci        for func in (secrets.token_bytes, secrets.token_hex,
827db96d56Sopenharmony_ci                     secrets.token_urlsafe):
837db96d56Sopenharmony_ci            with self.subTest(func=func):
847db96d56Sopenharmony_ci                name = func.__name__
857db96d56Sopenharmony_ci                try:
867db96d56Sopenharmony_ci                    func()
877db96d56Sopenharmony_ci                except TypeError:
887db96d56Sopenharmony_ci                    self.fail("%s cannot be called with no argument" % name)
897db96d56Sopenharmony_ci                try:
907db96d56Sopenharmony_ci                    func(None)
917db96d56Sopenharmony_ci                except TypeError:
927db96d56Sopenharmony_ci                    self.fail("%s cannot be called with None" % name)
937db96d56Sopenharmony_ci        size = secrets.DEFAULT_ENTROPY
947db96d56Sopenharmony_ci        self.assertEqual(len(secrets.token_bytes(None)), size)
957db96d56Sopenharmony_ci        self.assertEqual(len(secrets.token_hex(None)), 2*size)
967db96d56Sopenharmony_ci
977db96d56Sopenharmony_ci    def test_token_bytes(self):
987db96d56Sopenharmony_ci        # Test token_bytes.
997db96d56Sopenharmony_ci        for n in (1, 8, 17, 100):
1007db96d56Sopenharmony_ci            with self.subTest(n=n):
1017db96d56Sopenharmony_ci                self.assertIsInstance(secrets.token_bytes(n), bytes)
1027db96d56Sopenharmony_ci                self.assertEqual(len(secrets.token_bytes(n)), n)
1037db96d56Sopenharmony_ci
1047db96d56Sopenharmony_ci    def test_token_hex(self):
1057db96d56Sopenharmony_ci        # Test token_hex.
1067db96d56Sopenharmony_ci        for n in (1, 12, 25, 90):
1077db96d56Sopenharmony_ci            with self.subTest(n=n):
1087db96d56Sopenharmony_ci                s = secrets.token_hex(n)
1097db96d56Sopenharmony_ci                self.assertIsInstance(s, str)
1107db96d56Sopenharmony_ci                self.assertEqual(len(s), 2*n)
1117db96d56Sopenharmony_ci                self.assertTrue(all(c in string.hexdigits for c in s))
1127db96d56Sopenharmony_ci
1137db96d56Sopenharmony_ci    def test_token_urlsafe(self):
1147db96d56Sopenharmony_ci        # Test token_urlsafe.
1157db96d56Sopenharmony_ci        legal = string.ascii_letters + string.digits + '-_'
1167db96d56Sopenharmony_ci        for n in (1, 11, 28, 76):
1177db96d56Sopenharmony_ci            with self.subTest(n=n):
1187db96d56Sopenharmony_ci                s = secrets.token_urlsafe(n)
1197db96d56Sopenharmony_ci                self.assertIsInstance(s, str)
1207db96d56Sopenharmony_ci                self.assertTrue(all(c in legal for c in s))
1217db96d56Sopenharmony_ci
1227db96d56Sopenharmony_ci
1237db96d56Sopenharmony_ciif __name__ == '__main__':
1247db96d56Sopenharmony_ci    unittest.main()
125