17db96d56Sopenharmony_ci# Deliberately use "from dataclasses import *". Every name in __all__ 27db96d56Sopenharmony_ci# is tested, so they all must be present. This is a way to catch 37db96d56Sopenharmony_ci# missing ones. 47db96d56Sopenharmony_ci 57db96d56Sopenharmony_cifrom dataclasses import * 67db96d56Sopenharmony_ci 77db96d56Sopenharmony_ciimport abc 87db96d56Sopenharmony_ciimport io 97db96d56Sopenharmony_ciimport pickle 107db96d56Sopenharmony_ciimport inspect 117db96d56Sopenharmony_ciimport builtins 127db96d56Sopenharmony_ciimport types 137db96d56Sopenharmony_ciimport weakref 147db96d56Sopenharmony_ciimport traceback 157db96d56Sopenharmony_ciimport unittest 167db96d56Sopenharmony_cifrom unittest.mock import Mock 177db96d56Sopenharmony_cifrom typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol 187db96d56Sopenharmony_cifrom typing import get_type_hints 197db96d56Sopenharmony_cifrom collections import deque, OrderedDict, namedtuple 207db96d56Sopenharmony_cifrom functools import total_ordering 217db96d56Sopenharmony_ci 227db96d56Sopenharmony_ciimport typing # Needed for the string "typing.ClassVar[int]" to work as an annotation. 237db96d56Sopenharmony_ciimport dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation. 247db96d56Sopenharmony_ci 257db96d56Sopenharmony_ci# Just any custom exception we can catch. 267db96d56Sopenharmony_ciclass CustomError(Exception): pass 277db96d56Sopenharmony_ci 287db96d56Sopenharmony_ciclass TestCase(unittest.TestCase): 297db96d56Sopenharmony_ci def test_no_fields(self): 307db96d56Sopenharmony_ci @dataclass 317db96d56Sopenharmony_ci class C: 327db96d56Sopenharmony_ci pass 337db96d56Sopenharmony_ci 347db96d56Sopenharmony_ci o = C() 357db96d56Sopenharmony_ci self.assertEqual(len(fields(C)), 0) 367db96d56Sopenharmony_ci 377db96d56Sopenharmony_ci def test_no_fields_but_member_variable(self): 387db96d56Sopenharmony_ci @dataclass 397db96d56Sopenharmony_ci class C: 407db96d56Sopenharmony_ci i = 0 417db96d56Sopenharmony_ci 427db96d56Sopenharmony_ci o = C() 437db96d56Sopenharmony_ci self.assertEqual(len(fields(C)), 0) 447db96d56Sopenharmony_ci 457db96d56Sopenharmony_ci def test_one_field_no_default(self): 467db96d56Sopenharmony_ci @dataclass 477db96d56Sopenharmony_ci class C: 487db96d56Sopenharmony_ci x: int 497db96d56Sopenharmony_ci 507db96d56Sopenharmony_ci o = C(42) 517db96d56Sopenharmony_ci self.assertEqual(o.x, 42) 527db96d56Sopenharmony_ci 537db96d56Sopenharmony_ci def test_field_default_default_factory_error(self): 547db96d56Sopenharmony_ci msg = "cannot specify both default and default_factory" 557db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, msg): 567db96d56Sopenharmony_ci @dataclass 577db96d56Sopenharmony_ci class C: 587db96d56Sopenharmony_ci x: int = field(default=1, default_factory=int) 597db96d56Sopenharmony_ci 607db96d56Sopenharmony_ci def test_field_repr(self): 617db96d56Sopenharmony_ci int_field = field(default=1, init=True, repr=False) 627db96d56Sopenharmony_ci int_field.name = "id" 637db96d56Sopenharmony_ci repr_output = repr(int_field) 647db96d56Sopenharmony_ci expected_output = "Field(name='id',type=None," \ 657db96d56Sopenharmony_ci f"default=1,default_factory={MISSING!r}," \ 667db96d56Sopenharmony_ci "init=True,repr=False,hash=None," \ 677db96d56Sopenharmony_ci "compare=True,metadata=mappingproxy({})," \ 687db96d56Sopenharmony_ci f"kw_only={MISSING!r}," \ 697db96d56Sopenharmony_ci "_field_type=None)" 707db96d56Sopenharmony_ci 717db96d56Sopenharmony_ci self.assertEqual(repr_output, expected_output) 727db96d56Sopenharmony_ci 737db96d56Sopenharmony_ci def test_field_recursive_repr(self): 747db96d56Sopenharmony_ci rec_field = field() 757db96d56Sopenharmony_ci rec_field.type = rec_field 767db96d56Sopenharmony_ci rec_field.name = "id" 777db96d56Sopenharmony_ci repr_output = repr(rec_field) 787db96d56Sopenharmony_ci 797db96d56Sopenharmony_ci self.assertIn(",type=...,", repr_output) 807db96d56Sopenharmony_ci 817db96d56Sopenharmony_ci def test_recursive_annotation(self): 827db96d56Sopenharmony_ci class C: 837db96d56Sopenharmony_ci pass 847db96d56Sopenharmony_ci 857db96d56Sopenharmony_ci @dataclass 867db96d56Sopenharmony_ci class D: 877db96d56Sopenharmony_ci C: C = field() 887db96d56Sopenharmony_ci 897db96d56Sopenharmony_ci self.assertIn(",type=...,", repr(D.__dataclass_fields__["C"])) 907db96d56Sopenharmony_ci 917db96d56Sopenharmony_ci def test_named_init_params(self): 927db96d56Sopenharmony_ci @dataclass 937db96d56Sopenharmony_ci class C: 947db96d56Sopenharmony_ci x: int 957db96d56Sopenharmony_ci 967db96d56Sopenharmony_ci o = C(x=32) 977db96d56Sopenharmony_ci self.assertEqual(o.x, 32) 987db96d56Sopenharmony_ci 997db96d56Sopenharmony_ci def test_two_fields_one_default(self): 1007db96d56Sopenharmony_ci @dataclass 1017db96d56Sopenharmony_ci class C: 1027db96d56Sopenharmony_ci x: int 1037db96d56Sopenharmony_ci y: int = 0 1047db96d56Sopenharmony_ci 1057db96d56Sopenharmony_ci o = C(3) 1067db96d56Sopenharmony_ci self.assertEqual((o.x, o.y), (3, 0)) 1077db96d56Sopenharmony_ci 1087db96d56Sopenharmony_ci # Non-defaults following defaults. 1097db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 1107db96d56Sopenharmony_ci "non-default argument 'y' follows " 1117db96d56Sopenharmony_ci "default argument"): 1127db96d56Sopenharmony_ci @dataclass 1137db96d56Sopenharmony_ci class C: 1147db96d56Sopenharmony_ci x: int = 0 1157db96d56Sopenharmony_ci y: int 1167db96d56Sopenharmony_ci 1177db96d56Sopenharmony_ci # A derived class adds a non-default field after a default one. 1187db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 1197db96d56Sopenharmony_ci "non-default argument 'y' follows " 1207db96d56Sopenharmony_ci "default argument"): 1217db96d56Sopenharmony_ci @dataclass 1227db96d56Sopenharmony_ci class B: 1237db96d56Sopenharmony_ci x: int = 0 1247db96d56Sopenharmony_ci 1257db96d56Sopenharmony_ci @dataclass 1267db96d56Sopenharmony_ci class C(B): 1277db96d56Sopenharmony_ci y: int 1287db96d56Sopenharmony_ci 1297db96d56Sopenharmony_ci # Override a base class field and add a default to 1307db96d56Sopenharmony_ci # a field which didn't use to have a default. 1317db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 1327db96d56Sopenharmony_ci "non-default argument 'y' follows " 1337db96d56Sopenharmony_ci "default argument"): 1347db96d56Sopenharmony_ci @dataclass 1357db96d56Sopenharmony_ci class B: 1367db96d56Sopenharmony_ci x: int 1377db96d56Sopenharmony_ci y: int 1387db96d56Sopenharmony_ci 1397db96d56Sopenharmony_ci @dataclass 1407db96d56Sopenharmony_ci class C(B): 1417db96d56Sopenharmony_ci x: int = 0 1427db96d56Sopenharmony_ci 1437db96d56Sopenharmony_ci def test_overwrite_hash(self): 1447db96d56Sopenharmony_ci # Test that declaring this class isn't an error. It should 1457db96d56Sopenharmony_ci # use the user-provided __hash__. 1467db96d56Sopenharmony_ci @dataclass(frozen=True) 1477db96d56Sopenharmony_ci class C: 1487db96d56Sopenharmony_ci x: int 1497db96d56Sopenharmony_ci def __hash__(self): 1507db96d56Sopenharmony_ci return 301 1517db96d56Sopenharmony_ci self.assertEqual(hash(C(100)), 301) 1527db96d56Sopenharmony_ci 1537db96d56Sopenharmony_ci # Test that declaring this class isn't an error. It should 1547db96d56Sopenharmony_ci # use the generated __hash__. 1557db96d56Sopenharmony_ci @dataclass(frozen=True) 1567db96d56Sopenharmony_ci class C: 1577db96d56Sopenharmony_ci x: int 1587db96d56Sopenharmony_ci def __eq__(self, other): 1597db96d56Sopenharmony_ci return False 1607db96d56Sopenharmony_ci self.assertEqual(hash(C(100)), hash((100,))) 1617db96d56Sopenharmony_ci 1627db96d56Sopenharmony_ci # But this one should generate an exception, because with 1637db96d56Sopenharmony_ci # unsafe_hash=True, it's an error to have a __hash__ defined. 1647db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 1657db96d56Sopenharmony_ci 'Cannot overwrite attribute __hash__'): 1667db96d56Sopenharmony_ci @dataclass(unsafe_hash=True) 1677db96d56Sopenharmony_ci class C: 1687db96d56Sopenharmony_ci def __hash__(self): 1697db96d56Sopenharmony_ci pass 1707db96d56Sopenharmony_ci 1717db96d56Sopenharmony_ci # Creating this class should not generate an exception, 1727db96d56Sopenharmony_ci # because even though __hash__ exists before @dataclass is 1737db96d56Sopenharmony_ci # called, (due to __eq__ being defined), since it's None 1747db96d56Sopenharmony_ci # that's okay. 1757db96d56Sopenharmony_ci @dataclass(unsafe_hash=True) 1767db96d56Sopenharmony_ci class C: 1777db96d56Sopenharmony_ci x: int 1787db96d56Sopenharmony_ci def __eq__(self): 1797db96d56Sopenharmony_ci pass 1807db96d56Sopenharmony_ci # The generated hash function works as we'd expect. 1817db96d56Sopenharmony_ci self.assertEqual(hash(C(10)), hash((10,))) 1827db96d56Sopenharmony_ci 1837db96d56Sopenharmony_ci # Creating this class should generate an exception, because 1847db96d56Sopenharmony_ci # __hash__ exists and is not None, which it would be if it 1857db96d56Sopenharmony_ci # had been auto-generated due to __eq__ being defined. 1867db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 1877db96d56Sopenharmony_ci 'Cannot overwrite attribute __hash__'): 1887db96d56Sopenharmony_ci @dataclass(unsafe_hash=True) 1897db96d56Sopenharmony_ci class C: 1907db96d56Sopenharmony_ci x: int 1917db96d56Sopenharmony_ci def __eq__(self): 1927db96d56Sopenharmony_ci pass 1937db96d56Sopenharmony_ci def __hash__(self): 1947db96d56Sopenharmony_ci pass 1957db96d56Sopenharmony_ci 1967db96d56Sopenharmony_ci def test_overwrite_fields_in_derived_class(self): 1977db96d56Sopenharmony_ci # Note that x from C1 replaces x in Base, but the order remains 1987db96d56Sopenharmony_ci # the same as defined in Base. 1997db96d56Sopenharmony_ci @dataclass 2007db96d56Sopenharmony_ci class Base: 2017db96d56Sopenharmony_ci x: Any = 15.0 2027db96d56Sopenharmony_ci y: int = 0 2037db96d56Sopenharmony_ci 2047db96d56Sopenharmony_ci @dataclass 2057db96d56Sopenharmony_ci class C1(Base): 2067db96d56Sopenharmony_ci z: int = 10 2077db96d56Sopenharmony_ci x: int = 15 2087db96d56Sopenharmony_ci 2097db96d56Sopenharmony_ci o = Base() 2107db96d56Sopenharmony_ci self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)') 2117db96d56Sopenharmony_ci 2127db96d56Sopenharmony_ci o = C1() 2137db96d56Sopenharmony_ci self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)') 2147db96d56Sopenharmony_ci 2157db96d56Sopenharmony_ci o = C1(x=5) 2167db96d56Sopenharmony_ci self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)') 2177db96d56Sopenharmony_ci 2187db96d56Sopenharmony_ci def test_field_named_self(self): 2197db96d56Sopenharmony_ci @dataclass 2207db96d56Sopenharmony_ci class C: 2217db96d56Sopenharmony_ci self: str 2227db96d56Sopenharmony_ci c=C('foo') 2237db96d56Sopenharmony_ci self.assertEqual(c.self, 'foo') 2247db96d56Sopenharmony_ci 2257db96d56Sopenharmony_ci # Make sure the first parameter is not named 'self'. 2267db96d56Sopenharmony_ci sig = inspect.signature(C.__init__) 2277db96d56Sopenharmony_ci first = next(iter(sig.parameters)) 2287db96d56Sopenharmony_ci self.assertNotEqual('self', first) 2297db96d56Sopenharmony_ci 2307db96d56Sopenharmony_ci # But we do use 'self' if no field named self. 2317db96d56Sopenharmony_ci @dataclass 2327db96d56Sopenharmony_ci class C: 2337db96d56Sopenharmony_ci selfx: str 2347db96d56Sopenharmony_ci 2357db96d56Sopenharmony_ci # Make sure the first parameter is named 'self'. 2367db96d56Sopenharmony_ci sig = inspect.signature(C.__init__) 2377db96d56Sopenharmony_ci first = next(iter(sig.parameters)) 2387db96d56Sopenharmony_ci self.assertEqual('self', first) 2397db96d56Sopenharmony_ci 2407db96d56Sopenharmony_ci def test_field_named_object(self): 2417db96d56Sopenharmony_ci @dataclass 2427db96d56Sopenharmony_ci class C: 2437db96d56Sopenharmony_ci object: str 2447db96d56Sopenharmony_ci c = C('foo') 2457db96d56Sopenharmony_ci self.assertEqual(c.object, 'foo') 2467db96d56Sopenharmony_ci 2477db96d56Sopenharmony_ci def test_field_named_object_frozen(self): 2487db96d56Sopenharmony_ci @dataclass(frozen=True) 2497db96d56Sopenharmony_ci class C: 2507db96d56Sopenharmony_ci object: str 2517db96d56Sopenharmony_ci c = C('foo') 2527db96d56Sopenharmony_ci self.assertEqual(c.object, 'foo') 2537db96d56Sopenharmony_ci 2547db96d56Sopenharmony_ci def test_field_named_BUILTINS_frozen(self): 2557db96d56Sopenharmony_ci # gh-96151 2567db96d56Sopenharmony_ci @dataclass(frozen=True) 2577db96d56Sopenharmony_ci class C: 2587db96d56Sopenharmony_ci BUILTINS: int 2597db96d56Sopenharmony_ci c = C(5) 2607db96d56Sopenharmony_ci self.assertEqual(c.BUILTINS, 5) 2617db96d56Sopenharmony_ci 2627db96d56Sopenharmony_ci def test_field_named_like_builtin(self): 2637db96d56Sopenharmony_ci # Attribute names can shadow built-in names 2647db96d56Sopenharmony_ci # since code generation is used. 2657db96d56Sopenharmony_ci # Ensure that this is not happening. 2667db96d56Sopenharmony_ci exclusions = {'None', 'True', 'False'} 2677db96d56Sopenharmony_ci builtins_names = sorted( 2687db96d56Sopenharmony_ci b for b in builtins.__dict__.keys() 2697db96d56Sopenharmony_ci if not b.startswith('__') and b not in exclusions 2707db96d56Sopenharmony_ci ) 2717db96d56Sopenharmony_ci attributes = [(name, str) for name in builtins_names] 2727db96d56Sopenharmony_ci C = make_dataclass('C', attributes) 2737db96d56Sopenharmony_ci 2747db96d56Sopenharmony_ci c = C(*[name for name in builtins_names]) 2757db96d56Sopenharmony_ci 2767db96d56Sopenharmony_ci for name in builtins_names: 2777db96d56Sopenharmony_ci self.assertEqual(getattr(c, name), name) 2787db96d56Sopenharmony_ci 2797db96d56Sopenharmony_ci def test_field_named_like_builtin_frozen(self): 2807db96d56Sopenharmony_ci # Attribute names can shadow built-in names 2817db96d56Sopenharmony_ci # since code generation is used. 2827db96d56Sopenharmony_ci # Ensure that this is not happening 2837db96d56Sopenharmony_ci # for frozen data classes. 2847db96d56Sopenharmony_ci exclusions = {'None', 'True', 'False'} 2857db96d56Sopenharmony_ci builtins_names = sorted( 2867db96d56Sopenharmony_ci b for b in builtins.__dict__.keys() 2877db96d56Sopenharmony_ci if not b.startswith('__') and b not in exclusions 2887db96d56Sopenharmony_ci ) 2897db96d56Sopenharmony_ci attributes = [(name, str) for name in builtins_names] 2907db96d56Sopenharmony_ci C = make_dataclass('C', attributes, frozen=True) 2917db96d56Sopenharmony_ci 2927db96d56Sopenharmony_ci c = C(*[name for name in builtins_names]) 2937db96d56Sopenharmony_ci 2947db96d56Sopenharmony_ci for name in builtins_names: 2957db96d56Sopenharmony_ci self.assertEqual(getattr(c, name), name) 2967db96d56Sopenharmony_ci 2977db96d56Sopenharmony_ci def test_0_field_compare(self): 2987db96d56Sopenharmony_ci # Ensure that order=False is the default. 2997db96d56Sopenharmony_ci @dataclass 3007db96d56Sopenharmony_ci class C0: 3017db96d56Sopenharmony_ci pass 3027db96d56Sopenharmony_ci 3037db96d56Sopenharmony_ci @dataclass(order=False) 3047db96d56Sopenharmony_ci class C1: 3057db96d56Sopenharmony_ci pass 3067db96d56Sopenharmony_ci 3077db96d56Sopenharmony_ci for cls in [C0, C1]: 3087db96d56Sopenharmony_ci with self.subTest(cls=cls): 3097db96d56Sopenharmony_ci self.assertEqual(cls(), cls()) 3107db96d56Sopenharmony_ci for idx, fn in enumerate([lambda a, b: a < b, 3117db96d56Sopenharmony_ci lambda a, b: a <= b, 3127db96d56Sopenharmony_ci lambda a, b: a > b, 3137db96d56Sopenharmony_ci lambda a, b: a >= b]): 3147db96d56Sopenharmony_ci with self.subTest(idx=idx): 3157db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 3167db96d56Sopenharmony_ci f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 3177db96d56Sopenharmony_ci fn(cls(), cls()) 3187db96d56Sopenharmony_ci 3197db96d56Sopenharmony_ci @dataclass(order=True) 3207db96d56Sopenharmony_ci class C: 3217db96d56Sopenharmony_ci pass 3227db96d56Sopenharmony_ci self.assertLessEqual(C(), C()) 3237db96d56Sopenharmony_ci self.assertGreaterEqual(C(), C()) 3247db96d56Sopenharmony_ci 3257db96d56Sopenharmony_ci def test_1_field_compare(self): 3267db96d56Sopenharmony_ci # Ensure that order=False is the default. 3277db96d56Sopenharmony_ci @dataclass 3287db96d56Sopenharmony_ci class C0: 3297db96d56Sopenharmony_ci x: int 3307db96d56Sopenharmony_ci 3317db96d56Sopenharmony_ci @dataclass(order=False) 3327db96d56Sopenharmony_ci class C1: 3337db96d56Sopenharmony_ci x: int 3347db96d56Sopenharmony_ci 3357db96d56Sopenharmony_ci for cls in [C0, C1]: 3367db96d56Sopenharmony_ci with self.subTest(cls=cls): 3377db96d56Sopenharmony_ci self.assertEqual(cls(1), cls(1)) 3387db96d56Sopenharmony_ci self.assertNotEqual(cls(0), cls(1)) 3397db96d56Sopenharmony_ci for idx, fn in enumerate([lambda a, b: a < b, 3407db96d56Sopenharmony_ci lambda a, b: a <= b, 3417db96d56Sopenharmony_ci lambda a, b: a > b, 3427db96d56Sopenharmony_ci lambda a, b: a >= b]): 3437db96d56Sopenharmony_ci with self.subTest(idx=idx): 3447db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 3457db96d56Sopenharmony_ci f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 3467db96d56Sopenharmony_ci fn(cls(0), cls(0)) 3477db96d56Sopenharmony_ci 3487db96d56Sopenharmony_ci @dataclass(order=True) 3497db96d56Sopenharmony_ci class C: 3507db96d56Sopenharmony_ci x: int 3517db96d56Sopenharmony_ci self.assertLess(C(0), C(1)) 3527db96d56Sopenharmony_ci self.assertLessEqual(C(0), C(1)) 3537db96d56Sopenharmony_ci self.assertLessEqual(C(1), C(1)) 3547db96d56Sopenharmony_ci self.assertGreater(C(1), C(0)) 3557db96d56Sopenharmony_ci self.assertGreaterEqual(C(1), C(0)) 3567db96d56Sopenharmony_ci self.assertGreaterEqual(C(1), C(1)) 3577db96d56Sopenharmony_ci 3587db96d56Sopenharmony_ci def test_simple_compare(self): 3597db96d56Sopenharmony_ci # Ensure that order=False is the default. 3607db96d56Sopenharmony_ci @dataclass 3617db96d56Sopenharmony_ci class C0: 3627db96d56Sopenharmony_ci x: int 3637db96d56Sopenharmony_ci y: int 3647db96d56Sopenharmony_ci 3657db96d56Sopenharmony_ci @dataclass(order=False) 3667db96d56Sopenharmony_ci class C1: 3677db96d56Sopenharmony_ci x: int 3687db96d56Sopenharmony_ci y: int 3697db96d56Sopenharmony_ci 3707db96d56Sopenharmony_ci for cls in [C0, C1]: 3717db96d56Sopenharmony_ci with self.subTest(cls=cls): 3727db96d56Sopenharmony_ci self.assertEqual(cls(0, 0), cls(0, 0)) 3737db96d56Sopenharmony_ci self.assertEqual(cls(1, 2), cls(1, 2)) 3747db96d56Sopenharmony_ci self.assertNotEqual(cls(1, 0), cls(0, 0)) 3757db96d56Sopenharmony_ci self.assertNotEqual(cls(1, 0), cls(1, 1)) 3767db96d56Sopenharmony_ci for idx, fn in enumerate([lambda a, b: a < b, 3777db96d56Sopenharmony_ci lambda a, b: a <= b, 3787db96d56Sopenharmony_ci lambda a, b: a > b, 3797db96d56Sopenharmony_ci lambda a, b: a >= b]): 3807db96d56Sopenharmony_ci with self.subTest(idx=idx): 3817db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 3827db96d56Sopenharmony_ci f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 3837db96d56Sopenharmony_ci fn(cls(0, 0), cls(0, 0)) 3847db96d56Sopenharmony_ci 3857db96d56Sopenharmony_ci @dataclass(order=True) 3867db96d56Sopenharmony_ci class C: 3877db96d56Sopenharmony_ci x: int 3887db96d56Sopenharmony_ci y: int 3897db96d56Sopenharmony_ci 3907db96d56Sopenharmony_ci for idx, fn in enumerate([lambda a, b: a == b, 3917db96d56Sopenharmony_ci lambda a, b: a <= b, 3927db96d56Sopenharmony_ci lambda a, b: a >= b]): 3937db96d56Sopenharmony_ci with self.subTest(idx=idx): 3947db96d56Sopenharmony_ci self.assertTrue(fn(C(0, 0), C(0, 0))) 3957db96d56Sopenharmony_ci 3967db96d56Sopenharmony_ci for idx, fn in enumerate([lambda a, b: a < b, 3977db96d56Sopenharmony_ci lambda a, b: a <= b, 3987db96d56Sopenharmony_ci lambda a, b: a != b]): 3997db96d56Sopenharmony_ci with self.subTest(idx=idx): 4007db96d56Sopenharmony_ci self.assertTrue(fn(C(0, 0), C(0, 1))) 4017db96d56Sopenharmony_ci self.assertTrue(fn(C(0, 1), C(1, 0))) 4027db96d56Sopenharmony_ci self.assertTrue(fn(C(1, 0), C(1, 1))) 4037db96d56Sopenharmony_ci 4047db96d56Sopenharmony_ci for idx, fn in enumerate([lambda a, b: a > b, 4057db96d56Sopenharmony_ci lambda a, b: a >= b, 4067db96d56Sopenharmony_ci lambda a, b: a != b]): 4077db96d56Sopenharmony_ci with self.subTest(idx=idx): 4087db96d56Sopenharmony_ci self.assertTrue(fn(C(0, 1), C(0, 0))) 4097db96d56Sopenharmony_ci self.assertTrue(fn(C(1, 0), C(0, 1))) 4107db96d56Sopenharmony_ci self.assertTrue(fn(C(1, 1), C(1, 0))) 4117db96d56Sopenharmony_ci 4127db96d56Sopenharmony_ci def test_compare_subclasses(self): 4137db96d56Sopenharmony_ci # Comparisons fail for subclasses, even if no fields 4147db96d56Sopenharmony_ci # are added. 4157db96d56Sopenharmony_ci @dataclass 4167db96d56Sopenharmony_ci class B: 4177db96d56Sopenharmony_ci i: int 4187db96d56Sopenharmony_ci 4197db96d56Sopenharmony_ci @dataclass 4207db96d56Sopenharmony_ci class C(B): 4217db96d56Sopenharmony_ci pass 4227db96d56Sopenharmony_ci 4237db96d56Sopenharmony_ci for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False), 4247db96d56Sopenharmony_ci (lambda a, b: a != b, True)]): 4257db96d56Sopenharmony_ci with self.subTest(idx=idx): 4267db96d56Sopenharmony_ci self.assertEqual(fn(B(0), C(0)), expected) 4277db96d56Sopenharmony_ci 4287db96d56Sopenharmony_ci for idx, fn in enumerate([lambda a, b: a < b, 4297db96d56Sopenharmony_ci lambda a, b: a <= b, 4307db96d56Sopenharmony_ci lambda a, b: a > b, 4317db96d56Sopenharmony_ci lambda a, b: a >= b]): 4327db96d56Sopenharmony_ci with self.subTest(idx=idx): 4337db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 4347db96d56Sopenharmony_ci "not supported between instances of 'B' and 'C'"): 4357db96d56Sopenharmony_ci fn(B(0), C(0)) 4367db96d56Sopenharmony_ci 4377db96d56Sopenharmony_ci def test_eq_order(self): 4387db96d56Sopenharmony_ci # Test combining eq and order. 4397db96d56Sopenharmony_ci for (eq, order, result ) in [ 4407db96d56Sopenharmony_ci (False, False, 'neither'), 4417db96d56Sopenharmony_ci (False, True, 'exception'), 4427db96d56Sopenharmony_ci (True, False, 'eq_only'), 4437db96d56Sopenharmony_ci (True, True, 'both'), 4447db96d56Sopenharmony_ci ]: 4457db96d56Sopenharmony_ci with self.subTest(eq=eq, order=order): 4467db96d56Sopenharmony_ci if result == 'exception': 4477db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'): 4487db96d56Sopenharmony_ci @dataclass(eq=eq, order=order) 4497db96d56Sopenharmony_ci class C: 4507db96d56Sopenharmony_ci pass 4517db96d56Sopenharmony_ci else: 4527db96d56Sopenharmony_ci @dataclass(eq=eq, order=order) 4537db96d56Sopenharmony_ci class C: 4547db96d56Sopenharmony_ci pass 4557db96d56Sopenharmony_ci 4567db96d56Sopenharmony_ci if result == 'neither': 4577db96d56Sopenharmony_ci self.assertNotIn('__eq__', C.__dict__) 4587db96d56Sopenharmony_ci self.assertNotIn('__lt__', C.__dict__) 4597db96d56Sopenharmony_ci self.assertNotIn('__le__', C.__dict__) 4607db96d56Sopenharmony_ci self.assertNotIn('__gt__', C.__dict__) 4617db96d56Sopenharmony_ci self.assertNotIn('__ge__', C.__dict__) 4627db96d56Sopenharmony_ci elif result == 'both': 4637db96d56Sopenharmony_ci self.assertIn('__eq__', C.__dict__) 4647db96d56Sopenharmony_ci self.assertIn('__lt__', C.__dict__) 4657db96d56Sopenharmony_ci self.assertIn('__le__', C.__dict__) 4667db96d56Sopenharmony_ci self.assertIn('__gt__', C.__dict__) 4677db96d56Sopenharmony_ci self.assertIn('__ge__', C.__dict__) 4687db96d56Sopenharmony_ci elif result == 'eq_only': 4697db96d56Sopenharmony_ci self.assertIn('__eq__', C.__dict__) 4707db96d56Sopenharmony_ci self.assertNotIn('__lt__', C.__dict__) 4717db96d56Sopenharmony_ci self.assertNotIn('__le__', C.__dict__) 4727db96d56Sopenharmony_ci self.assertNotIn('__gt__', C.__dict__) 4737db96d56Sopenharmony_ci self.assertNotIn('__ge__', C.__dict__) 4747db96d56Sopenharmony_ci else: 4757db96d56Sopenharmony_ci assert False, f'unknown result {result!r}' 4767db96d56Sopenharmony_ci 4777db96d56Sopenharmony_ci def test_field_no_default(self): 4787db96d56Sopenharmony_ci @dataclass 4797db96d56Sopenharmony_ci class C: 4807db96d56Sopenharmony_ci x: int = field() 4817db96d56Sopenharmony_ci 4827db96d56Sopenharmony_ci self.assertEqual(C(5).x, 5) 4837db96d56Sopenharmony_ci 4847db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 4857db96d56Sopenharmony_ci r"__init__\(\) missing 1 required " 4867db96d56Sopenharmony_ci "positional argument: 'x'"): 4877db96d56Sopenharmony_ci C() 4887db96d56Sopenharmony_ci 4897db96d56Sopenharmony_ci def test_field_default(self): 4907db96d56Sopenharmony_ci default = object() 4917db96d56Sopenharmony_ci @dataclass 4927db96d56Sopenharmony_ci class C: 4937db96d56Sopenharmony_ci x: object = field(default=default) 4947db96d56Sopenharmony_ci 4957db96d56Sopenharmony_ci self.assertIs(C.x, default) 4967db96d56Sopenharmony_ci c = C(10) 4977db96d56Sopenharmony_ci self.assertEqual(c.x, 10) 4987db96d56Sopenharmony_ci 4997db96d56Sopenharmony_ci # If we delete the instance attribute, we should then see the 5007db96d56Sopenharmony_ci # class attribute. 5017db96d56Sopenharmony_ci del c.x 5027db96d56Sopenharmony_ci self.assertIs(c.x, default) 5037db96d56Sopenharmony_ci 5047db96d56Sopenharmony_ci self.assertIs(C().x, default) 5057db96d56Sopenharmony_ci 5067db96d56Sopenharmony_ci def test_not_in_repr(self): 5077db96d56Sopenharmony_ci @dataclass 5087db96d56Sopenharmony_ci class C: 5097db96d56Sopenharmony_ci x: int = field(repr=False) 5107db96d56Sopenharmony_ci with self.assertRaises(TypeError): 5117db96d56Sopenharmony_ci C() 5127db96d56Sopenharmony_ci c = C(10) 5137db96d56Sopenharmony_ci self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()') 5147db96d56Sopenharmony_ci 5157db96d56Sopenharmony_ci @dataclass 5167db96d56Sopenharmony_ci class C: 5177db96d56Sopenharmony_ci x: int = field(repr=False) 5187db96d56Sopenharmony_ci y: int 5197db96d56Sopenharmony_ci c = C(10, 20) 5207db96d56Sopenharmony_ci self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)') 5217db96d56Sopenharmony_ci 5227db96d56Sopenharmony_ci def test_not_in_compare(self): 5237db96d56Sopenharmony_ci @dataclass 5247db96d56Sopenharmony_ci class C: 5257db96d56Sopenharmony_ci x: int = 0 5267db96d56Sopenharmony_ci y: int = field(compare=False, default=4) 5277db96d56Sopenharmony_ci 5287db96d56Sopenharmony_ci self.assertEqual(C(), C(0, 20)) 5297db96d56Sopenharmony_ci self.assertEqual(C(1, 10), C(1, 20)) 5307db96d56Sopenharmony_ci self.assertNotEqual(C(3), C(4, 10)) 5317db96d56Sopenharmony_ci self.assertNotEqual(C(3, 10), C(4, 10)) 5327db96d56Sopenharmony_ci 5337db96d56Sopenharmony_ci def test_no_unhashable_default(self): 5347db96d56Sopenharmony_ci # See bpo-44674. 5357db96d56Sopenharmony_ci class Unhashable: 5367db96d56Sopenharmony_ci __hash__ = None 5377db96d56Sopenharmony_ci 5387db96d56Sopenharmony_ci unhashable_re = 'mutable default .* for field a is not allowed' 5397db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, unhashable_re): 5407db96d56Sopenharmony_ci @dataclass 5417db96d56Sopenharmony_ci class A: 5427db96d56Sopenharmony_ci a: dict = {} 5437db96d56Sopenharmony_ci 5447db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, unhashable_re): 5457db96d56Sopenharmony_ci @dataclass 5467db96d56Sopenharmony_ci class A: 5477db96d56Sopenharmony_ci a: Any = Unhashable() 5487db96d56Sopenharmony_ci 5497db96d56Sopenharmony_ci # Make sure that the machinery looking for hashability is using the 5507db96d56Sopenharmony_ci # class's __hash__, not the instance's __hash__. 5517db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, unhashable_re): 5527db96d56Sopenharmony_ci unhashable = Unhashable() 5537db96d56Sopenharmony_ci # This shouldn't make the variable hashable. 5547db96d56Sopenharmony_ci unhashable.__hash__ = lambda: 0 5557db96d56Sopenharmony_ci @dataclass 5567db96d56Sopenharmony_ci class A: 5577db96d56Sopenharmony_ci a: Any = unhashable 5587db96d56Sopenharmony_ci 5597db96d56Sopenharmony_ci def test_hash_field_rules(self): 5607db96d56Sopenharmony_ci # Test all 6 cases of: 5617db96d56Sopenharmony_ci # hash=True/False/None 5627db96d56Sopenharmony_ci # compare=True/False 5637db96d56Sopenharmony_ci for (hash_, compare, result ) in [ 5647db96d56Sopenharmony_ci (True, False, 'field' ), 5657db96d56Sopenharmony_ci (True, True, 'field' ), 5667db96d56Sopenharmony_ci (False, False, 'absent'), 5677db96d56Sopenharmony_ci (False, True, 'absent'), 5687db96d56Sopenharmony_ci (None, False, 'absent'), 5697db96d56Sopenharmony_ci (None, True, 'field' ), 5707db96d56Sopenharmony_ci ]: 5717db96d56Sopenharmony_ci with self.subTest(hash=hash_, compare=compare): 5727db96d56Sopenharmony_ci @dataclass(unsafe_hash=True) 5737db96d56Sopenharmony_ci class C: 5747db96d56Sopenharmony_ci x: int = field(compare=compare, hash=hash_, default=5) 5757db96d56Sopenharmony_ci 5767db96d56Sopenharmony_ci if result == 'field': 5777db96d56Sopenharmony_ci # __hash__ contains the field. 5787db96d56Sopenharmony_ci self.assertEqual(hash(C(5)), hash((5,))) 5797db96d56Sopenharmony_ci elif result == 'absent': 5807db96d56Sopenharmony_ci # The field is not present in the hash. 5817db96d56Sopenharmony_ci self.assertEqual(hash(C(5)), hash(())) 5827db96d56Sopenharmony_ci else: 5837db96d56Sopenharmony_ci assert False, f'unknown result {result!r}' 5847db96d56Sopenharmony_ci 5857db96d56Sopenharmony_ci def test_init_false_no_default(self): 5867db96d56Sopenharmony_ci # If init=False and no default value, then the field won't be 5877db96d56Sopenharmony_ci # present in the instance. 5887db96d56Sopenharmony_ci @dataclass 5897db96d56Sopenharmony_ci class C: 5907db96d56Sopenharmony_ci x: int = field(init=False) 5917db96d56Sopenharmony_ci 5927db96d56Sopenharmony_ci self.assertNotIn('x', C().__dict__) 5937db96d56Sopenharmony_ci 5947db96d56Sopenharmony_ci @dataclass 5957db96d56Sopenharmony_ci class C: 5967db96d56Sopenharmony_ci x: int 5977db96d56Sopenharmony_ci y: int = 0 5987db96d56Sopenharmony_ci z: int = field(init=False) 5997db96d56Sopenharmony_ci t: int = 10 6007db96d56Sopenharmony_ci 6017db96d56Sopenharmony_ci self.assertNotIn('z', C(0).__dict__) 6027db96d56Sopenharmony_ci self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0}) 6037db96d56Sopenharmony_ci 6047db96d56Sopenharmony_ci def test_class_marker(self): 6057db96d56Sopenharmony_ci @dataclass 6067db96d56Sopenharmony_ci class C: 6077db96d56Sopenharmony_ci x: int 6087db96d56Sopenharmony_ci y: str = field(init=False, default=None) 6097db96d56Sopenharmony_ci z: str = field(repr=False) 6107db96d56Sopenharmony_ci 6117db96d56Sopenharmony_ci the_fields = fields(C) 6127db96d56Sopenharmony_ci # the_fields is a tuple of 3 items, each value 6137db96d56Sopenharmony_ci # is in __annotations__. 6147db96d56Sopenharmony_ci self.assertIsInstance(the_fields, tuple) 6157db96d56Sopenharmony_ci for f in the_fields: 6167db96d56Sopenharmony_ci self.assertIs(type(f), Field) 6177db96d56Sopenharmony_ci self.assertIn(f.name, C.__annotations__) 6187db96d56Sopenharmony_ci 6197db96d56Sopenharmony_ci self.assertEqual(len(the_fields), 3) 6207db96d56Sopenharmony_ci 6217db96d56Sopenharmony_ci self.assertEqual(the_fields[0].name, 'x') 6227db96d56Sopenharmony_ci self.assertEqual(the_fields[0].type, int) 6237db96d56Sopenharmony_ci self.assertFalse(hasattr(C, 'x')) 6247db96d56Sopenharmony_ci self.assertTrue (the_fields[0].init) 6257db96d56Sopenharmony_ci self.assertTrue (the_fields[0].repr) 6267db96d56Sopenharmony_ci self.assertEqual(the_fields[1].name, 'y') 6277db96d56Sopenharmony_ci self.assertEqual(the_fields[1].type, str) 6287db96d56Sopenharmony_ci self.assertIsNone(getattr(C, 'y')) 6297db96d56Sopenharmony_ci self.assertFalse(the_fields[1].init) 6307db96d56Sopenharmony_ci self.assertTrue (the_fields[1].repr) 6317db96d56Sopenharmony_ci self.assertEqual(the_fields[2].name, 'z') 6327db96d56Sopenharmony_ci self.assertEqual(the_fields[2].type, str) 6337db96d56Sopenharmony_ci self.assertFalse(hasattr(C, 'z')) 6347db96d56Sopenharmony_ci self.assertTrue (the_fields[2].init) 6357db96d56Sopenharmony_ci self.assertFalse(the_fields[2].repr) 6367db96d56Sopenharmony_ci 6377db96d56Sopenharmony_ci def test_field_order(self): 6387db96d56Sopenharmony_ci @dataclass 6397db96d56Sopenharmony_ci class B: 6407db96d56Sopenharmony_ci a: str = 'B:a' 6417db96d56Sopenharmony_ci b: str = 'B:b' 6427db96d56Sopenharmony_ci c: str = 'B:c' 6437db96d56Sopenharmony_ci 6447db96d56Sopenharmony_ci @dataclass 6457db96d56Sopenharmony_ci class C(B): 6467db96d56Sopenharmony_ci b: str = 'C:b' 6477db96d56Sopenharmony_ci 6487db96d56Sopenharmony_ci self.assertEqual([(f.name, f.default) for f in fields(C)], 6497db96d56Sopenharmony_ci [('a', 'B:a'), 6507db96d56Sopenharmony_ci ('b', 'C:b'), 6517db96d56Sopenharmony_ci ('c', 'B:c')]) 6527db96d56Sopenharmony_ci 6537db96d56Sopenharmony_ci @dataclass 6547db96d56Sopenharmony_ci class D(B): 6557db96d56Sopenharmony_ci c: str = 'D:c' 6567db96d56Sopenharmony_ci 6577db96d56Sopenharmony_ci self.assertEqual([(f.name, f.default) for f in fields(D)], 6587db96d56Sopenharmony_ci [('a', 'B:a'), 6597db96d56Sopenharmony_ci ('b', 'B:b'), 6607db96d56Sopenharmony_ci ('c', 'D:c')]) 6617db96d56Sopenharmony_ci 6627db96d56Sopenharmony_ci @dataclass 6637db96d56Sopenharmony_ci class E(D): 6647db96d56Sopenharmony_ci a: str = 'E:a' 6657db96d56Sopenharmony_ci d: str = 'E:d' 6667db96d56Sopenharmony_ci 6677db96d56Sopenharmony_ci self.assertEqual([(f.name, f.default) for f in fields(E)], 6687db96d56Sopenharmony_ci [('a', 'E:a'), 6697db96d56Sopenharmony_ci ('b', 'B:b'), 6707db96d56Sopenharmony_ci ('c', 'D:c'), 6717db96d56Sopenharmony_ci ('d', 'E:d')]) 6727db96d56Sopenharmony_ci 6737db96d56Sopenharmony_ci def test_class_attrs(self): 6747db96d56Sopenharmony_ci # We only have a class attribute if a default value is 6757db96d56Sopenharmony_ci # specified, either directly or via a field with a default. 6767db96d56Sopenharmony_ci default = object() 6777db96d56Sopenharmony_ci @dataclass 6787db96d56Sopenharmony_ci class C: 6797db96d56Sopenharmony_ci x: int 6807db96d56Sopenharmony_ci y: int = field(repr=False) 6817db96d56Sopenharmony_ci z: object = default 6827db96d56Sopenharmony_ci t: int = field(default=100) 6837db96d56Sopenharmony_ci 6847db96d56Sopenharmony_ci self.assertFalse(hasattr(C, 'x')) 6857db96d56Sopenharmony_ci self.assertFalse(hasattr(C, 'y')) 6867db96d56Sopenharmony_ci self.assertIs (C.z, default) 6877db96d56Sopenharmony_ci self.assertEqual(C.t, 100) 6887db96d56Sopenharmony_ci 6897db96d56Sopenharmony_ci def test_disallowed_mutable_defaults(self): 6907db96d56Sopenharmony_ci # For the known types, don't allow mutable default values. 6917db96d56Sopenharmony_ci for typ, empty, non_empty in [(list, [], [1]), 6927db96d56Sopenharmony_ci (dict, {}, {0:1}), 6937db96d56Sopenharmony_ci (set, set(), set([1])), 6947db96d56Sopenharmony_ci ]: 6957db96d56Sopenharmony_ci with self.subTest(typ=typ): 6967db96d56Sopenharmony_ci # Can't use a zero-length value. 6977db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, 6987db96d56Sopenharmony_ci f'mutable default {typ} for field ' 6997db96d56Sopenharmony_ci 'x is not allowed'): 7007db96d56Sopenharmony_ci @dataclass 7017db96d56Sopenharmony_ci class Point: 7027db96d56Sopenharmony_ci x: typ = empty 7037db96d56Sopenharmony_ci 7047db96d56Sopenharmony_ci 7057db96d56Sopenharmony_ci # Nor a non-zero-length value 7067db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, 7077db96d56Sopenharmony_ci f'mutable default {typ} for field ' 7087db96d56Sopenharmony_ci 'y is not allowed'): 7097db96d56Sopenharmony_ci @dataclass 7107db96d56Sopenharmony_ci class Point: 7117db96d56Sopenharmony_ci y: typ = non_empty 7127db96d56Sopenharmony_ci 7137db96d56Sopenharmony_ci # Check subtypes also fail. 7147db96d56Sopenharmony_ci class Subclass(typ): pass 7157db96d56Sopenharmony_ci 7167db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, 7177db96d56Sopenharmony_ci f"mutable default .*Subclass'>" 7187db96d56Sopenharmony_ci ' for field z is not allowed' 7197db96d56Sopenharmony_ci ): 7207db96d56Sopenharmony_ci @dataclass 7217db96d56Sopenharmony_ci class Point: 7227db96d56Sopenharmony_ci z: typ = Subclass() 7237db96d56Sopenharmony_ci 7247db96d56Sopenharmony_ci # Because this is a ClassVar, it can be mutable. 7257db96d56Sopenharmony_ci @dataclass 7267db96d56Sopenharmony_ci class C: 7277db96d56Sopenharmony_ci z: ClassVar[typ] = typ() 7287db96d56Sopenharmony_ci 7297db96d56Sopenharmony_ci # Because this is a ClassVar, it can be mutable. 7307db96d56Sopenharmony_ci @dataclass 7317db96d56Sopenharmony_ci class C: 7327db96d56Sopenharmony_ci x: ClassVar[typ] = Subclass() 7337db96d56Sopenharmony_ci 7347db96d56Sopenharmony_ci def test_deliberately_mutable_defaults(self): 7357db96d56Sopenharmony_ci # If a mutable default isn't in the known list of 7367db96d56Sopenharmony_ci # (list, dict, set), then it's okay. 7377db96d56Sopenharmony_ci class Mutable: 7387db96d56Sopenharmony_ci def __init__(self): 7397db96d56Sopenharmony_ci self.l = [] 7407db96d56Sopenharmony_ci 7417db96d56Sopenharmony_ci @dataclass 7427db96d56Sopenharmony_ci class C: 7437db96d56Sopenharmony_ci x: Mutable 7447db96d56Sopenharmony_ci 7457db96d56Sopenharmony_ci # These 2 instances will share this value of x. 7467db96d56Sopenharmony_ci lst = Mutable() 7477db96d56Sopenharmony_ci o1 = C(lst) 7487db96d56Sopenharmony_ci o2 = C(lst) 7497db96d56Sopenharmony_ci self.assertEqual(o1, o2) 7507db96d56Sopenharmony_ci o1.x.l.extend([1, 2]) 7517db96d56Sopenharmony_ci self.assertEqual(o1, o2) 7527db96d56Sopenharmony_ci self.assertEqual(o1.x.l, [1, 2]) 7537db96d56Sopenharmony_ci self.assertIs(o1.x, o2.x) 7547db96d56Sopenharmony_ci 7557db96d56Sopenharmony_ci def test_no_options(self): 7567db96d56Sopenharmony_ci # Call with dataclass(). 7577db96d56Sopenharmony_ci @dataclass() 7587db96d56Sopenharmony_ci class C: 7597db96d56Sopenharmony_ci x: int 7607db96d56Sopenharmony_ci 7617db96d56Sopenharmony_ci self.assertEqual(C(42).x, 42) 7627db96d56Sopenharmony_ci 7637db96d56Sopenharmony_ci def test_not_tuple(self): 7647db96d56Sopenharmony_ci # Make sure we can't be compared to a tuple. 7657db96d56Sopenharmony_ci @dataclass 7667db96d56Sopenharmony_ci class Point: 7677db96d56Sopenharmony_ci x: int 7687db96d56Sopenharmony_ci y: int 7697db96d56Sopenharmony_ci self.assertNotEqual(Point(1, 2), (1, 2)) 7707db96d56Sopenharmony_ci 7717db96d56Sopenharmony_ci # And that we can't compare to another unrelated dataclass. 7727db96d56Sopenharmony_ci @dataclass 7737db96d56Sopenharmony_ci class C: 7747db96d56Sopenharmony_ci x: int 7757db96d56Sopenharmony_ci y: int 7767db96d56Sopenharmony_ci self.assertNotEqual(Point(1, 3), C(1, 3)) 7777db96d56Sopenharmony_ci 7787db96d56Sopenharmony_ci def test_not_other_dataclass(self): 7797db96d56Sopenharmony_ci # Test that some of the problems with namedtuple don't happen 7807db96d56Sopenharmony_ci # here. 7817db96d56Sopenharmony_ci @dataclass 7827db96d56Sopenharmony_ci class Point3D: 7837db96d56Sopenharmony_ci x: int 7847db96d56Sopenharmony_ci y: int 7857db96d56Sopenharmony_ci z: int 7867db96d56Sopenharmony_ci 7877db96d56Sopenharmony_ci @dataclass 7887db96d56Sopenharmony_ci class Date: 7897db96d56Sopenharmony_ci year: int 7907db96d56Sopenharmony_ci month: int 7917db96d56Sopenharmony_ci day: int 7927db96d56Sopenharmony_ci 7937db96d56Sopenharmony_ci self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3)) 7947db96d56Sopenharmony_ci self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3)) 7957db96d56Sopenharmony_ci 7967db96d56Sopenharmony_ci # Make sure we can't unpack. 7977db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'unpack'): 7987db96d56Sopenharmony_ci x, y, z = Point3D(4, 5, 6) 7997db96d56Sopenharmony_ci 8007db96d56Sopenharmony_ci # Make sure another class with the same field names isn't 8017db96d56Sopenharmony_ci # equal. 8027db96d56Sopenharmony_ci @dataclass 8037db96d56Sopenharmony_ci class Point3Dv1: 8047db96d56Sopenharmony_ci x: int = 0 8057db96d56Sopenharmony_ci y: int = 0 8067db96d56Sopenharmony_ci z: int = 0 8077db96d56Sopenharmony_ci self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1()) 8087db96d56Sopenharmony_ci 8097db96d56Sopenharmony_ci def test_function_annotations(self): 8107db96d56Sopenharmony_ci # Some dummy class and instance to use as a default. 8117db96d56Sopenharmony_ci class F: 8127db96d56Sopenharmony_ci pass 8137db96d56Sopenharmony_ci f = F() 8147db96d56Sopenharmony_ci 8157db96d56Sopenharmony_ci def validate_class(cls): 8167db96d56Sopenharmony_ci # First, check __annotations__, even though they're not 8177db96d56Sopenharmony_ci # function annotations. 8187db96d56Sopenharmony_ci self.assertEqual(cls.__annotations__['i'], int) 8197db96d56Sopenharmony_ci self.assertEqual(cls.__annotations__['j'], str) 8207db96d56Sopenharmony_ci self.assertEqual(cls.__annotations__['k'], F) 8217db96d56Sopenharmony_ci self.assertEqual(cls.__annotations__['l'], float) 8227db96d56Sopenharmony_ci self.assertEqual(cls.__annotations__['z'], complex) 8237db96d56Sopenharmony_ci 8247db96d56Sopenharmony_ci # Verify __init__. 8257db96d56Sopenharmony_ci 8267db96d56Sopenharmony_ci signature = inspect.signature(cls.__init__) 8277db96d56Sopenharmony_ci # Check the return type, should be None. 8287db96d56Sopenharmony_ci self.assertIs(signature.return_annotation, None) 8297db96d56Sopenharmony_ci 8307db96d56Sopenharmony_ci # Check each parameter. 8317db96d56Sopenharmony_ci params = iter(signature.parameters.values()) 8327db96d56Sopenharmony_ci param = next(params) 8337db96d56Sopenharmony_ci # This is testing an internal name, and probably shouldn't be tested. 8347db96d56Sopenharmony_ci self.assertEqual(param.name, 'self') 8357db96d56Sopenharmony_ci param = next(params) 8367db96d56Sopenharmony_ci self.assertEqual(param.name, 'i') 8377db96d56Sopenharmony_ci self.assertIs (param.annotation, int) 8387db96d56Sopenharmony_ci self.assertEqual(param.default, inspect.Parameter.empty) 8397db96d56Sopenharmony_ci self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 8407db96d56Sopenharmony_ci param = next(params) 8417db96d56Sopenharmony_ci self.assertEqual(param.name, 'j') 8427db96d56Sopenharmony_ci self.assertIs (param.annotation, str) 8437db96d56Sopenharmony_ci self.assertEqual(param.default, inspect.Parameter.empty) 8447db96d56Sopenharmony_ci self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 8457db96d56Sopenharmony_ci param = next(params) 8467db96d56Sopenharmony_ci self.assertEqual(param.name, 'k') 8477db96d56Sopenharmony_ci self.assertIs (param.annotation, F) 8487db96d56Sopenharmony_ci # Don't test for the default, since it's set to MISSING. 8497db96d56Sopenharmony_ci self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 8507db96d56Sopenharmony_ci param = next(params) 8517db96d56Sopenharmony_ci self.assertEqual(param.name, 'l') 8527db96d56Sopenharmony_ci self.assertIs (param.annotation, float) 8537db96d56Sopenharmony_ci # Don't test for the default, since it's set to MISSING. 8547db96d56Sopenharmony_ci self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 8557db96d56Sopenharmony_ci self.assertRaises(StopIteration, next, params) 8567db96d56Sopenharmony_ci 8577db96d56Sopenharmony_ci 8587db96d56Sopenharmony_ci @dataclass 8597db96d56Sopenharmony_ci class C: 8607db96d56Sopenharmony_ci i: int 8617db96d56Sopenharmony_ci j: str 8627db96d56Sopenharmony_ci k: F = f 8637db96d56Sopenharmony_ci l: float=field(default=None) 8647db96d56Sopenharmony_ci z: complex=field(default=3+4j, init=False) 8657db96d56Sopenharmony_ci 8667db96d56Sopenharmony_ci validate_class(C) 8677db96d56Sopenharmony_ci 8687db96d56Sopenharmony_ci # Now repeat with __hash__. 8697db96d56Sopenharmony_ci @dataclass(frozen=True, unsafe_hash=True) 8707db96d56Sopenharmony_ci class C: 8717db96d56Sopenharmony_ci i: int 8727db96d56Sopenharmony_ci j: str 8737db96d56Sopenharmony_ci k: F = f 8747db96d56Sopenharmony_ci l: float=field(default=None) 8757db96d56Sopenharmony_ci z: complex=field(default=3+4j, init=False) 8767db96d56Sopenharmony_ci 8777db96d56Sopenharmony_ci validate_class(C) 8787db96d56Sopenharmony_ci 8797db96d56Sopenharmony_ci def test_missing_default(self): 8807db96d56Sopenharmony_ci # Test that MISSING works the same as a default not being 8817db96d56Sopenharmony_ci # specified. 8827db96d56Sopenharmony_ci @dataclass 8837db96d56Sopenharmony_ci class C: 8847db96d56Sopenharmony_ci x: int=field(default=MISSING) 8857db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 8867db96d56Sopenharmony_ci r'__init__\(\) missing 1 required ' 8877db96d56Sopenharmony_ci 'positional argument'): 8887db96d56Sopenharmony_ci C() 8897db96d56Sopenharmony_ci self.assertNotIn('x', C.__dict__) 8907db96d56Sopenharmony_ci 8917db96d56Sopenharmony_ci @dataclass 8927db96d56Sopenharmony_ci class D: 8937db96d56Sopenharmony_ci x: int 8947db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 8957db96d56Sopenharmony_ci r'__init__\(\) missing 1 required ' 8967db96d56Sopenharmony_ci 'positional argument'): 8977db96d56Sopenharmony_ci D() 8987db96d56Sopenharmony_ci self.assertNotIn('x', D.__dict__) 8997db96d56Sopenharmony_ci 9007db96d56Sopenharmony_ci def test_missing_default_factory(self): 9017db96d56Sopenharmony_ci # Test that MISSING works the same as a default factory not 9027db96d56Sopenharmony_ci # being specified (which is really the same as a default not 9037db96d56Sopenharmony_ci # being specified, too). 9047db96d56Sopenharmony_ci @dataclass 9057db96d56Sopenharmony_ci class C: 9067db96d56Sopenharmony_ci x: int=field(default_factory=MISSING) 9077db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 9087db96d56Sopenharmony_ci r'__init__\(\) missing 1 required ' 9097db96d56Sopenharmony_ci 'positional argument'): 9107db96d56Sopenharmony_ci C() 9117db96d56Sopenharmony_ci self.assertNotIn('x', C.__dict__) 9127db96d56Sopenharmony_ci 9137db96d56Sopenharmony_ci @dataclass 9147db96d56Sopenharmony_ci class D: 9157db96d56Sopenharmony_ci x: int=field(default=MISSING, default_factory=MISSING) 9167db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 9177db96d56Sopenharmony_ci r'__init__\(\) missing 1 required ' 9187db96d56Sopenharmony_ci 'positional argument'): 9197db96d56Sopenharmony_ci D() 9207db96d56Sopenharmony_ci self.assertNotIn('x', D.__dict__) 9217db96d56Sopenharmony_ci 9227db96d56Sopenharmony_ci def test_missing_repr(self): 9237db96d56Sopenharmony_ci self.assertIn('MISSING_TYPE object', repr(MISSING)) 9247db96d56Sopenharmony_ci 9257db96d56Sopenharmony_ci def test_dont_include_other_annotations(self): 9267db96d56Sopenharmony_ci @dataclass 9277db96d56Sopenharmony_ci class C: 9287db96d56Sopenharmony_ci i: int 9297db96d56Sopenharmony_ci def foo(self) -> int: 9307db96d56Sopenharmony_ci return 4 9317db96d56Sopenharmony_ci @property 9327db96d56Sopenharmony_ci def bar(self) -> int: 9337db96d56Sopenharmony_ci return 5 9347db96d56Sopenharmony_ci self.assertEqual(list(C.__annotations__), ['i']) 9357db96d56Sopenharmony_ci self.assertEqual(C(10).foo(), 4) 9367db96d56Sopenharmony_ci self.assertEqual(C(10).bar, 5) 9377db96d56Sopenharmony_ci self.assertEqual(C(10).i, 10) 9387db96d56Sopenharmony_ci 9397db96d56Sopenharmony_ci def test_post_init(self): 9407db96d56Sopenharmony_ci # Just make sure it gets called 9417db96d56Sopenharmony_ci @dataclass 9427db96d56Sopenharmony_ci class C: 9437db96d56Sopenharmony_ci def __post_init__(self): 9447db96d56Sopenharmony_ci raise CustomError() 9457db96d56Sopenharmony_ci with self.assertRaises(CustomError): 9467db96d56Sopenharmony_ci C() 9477db96d56Sopenharmony_ci 9487db96d56Sopenharmony_ci @dataclass 9497db96d56Sopenharmony_ci class C: 9507db96d56Sopenharmony_ci i: int = 10 9517db96d56Sopenharmony_ci def __post_init__(self): 9527db96d56Sopenharmony_ci if self.i == 10: 9537db96d56Sopenharmony_ci raise CustomError() 9547db96d56Sopenharmony_ci with self.assertRaises(CustomError): 9557db96d56Sopenharmony_ci C() 9567db96d56Sopenharmony_ci # post-init gets called, but doesn't raise. This is just 9577db96d56Sopenharmony_ci # checking that self is used correctly. 9587db96d56Sopenharmony_ci C(5) 9597db96d56Sopenharmony_ci 9607db96d56Sopenharmony_ci # If there's not an __init__, then post-init won't get called. 9617db96d56Sopenharmony_ci @dataclass(init=False) 9627db96d56Sopenharmony_ci class C: 9637db96d56Sopenharmony_ci def __post_init__(self): 9647db96d56Sopenharmony_ci raise CustomError() 9657db96d56Sopenharmony_ci # Creating the class won't raise 9667db96d56Sopenharmony_ci C() 9677db96d56Sopenharmony_ci 9687db96d56Sopenharmony_ci @dataclass 9697db96d56Sopenharmony_ci class C: 9707db96d56Sopenharmony_ci x: int = 0 9717db96d56Sopenharmony_ci def __post_init__(self): 9727db96d56Sopenharmony_ci self.x *= 2 9737db96d56Sopenharmony_ci self.assertEqual(C().x, 0) 9747db96d56Sopenharmony_ci self.assertEqual(C(2).x, 4) 9757db96d56Sopenharmony_ci 9767db96d56Sopenharmony_ci # Make sure that if we're frozen, post-init can't set 9777db96d56Sopenharmony_ci # attributes. 9787db96d56Sopenharmony_ci @dataclass(frozen=True) 9797db96d56Sopenharmony_ci class C: 9807db96d56Sopenharmony_ci x: int = 0 9817db96d56Sopenharmony_ci def __post_init__(self): 9827db96d56Sopenharmony_ci self.x *= 2 9837db96d56Sopenharmony_ci with self.assertRaises(FrozenInstanceError): 9847db96d56Sopenharmony_ci C() 9857db96d56Sopenharmony_ci 9867db96d56Sopenharmony_ci def test_post_init_super(self): 9877db96d56Sopenharmony_ci # Make sure super() post-init isn't called by default. 9887db96d56Sopenharmony_ci class B: 9897db96d56Sopenharmony_ci def __post_init__(self): 9907db96d56Sopenharmony_ci raise CustomError() 9917db96d56Sopenharmony_ci 9927db96d56Sopenharmony_ci @dataclass 9937db96d56Sopenharmony_ci class C(B): 9947db96d56Sopenharmony_ci def __post_init__(self): 9957db96d56Sopenharmony_ci self.x = 5 9967db96d56Sopenharmony_ci 9977db96d56Sopenharmony_ci self.assertEqual(C().x, 5) 9987db96d56Sopenharmony_ci 9997db96d56Sopenharmony_ci # Now call super(), and it will raise. 10007db96d56Sopenharmony_ci @dataclass 10017db96d56Sopenharmony_ci class C(B): 10027db96d56Sopenharmony_ci def __post_init__(self): 10037db96d56Sopenharmony_ci super().__post_init__() 10047db96d56Sopenharmony_ci 10057db96d56Sopenharmony_ci with self.assertRaises(CustomError): 10067db96d56Sopenharmony_ci C() 10077db96d56Sopenharmony_ci 10087db96d56Sopenharmony_ci # Make sure post-init is called, even if not defined in our 10097db96d56Sopenharmony_ci # class. 10107db96d56Sopenharmony_ci @dataclass 10117db96d56Sopenharmony_ci class C(B): 10127db96d56Sopenharmony_ci pass 10137db96d56Sopenharmony_ci 10147db96d56Sopenharmony_ci with self.assertRaises(CustomError): 10157db96d56Sopenharmony_ci C() 10167db96d56Sopenharmony_ci 10177db96d56Sopenharmony_ci def test_post_init_staticmethod(self): 10187db96d56Sopenharmony_ci flag = False 10197db96d56Sopenharmony_ci @dataclass 10207db96d56Sopenharmony_ci class C: 10217db96d56Sopenharmony_ci x: int 10227db96d56Sopenharmony_ci y: int 10237db96d56Sopenharmony_ci @staticmethod 10247db96d56Sopenharmony_ci def __post_init__(): 10257db96d56Sopenharmony_ci nonlocal flag 10267db96d56Sopenharmony_ci flag = True 10277db96d56Sopenharmony_ci 10287db96d56Sopenharmony_ci self.assertFalse(flag) 10297db96d56Sopenharmony_ci c = C(3, 4) 10307db96d56Sopenharmony_ci self.assertEqual((c.x, c.y), (3, 4)) 10317db96d56Sopenharmony_ci self.assertTrue(flag) 10327db96d56Sopenharmony_ci 10337db96d56Sopenharmony_ci def test_post_init_classmethod(self): 10347db96d56Sopenharmony_ci @dataclass 10357db96d56Sopenharmony_ci class C: 10367db96d56Sopenharmony_ci flag = False 10377db96d56Sopenharmony_ci x: int 10387db96d56Sopenharmony_ci y: int 10397db96d56Sopenharmony_ci @classmethod 10407db96d56Sopenharmony_ci def __post_init__(cls): 10417db96d56Sopenharmony_ci cls.flag = True 10427db96d56Sopenharmony_ci 10437db96d56Sopenharmony_ci self.assertFalse(C.flag) 10447db96d56Sopenharmony_ci c = C(3, 4) 10457db96d56Sopenharmony_ci self.assertEqual((c.x, c.y), (3, 4)) 10467db96d56Sopenharmony_ci self.assertTrue(C.flag) 10477db96d56Sopenharmony_ci 10487db96d56Sopenharmony_ci def test_post_init_not_auto_added(self): 10497db96d56Sopenharmony_ci # See bpo-46757, which had proposed always adding __post_init__. As 10507db96d56Sopenharmony_ci # Raymond Hettinger pointed out, that would be a breaking change. So, 10517db96d56Sopenharmony_ci # add a test to make sure that the current behavior doesn't change. 10527db96d56Sopenharmony_ci 10537db96d56Sopenharmony_ci @dataclass 10547db96d56Sopenharmony_ci class A0: 10557db96d56Sopenharmony_ci pass 10567db96d56Sopenharmony_ci 10577db96d56Sopenharmony_ci @dataclass 10587db96d56Sopenharmony_ci class B0: 10597db96d56Sopenharmony_ci b_called: bool = False 10607db96d56Sopenharmony_ci def __post_init__(self): 10617db96d56Sopenharmony_ci self.b_called = True 10627db96d56Sopenharmony_ci 10637db96d56Sopenharmony_ci @dataclass 10647db96d56Sopenharmony_ci class C0(A0, B0): 10657db96d56Sopenharmony_ci c_called: bool = False 10667db96d56Sopenharmony_ci def __post_init__(self): 10677db96d56Sopenharmony_ci super().__post_init__() 10687db96d56Sopenharmony_ci self.c_called = True 10697db96d56Sopenharmony_ci 10707db96d56Sopenharmony_ci # Since A0 has no __post_init__, and one wasn't automatically added 10717db96d56Sopenharmony_ci # (because that's the rule: it's never added by @dataclass, it's only 10727db96d56Sopenharmony_ci # the class author that can add it), then B0.__post_init__ is called. 10737db96d56Sopenharmony_ci # Verify that. 10747db96d56Sopenharmony_ci c = C0() 10757db96d56Sopenharmony_ci self.assertTrue(c.b_called) 10767db96d56Sopenharmony_ci self.assertTrue(c.c_called) 10777db96d56Sopenharmony_ci 10787db96d56Sopenharmony_ci ###################################### 10797db96d56Sopenharmony_ci # Now, the same thing, except A1 defines __post_init__. 10807db96d56Sopenharmony_ci @dataclass 10817db96d56Sopenharmony_ci class A1: 10827db96d56Sopenharmony_ci def __post_init__(self): 10837db96d56Sopenharmony_ci pass 10847db96d56Sopenharmony_ci 10857db96d56Sopenharmony_ci @dataclass 10867db96d56Sopenharmony_ci class B1: 10877db96d56Sopenharmony_ci b_called: bool = False 10887db96d56Sopenharmony_ci def __post_init__(self): 10897db96d56Sopenharmony_ci self.b_called = True 10907db96d56Sopenharmony_ci 10917db96d56Sopenharmony_ci @dataclass 10927db96d56Sopenharmony_ci class C1(A1, B1): 10937db96d56Sopenharmony_ci c_called: bool = False 10947db96d56Sopenharmony_ci def __post_init__(self): 10957db96d56Sopenharmony_ci super().__post_init__() 10967db96d56Sopenharmony_ci self.c_called = True 10977db96d56Sopenharmony_ci 10987db96d56Sopenharmony_ci # This time, B1.__post_init__ isn't being called. This mimics what 10997db96d56Sopenharmony_ci # would happen if A1.__post_init__ had been automatically added, 11007db96d56Sopenharmony_ci # instead of manually added as we see here. This test isn't really 11017db96d56Sopenharmony_ci # needed, but I'm including it just to demonstrate the changed 11027db96d56Sopenharmony_ci # behavior when A1 does define __post_init__. 11037db96d56Sopenharmony_ci c = C1() 11047db96d56Sopenharmony_ci self.assertFalse(c.b_called) 11057db96d56Sopenharmony_ci self.assertTrue(c.c_called) 11067db96d56Sopenharmony_ci 11077db96d56Sopenharmony_ci def test_class_var(self): 11087db96d56Sopenharmony_ci # Make sure ClassVars are ignored in __init__, __repr__, etc. 11097db96d56Sopenharmony_ci @dataclass 11107db96d56Sopenharmony_ci class C: 11117db96d56Sopenharmony_ci x: int 11127db96d56Sopenharmony_ci y: int = 10 11137db96d56Sopenharmony_ci z: ClassVar[int] = 1000 11147db96d56Sopenharmony_ci w: ClassVar[int] = 2000 11157db96d56Sopenharmony_ci t: ClassVar[int] = 3000 11167db96d56Sopenharmony_ci s: ClassVar = 4000 11177db96d56Sopenharmony_ci 11187db96d56Sopenharmony_ci c = C(5) 11197db96d56Sopenharmony_ci self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)') 11207db96d56Sopenharmony_ci self.assertEqual(len(fields(C)), 2) # We have 2 fields. 11217db96d56Sopenharmony_ci self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars. 11227db96d56Sopenharmony_ci self.assertEqual(c.z, 1000) 11237db96d56Sopenharmony_ci self.assertEqual(c.w, 2000) 11247db96d56Sopenharmony_ci self.assertEqual(c.t, 3000) 11257db96d56Sopenharmony_ci self.assertEqual(c.s, 4000) 11267db96d56Sopenharmony_ci C.z += 1 11277db96d56Sopenharmony_ci self.assertEqual(c.z, 1001) 11287db96d56Sopenharmony_ci c = C(20) 11297db96d56Sopenharmony_ci self.assertEqual((c.x, c.y), (20, 10)) 11307db96d56Sopenharmony_ci self.assertEqual(c.z, 1001) 11317db96d56Sopenharmony_ci self.assertEqual(c.w, 2000) 11327db96d56Sopenharmony_ci self.assertEqual(c.t, 3000) 11337db96d56Sopenharmony_ci self.assertEqual(c.s, 4000) 11347db96d56Sopenharmony_ci 11357db96d56Sopenharmony_ci def test_class_var_no_default(self): 11367db96d56Sopenharmony_ci # If a ClassVar has no default value, it should not be set on the class. 11377db96d56Sopenharmony_ci @dataclass 11387db96d56Sopenharmony_ci class C: 11397db96d56Sopenharmony_ci x: ClassVar[int] 11407db96d56Sopenharmony_ci 11417db96d56Sopenharmony_ci self.assertNotIn('x', C.__dict__) 11427db96d56Sopenharmony_ci 11437db96d56Sopenharmony_ci def test_class_var_default_factory(self): 11447db96d56Sopenharmony_ci # It makes no sense for a ClassVar to have a default factory. When 11457db96d56Sopenharmony_ci # would it be called? Call it yourself, since it's class-wide. 11467db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 11477db96d56Sopenharmony_ci 'cannot have a default factory'): 11487db96d56Sopenharmony_ci @dataclass 11497db96d56Sopenharmony_ci class C: 11507db96d56Sopenharmony_ci x: ClassVar[int] = field(default_factory=int) 11517db96d56Sopenharmony_ci 11527db96d56Sopenharmony_ci self.assertNotIn('x', C.__dict__) 11537db96d56Sopenharmony_ci 11547db96d56Sopenharmony_ci def test_class_var_with_default(self): 11557db96d56Sopenharmony_ci # If a ClassVar has a default value, it should be set on the class. 11567db96d56Sopenharmony_ci @dataclass 11577db96d56Sopenharmony_ci class C: 11587db96d56Sopenharmony_ci x: ClassVar[int] = 10 11597db96d56Sopenharmony_ci self.assertEqual(C.x, 10) 11607db96d56Sopenharmony_ci 11617db96d56Sopenharmony_ci @dataclass 11627db96d56Sopenharmony_ci class C: 11637db96d56Sopenharmony_ci x: ClassVar[int] = field(default=10) 11647db96d56Sopenharmony_ci self.assertEqual(C.x, 10) 11657db96d56Sopenharmony_ci 11667db96d56Sopenharmony_ci def test_class_var_frozen(self): 11677db96d56Sopenharmony_ci # Make sure ClassVars work even if we're frozen. 11687db96d56Sopenharmony_ci @dataclass(frozen=True) 11697db96d56Sopenharmony_ci class C: 11707db96d56Sopenharmony_ci x: int 11717db96d56Sopenharmony_ci y: int = 10 11727db96d56Sopenharmony_ci z: ClassVar[int] = 1000 11737db96d56Sopenharmony_ci w: ClassVar[int] = 2000 11747db96d56Sopenharmony_ci t: ClassVar[int] = 3000 11757db96d56Sopenharmony_ci 11767db96d56Sopenharmony_ci c = C(5) 11777db96d56Sopenharmony_ci self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)') 11787db96d56Sopenharmony_ci self.assertEqual(len(fields(C)), 2) # We have 2 fields 11797db96d56Sopenharmony_ci self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars 11807db96d56Sopenharmony_ci self.assertEqual(c.z, 1000) 11817db96d56Sopenharmony_ci self.assertEqual(c.w, 2000) 11827db96d56Sopenharmony_ci self.assertEqual(c.t, 3000) 11837db96d56Sopenharmony_ci # We can still modify the ClassVar, it's only instances that are 11847db96d56Sopenharmony_ci # frozen. 11857db96d56Sopenharmony_ci C.z += 1 11867db96d56Sopenharmony_ci self.assertEqual(c.z, 1001) 11877db96d56Sopenharmony_ci c = C(20) 11887db96d56Sopenharmony_ci self.assertEqual((c.x, c.y), (20, 10)) 11897db96d56Sopenharmony_ci self.assertEqual(c.z, 1001) 11907db96d56Sopenharmony_ci self.assertEqual(c.w, 2000) 11917db96d56Sopenharmony_ci self.assertEqual(c.t, 3000) 11927db96d56Sopenharmony_ci 11937db96d56Sopenharmony_ci def test_init_var_no_default(self): 11947db96d56Sopenharmony_ci # If an InitVar has no default value, it should not be set on the class. 11957db96d56Sopenharmony_ci @dataclass 11967db96d56Sopenharmony_ci class C: 11977db96d56Sopenharmony_ci x: InitVar[int] 11987db96d56Sopenharmony_ci 11997db96d56Sopenharmony_ci self.assertNotIn('x', C.__dict__) 12007db96d56Sopenharmony_ci 12017db96d56Sopenharmony_ci def test_init_var_default_factory(self): 12027db96d56Sopenharmony_ci # It makes no sense for an InitVar to have a default factory. When 12037db96d56Sopenharmony_ci # would it be called? Call it yourself, since it's class-wide. 12047db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 12057db96d56Sopenharmony_ci 'cannot have a default factory'): 12067db96d56Sopenharmony_ci @dataclass 12077db96d56Sopenharmony_ci class C: 12087db96d56Sopenharmony_ci x: InitVar[int] = field(default_factory=int) 12097db96d56Sopenharmony_ci 12107db96d56Sopenharmony_ci self.assertNotIn('x', C.__dict__) 12117db96d56Sopenharmony_ci 12127db96d56Sopenharmony_ci def test_init_var_with_default(self): 12137db96d56Sopenharmony_ci # If an InitVar has a default value, it should be set on the class. 12147db96d56Sopenharmony_ci @dataclass 12157db96d56Sopenharmony_ci class C: 12167db96d56Sopenharmony_ci x: InitVar[int] = 10 12177db96d56Sopenharmony_ci self.assertEqual(C.x, 10) 12187db96d56Sopenharmony_ci 12197db96d56Sopenharmony_ci @dataclass 12207db96d56Sopenharmony_ci class C: 12217db96d56Sopenharmony_ci x: InitVar[int] = field(default=10) 12227db96d56Sopenharmony_ci self.assertEqual(C.x, 10) 12237db96d56Sopenharmony_ci 12247db96d56Sopenharmony_ci def test_init_var(self): 12257db96d56Sopenharmony_ci @dataclass 12267db96d56Sopenharmony_ci class C: 12277db96d56Sopenharmony_ci x: int = None 12287db96d56Sopenharmony_ci init_param: InitVar[int] = None 12297db96d56Sopenharmony_ci 12307db96d56Sopenharmony_ci def __post_init__(self, init_param): 12317db96d56Sopenharmony_ci if self.x is None: 12327db96d56Sopenharmony_ci self.x = init_param*2 12337db96d56Sopenharmony_ci 12347db96d56Sopenharmony_ci c = C(init_param=10) 12357db96d56Sopenharmony_ci self.assertEqual(c.x, 20) 12367db96d56Sopenharmony_ci 12377db96d56Sopenharmony_ci def test_init_var_preserve_type(self): 12387db96d56Sopenharmony_ci self.assertEqual(InitVar[int].type, int) 12397db96d56Sopenharmony_ci 12407db96d56Sopenharmony_ci # Make sure the repr is correct. 12417db96d56Sopenharmony_ci self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]') 12427db96d56Sopenharmony_ci self.assertEqual(repr(InitVar[List[int]]), 12437db96d56Sopenharmony_ci 'dataclasses.InitVar[typing.List[int]]') 12447db96d56Sopenharmony_ci self.assertEqual(repr(InitVar[list[int]]), 12457db96d56Sopenharmony_ci 'dataclasses.InitVar[list[int]]') 12467db96d56Sopenharmony_ci self.assertEqual(repr(InitVar[int|str]), 12477db96d56Sopenharmony_ci 'dataclasses.InitVar[int | str]') 12487db96d56Sopenharmony_ci 12497db96d56Sopenharmony_ci def test_init_var_inheritance(self): 12507db96d56Sopenharmony_ci # Note that this deliberately tests that a dataclass need not 12517db96d56Sopenharmony_ci # have a __post_init__ function if it has an InitVar field. 12527db96d56Sopenharmony_ci # It could just be used in a derived class, as shown here. 12537db96d56Sopenharmony_ci @dataclass 12547db96d56Sopenharmony_ci class Base: 12557db96d56Sopenharmony_ci x: int 12567db96d56Sopenharmony_ci init_base: InitVar[int] 12577db96d56Sopenharmony_ci 12587db96d56Sopenharmony_ci # We can instantiate by passing the InitVar, even though 12597db96d56Sopenharmony_ci # it's not used. 12607db96d56Sopenharmony_ci b = Base(0, 10) 12617db96d56Sopenharmony_ci self.assertEqual(vars(b), {'x': 0}) 12627db96d56Sopenharmony_ci 12637db96d56Sopenharmony_ci @dataclass 12647db96d56Sopenharmony_ci class C(Base): 12657db96d56Sopenharmony_ci y: int 12667db96d56Sopenharmony_ci init_derived: InitVar[int] 12677db96d56Sopenharmony_ci 12687db96d56Sopenharmony_ci def __post_init__(self, init_base, init_derived): 12697db96d56Sopenharmony_ci self.x = self.x + init_base 12707db96d56Sopenharmony_ci self.y = self.y + init_derived 12717db96d56Sopenharmony_ci 12727db96d56Sopenharmony_ci c = C(10, 11, 50, 51) 12737db96d56Sopenharmony_ci self.assertEqual(vars(c), {'x': 21, 'y': 101}) 12747db96d56Sopenharmony_ci 12757db96d56Sopenharmony_ci def test_default_factory(self): 12767db96d56Sopenharmony_ci # Test a factory that returns a new list. 12777db96d56Sopenharmony_ci @dataclass 12787db96d56Sopenharmony_ci class C: 12797db96d56Sopenharmony_ci x: int 12807db96d56Sopenharmony_ci y: list = field(default_factory=list) 12817db96d56Sopenharmony_ci 12827db96d56Sopenharmony_ci c0 = C(3) 12837db96d56Sopenharmony_ci c1 = C(3) 12847db96d56Sopenharmony_ci self.assertEqual(c0.x, 3) 12857db96d56Sopenharmony_ci self.assertEqual(c0.y, []) 12867db96d56Sopenharmony_ci self.assertEqual(c0, c1) 12877db96d56Sopenharmony_ci self.assertIsNot(c0.y, c1.y) 12887db96d56Sopenharmony_ci self.assertEqual(astuple(C(5, [1])), (5, [1])) 12897db96d56Sopenharmony_ci 12907db96d56Sopenharmony_ci # Test a factory that returns a shared list. 12917db96d56Sopenharmony_ci l = [] 12927db96d56Sopenharmony_ci @dataclass 12937db96d56Sopenharmony_ci class C: 12947db96d56Sopenharmony_ci x: int 12957db96d56Sopenharmony_ci y: list = field(default_factory=lambda: l) 12967db96d56Sopenharmony_ci 12977db96d56Sopenharmony_ci c0 = C(3) 12987db96d56Sopenharmony_ci c1 = C(3) 12997db96d56Sopenharmony_ci self.assertEqual(c0.x, 3) 13007db96d56Sopenharmony_ci self.assertEqual(c0.y, []) 13017db96d56Sopenharmony_ci self.assertEqual(c0, c1) 13027db96d56Sopenharmony_ci self.assertIs(c0.y, c1.y) 13037db96d56Sopenharmony_ci self.assertEqual(astuple(C(5, [1])), (5, [1])) 13047db96d56Sopenharmony_ci 13057db96d56Sopenharmony_ci # Test various other field flags. 13067db96d56Sopenharmony_ci # repr 13077db96d56Sopenharmony_ci @dataclass 13087db96d56Sopenharmony_ci class C: 13097db96d56Sopenharmony_ci x: list = field(default_factory=list, repr=False) 13107db96d56Sopenharmony_ci self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()') 13117db96d56Sopenharmony_ci self.assertEqual(C().x, []) 13127db96d56Sopenharmony_ci 13137db96d56Sopenharmony_ci # hash 13147db96d56Sopenharmony_ci @dataclass(unsafe_hash=True) 13157db96d56Sopenharmony_ci class C: 13167db96d56Sopenharmony_ci x: list = field(default_factory=list, hash=False) 13177db96d56Sopenharmony_ci self.assertEqual(astuple(C()), ([],)) 13187db96d56Sopenharmony_ci self.assertEqual(hash(C()), hash(())) 13197db96d56Sopenharmony_ci 13207db96d56Sopenharmony_ci # init (see also test_default_factory_with_no_init) 13217db96d56Sopenharmony_ci @dataclass 13227db96d56Sopenharmony_ci class C: 13237db96d56Sopenharmony_ci x: list = field(default_factory=list, init=False) 13247db96d56Sopenharmony_ci self.assertEqual(astuple(C()), ([],)) 13257db96d56Sopenharmony_ci 13267db96d56Sopenharmony_ci # compare 13277db96d56Sopenharmony_ci @dataclass 13287db96d56Sopenharmony_ci class C: 13297db96d56Sopenharmony_ci x: list = field(default_factory=list, compare=False) 13307db96d56Sopenharmony_ci self.assertEqual(C(), C([1])) 13317db96d56Sopenharmony_ci 13327db96d56Sopenharmony_ci def test_default_factory_with_no_init(self): 13337db96d56Sopenharmony_ci # We need a factory with a side effect. 13347db96d56Sopenharmony_ci factory = Mock() 13357db96d56Sopenharmony_ci 13367db96d56Sopenharmony_ci @dataclass 13377db96d56Sopenharmony_ci class C: 13387db96d56Sopenharmony_ci x: list = field(default_factory=factory, init=False) 13397db96d56Sopenharmony_ci 13407db96d56Sopenharmony_ci # Make sure the default factory is called for each new instance. 13417db96d56Sopenharmony_ci C().x 13427db96d56Sopenharmony_ci self.assertEqual(factory.call_count, 1) 13437db96d56Sopenharmony_ci C().x 13447db96d56Sopenharmony_ci self.assertEqual(factory.call_count, 2) 13457db96d56Sopenharmony_ci 13467db96d56Sopenharmony_ci def test_default_factory_not_called_if_value_given(self): 13477db96d56Sopenharmony_ci # We need a factory that we can test if it's been called. 13487db96d56Sopenharmony_ci factory = Mock() 13497db96d56Sopenharmony_ci 13507db96d56Sopenharmony_ci @dataclass 13517db96d56Sopenharmony_ci class C: 13527db96d56Sopenharmony_ci x: int = field(default_factory=factory) 13537db96d56Sopenharmony_ci 13547db96d56Sopenharmony_ci # Make sure that if a field has a default factory function, 13557db96d56Sopenharmony_ci # it's not called if a value is specified. 13567db96d56Sopenharmony_ci C().x 13577db96d56Sopenharmony_ci self.assertEqual(factory.call_count, 1) 13587db96d56Sopenharmony_ci self.assertEqual(C(10).x, 10) 13597db96d56Sopenharmony_ci self.assertEqual(factory.call_count, 1) 13607db96d56Sopenharmony_ci C().x 13617db96d56Sopenharmony_ci self.assertEqual(factory.call_count, 2) 13627db96d56Sopenharmony_ci 13637db96d56Sopenharmony_ci def test_default_factory_derived(self): 13647db96d56Sopenharmony_ci # See bpo-32896. 13657db96d56Sopenharmony_ci @dataclass 13667db96d56Sopenharmony_ci class Foo: 13677db96d56Sopenharmony_ci x: dict = field(default_factory=dict) 13687db96d56Sopenharmony_ci 13697db96d56Sopenharmony_ci @dataclass 13707db96d56Sopenharmony_ci class Bar(Foo): 13717db96d56Sopenharmony_ci y: int = 1 13727db96d56Sopenharmony_ci 13737db96d56Sopenharmony_ci self.assertEqual(Foo().x, {}) 13747db96d56Sopenharmony_ci self.assertEqual(Bar().x, {}) 13757db96d56Sopenharmony_ci self.assertEqual(Bar().y, 1) 13767db96d56Sopenharmony_ci 13777db96d56Sopenharmony_ci @dataclass 13787db96d56Sopenharmony_ci class Baz(Foo): 13797db96d56Sopenharmony_ci pass 13807db96d56Sopenharmony_ci self.assertEqual(Baz().x, {}) 13817db96d56Sopenharmony_ci 13827db96d56Sopenharmony_ci def test_intermediate_non_dataclass(self): 13837db96d56Sopenharmony_ci # Test that an intermediate class that defines 13847db96d56Sopenharmony_ci # annotations does not define fields. 13857db96d56Sopenharmony_ci 13867db96d56Sopenharmony_ci @dataclass 13877db96d56Sopenharmony_ci class A: 13887db96d56Sopenharmony_ci x: int 13897db96d56Sopenharmony_ci 13907db96d56Sopenharmony_ci class B(A): 13917db96d56Sopenharmony_ci y: int 13927db96d56Sopenharmony_ci 13937db96d56Sopenharmony_ci @dataclass 13947db96d56Sopenharmony_ci class C(B): 13957db96d56Sopenharmony_ci z: int 13967db96d56Sopenharmony_ci 13977db96d56Sopenharmony_ci c = C(1, 3) 13987db96d56Sopenharmony_ci self.assertEqual((c.x, c.z), (1, 3)) 13997db96d56Sopenharmony_ci 14007db96d56Sopenharmony_ci # .y was not initialized. 14017db96d56Sopenharmony_ci with self.assertRaisesRegex(AttributeError, 14027db96d56Sopenharmony_ci 'object has no attribute'): 14037db96d56Sopenharmony_ci c.y 14047db96d56Sopenharmony_ci 14057db96d56Sopenharmony_ci # And if we again derive a non-dataclass, no fields are added. 14067db96d56Sopenharmony_ci class D(C): 14077db96d56Sopenharmony_ci t: int 14087db96d56Sopenharmony_ci d = D(4, 5) 14097db96d56Sopenharmony_ci self.assertEqual((d.x, d.z), (4, 5)) 14107db96d56Sopenharmony_ci 14117db96d56Sopenharmony_ci def test_classvar_default_factory(self): 14127db96d56Sopenharmony_ci # It's an error for a ClassVar to have a factory function. 14137db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 14147db96d56Sopenharmony_ci 'cannot have a default factory'): 14157db96d56Sopenharmony_ci @dataclass 14167db96d56Sopenharmony_ci class C: 14177db96d56Sopenharmony_ci x: ClassVar[int] = field(default_factory=int) 14187db96d56Sopenharmony_ci 14197db96d56Sopenharmony_ci def test_is_dataclass(self): 14207db96d56Sopenharmony_ci class NotDataClass: 14217db96d56Sopenharmony_ci pass 14227db96d56Sopenharmony_ci 14237db96d56Sopenharmony_ci self.assertFalse(is_dataclass(0)) 14247db96d56Sopenharmony_ci self.assertFalse(is_dataclass(int)) 14257db96d56Sopenharmony_ci self.assertFalse(is_dataclass(NotDataClass)) 14267db96d56Sopenharmony_ci self.assertFalse(is_dataclass(NotDataClass())) 14277db96d56Sopenharmony_ci 14287db96d56Sopenharmony_ci @dataclass 14297db96d56Sopenharmony_ci class C: 14307db96d56Sopenharmony_ci x: int 14317db96d56Sopenharmony_ci 14327db96d56Sopenharmony_ci @dataclass 14337db96d56Sopenharmony_ci class D: 14347db96d56Sopenharmony_ci d: C 14357db96d56Sopenharmony_ci e: int 14367db96d56Sopenharmony_ci 14377db96d56Sopenharmony_ci c = C(10) 14387db96d56Sopenharmony_ci d = D(c, 4) 14397db96d56Sopenharmony_ci 14407db96d56Sopenharmony_ci self.assertTrue(is_dataclass(C)) 14417db96d56Sopenharmony_ci self.assertTrue(is_dataclass(c)) 14427db96d56Sopenharmony_ci self.assertFalse(is_dataclass(c.x)) 14437db96d56Sopenharmony_ci self.assertTrue(is_dataclass(d.d)) 14447db96d56Sopenharmony_ci self.assertFalse(is_dataclass(d.e)) 14457db96d56Sopenharmony_ci 14467db96d56Sopenharmony_ci def test_is_dataclass_when_getattr_always_returns(self): 14477db96d56Sopenharmony_ci # See bpo-37868. 14487db96d56Sopenharmony_ci class A: 14497db96d56Sopenharmony_ci def __getattr__(self, key): 14507db96d56Sopenharmony_ci return 0 14517db96d56Sopenharmony_ci self.assertFalse(is_dataclass(A)) 14527db96d56Sopenharmony_ci a = A() 14537db96d56Sopenharmony_ci 14547db96d56Sopenharmony_ci # Also test for an instance attribute. 14557db96d56Sopenharmony_ci class B: 14567db96d56Sopenharmony_ci pass 14577db96d56Sopenharmony_ci b = B() 14587db96d56Sopenharmony_ci b.__dataclass_fields__ = [] 14597db96d56Sopenharmony_ci 14607db96d56Sopenharmony_ci for obj in a, b: 14617db96d56Sopenharmony_ci with self.subTest(obj=obj): 14627db96d56Sopenharmony_ci self.assertFalse(is_dataclass(obj)) 14637db96d56Sopenharmony_ci 14647db96d56Sopenharmony_ci # Indirect tests for _is_dataclass_instance(). 14657db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 14667db96d56Sopenharmony_ci asdict(obj) 14677db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 14687db96d56Sopenharmony_ci astuple(obj) 14697db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 14707db96d56Sopenharmony_ci replace(obj, x=0) 14717db96d56Sopenharmony_ci 14727db96d56Sopenharmony_ci def test_is_dataclass_genericalias(self): 14737db96d56Sopenharmony_ci @dataclass 14747db96d56Sopenharmony_ci class A(types.GenericAlias): 14757db96d56Sopenharmony_ci origin: type 14767db96d56Sopenharmony_ci args: type 14777db96d56Sopenharmony_ci self.assertTrue(is_dataclass(A)) 14787db96d56Sopenharmony_ci a = A(list, int) 14797db96d56Sopenharmony_ci self.assertTrue(is_dataclass(type(a))) 14807db96d56Sopenharmony_ci self.assertTrue(is_dataclass(a)) 14817db96d56Sopenharmony_ci 14827db96d56Sopenharmony_ci 14837db96d56Sopenharmony_ci def test_helper_fields_with_class_instance(self): 14847db96d56Sopenharmony_ci # Check that we can call fields() on either a class or instance, 14857db96d56Sopenharmony_ci # and get back the same thing. 14867db96d56Sopenharmony_ci @dataclass 14877db96d56Sopenharmony_ci class C: 14887db96d56Sopenharmony_ci x: int 14897db96d56Sopenharmony_ci y: float 14907db96d56Sopenharmony_ci 14917db96d56Sopenharmony_ci self.assertEqual(fields(C), fields(C(0, 0.0))) 14927db96d56Sopenharmony_ci 14937db96d56Sopenharmony_ci def test_helper_fields_exception(self): 14947db96d56Sopenharmony_ci # Check that TypeError is raised if not passed a dataclass or 14957db96d56Sopenharmony_ci # instance. 14967db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 14977db96d56Sopenharmony_ci fields(0) 14987db96d56Sopenharmony_ci 14997db96d56Sopenharmony_ci class C: pass 15007db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 15017db96d56Sopenharmony_ci fields(C) 15027db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 15037db96d56Sopenharmony_ci fields(C()) 15047db96d56Sopenharmony_ci 15057db96d56Sopenharmony_ci def test_clean_traceback_from_fields_exception(self): 15067db96d56Sopenharmony_ci stdout = io.StringIO() 15077db96d56Sopenharmony_ci try: 15087db96d56Sopenharmony_ci fields(object) 15097db96d56Sopenharmony_ci except TypeError as exc: 15107db96d56Sopenharmony_ci traceback.print_exception(exc, file=stdout) 15117db96d56Sopenharmony_ci printed_traceback = stdout.getvalue() 15127db96d56Sopenharmony_ci self.assertNotIn("AttributeError", printed_traceback) 15137db96d56Sopenharmony_ci self.assertNotIn("__dataclass_fields__", printed_traceback) 15147db96d56Sopenharmony_ci 15157db96d56Sopenharmony_ci def test_helper_asdict(self): 15167db96d56Sopenharmony_ci # Basic tests for asdict(), it should return a new dictionary. 15177db96d56Sopenharmony_ci @dataclass 15187db96d56Sopenharmony_ci class C: 15197db96d56Sopenharmony_ci x: int 15207db96d56Sopenharmony_ci y: int 15217db96d56Sopenharmony_ci c = C(1, 2) 15227db96d56Sopenharmony_ci 15237db96d56Sopenharmony_ci self.assertEqual(asdict(c), {'x': 1, 'y': 2}) 15247db96d56Sopenharmony_ci self.assertEqual(asdict(c), asdict(c)) 15257db96d56Sopenharmony_ci self.assertIsNot(asdict(c), asdict(c)) 15267db96d56Sopenharmony_ci c.x = 42 15277db96d56Sopenharmony_ci self.assertEqual(asdict(c), {'x': 42, 'y': 2}) 15287db96d56Sopenharmony_ci self.assertIs(type(asdict(c)), dict) 15297db96d56Sopenharmony_ci 15307db96d56Sopenharmony_ci def test_helper_asdict_raises_on_classes(self): 15317db96d56Sopenharmony_ci # asdict() should raise on a class object. 15327db96d56Sopenharmony_ci @dataclass 15337db96d56Sopenharmony_ci class C: 15347db96d56Sopenharmony_ci x: int 15357db96d56Sopenharmony_ci y: int 15367db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'dataclass instance'): 15377db96d56Sopenharmony_ci asdict(C) 15387db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'dataclass instance'): 15397db96d56Sopenharmony_ci asdict(int) 15407db96d56Sopenharmony_ci 15417db96d56Sopenharmony_ci def test_helper_asdict_copy_values(self): 15427db96d56Sopenharmony_ci @dataclass 15437db96d56Sopenharmony_ci class C: 15447db96d56Sopenharmony_ci x: int 15457db96d56Sopenharmony_ci y: List[int] = field(default_factory=list) 15467db96d56Sopenharmony_ci initial = [] 15477db96d56Sopenharmony_ci c = C(1, initial) 15487db96d56Sopenharmony_ci d = asdict(c) 15497db96d56Sopenharmony_ci self.assertEqual(d['y'], initial) 15507db96d56Sopenharmony_ci self.assertIsNot(d['y'], initial) 15517db96d56Sopenharmony_ci c = C(1) 15527db96d56Sopenharmony_ci d = asdict(c) 15537db96d56Sopenharmony_ci d['y'].append(1) 15547db96d56Sopenharmony_ci self.assertEqual(c.y, []) 15557db96d56Sopenharmony_ci 15567db96d56Sopenharmony_ci def test_helper_asdict_nested(self): 15577db96d56Sopenharmony_ci @dataclass 15587db96d56Sopenharmony_ci class UserId: 15597db96d56Sopenharmony_ci token: int 15607db96d56Sopenharmony_ci group: int 15617db96d56Sopenharmony_ci @dataclass 15627db96d56Sopenharmony_ci class User: 15637db96d56Sopenharmony_ci name: str 15647db96d56Sopenharmony_ci id: UserId 15657db96d56Sopenharmony_ci u = User('Joe', UserId(123, 1)) 15667db96d56Sopenharmony_ci d = asdict(u) 15677db96d56Sopenharmony_ci self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}}) 15687db96d56Sopenharmony_ci self.assertIsNot(asdict(u), asdict(u)) 15697db96d56Sopenharmony_ci u.id.group = 2 15707db96d56Sopenharmony_ci self.assertEqual(asdict(u), {'name': 'Joe', 15717db96d56Sopenharmony_ci 'id': {'token': 123, 'group': 2}}) 15727db96d56Sopenharmony_ci 15737db96d56Sopenharmony_ci def test_helper_asdict_builtin_containers(self): 15747db96d56Sopenharmony_ci @dataclass 15757db96d56Sopenharmony_ci class User: 15767db96d56Sopenharmony_ci name: str 15777db96d56Sopenharmony_ci id: int 15787db96d56Sopenharmony_ci @dataclass 15797db96d56Sopenharmony_ci class GroupList: 15807db96d56Sopenharmony_ci id: int 15817db96d56Sopenharmony_ci users: List[User] 15827db96d56Sopenharmony_ci @dataclass 15837db96d56Sopenharmony_ci class GroupTuple: 15847db96d56Sopenharmony_ci id: int 15857db96d56Sopenharmony_ci users: Tuple[User, ...] 15867db96d56Sopenharmony_ci @dataclass 15877db96d56Sopenharmony_ci class GroupDict: 15887db96d56Sopenharmony_ci id: int 15897db96d56Sopenharmony_ci users: Dict[str, User] 15907db96d56Sopenharmony_ci a = User('Alice', 1) 15917db96d56Sopenharmony_ci b = User('Bob', 2) 15927db96d56Sopenharmony_ci gl = GroupList(0, [a, b]) 15937db96d56Sopenharmony_ci gt = GroupTuple(0, (a, b)) 15947db96d56Sopenharmony_ci gd = GroupDict(0, {'first': a, 'second': b}) 15957db96d56Sopenharmony_ci self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1}, 15967db96d56Sopenharmony_ci {'name': 'Bob', 'id': 2}]}) 15977db96d56Sopenharmony_ci self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1}, 15987db96d56Sopenharmony_ci {'name': 'Bob', 'id': 2})}) 15997db96d56Sopenharmony_ci self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1}, 16007db96d56Sopenharmony_ci 'second': {'name': 'Bob', 'id': 2}}}) 16017db96d56Sopenharmony_ci 16027db96d56Sopenharmony_ci def test_helper_asdict_builtin_object_containers(self): 16037db96d56Sopenharmony_ci @dataclass 16047db96d56Sopenharmony_ci class Child: 16057db96d56Sopenharmony_ci d: object 16067db96d56Sopenharmony_ci 16077db96d56Sopenharmony_ci @dataclass 16087db96d56Sopenharmony_ci class Parent: 16097db96d56Sopenharmony_ci child: Child 16107db96d56Sopenharmony_ci 16117db96d56Sopenharmony_ci self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}}) 16127db96d56Sopenharmony_ci self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}}) 16137db96d56Sopenharmony_ci 16147db96d56Sopenharmony_ci def test_helper_asdict_factory(self): 16157db96d56Sopenharmony_ci @dataclass 16167db96d56Sopenharmony_ci class C: 16177db96d56Sopenharmony_ci x: int 16187db96d56Sopenharmony_ci y: int 16197db96d56Sopenharmony_ci c = C(1, 2) 16207db96d56Sopenharmony_ci d = asdict(c, dict_factory=OrderedDict) 16217db96d56Sopenharmony_ci self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)])) 16227db96d56Sopenharmony_ci self.assertIsNot(d, asdict(c, dict_factory=OrderedDict)) 16237db96d56Sopenharmony_ci c.x = 42 16247db96d56Sopenharmony_ci d = asdict(c, dict_factory=OrderedDict) 16257db96d56Sopenharmony_ci self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)])) 16267db96d56Sopenharmony_ci self.assertIs(type(d), OrderedDict) 16277db96d56Sopenharmony_ci 16287db96d56Sopenharmony_ci def test_helper_asdict_namedtuple(self): 16297db96d56Sopenharmony_ci T = namedtuple('T', 'a b c') 16307db96d56Sopenharmony_ci @dataclass 16317db96d56Sopenharmony_ci class C: 16327db96d56Sopenharmony_ci x: str 16337db96d56Sopenharmony_ci y: T 16347db96d56Sopenharmony_ci c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) 16357db96d56Sopenharmony_ci 16367db96d56Sopenharmony_ci d = asdict(c) 16377db96d56Sopenharmony_ci self.assertEqual(d, {'x': 'outer', 16387db96d56Sopenharmony_ci 'y': T(1, 16397db96d56Sopenharmony_ci {'x': 'inner', 16407db96d56Sopenharmony_ci 'y': T(11, 12, 13)}, 16417db96d56Sopenharmony_ci 2), 16427db96d56Sopenharmony_ci } 16437db96d56Sopenharmony_ci ) 16447db96d56Sopenharmony_ci 16457db96d56Sopenharmony_ci # Now with a dict_factory. OrderedDict is convenient, but 16467db96d56Sopenharmony_ci # since it compares to dicts, we also need to have separate 16477db96d56Sopenharmony_ci # assertIs tests. 16487db96d56Sopenharmony_ci d = asdict(c, dict_factory=OrderedDict) 16497db96d56Sopenharmony_ci self.assertEqual(d, {'x': 'outer', 16507db96d56Sopenharmony_ci 'y': T(1, 16517db96d56Sopenharmony_ci {'x': 'inner', 16527db96d56Sopenharmony_ci 'y': T(11, 12, 13)}, 16537db96d56Sopenharmony_ci 2), 16547db96d56Sopenharmony_ci } 16557db96d56Sopenharmony_ci ) 16567db96d56Sopenharmony_ci 16577db96d56Sopenharmony_ci # Make sure that the returned dicts are actually OrderedDicts. 16587db96d56Sopenharmony_ci self.assertIs(type(d), OrderedDict) 16597db96d56Sopenharmony_ci self.assertIs(type(d['y'][1]), OrderedDict) 16607db96d56Sopenharmony_ci 16617db96d56Sopenharmony_ci def test_helper_asdict_namedtuple_key(self): 16627db96d56Sopenharmony_ci # Ensure that a field that contains a dict which has a 16637db96d56Sopenharmony_ci # namedtuple as a key works with asdict(). 16647db96d56Sopenharmony_ci 16657db96d56Sopenharmony_ci @dataclass 16667db96d56Sopenharmony_ci class C: 16677db96d56Sopenharmony_ci f: dict 16687db96d56Sopenharmony_ci T = namedtuple('T', 'a') 16697db96d56Sopenharmony_ci 16707db96d56Sopenharmony_ci c = C({T('an a'): 0}) 16717db96d56Sopenharmony_ci 16727db96d56Sopenharmony_ci self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}}) 16737db96d56Sopenharmony_ci 16747db96d56Sopenharmony_ci def test_helper_asdict_namedtuple_derived(self): 16757db96d56Sopenharmony_ci class T(namedtuple('Tbase', 'a')): 16767db96d56Sopenharmony_ci def my_a(self): 16777db96d56Sopenharmony_ci return self.a 16787db96d56Sopenharmony_ci 16797db96d56Sopenharmony_ci @dataclass 16807db96d56Sopenharmony_ci class C: 16817db96d56Sopenharmony_ci f: T 16827db96d56Sopenharmony_ci 16837db96d56Sopenharmony_ci t = T(6) 16847db96d56Sopenharmony_ci c = C(t) 16857db96d56Sopenharmony_ci 16867db96d56Sopenharmony_ci d = asdict(c) 16877db96d56Sopenharmony_ci self.assertEqual(d, {'f': T(a=6)}) 16887db96d56Sopenharmony_ci # Make sure that t has been copied, not used directly. 16897db96d56Sopenharmony_ci self.assertIsNot(d['f'], t) 16907db96d56Sopenharmony_ci self.assertEqual(d['f'].my_a(), 6) 16917db96d56Sopenharmony_ci 16927db96d56Sopenharmony_ci def test_helper_astuple(self): 16937db96d56Sopenharmony_ci # Basic tests for astuple(), it should return a new tuple. 16947db96d56Sopenharmony_ci @dataclass 16957db96d56Sopenharmony_ci class C: 16967db96d56Sopenharmony_ci x: int 16977db96d56Sopenharmony_ci y: int = 0 16987db96d56Sopenharmony_ci c = C(1) 16997db96d56Sopenharmony_ci 17007db96d56Sopenharmony_ci self.assertEqual(astuple(c), (1, 0)) 17017db96d56Sopenharmony_ci self.assertEqual(astuple(c), astuple(c)) 17027db96d56Sopenharmony_ci self.assertIsNot(astuple(c), astuple(c)) 17037db96d56Sopenharmony_ci c.y = 42 17047db96d56Sopenharmony_ci self.assertEqual(astuple(c), (1, 42)) 17057db96d56Sopenharmony_ci self.assertIs(type(astuple(c)), tuple) 17067db96d56Sopenharmony_ci 17077db96d56Sopenharmony_ci def test_helper_astuple_raises_on_classes(self): 17087db96d56Sopenharmony_ci # astuple() should raise on a class object. 17097db96d56Sopenharmony_ci @dataclass 17107db96d56Sopenharmony_ci class C: 17117db96d56Sopenharmony_ci x: int 17127db96d56Sopenharmony_ci y: int 17137db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'dataclass instance'): 17147db96d56Sopenharmony_ci astuple(C) 17157db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'dataclass instance'): 17167db96d56Sopenharmony_ci astuple(int) 17177db96d56Sopenharmony_ci 17187db96d56Sopenharmony_ci def test_helper_astuple_copy_values(self): 17197db96d56Sopenharmony_ci @dataclass 17207db96d56Sopenharmony_ci class C: 17217db96d56Sopenharmony_ci x: int 17227db96d56Sopenharmony_ci y: List[int] = field(default_factory=list) 17237db96d56Sopenharmony_ci initial = [] 17247db96d56Sopenharmony_ci c = C(1, initial) 17257db96d56Sopenharmony_ci t = astuple(c) 17267db96d56Sopenharmony_ci self.assertEqual(t[1], initial) 17277db96d56Sopenharmony_ci self.assertIsNot(t[1], initial) 17287db96d56Sopenharmony_ci c = C(1) 17297db96d56Sopenharmony_ci t = astuple(c) 17307db96d56Sopenharmony_ci t[1].append(1) 17317db96d56Sopenharmony_ci self.assertEqual(c.y, []) 17327db96d56Sopenharmony_ci 17337db96d56Sopenharmony_ci def test_helper_astuple_nested(self): 17347db96d56Sopenharmony_ci @dataclass 17357db96d56Sopenharmony_ci class UserId: 17367db96d56Sopenharmony_ci token: int 17377db96d56Sopenharmony_ci group: int 17387db96d56Sopenharmony_ci @dataclass 17397db96d56Sopenharmony_ci class User: 17407db96d56Sopenharmony_ci name: str 17417db96d56Sopenharmony_ci id: UserId 17427db96d56Sopenharmony_ci u = User('Joe', UserId(123, 1)) 17437db96d56Sopenharmony_ci t = astuple(u) 17447db96d56Sopenharmony_ci self.assertEqual(t, ('Joe', (123, 1))) 17457db96d56Sopenharmony_ci self.assertIsNot(astuple(u), astuple(u)) 17467db96d56Sopenharmony_ci u.id.group = 2 17477db96d56Sopenharmony_ci self.assertEqual(astuple(u), ('Joe', (123, 2))) 17487db96d56Sopenharmony_ci 17497db96d56Sopenharmony_ci def test_helper_astuple_builtin_containers(self): 17507db96d56Sopenharmony_ci @dataclass 17517db96d56Sopenharmony_ci class User: 17527db96d56Sopenharmony_ci name: str 17537db96d56Sopenharmony_ci id: int 17547db96d56Sopenharmony_ci @dataclass 17557db96d56Sopenharmony_ci class GroupList: 17567db96d56Sopenharmony_ci id: int 17577db96d56Sopenharmony_ci users: List[User] 17587db96d56Sopenharmony_ci @dataclass 17597db96d56Sopenharmony_ci class GroupTuple: 17607db96d56Sopenharmony_ci id: int 17617db96d56Sopenharmony_ci users: Tuple[User, ...] 17627db96d56Sopenharmony_ci @dataclass 17637db96d56Sopenharmony_ci class GroupDict: 17647db96d56Sopenharmony_ci id: int 17657db96d56Sopenharmony_ci users: Dict[str, User] 17667db96d56Sopenharmony_ci a = User('Alice', 1) 17677db96d56Sopenharmony_ci b = User('Bob', 2) 17687db96d56Sopenharmony_ci gl = GroupList(0, [a, b]) 17697db96d56Sopenharmony_ci gt = GroupTuple(0, (a, b)) 17707db96d56Sopenharmony_ci gd = GroupDict(0, {'first': a, 'second': b}) 17717db96d56Sopenharmony_ci self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)])) 17727db96d56Sopenharmony_ci self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2)))) 17737db96d56Sopenharmony_ci self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)})) 17747db96d56Sopenharmony_ci 17757db96d56Sopenharmony_ci def test_helper_astuple_builtin_object_containers(self): 17767db96d56Sopenharmony_ci @dataclass 17777db96d56Sopenharmony_ci class Child: 17787db96d56Sopenharmony_ci d: object 17797db96d56Sopenharmony_ci 17807db96d56Sopenharmony_ci @dataclass 17817db96d56Sopenharmony_ci class Parent: 17827db96d56Sopenharmony_ci child: Child 17837db96d56Sopenharmony_ci 17847db96d56Sopenharmony_ci self.assertEqual(astuple(Parent(Child([1]))), (([1],),)) 17857db96d56Sopenharmony_ci self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),)) 17867db96d56Sopenharmony_ci 17877db96d56Sopenharmony_ci def test_helper_astuple_factory(self): 17887db96d56Sopenharmony_ci @dataclass 17897db96d56Sopenharmony_ci class C: 17907db96d56Sopenharmony_ci x: int 17917db96d56Sopenharmony_ci y: int 17927db96d56Sopenharmony_ci NT = namedtuple('NT', 'x y') 17937db96d56Sopenharmony_ci def nt(lst): 17947db96d56Sopenharmony_ci return NT(*lst) 17957db96d56Sopenharmony_ci c = C(1, 2) 17967db96d56Sopenharmony_ci t = astuple(c, tuple_factory=nt) 17977db96d56Sopenharmony_ci self.assertEqual(t, NT(1, 2)) 17987db96d56Sopenharmony_ci self.assertIsNot(t, astuple(c, tuple_factory=nt)) 17997db96d56Sopenharmony_ci c.x = 42 18007db96d56Sopenharmony_ci t = astuple(c, tuple_factory=nt) 18017db96d56Sopenharmony_ci self.assertEqual(t, NT(42, 2)) 18027db96d56Sopenharmony_ci self.assertIs(type(t), NT) 18037db96d56Sopenharmony_ci 18047db96d56Sopenharmony_ci def test_helper_astuple_namedtuple(self): 18057db96d56Sopenharmony_ci T = namedtuple('T', 'a b c') 18067db96d56Sopenharmony_ci @dataclass 18077db96d56Sopenharmony_ci class C: 18087db96d56Sopenharmony_ci x: str 18097db96d56Sopenharmony_ci y: T 18107db96d56Sopenharmony_ci c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) 18117db96d56Sopenharmony_ci 18127db96d56Sopenharmony_ci t = astuple(c) 18137db96d56Sopenharmony_ci self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2))) 18147db96d56Sopenharmony_ci 18157db96d56Sopenharmony_ci # Now, using a tuple_factory. list is convenient here. 18167db96d56Sopenharmony_ci t = astuple(c, tuple_factory=list) 18177db96d56Sopenharmony_ci self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)]) 18187db96d56Sopenharmony_ci 18197db96d56Sopenharmony_ci def test_dynamic_class_creation(self): 18207db96d56Sopenharmony_ci cls_dict = {'__annotations__': {'x': int, 'y': int}, 18217db96d56Sopenharmony_ci } 18227db96d56Sopenharmony_ci 18237db96d56Sopenharmony_ci # Create the class. 18247db96d56Sopenharmony_ci cls = type('C', (), cls_dict) 18257db96d56Sopenharmony_ci 18267db96d56Sopenharmony_ci # Make it a dataclass. 18277db96d56Sopenharmony_ci cls1 = dataclass(cls) 18287db96d56Sopenharmony_ci 18297db96d56Sopenharmony_ci self.assertEqual(cls1, cls) 18307db96d56Sopenharmony_ci self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2}) 18317db96d56Sopenharmony_ci 18327db96d56Sopenharmony_ci def test_dynamic_class_creation_using_field(self): 18337db96d56Sopenharmony_ci cls_dict = {'__annotations__': {'x': int, 'y': int}, 18347db96d56Sopenharmony_ci 'y': field(default=5), 18357db96d56Sopenharmony_ci } 18367db96d56Sopenharmony_ci 18377db96d56Sopenharmony_ci # Create the class. 18387db96d56Sopenharmony_ci cls = type('C', (), cls_dict) 18397db96d56Sopenharmony_ci 18407db96d56Sopenharmony_ci # Make it a dataclass. 18417db96d56Sopenharmony_ci cls1 = dataclass(cls) 18427db96d56Sopenharmony_ci 18437db96d56Sopenharmony_ci self.assertEqual(cls1, cls) 18447db96d56Sopenharmony_ci self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5}) 18457db96d56Sopenharmony_ci 18467db96d56Sopenharmony_ci def test_init_in_order(self): 18477db96d56Sopenharmony_ci @dataclass 18487db96d56Sopenharmony_ci class C: 18497db96d56Sopenharmony_ci a: int 18507db96d56Sopenharmony_ci b: int = field() 18517db96d56Sopenharmony_ci c: list = field(default_factory=list, init=False) 18527db96d56Sopenharmony_ci d: list = field(default_factory=list) 18537db96d56Sopenharmony_ci e: int = field(default=4, init=False) 18547db96d56Sopenharmony_ci f: int = 4 18557db96d56Sopenharmony_ci 18567db96d56Sopenharmony_ci calls = [] 18577db96d56Sopenharmony_ci def setattr(self, name, value): 18587db96d56Sopenharmony_ci calls.append((name, value)) 18597db96d56Sopenharmony_ci 18607db96d56Sopenharmony_ci C.__setattr__ = setattr 18617db96d56Sopenharmony_ci c = C(0, 1) 18627db96d56Sopenharmony_ci self.assertEqual(('a', 0), calls[0]) 18637db96d56Sopenharmony_ci self.assertEqual(('b', 1), calls[1]) 18647db96d56Sopenharmony_ci self.assertEqual(('c', []), calls[2]) 18657db96d56Sopenharmony_ci self.assertEqual(('d', []), calls[3]) 18667db96d56Sopenharmony_ci self.assertNotIn(('e', 4), calls) 18677db96d56Sopenharmony_ci self.assertEqual(('f', 4), calls[4]) 18687db96d56Sopenharmony_ci 18697db96d56Sopenharmony_ci def test_items_in_dicts(self): 18707db96d56Sopenharmony_ci @dataclass 18717db96d56Sopenharmony_ci class C: 18727db96d56Sopenharmony_ci a: int 18737db96d56Sopenharmony_ci b: list = field(default_factory=list, init=False) 18747db96d56Sopenharmony_ci c: list = field(default_factory=list) 18757db96d56Sopenharmony_ci d: int = field(default=4, init=False) 18767db96d56Sopenharmony_ci e: int = 0 18777db96d56Sopenharmony_ci 18787db96d56Sopenharmony_ci c = C(0) 18797db96d56Sopenharmony_ci # Class dict 18807db96d56Sopenharmony_ci self.assertNotIn('a', C.__dict__) 18817db96d56Sopenharmony_ci self.assertNotIn('b', C.__dict__) 18827db96d56Sopenharmony_ci self.assertNotIn('c', C.__dict__) 18837db96d56Sopenharmony_ci self.assertIn('d', C.__dict__) 18847db96d56Sopenharmony_ci self.assertEqual(C.d, 4) 18857db96d56Sopenharmony_ci self.assertIn('e', C.__dict__) 18867db96d56Sopenharmony_ci self.assertEqual(C.e, 0) 18877db96d56Sopenharmony_ci # Instance dict 18887db96d56Sopenharmony_ci self.assertIn('a', c.__dict__) 18897db96d56Sopenharmony_ci self.assertEqual(c.a, 0) 18907db96d56Sopenharmony_ci self.assertIn('b', c.__dict__) 18917db96d56Sopenharmony_ci self.assertEqual(c.b, []) 18927db96d56Sopenharmony_ci self.assertIn('c', c.__dict__) 18937db96d56Sopenharmony_ci self.assertEqual(c.c, []) 18947db96d56Sopenharmony_ci self.assertNotIn('d', c.__dict__) 18957db96d56Sopenharmony_ci self.assertIn('e', c.__dict__) 18967db96d56Sopenharmony_ci self.assertEqual(c.e, 0) 18977db96d56Sopenharmony_ci 18987db96d56Sopenharmony_ci def test_alternate_classmethod_constructor(self): 18997db96d56Sopenharmony_ci # Since __post_init__ can't take params, use a classmethod 19007db96d56Sopenharmony_ci # alternate constructor. This is mostly an example to show 19017db96d56Sopenharmony_ci # how to use this technique. 19027db96d56Sopenharmony_ci @dataclass 19037db96d56Sopenharmony_ci class C: 19047db96d56Sopenharmony_ci x: int 19057db96d56Sopenharmony_ci @classmethod 19067db96d56Sopenharmony_ci def from_file(cls, filename): 19077db96d56Sopenharmony_ci # In a real example, create a new instance 19087db96d56Sopenharmony_ci # and populate 'x' from contents of a file. 19097db96d56Sopenharmony_ci value_in_file = 20 19107db96d56Sopenharmony_ci return cls(value_in_file) 19117db96d56Sopenharmony_ci 19127db96d56Sopenharmony_ci self.assertEqual(C.from_file('filename').x, 20) 19137db96d56Sopenharmony_ci 19147db96d56Sopenharmony_ci def test_field_metadata_default(self): 19157db96d56Sopenharmony_ci # Make sure the default metadata is read-only and of 19167db96d56Sopenharmony_ci # zero length. 19177db96d56Sopenharmony_ci @dataclass 19187db96d56Sopenharmony_ci class C: 19197db96d56Sopenharmony_ci i: int 19207db96d56Sopenharmony_ci 19217db96d56Sopenharmony_ci self.assertFalse(fields(C)[0].metadata) 19227db96d56Sopenharmony_ci self.assertEqual(len(fields(C)[0].metadata), 0) 19237db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 19247db96d56Sopenharmony_ci 'does not support item assignment'): 19257db96d56Sopenharmony_ci fields(C)[0].metadata['test'] = 3 19267db96d56Sopenharmony_ci 19277db96d56Sopenharmony_ci def test_field_metadata_mapping(self): 19287db96d56Sopenharmony_ci # Make sure only a mapping can be passed as metadata 19297db96d56Sopenharmony_ci # zero length. 19307db96d56Sopenharmony_ci with self.assertRaises(TypeError): 19317db96d56Sopenharmony_ci @dataclass 19327db96d56Sopenharmony_ci class C: 19337db96d56Sopenharmony_ci i: int = field(metadata=0) 19347db96d56Sopenharmony_ci 19357db96d56Sopenharmony_ci # Make sure an empty dict works. 19367db96d56Sopenharmony_ci d = {} 19377db96d56Sopenharmony_ci @dataclass 19387db96d56Sopenharmony_ci class C: 19397db96d56Sopenharmony_ci i: int = field(metadata=d) 19407db96d56Sopenharmony_ci self.assertFalse(fields(C)[0].metadata) 19417db96d56Sopenharmony_ci self.assertEqual(len(fields(C)[0].metadata), 0) 19427db96d56Sopenharmony_ci # Update should work (see bpo-35960). 19437db96d56Sopenharmony_ci d['foo'] = 1 19447db96d56Sopenharmony_ci self.assertEqual(len(fields(C)[0].metadata), 1) 19457db96d56Sopenharmony_ci self.assertEqual(fields(C)[0].metadata['foo'], 1) 19467db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 19477db96d56Sopenharmony_ci 'does not support item assignment'): 19487db96d56Sopenharmony_ci fields(C)[0].metadata['test'] = 3 19497db96d56Sopenharmony_ci 19507db96d56Sopenharmony_ci # Make sure a non-empty dict works. 19517db96d56Sopenharmony_ci d = {'test': 10, 'bar': '42', 3: 'three'} 19527db96d56Sopenharmony_ci @dataclass 19537db96d56Sopenharmony_ci class C: 19547db96d56Sopenharmony_ci i: int = field(metadata=d) 19557db96d56Sopenharmony_ci self.assertEqual(len(fields(C)[0].metadata), 3) 19567db96d56Sopenharmony_ci self.assertEqual(fields(C)[0].metadata['test'], 10) 19577db96d56Sopenharmony_ci self.assertEqual(fields(C)[0].metadata['bar'], '42') 19587db96d56Sopenharmony_ci self.assertEqual(fields(C)[0].metadata[3], 'three') 19597db96d56Sopenharmony_ci # Update should work. 19607db96d56Sopenharmony_ci d['foo'] = 1 19617db96d56Sopenharmony_ci self.assertEqual(len(fields(C)[0].metadata), 4) 19627db96d56Sopenharmony_ci self.assertEqual(fields(C)[0].metadata['foo'], 1) 19637db96d56Sopenharmony_ci with self.assertRaises(KeyError): 19647db96d56Sopenharmony_ci # Non-existent key. 19657db96d56Sopenharmony_ci fields(C)[0].metadata['baz'] 19667db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 19677db96d56Sopenharmony_ci 'does not support item assignment'): 19687db96d56Sopenharmony_ci fields(C)[0].metadata['test'] = 3 19697db96d56Sopenharmony_ci 19707db96d56Sopenharmony_ci def test_field_metadata_custom_mapping(self): 19717db96d56Sopenharmony_ci # Try a custom mapping. 19727db96d56Sopenharmony_ci class SimpleNameSpace: 19737db96d56Sopenharmony_ci def __init__(self, **kw): 19747db96d56Sopenharmony_ci self.__dict__.update(kw) 19757db96d56Sopenharmony_ci 19767db96d56Sopenharmony_ci def __getitem__(self, item): 19777db96d56Sopenharmony_ci if item == 'xyzzy': 19787db96d56Sopenharmony_ci return 'plugh' 19797db96d56Sopenharmony_ci return getattr(self, item) 19807db96d56Sopenharmony_ci 19817db96d56Sopenharmony_ci def __len__(self): 19827db96d56Sopenharmony_ci return self.__dict__.__len__() 19837db96d56Sopenharmony_ci 19847db96d56Sopenharmony_ci @dataclass 19857db96d56Sopenharmony_ci class C: 19867db96d56Sopenharmony_ci i: int = field(metadata=SimpleNameSpace(a=10)) 19877db96d56Sopenharmony_ci 19887db96d56Sopenharmony_ci self.assertEqual(len(fields(C)[0].metadata), 1) 19897db96d56Sopenharmony_ci self.assertEqual(fields(C)[0].metadata['a'], 10) 19907db96d56Sopenharmony_ci with self.assertRaises(AttributeError): 19917db96d56Sopenharmony_ci fields(C)[0].metadata['b'] 19927db96d56Sopenharmony_ci # Make sure we're still talking to our custom mapping. 19937db96d56Sopenharmony_ci self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh') 19947db96d56Sopenharmony_ci 19957db96d56Sopenharmony_ci def test_generic_dataclasses(self): 19967db96d56Sopenharmony_ci T = TypeVar('T') 19977db96d56Sopenharmony_ci 19987db96d56Sopenharmony_ci @dataclass 19997db96d56Sopenharmony_ci class LabeledBox(Generic[T]): 20007db96d56Sopenharmony_ci content: T 20017db96d56Sopenharmony_ci label: str = '<unknown>' 20027db96d56Sopenharmony_ci 20037db96d56Sopenharmony_ci box = LabeledBox(42) 20047db96d56Sopenharmony_ci self.assertEqual(box.content, 42) 20057db96d56Sopenharmony_ci self.assertEqual(box.label, '<unknown>') 20067db96d56Sopenharmony_ci 20077db96d56Sopenharmony_ci # Subscripting the resulting class should work, etc. 20087db96d56Sopenharmony_ci Alias = List[LabeledBox[int]] 20097db96d56Sopenharmony_ci 20107db96d56Sopenharmony_ci def test_generic_extending(self): 20117db96d56Sopenharmony_ci S = TypeVar('S') 20127db96d56Sopenharmony_ci T = TypeVar('T') 20137db96d56Sopenharmony_ci 20147db96d56Sopenharmony_ci @dataclass 20157db96d56Sopenharmony_ci class Base(Generic[T, S]): 20167db96d56Sopenharmony_ci x: T 20177db96d56Sopenharmony_ci y: S 20187db96d56Sopenharmony_ci 20197db96d56Sopenharmony_ci @dataclass 20207db96d56Sopenharmony_ci class DataDerived(Base[int, T]): 20217db96d56Sopenharmony_ci new_field: str 20227db96d56Sopenharmony_ci Alias = DataDerived[str] 20237db96d56Sopenharmony_ci c = Alias(0, 'test1', 'test2') 20247db96d56Sopenharmony_ci self.assertEqual(astuple(c), (0, 'test1', 'test2')) 20257db96d56Sopenharmony_ci 20267db96d56Sopenharmony_ci class NonDataDerived(Base[int, T]): 20277db96d56Sopenharmony_ci def new_method(self): 20287db96d56Sopenharmony_ci return self.y 20297db96d56Sopenharmony_ci Alias = NonDataDerived[float] 20307db96d56Sopenharmony_ci c = Alias(10, 1.0) 20317db96d56Sopenharmony_ci self.assertEqual(c.new_method(), 1.0) 20327db96d56Sopenharmony_ci 20337db96d56Sopenharmony_ci def test_generic_dynamic(self): 20347db96d56Sopenharmony_ci T = TypeVar('T') 20357db96d56Sopenharmony_ci 20367db96d56Sopenharmony_ci @dataclass 20377db96d56Sopenharmony_ci class Parent(Generic[T]): 20387db96d56Sopenharmony_ci x: T 20397db96d56Sopenharmony_ci Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)], 20407db96d56Sopenharmony_ci bases=(Parent[int], Generic[T]), namespace={'other': 42}) 20417db96d56Sopenharmony_ci self.assertIs(Child[int](1, 2).z, None) 20427db96d56Sopenharmony_ci self.assertEqual(Child[int](1, 2, 3).z, 3) 20437db96d56Sopenharmony_ci self.assertEqual(Child[int](1, 2, 3).other, 42) 20447db96d56Sopenharmony_ci # Check that type aliases work correctly. 20457db96d56Sopenharmony_ci Alias = Child[T] 20467db96d56Sopenharmony_ci self.assertEqual(Alias[int](1, 2).x, 1) 20477db96d56Sopenharmony_ci # Check MRO resolution. 20487db96d56Sopenharmony_ci self.assertEqual(Child.__mro__, (Child, Parent, Generic, object)) 20497db96d56Sopenharmony_ci 20507db96d56Sopenharmony_ci def test_dataclasses_pickleable(self): 20517db96d56Sopenharmony_ci global P, Q, R 20527db96d56Sopenharmony_ci @dataclass 20537db96d56Sopenharmony_ci class P: 20547db96d56Sopenharmony_ci x: int 20557db96d56Sopenharmony_ci y: int = 0 20567db96d56Sopenharmony_ci @dataclass 20577db96d56Sopenharmony_ci class Q: 20587db96d56Sopenharmony_ci x: int 20597db96d56Sopenharmony_ci y: int = field(default=0, init=False) 20607db96d56Sopenharmony_ci @dataclass 20617db96d56Sopenharmony_ci class R: 20627db96d56Sopenharmony_ci x: int 20637db96d56Sopenharmony_ci y: List[int] = field(default_factory=list) 20647db96d56Sopenharmony_ci q = Q(1) 20657db96d56Sopenharmony_ci q.y = 2 20667db96d56Sopenharmony_ci samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])] 20677db96d56Sopenharmony_ci for sample in samples: 20687db96d56Sopenharmony_ci for proto in range(pickle.HIGHEST_PROTOCOL + 1): 20697db96d56Sopenharmony_ci with self.subTest(sample=sample, proto=proto): 20707db96d56Sopenharmony_ci new_sample = pickle.loads(pickle.dumps(sample, proto)) 20717db96d56Sopenharmony_ci self.assertEqual(sample.x, new_sample.x) 20727db96d56Sopenharmony_ci self.assertEqual(sample.y, new_sample.y) 20737db96d56Sopenharmony_ci self.assertIsNot(sample, new_sample) 20747db96d56Sopenharmony_ci new_sample.x = 42 20757db96d56Sopenharmony_ci another_new_sample = pickle.loads(pickle.dumps(new_sample, proto)) 20767db96d56Sopenharmony_ci self.assertEqual(new_sample.x, another_new_sample.x) 20777db96d56Sopenharmony_ci self.assertEqual(sample.y, another_new_sample.y) 20787db96d56Sopenharmony_ci 20797db96d56Sopenharmony_ci def test_dataclasses_qualnames(self): 20807db96d56Sopenharmony_ci @dataclass(order=True, unsafe_hash=True, frozen=True) 20817db96d56Sopenharmony_ci class A: 20827db96d56Sopenharmony_ci x: int 20837db96d56Sopenharmony_ci y: int 20847db96d56Sopenharmony_ci 20857db96d56Sopenharmony_ci self.assertEqual(A.__init__.__name__, "__init__") 20867db96d56Sopenharmony_ci for function in ( 20877db96d56Sopenharmony_ci '__eq__', 20887db96d56Sopenharmony_ci '__lt__', 20897db96d56Sopenharmony_ci '__le__', 20907db96d56Sopenharmony_ci '__gt__', 20917db96d56Sopenharmony_ci '__ge__', 20927db96d56Sopenharmony_ci '__hash__', 20937db96d56Sopenharmony_ci '__init__', 20947db96d56Sopenharmony_ci '__repr__', 20957db96d56Sopenharmony_ci '__setattr__', 20967db96d56Sopenharmony_ci '__delattr__', 20977db96d56Sopenharmony_ci ): 20987db96d56Sopenharmony_ci self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}") 20997db96d56Sopenharmony_ci 21007db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"): 21017db96d56Sopenharmony_ci A() 21027db96d56Sopenharmony_ci 21037db96d56Sopenharmony_ci 21047db96d56Sopenharmony_ciclass TestFieldNoAnnotation(unittest.TestCase): 21057db96d56Sopenharmony_ci def test_field_without_annotation(self): 21067db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 21077db96d56Sopenharmony_ci "'f' is a field but has no type annotation"): 21087db96d56Sopenharmony_ci @dataclass 21097db96d56Sopenharmony_ci class C: 21107db96d56Sopenharmony_ci f = field() 21117db96d56Sopenharmony_ci 21127db96d56Sopenharmony_ci def test_field_without_annotation_but_annotation_in_base(self): 21137db96d56Sopenharmony_ci @dataclass 21147db96d56Sopenharmony_ci class B: 21157db96d56Sopenharmony_ci f: int 21167db96d56Sopenharmony_ci 21177db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 21187db96d56Sopenharmony_ci "'f' is a field but has no type annotation"): 21197db96d56Sopenharmony_ci # This is still an error: make sure we don't pick up the 21207db96d56Sopenharmony_ci # type annotation in the base class. 21217db96d56Sopenharmony_ci @dataclass 21227db96d56Sopenharmony_ci class C(B): 21237db96d56Sopenharmony_ci f = field() 21247db96d56Sopenharmony_ci 21257db96d56Sopenharmony_ci def test_field_without_annotation_but_annotation_in_base_not_dataclass(self): 21267db96d56Sopenharmony_ci # Same test, but with the base class not a dataclass. 21277db96d56Sopenharmony_ci class B: 21287db96d56Sopenharmony_ci f: int 21297db96d56Sopenharmony_ci 21307db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 21317db96d56Sopenharmony_ci "'f' is a field but has no type annotation"): 21327db96d56Sopenharmony_ci # This is still an error: make sure we don't pick up the 21337db96d56Sopenharmony_ci # type annotation in the base class. 21347db96d56Sopenharmony_ci @dataclass 21357db96d56Sopenharmony_ci class C(B): 21367db96d56Sopenharmony_ci f = field() 21377db96d56Sopenharmony_ci 21387db96d56Sopenharmony_ci 21397db96d56Sopenharmony_ciclass TestDocString(unittest.TestCase): 21407db96d56Sopenharmony_ci def assertDocStrEqual(self, a, b): 21417db96d56Sopenharmony_ci # Because 3.6 and 3.7 differ in how inspect.signature work 21427db96d56Sopenharmony_ci # (see bpo #32108), for the time being just compare them with 21437db96d56Sopenharmony_ci # whitespace stripped. 21447db96d56Sopenharmony_ci self.assertEqual(a.replace(' ', ''), b.replace(' ', '')) 21457db96d56Sopenharmony_ci 21467db96d56Sopenharmony_ci def test_existing_docstring_not_overridden(self): 21477db96d56Sopenharmony_ci @dataclass 21487db96d56Sopenharmony_ci class C: 21497db96d56Sopenharmony_ci """Lorem ipsum""" 21507db96d56Sopenharmony_ci x: int 21517db96d56Sopenharmony_ci 21527db96d56Sopenharmony_ci self.assertEqual(C.__doc__, "Lorem ipsum") 21537db96d56Sopenharmony_ci 21547db96d56Sopenharmony_ci def test_docstring_no_fields(self): 21557db96d56Sopenharmony_ci @dataclass 21567db96d56Sopenharmony_ci class C: 21577db96d56Sopenharmony_ci pass 21587db96d56Sopenharmony_ci 21597db96d56Sopenharmony_ci self.assertDocStrEqual(C.__doc__, "C()") 21607db96d56Sopenharmony_ci 21617db96d56Sopenharmony_ci def test_docstring_one_field(self): 21627db96d56Sopenharmony_ci @dataclass 21637db96d56Sopenharmony_ci class C: 21647db96d56Sopenharmony_ci x: int 21657db96d56Sopenharmony_ci 21667db96d56Sopenharmony_ci self.assertDocStrEqual(C.__doc__, "C(x:int)") 21677db96d56Sopenharmony_ci 21687db96d56Sopenharmony_ci def test_docstring_two_fields(self): 21697db96d56Sopenharmony_ci @dataclass 21707db96d56Sopenharmony_ci class C: 21717db96d56Sopenharmony_ci x: int 21727db96d56Sopenharmony_ci y: int 21737db96d56Sopenharmony_ci 21747db96d56Sopenharmony_ci self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)") 21757db96d56Sopenharmony_ci 21767db96d56Sopenharmony_ci def test_docstring_three_fields(self): 21777db96d56Sopenharmony_ci @dataclass 21787db96d56Sopenharmony_ci class C: 21797db96d56Sopenharmony_ci x: int 21807db96d56Sopenharmony_ci y: int 21817db96d56Sopenharmony_ci z: str 21827db96d56Sopenharmony_ci 21837db96d56Sopenharmony_ci self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)") 21847db96d56Sopenharmony_ci 21857db96d56Sopenharmony_ci def test_docstring_one_field_with_default(self): 21867db96d56Sopenharmony_ci @dataclass 21877db96d56Sopenharmony_ci class C: 21887db96d56Sopenharmony_ci x: int = 3 21897db96d56Sopenharmony_ci 21907db96d56Sopenharmony_ci self.assertDocStrEqual(C.__doc__, "C(x:int=3)") 21917db96d56Sopenharmony_ci 21927db96d56Sopenharmony_ci def test_docstring_one_field_with_default_none(self): 21937db96d56Sopenharmony_ci @dataclass 21947db96d56Sopenharmony_ci class C: 21957db96d56Sopenharmony_ci x: Union[int, type(None)] = None 21967db96d56Sopenharmony_ci 21977db96d56Sopenharmony_ci self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)") 21987db96d56Sopenharmony_ci 21997db96d56Sopenharmony_ci def test_docstring_list_field(self): 22007db96d56Sopenharmony_ci @dataclass 22017db96d56Sopenharmony_ci class C: 22027db96d56Sopenharmony_ci x: List[int] 22037db96d56Sopenharmony_ci 22047db96d56Sopenharmony_ci self.assertDocStrEqual(C.__doc__, "C(x:List[int])") 22057db96d56Sopenharmony_ci 22067db96d56Sopenharmony_ci def test_docstring_list_field_with_default_factory(self): 22077db96d56Sopenharmony_ci @dataclass 22087db96d56Sopenharmony_ci class C: 22097db96d56Sopenharmony_ci x: List[int] = field(default_factory=list) 22107db96d56Sopenharmony_ci 22117db96d56Sopenharmony_ci self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)") 22127db96d56Sopenharmony_ci 22137db96d56Sopenharmony_ci def test_docstring_deque_field(self): 22147db96d56Sopenharmony_ci @dataclass 22157db96d56Sopenharmony_ci class C: 22167db96d56Sopenharmony_ci x: deque 22177db96d56Sopenharmony_ci 22187db96d56Sopenharmony_ci self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)") 22197db96d56Sopenharmony_ci 22207db96d56Sopenharmony_ci def test_docstring_deque_field_with_default_factory(self): 22217db96d56Sopenharmony_ci @dataclass 22227db96d56Sopenharmony_ci class C: 22237db96d56Sopenharmony_ci x: deque = field(default_factory=deque) 22247db96d56Sopenharmony_ci 22257db96d56Sopenharmony_ci self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)") 22267db96d56Sopenharmony_ci 22277db96d56Sopenharmony_ci def test_docstring_with_no_signature(self): 22287db96d56Sopenharmony_ci # See https://github.com/python/cpython/issues/103449 22297db96d56Sopenharmony_ci class Meta(type): 22307db96d56Sopenharmony_ci __call__ = dict 22317db96d56Sopenharmony_ci class Base(metaclass=Meta): 22327db96d56Sopenharmony_ci pass 22337db96d56Sopenharmony_ci 22347db96d56Sopenharmony_ci @dataclass 22357db96d56Sopenharmony_ci class C(Base): 22367db96d56Sopenharmony_ci pass 22377db96d56Sopenharmony_ci 22387db96d56Sopenharmony_ci self.assertDocStrEqual(C.__doc__, "C") 22397db96d56Sopenharmony_ci 22407db96d56Sopenharmony_ci 22417db96d56Sopenharmony_ciclass TestInit(unittest.TestCase): 22427db96d56Sopenharmony_ci def test_base_has_init(self): 22437db96d56Sopenharmony_ci class B: 22447db96d56Sopenharmony_ci def __init__(self): 22457db96d56Sopenharmony_ci self.z = 100 22467db96d56Sopenharmony_ci pass 22477db96d56Sopenharmony_ci 22487db96d56Sopenharmony_ci # Make sure that declaring this class doesn't raise an error. 22497db96d56Sopenharmony_ci # The issue is that we can't override __init__ in our class, 22507db96d56Sopenharmony_ci # but it should be okay to add __init__ to us if our base has 22517db96d56Sopenharmony_ci # an __init__. 22527db96d56Sopenharmony_ci @dataclass 22537db96d56Sopenharmony_ci class C(B): 22547db96d56Sopenharmony_ci x: int = 0 22557db96d56Sopenharmony_ci c = C(10) 22567db96d56Sopenharmony_ci self.assertEqual(c.x, 10) 22577db96d56Sopenharmony_ci self.assertNotIn('z', vars(c)) 22587db96d56Sopenharmony_ci 22597db96d56Sopenharmony_ci # Make sure that if we don't add an init, the base __init__ 22607db96d56Sopenharmony_ci # gets called. 22617db96d56Sopenharmony_ci @dataclass(init=False) 22627db96d56Sopenharmony_ci class C(B): 22637db96d56Sopenharmony_ci x: int = 10 22647db96d56Sopenharmony_ci c = C() 22657db96d56Sopenharmony_ci self.assertEqual(c.x, 10) 22667db96d56Sopenharmony_ci self.assertEqual(c.z, 100) 22677db96d56Sopenharmony_ci 22687db96d56Sopenharmony_ci def test_no_init(self): 22697db96d56Sopenharmony_ci @dataclass(init=False) 22707db96d56Sopenharmony_ci class C: 22717db96d56Sopenharmony_ci i: int = 0 22727db96d56Sopenharmony_ci self.assertEqual(C().i, 0) 22737db96d56Sopenharmony_ci 22747db96d56Sopenharmony_ci @dataclass(init=False) 22757db96d56Sopenharmony_ci class C: 22767db96d56Sopenharmony_ci i: int = 2 22777db96d56Sopenharmony_ci def __init__(self): 22787db96d56Sopenharmony_ci self.i = 3 22797db96d56Sopenharmony_ci self.assertEqual(C().i, 3) 22807db96d56Sopenharmony_ci 22817db96d56Sopenharmony_ci def test_overwriting_init(self): 22827db96d56Sopenharmony_ci # If the class has __init__, use it no matter the value of 22837db96d56Sopenharmony_ci # init=. 22847db96d56Sopenharmony_ci 22857db96d56Sopenharmony_ci @dataclass 22867db96d56Sopenharmony_ci class C: 22877db96d56Sopenharmony_ci x: int 22887db96d56Sopenharmony_ci def __init__(self, x): 22897db96d56Sopenharmony_ci self.x = 2 * x 22907db96d56Sopenharmony_ci self.assertEqual(C(3).x, 6) 22917db96d56Sopenharmony_ci 22927db96d56Sopenharmony_ci @dataclass(init=True) 22937db96d56Sopenharmony_ci class C: 22947db96d56Sopenharmony_ci x: int 22957db96d56Sopenharmony_ci def __init__(self, x): 22967db96d56Sopenharmony_ci self.x = 2 * x 22977db96d56Sopenharmony_ci self.assertEqual(C(4).x, 8) 22987db96d56Sopenharmony_ci 22997db96d56Sopenharmony_ci @dataclass(init=False) 23007db96d56Sopenharmony_ci class C: 23017db96d56Sopenharmony_ci x: int 23027db96d56Sopenharmony_ci def __init__(self, x): 23037db96d56Sopenharmony_ci self.x = 2 * x 23047db96d56Sopenharmony_ci self.assertEqual(C(5).x, 10) 23057db96d56Sopenharmony_ci 23067db96d56Sopenharmony_ci def test_inherit_from_protocol(self): 23077db96d56Sopenharmony_ci # Dataclasses inheriting from protocol should preserve their own `__init__`. 23087db96d56Sopenharmony_ci # See bpo-45081. 23097db96d56Sopenharmony_ci 23107db96d56Sopenharmony_ci class P(Protocol): 23117db96d56Sopenharmony_ci a: int 23127db96d56Sopenharmony_ci 23137db96d56Sopenharmony_ci @dataclass 23147db96d56Sopenharmony_ci class C(P): 23157db96d56Sopenharmony_ci a: int 23167db96d56Sopenharmony_ci 23177db96d56Sopenharmony_ci self.assertEqual(C(5).a, 5) 23187db96d56Sopenharmony_ci 23197db96d56Sopenharmony_ci @dataclass 23207db96d56Sopenharmony_ci class D(P): 23217db96d56Sopenharmony_ci def __init__(self, a): 23227db96d56Sopenharmony_ci self.a = a * 2 23237db96d56Sopenharmony_ci 23247db96d56Sopenharmony_ci self.assertEqual(D(5).a, 10) 23257db96d56Sopenharmony_ci 23267db96d56Sopenharmony_ci 23277db96d56Sopenharmony_ciclass TestRepr(unittest.TestCase): 23287db96d56Sopenharmony_ci def test_repr(self): 23297db96d56Sopenharmony_ci @dataclass 23307db96d56Sopenharmony_ci class B: 23317db96d56Sopenharmony_ci x: int 23327db96d56Sopenharmony_ci 23337db96d56Sopenharmony_ci @dataclass 23347db96d56Sopenharmony_ci class C(B): 23357db96d56Sopenharmony_ci y: int = 10 23367db96d56Sopenharmony_ci 23377db96d56Sopenharmony_ci o = C(4) 23387db96d56Sopenharmony_ci self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)') 23397db96d56Sopenharmony_ci 23407db96d56Sopenharmony_ci @dataclass 23417db96d56Sopenharmony_ci class D(C): 23427db96d56Sopenharmony_ci x: int = 20 23437db96d56Sopenharmony_ci self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)') 23447db96d56Sopenharmony_ci 23457db96d56Sopenharmony_ci @dataclass 23467db96d56Sopenharmony_ci class C: 23477db96d56Sopenharmony_ci @dataclass 23487db96d56Sopenharmony_ci class D: 23497db96d56Sopenharmony_ci i: int 23507db96d56Sopenharmony_ci @dataclass 23517db96d56Sopenharmony_ci class E: 23527db96d56Sopenharmony_ci pass 23537db96d56Sopenharmony_ci self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)') 23547db96d56Sopenharmony_ci self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()') 23557db96d56Sopenharmony_ci 23567db96d56Sopenharmony_ci def test_no_repr(self): 23577db96d56Sopenharmony_ci # Test a class with no __repr__ and repr=False. 23587db96d56Sopenharmony_ci @dataclass(repr=False) 23597db96d56Sopenharmony_ci class C: 23607db96d56Sopenharmony_ci x: int 23617db96d56Sopenharmony_ci self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at', 23627db96d56Sopenharmony_ci repr(C(3))) 23637db96d56Sopenharmony_ci 23647db96d56Sopenharmony_ci # Test a class with a __repr__ and repr=False. 23657db96d56Sopenharmony_ci @dataclass(repr=False) 23667db96d56Sopenharmony_ci class C: 23677db96d56Sopenharmony_ci x: int 23687db96d56Sopenharmony_ci def __repr__(self): 23697db96d56Sopenharmony_ci return 'C-class' 23707db96d56Sopenharmony_ci self.assertEqual(repr(C(3)), 'C-class') 23717db96d56Sopenharmony_ci 23727db96d56Sopenharmony_ci def test_overwriting_repr(self): 23737db96d56Sopenharmony_ci # If the class has __repr__, use it no matter the value of 23747db96d56Sopenharmony_ci # repr=. 23757db96d56Sopenharmony_ci 23767db96d56Sopenharmony_ci @dataclass 23777db96d56Sopenharmony_ci class C: 23787db96d56Sopenharmony_ci x: int 23797db96d56Sopenharmony_ci def __repr__(self): 23807db96d56Sopenharmony_ci return 'x' 23817db96d56Sopenharmony_ci self.assertEqual(repr(C(0)), 'x') 23827db96d56Sopenharmony_ci 23837db96d56Sopenharmony_ci @dataclass(repr=True) 23847db96d56Sopenharmony_ci class C: 23857db96d56Sopenharmony_ci x: int 23867db96d56Sopenharmony_ci def __repr__(self): 23877db96d56Sopenharmony_ci return 'x' 23887db96d56Sopenharmony_ci self.assertEqual(repr(C(0)), 'x') 23897db96d56Sopenharmony_ci 23907db96d56Sopenharmony_ci @dataclass(repr=False) 23917db96d56Sopenharmony_ci class C: 23927db96d56Sopenharmony_ci x: int 23937db96d56Sopenharmony_ci def __repr__(self): 23947db96d56Sopenharmony_ci return 'x' 23957db96d56Sopenharmony_ci self.assertEqual(repr(C(0)), 'x') 23967db96d56Sopenharmony_ci 23977db96d56Sopenharmony_ci 23987db96d56Sopenharmony_ciclass TestEq(unittest.TestCase): 23997db96d56Sopenharmony_ci def test_no_eq(self): 24007db96d56Sopenharmony_ci # Test a class with no __eq__ and eq=False. 24017db96d56Sopenharmony_ci @dataclass(eq=False) 24027db96d56Sopenharmony_ci class C: 24037db96d56Sopenharmony_ci x: int 24047db96d56Sopenharmony_ci self.assertNotEqual(C(0), C(0)) 24057db96d56Sopenharmony_ci c = C(3) 24067db96d56Sopenharmony_ci self.assertEqual(c, c) 24077db96d56Sopenharmony_ci 24087db96d56Sopenharmony_ci # Test a class with an __eq__ and eq=False. 24097db96d56Sopenharmony_ci @dataclass(eq=False) 24107db96d56Sopenharmony_ci class C: 24117db96d56Sopenharmony_ci x: int 24127db96d56Sopenharmony_ci def __eq__(self, other): 24137db96d56Sopenharmony_ci return other == 10 24147db96d56Sopenharmony_ci self.assertEqual(C(3), 10) 24157db96d56Sopenharmony_ci 24167db96d56Sopenharmony_ci def test_overwriting_eq(self): 24177db96d56Sopenharmony_ci # If the class has __eq__, use it no matter the value of 24187db96d56Sopenharmony_ci # eq=. 24197db96d56Sopenharmony_ci 24207db96d56Sopenharmony_ci @dataclass 24217db96d56Sopenharmony_ci class C: 24227db96d56Sopenharmony_ci x: int 24237db96d56Sopenharmony_ci def __eq__(self, other): 24247db96d56Sopenharmony_ci return other == 3 24257db96d56Sopenharmony_ci self.assertEqual(C(1), 3) 24267db96d56Sopenharmony_ci self.assertNotEqual(C(1), 1) 24277db96d56Sopenharmony_ci 24287db96d56Sopenharmony_ci @dataclass(eq=True) 24297db96d56Sopenharmony_ci class C: 24307db96d56Sopenharmony_ci x: int 24317db96d56Sopenharmony_ci def __eq__(self, other): 24327db96d56Sopenharmony_ci return other == 4 24337db96d56Sopenharmony_ci self.assertEqual(C(1), 4) 24347db96d56Sopenharmony_ci self.assertNotEqual(C(1), 1) 24357db96d56Sopenharmony_ci 24367db96d56Sopenharmony_ci @dataclass(eq=False) 24377db96d56Sopenharmony_ci class C: 24387db96d56Sopenharmony_ci x: int 24397db96d56Sopenharmony_ci def __eq__(self, other): 24407db96d56Sopenharmony_ci return other == 5 24417db96d56Sopenharmony_ci self.assertEqual(C(1), 5) 24427db96d56Sopenharmony_ci self.assertNotEqual(C(1), 1) 24437db96d56Sopenharmony_ci 24447db96d56Sopenharmony_ci 24457db96d56Sopenharmony_ciclass TestOrdering(unittest.TestCase): 24467db96d56Sopenharmony_ci def test_functools_total_ordering(self): 24477db96d56Sopenharmony_ci # Test that functools.total_ordering works with this class. 24487db96d56Sopenharmony_ci @total_ordering 24497db96d56Sopenharmony_ci @dataclass 24507db96d56Sopenharmony_ci class C: 24517db96d56Sopenharmony_ci x: int 24527db96d56Sopenharmony_ci def __lt__(self, other): 24537db96d56Sopenharmony_ci # Perform the test "backward", just to make 24547db96d56Sopenharmony_ci # sure this is being called. 24557db96d56Sopenharmony_ci return self.x >= other 24567db96d56Sopenharmony_ci 24577db96d56Sopenharmony_ci self.assertLess(C(0), -1) 24587db96d56Sopenharmony_ci self.assertLessEqual(C(0), -1) 24597db96d56Sopenharmony_ci self.assertGreater(C(0), 1) 24607db96d56Sopenharmony_ci self.assertGreaterEqual(C(0), 1) 24617db96d56Sopenharmony_ci 24627db96d56Sopenharmony_ci def test_no_order(self): 24637db96d56Sopenharmony_ci # Test that no ordering functions are added by default. 24647db96d56Sopenharmony_ci @dataclass(order=False) 24657db96d56Sopenharmony_ci class C: 24667db96d56Sopenharmony_ci x: int 24677db96d56Sopenharmony_ci # Make sure no order methods are added. 24687db96d56Sopenharmony_ci self.assertNotIn('__le__', C.__dict__) 24697db96d56Sopenharmony_ci self.assertNotIn('__lt__', C.__dict__) 24707db96d56Sopenharmony_ci self.assertNotIn('__ge__', C.__dict__) 24717db96d56Sopenharmony_ci self.assertNotIn('__gt__', C.__dict__) 24727db96d56Sopenharmony_ci 24737db96d56Sopenharmony_ci # Test that __lt__ is still called 24747db96d56Sopenharmony_ci @dataclass(order=False) 24757db96d56Sopenharmony_ci class C: 24767db96d56Sopenharmony_ci x: int 24777db96d56Sopenharmony_ci def __lt__(self, other): 24787db96d56Sopenharmony_ci return False 24797db96d56Sopenharmony_ci # Make sure other methods aren't added. 24807db96d56Sopenharmony_ci self.assertNotIn('__le__', C.__dict__) 24817db96d56Sopenharmony_ci self.assertNotIn('__ge__', C.__dict__) 24827db96d56Sopenharmony_ci self.assertNotIn('__gt__', C.__dict__) 24837db96d56Sopenharmony_ci 24847db96d56Sopenharmony_ci def test_overwriting_order(self): 24857db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 24867db96d56Sopenharmony_ci 'Cannot overwrite attribute __lt__' 24877db96d56Sopenharmony_ci '.*using functools.total_ordering'): 24887db96d56Sopenharmony_ci @dataclass(order=True) 24897db96d56Sopenharmony_ci class C: 24907db96d56Sopenharmony_ci x: int 24917db96d56Sopenharmony_ci def __lt__(self): 24927db96d56Sopenharmony_ci pass 24937db96d56Sopenharmony_ci 24947db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 24957db96d56Sopenharmony_ci 'Cannot overwrite attribute __le__' 24967db96d56Sopenharmony_ci '.*using functools.total_ordering'): 24977db96d56Sopenharmony_ci @dataclass(order=True) 24987db96d56Sopenharmony_ci class C: 24997db96d56Sopenharmony_ci x: int 25007db96d56Sopenharmony_ci def __le__(self): 25017db96d56Sopenharmony_ci pass 25027db96d56Sopenharmony_ci 25037db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 25047db96d56Sopenharmony_ci 'Cannot overwrite attribute __gt__' 25057db96d56Sopenharmony_ci '.*using functools.total_ordering'): 25067db96d56Sopenharmony_ci @dataclass(order=True) 25077db96d56Sopenharmony_ci class C: 25087db96d56Sopenharmony_ci x: int 25097db96d56Sopenharmony_ci def __gt__(self): 25107db96d56Sopenharmony_ci pass 25117db96d56Sopenharmony_ci 25127db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 25137db96d56Sopenharmony_ci 'Cannot overwrite attribute __ge__' 25147db96d56Sopenharmony_ci '.*using functools.total_ordering'): 25157db96d56Sopenharmony_ci @dataclass(order=True) 25167db96d56Sopenharmony_ci class C: 25177db96d56Sopenharmony_ci x: int 25187db96d56Sopenharmony_ci def __ge__(self): 25197db96d56Sopenharmony_ci pass 25207db96d56Sopenharmony_ci 25217db96d56Sopenharmony_ciclass TestHash(unittest.TestCase): 25227db96d56Sopenharmony_ci def test_unsafe_hash(self): 25237db96d56Sopenharmony_ci @dataclass(unsafe_hash=True) 25247db96d56Sopenharmony_ci class C: 25257db96d56Sopenharmony_ci x: int 25267db96d56Sopenharmony_ci y: str 25277db96d56Sopenharmony_ci self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo'))) 25287db96d56Sopenharmony_ci 25297db96d56Sopenharmony_ci def test_hash_rules(self): 25307db96d56Sopenharmony_ci def non_bool(value): 25317db96d56Sopenharmony_ci # Map to something else that's True, but not a bool. 25327db96d56Sopenharmony_ci if value is None: 25337db96d56Sopenharmony_ci return None 25347db96d56Sopenharmony_ci if value: 25357db96d56Sopenharmony_ci return (3,) 25367db96d56Sopenharmony_ci return 0 25377db96d56Sopenharmony_ci 25387db96d56Sopenharmony_ci def test(case, unsafe_hash, eq, frozen, with_hash, result): 25397db96d56Sopenharmony_ci with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq, 25407db96d56Sopenharmony_ci frozen=frozen): 25417db96d56Sopenharmony_ci if result != 'exception': 25427db96d56Sopenharmony_ci if with_hash: 25437db96d56Sopenharmony_ci @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 25447db96d56Sopenharmony_ci class C: 25457db96d56Sopenharmony_ci def __hash__(self): 25467db96d56Sopenharmony_ci return 0 25477db96d56Sopenharmony_ci else: 25487db96d56Sopenharmony_ci @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 25497db96d56Sopenharmony_ci class C: 25507db96d56Sopenharmony_ci pass 25517db96d56Sopenharmony_ci 25527db96d56Sopenharmony_ci # See if the result matches what's expected. 25537db96d56Sopenharmony_ci if result == 'fn': 25547db96d56Sopenharmony_ci # __hash__ contains the function we generated. 25557db96d56Sopenharmony_ci self.assertIn('__hash__', C.__dict__) 25567db96d56Sopenharmony_ci self.assertIsNotNone(C.__dict__['__hash__']) 25577db96d56Sopenharmony_ci 25587db96d56Sopenharmony_ci elif result == '': 25597db96d56Sopenharmony_ci # __hash__ is not present in our class. 25607db96d56Sopenharmony_ci if not with_hash: 25617db96d56Sopenharmony_ci self.assertNotIn('__hash__', C.__dict__) 25627db96d56Sopenharmony_ci 25637db96d56Sopenharmony_ci elif result == 'none': 25647db96d56Sopenharmony_ci # __hash__ is set to None. 25657db96d56Sopenharmony_ci self.assertIn('__hash__', C.__dict__) 25667db96d56Sopenharmony_ci self.assertIsNone(C.__dict__['__hash__']) 25677db96d56Sopenharmony_ci 25687db96d56Sopenharmony_ci elif result == 'exception': 25697db96d56Sopenharmony_ci # Creating the class should cause an exception. 25707db96d56Sopenharmony_ci # This only happens with with_hash==True. 25717db96d56Sopenharmony_ci assert(with_hash) 25727db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'): 25737db96d56Sopenharmony_ci @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 25747db96d56Sopenharmony_ci class C: 25757db96d56Sopenharmony_ci def __hash__(self): 25767db96d56Sopenharmony_ci return 0 25777db96d56Sopenharmony_ci 25787db96d56Sopenharmony_ci else: 25797db96d56Sopenharmony_ci assert False, f'unknown result {result!r}' 25807db96d56Sopenharmony_ci 25817db96d56Sopenharmony_ci # There are 8 cases of: 25827db96d56Sopenharmony_ci # unsafe_hash=True/False 25837db96d56Sopenharmony_ci # eq=True/False 25847db96d56Sopenharmony_ci # frozen=True/False 25857db96d56Sopenharmony_ci # And for each of these, a different result if 25867db96d56Sopenharmony_ci # __hash__ is defined or not. 25877db96d56Sopenharmony_ci for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([ 25887db96d56Sopenharmony_ci (False, False, False, '', ''), 25897db96d56Sopenharmony_ci (False, False, True, '', ''), 25907db96d56Sopenharmony_ci (False, True, False, 'none', ''), 25917db96d56Sopenharmony_ci (False, True, True, 'fn', ''), 25927db96d56Sopenharmony_ci (True, False, False, 'fn', 'exception'), 25937db96d56Sopenharmony_ci (True, False, True, 'fn', 'exception'), 25947db96d56Sopenharmony_ci (True, True, False, 'fn', 'exception'), 25957db96d56Sopenharmony_ci (True, True, True, 'fn', 'exception'), 25967db96d56Sopenharmony_ci ], 1): 25977db96d56Sopenharmony_ci test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash) 25987db96d56Sopenharmony_ci test(case, unsafe_hash, eq, frozen, True, res_defined_hash) 25997db96d56Sopenharmony_ci 26007db96d56Sopenharmony_ci # Test non-bool truth values, too. This is just to 26017db96d56Sopenharmony_ci # make sure the data-driven table in the decorator 26027db96d56Sopenharmony_ci # handles non-bool values. 26037db96d56Sopenharmony_ci test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash) 26047db96d56Sopenharmony_ci test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash) 26057db96d56Sopenharmony_ci 26067db96d56Sopenharmony_ci 26077db96d56Sopenharmony_ci def test_eq_only(self): 26087db96d56Sopenharmony_ci # If a class defines __eq__, __hash__ is automatically added 26097db96d56Sopenharmony_ci # and set to None. This is normal Python behavior, not 26107db96d56Sopenharmony_ci # related to dataclasses. Make sure we don't interfere with 26117db96d56Sopenharmony_ci # that (see bpo=32546). 26127db96d56Sopenharmony_ci 26137db96d56Sopenharmony_ci @dataclass 26147db96d56Sopenharmony_ci class C: 26157db96d56Sopenharmony_ci i: int 26167db96d56Sopenharmony_ci def __eq__(self, other): 26177db96d56Sopenharmony_ci return self.i == other.i 26187db96d56Sopenharmony_ci self.assertEqual(C(1), C(1)) 26197db96d56Sopenharmony_ci self.assertNotEqual(C(1), C(4)) 26207db96d56Sopenharmony_ci 26217db96d56Sopenharmony_ci # And make sure things work in this case if we specify 26227db96d56Sopenharmony_ci # unsafe_hash=True. 26237db96d56Sopenharmony_ci @dataclass(unsafe_hash=True) 26247db96d56Sopenharmony_ci class C: 26257db96d56Sopenharmony_ci i: int 26267db96d56Sopenharmony_ci def __eq__(self, other): 26277db96d56Sopenharmony_ci return self.i == other.i 26287db96d56Sopenharmony_ci self.assertEqual(C(1), C(1.0)) 26297db96d56Sopenharmony_ci self.assertEqual(hash(C(1)), hash(C(1.0))) 26307db96d56Sopenharmony_ci 26317db96d56Sopenharmony_ci # And check that the classes __eq__ is being used, despite 26327db96d56Sopenharmony_ci # specifying eq=True. 26337db96d56Sopenharmony_ci @dataclass(unsafe_hash=True, eq=True) 26347db96d56Sopenharmony_ci class C: 26357db96d56Sopenharmony_ci i: int 26367db96d56Sopenharmony_ci def __eq__(self, other): 26377db96d56Sopenharmony_ci return self.i == 3 and self.i == other.i 26387db96d56Sopenharmony_ci self.assertEqual(C(3), C(3)) 26397db96d56Sopenharmony_ci self.assertNotEqual(C(1), C(1)) 26407db96d56Sopenharmony_ci self.assertEqual(hash(C(1)), hash(C(1.0))) 26417db96d56Sopenharmony_ci 26427db96d56Sopenharmony_ci def test_0_field_hash(self): 26437db96d56Sopenharmony_ci @dataclass(frozen=True) 26447db96d56Sopenharmony_ci class C: 26457db96d56Sopenharmony_ci pass 26467db96d56Sopenharmony_ci self.assertEqual(hash(C()), hash(())) 26477db96d56Sopenharmony_ci 26487db96d56Sopenharmony_ci @dataclass(unsafe_hash=True) 26497db96d56Sopenharmony_ci class C: 26507db96d56Sopenharmony_ci pass 26517db96d56Sopenharmony_ci self.assertEqual(hash(C()), hash(())) 26527db96d56Sopenharmony_ci 26537db96d56Sopenharmony_ci def test_1_field_hash(self): 26547db96d56Sopenharmony_ci @dataclass(frozen=True) 26557db96d56Sopenharmony_ci class C: 26567db96d56Sopenharmony_ci x: int 26577db96d56Sopenharmony_ci self.assertEqual(hash(C(4)), hash((4,))) 26587db96d56Sopenharmony_ci self.assertEqual(hash(C(42)), hash((42,))) 26597db96d56Sopenharmony_ci 26607db96d56Sopenharmony_ci @dataclass(unsafe_hash=True) 26617db96d56Sopenharmony_ci class C: 26627db96d56Sopenharmony_ci x: int 26637db96d56Sopenharmony_ci self.assertEqual(hash(C(4)), hash((4,))) 26647db96d56Sopenharmony_ci self.assertEqual(hash(C(42)), hash((42,))) 26657db96d56Sopenharmony_ci 26667db96d56Sopenharmony_ci def test_hash_no_args(self): 26677db96d56Sopenharmony_ci # Test dataclasses with no hash= argument. This exists to 26687db96d56Sopenharmony_ci # make sure that if the @dataclass parameter name is changed 26697db96d56Sopenharmony_ci # or the non-default hashing behavior changes, the default 26707db96d56Sopenharmony_ci # hashability keeps working the same way. 26717db96d56Sopenharmony_ci 26727db96d56Sopenharmony_ci class Base: 26737db96d56Sopenharmony_ci def __hash__(self): 26747db96d56Sopenharmony_ci return 301 26757db96d56Sopenharmony_ci 26767db96d56Sopenharmony_ci # If frozen or eq is None, then use the default value (do not 26777db96d56Sopenharmony_ci # specify any value in the decorator). 26787db96d56Sopenharmony_ci for frozen, eq, base, expected in [ 26797db96d56Sopenharmony_ci (None, None, object, 'unhashable'), 26807db96d56Sopenharmony_ci (None, None, Base, 'unhashable'), 26817db96d56Sopenharmony_ci (None, False, object, 'object'), 26827db96d56Sopenharmony_ci (None, False, Base, 'base'), 26837db96d56Sopenharmony_ci (None, True, object, 'unhashable'), 26847db96d56Sopenharmony_ci (None, True, Base, 'unhashable'), 26857db96d56Sopenharmony_ci (False, None, object, 'unhashable'), 26867db96d56Sopenharmony_ci (False, None, Base, 'unhashable'), 26877db96d56Sopenharmony_ci (False, False, object, 'object'), 26887db96d56Sopenharmony_ci (False, False, Base, 'base'), 26897db96d56Sopenharmony_ci (False, True, object, 'unhashable'), 26907db96d56Sopenharmony_ci (False, True, Base, 'unhashable'), 26917db96d56Sopenharmony_ci (True, None, object, 'tuple'), 26927db96d56Sopenharmony_ci (True, None, Base, 'tuple'), 26937db96d56Sopenharmony_ci (True, False, object, 'object'), 26947db96d56Sopenharmony_ci (True, False, Base, 'base'), 26957db96d56Sopenharmony_ci (True, True, object, 'tuple'), 26967db96d56Sopenharmony_ci (True, True, Base, 'tuple'), 26977db96d56Sopenharmony_ci ]: 26987db96d56Sopenharmony_ci 26997db96d56Sopenharmony_ci with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected): 27007db96d56Sopenharmony_ci # First, create the class. 27017db96d56Sopenharmony_ci if frozen is None and eq is None: 27027db96d56Sopenharmony_ci @dataclass 27037db96d56Sopenharmony_ci class C(base): 27047db96d56Sopenharmony_ci i: int 27057db96d56Sopenharmony_ci elif frozen is None: 27067db96d56Sopenharmony_ci @dataclass(eq=eq) 27077db96d56Sopenharmony_ci class C(base): 27087db96d56Sopenharmony_ci i: int 27097db96d56Sopenharmony_ci elif eq is None: 27107db96d56Sopenharmony_ci @dataclass(frozen=frozen) 27117db96d56Sopenharmony_ci class C(base): 27127db96d56Sopenharmony_ci i: int 27137db96d56Sopenharmony_ci else: 27147db96d56Sopenharmony_ci @dataclass(frozen=frozen, eq=eq) 27157db96d56Sopenharmony_ci class C(base): 27167db96d56Sopenharmony_ci i: int 27177db96d56Sopenharmony_ci 27187db96d56Sopenharmony_ci # Now, make sure it hashes as expected. 27197db96d56Sopenharmony_ci if expected == 'unhashable': 27207db96d56Sopenharmony_ci c = C(10) 27217db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'unhashable type'): 27227db96d56Sopenharmony_ci hash(c) 27237db96d56Sopenharmony_ci 27247db96d56Sopenharmony_ci elif expected == 'base': 27257db96d56Sopenharmony_ci self.assertEqual(hash(C(10)), 301) 27267db96d56Sopenharmony_ci 27277db96d56Sopenharmony_ci elif expected == 'object': 27287db96d56Sopenharmony_ci # I'm not sure what test to use here. object's 27297db96d56Sopenharmony_ci # hash isn't based on id(), so calling hash() 27307db96d56Sopenharmony_ci # won't tell us much. So, just check the 27317db96d56Sopenharmony_ci # function used is object's. 27327db96d56Sopenharmony_ci self.assertIs(C.__hash__, object.__hash__) 27337db96d56Sopenharmony_ci 27347db96d56Sopenharmony_ci elif expected == 'tuple': 27357db96d56Sopenharmony_ci self.assertEqual(hash(C(42)), hash((42,))) 27367db96d56Sopenharmony_ci 27377db96d56Sopenharmony_ci else: 27387db96d56Sopenharmony_ci assert False, f'unknown value for expected={expected!r}' 27397db96d56Sopenharmony_ci 27407db96d56Sopenharmony_ci 27417db96d56Sopenharmony_ciclass TestFrozen(unittest.TestCase): 27427db96d56Sopenharmony_ci def test_frozen(self): 27437db96d56Sopenharmony_ci @dataclass(frozen=True) 27447db96d56Sopenharmony_ci class C: 27457db96d56Sopenharmony_ci i: int 27467db96d56Sopenharmony_ci 27477db96d56Sopenharmony_ci c = C(10) 27487db96d56Sopenharmony_ci self.assertEqual(c.i, 10) 27497db96d56Sopenharmony_ci with self.assertRaises(FrozenInstanceError): 27507db96d56Sopenharmony_ci c.i = 5 27517db96d56Sopenharmony_ci self.assertEqual(c.i, 10) 27527db96d56Sopenharmony_ci 27537db96d56Sopenharmony_ci def test_inherit(self): 27547db96d56Sopenharmony_ci @dataclass(frozen=True) 27557db96d56Sopenharmony_ci class C: 27567db96d56Sopenharmony_ci i: int 27577db96d56Sopenharmony_ci 27587db96d56Sopenharmony_ci @dataclass(frozen=True) 27597db96d56Sopenharmony_ci class D(C): 27607db96d56Sopenharmony_ci j: int 27617db96d56Sopenharmony_ci 27627db96d56Sopenharmony_ci d = D(0, 10) 27637db96d56Sopenharmony_ci with self.assertRaises(FrozenInstanceError): 27647db96d56Sopenharmony_ci d.i = 5 27657db96d56Sopenharmony_ci with self.assertRaises(FrozenInstanceError): 27667db96d56Sopenharmony_ci d.j = 6 27677db96d56Sopenharmony_ci self.assertEqual(d.i, 0) 27687db96d56Sopenharmony_ci self.assertEqual(d.j, 10) 27697db96d56Sopenharmony_ci 27707db96d56Sopenharmony_ci def test_inherit_nonfrozen_from_empty_frozen(self): 27717db96d56Sopenharmony_ci @dataclass(frozen=True) 27727db96d56Sopenharmony_ci class C: 27737db96d56Sopenharmony_ci pass 27747db96d56Sopenharmony_ci 27757db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 27767db96d56Sopenharmony_ci 'cannot inherit non-frozen dataclass from a frozen one'): 27777db96d56Sopenharmony_ci @dataclass 27787db96d56Sopenharmony_ci class D(C): 27797db96d56Sopenharmony_ci j: int 27807db96d56Sopenharmony_ci 27817db96d56Sopenharmony_ci def test_inherit_nonfrozen_from_empty(self): 27827db96d56Sopenharmony_ci @dataclass 27837db96d56Sopenharmony_ci class C: 27847db96d56Sopenharmony_ci pass 27857db96d56Sopenharmony_ci 27867db96d56Sopenharmony_ci @dataclass 27877db96d56Sopenharmony_ci class D(C): 27887db96d56Sopenharmony_ci j: int 27897db96d56Sopenharmony_ci 27907db96d56Sopenharmony_ci d = D(3) 27917db96d56Sopenharmony_ci self.assertEqual(d.j, 3) 27927db96d56Sopenharmony_ci self.assertIsInstance(d, C) 27937db96d56Sopenharmony_ci 27947db96d56Sopenharmony_ci # Test both ways: with an intermediate normal (non-dataclass) 27957db96d56Sopenharmony_ci # class and without an intermediate class. 27967db96d56Sopenharmony_ci def test_inherit_nonfrozen_from_frozen(self): 27977db96d56Sopenharmony_ci for intermediate_class in [True, False]: 27987db96d56Sopenharmony_ci with self.subTest(intermediate_class=intermediate_class): 27997db96d56Sopenharmony_ci @dataclass(frozen=True) 28007db96d56Sopenharmony_ci class C: 28017db96d56Sopenharmony_ci i: int 28027db96d56Sopenharmony_ci 28037db96d56Sopenharmony_ci if intermediate_class: 28047db96d56Sopenharmony_ci class I(C): pass 28057db96d56Sopenharmony_ci else: 28067db96d56Sopenharmony_ci I = C 28077db96d56Sopenharmony_ci 28087db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 28097db96d56Sopenharmony_ci 'cannot inherit non-frozen dataclass from a frozen one'): 28107db96d56Sopenharmony_ci @dataclass 28117db96d56Sopenharmony_ci class D(I): 28127db96d56Sopenharmony_ci pass 28137db96d56Sopenharmony_ci 28147db96d56Sopenharmony_ci def test_inherit_frozen_from_nonfrozen(self): 28157db96d56Sopenharmony_ci for intermediate_class in [True, False]: 28167db96d56Sopenharmony_ci with self.subTest(intermediate_class=intermediate_class): 28177db96d56Sopenharmony_ci @dataclass 28187db96d56Sopenharmony_ci class C: 28197db96d56Sopenharmony_ci i: int 28207db96d56Sopenharmony_ci 28217db96d56Sopenharmony_ci if intermediate_class: 28227db96d56Sopenharmony_ci class I(C): pass 28237db96d56Sopenharmony_ci else: 28247db96d56Sopenharmony_ci I = C 28257db96d56Sopenharmony_ci 28267db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 28277db96d56Sopenharmony_ci 'cannot inherit frozen dataclass from a non-frozen one'): 28287db96d56Sopenharmony_ci @dataclass(frozen=True) 28297db96d56Sopenharmony_ci class D(I): 28307db96d56Sopenharmony_ci pass 28317db96d56Sopenharmony_ci 28327db96d56Sopenharmony_ci def test_inherit_from_normal_class(self): 28337db96d56Sopenharmony_ci for intermediate_class in [True, False]: 28347db96d56Sopenharmony_ci with self.subTest(intermediate_class=intermediate_class): 28357db96d56Sopenharmony_ci class C: 28367db96d56Sopenharmony_ci pass 28377db96d56Sopenharmony_ci 28387db96d56Sopenharmony_ci if intermediate_class: 28397db96d56Sopenharmony_ci class I(C): pass 28407db96d56Sopenharmony_ci else: 28417db96d56Sopenharmony_ci I = C 28427db96d56Sopenharmony_ci 28437db96d56Sopenharmony_ci @dataclass(frozen=True) 28447db96d56Sopenharmony_ci class D(I): 28457db96d56Sopenharmony_ci i: int 28467db96d56Sopenharmony_ci 28477db96d56Sopenharmony_ci d = D(10) 28487db96d56Sopenharmony_ci with self.assertRaises(FrozenInstanceError): 28497db96d56Sopenharmony_ci d.i = 5 28507db96d56Sopenharmony_ci 28517db96d56Sopenharmony_ci def test_non_frozen_normal_derived(self): 28527db96d56Sopenharmony_ci # See bpo-32953. 28537db96d56Sopenharmony_ci 28547db96d56Sopenharmony_ci @dataclass(frozen=True) 28557db96d56Sopenharmony_ci class D: 28567db96d56Sopenharmony_ci x: int 28577db96d56Sopenharmony_ci y: int = 10 28587db96d56Sopenharmony_ci 28597db96d56Sopenharmony_ci class S(D): 28607db96d56Sopenharmony_ci pass 28617db96d56Sopenharmony_ci 28627db96d56Sopenharmony_ci s = S(3) 28637db96d56Sopenharmony_ci self.assertEqual(s.x, 3) 28647db96d56Sopenharmony_ci self.assertEqual(s.y, 10) 28657db96d56Sopenharmony_ci s.cached = True 28667db96d56Sopenharmony_ci 28677db96d56Sopenharmony_ci # But can't change the frozen attributes. 28687db96d56Sopenharmony_ci with self.assertRaises(FrozenInstanceError): 28697db96d56Sopenharmony_ci s.x = 5 28707db96d56Sopenharmony_ci with self.assertRaises(FrozenInstanceError): 28717db96d56Sopenharmony_ci s.y = 5 28727db96d56Sopenharmony_ci self.assertEqual(s.x, 3) 28737db96d56Sopenharmony_ci self.assertEqual(s.y, 10) 28747db96d56Sopenharmony_ci self.assertEqual(s.cached, True) 28757db96d56Sopenharmony_ci 28767db96d56Sopenharmony_ci def test_overwriting_frozen(self): 28777db96d56Sopenharmony_ci # frozen uses __setattr__ and __delattr__. 28787db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 28797db96d56Sopenharmony_ci 'Cannot overwrite attribute __setattr__'): 28807db96d56Sopenharmony_ci @dataclass(frozen=True) 28817db96d56Sopenharmony_ci class C: 28827db96d56Sopenharmony_ci x: int 28837db96d56Sopenharmony_ci def __setattr__(self): 28847db96d56Sopenharmony_ci pass 28857db96d56Sopenharmony_ci 28867db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 28877db96d56Sopenharmony_ci 'Cannot overwrite attribute __delattr__'): 28887db96d56Sopenharmony_ci @dataclass(frozen=True) 28897db96d56Sopenharmony_ci class C: 28907db96d56Sopenharmony_ci x: int 28917db96d56Sopenharmony_ci def __delattr__(self): 28927db96d56Sopenharmony_ci pass 28937db96d56Sopenharmony_ci 28947db96d56Sopenharmony_ci @dataclass(frozen=False) 28957db96d56Sopenharmony_ci class C: 28967db96d56Sopenharmony_ci x: int 28977db96d56Sopenharmony_ci def __setattr__(self, name, value): 28987db96d56Sopenharmony_ci self.__dict__['x'] = value * 2 28997db96d56Sopenharmony_ci self.assertEqual(C(10).x, 20) 29007db96d56Sopenharmony_ci 29017db96d56Sopenharmony_ci def test_frozen_hash(self): 29027db96d56Sopenharmony_ci @dataclass(frozen=True) 29037db96d56Sopenharmony_ci class C: 29047db96d56Sopenharmony_ci x: Any 29057db96d56Sopenharmony_ci 29067db96d56Sopenharmony_ci # If x is immutable, we can compute the hash. No exception is 29077db96d56Sopenharmony_ci # raised. 29087db96d56Sopenharmony_ci hash(C(3)) 29097db96d56Sopenharmony_ci 29107db96d56Sopenharmony_ci # If x is mutable, computing the hash is an error. 29117db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'unhashable type'): 29127db96d56Sopenharmony_ci hash(C({})) 29137db96d56Sopenharmony_ci 29147db96d56Sopenharmony_ci 29157db96d56Sopenharmony_ciclass TestSlots(unittest.TestCase): 29167db96d56Sopenharmony_ci def test_simple(self): 29177db96d56Sopenharmony_ci @dataclass 29187db96d56Sopenharmony_ci class C: 29197db96d56Sopenharmony_ci __slots__ = ('x',) 29207db96d56Sopenharmony_ci x: Any 29217db96d56Sopenharmony_ci 29227db96d56Sopenharmony_ci # There was a bug where a variable in a slot was assumed to 29237db96d56Sopenharmony_ci # also have a default value (of type 29247db96d56Sopenharmony_ci # types.MemberDescriptorType). 29257db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 29267db96d56Sopenharmony_ci r"__init__\(\) missing 1 required positional argument: 'x'"): 29277db96d56Sopenharmony_ci C() 29287db96d56Sopenharmony_ci 29297db96d56Sopenharmony_ci # We can create an instance, and assign to x. 29307db96d56Sopenharmony_ci c = C(10) 29317db96d56Sopenharmony_ci self.assertEqual(c.x, 10) 29327db96d56Sopenharmony_ci c.x = 5 29337db96d56Sopenharmony_ci self.assertEqual(c.x, 5) 29347db96d56Sopenharmony_ci 29357db96d56Sopenharmony_ci # We can't assign to anything else. 29367db96d56Sopenharmony_ci with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"): 29377db96d56Sopenharmony_ci c.y = 5 29387db96d56Sopenharmony_ci 29397db96d56Sopenharmony_ci def test_derived_added_field(self): 29407db96d56Sopenharmony_ci # See bpo-33100. 29417db96d56Sopenharmony_ci @dataclass 29427db96d56Sopenharmony_ci class Base: 29437db96d56Sopenharmony_ci __slots__ = ('x',) 29447db96d56Sopenharmony_ci x: Any 29457db96d56Sopenharmony_ci 29467db96d56Sopenharmony_ci @dataclass 29477db96d56Sopenharmony_ci class Derived(Base): 29487db96d56Sopenharmony_ci x: int 29497db96d56Sopenharmony_ci y: int 29507db96d56Sopenharmony_ci 29517db96d56Sopenharmony_ci d = Derived(1, 2) 29527db96d56Sopenharmony_ci self.assertEqual((d.x, d.y), (1, 2)) 29537db96d56Sopenharmony_ci 29547db96d56Sopenharmony_ci # We can add a new field to the derived instance. 29557db96d56Sopenharmony_ci d.z = 10 29567db96d56Sopenharmony_ci 29577db96d56Sopenharmony_ci def test_generated_slots(self): 29587db96d56Sopenharmony_ci @dataclass(slots=True) 29597db96d56Sopenharmony_ci class C: 29607db96d56Sopenharmony_ci x: int 29617db96d56Sopenharmony_ci y: int 29627db96d56Sopenharmony_ci 29637db96d56Sopenharmony_ci c = C(1, 2) 29647db96d56Sopenharmony_ci self.assertEqual((c.x, c.y), (1, 2)) 29657db96d56Sopenharmony_ci 29667db96d56Sopenharmony_ci c.x = 3 29677db96d56Sopenharmony_ci c.y = 4 29687db96d56Sopenharmony_ci self.assertEqual((c.x, c.y), (3, 4)) 29697db96d56Sopenharmony_ci 29707db96d56Sopenharmony_ci with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'z'"): 29717db96d56Sopenharmony_ci c.z = 5 29727db96d56Sopenharmony_ci 29737db96d56Sopenharmony_ci def test_add_slots_when_slots_exists(self): 29747db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, '^C already specifies __slots__$'): 29757db96d56Sopenharmony_ci @dataclass(slots=True) 29767db96d56Sopenharmony_ci class C: 29777db96d56Sopenharmony_ci __slots__ = ('x',) 29787db96d56Sopenharmony_ci x: int 29797db96d56Sopenharmony_ci 29807db96d56Sopenharmony_ci def test_generated_slots_value(self): 29817db96d56Sopenharmony_ci 29827db96d56Sopenharmony_ci class Root: 29837db96d56Sopenharmony_ci __slots__ = {'x'} 29847db96d56Sopenharmony_ci 29857db96d56Sopenharmony_ci class Root2(Root): 29867db96d56Sopenharmony_ci __slots__ = {'k': '...', 'j': ''} 29877db96d56Sopenharmony_ci 29887db96d56Sopenharmony_ci class Root3(Root2): 29897db96d56Sopenharmony_ci __slots__ = ['h'] 29907db96d56Sopenharmony_ci 29917db96d56Sopenharmony_ci class Root4(Root3): 29927db96d56Sopenharmony_ci __slots__ = 'aa' 29937db96d56Sopenharmony_ci 29947db96d56Sopenharmony_ci @dataclass(slots=True) 29957db96d56Sopenharmony_ci class Base(Root4): 29967db96d56Sopenharmony_ci y: int 29977db96d56Sopenharmony_ci j: str 29987db96d56Sopenharmony_ci h: str 29997db96d56Sopenharmony_ci 30007db96d56Sopenharmony_ci self.assertEqual(Base.__slots__, ('y', )) 30017db96d56Sopenharmony_ci 30027db96d56Sopenharmony_ci @dataclass(slots=True) 30037db96d56Sopenharmony_ci class Derived(Base): 30047db96d56Sopenharmony_ci aa: float 30057db96d56Sopenharmony_ci x: str 30067db96d56Sopenharmony_ci z: int 30077db96d56Sopenharmony_ci k: str 30087db96d56Sopenharmony_ci h: str 30097db96d56Sopenharmony_ci 30107db96d56Sopenharmony_ci self.assertEqual(Derived.__slots__, ('z', )) 30117db96d56Sopenharmony_ci 30127db96d56Sopenharmony_ci @dataclass 30137db96d56Sopenharmony_ci class AnotherDerived(Base): 30147db96d56Sopenharmony_ci z: int 30157db96d56Sopenharmony_ci 30167db96d56Sopenharmony_ci self.assertNotIn('__slots__', AnotherDerived.__dict__) 30177db96d56Sopenharmony_ci 30187db96d56Sopenharmony_ci def test_cant_inherit_from_iterator_slots(self): 30197db96d56Sopenharmony_ci 30207db96d56Sopenharmony_ci class Root: 30217db96d56Sopenharmony_ci __slots__ = iter(['a']) 30227db96d56Sopenharmony_ci 30237db96d56Sopenharmony_ci class Root2(Root): 30247db96d56Sopenharmony_ci __slots__ = ('b', ) 30257db96d56Sopenharmony_ci 30267db96d56Sopenharmony_ci with self.assertRaisesRegex( 30277db96d56Sopenharmony_ci TypeError, 30287db96d56Sopenharmony_ci "^Slots of 'Root' cannot be determined" 30297db96d56Sopenharmony_ci ): 30307db96d56Sopenharmony_ci @dataclass(slots=True) 30317db96d56Sopenharmony_ci class C(Root2): 30327db96d56Sopenharmony_ci x: int 30337db96d56Sopenharmony_ci 30347db96d56Sopenharmony_ci def test_returns_new_class(self): 30357db96d56Sopenharmony_ci class A: 30367db96d56Sopenharmony_ci x: int 30377db96d56Sopenharmony_ci 30387db96d56Sopenharmony_ci B = dataclass(A, slots=True) 30397db96d56Sopenharmony_ci self.assertIsNot(A, B) 30407db96d56Sopenharmony_ci 30417db96d56Sopenharmony_ci self.assertFalse(hasattr(A, "__slots__")) 30427db96d56Sopenharmony_ci self.assertTrue(hasattr(B, "__slots__")) 30437db96d56Sopenharmony_ci 30447db96d56Sopenharmony_ci # Can't be local to test_frozen_pickle. 30457db96d56Sopenharmony_ci @dataclass(frozen=True, slots=True) 30467db96d56Sopenharmony_ci class FrozenSlotsClass: 30477db96d56Sopenharmony_ci foo: str 30487db96d56Sopenharmony_ci bar: int 30497db96d56Sopenharmony_ci 30507db96d56Sopenharmony_ci @dataclass(frozen=True) 30517db96d56Sopenharmony_ci class FrozenWithoutSlotsClass: 30527db96d56Sopenharmony_ci foo: str 30537db96d56Sopenharmony_ci bar: int 30547db96d56Sopenharmony_ci 30557db96d56Sopenharmony_ci def test_frozen_pickle(self): 30567db96d56Sopenharmony_ci # bpo-43999 30577db96d56Sopenharmony_ci 30587db96d56Sopenharmony_ci self.assertEqual(self.FrozenSlotsClass.__slots__, ("foo", "bar")) 30597db96d56Sopenharmony_ci for proto in range(pickle.HIGHEST_PROTOCOL + 1): 30607db96d56Sopenharmony_ci with self.subTest(proto=proto): 30617db96d56Sopenharmony_ci obj = self.FrozenSlotsClass("a", 1) 30627db96d56Sopenharmony_ci p = pickle.loads(pickle.dumps(obj, protocol=proto)) 30637db96d56Sopenharmony_ci self.assertIsNot(obj, p) 30647db96d56Sopenharmony_ci self.assertEqual(obj, p) 30657db96d56Sopenharmony_ci 30667db96d56Sopenharmony_ci obj = self.FrozenWithoutSlotsClass("a", 1) 30677db96d56Sopenharmony_ci p = pickle.loads(pickle.dumps(obj, protocol=proto)) 30687db96d56Sopenharmony_ci self.assertIsNot(obj, p) 30697db96d56Sopenharmony_ci self.assertEqual(obj, p) 30707db96d56Sopenharmony_ci 30717db96d56Sopenharmony_ci @dataclass(frozen=True, slots=True) 30727db96d56Sopenharmony_ci class FrozenSlotsGetStateClass: 30737db96d56Sopenharmony_ci foo: str 30747db96d56Sopenharmony_ci bar: int 30757db96d56Sopenharmony_ci 30767db96d56Sopenharmony_ci getstate_called: bool = field(default=False, compare=False) 30777db96d56Sopenharmony_ci 30787db96d56Sopenharmony_ci def __getstate__(self): 30797db96d56Sopenharmony_ci object.__setattr__(self, 'getstate_called', True) 30807db96d56Sopenharmony_ci return [self.foo, self.bar] 30817db96d56Sopenharmony_ci 30827db96d56Sopenharmony_ci @dataclass(frozen=True, slots=True) 30837db96d56Sopenharmony_ci class FrozenSlotsSetStateClass: 30847db96d56Sopenharmony_ci foo: str 30857db96d56Sopenharmony_ci bar: int 30867db96d56Sopenharmony_ci 30877db96d56Sopenharmony_ci setstate_called: bool = field(default=False, compare=False) 30887db96d56Sopenharmony_ci 30897db96d56Sopenharmony_ci def __setstate__(self, state): 30907db96d56Sopenharmony_ci object.__setattr__(self, 'setstate_called', True) 30917db96d56Sopenharmony_ci object.__setattr__(self, 'foo', state[0]) 30927db96d56Sopenharmony_ci object.__setattr__(self, 'bar', state[1]) 30937db96d56Sopenharmony_ci 30947db96d56Sopenharmony_ci @dataclass(frozen=True, slots=True) 30957db96d56Sopenharmony_ci class FrozenSlotsAllStateClass: 30967db96d56Sopenharmony_ci foo: str 30977db96d56Sopenharmony_ci bar: int 30987db96d56Sopenharmony_ci 30997db96d56Sopenharmony_ci getstate_called: bool = field(default=False, compare=False) 31007db96d56Sopenharmony_ci setstate_called: bool = field(default=False, compare=False) 31017db96d56Sopenharmony_ci 31027db96d56Sopenharmony_ci def __getstate__(self): 31037db96d56Sopenharmony_ci object.__setattr__(self, 'getstate_called', True) 31047db96d56Sopenharmony_ci return [self.foo, self.bar] 31057db96d56Sopenharmony_ci 31067db96d56Sopenharmony_ci def __setstate__(self, state): 31077db96d56Sopenharmony_ci object.__setattr__(self, 'setstate_called', True) 31087db96d56Sopenharmony_ci object.__setattr__(self, 'foo', state[0]) 31097db96d56Sopenharmony_ci object.__setattr__(self, 'bar', state[1]) 31107db96d56Sopenharmony_ci 31117db96d56Sopenharmony_ci def test_frozen_slots_pickle_custom_state(self): 31127db96d56Sopenharmony_ci for proto in range(pickle.HIGHEST_PROTOCOL + 1): 31137db96d56Sopenharmony_ci with self.subTest(proto=proto): 31147db96d56Sopenharmony_ci obj = self.FrozenSlotsGetStateClass('a', 1) 31157db96d56Sopenharmony_ci dumped = pickle.dumps(obj, protocol=proto) 31167db96d56Sopenharmony_ci 31177db96d56Sopenharmony_ci self.assertTrue(obj.getstate_called) 31187db96d56Sopenharmony_ci self.assertEqual(obj, pickle.loads(dumped)) 31197db96d56Sopenharmony_ci 31207db96d56Sopenharmony_ci for proto in range(pickle.HIGHEST_PROTOCOL + 1): 31217db96d56Sopenharmony_ci with self.subTest(proto=proto): 31227db96d56Sopenharmony_ci obj = self.FrozenSlotsSetStateClass('a', 1) 31237db96d56Sopenharmony_ci obj2 = pickle.loads(pickle.dumps(obj, protocol=proto)) 31247db96d56Sopenharmony_ci 31257db96d56Sopenharmony_ci self.assertTrue(obj2.setstate_called) 31267db96d56Sopenharmony_ci self.assertEqual(obj, obj2) 31277db96d56Sopenharmony_ci 31287db96d56Sopenharmony_ci for proto in range(pickle.HIGHEST_PROTOCOL + 1): 31297db96d56Sopenharmony_ci with self.subTest(proto=proto): 31307db96d56Sopenharmony_ci obj = self.FrozenSlotsAllStateClass('a', 1) 31317db96d56Sopenharmony_ci dumped = pickle.dumps(obj, protocol=proto) 31327db96d56Sopenharmony_ci 31337db96d56Sopenharmony_ci self.assertTrue(obj.getstate_called) 31347db96d56Sopenharmony_ci 31357db96d56Sopenharmony_ci obj2 = pickle.loads(dumped) 31367db96d56Sopenharmony_ci self.assertTrue(obj2.setstate_called) 31377db96d56Sopenharmony_ci self.assertEqual(obj, obj2) 31387db96d56Sopenharmony_ci 31397db96d56Sopenharmony_ci def test_slots_with_default_no_init(self): 31407db96d56Sopenharmony_ci # Originally reported in bpo-44649. 31417db96d56Sopenharmony_ci @dataclass(slots=True) 31427db96d56Sopenharmony_ci class A: 31437db96d56Sopenharmony_ci a: str 31447db96d56Sopenharmony_ci b: str = field(default='b', init=False) 31457db96d56Sopenharmony_ci 31467db96d56Sopenharmony_ci obj = A("a") 31477db96d56Sopenharmony_ci self.assertEqual(obj.a, 'a') 31487db96d56Sopenharmony_ci self.assertEqual(obj.b, 'b') 31497db96d56Sopenharmony_ci 31507db96d56Sopenharmony_ci def test_slots_with_default_factory_no_init(self): 31517db96d56Sopenharmony_ci # Originally reported in bpo-44649. 31527db96d56Sopenharmony_ci @dataclass(slots=True) 31537db96d56Sopenharmony_ci class A: 31547db96d56Sopenharmony_ci a: str 31557db96d56Sopenharmony_ci b: str = field(default_factory=lambda:'b', init=False) 31567db96d56Sopenharmony_ci 31577db96d56Sopenharmony_ci obj = A("a") 31587db96d56Sopenharmony_ci self.assertEqual(obj.a, 'a') 31597db96d56Sopenharmony_ci self.assertEqual(obj.b, 'b') 31607db96d56Sopenharmony_ci 31617db96d56Sopenharmony_ci def test_slots_no_weakref(self): 31627db96d56Sopenharmony_ci @dataclass(slots=True) 31637db96d56Sopenharmony_ci class A: 31647db96d56Sopenharmony_ci # No weakref. 31657db96d56Sopenharmony_ci pass 31667db96d56Sopenharmony_ci 31677db96d56Sopenharmony_ci self.assertNotIn("__weakref__", A.__slots__) 31687db96d56Sopenharmony_ci a = A() 31697db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 31707db96d56Sopenharmony_ci "cannot create weak reference"): 31717db96d56Sopenharmony_ci weakref.ref(a) 31727db96d56Sopenharmony_ci with self.assertRaises(AttributeError): 31737db96d56Sopenharmony_ci a.__weakref__ 31747db96d56Sopenharmony_ci 31757db96d56Sopenharmony_ci def test_slots_weakref(self): 31767db96d56Sopenharmony_ci @dataclass(slots=True, weakref_slot=True) 31777db96d56Sopenharmony_ci class A: 31787db96d56Sopenharmony_ci a: int 31797db96d56Sopenharmony_ci 31807db96d56Sopenharmony_ci self.assertIn("__weakref__", A.__slots__) 31817db96d56Sopenharmony_ci a = A(1) 31827db96d56Sopenharmony_ci a_ref = weakref.ref(a) 31837db96d56Sopenharmony_ci 31847db96d56Sopenharmony_ci self.assertIs(a.__weakref__, a_ref) 31857db96d56Sopenharmony_ci 31867db96d56Sopenharmony_ci def test_slots_weakref_base_str(self): 31877db96d56Sopenharmony_ci class Base: 31887db96d56Sopenharmony_ci __slots__ = '__weakref__' 31897db96d56Sopenharmony_ci 31907db96d56Sopenharmony_ci @dataclass(slots=True) 31917db96d56Sopenharmony_ci class A(Base): 31927db96d56Sopenharmony_ci a: int 31937db96d56Sopenharmony_ci 31947db96d56Sopenharmony_ci # __weakref__ is in the base class, not A. But an A is still weakref-able. 31957db96d56Sopenharmony_ci self.assertIn("__weakref__", Base.__slots__) 31967db96d56Sopenharmony_ci self.assertNotIn("__weakref__", A.__slots__) 31977db96d56Sopenharmony_ci a = A(1) 31987db96d56Sopenharmony_ci weakref.ref(a) 31997db96d56Sopenharmony_ci 32007db96d56Sopenharmony_ci def test_slots_weakref_base_tuple(self): 32017db96d56Sopenharmony_ci # Same as test_slots_weakref_base, but use a tuple instead of a string 32027db96d56Sopenharmony_ci # in the base class. 32037db96d56Sopenharmony_ci class Base: 32047db96d56Sopenharmony_ci __slots__ = ('__weakref__',) 32057db96d56Sopenharmony_ci 32067db96d56Sopenharmony_ci @dataclass(slots=True) 32077db96d56Sopenharmony_ci class A(Base): 32087db96d56Sopenharmony_ci a: int 32097db96d56Sopenharmony_ci 32107db96d56Sopenharmony_ci # __weakref__ is in the base class, not A. But an A is still 32117db96d56Sopenharmony_ci # weakref-able. 32127db96d56Sopenharmony_ci self.assertIn("__weakref__", Base.__slots__) 32137db96d56Sopenharmony_ci self.assertNotIn("__weakref__", A.__slots__) 32147db96d56Sopenharmony_ci a = A(1) 32157db96d56Sopenharmony_ci weakref.ref(a) 32167db96d56Sopenharmony_ci 32177db96d56Sopenharmony_ci def test_weakref_slot_without_slot(self): 32187db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 32197db96d56Sopenharmony_ci "weakref_slot is True but slots is False"): 32207db96d56Sopenharmony_ci @dataclass(weakref_slot=True) 32217db96d56Sopenharmony_ci class A: 32227db96d56Sopenharmony_ci a: int 32237db96d56Sopenharmony_ci 32247db96d56Sopenharmony_ci def test_weakref_slot_make_dataclass(self): 32257db96d56Sopenharmony_ci A = make_dataclass('A', [('a', int),], slots=True, weakref_slot=True) 32267db96d56Sopenharmony_ci self.assertIn("__weakref__", A.__slots__) 32277db96d56Sopenharmony_ci a = A(1) 32287db96d56Sopenharmony_ci weakref.ref(a) 32297db96d56Sopenharmony_ci 32307db96d56Sopenharmony_ci # And make sure if raises if slots=True is not given. 32317db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 32327db96d56Sopenharmony_ci "weakref_slot is True but slots is False"): 32337db96d56Sopenharmony_ci B = make_dataclass('B', [('a', int),], weakref_slot=True) 32347db96d56Sopenharmony_ci 32357db96d56Sopenharmony_ci def test_weakref_slot_subclass_weakref_slot(self): 32367db96d56Sopenharmony_ci @dataclass(slots=True, weakref_slot=True) 32377db96d56Sopenharmony_ci class Base: 32387db96d56Sopenharmony_ci field: int 32397db96d56Sopenharmony_ci 32407db96d56Sopenharmony_ci # A *can* also specify weakref_slot=True if it wants to (gh-93521) 32417db96d56Sopenharmony_ci @dataclass(slots=True, weakref_slot=True) 32427db96d56Sopenharmony_ci class A(Base): 32437db96d56Sopenharmony_ci ... 32447db96d56Sopenharmony_ci 32457db96d56Sopenharmony_ci # __weakref__ is in the base class, not A. But an instance of A 32467db96d56Sopenharmony_ci # is still weakref-able. 32477db96d56Sopenharmony_ci self.assertIn("__weakref__", Base.__slots__) 32487db96d56Sopenharmony_ci self.assertNotIn("__weakref__", A.__slots__) 32497db96d56Sopenharmony_ci a = A(1) 32507db96d56Sopenharmony_ci a_ref = weakref.ref(a) 32517db96d56Sopenharmony_ci self.assertIs(a.__weakref__, a_ref) 32527db96d56Sopenharmony_ci 32537db96d56Sopenharmony_ci def test_weakref_slot_subclass_no_weakref_slot(self): 32547db96d56Sopenharmony_ci @dataclass(slots=True, weakref_slot=True) 32557db96d56Sopenharmony_ci class Base: 32567db96d56Sopenharmony_ci field: int 32577db96d56Sopenharmony_ci 32587db96d56Sopenharmony_ci @dataclass(slots=True) 32597db96d56Sopenharmony_ci class A(Base): 32607db96d56Sopenharmony_ci ... 32617db96d56Sopenharmony_ci 32627db96d56Sopenharmony_ci # __weakref__ is in the base class, not A. Even though A doesn't 32637db96d56Sopenharmony_ci # specify weakref_slot, it should still be weakref-able. 32647db96d56Sopenharmony_ci self.assertIn("__weakref__", Base.__slots__) 32657db96d56Sopenharmony_ci self.assertNotIn("__weakref__", A.__slots__) 32667db96d56Sopenharmony_ci a = A(1) 32677db96d56Sopenharmony_ci a_ref = weakref.ref(a) 32687db96d56Sopenharmony_ci self.assertIs(a.__weakref__, a_ref) 32697db96d56Sopenharmony_ci 32707db96d56Sopenharmony_ci def test_weakref_slot_normal_base_weakref_slot(self): 32717db96d56Sopenharmony_ci class Base: 32727db96d56Sopenharmony_ci __slots__ = ('__weakref__',) 32737db96d56Sopenharmony_ci 32747db96d56Sopenharmony_ci @dataclass(slots=True, weakref_slot=True) 32757db96d56Sopenharmony_ci class A(Base): 32767db96d56Sopenharmony_ci field: int 32777db96d56Sopenharmony_ci 32787db96d56Sopenharmony_ci # __weakref__ is in the base class, not A. But an instance of 32797db96d56Sopenharmony_ci # A is still weakref-able. 32807db96d56Sopenharmony_ci self.assertIn("__weakref__", Base.__slots__) 32817db96d56Sopenharmony_ci self.assertNotIn("__weakref__", A.__slots__) 32827db96d56Sopenharmony_ci a = A(1) 32837db96d56Sopenharmony_ci a_ref = weakref.ref(a) 32847db96d56Sopenharmony_ci self.assertIs(a.__weakref__, a_ref) 32857db96d56Sopenharmony_ci 32867db96d56Sopenharmony_ci 32877db96d56Sopenharmony_ciclass TestDescriptors(unittest.TestCase): 32887db96d56Sopenharmony_ci def test_set_name(self): 32897db96d56Sopenharmony_ci # See bpo-33141. 32907db96d56Sopenharmony_ci 32917db96d56Sopenharmony_ci # Create a descriptor. 32927db96d56Sopenharmony_ci class D: 32937db96d56Sopenharmony_ci def __set_name__(self, owner, name): 32947db96d56Sopenharmony_ci self.name = name + 'x' 32957db96d56Sopenharmony_ci def __get__(self, instance, owner): 32967db96d56Sopenharmony_ci if instance is not None: 32977db96d56Sopenharmony_ci return 1 32987db96d56Sopenharmony_ci return self 32997db96d56Sopenharmony_ci 33007db96d56Sopenharmony_ci # This is the case of just normal descriptor behavior, no 33017db96d56Sopenharmony_ci # dataclass code is involved in initializing the descriptor. 33027db96d56Sopenharmony_ci @dataclass 33037db96d56Sopenharmony_ci class C: 33047db96d56Sopenharmony_ci c: int=D() 33057db96d56Sopenharmony_ci self.assertEqual(C.c.name, 'cx') 33067db96d56Sopenharmony_ci 33077db96d56Sopenharmony_ci # Now test with a default value and init=False, which is the 33087db96d56Sopenharmony_ci # only time this is really meaningful. If not using 33097db96d56Sopenharmony_ci # init=False, then the descriptor will be overwritten, anyway. 33107db96d56Sopenharmony_ci @dataclass 33117db96d56Sopenharmony_ci class C: 33127db96d56Sopenharmony_ci c: int=field(default=D(), init=False) 33137db96d56Sopenharmony_ci self.assertEqual(C.c.name, 'cx') 33147db96d56Sopenharmony_ci self.assertEqual(C().c, 1) 33157db96d56Sopenharmony_ci 33167db96d56Sopenharmony_ci def test_non_descriptor(self): 33177db96d56Sopenharmony_ci # PEP 487 says __set_name__ should work on non-descriptors. 33187db96d56Sopenharmony_ci # Create a descriptor. 33197db96d56Sopenharmony_ci 33207db96d56Sopenharmony_ci class D: 33217db96d56Sopenharmony_ci def __set_name__(self, owner, name): 33227db96d56Sopenharmony_ci self.name = name + 'x' 33237db96d56Sopenharmony_ci 33247db96d56Sopenharmony_ci @dataclass 33257db96d56Sopenharmony_ci class C: 33267db96d56Sopenharmony_ci c: int=field(default=D(), init=False) 33277db96d56Sopenharmony_ci self.assertEqual(C.c.name, 'cx') 33287db96d56Sopenharmony_ci 33297db96d56Sopenharmony_ci def test_lookup_on_instance(self): 33307db96d56Sopenharmony_ci # See bpo-33175. 33317db96d56Sopenharmony_ci class D: 33327db96d56Sopenharmony_ci pass 33337db96d56Sopenharmony_ci 33347db96d56Sopenharmony_ci d = D() 33357db96d56Sopenharmony_ci # Create an attribute on the instance, not type. 33367db96d56Sopenharmony_ci d.__set_name__ = Mock() 33377db96d56Sopenharmony_ci 33387db96d56Sopenharmony_ci # Make sure d.__set_name__ is not called. 33397db96d56Sopenharmony_ci @dataclass 33407db96d56Sopenharmony_ci class C: 33417db96d56Sopenharmony_ci i: int=field(default=d, init=False) 33427db96d56Sopenharmony_ci 33437db96d56Sopenharmony_ci self.assertEqual(d.__set_name__.call_count, 0) 33447db96d56Sopenharmony_ci 33457db96d56Sopenharmony_ci def test_lookup_on_class(self): 33467db96d56Sopenharmony_ci # See bpo-33175. 33477db96d56Sopenharmony_ci class D: 33487db96d56Sopenharmony_ci pass 33497db96d56Sopenharmony_ci D.__set_name__ = Mock() 33507db96d56Sopenharmony_ci 33517db96d56Sopenharmony_ci # Make sure D.__set_name__ is called. 33527db96d56Sopenharmony_ci @dataclass 33537db96d56Sopenharmony_ci class C: 33547db96d56Sopenharmony_ci i: int=field(default=D(), init=False) 33557db96d56Sopenharmony_ci 33567db96d56Sopenharmony_ci self.assertEqual(D.__set_name__.call_count, 1) 33577db96d56Sopenharmony_ci 33587db96d56Sopenharmony_ci def test_init_calls_set(self): 33597db96d56Sopenharmony_ci class D: 33607db96d56Sopenharmony_ci pass 33617db96d56Sopenharmony_ci 33627db96d56Sopenharmony_ci D.__set__ = Mock() 33637db96d56Sopenharmony_ci 33647db96d56Sopenharmony_ci @dataclass 33657db96d56Sopenharmony_ci class C: 33667db96d56Sopenharmony_ci i: D = D() 33677db96d56Sopenharmony_ci 33687db96d56Sopenharmony_ci # Make sure D.__set__ is called. 33697db96d56Sopenharmony_ci D.__set__.reset_mock() 33707db96d56Sopenharmony_ci c = C(5) 33717db96d56Sopenharmony_ci self.assertEqual(D.__set__.call_count, 1) 33727db96d56Sopenharmony_ci 33737db96d56Sopenharmony_ci def test_getting_field_calls_get(self): 33747db96d56Sopenharmony_ci class D: 33757db96d56Sopenharmony_ci pass 33767db96d56Sopenharmony_ci 33777db96d56Sopenharmony_ci D.__set__ = Mock() 33787db96d56Sopenharmony_ci D.__get__ = Mock() 33797db96d56Sopenharmony_ci 33807db96d56Sopenharmony_ci @dataclass 33817db96d56Sopenharmony_ci class C: 33827db96d56Sopenharmony_ci i: D = D() 33837db96d56Sopenharmony_ci 33847db96d56Sopenharmony_ci c = C(5) 33857db96d56Sopenharmony_ci 33867db96d56Sopenharmony_ci # Make sure D.__get__ is called. 33877db96d56Sopenharmony_ci D.__get__.reset_mock() 33887db96d56Sopenharmony_ci value = c.i 33897db96d56Sopenharmony_ci self.assertEqual(D.__get__.call_count, 1) 33907db96d56Sopenharmony_ci 33917db96d56Sopenharmony_ci def test_setting_field_calls_set(self): 33927db96d56Sopenharmony_ci class D: 33937db96d56Sopenharmony_ci pass 33947db96d56Sopenharmony_ci 33957db96d56Sopenharmony_ci D.__set__ = Mock() 33967db96d56Sopenharmony_ci 33977db96d56Sopenharmony_ci @dataclass 33987db96d56Sopenharmony_ci class C: 33997db96d56Sopenharmony_ci i: D = D() 34007db96d56Sopenharmony_ci 34017db96d56Sopenharmony_ci c = C(5) 34027db96d56Sopenharmony_ci 34037db96d56Sopenharmony_ci # Make sure D.__set__ is called. 34047db96d56Sopenharmony_ci D.__set__.reset_mock() 34057db96d56Sopenharmony_ci c.i = 10 34067db96d56Sopenharmony_ci self.assertEqual(D.__set__.call_count, 1) 34077db96d56Sopenharmony_ci 34087db96d56Sopenharmony_ci def test_setting_uninitialized_descriptor_field(self): 34097db96d56Sopenharmony_ci class D: 34107db96d56Sopenharmony_ci pass 34117db96d56Sopenharmony_ci 34127db96d56Sopenharmony_ci D.__set__ = Mock() 34137db96d56Sopenharmony_ci 34147db96d56Sopenharmony_ci @dataclass 34157db96d56Sopenharmony_ci class C: 34167db96d56Sopenharmony_ci i: D 34177db96d56Sopenharmony_ci 34187db96d56Sopenharmony_ci # D.__set__ is not called because there's no D instance to call it on 34197db96d56Sopenharmony_ci D.__set__.reset_mock() 34207db96d56Sopenharmony_ci c = C(5) 34217db96d56Sopenharmony_ci self.assertEqual(D.__set__.call_count, 0) 34227db96d56Sopenharmony_ci 34237db96d56Sopenharmony_ci # D.__set__ still isn't called after setting i to an instance of D 34247db96d56Sopenharmony_ci # because descriptors don't behave like that when stored as instance vars 34257db96d56Sopenharmony_ci c.i = D() 34267db96d56Sopenharmony_ci c.i = 5 34277db96d56Sopenharmony_ci self.assertEqual(D.__set__.call_count, 0) 34287db96d56Sopenharmony_ci 34297db96d56Sopenharmony_ci def test_default_value(self): 34307db96d56Sopenharmony_ci class D: 34317db96d56Sopenharmony_ci def __get__(self, instance: Any, owner: object) -> int: 34327db96d56Sopenharmony_ci if instance is None: 34337db96d56Sopenharmony_ci return 100 34347db96d56Sopenharmony_ci 34357db96d56Sopenharmony_ci return instance._x 34367db96d56Sopenharmony_ci 34377db96d56Sopenharmony_ci def __set__(self, instance: Any, value: int) -> None: 34387db96d56Sopenharmony_ci instance._x = value 34397db96d56Sopenharmony_ci 34407db96d56Sopenharmony_ci @dataclass 34417db96d56Sopenharmony_ci class C: 34427db96d56Sopenharmony_ci i: D = D() 34437db96d56Sopenharmony_ci 34447db96d56Sopenharmony_ci c = C() 34457db96d56Sopenharmony_ci self.assertEqual(c.i, 100) 34467db96d56Sopenharmony_ci 34477db96d56Sopenharmony_ci c = C(5) 34487db96d56Sopenharmony_ci self.assertEqual(c.i, 5) 34497db96d56Sopenharmony_ci 34507db96d56Sopenharmony_ci def test_no_default_value(self): 34517db96d56Sopenharmony_ci class D: 34527db96d56Sopenharmony_ci def __get__(self, instance: Any, owner: object) -> int: 34537db96d56Sopenharmony_ci if instance is None: 34547db96d56Sopenharmony_ci raise AttributeError() 34557db96d56Sopenharmony_ci 34567db96d56Sopenharmony_ci return instance._x 34577db96d56Sopenharmony_ci 34587db96d56Sopenharmony_ci def __set__(self, instance: Any, value: int) -> None: 34597db96d56Sopenharmony_ci instance._x = value 34607db96d56Sopenharmony_ci 34617db96d56Sopenharmony_ci @dataclass 34627db96d56Sopenharmony_ci class C: 34637db96d56Sopenharmony_ci i: D = D() 34647db96d56Sopenharmony_ci 34657db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'missing 1 required positional argument'): 34667db96d56Sopenharmony_ci c = C() 34677db96d56Sopenharmony_ci 34687db96d56Sopenharmony_ciclass TestStringAnnotations(unittest.TestCase): 34697db96d56Sopenharmony_ci def test_classvar(self): 34707db96d56Sopenharmony_ci # Some expressions recognized as ClassVar really aren't. But 34717db96d56Sopenharmony_ci # if you're using string annotations, it's not an exact 34727db96d56Sopenharmony_ci # science. 34737db96d56Sopenharmony_ci # These tests assume that both "import typing" and "from 34747db96d56Sopenharmony_ci # typing import *" have been run in this file. 34757db96d56Sopenharmony_ci for typestr in ('ClassVar[int]', 34767db96d56Sopenharmony_ci 'ClassVar [int]', 34777db96d56Sopenharmony_ci ' ClassVar [int]', 34787db96d56Sopenharmony_ci 'ClassVar', 34797db96d56Sopenharmony_ci ' ClassVar ', 34807db96d56Sopenharmony_ci 'typing.ClassVar[int]', 34817db96d56Sopenharmony_ci 'typing.ClassVar[str]', 34827db96d56Sopenharmony_ci ' typing.ClassVar[str]', 34837db96d56Sopenharmony_ci 'typing .ClassVar[str]', 34847db96d56Sopenharmony_ci 'typing. ClassVar[str]', 34857db96d56Sopenharmony_ci 'typing.ClassVar [str]', 34867db96d56Sopenharmony_ci 'typing.ClassVar [ str]', 34877db96d56Sopenharmony_ci 34887db96d56Sopenharmony_ci # Not syntactically valid, but these will 34897db96d56Sopenharmony_ci # be treated as ClassVars. 34907db96d56Sopenharmony_ci 'typing.ClassVar.[int]', 34917db96d56Sopenharmony_ci 'typing.ClassVar+', 34927db96d56Sopenharmony_ci ): 34937db96d56Sopenharmony_ci with self.subTest(typestr=typestr): 34947db96d56Sopenharmony_ci @dataclass 34957db96d56Sopenharmony_ci class C: 34967db96d56Sopenharmony_ci x: typestr 34977db96d56Sopenharmony_ci 34987db96d56Sopenharmony_ci # x is a ClassVar, so C() takes no args. 34997db96d56Sopenharmony_ci C() 35007db96d56Sopenharmony_ci 35017db96d56Sopenharmony_ci # And it won't appear in the class's dict because it doesn't 35027db96d56Sopenharmony_ci # have a default. 35037db96d56Sopenharmony_ci self.assertNotIn('x', C.__dict__) 35047db96d56Sopenharmony_ci 35057db96d56Sopenharmony_ci def test_isnt_classvar(self): 35067db96d56Sopenharmony_ci for typestr in ('CV', 35077db96d56Sopenharmony_ci 't.ClassVar', 35087db96d56Sopenharmony_ci 't.ClassVar[int]', 35097db96d56Sopenharmony_ci 'typing..ClassVar[int]', 35107db96d56Sopenharmony_ci 'Classvar', 35117db96d56Sopenharmony_ci 'Classvar[int]', 35127db96d56Sopenharmony_ci 'typing.ClassVarx[int]', 35137db96d56Sopenharmony_ci 'typong.ClassVar[int]', 35147db96d56Sopenharmony_ci 'dataclasses.ClassVar[int]', 35157db96d56Sopenharmony_ci 'typingxClassVar[str]', 35167db96d56Sopenharmony_ci ): 35177db96d56Sopenharmony_ci with self.subTest(typestr=typestr): 35187db96d56Sopenharmony_ci @dataclass 35197db96d56Sopenharmony_ci class C: 35207db96d56Sopenharmony_ci x: typestr 35217db96d56Sopenharmony_ci 35227db96d56Sopenharmony_ci # x is not a ClassVar, so C() takes one arg. 35237db96d56Sopenharmony_ci self.assertEqual(C(10).x, 10) 35247db96d56Sopenharmony_ci 35257db96d56Sopenharmony_ci def test_initvar(self): 35267db96d56Sopenharmony_ci # These tests assume that both "import dataclasses" and "from 35277db96d56Sopenharmony_ci # dataclasses import *" have been run in this file. 35287db96d56Sopenharmony_ci for typestr in ('InitVar[int]', 35297db96d56Sopenharmony_ci 'InitVar [int]' 35307db96d56Sopenharmony_ci ' InitVar [int]', 35317db96d56Sopenharmony_ci 'InitVar', 35327db96d56Sopenharmony_ci ' InitVar ', 35337db96d56Sopenharmony_ci 'dataclasses.InitVar[int]', 35347db96d56Sopenharmony_ci 'dataclasses.InitVar[str]', 35357db96d56Sopenharmony_ci ' dataclasses.InitVar[str]', 35367db96d56Sopenharmony_ci 'dataclasses .InitVar[str]', 35377db96d56Sopenharmony_ci 'dataclasses. InitVar[str]', 35387db96d56Sopenharmony_ci 'dataclasses.InitVar [str]', 35397db96d56Sopenharmony_ci 'dataclasses.InitVar [ str]', 35407db96d56Sopenharmony_ci 35417db96d56Sopenharmony_ci # Not syntactically valid, but these will 35427db96d56Sopenharmony_ci # be treated as InitVars. 35437db96d56Sopenharmony_ci 'dataclasses.InitVar.[int]', 35447db96d56Sopenharmony_ci 'dataclasses.InitVar+', 35457db96d56Sopenharmony_ci ): 35467db96d56Sopenharmony_ci with self.subTest(typestr=typestr): 35477db96d56Sopenharmony_ci @dataclass 35487db96d56Sopenharmony_ci class C: 35497db96d56Sopenharmony_ci x: typestr 35507db96d56Sopenharmony_ci 35517db96d56Sopenharmony_ci # x is an InitVar, so doesn't create a member. 35527db96d56Sopenharmony_ci with self.assertRaisesRegex(AttributeError, 35537db96d56Sopenharmony_ci "object has no attribute 'x'"): 35547db96d56Sopenharmony_ci C(1).x 35557db96d56Sopenharmony_ci 35567db96d56Sopenharmony_ci def test_isnt_initvar(self): 35577db96d56Sopenharmony_ci for typestr in ('IV', 35587db96d56Sopenharmony_ci 'dc.InitVar', 35597db96d56Sopenharmony_ci 'xdataclasses.xInitVar', 35607db96d56Sopenharmony_ci 'typing.xInitVar[int]', 35617db96d56Sopenharmony_ci ): 35627db96d56Sopenharmony_ci with self.subTest(typestr=typestr): 35637db96d56Sopenharmony_ci @dataclass 35647db96d56Sopenharmony_ci class C: 35657db96d56Sopenharmony_ci x: typestr 35667db96d56Sopenharmony_ci 35677db96d56Sopenharmony_ci # x is not an InitVar, so there will be a member x. 35687db96d56Sopenharmony_ci self.assertEqual(C(10).x, 10) 35697db96d56Sopenharmony_ci 35707db96d56Sopenharmony_ci def test_classvar_module_level_import(self): 35717db96d56Sopenharmony_ci from test import dataclass_module_1 35727db96d56Sopenharmony_ci from test import dataclass_module_1_str 35737db96d56Sopenharmony_ci from test import dataclass_module_2 35747db96d56Sopenharmony_ci from test import dataclass_module_2_str 35757db96d56Sopenharmony_ci 35767db96d56Sopenharmony_ci for m in (dataclass_module_1, dataclass_module_1_str, 35777db96d56Sopenharmony_ci dataclass_module_2, dataclass_module_2_str, 35787db96d56Sopenharmony_ci ): 35797db96d56Sopenharmony_ci with self.subTest(m=m): 35807db96d56Sopenharmony_ci # There's a difference in how the ClassVars are 35817db96d56Sopenharmony_ci # interpreted when using string annotations or 35827db96d56Sopenharmony_ci # not. See the imported modules for details. 35837db96d56Sopenharmony_ci if m.USING_STRINGS: 35847db96d56Sopenharmony_ci c = m.CV(10) 35857db96d56Sopenharmony_ci else: 35867db96d56Sopenharmony_ci c = m.CV() 35877db96d56Sopenharmony_ci self.assertEqual(c.cv0, 20) 35887db96d56Sopenharmony_ci 35897db96d56Sopenharmony_ci 35907db96d56Sopenharmony_ci # There's a difference in how the InitVars are 35917db96d56Sopenharmony_ci # interpreted when using string annotations or 35927db96d56Sopenharmony_ci # not. See the imported modules for details. 35937db96d56Sopenharmony_ci c = m.IV(0, 1, 2, 3, 4) 35947db96d56Sopenharmony_ci 35957db96d56Sopenharmony_ci for field_name in ('iv0', 'iv1', 'iv2', 'iv3'): 35967db96d56Sopenharmony_ci with self.subTest(field_name=field_name): 35977db96d56Sopenharmony_ci with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"): 35987db96d56Sopenharmony_ci # Since field_name is an InitVar, it's 35997db96d56Sopenharmony_ci # not an instance field. 36007db96d56Sopenharmony_ci getattr(c, field_name) 36017db96d56Sopenharmony_ci 36027db96d56Sopenharmony_ci if m.USING_STRINGS: 36037db96d56Sopenharmony_ci # iv4 is interpreted as a normal field. 36047db96d56Sopenharmony_ci self.assertIn('not_iv4', c.__dict__) 36057db96d56Sopenharmony_ci self.assertEqual(c.not_iv4, 4) 36067db96d56Sopenharmony_ci else: 36077db96d56Sopenharmony_ci # iv4 is interpreted as an InitVar, so it 36087db96d56Sopenharmony_ci # won't exist on the instance. 36097db96d56Sopenharmony_ci self.assertNotIn('not_iv4', c.__dict__) 36107db96d56Sopenharmony_ci 36117db96d56Sopenharmony_ci def test_text_annotations(self): 36127db96d56Sopenharmony_ci from test import dataclass_textanno 36137db96d56Sopenharmony_ci 36147db96d56Sopenharmony_ci self.assertEqual( 36157db96d56Sopenharmony_ci get_type_hints(dataclass_textanno.Bar), 36167db96d56Sopenharmony_ci {'foo': dataclass_textanno.Foo}) 36177db96d56Sopenharmony_ci self.assertEqual( 36187db96d56Sopenharmony_ci get_type_hints(dataclass_textanno.Bar.__init__), 36197db96d56Sopenharmony_ci {'foo': dataclass_textanno.Foo, 36207db96d56Sopenharmony_ci 'return': type(None)}) 36217db96d56Sopenharmony_ci 36227db96d56Sopenharmony_ci 36237db96d56Sopenharmony_ciclass TestMakeDataclass(unittest.TestCase): 36247db96d56Sopenharmony_ci def test_simple(self): 36257db96d56Sopenharmony_ci C = make_dataclass('C', 36267db96d56Sopenharmony_ci [('x', int), 36277db96d56Sopenharmony_ci ('y', int, field(default=5))], 36287db96d56Sopenharmony_ci namespace={'add_one': lambda self: self.x + 1}) 36297db96d56Sopenharmony_ci c = C(10) 36307db96d56Sopenharmony_ci self.assertEqual((c.x, c.y), (10, 5)) 36317db96d56Sopenharmony_ci self.assertEqual(c.add_one(), 11) 36327db96d56Sopenharmony_ci 36337db96d56Sopenharmony_ci 36347db96d56Sopenharmony_ci def test_no_mutate_namespace(self): 36357db96d56Sopenharmony_ci # Make sure a provided namespace isn't mutated. 36367db96d56Sopenharmony_ci ns = {} 36377db96d56Sopenharmony_ci C = make_dataclass('C', 36387db96d56Sopenharmony_ci [('x', int), 36397db96d56Sopenharmony_ci ('y', int, field(default=5))], 36407db96d56Sopenharmony_ci namespace=ns) 36417db96d56Sopenharmony_ci self.assertEqual(ns, {}) 36427db96d56Sopenharmony_ci 36437db96d56Sopenharmony_ci def test_base(self): 36447db96d56Sopenharmony_ci class Base1: 36457db96d56Sopenharmony_ci pass 36467db96d56Sopenharmony_ci class Base2: 36477db96d56Sopenharmony_ci pass 36487db96d56Sopenharmony_ci C = make_dataclass('C', 36497db96d56Sopenharmony_ci [('x', int)], 36507db96d56Sopenharmony_ci bases=(Base1, Base2)) 36517db96d56Sopenharmony_ci c = C(2) 36527db96d56Sopenharmony_ci self.assertIsInstance(c, C) 36537db96d56Sopenharmony_ci self.assertIsInstance(c, Base1) 36547db96d56Sopenharmony_ci self.assertIsInstance(c, Base2) 36557db96d56Sopenharmony_ci 36567db96d56Sopenharmony_ci def test_base_dataclass(self): 36577db96d56Sopenharmony_ci @dataclass 36587db96d56Sopenharmony_ci class Base1: 36597db96d56Sopenharmony_ci x: int 36607db96d56Sopenharmony_ci class Base2: 36617db96d56Sopenharmony_ci pass 36627db96d56Sopenharmony_ci C = make_dataclass('C', 36637db96d56Sopenharmony_ci [('y', int)], 36647db96d56Sopenharmony_ci bases=(Base1, Base2)) 36657db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'required positional'): 36667db96d56Sopenharmony_ci c = C(2) 36677db96d56Sopenharmony_ci c = C(1, 2) 36687db96d56Sopenharmony_ci self.assertIsInstance(c, C) 36697db96d56Sopenharmony_ci self.assertIsInstance(c, Base1) 36707db96d56Sopenharmony_ci self.assertIsInstance(c, Base2) 36717db96d56Sopenharmony_ci 36727db96d56Sopenharmony_ci self.assertEqual((c.x, c.y), (1, 2)) 36737db96d56Sopenharmony_ci 36747db96d56Sopenharmony_ci def test_init_var(self): 36757db96d56Sopenharmony_ci def post_init(self, y): 36767db96d56Sopenharmony_ci self.x *= y 36777db96d56Sopenharmony_ci 36787db96d56Sopenharmony_ci C = make_dataclass('C', 36797db96d56Sopenharmony_ci [('x', int), 36807db96d56Sopenharmony_ci ('y', InitVar[int]), 36817db96d56Sopenharmony_ci ], 36827db96d56Sopenharmony_ci namespace={'__post_init__': post_init}, 36837db96d56Sopenharmony_ci ) 36847db96d56Sopenharmony_ci c = C(2, 3) 36857db96d56Sopenharmony_ci self.assertEqual(vars(c), {'x': 6}) 36867db96d56Sopenharmony_ci self.assertEqual(len(fields(c)), 1) 36877db96d56Sopenharmony_ci 36887db96d56Sopenharmony_ci def test_class_var(self): 36897db96d56Sopenharmony_ci C = make_dataclass('C', 36907db96d56Sopenharmony_ci [('x', int), 36917db96d56Sopenharmony_ci ('y', ClassVar[int], 10), 36927db96d56Sopenharmony_ci ('z', ClassVar[int], field(default=20)), 36937db96d56Sopenharmony_ci ]) 36947db96d56Sopenharmony_ci c = C(1) 36957db96d56Sopenharmony_ci self.assertEqual(vars(c), {'x': 1}) 36967db96d56Sopenharmony_ci self.assertEqual(len(fields(c)), 1) 36977db96d56Sopenharmony_ci self.assertEqual(C.y, 10) 36987db96d56Sopenharmony_ci self.assertEqual(C.z, 20) 36997db96d56Sopenharmony_ci 37007db96d56Sopenharmony_ci def test_other_params(self): 37017db96d56Sopenharmony_ci C = make_dataclass('C', 37027db96d56Sopenharmony_ci [('x', int), 37037db96d56Sopenharmony_ci ('y', ClassVar[int], 10), 37047db96d56Sopenharmony_ci ('z', ClassVar[int], field(default=20)), 37057db96d56Sopenharmony_ci ], 37067db96d56Sopenharmony_ci init=False) 37077db96d56Sopenharmony_ci # Make sure we have a repr, but no init. 37087db96d56Sopenharmony_ci self.assertNotIn('__init__', vars(C)) 37097db96d56Sopenharmony_ci self.assertIn('__repr__', vars(C)) 37107db96d56Sopenharmony_ci 37117db96d56Sopenharmony_ci # Make sure random other params don't work. 37127db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'): 37137db96d56Sopenharmony_ci C = make_dataclass('C', 37147db96d56Sopenharmony_ci [], 37157db96d56Sopenharmony_ci xxinit=False) 37167db96d56Sopenharmony_ci 37177db96d56Sopenharmony_ci def test_no_types(self): 37187db96d56Sopenharmony_ci C = make_dataclass('Point', ['x', 'y', 'z']) 37197db96d56Sopenharmony_ci c = C(1, 2, 3) 37207db96d56Sopenharmony_ci self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) 37217db96d56Sopenharmony_ci self.assertEqual(C.__annotations__, {'x': 'typing.Any', 37227db96d56Sopenharmony_ci 'y': 'typing.Any', 37237db96d56Sopenharmony_ci 'z': 'typing.Any'}) 37247db96d56Sopenharmony_ci 37257db96d56Sopenharmony_ci C = make_dataclass('Point', ['x', ('y', int), 'z']) 37267db96d56Sopenharmony_ci c = C(1, 2, 3) 37277db96d56Sopenharmony_ci self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) 37287db96d56Sopenharmony_ci self.assertEqual(C.__annotations__, {'x': 'typing.Any', 37297db96d56Sopenharmony_ci 'y': int, 37307db96d56Sopenharmony_ci 'z': 'typing.Any'}) 37317db96d56Sopenharmony_ci 37327db96d56Sopenharmony_ci def test_invalid_type_specification(self): 37337db96d56Sopenharmony_ci for bad_field in [(), 37347db96d56Sopenharmony_ci (1, 2, 3, 4), 37357db96d56Sopenharmony_ci ]: 37367db96d56Sopenharmony_ci with self.subTest(bad_field=bad_field): 37377db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, r'Invalid field: '): 37387db96d56Sopenharmony_ci make_dataclass('C', ['a', bad_field]) 37397db96d56Sopenharmony_ci 37407db96d56Sopenharmony_ci # And test for things with no len(). 37417db96d56Sopenharmony_ci for bad_field in [float, 37427db96d56Sopenharmony_ci lambda x:x, 37437db96d56Sopenharmony_ci ]: 37447db96d56Sopenharmony_ci with self.subTest(bad_field=bad_field): 37457db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, r'has no len\(\)'): 37467db96d56Sopenharmony_ci make_dataclass('C', ['a', bad_field]) 37477db96d56Sopenharmony_ci 37487db96d56Sopenharmony_ci def test_duplicate_field_names(self): 37497db96d56Sopenharmony_ci for field in ['a', 'ab']: 37507db96d56Sopenharmony_ci with self.subTest(field=field): 37517db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'Field name duplicated'): 37527db96d56Sopenharmony_ci make_dataclass('C', [field, 'a', field]) 37537db96d56Sopenharmony_ci 37547db96d56Sopenharmony_ci def test_keyword_field_names(self): 37557db96d56Sopenharmony_ci for field in ['for', 'async', 'await', 'as']: 37567db96d56Sopenharmony_ci with self.subTest(field=field): 37577db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'must not be keywords'): 37587db96d56Sopenharmony_ci make_dataclass('C', ['a', field]) 37597db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'must not be keywords'): 37607db96d56Sopenharmony_ci make_dataclass('C', [field]) 37617db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'must not be keywords'): 37627db96d56Sopenharmony_ci make_dataclass('C', [field, 'a']) 37637db96d56Sopenharmony_ci 37647db96d56Sopenharmony_ci def test_non_identifier_field_names(self): 37657db96d56Sopenharmony_ci for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']: 37667db96d56Sopenharmony_ci with self.subTest(field=field): 37677db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 37687db96d56Sopenharmony_ci make_dataclass('C', ['a', field]) 37697db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 37707db96d56Sopenharmony_ci make_dataclass('C', [field]) 37717db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 37727db96d56Sopenharmony_ci make_dataclass('C', [field, 'a']) 37737db96d56Sopenharmony_ci 37747db96d56Sopenharmony_ci def test_underscore_field_names(self): 37757db96d56Sopenharmony_ci # Unlike namedtuple, it's okay if dataclass field names have 37767db96d56Sopenharmony_ci # an underscore. 37777db96d56Sopenharmony_ci make_dataclass('C', ['_', '_a', 'a_a', 'a_']) 37787db96d56Sopenharmony_ci 37797db96d56Sopenharmony_ci def test_funny_class_names_names(self): 37807db96d56Sopenharmony_ci # No reason to prevent weird class names, since 37817db96d56Sopenharmony_ci # types.new_class allows them. 37827db96d56Sopenharmony_ci for classname in ['()', 'x,y', '*', '2@3', '']: 37837db96d56Sopenharmony_ci with self.subTest(classname=classname): 37847db96d56Sopenharmony_ci C = make_dataclass(classname, ['a', 'b']) 37857db96d56Sopenharmony_ci self.assertEqual(C.__name__, classname) 37867db96d56Sopenharmony_ci 37877db96d56Sopenharmony_ciclass TestReplace(unittest.TestCase): 37887db96d56Sopenharmony_ci def test(self): 37897db96d56Sopenharmony_ci @dataclass(frozen=True) 37907db96d56Sopenharmony_ci class C: 37917db96d56Sopenharmony_ci x: int 37927db96d56Sopenharmony_ci y: int 37937db96d56Sopenharmony_ci 37947db96d56Sopenharmony_ci c = C(1, 2) 37957db96d56Sopenharmony_ci c1 = replace(c, x=3) 37967db96d56Sopenharmony_ci self.assertEqual(c1.x, 3) 37977db96d56Sopenharmony_ci self.assertEqual(c1.y, 2) 37987db96d56Sopenharmony_ci 37997db96d56Sopenharmony_ci def test_frozen(self): 38007db96d56Sopenharmony_ci @dataclass(frozen=True) 38017db96d56Sopenharmony_ci class C: 38027db96d56Sopenharmony_ci x: int 38037db96d56Sopenharmony_ci y: int 38047db96d56Sopenharmony_ci z: int = field(init=False, default=10) 38057db96d56Sopenharmony_ci t: int = field(init=False, default=100) 38067db96d56Sopenharmony_ci 38077db96d56Sopenharmony_ci c = C(1, 2) 38087db96d56Sopenharmony_ci c1 = replace(c, x=3) 38097db96d56Sopenharmony_ci self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100)) 38107db96d56Sopenharmony_ci self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100)) 38117db96d56Sopenharmony_ci 38127db96d56Sopenharmony_ci 38137db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, 'init=False'): 38147db96d56Sopenharmony_ci replace(c, x=3, z=20, t=50) 38157db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, 'init=False'): 38167db96d56Sopenharmony_ci replace(c, z=20) 38177db96d56Sopenharmony_ci replace(c, x=3, z=20, t=50) 38187db96d56Sopenharmony_ci 38197db96d56Sopenharmony_ci # Make sure the result is still frozen. 38207db96d56Sopenharmony_ci with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"): 38217db96d56Sopenharmony_ci c1.x = 3 38227db96d56Sopenharmony_ci 38237db96d56Sopenharmony_ci # Make sure we can't replace an attribute that doesn't exist, 38247db96d56Sopenharmony_ci # if we're also replacing one that does exist. Test this 38257db96d56Sopenharmony_ci # here, because setting attributes on frozen instances is 38267db96d56Sopenharmony_ci # handled slightly differently from non-frozen ones. 38277db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " 38287db96d56Sopenharmony_ci "keyword argument 'a'"): 38297db96d56Sopenharmony_ci c1 = replace(c, x=20, a=5) 38307db96d56Sopenharmony_ci 38317db96d56Sopenharmony_ci def test_invalid_field_name(self): 38327db96d56Sopenharmony_ci @dataclass(frozen=True) 38337db96d56Sopenharmony_ci class C: 38347db96d56Sopenharmony_ci x: int 38357db96d56Sopenharmony_ci y: int 38367db96d56Sopenharmony_ci 38377db96d56Sopenharmony_ci c = C(1, 2) 38387db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " 38397db96d56Sopenharmony_ci "keyword argument 'z'"): 38407db96d56Sopenharmony_ci c1 = replace(c, z=3) 38417db96d56Sopenharmony_ci 38427db96d56Sopenharmony_ci def test_invalid_object(self): 38437db96d56Sopenharmony_ci @dataclass(frozen=True) 38447db96d56Sopenharmony_ci class C: 38457db96d56Sopenharmony_ci x: int 38467db96d56Sopenharmony_ci y: int 38477db96d56Sopenharmony_ci 38487db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'dataclass instance'): 38497db96d56Sopenharmony_ci replace(C, x=3) 38507db96d56Sopenharmony_ci 38517db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, 'dataclass instance'): 38527db96d56Sopenharmony_ci replace(0, x=3) 38537db96d56Sopenharmony_ci 38547db96d56Sopenharmony_ci def test_no_init(self): 38557db96d56Sopenharmony_ci @dataclass 38567db96d56Sopenharmony_ci class C: 38577db96d56Sopenharmony_ci x: int 38587db96d56Sopenharmony_ci y: int = field(init=False, default=10) 38597db96d56Sopenharmony_ci 38607db96d56Sopenharmony_ci c = C(1) 38617db96d56Sopenharmony_ci c.y = 20 38627db96d56Sopenharmony_ci 38637db96d56Sopenharmony_ci # Make sure y gets the default value. 38647db96d56Sopenharmony_ci c1 = replace(c, x=5) 38657db96d56Sopenharmony_ci self.assertEqual((c1.x, c1.y), (5, 10)) 38667db96d56Sopenharmony_ci 38677db96d56Sopenharmony_ci # Trying to replace y is an error. 38687db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, 'init=False'): 38697db96d56Sopenharmony_ci replace(c, x=2, y=30) 38707db96d56Sopenharmony_ci 38717db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, 'init=False'): 38727db96d56Sopenharmony_ci replace(c, y=30) 38737db96d56Sopenharmony_ci 38747db96d56Sopenharmony_ci def test_classvar(self): 38757db96d56Sopenharmony_ci @dataclass 38767db96d56Sopenharmony_ci class C: 38777db96d56Sopenharmony_ci x: int 38787db96d56Sopenharmony_ci y: ClassVar[int] = 1000 38797db96d56Sopenharmony_ci 38807db96d56Sopenharmony_ci c = C(1) 38817db96d56Sopenharmony_ci d = C(2) 38827db96d56Sopenharmony_ci 38837db96d56Sopenharmony_ci self.assertIs(c.y, d.y) 38847db96d56Sopenharmony_ci self.assertEqual(c.y, 1000) 38857db96d56Sopenharmony_ci 38867db96d56Sopenharmony_ci # Trying to replace y is an error: can't replace ClassVars. 38877db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, r"__init__\(\) got an " 38887db96d56Sopenharmony_ci "unexpected keyword argument 'y'"): 38897db96d56Sopenharmony_ci replace(c, y=30) 38907db96d56Sopenharmony_ci 38917db96d56Sopenharmony_ci replace(c, x=5) 38927db96d56Sopenharmony_ci 38937db96d56Sopenharmony_ci def test_initvar_is_specified(self): 38947db96d56Sopenharmony_ci @dataclass 38957db96d56Sopenharmony_ci class C: 38967db96d56Sopenharmony_ci x: int 38977db96d56Sopenharmony_ci y: InitVar[int] 38987db96d56Sopenharmony_ci 38997db96d56Sopenharmony_ci def __post_init__(self, y): 39007db96d56Sopenharmony_ci self.x *= y 39017db96d56Sopenharmony_ci 39027db96d56Sopenharmony_ci c = C(1, 10) 39037db96d56Sopenharmony_ci self.assertEqual(c.x, 10) 39047db96d56Sopenharmony_ci with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be " 39057db96d56Sopenharmony_ci "specified with replace()"): 39067db96d56Sopenharmony_ci replace(c, x=3) 39077db96d56Sopenharmony_ci c = replace(c, x=3, y=5) 39087db96d56Sopenharmony_ci self.assertEqual(c.x, 15) 39097db96d56Sopenharmony_ci 39107db96d56Sopenharmony_ci def test_initvar_with_default_value(self): 39117db96d56Sopenharmony_ci @dataclass 39127db96d56Sopenharmony_ci class C: 39137db96d56Sopenharmony_ci x: int 39147db96d56Sopenharmony_ci y: InitVar[int] = None 39157db96d56Sopenharmony_ci z: InitVar[int] = 42 39167db96d56Sopenharmony_ci 39177db96d56Sopenharmony_ci def __post_init__(self, y, z): 39187db96d56Sopenharmony_ci if y is not None: 39197db96d56Sopenharmony_ci self.x += y 39207db96d56Sopenharmony_ci if z is not None: 39217db96d56Sopenharmony_ci self.x += z 39227db96d56Sopenharmony_ci 39237db96d56Sopenharmony_ci c = C(x=1, y=10, z=1) 39247db96d56Sopenharmony_ci self.assertEqual(replace(c), C(x=12)) 39257db96d56Sopenharmony_ci self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42)) 39267db96d56Sopenharmony_ci self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1)) 39277db96d56Sopenharmony_ci 39287db96d56Sopenharmony_ci def test_recursive_repr(self): 39297db96d56Sopenharmony_ci @dataclass 39307db96d56Sopenharmony_ci class C: 39317db96d56Sopenharmony_ci f: "C" 39327db96d56Sopenharmony_ci 39337db96d56Sopenharmony_ci c = C(None) 39347db96d56Sopenharmony_ci c.f = c 39357db96d56Sopenharmony_ci self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)") 39367db96d56Sopenharmony_ci 39377db96d56Sopenharmony_ci def test_recursive_repr_two_attrs(self): 39387db96d56Sopenharmony_ci @dataclass 39397db96d56Sopenharmony_ci class C: 39407db96d56Sopenharmony_ci f: "C" 39417db96d56Sopenharmony_ci g: "C" 39427db96d56Sopenharmony_ci 39437db96d56Sopenharmony_ci c = C(None, None) 39447db96d56Sopenharmony_ci c.f = c 39457db96d56Sopenharmony_ci c.g = c 39467db96d56Sopenharmony_ci self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs" 39477db96d56Sopenharmony_ci ".<locals>.C(f=..., g=...)") 39487db96d56Sopenharmony_ci 39497db96d56Sopenharmony_ci def test_recursive_repr_indirection(self): 39507db96d56Sopenharmony_ci @dataclass 39517db96d56Sopenharmony_ci class C: 39527db96d56Sopenharmony_ci f: "D" 39537db96d56Sopenharmony_ci 39547db96d56Sopenharmony_ci @dataclass 39557db96d56Sopenharmony_ci class D: 39567db96d56Sopenharmony_ci f: "C" 39577db96d56Sopenharmony_ci 39587db96d56Sopenharmony_ci c = C(None) 39597db96d56Sopenharmony_ci d = D(None) 39607db96d56Sopenharmony_ci c.f = d 39617db96d56Sopenharmony_ci d.f = c 39627db96d56Sopenharmony_ci self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection" 39637db96d56Sopenharmony_ci ".<locals>.C(f=TestReplace.test_recursive_repr_indirection" 39647db96d56Sopenharmony_ci ".<locals>.D(f=...))") 39657db96d56Sopenharmony_ci 39667db96d56Sopenharmony_ci def test_recursive_repr_indirection_two(self): 39677db96d56Sopenharmony_ci @dataclass 39687db96d56Sopenharmony_ci class C: 39697db96d56Sopenharmony_ci f: "D" 39707db96d56Sopenharmony_ci 39717db96d56Sopenharmony_ci @dataclass 39727db96d56Sopenharmony_ci class D: 39737db96d56Sopenharmony_ci f: "E" 39747db96d56Sopenharmony_ci 39757db96d56Sopenharmony_ci @dataclass 39767db96d56Sopenharmony_ci class E: 39777db96d56Sopenharmony_ci f: "C" 39787db96d56Sopenharmony_ci 39797db96d56Sopenharmony_ci c = C(None) 39807db96d56Sopenharmony_ci d = D(None) 39817db96d56Sopenharmony_ci e = E(None) 39827db96d56Sopenharmony_ci c.f = d 39837db96d56Sopenharmony_ci d.f = e 39847db96d56Sopenharmony_ci e.f = c 39857db96d56Sopenharmony_ci self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two" 39867db96d56Sopenharmony_ci ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two" 39877db96d56Sopenharmony_ci ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two" 39887db96d56Sopenharmony_ci ".<locals>.E(f=...)))") 39897db96d56Sopenharmony_ci 39907db96d56Sopenharmony_ci def test_recursive_repr_misc_attrs(self): 39917db96d56Sopenharmony_ci @dataclass 39927db96d56Sopenharmony_ci class C: 39937db96d56Sopenharmony_ci f: "C" 39947db96d56Sopenharmony_ci g: int 39957db96d56Sopenharmony_ci 39967db96d56Sopenharmony_ci c = C(None, 1) 39977db96d56Sopenharmony_ci c.f = c 39987db96d56Sopenharmony_ci self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs" 39997db96d56Sopenharmony_ci ".<locals>.C(f=..., g=1)") 40007db96d56Sopenharmony_ci 40017db96d56Sopenharmony_ci ## def test_initvar(self): 40027db96d56Sopenharmony_ci ## @dataclass 40037db96d56Sopenharmony_ci ## class C: 40047db96d56Sopenharmony_ci ## x: int 40057db96d56Sopenharmony_ci ## y: InitVar[int] 40067db96d56Sopenharmony_ci 40077db96d56Sopenharmony_ci ## c = C(1, 10) 40087db96d56Sopenharmony_ci ## d = C(2, 20) 40097db96d56Sopenharmony_ci 40107db96d56Sopenharmony_ci ## # In our case, replacing an InitVar is a no-op 40117db96d56Sopenharmony_ci ## self.assertEqual(c, replace(c, y=5)) 40127db96d56Sopenharmony_ci 40137db96d56Sopenharmony_ci ## replace(c, x=5) 40147db96d56Sopenharmony_ci 40157db96d56Sopenharmony_ciclass TestAbstract(unittest.TestCase): 40167db96d56Sopenharmony_ci def test_abc_implementation(self): 40177db96d56Sopenharmony_ci class Ordered(abc.ABC): 40187db96d56Sopenharmony_ci @abc.abstractmethod 40197db96d56Sopenharmony_ci def __lt__(self, other): 40207db96d56Sopenharmony_ci pass 40217db96d56Sopenharmony_ci 40227db96d56Sopenharmony_ci @abc.abstractmethod 40237db96d56Sopenharmony_ci def __le__(self, other): 40247db96d56Sopenharmony_ci pass 40257db96d56Sopenharmony_ci 40267db96d56Sopenharmony_ci @dataclass(order=True) 40277db96d56Sopenharmony_ci class Date(Ordered): 40287db96d56Sopenharmony_ci year: int 40297db96d56Sopenharmony_ci month: 'Month' 40307db96d56Sopenharmony_ci day: 'int' 40317db96d56Sopenharmony_ci 40327db96d56Sopenharmony_ci self.assertFalse(inspect.isabstract(Date)) 40337db96d56Sopenharmony_ci self.assertGreater(Date(2020,12,25), Date(2020,8,31)) 40347db96d56Sopenharmony_ci 40357db96d56Sopenharmony_ci def test_maintain_abc(self): 40367db96d56Sopenharmony_ci class A(abc.ABC): 40377db96d56Sopenharmony_ci @abc.abstractmethod 40387db96d56Sopenharmony_ci def foo(self): 40397db96d56Sopenharmony_ci pass 40407db96d56Sopenharmony_ci 40417db96d56Sopenharmony_ci @dataclass 40427db96d56Sopenharmony_ci class Date(A): 40437db96d56Sopenharmony_ci year: int 40447db96d56Sopenharmony_ci month: 'Month' 40457db96d56Sopenharmony_ci day: 'int' 40467db96d56Sopenharmony_ci 40477db96d56Sopenharmony_ci self.assertTrue(inspect.isabstract(Date)) 40487db96d56Sopenharmony_ci msg = 'class Date with abstract method foo' 40497db96d56Sopenharmony_ci self.assertRaisesRegex(TypeError, msg, Date) 40507db96d56Sopenharmony_ci 40517db96d56Sopenharmony_ci 40527db96d56Sopenharmony_ciclass TestMatchArgs(unittest.TestCase): 40537db96d56Sopenharmony_ci def test_match_args(self): 40547db96d56Sopenharmony_ci @dataclass 40557db96d56Sopenharmony_ci class C: 40567db96d56Sopenharmony_ci a: int 40577db96d56Sopenharmony_ci self.assertEqual(C(42).__match_args__, ('a',)) 40587db96d56Sopenharmony_ci 40597db96d56Sopenharmony_ci def test_explicit_match_args(self): 40607db96d56Sopenharmony_ci ma = () 40617db96d56Sopenharmony_ci @dataclass 40627db96d56Sopenharmony_ci class C: 40637db96d56Sopenharmony_ci a: int 40647db96d56Sopenharmony_ci __match_args__ = ma 40657db96d56Sopenharmony_ci self.assertIs(C(42).__match_args__, ma) 40667db96d56Sopenharmony_ci 40677db96d56Sopenharmony_ci def test_bpo_43764(self): 40687db96d56Sopenharmony_ci @dataclass(repr=False, eq=False, init=False) 40697db96d56Sopenharmony_ci class X: 40707db96d56Sopenharmony_ci a: int 40717db96d56Sopenharmony_ci b: int 40727db96d56Sopenharmony_ci c: int 40737db96d56Sopenharmony_ci self.assertEqual(X.__match_args__, ("a", "b", "c")) 40747db96d56Sopenharmony_ci 40757db96d56Sopenharmony_ci def test_match_args_argument(self): 40767db96d56Sopenharmony_ci @dataclass(match_args=False) 40777db96d56Sopenharmony_ci class X: 40787db96d56Sopenharmony_ci a: int 40797db96d56Sopenharmony_ci self.assertNotIn('__match_args__', X.__dict__) 40807db96d56Sopenharmony_ci 40817db96d56Sopenharmony_ci @dataclass(match_args=False) 40827db96d56Sopenharmony_ci class Y: 40837db96d56Sopenharmony_ci a: int 40847db96d56Sopenharmony_ci __match_args__ = ('b',) 40857db96d56Sopenharmony_ci self.assertEqual(Y.__match_args__, ('b',)) 40867db96d56Sopenharmony_ci 40877db96d56Sopenharmony_ci @dataclass(match_args=False) 40887db96d56Sopenharmony_ci class Z(Y): 40897db96d56Sopenharmony_ci z: int 40907db96d56Sopenharmony_ci self.assertEqual(Z.__match_args__, ('b',)) 40917db96d56Sopenharmony_ci 40927db96d56Sopenharmony_ci # Ensure parent dataclass __match_args__ is seen, if child class 40937db96d56Sopenharmony_ci # specifies match_args=False. 40947db96d56Sopenharmony_ci @dataclass 40957db96d56Sopenharmony_ci class A: 40967db96d56Sopenharmony_ci a: int 40977db96d56Sopenharmony_ci z: int 40987db96d56Sopenharmony_ci @dataclass(match_args=False) 40997db96d56Sopenharmony_ci class B(A): 41007db96d56Sopenharmony_ci b: int 41017db96d56Sopenharmony_ci self.assertEqual(B.__match_args__, ('a', 'z')) 41027db96d56Sopenharmony_ci 41037db96d56Sopenharmony_ci def test_make_dataclasses(self): 41047db96d56Sopenharmony_ci C = make_dataclass('C', [('x', int), ('y', int)]) 41057db96d56Sopenharmony_ci self.assertEqual(C.__match_args__, ('x', 'y')) 41067db96d56Sopenharmony_ci 41077db96d56Sopenharmony_ci C = make_dataclass('C', [('x', int), ('y', int)], match_args=True) 41087db96d56Sopenharmony_ci self.assertEqual(C.__match_args__, ('x', 'y')) 41097db96d56Sopenharmony_ci 41107db96d56Sopenharmony_ci C = make_dataclass('C', [('x', int), ('y', int)], match_args=False) 41117db96d56Sopenharmony_ci self.assertNotIn('__match__args__', C.__dict__) 41127db96d56Sopenharmony_ci 41137db96d56Sopenharmony_ci C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)}) 41147db96d56Sopenharmony_ci self.assertEqual(C.__match_args__, ('z',)) 41157db96d56Sopenharmony_ci 41167db96d56Sopenharmony_ci 41177db96d56Sopenharmony_ciclass TestKeywordArgs(unittest.TestCase): 41187db96d56Sopenharmony_ci def test_no_classvar_kwarg(self): 41197db96d56Sopenharmony_ci msg = 'field a is a ClassVar but specifies kw_only' 41207db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, msg): 41217db96d56Sopenharmony_ci @dataclass 41227db96d56Sopenharmony_ci class A: 41237db96d56Sopenharmony_ci a: ClassVar[int] = field(kw_only=True) 41247db96d56Sopenharmony_ci 41257db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, msg): 41267db96d56Sopenharmony_ci @dataclass 41277db96d56Sopenharmony_ci class A: 41287db96d56Sopenharmony_ci a: ClassVar[int] = field(kw_only=False) 41297db96d56Sopenharmony_ci 41307db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, msg): 41317db96d56Sopenharmony_ci @dataclass(kw_only=True) 41327db96d56Sopenharmony_ci class A: 41337db96d56Sopenharmony_ci a: ClassVar[int] = field(kw_only=False) 41347db96d56Sopenharmony_ci 41357db96d56Sopenharmony_ci def test_field_marked_as_kwonly(self): 41367db96d56Sopenharmony_ci ####################### 41377db96d56Sopenharmony_ci # Using dataclass(kw_only=True) 41387db96d56Sopenharmony_ci @dataclass(kw_only=True) 41397db96d56Sopenharmony_ci class A: 41407db96d56Sopenharmony_ci a: int 41417db96d56Sopenharmony_ci self.assertTrue(fields(A)[0].kw_only) 41427db96d56Sopenharmony_ci 41437db96d56Sopenharmony_ci @dataclass(kw_only=True) 41447db96d56Sopenharmony_ci class A: 41457db96d56Sopenharmony_ci a: int = field(kw_only=True) 41467db96d56Sopenharmony_ci self.assertTrue(fields(A)[0].kw_only) 41477db96d56Sopenharmony_ci 41487db96d56Sopenharmony_ci @dataclass(kw_only=True) 41497db96d56Sopenharmony_ci class A: 41507db96d56Sopenharmony_ci a: int = field(kw_only=False) 41517db96d56Sopenharmony_ci self.assertFalse(fields(A)[0].kw_only) 41527db96d56Sopenharmony_ci 41537db96d56Sopenharmony_ci ####################### 41547db96d56Sopenharmony_ci # Using dataclass(kw_only=False) 41557db96d56Sopenharmony_ci @dataclass(kw_only=False) 41567db96d56Sopenharmony_ci class A: 41577db96d56Sopenharmony_ci a: int 41587db96d56Sopenharmony_ci self.assertFalse(fields(A)[0].kw_only) 41597db96d56Sopenharmony_ci 41607db96d56Sopenharmony_ci @dataclass(kw_only=False) 41617db96d56Sopenharmony_ci class A: 41627db96d56Sopenharmony_ci a: int = field(kw_only=True) 41637db96d56Sopenharmony_ci self.assertTrue(fields(A)[0].kw_only) 41647db96d56Sopenharmony_ci 41657db96d56Sopenharmony_ci @dataclass(kw_only=False) 41667db96d56Sopenharmony_ci class A: 41677db96d56Sopenharmony_ci a: int = field(kw_only=False) 41687db96d56Sopenharmony_ci self.assertFalse(fields(A)[0].kw_only) 41697db96d56Sopenharmony_ci 41707db96d56Sopenharmony_ci ####################### 41717db96d56Sopenharmony_ci # Not specifying dataclass(kw_only) 41727db96d56Sopenharmony_ci @dataclass 41737db96d56Sopenharmony_ci class A: 41747db96d56Sopenharmony_ci a: int 41757db96d56Sopenharmony_ci self.assertFalse(fields(A)[0].kw_only) 41767db96d56Sopenharmony_ci 41777db96d56Sopenharmony_ci @dataclass 41787db96d56Sopenharmony_ci class A: 41797db96d56Sopenharmony_ci a: int = field(kw_only=True) 41807db96d56Sopenharmony_ci self.assertTrue(fields(A)[0].kw_only) 41817db96d56Sopenharmony_ci 41827db96d56Sopenharmony_ci @dataclass 41837db96d56Sopenharmony_ci class A: 41847db96d56Sopenharmony_ci a: int = field(kw_only=False) 41857db96d56Sopenharmony_ci self.assertFalse(fields(A)[0].kw_only) 41867db96d56Sopenharmony_ci 41877db96d56Sopenharmony_ci def test_match_args(self): 41887db96d56Sopenharmony_ci # kw fields don't show up in __match_args__. 41897db96d56Sopenharmony_ci @dataclass(kw_only=True) 41907db96d56Sopenharmony_ci class C: 41917db96d56Sopenharmony_ci a: int 41927db96d56Sopenharmony_ci self.assertEqual(C(a=42).__match_args__, ()) 41937db96d56Sopenharmony_ci 41947db96d56Sopenharmony_ci @dataclass 41957db96d56Sopenharmony_ci class C: 41967db96d56Sopenharmony_ci a: int 41977db96d56Sopenharmony_ci b: int = field(kw_only=True) 41987db96d56Sopenharmony_ci self.assertEqual(C(42, b=10).__match_args__, ('a',)) 41997db96d56Sopenharmony_ci 42007db96d56Sopenharmony_ci def test_KW_ONLY(self): 42017db96d56Sopenharmony_ci @dataclass 42027db96d56Sopenharmony_ci class A: 42037db96d56Sopenharmony_ci a: int 42047db96d56Sopenharmony_ci _: KW_ONLY 42057db96d56Sopenharmony_ci b: int 42067db96d56Sopenharmony_ci c: int 42077db96d56Sopenharmony_ci A(3, c=5, b=4) 42087db96d56Sopenharmony_ci msg = "takes 2 positional arguments but 4 were given" 42097db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, msg): 42107db96d56Sopenharmony_ci A(3, 4, 5) 42117db96d56Sopenharmony_ci 42127db96d56Sopenharmony_ci 42137db96d56Sopenharmony_ci @dataclass(kw_only=True) 42147db96d56Sopenharmony_ci class B: 42157db96d56Sopenharmony_ci a: int 42167db96d56Sopenharmony_ci _: KW_ONLY 42177db96d56Sopenharmony_ci b: int 42187db96d56Sopenharmony_ci c: int 42197db96d56Sopenharmony_ci B(a=3, b=4, c=5) 42207db96d56Sopenharmony_ci msg = "takes 1 positional argument but 4 were given" 42217db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, msg): 42227db96d56Sopenharmony_ci B(3, 4, 5) 42237db96d56Sopenharmony_ci 42247db96d56Sopenharmony_ci # Explicitly make a field that follows KW_ONLY be non-keyword-only. 42257db96d56Sopenharmony_ci @dataclass 42267db96d56Sopenharmony_ci class C: 42277db96d56Sopenharmony_ci a: int 42287db96d56Sopenharmony_ci _: KW_ONLY 42297db96d56Sopenharmony_ci b: int 42307db96d56Sopenharmony_ci c: int = field(kw_only=False) 42317db96d56Sopenharmony_ci c = C(1, 2, b=3) 42327db96d56Sopenharmony_ci self.assertEqual(c.a, 1) 42337db96d56Sopenharmony_ci self.assertEqual(c.b, 3) 42347db96d56Sopenharmony_ci self.assertEqual(c.c, 2) 42357db96d56Sopenharmony_ci c = C(1, b=3, c=2) 42367db96d56Sopenharmony_ci self.assertEqual(c.a, 1) 42377db96d56Sopenharmony_ci self.assertEqual(c.b, 3) 42387db96d56Sopenharmony_ci self.assertEqual(c.c, 2) 42397db96d56Sopenharmony_ci c = C(1, b=3, c=2) 42407db96d56Sopenharmony_ci self.assertEqual(c.a, 1) 42417db96d56Sopenharmony_ci self.assertEqual(c.b, 3) 42427db96d56Sopenharmony_ci self.assertEqual(c.c, 2) 42437db96d56Sopenharmony_ci c = C(c=2, b=3, a=1) 42447db96d56Sopenharmony_ci self.assertEqual(c.a, 1) 42457db96d56Sopenharmony_ci self.assertEqual(c.b, 3) 42467db96d56Sopenharmony_ci self.assertEqual(c.c, 2) 42477db96d56Sopenharmony_ci 42487db96d56Sopenharmony_ci def test_KW_ONLY_as_string(self): 42497db96d56Sopenharmony_ci @dataclass 42507db96d56Sopenharmony_ci class A: 42517db96d56Sopenharmony_ci a: int 42527db96d56Sopenharmony_ci _: 'dataclasses.KW_ONLY' 42537db96d56Sopenharmony_ci b: int 42547db96d56Sopenharmony_ci c: int 42557db96d56Sopenharmony_ci A(3, c=5, b=4) 42567db96d56Sopenharmony_ci msg = "takes 2 positional arguments but 4 were given" 42577db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, msg): 42587db96d56Sopenharmony_ci A(3, 4, 5) 42597db96d56Sopenharmony_ci 42607db96d56Sopenharmony_ci def test_KW_ONLY_twice(self): 42617db96d56Sopenharmony_ci msg = "'Y' is KW_ONLY, but KW_ONLY has already been specified" 42627db96d56Sopenharmony_ci 42637db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, msg): 42647db96d56Sopenharmony_ci @dataclass 42657db96d56Sopenharmony_ci class A: 42667db96d56Sopenharmony_ci a: int 42677db96d56Sopenharmony_ci X: KW_ONLY 42687db96d56Sopenharmony_ci Y: KW_ONLY 42697db96d56Sopenharmony_ci b: int 42707db96d56Sopenharmony_ci c: int 42717db96d56Sopenharmony_ci 42727db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, msg): 42737db96d56Sopenharmony_ci @dataclass 42747db96d56Sopenharmony_ci class A: 42757db96d56Sopenharmony_ci a: int 42767db96d56Sopenharmony_ci X: KW_ONLY 42777db96d56Sopenharmony_ci b: int 42787db96d56Sopenharmony_ci Y: KW_ONLY 42797db96d56Sopenharmony_ci c: int 42807db96d56Sopenharmony_ci 42817db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, msg): 42827db96d56Sopenharmony_ci @dataclass 42837db96d56Sopenharmony_ci class A: 42847db96d56Sopenharmony_ci a: int 42857db96d56Sopenharmony_ci X: KW_ONLY 42867db96d56Sopenharmony_ci b: int 42877db96d56Sopenharmony_ci c: int 42887db96d56Sopenharmony_ci Y: KW_ONLY 42897db96d56Sopenharmony_ci 42907db96d56Sopenharmony_ci # But this usage is okay, since it's not using KW_ONLY. 42917db96d56Sopenharmony_ci @dataclass 42927db96d56Sopenharmony_ci class A: 42937db96d56Sopenharmony_ci a: int 42947db96d56Sopenharmony_ci _: KW_ONLY 42957db96d56Sopenharmony_ci b: int 42967db96d56Sopenharmony_ci c: int = field(kw_only=True) 42977db96d56Sopenharmony_ci 42987db96d56Sopenharmony_ci # And if inheriting, it's okay. 42997db96d56Sopenharmony_ci @dataclass 43007db96d56Sopenharmony_ci class A: 43017db96d56Sopenharmony_ci a: int 43027db96d56Sopenharmony_ci _: KW_ONLY 43037db96d56Sopenharmony_ci b: int 43047db96d56Sopenharmony_ci c: int 43057db96d56Sopenharmony_ci @dataclass 43067db96d56Sopenharmony_ci class B(A): 43077db96d56Sopenharmony_ci _: KW_ONLY 43087db96d56Sopenharmony_ci d: int 43097db96d56Sopenharmony_ci 43107db96d56Sopenharmony_ci # Make sure the error is raised in a derived class. 43117db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, msg): 43127db96d56Sopenharmony_ci @dataclass 43137db96d56Sopenharmony_ci class A: 43147db96d56Sopenharmony_ci a: int 43157db96d56Sopenharmony_ci _: KW_ONLY 43167db96d56Sopenharmony_ci b: int 43177db96d56Sopenharmony_ci c: int 43187db96d56Sopenharmony_ci @dataclass 43197db96d56Sopenharmony_ci class B(A): 43207db96d56Sopenharmony_ci X: KW_ONLY 43217db96d56Sopenharmony_ci d: int 43227db96d56Sopenharmony_ci Y: KW_ONLY 43237db96d56Sopenharmony_ci 43247db96d56Sopenharmony_ci 43257db96d56Sopenharmony_ci def test_post_init(self): 43267db96d56Sopenharmony_ci @dataclass 43277db96d56Sopenharmony_ci class A: 43287db96d56Sopenharmony_ci a: int 43297db96d56Sopenharmony_ci _: KW_ONLY 43307db96d56Sopenharmony_ci b: InitVar[int] 43317db96d56Sopenharmony_ci c: int 43327db96d56Sopenharmony_ci d: InitVar[int] 43337db96d56Sopenharmony_ci def __post_init__(self, b, d): 43347db96d56Sopenharmony_ci raise CustomError(f'{b=} {d=}') 43357db96d56Sopenharmony_ci with self.assertRaisesRegex(CustomError, 'b=3 d=4'): 43367db96d56Sopenharmony_ci A(1, c=2, b=3, d=4) 43377db96d56Sopenharmony_ci 43387db96d56Sopenharmony_ci @dataclass 43397db96d56Sopenharmony_ci class B: 43407db96d56Sopenharmony_ci a: int 43417db96d56Sopenharmony_ci _: KW_ONLY 43427db96d56Sopenharmony_ci b: InitVar[int] 43437db96d56Sopenharmony_ci c: int 43447db96d56Sopenharmony_ci d: InitVar[int] 43457db96d56Sopenharmony_ci def __post_init__(self, b, d): 43467db96d56Sopenharmony_ci self.a = b 43477db96d56Sopenharmony_ci self.c = d 43487db96d56Sopenharmony_ci b = B(1, c=2, b=3, d=4) 43497db96d56Sopenharmony_ci self.assertEqual(asdict(b), {'a': 3, 'c': 4}) 43507db96d56Sopenharmony_ci 43517db96d56Sopenharmony_ci def test_defaults(self): 43527db96d56Sopenharmony_ci # For kwargs, make sure we can have defaults after non-defaults. 43537db96d56Sopenharmony_ci @dataclass 43547db96d56Sopenharmony_ci class A: 43557db96d56Sopenharmony_ci a: int = 0 43567db96d56Sopenharmony_ci _: KW_ONLY 43577db96d56Sopenharmony_ci b: int 43587db96d56Sopenharmony_ci c: int = 1 43597db96d56Sopenharmony_ci d: int 43607db96d56Sopenharmony_ci 43617db96d56Sopenharmony_ci a = A(d=4, b=3) 43627db96d56Sopenharmony_ci self.assertEqual(a.a, 0) 43637db96d56Sopenharmony_ci self.assertEqual(a.b, 3) 43647db96d56Sopenharmony_ci self.assertEqual(a.c, 1) 43657db96d56Sopenharmony_ci self.assertEqual(a.d, 4) 43667db96d56Sopenharmony_ci 43677db96d56Sopenharmony_ci # Make sure we still check for non-kwarg non-defaults not following 43687db96d56Sopenharmony_ci # defaults. 43697db96d56Sopenharmony_ci err_regex = "non-default argument 'z' follows default argument" 43707db96d56Sopenharmony_ci with self.assertRaisesRegex(TypeError, err_regex): 43717db96d56Sopenharmony_ci @dataclass 43727db96d56Sopenharmony_ci class A: 43737db96d56Sopenharmony_ci a: int = 0 43747db96d56Sopenharmony_ci z: int 43757db96d56Sopenharmony_ci _: KW_ONLY 43767db96d56Sopenharmony_ci b: int 43777db96d56Sopenharmony_ci c: int = 1 43787db96d56Sopenharmony_ci d: int 43797db96d56Sopenharmony_ci 43807db96d56Sopenharmony_ci def test_make_dataclass(self): 43817db96d56Sopenharmony_ci A = make_dataclass("A", ['a'], kw_only=True) 43827db96d56Sopenharmony_ci self.assertTrue(fields(A)[0].kw_only) 43837db96d56Sopenharmony_ci 43847db96d56Sopenharmony_ci B = make_dataclass("B", 43857db96d56Sopenharmony_ci ['a', ('b', int, field(kw_only=False))], 43867db96d56Sopenharmony_ci kw_only=True) 43877db96d56Sopenharmony_ci self.assertTrue(fields(B)[0].kw_only) 43887db96d56Sopenharmony_ci self.assertFalse(fields(B)[1].kw_only) 43897db96d56Sopenharmony_ci 43907db96d56Sopenharmony_ci 43917db96d56Sopenharmony_ciif __name__ == '__main__': 43927db96d56Sopenharmony_ci unittest.main() 4393