1// types.UnionType -- used to represent e.g. Union[int, str], int | str 2#include "Python.h" 3#include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK 4#include "pycore_unionobject.h" 5#include "structmember.h" 6 7 8static PyObject *make_union(PyObject *); 9 10 11typedef struct { 12 PyObject_HEAD 13 PyObject *args; 14 PyObject *parameters; 15} unionobject; 16 17static void 18unionobject_dealloc(PyObject *self) 19{ 20 unionobject *alias = (unionobject *)self; 21 22 _PyObject_GC_UNTRACK(self); 23 24 Py_XDECREF(alias->args); 25 Py_XDECREF(alias->parameters); 26 Py_TYPE(self)->tp_free(self); 27} 28 29static int 30union_traverse(PyObject *self, visitproc visit, void *arg) 31{ 32 unionobject *alias = (unionobject *)self; 33 Py_VISIT(alias->args); 34 Py_VISIT(alias->parameters); 35 return 0; 36} 37 38static Py_hash_t 39union_hash(PyObject *self) 40{ 41 unionobject *alias = (unionobject *)self; 42 PyObject *args = PyFrozenSet_New(alias->args); 43 if (args == NULL) { 44 return (Py_hash_t)-1; 45 } 46 Py_hash_t hash = PyObject_Hash(args); 47 Py_DECREF(args); 48 return hash; 49} 50 51static PyObject * 52union_richcompare(PyObject *a, PyObject *b, int op) 53{ 54 if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) { 55 Py_RETURN_NOTIMPLEMENTED; 56 } 57 58 PyObject *a_set = PySet_New(((unionobject*)a)->args); 59 if (a_set == NULL) { 60 return NULL; 61 } 62 PyObject *b_set = PySet_New(((unionobject*)b)->args); 63 if (b_set == NULL) { 64 Py_DECREF(a_set); 65 return NULL; 66 } 67 PyObject *result = PyObject_RichCompare(a_set, b_set, op); 68 Py_DECREF(b_set); 69 Py_DECREF(a_set); 70 return result; 71} 72 73static int 74is_same(PyObject *left, PyObject *right) 75{ 76 int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right); 77 return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right; 78} 79 80static int 81contains(PyObject **items, Py_ssize_t size, PyObject *obj) 82{ 83 for (int i = 0; i < size; i++) { 84 int is_duplicate = is_same(items[i], obj); 85 if (is_duplicate) { // -1 or 1 86 return is_duplicate; 87 } 88 } 89 return 0; 90} 91 92static PyObject * 93merge(PyObject **items1, Py_ssize_t size1, 94 PyObject **items2, Py_ssize_t size2) 95{ 96 PyObject *tuple = NULL; 97 Py_ssize_t pos = 0; 98 99 for (int i = 0; i < size2; i++) { 100 PyObject *arg = items2[i]; 101 int is_duplicate = contains(items1, size1, arg); 102 if (is_duplicate < 0) { 103 Py_XDECREF(tuple); 104 return NULL; 105 } 106 if (is_duplicate) { 107 continue; 108 } 109 110 if (tuple == NULL) { 111 tuple = PyTuple_New(size1 + size2 - i); 112 if (tuple == NULL) { 113 return NULL; 114 } 115 for (; pos < size1; pos++) { 116 PyObject *a = items1[pos]; 117 Py_INCREF(a); 118 PyTuple_SET_ITEM(tuple, pos, a); 119 } 120 } 121 Py_INCREF(arg); 122 PyTuple_SET_ITEM(tuple, pos, arg); 123 pos++; 124 } 125 126 if (tuple) { 127 (void) _PyTuple_Resize(&tuple, pos); 128 } 129 return tuple; 130} 131 132static PyObject ** 133get_types(PyObject **obj, Py_ssize_t *size) 134{ 135 if (*obj == Py_None) { 136 *obj = (PyObject *)&_PyNone_Type; 137 } 138 if (_PyUnion_Check(*obj)) { 139 PyObject *args = ((unionobject *) *obj)->args; 140 *size = PyTuple_GET_SIZE(args); 141 return &PyTuple_GET_ITEM(args, 0); 142 } 143 else { 144 *size = 1; 145 return obj; 146 } 147} 148 149static int 150is_unionable(PyObject *obj) 151{ 152 return (obj == Py_None || 153 PyType_Check(obj) || 154 _PyGenericAlias_Check(obj) || 155 _PyUnion_Check(obj)); 156} 157 158PyObject * 159_Py_union_type_or(PyObject* self, PyObject* other) 160{ 161 if (!is_unionable(self) || !is_unionable(other)) { 162 Py_RETURN_NOTIMPLEMENTED; 163 } 164 165 Py_ssize_t size1, size2; 166 PyObject **items1 = get_types(&self, &size1); 167 PyObject **items2 = get_types(&other, &size2); 168 PyObject *tuple = merge(items1, size1, items2, size2); 169 if (tuple == NULL) { 170 if (PyErr_Occurred()) { 171 return NULL; 172 } 173 Py_INCREF(self); 174 return self; 175 } 176 177 PyObject *new_union = make_union(tuple); 178 Py_DECREF(tuple); 179 return new_union; 180} 181 182static int 183union_repr_item(_PyUnicodeWriter *writer, PyObject *p) 184{ 185 PyObject *qualname = NULL; 186 PyObject *module = NULL; 187 PyObject *tmp; 188 PyObject *r = NULL; 189 int err; 190 191 if (p == (PyObject *)&_PyNone_Type) { 192 return _PyUnicodeWriter_WriteASCIIString(writer, "None", 4); 193 } 194 195 if (_PyObject_LookupAttr(p, &_Py_ID(__origin__), &tmp) < 0) { 196 goto exit; 197 } 198 199 if (tmp) { 200 Py_DECREF(tmp); 201 if (_PyObject_LookupAttr(p, &_Py_ID(__args__), &tmp) < 0) { 202 goto exit; 203 } 204 if (tmp) { 205 // It looks like a GenericAlias 206 Py_DECREF(tmp); 207 goto use_repr; 208 } 209 } 210 211 if (_PyObject_LookupAttr(p, &_Py_ID(__qualname__), &qualname) < 0) { 212 goto exit; 213 } 214 if (qualname == NULL) { 215 goto use_repr; 216 } 217 if (_PyObject_LookupAttr(p, &_Py_ID(__module__), &module) < 0) { 218 goto exit; 219 } 220 if (module == NULL || module == Py_None) { 221 goto use_repr; 222 } 223 224 // Looks like a class 225 if (PyUnicode_Check(module) && 226 _PyUnicode_EqualToASCIIString(module, "builtins")) 227 { 228 // builtins don't need a module name 229 r = PyObject_Str(qualname); 230 goto exit; 231 } 232 else { 233 r = PyUnicode_FromFormat("%S.%S", module, qualname); 234 goto exit; 235 } 236 237use_repr: 238 r = PyObject_Repr(p); 239exit: 240 Py_XDECREF(qualname); 241 Py_XDECREF(module); 242 if (r == NULL) { 243 return -1; 244 } 245 err = _PyUnicodeWriter_WriteStr(writer, r); 246 Py_DECREF(r); 247 return err; 248} 249 250static PyObject * 251union_repr(PyObject *self) 252{ 253 unionobject *alias = (unionobject *)self; 254 Py_ssize_t len = PyTuple_GET_SIZE(alias->args); 255 256 _PyUnicodeWriter writer; 257 _PyUnicodeWriter_Init(&writer); 258 for (Py_ssize_t i = 0; i < len; i++) { 259 if (i > 0 && _PyUnicodeWriter_WriteASCIIString(&writer, " | ", 3) < 0) { 260 goto error; 261 } 262 PyObject *p = PyTuple_GET_ITEM(alias->args, i); 263 if (union_repr_item(&writer, p) < 0) { 264 goto error; 265 } 266 } 267 return _PyUnicodeWriter_Finish(&writer); 268error: 269 _PyUnicodeWriter_Dealloc(&writer); 270 return NULL; 271} 272 273static PyMemberDef union_members[] = { 274 {"__args__", T_OBJECT, offsetof(unionobject, args), READONLY}, 275 {0} 276}; 277 278static PyObject * 279union_getitem(PyObject *self, PyObject *item) 280{ 281 unionobject *alias = (unionobject *)self; 282 // Populate __parameters__ if needed. 283 if (alias->parameters == NULL) { 284 alias->parameters = _Py_make_parameters(alias->args); 285 if (alias->parameters == NULL) { 286 return NULL; 287 } 288 } 289 290 PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item); 291 if (newargs == NULL) { 292 return NULL; 293 } 294 295 PyObject *res; 296 Py_ssize_t nargs = PyTuple_GET_SIZE(newargs); 297 if (nargs == 0) { 298 res = make_union(newargs); 299 } 300 else { 301 res = PyTuple_GET_ITEM(newargs, 0); 302 Py_INCREF(res); 303 for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) { 304 PyObject *arg = PyTuple_GET_ITEM(newargs, iarg); 305 Py_SETREF(res, PyNumber_Or(res, arg)); 306 if (res == NULL) { 307 break; 308 } 309 } 310 } 311 Py_DECREF(newargs); 312 return res; 313} 314 315static PyMappingMethods union_as_mapping = { 316 .mp_subscript = union_getitem, 317}; 318 319static PyObject * 320union_parameters(PyObject *self, void *Py_UNUSED(unused)) 321{ 322 unionobject *alias = (unionobject *)self; 323 if (alias->parameters == NULL) { 324 alias->parameters = _Py_make_parameters(alias->args); 325 if (alias->parameters == NULL) { 326 return NULL; 327 } 328 } 329 Py_INCREF(alias->parameters); 330 return alias->parameters; 331} 332 333static PyGetSetDef union_properties[] = { 334 {"__parameters__", union_parameters, (setter)NULL, "Type variables in the types.UnionType.", NULL}, 335 {0} 336}; 337 338static PyNumberMethods union_as_number = { 339 .nb_or = _Py_union_type_or, // Add __or__ function 340}; 341 342static const char* const cls_attrs[] = { 343 "__module__", // Required for compatibility with typing module 344 NULL, 345}; 346 347static PyObject * 348union_getattro(PyObject *self, PyObject *name) 349{ 350 unionobject *alias = (unionobject *)self; 351 if (PyUnicode_Check(name)) { 352 for (const char * const *p = cls_attrs; ; p++) { 353 if (*p == NULL) { 354 break; 355 } 356 if (_PyUnicode_EqualToASCIIString(name, *p)) { 357 return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name); 358 } 359 } 360 } 361 return PyObject_GenericGetAttr(self, name); 362} 363 364PyObject * 365_Py_union_args(PyObject *self) 366{ 367 assert(_PyUnion_Check(self)); 368 return ((unionobject *) self)->args; 369} 370 371PyTypeObject _PyUnion_Type = { 372 PyVarObject_HEAD_INIT(&PyType_Type, 0) 373 .tp_name = "types.UnionType", 374 .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n" 375 "\n" 376 "E.g. for int | str"), 377 .tp_basicsize = sizeof(unionobject), 378 .tp_dealloc = unionobject_dealloc, 379 .tp_alloc = PyType_GenericAlloc, 380 .tp_free = PyObject_GC_Del, 381 .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, 382 .tp_traverse = union_traverse, 383 .tp_hash = union_hash, 384 .tp_getattro = union_getattro, 385 .tp_members = union_members, 386 .tp_richcompare = union_richcompare, 387 .tp_as_mapping = &union_as_mapping, 388 .tp_as_number = &union_as_number, 389 .tp_repr = union_repr, 390 .tp_getset = union_properties, 391}; 392 393static PyObject * 394make_union(PyObject *args) 395{ 396 assert(PyTuple_CheckExact(args)); 397 398 unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type); 399 if (result == NULL) { 400 return NULL; 401 } 402 403 Py_INCREF(args); 404 result->parameters = NULL; 405 result->args = args; 406 _PyObject_GC_TRACK(result); 407 return (PyObject*)result; 408} 409