this repo has no description
at trunk 1304 lines 50 kB view raw
1#!/usr/bin/env python3 2# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) 3# WARNING: This is a temporary copy of code from the cpython library to 4# facilitate bringup. Please file a task for anything you change! 5# flake8: noqa 6# fmt: off 7 8import _thread 9import builtins 10import copy 11import functools 12import inspect 13import keyword 14import re 15import sys 16import types 17 18 19__all__ = ['dataclass', 20 'field', 21 'Field', 22 'FrozenInstanceError', 23 'InitVar', 24 'MISSING', 25 26 # Helper functions. 27 'fields', 28 'asdict', 29 'astuple', 30 'make_dataclass', 31 'replace', 32 'is_dataclass', 33 ] 34 35# Conditions for adding methods. The boxes indicate what action the 36# dataclass decorator takes. For all of these tables, when I talk 37# about init=, repr=, eq=, order=, unsafe_hash=, or frozen=, I'm 38# referring to the arguments to the @dataclass decorator. When 39# checking if a dunder method already exists, I mean check for an 40# entry in the class's __dict__. I never check to see if an attribute 41# is defined in a base class. 42 43# Key: 44# +=========+=========================================+ 45# + Value | Meaning | 46# +=========+=========================================+ 47# | <blank> | No action: no method is added. | 48# +---------+-----------------------------------------+ 49# | add | Generated method is added. | 50# +---------+-----------------------------------------+ 51# | raise | TypeError is raised. | 52# +---------+-----------------------------------------+ 53# | None | Attribute is set to None. | 54# +=========+=========================================+ 55 56# __init__ 57# 58# +--- init= parameter 59# | 60# v | | | 61# | no | yes | <--- class has __init__ in __dict__? 62# +=======+=======+=======+ 63# | False | | | 64# +-------+-------+-------+ 65# | True | add | | <- the default 66# +=======+=======+=======+ 67 68# __repr__ 69# 70# +--- repr= parameter 71# | 72# v | | | 73# | no | yes | <--- class has __repr__ in __dict__? 74# +=======+=======+=======+ 75# | False | | | 76# +-------+-------+-------+ 77# | True | add | | <- the default 78# +=======+=======+=======+ 79 80 81# __setattr__ 82# __delattr__ 83# 84# +--- frozen= parameter 85# | 86# v | | | 87# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__? 88# +=======+=======+=======+ 89# | False | | | <- the default 90# +-------+-------+-------+ 91# | True | add | raise | 92# +=======+=======+=======+ 93# Raise because not adding these methods would break the "frozen-ness" 94# of the class. 95 96# __eq__ 97# 98# +--- eq= parameter 99# | 100# v | | | 101# | no | yes | <--- class has __eq__ in __dict__? 102# +=======+=======+=======+ 103# | False | | | 104# +-------+-------+-------+ 105# | True | add | | <- the default 106# +=======+=======+=======+ 107 108# __lt__ 109# __le__ 110# __gt__ 111# __ge__ 112# 113# +--- order= parameter 114# | 115# v | | | 116# | no | yes | <--- class has any comparison method in __dict__? 117# +=======+=======+=======+ 118# | False | | | <- the default 119# +-------+-------+-------+ 120# | True | add | raise | 121# +=======+=======+=======+ 122# Raise because to allow this case would interfere with using 123# functools.total_ordering. 124 125# __hash__ 126 127# +------------------- unsafe_hash= parameter 128# | +----------- eq= parameter 129# | | +--- frozen= parameter 130# | | | 131# v v v | | | 132# | no | yes | <--- class has explicitly defined __hash__ 133# +=======+=======+=======+========+========+ 134# | False | False | False | | | No __eq__, use the base class __hash__ 135# +-------+-------+-------+--------+--------+ 136# | False | False | True | | | No __eq__, use the base class __hash__ 137# +-------+-------+-------+--------+--------+ 138# | False | True | False | None | | <-- the default, not hashable 139# +-------+-------+-------+--------+--------+ 140# | False | True | True | add | | Frozen, so hashable, allows override 141# +-------+-------+-------+--------+--------+ 142# | True | False | False | add | raise | Has no __eq__, but hashable 143# +-------+-------+-------+--------+--------+ 144# | True | False | True | add | raise | Has no __eq__, but hashable 145# +-------+-------+-------+--------+--------+ 146# | True | True | False | add | raise | Not frozen, but hashable 147# +-------+-------+-------+--------+--------+ 148# | True | True | True | add | raise | Frozen, so hashable 149# +=======+=======+=======+========+========+ 150# For boxes that are blank, __hash__ is untouched and therefore 151# inherited from the base class. If the base is object, then 152# id-based hashing is used. 153# 154# Note that a class may already have __hash__=None if it specified an 155# __eq__ method in the class body (not one that was created by 156# @dataclass). 157# 158# See _hash_action (below) for a coded version of this table. 159 160 161# Raised when an attempt is made to modify a frozen class. 162class FrozenInstanceError(AttributeError): pass 163 164# A sentinel object for default values to signal that a default 165# factory will be used. This is given a nice repr() which will appear 166# in the function signature of dataclasses' constructors. 167class _HAS_DEFAULT_FACTORY_CLASS: 168 def __repr__(self): 169 return '<factory>' 170_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS() 171 172# A sentinel object to detect if a parameter is supplied or not. Use 173# a class to give it a better repr. 174class _MISSING_TYPE: 175 pass 176MISSING = _MISSING_TYPE() 177 178# Since most per-field metadata will be unused, create an empty 179# read-only proxy that can be shared among all fields. 180_EMPTY_METADATA = types.MappingProxyType({}) 181 182# Markers for the various kinds of fields and pseudo-fields. 183class _FIELD_BASE: 184 def __init__(self, name): 185 self.name = name 186 def __repr__(self): 187 return self.name 188_FIELD = _FIELD_BASE('_FIELD') 189_FIELD_CLASSVAR = _FIELD_BASE('_FIELD_CLASSVAR') 190_FIELD_INITVAR = _FIELD_BASE('_FIELD_INITVAR') 191 192# The name of an attribute on the class where we store the Field 193# objects. Also used to check if a class is a Data Class. 194_FIELDS = '__dataclass_fields__' 195 196# The name of an attribute on the class that stores the parameters to 197# @dataclass. 198_PARAMS = '__dataclass_params__' 199 200# The name of the function, that if it exists, is called at the end of 201# __init__. 202_POST_INIT_NAME = '__post_init__' 203 204# String regex that string annotations for ClassVar or InitVar must match. 205# Allows "identifier.identifier[" or "identifier[". 206# https://bugs.python.org/issue33453 for details. 207_MODULE_IDENTIFIER_RE = re.compile(r'^(?:\s*(\w+)\s*\.)?\s*(\w+)') 208 209class _InitVarMeta(type): 210 def __getitem__(self, params): 211 return InitVar(params) 212 213class InitVar(metaclass=_InitVarMeta): 214 __slots__ = ('type', ) 215 216 def __init__(self, type): 217 self.type = type 218 219 def __repr__(self): 220 if isinstance(self.type, type): 221 type_name = self.type.__name__ 222 else: 223 # typing objects, e.g. List[int] 224 type_name = repr(self.type) 225 return f'dataclasses.InitVar[{type_name}]' 226 227 228# Instances of Field are only ever created from within this module, 229# and only from the field() function, although Field instances are 230# exposed externally as (conceptually) read-only objects. 231# 232# name and type are filled in after the fact, not in __init__. 233# They're not known at the time this class is instantiated, but it's 234# convenient if they're available later. 235# 236# When cls._FIELDS is filled in with a list of Field objects, the name 237# and type fields will have been populated. 238class Field: 239 __slots__ = ('name', 240 'type', 241 'default', 242 'default_factory', 243 'repr', 244 'hash', 245 'init', 246 'compare', 247 'metadata', 248 '_field_type', # Private: not to be used by user code. 249 ) 250 251 def __init__(self, default, default_factory, init, repr, hash, compare, 252 metadata): 253 self.name = None 254 self.type = None 255 self.default = default 256 self.default_factory = default_factory 257 self.init = init 258 self.repr = repr 259 self.hash = hash 260 self.compare = compare 261 self.metadata = (_EMPTY_METADATA 262 if metadata is None else 263 types.MappingProxyType(metadata)) 264 self._field_type = None 265 266 def __repr__(self): 267 return ('Field(' 268 f'name={self.name!r},' 269 f'type={self.type!r},' 270 f'default={self.default!r},' 271 f'default_factory={self.default_factory!r},' 272 f'init={self.init!r},' 273 f'repr={self.repr!r},' 274 f'hash={self.hash!r},' 275 f'compare={self.compare!r},' 276 f'metadata={self.metadata!r},' 277 f'_field_type={self._field_type}' 278 ')') 279 280 # This is used to support the PEP 487 __set_name__ protocol in the 281 # case where we're using a field that contains a descriptor as a 282 # default value. For details on __set_name__, see 283 # https://www.python.org/dev/peps/pep-0487/#implementation-details. 284 # 285 # Note that in _process_class, this Field object is overwritten 286 # with the default value, so the end result is a descriptor that 287 # had __set_name__ called on it at the right time. 288 def __set_name__(self, owner, name): 289 func = getattr(type(self.default), '__set_name__', None) 290 if func: 291 # There is a __set_name__ method on the descriptor, call 292 # it. 293 func(self.default, owner, name) 294 295 296class _DataclassParams: 297 __slots__ = ('init', 298 'repr', 299 'eq', 300 'order', 301 'unsafe_hash', 302 'frozen', 303 ) 304 305 def __init__(self, init, repr, eq, order, unsafe_hash, frozen): 306 self.init = init 307 self.repr = repr 308 self.eq = eq 309 self.order = order 310 self.unsafe_hash = unsafe_hash 311 self.frozen = frozen 312 313 def __repr__(self): 314 return ('_DataclassParams(' 315 f'init={self.init!r},' 316 f'repr={self.repr!r},' 317 f'eq={self.eq!r},' 318 f'order={self.order!r},' 319 f'unsafe_hash={self.unsafe_hash!r},' 320 f'frozen={self.frozen!r}' 321 ')') 322 323 324# This function is used instead of exposing Field creation directly, 325# so that a type checker can be told (via overloads) that this is a 326# function whose type depends on its parameters. 327def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True, 328 hash=None, compare=True, metadata=None): 329 """Return an object to identify dataclass fields. 330 331 default is the default value of the field. default_factory is a 332 0-argument function called to initialize a field's value. If init 333 is True, the field will be a parameter to the class's __init__() 334 function. If repr is True, the field will be included in the 335 object's repr(). If hash is True, the field will be included in 336 the object's hash(). If compare is True, the field will be used 337 in comparison functions. metadata, if specified, must be a 338 mapping which is stored but not otherwise examined by dataclass. 339 340 It is an error to specify both default and default_factory. 341 """ 342 343 if default is not MISSING and default_factory is not MISSING: 344 raise ValueError('cannot specify both default and default_factory') 345 return Field(default, default_factory, init, repr, hash, compare, 346 metadata) 347 348 349def _tuple_str(obj_name, fields): 350 # Return a string representing each field of obj_name as a tuple 351 # member. So, if fields is ['x', 'y'] and obj_name is "self", 352 # return "(self.x,self.y)". 353 354 # Special case for the 0-tuple. 355 if not fields: 356 return '()' 357 # Note the trailing comma, needed if this turns out to be a 1-tuple. 358 return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' 359 360 361# This function's logic is copied from "recursive_repr" function in 362# reprlib module to avoid dependency. 363def _recursive_repr(user_function): 364 # Decorator to make a repr function return "..." for a recursive 365 # call. 366 repr_running = set() 367 368 @functools.wraps(user_function) 369 def wrapper(self): 370 key = id(self), _thread.get_ident() 371 if key in repr_running: 372 return '...' 373 repr_running.add(key) 374 try: 375 result = user_function(self) 376 finally: 377 repr_running.discard(key) 378 return result 379 return wrapper 380 381 382def _create_fn(name, args, body, *, globals=None, locals=None, 383 return_type=MISSING): 384 # Note that we mutate locals when exec() is called. Caller 385 # beware! The only callers are internal to this module, so no 386 # worries about external callers. 387 if locals is None: 388 locals = {} 389 if 'BUILTINS' not in locals: 390 locals['BUILTINS'] = builtins 391 return_annotation = '' 392 if return_type is not MISSING: 393 locals['_return_type'] = return_type 394 return_annotation = '->_return_type' 395 args = ','.join(args) 396 body = '\n'.join(f' {b}' for b in body) 397 398 # Compute the text of the entire function. 399 txt = f' def {name}({args}){return_annotation}:\n{body}' 400 401 local_vars = ', '.join(locals.keys()) 402 txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}" 403 404 ns = {} 405 exec(txt, globals, ns) 406 return ns['__create_fn__'](**locals) 407 408 409def _field_assign(frozen, name, value, self_name): 410 # If we're a frozen class, then assign to our fields in __init__ 411 # via object.__setattr__. Otherwise, just use a simple 412 # assignment. 413 # 414 # self_name is what "self" is called in this function: don't 415 # hard-code "self", since that might be a field name. 416 if frozen: 417 return f'BUILTINS.object.__setattr__({self_name},{name!r},{value})' 418 return f'{self_name}.{name}={value}' 419 420 421def _field_init(f, frozen, globals, self_name): 422 # Return the text of the line in the body of __init__ that will 423 # initialize this field. 424 425 default_name = f'_dflt_{f.name}' 426 if f.default_factory is not MISSING: 427 if f.init: 428 # This field has a default factory. If a parameter is 429 # given, use it. If not, call the factory. 430 globals[default_name] = f.default_factory 431 value = (f'{default_name}() ' 432 f'if {f.name} is _HAS_DEFAULT_FACTORY ' 433 f'else {f.name}') 434 else: 435 # This is a field that's not in the __init__ params, but 436 # has a default factory function. It needs to be 437 # initialized here by calling the factory function, 438 # because there's no other way to initialize it. 439 440 # For a field initialized with a default=defaultvalue, the 441 # class dict just has the default value 442 # (cls.fieldname=defaultvalue). But that won't work for a 443 # default factory, the factory must be called in __init__ 444 # and we must assign that to self.fieldname. We can't 445 # fall back to the class dict's value, both because it's 446 # not set, and because it might be different per-class 447 # (which, after all, is why we have a factory function!). 448 449 globals[default_name] = f.default_factory 450 value = f'{default_name}()' 451 else: 452 # No default factory. 453 if f.init: 454 if f.default is MISSING: 455 # There's no default, just do an assignment. 456 value = f.name 457 elif f.default is not MISSING: 458 globals[default_name] = f.default 459 value = f.name 460 else: 461 # This field does not need initialization. Signify that 462 # to the caller by returning None. 463 return None 464 465 # Only test this now, so that we can create variables for the 466 # default. However, return None to signify that we're not going 467 # to actually do the assignment statement for InitVars. 468 if f._field_type is _FIELD_INITVAR: 469 return None 470 471 # Now, actually generate the field assignment. 472 return _field_assign(frozen, f.name, value, self_name) 473 474 475def _init_param(f): 476 # Return the __init__ parameter string for this field. For 477 # example, the equivalent of 'x:int=3' (except instead of 'int', 478 # reference a variable set to int, and instead of '3', reference a 479 # variable set to 3). 480 if f.default is MISSING and f.default_factory is MISSING: 481 # There's no default, and no default_factory, just output the 482 # variable name and type. 483 default = '' 484 elif f.default is not MISSING: 485 # There's a default, this will be the name that's used to look 486 # it up. 487 default = f'=_dflt_{f.name}' 488 elif f.default_factory is not MISSING: 489 # There's a factory function. Set a marker. 490 default = '=_HAS_DEFAULT_FACTORY' 491 return f'{f.name}:_type_{f.name}{default}' 492 493 494def _init_fn(fields, frozen, has_post_init, self_name, globals): 495 # fields contains both real fields and InitVar pseudo-fields. 496 497 # Make sure we don't have fields without defaults following fields 498 # with defaults. This actually would be caught when exec-ing the 499 # function source code, but catching it here gives a better error 500 # message, and future-proofs us in case we build up the function 501 # using ast. 502 seen_default = False 503 for f in fields: 504 # Only consider fields in the __init__ call. 505 if f.init: 506 if not (f.default is MISSING and f.default_factory is MISSING): 507 seen_default = True 508 elif seen_default: 509 raise TypeError(f'non-default argument {f.name!r} ' 510 'follows default argument') 511 512 locals = {f'_type_{f.name}': f.type for f in fields} 513 locals.update({ 514 'MISSING': MISSING, 515 '_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY, 516 }) 517 518 body_lines = [] 519 for f in fields: 520 line = _field_init(f, frozen, locals, self_name) 521 # line is None means that this field doesn't require 522 # initialization (it's a pseudo-field). Just skip it. 523 if line: 524 body_lines.append(line) 525 526 # Does this class have a post-init function? 527 if has_post_init: 528 params_str = ','.join(f.name for f in fields 529 if f._field_type is _FIELD_INITVAR) 530 body_lines.append(f'{self_name}.{_POST_INIT_NAME}({params_str})') 531 532 # If no body lines, use 'pass'. 533 if not body_lines: 534 body_lines = ['pass'] 535 536 return _create_fn('__init__', 537 [self_name] + [_init_param(f) for f in fields if f.init], 538 body_lines, 539 locals=locals, 540 globals=globals, 541 return_type=None) 542 543 544def _repr_fn(fields, globals): 545 fn = _create_fn('__repr__', 546 ('self',), 547 ['return self.__class__.__qualname__ + f"(' + 548 ', '.join([f"{f.name}={{self.{f.name}!r}}" 549 for f in fields]) + 550 ')"'], 551 globals=globals) 552 return _recursive_repr(fn) 553 554 555def _frozen_get_del_attr(cls, fields, globals): 556 locals = {'cls': cls, 557 'FrozenInstanceError': FrozenInstanceError} 558 if fields: 559 fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)' 560 else: 561 # Special case for the zero-length tuple. 562 fields_str = '()' 563 return (_create_fn('__setattr__', 564 ('self', 'name', 'value'), 565 (f'if type(self) is cls or name in {fields_str}:', 566 ' raise FrozenInstanceError(f"cannot assign to field {name!r}")', 567 f'super(cls, self).__setattr__(name, value)'), 568 locals=locals, 569 globals=globals), 570 _create_fn('__delattr__', 571 ('self', 'name'), 572 (f'if type(self) is cls or name in {fields_str}:', 573 ' raise FrozenInstanceError(f"cannot delete field {name!r}")', 574 f'super(cls, self).__delattr__(name)'), 575 locals=locals, 576 globals=globals), 577 ) 578 579 580def _cmp_fn(name, op, self_tuple, other_tuple, globals): 581 # Create a comparison function. If the fields in the object are 582 # named 'x' and 'y', then self_tuple is the string 583 # '(self.x,self.y)' and other_tuple is the string 584 # '(other.x,other.y)'. 585 586 return _create_fn(name, 587 ('self', 'other'), 588 [ 'if other.__class__ is self.__class__:', 589 f' return {self_tuple}{op}{other_tuple}', 590 'return NotImplemented'], 591 globals=globals) 592 593 594def _hash_fn(fields, globals): 595 self_tuple = _tuple_str('self', fields) 596 return _create_fn('__hash__', 597 ('self',), 598 [f'return hash({self_tuple})'], 599 globals=globals) 600 601 602def _is_classvar(a_type, typing): 603 # This test uses a typing internal class, but it's the best way to 604 # test if this is a ClassVar. 605 return (a_type is typing.ClassVar 606 or (type(a_type) is typing._GenericAlias 607 and a_type.__origin__ is typing.ClassVar)) 608 609 610def _is_initvar(a_type, dataclasses): 611 # The module we're checking against is the module we're 612 # currently in (dataclasses.py). 613 return (a_type is dataclasses.InitVar 614 or type(a_type) is dataclasses.InitVar) 615 616 617def _is_type(annotation, cls, a_module, a_type, is_type_predicate): 618 # Given a type annotation string, does it refer to a_type in 619 # a_module? For example, when checking that annotation denotes a 620 # ClassVar, then a_module is typing, and a_type is 621 # typing.ClassVar. 622 623 # It's possible to look up a_module given a_type, but it involves 624 # looking in sys.modules (again!), and seems like a waste since 625 # the caller already knows a_module. 626 627 # - annotation is a string type annotation 628 # - cls is the class that this annotation was found in 629 # - a_module is the module we want to match 630 # - a_type is the type in that module we want to match 631 # - is_type_predicate is a function called with (obj, a_module) 632 # that determines if obj is of the desired type. 633 634 # Since this test does not do a local namespace lookup (and 635 # instead only a module (global) lookup), there are some things it 636 # gets wrong. 637 638 # With string annotations, cv0 will be detected as a ClassVar: 639 # CV = ClassVar 640 # @dataclass 641 # class C0: 642 # cv0: CV 643 644 # But in this example cv1 will not be detected as a ClassVar: 645 # @dataclass 646 # class C1: 647 # CV = ClassVar 648 # cv1: CV 649 650 # In C1, the code in this function (_is_type) will look up "CV" in 651 # the module and not find it, so it will not consider cv1 as a 652 # ClassVar. This is a fairly obscure corner case, and the best 653 # way to fix it would be to eval() the string "CV" with the 654 # correct global and local namespaces. However that would involve 655 # a eval() penalty for every single field of every dataclass 656 # that's defined. It was judged not worth it. 657 658 match = _MODULE_IDENTIFIER_RE.match(annotation) 659 if match: 660 ns = None 661 module_name = match.group(1) 662 if not module_name: 663 # No module name, assume the class's module did 664 # "from dataclasses import InitVar". 665 ns = sys.modules.get(cls.__module__).__dict__ 666 else: 667 # Look up module_name in the class's module. 668 module = sys.modules.get(cls.__module__) 669 if module and module.__dict__.get(module_name) is a_module: 670 ns = sys.modules.get(a_type.__module__).__dict__ 671 if ns and is_type_predicate(ns.get(match.group(2)), a_module): 672 return True 673 return False 674 675 676def _get_field(cls, a_name, a_type): 677 # Return a Field object for this field name and type. ClassVars 678 # and InitVars are also returned, but marked as such (see 679 # f._field_type). 680 681 # If the default value isn't derived from Field, then it's only a 682 # normal default value. Convert it to a Field(). 683 default = getattr(cls, a_name, MISSING) 684 if isinstance(default, Field): 685 f = default 686 else: 687 # TODO(T42989996): Uncomment when member descriptors are implemented 688 # if isinstance(default, types.MemberDescriptorType): 689 # # This is a field in __slots__, so it has no default value. 690 # default = MISSING 691 f = field(default=default) 692 693 # Only at this point do we know the name and the type. Set them. 694 f.name = a_name 695 f.type = a_type 696 697 # Assume it's a normal field until proven otherwise. We're next 698 # going to decide if it's a ClassVar or InitVar, everything else 699 # is just a normal field. 700 f._field_type = _FIELD 701 702 # In addition to checking for actual types here, also check for 703 # string annotations. get_type_hints() won't always work for us 704 # (see https://github.com/python/typing/issues/508 for example), 705 # plus it's expensive and would require an eval for every stirng 706 # annotation. So, make a best effort to see if this is a ClassVar 707 # or InitVar using regex's and checking that the thing referenced 708 # is actually of the correct type. 709 710 # For the complete discussion, see https://bugs.python.org/issue33453 711 712 # If typing has not been imported, then it's impossible for any 713 # annotation to be a ClassVar. So, only look for ClassVar if 714 # typing has been imported by any module (not necessarily cls's 715 # module). 716 typing = sys.modules.get('typing') 717 if typing: 718 if (_is_classvar(a_type, typing) 719 or (isinstance(f.type, str) 720 and _is_type(f.type, cls, typing, typing.ClassVar, 721 _is_classvar))): 722 f._field_type = _FIELD_CLASSVAR 723 724 # If the type is InitVar, or if it's a matching string annotation, 725 # then it's an InitVar. 726 if f._field_type is _FIELD: 727 # The module we're checking against is the module we're 728 # currently in (dataclasses.py). 729 dataclasses = sys.modules[__name__] 730 if (_is_initvar(a_type, dataclasses) 731 or (isinstance(f.type, str) 732 and _is_type(f.type, cls, dataclasses, dataclasses.InitVar, 733 _is_initvar))): 734 f._field_type = _FIELD_INITVAR 735 736 # Validations for individual fields. This is delayed until now, 737 # instead of in the Field() constructor, since only here do we 738 # know the field name, which allows for better error reporting. 739 740 # Special restrictions for ClassVar and InitVar. 741 if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR): 742 if f.default_factory is not MISSING: 743 raise TypeError(f'field {f.name} cannot have a ' 744 'default factory') 745 # Should I check for other field settings? default_factory 746 # seems the most serious to check for. Maybe add others. For 747 # example, how about init=False (or really, 748 # init=<not-the-default-init-value>)? It makes no sense for 749 # ClassVar and InitVar to specify init=<anything>. 750 751 # For real fields, disallow mutable defaults for known types. 752 if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)): 753 raise ValueError(f'mutable default {type(f.default)} for field ' 754 f'{f.name} is not allowed: use default_factory') 755 756 return f 757 758 759def _set_new_attribute(cls, name, value): 760 # Never overwrites an existing attribute. Returns True if the 761 # attribute already exists. 762 if name in cls.__dict__: 763 return True 764 setattr(cls, name, value) 765 return False 766 767 768# Decide if/how we're going to create a hash function. Key is 769# (unsafe_hash, eq, frozen, does-hash-exist). Value is the action to 770# take. The common case is to do nothing, so instead of providing a 771# function that is a no-op, use None to signify that. 772 773def _hash_set_none(cls, fields, globals): 774 return None 775 776def _hash_add(cls, fields, globals): 777 flds = [f for f in fields if (f.compare if f.hash is None else f.hash)] 778 return _hash_fn(flds, globals) 779 780def _hash_exception(cls, fields, globals): 781 # Raise an exception. 782 raise TypeError(f'Cannot overwrite attribute __hash__ ' 783 f'in class {cls.__name__}') 784 785# 786# +-------------------------------------- unsafe_hash? 787# | +------------------------------- eq? 788# | | +------------------------ frozen? 789# | | | +---------------- has-explicit-hash? 790# | | | | 791# | | | | +------- action 792# | | | | | 793# v v v v v 794_hash_action = {(False, False, False, False): None, 795 (False, False, False, True ): None, 796 (False, False, True, False): None, 797 (False, False, True, True ): None, 798 (False, True, False, False): _hash_set_none, 799 (False, True, False, True ): None, 800 (False, True, True, False): _hash_add, 801 (False, True, True, True ): None, 802 (True, False, False, False): _hash_add, 803 (True, False, False, True ): _hash_exception, 804 (True, False, True, False): _hash_add, 805 (True, False, True, True ): _hash_exception, 806 (True, True, False, False): _hash_add, 807 (True, True, False, True ): _hash_exception, 808 (True, True, True, False): _hash_add, 809 (True, True, True, True ): _hash_exception, 810 } 811# See https://bugs.python.org/issue32929#msg312829 for an if-statement 812# version of this table. 813 814 815def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen): 816 # Now that dicts retain insertion order, there's no reason to use 817 # an ordered dict. I am leveraging that ordering here, because 818 # derived class fields overwrite base class fields, but the order 819 # is defined by the base class, which is found first. 820 fields = {} 821 822 if cls.__module__ in sys.modules: 823 globals = sys.modules[cls.__module__].__dict__ 824 else: 825 # Theoretically this can happen if someone writes 826 # a custom string to cls.__module__. In which case 827 # such dataclass won't be fully introspectable 828 # (w.r.t. typing.get_type_hints) but will still function 829 # correctly. 830 globals = {} 831 832 setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order, 833 unsafe_hash, frozen)) 834 835 # Find our base classes in reverse MRO order, and exclude 836 # ourselves. In reversed order so that more derived classes 837 # override earlier field definitions in base classes. As long as 838 # we're iterating over them, see if any are frozen. 839 any_frozen_base = False 840 has_dataclass_bases = False 841 for b in cls.__mro__[-1:0:-1]: 842 # Only process classes that have been processed by our 843 # decorator. That is, they have a _FIELDS attribute. 844 base_fields = getattr(b, _FIELDS, None) 845 if base_fields: 846 has_dataclass_bases = True 847 for f in base_fields.values(): 848 fields[f.name] = f 849 if getattr(b, _PARAMS).frozen: 850 any_frozen_base = True 851 852 # Annotations that are defined in this class (not in base 853 # classes). If __annotations__ isn't present, then this class 854 # adds no new annotations. We use this to compute fields that are 855 # added by this class. 856 # 857 # Fields are found from cls_annotations, which is guaranteed to be 858 # ordered. Default values are from class attributes, if a field 859 # has a default. If the default value is a Field(), then it 860 # contains additional info beyond (and possibly including) the 861 # actual default value. Pseudo-fields ClassVars and InitVars are 862 # included, despite the fact that they're not real fields. That's 863 # dealt with later. 864 cls_annotations = cls.__dict__.get('__annotations__', {}) 865 866 # Now find fields in our class. While doing so, validate some 867 # things, and set the default values (as class attributes) where 868 # we can. 869 cls_fields = [_get_field(cls, name, type) 870 for name, type in cls_annotations.items()] 871 for f in cls_fields: 872 fields[f.name] = f 873 874 # If the class attribute (which is the default value for this 875 # field) exists and is of type 'Field', replace it with the 876 # real default. This is so that normal class introspection 877 # sees a real default value, not a Field. 878 if isinstance(getattr(cls, f.name, None), Field): 879 if f.default is MISSING: 880 # If there's no default, delete the class attribute. 881 # This happens if we specify field(repr=False), for 882 # example (that is, we specified a field object, but 883 # no default value). Also if we're using a default 884 # factory. The class attribute should not be set at 885 # all in the post-processed class. 886 delattr(cls, f.name) 887 else: 888 setattr(cls, f.name, f.default) 889 890 # Do we have any Field members that don't also have annotations? 891 for name, value in cls.__dict__.items(): 892 if isinstance(value, Field) and not name in cls_annotations: 893 raise TypeError(f'{name!r} is a field but has no type annotation') 894 895 # Check rules that apply if we are derived from any dataclasses. 896 if has_dataclass_bases: 897 # Raise an exception if any of our bases are frozen, but we're not. 898 if any_frozen_base and not frozen: 899 raise TypeError('cannot inherit non-frozen dataclass from a ' 900 'frozen one') 901 902 # Raise an exception if we're frozen, but none of our bases are. 903 if not any_frozen_base and frozen: 904 raise TypeError('cannot inherit frozen dataclass from a ' 905 'non-frozen one') 906 907 # Remember all of the fields on our class (including bases). This 908 # also marks this class as being a dataclass. 909 setattr(cls, _FIELDS, fields) 910 911 # Was this class defined with an explicit __hash__? Note that if 912 # __eq__ is defined in this class, then python will automatically 913 # set __hash__ to None. This is a heuristic, as it's possible 914 # that such a __hash__ == None was not auto-generated, but it 915 # close enough. 916 class_hash = cls.__dict__.get('__hash__', MISSING) 917 has_explicit_hash = not (class_hash is MISSING or 918 (class_hash is None and '__eq__' in cls.__dict__)) 919 920 # If we're generating ordering methods, we must be generating the 921 # eq methods. 922 if order and not eq: 923 raise ValueError('eq must be true if order is true') 924 925 if init: 926 # Does this class have a post-init function? 927 has_post_init = hasattr(cls, _POST_INIT_NAME) 928 929 # Include InitVars and regular fields (so, not ClassVars). 930 flds = [f for f in fields.values() 931 if f._field_type in (_FIELD, _FIELD_INITVAR)] 932 _set_new_attribute(cls, '__init__', 933 _init_fn(flds, 934 frozen, 935 has_post_init, 936 # The name to use for the "self" 937 # param in __init__. Use "self" 938 # if possible. 939 '__dataclass_self__' if 'self' in fields 940 else 'self', 941 globals, 942 )) 943 944 # Get the fields as a list, and include only real fields. This is 945 # used in all of the following methods. 946 field_list = [f for f in fields.values() if f._field_type is _FIELD] 947 948 if repr: 949 flds = [f for f in field_list if f.repr] 950 _set_new_attribute(cls, '__repr__', _repr_fn(flds, globals)) 951 952 if eq: 953 # Create _eq__ method. There's no need for a __ne__ method, 954 # since python will call __eq__ and negate it. 955 flds = [f for f in field_list if f.compare] 956 self_tuple = _tuple_str('self', flds) 957 other_tuple = _tuple_str('other', flds) 958 _set_new_attribute(cls, '__eq__', 959 _cmp_fn('__eq__', '==', 960 self_tuple, other_tuple, 961 globals=globals)) 962 963 if order: 964 # Create and set the ordering methods. 965 flds = [f for f in field_list if f.compare] 966 self_tuple = _tuple_str('self', flds) 967 other_tuple = _tuple_str('other', flds) 968 for name, op in [('__lt__', '<'), 969 ('__le__', '<='), 970 ('__gt__', '>'), 971 ('__ge__', '>='), 972 ]: 973 if _set_new_attribute(cls, name, 974 _cmp_fn(name, op, self_tuple, other_tuple, 975 globals=globals)): 976 raise TypeError(f'Cannot overwrite attribute {name} ' 977 f'in class {cls.__name__}. Consider using ' 978 'functools.total_ordering') 979 980 if frozen: 981 for fn in _frozen_get_del_attr(cls, field_list, globals): 982 if _set_new_attribute(cls, fn.__name__, fn): 983 raise TypeError(f'Cannot overwrite attribute {fn.__name__} ' 984 f'in class {cls.__name__}') 985 986 # Decide if/how we're going to create a hash function. 987 hash_action = _hash_action[bool(unsafe_hash), 988 bool(eq), 989 bool(frozen), 990 has_explicit_hash] 991 if hash_action: 992 # No need to call _set_new_attribute here, since by the time 993 # we're here the overwriting is unconditional. 994 cls.__hash__ = hash_action(cls, field_list, globals) 995 996 if not getattr(cls, '__doc__'): 997 # Create a class doc-string. 998 # TODO(T63180083): Uncomment below when inspect.signature is implemented 999 # cls.__doc__ = (cls.__name__ + 1000 # str(inspect.signature(cls)).replace(' -> None', '')) 1001 pass 1002 1003 return cls 1004 1005 1006def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, 1007 unsafe_hash=False, frozen=False): 1008 """Returns the same class as was passed in, with dunder methods 1009 added based on the fields defined in the class. 1010 1011 Examines PEP 526 __annotations__ to determine fields. 1012 1013 If init is true, an __init__() method is added to the class. If 1014 repr is true, a __repr__() method is added. If order is true, rich 1015 comparison dunder methods are added. If unsafe_hash is true, a 1016 __hash__() method function is added. If frozen is true, fields may 1017 not be assigned to after instance creation. 1018 """ 1019 1020 def wrap(cls): 1021 return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen) 1022 1023 # See if we're being called as @dataclass or @dataclass(). 1024 if cls is None: 1025 # We're called with parens. 1026 return wrap 1027 1028 # We're called as @dataclass without parens. 1029 return wrap(cls) 1030 1031 1032def fields(class_or_instance): 1033 """Return a tuple describing the fields of this dataclass. 1034 1035 Accepts a dataclass or an instance of one. Tuple elements are of 1036 type Field. 1037 """ 1038 1039 # Might it be worth caching this, per class? 1040 try: 1041 fields = getattr(class_or_instance, _FIELDS) 1042 except AttributeError: 1043 raise TypeError('must be called with a dataclass type or instance') 1044 1045 # Exclude pseudo-fields. Note that fields is sorted by insertion 1046 # order, so the order of the tuple is as the fields were defined. 1047 return tuple(f for f in fields.values() if f._field_type is _FIELD) 1048 1049 1050def _is_dataclass_instance(obj): 1051 """Returns True if obj is an instance of a dataclass.""" 1052 return hasattr(type(obj), _FIELDS) 1053 1054 1055def is_dataclass(obj): 1056 """Returns True if obj is a dataclass or an instance of a 1057 dataclass.""" 1058 cls = obj if isinstance(obj, type) else type(obj) 1059 return hasattr(cls, _FIELDS) 1060 1061 1062def asdict(obj, *, dict_factory=dict): 1063 """Return the fields of a dataclass instance as a new dictionary mapping 1064 field names to field values. 1065 1066 Example usage: 1067 1068 @dataclass 1069 class C: 1070 x: int 1071 y: int 1072 1073 c = C(1, 2) 1074 assert asdict(c) == {'x': 1, 'y': 2} 1075 1076 If given, 'dict_factory' will be used instead of built-in dict. 1077 The function applies recursively to field values that are 1078 dataclass instances. This will also look into built-in containers: 1079 tuples, lists, and dicts. 1080 """ 1081 if not _is_dataclass_instance(obj): 1082 raise TypeError("asdict() should be called on dataclass instances") 1083 return _asdict_inner(obj, dict_factory) 1084 1085 1086def _asdict_inner(obj, dict_factory): 1087 if _is_dataclass_instance(obj): 1088 result = [] 1089 for f in fields(obj): 1090 value = _asdict_inner(getattr(obj, f.name), dict_factory) 1091 result.append((f.name, value)) 1092 return dict_factory(result) 1093 elif isinstance(obj, tuple) and hasattr(obj, '_fields'): 1094 # obj is a namedtuple. Recurse into it, but the returned 1095 # object is another namedtuple of the same type. This is 1096 # similar to how other list- or tuple-derived classes are 1097 # treated (see below), but we just need to create them 1098 # differently because a namedtuple's __init__ needs to be 1099 # called differently (see bpo-34363). 1100 1101 # I'm not using namedtuple's _asdict() 1102 # method, because: 1103 # - it does not recurse in to the namedtuple fields and 1104 # convert them to dicts (using dict_factory). 1105 # - I don't actually want to return a dict here. The main 1106 # use case here is json.dumps, and it handles converting 1107 # namedtuples to lists. Admittedly we're losing some 1108 # information here when we produce a json list instead of a 1109 # dict. Note that if we returned dicts here instead of 1110 # namedtuples, we could no longer call asdict() on a data 1111 # structure where a namedtuple was used as a dict key. 1112 1113 return type(obj)(*[_asdict_inner(v, dict_factory) for v in obj]) 1114 elif isinstance(obj, (list, tuple)): 1115 # Assume we can create an object of this type by passing in a 1116 # generator (which is not true for namedtuples, handled 1117 # above). 1118 return type(obj)(_asdict_inner(v, dict_factory) for v in obj) 1119 elif isinstance(obj, dict): 1120 return type(obj)((_asdict_inner(k, dict_factory), 1121 _asdict_inner(v, dict_factory)) 1122 for k, v in obj.items()) 1123 else: 1124 return copy.deepcopy(obj) 1125 1126 1127def astuple(obj, *, tuple_factory=tuple): 1128 """Return the fields of a dataclass instance as a new tuple of field values. 1129 1130 Example usage:: 1131 1132 @dataclass 1133 class C: 1134 x: int 1135 y: int 1136 1137 c = C(1, 2) 1138 assert astuple(c) == (1, 2) 1139 1140 If given, 'tuple_factory' will be used instead of built-in tuple. 1141 The function applies recursively to field values that are 1142 dataclass instances. This will also look into built-in containers: 1143 tuples, lists, and dicts. 1144 """ 1145 1146 if not _is_dataclass_instance(obj): 1147 raise TypeError("astuple() should be called on dataclass instances") 1148 return _astuple_inner(obj, tuple_factory) 1149 1150 1151def _astuple_inner(obj, tuple_factory): 1152 if _is_dataclass_instance(obj): 1153 result = [] 1154 for f in fields(obj): 1155 value = _astuple_inner(getattr(obj, f.name), tuple_factory) 1156 result.append(value) 1157 return tuple_factory(result) 1158 elif isinstance(obj, tuple) and hasattr(obj, '_fields'): 1159 # obj is a namedtuple. Recurse into it, but the returned 1160 # object is another namedtuple of the same type. This is 1161 # similar to how other list- or tuple-derived classes are 1162 # treated (see below), but we just need to create them 1163 # differently because a namedtuple's __init__ needs to be 1164 # called differently (see bpo-34363). 1165 return type(obj)(*[_astuple_inner(v, tuple_factory) for v in obj]) 1166 elif isinstance(obj, (list, tuple)): 1167 # Assume we can create an object of this type by passing in a 1168 # generator (which is not true for namedtuples, handled 1169 # above). 1170 return type(obj)(_astuple_inner(v, tuple_factory) for v in obj) 1171 elif isinstance(obj, dict): 1172 return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory)) 1173 for k, v in obj.items()) 1174 else: 1175 return copy.deepcopy(obj) 1176 1177 1178def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, 1179 repr=True, eq=True, order=False, unsafe_hash=False, 1180 frozen=False): 1181 """Return a new dynamically created dataclass. 1182 1183 The dataclass name will be 'cls_name'. 'fields' is an iterable 1184 of either (name), (name, type) or (name, type, Field) objects. If type is 1185 omitted, use the string 'typing.Any'. Field objects are created by 1186 the equivalent of calling 'field(name, type [, Field-info])'. 1187 1188 C = make_dataclass('C', ['x', ('y', int), ('z', int, field(init=False))], bases=(Base,)) 1189 1190 is equivalent to: 1191 1192 @dataclass 1193 class C(Base): 1194 x: 'typing.Any' 1195 y: int 1196 z: int = field(init=False) 1197 1198 For the bases and namespace parameters, see the builtin type() function. 1199 1200 The parameters init, repr, eq, order, unsafe_hash, and frozen are passed to 1201 dataclass(). 1202 """ 1203 1204 if namespace is None: 1205 namespace = {} 1206 else: 1207 # Copy namespace since we're going to mutate it. 1208 namespace = namespace.copy() 1209 1210 # While we're looking through the field names, validate that they 1211 # are identifiers, are not keywords, and not duplicates. 1212 seen = set() 1213 anns = {} 1214 for item in fields: 1215 if isinstance(item, str): 1216 name = item 1217 tp = 'typing.Any' 1218 elif len(item) == 2: 1219 name, tp, = item 1220 elif len(item) == 3: 1221 name, tp, spec = item 1222 namespace[name] = spec 1223 else: 1224 raise TypeError(f'Invalid field: {item!r}') 1225 1226 if not isinstance(name, str) or not name.isidentifier(): 1227 raise TypeError(f'Field names must be valid identifiers: {name!r}') 1228 if keyword.iskeyword(name): 1229 raise TypeError(f'Field names must not be keywords: {name!r}') 1230 if name in seen: 1231 raise TypeError(f'Field name duplicated: {name!r}') 1232 1233 seen.add(name) 1234 anns[name] = tp 1235 1236 namespace['__annotations__'] = anns 1237 # We use `types.new_class()` instead of simply `type()` to allow dynamic creation 1238 # of generic dataclassses. 1239 cls = types.new_class(cls_name, bases, {}, lambda ns: ns.update(namespace)) 1240 return dataclass(cls, init=init, repr=repr, eq=eq, order=order, 1241 unsafe_hash=unsafe_hash, frozen=frozen) 1242 1243 1244def replace(*args, **changes): 1245 """Return a new object replacing specified fields with new values. 1246 1247 This is especially useful for frozen classes. Example usage: 1248 1249 @dataclass(frozen=True) 1250 class C: 1251 x: int 1252 y: int 1253 1254 c = C(1, 2) 1255 c1 = replace(c, x=3) 1256 assert c1.x == 3 and c1.y == 2 1257 """ 1258 if len(args) > 1: 1259 raise TypeError(f'replace() takes 1 positional argument but {len(args)} were given') 1260 if args: 1261 obj, = args 1262 elif 'obj' in changes: 1263 obj = changes.pop('obj') 1264 import warnings 1265 warnings.warn("Passing 'obj' as keyword argument is deprecated", 1266 DeprecationWarning, stacklevel=2) 1267 else: 1268 raise TypeError("replace() missing 1 required positional argument: 'obj'") 1269 1270 # We're going to mutate 'changes', but that's okay because it's a 1271 # new dict, even if called with 'replace(obj, **my_changes)'. 1272 1273 if not _is_dataclass_instance(obj): 1274 raise TypeError("replace() should be called on dataclass instances") 1275 1276 # It's an error to have init=False fields in 'changes'. 1277 # If a field is not in 'changes', read its value from the provided obj. 1278 1279 for f in getattr(obj, _FIELDS).values(): 1280 # Only consider normal fields or InitVars. 1281 if f._field_type is _FIELD_CLASSVAR: 1282 continue 1283 1284 if not f.init: 1285 # Error if this field is specified in changes. 1286 if f.name in changes: 1287 raise ValueError(f'field {f.name} is declared with ' 1288 'init=False, it cannot be specified with ' 1289 'replace()') 1290 continue 1291 1292 if f.name not in changes: 1293 if f._field_type is _FIELD_INITVAR: 1294 raise ValueError(f"InitVar {f.name!r} " 1295 'must be specified with replace()') 1296 changes[f.name] = getattr(obj, f.name) 1297 1298 # Create the new object, which calls __init__() and 1299 # __post_init__() (if defined), using all of the init fields we've 1300 # added and/or left in 'changes'. If there are values supplied in 1301 # changes that aren't fields, this will correctly raise a 1302 # TypeError. 1303 return obj.__class__(**changes) 1304replace.__text_signature__ = '(obj, /, **kwargs)'