this repo has no description
1#!/usr/bin/env python3
2# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
3# WARNING: This is a temporary copy of code from the cpython library to
4# facilitate bringup. Please file a task for anything you change!
5# flake8: noqa
6# fmt: off
7
8import _thread
9import builtins
10import copy
11import functools
12import inspect
13import keyword
14import re
15import sys
16import types
17
18
19__all__ = ['dataclass',
20 'field',
21 'Field',
22 'FrozenInstanceError',
23 'InitVar',
24 'MISSING',
25
26 # Helper functions.
27 'fields',
28 'asdict',
29 'astuple',
30 'make_dataclass',
31 'replace',
32 'is_dataclass',
33 ]
34
35# Conditions for adding methods. The boxes indicate what action the
36# dataclass decorator takes. For all of these tables, when I talk
37# about init=, repr=, eq=, order=, unsafe_hash=, or frozen=, I'm
38# referring to the arguments to the @dataclass decorator. When
39# checking if a dunder method already exists, I mean check for an
40# entry in the class's __dict__. I never check to see if an attribute
41# is defined in a base class.
42
43# Key:
44# +=========+=========================================+
45# + Value | Meaning |
46# +=========+=========================================+
47# | <blank> | No action: no method is added. |
48# +---------+-----------------------------------------+
49# | add | Generated method is added. |
50# +---------+-----------------------------------------+
51# | raise | TypeError is raised. |
52# +---------+-----------------------------------------+
53# | None | Attribute is set to None. |
54# +=========+=========================================+
55
56# __init__
57#
58# +--- init= parameter
59# |
60# v | | |
61# | no | yes | <--- class has __init__ in __dict__?
62# +=======+=======+=======+
63# | False | | |
64# +-------+-------+-------+
65# | True | add | | <- the default
66# +=======+=======+=======+
67
68# __repr__
69#
70# +--- repr= parameter
71# |
72# v | | |
73# | no | yes | <--- class has __repr__ in __dict__?
74# +=======+=======+=======+
75# | False | | |
76# +-------+-------+-------+
77# | True | add | | <- the default
78# +=======+=======+=======+
79
80
81# __setattr__
82# __delattr__
83#
84# +--- frozen= parameter
85# |
86# v | | |
87# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__?
88# +=======+=======+=======+
89# | False | | | <- the default
90# +-------+-------+-------+
91# | True | add | raise |
92# +=======+=======+=======+
93# Raise because not adding these methods would break the "frozen-ness"
94# of the class.
95
96# __eq__
97#
98# +--- eq= parameter
99# |
100# v | | |
101# | no | yes | <--- class has __eq__ in __dict__?
102# +=======+=======+=======+
103# | False | | |
104# +-------+-------+-------+
105# | True | add | | <- the default
106# +=======+=======+=======+
107
108# __lt__
109# __le__
110# __gt__
111# __ge__
112#
113# +--- order= parameter
114# |
115# v | | |
116# | no | yes | <--- class has any comparison method in __dict__?
117# +=======+=======+=======+
118# | False | | | <- the default
119# +-------+-------+-------+
120# | True | add | raise |
121# +=======+=======+=======+
122# Raise because to allow this case would interfere with using
123# functools.total_ordering.
124
125# __hash__
126
127# +------------------- unsafe_hash= parameter
128# | +----------- eq= parameter
129# | | +--- frozen= parameter
130# | | |
131# v v v | | |
132# | no | yes | <--- class has explicitly defined __hash__
133# +=======+=======+=======+========+========+
134# | False | False | False | | | No __eq__, use the base class __hash__
135# +-------+-------+-------+--------+--------+
136# | False | False | True | | | No __eq__, use the base class __hash__
137# +-------+-------+-------+--------+--------+
138# | False | True | False | None | | <-- the default, not hashable
139# +-------+-------+-------+--------+--------+
140# | False | True | True | add | | Frozen, so hashable, allows override
141# +-------+-------+-------+--------+--------+
142# | True | False | False | add | raise | Has no __eq__, but hashable
143# +-------+-------+-------+--------+--------+
144# | True | False | True | add | raise | Has no __eq__, but hashable
145# +-------+-------+-------+--------+--------+
146# | True | True | False | add | raise | Not frozen, but hashable
147# +-------+-------+-------+--------+--------+
148# | True | True | True | add | raise | Frozen, so hashable
149# +=======+=======+=======+========+========+
150# For boxes that are blank, __hash__ is untouched and therefore
151# inherited from the base class. If the base is object, then
152# id-based hashing is used.
153#
154# Note that a class may already have __hash__=None if it specified an
155# __eq__ method in the class body (not one that was created by
156# @dataclass).
157#
158# See _hash_action (below) for a coded version of this table.
159
160
161# Raised when an attempt is made to modify a frozen class.
162class FrozenInstanceError(AttributeError): pass
163
164# A sentinel object for default values to signal that a default
165# factory will be used. This is given a nice repr() which will appear
166# in the function signature of dataclasses' constructors.
167class _HAS_DEFAULT_FACTORY_CLASS:
168 def __repr__(self):
169 return '<factory>'
170_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
171
172# A sentinel object to detect if a parameter is supplied or not. Use
173# a class to give it a better repr.
174class _MISSING_TYPE:
175 pass
176MISSING = _MISSING_TYPE()
177
178# Since most per-field metadata will be unused, create an empty
179# read-only proxy that can be shared among all fields.
180_EMPTY_METADATA = types.MappingProxyType({})
181
182# Markers for the various kinds of fields and pseudo-fields.
183class _FIELD_BASE:
184 def __init__(self, name):
185 self.name = name
186 def __repr__(self):
187 return self.name
188_FIELD = _FIELD_BASE('_FIELD')
189_FIELD_CLASSVAR = _FIELD_BASE('_FIELD_CLASSVAR')
190_FIELD_INITVAR = _FIELD_BASE('_FIELD_INITVAR')
191
192# The name of an attribute on the class where we store the Field
193# objects. Also used to check if a class is a Data Class.
194_FIELDS = '__dataclass_fields__'
195
196# The name of an attribute on the class that stores the parameters to
197# @dataclass.
198_PARAMS = '__dataclass_params__'
199
200# The name of the function, that if it exists, is called at the end of
201# __init__.
202_POST_INIT_NAME = '__post_init__'
203
204# String regex that string annotations for ClassVar or InitVar must match.
205# Allows "identifier.identifier[" or "identifier[".
206# https://bugs.python.org/issue33453 for details.
207_MODULE_IDENTIFIER_RE = re.compile(r'^(?:\s*(\w+)\s*\.)?\s*(\w+)')
208
209class _InitVarMeta(type):
210 def __getitem__(self, params):
211 return InitVar(params)
212
213class InitVar(metaclass=_InitVarMeta):
214 __slots__ = ('type', )
215
216 def __init__(self, type):
217 self.type = type
218
219 def __repr__(self):
220 if isinstance(self.type, type):
221 type_name = self.type.__name__
222 else:
223 # typing objects, e.g. List[int]
224 type_name = repr(self.type)
225 return f'dataclasses.InitVar[{type_name}]'
226
227
228# Instances of Field are only ever created from within this module,
229# and only from the field() function, although Field instances are
230# exposed externally as (conceptually) read-only objects.
231#
232# name and type are filled in after the fact, not in __init__.
233# They're not known at the time this class is instantiated, but it's
234# convenient if they're available later.
235#
236# When cls._FIELDS is filled in with a list of Field objects, the name
237# and type fields will have been populated.
238class Field:
239 __slots__ = ('name',
240 'type',
241 'default',
242 'default_factory',
243 'repr',
244 'hash',
245 'init',
246 'compare',
247 'metadata',
248 '_field_type', # Private: not to be used by user code.
249 )
250
251 def __init__(self, default, default_factory, init, repr, hash, compare,
252 metadata):
253 self.name = None
254 self.type = None
255 self.default = default
256 self.default_factory = default_factory
257 self.init = init
258 self.repr = repr
259 self.hash = hash
260 self.compare = compare
261 self.metadata = (_EMPTY_METADATA
262 if metadata is None else
263 types.MappingProxyType(metadata))
264 self._field_type = None
265
266 def __repr__(self):
267 return ('Field('
268 f'name={self.name!r},'
269 f'type={self.type!r},'
270 f'default={self.default!r},'
271 f'default_factory={self.default_factory!r},'
272 f'init={self.init!r},'
273 f'repr={self.repr!r},'
274 f'hash={self.hash!r},'
275 f'compare={self.compare!r},'
276 f'metadata={self.metadata!r},'
277 f'_field_type={self._field_type}'
278 ')')
279
280 # This is used to support the PEP 487 __set_name__ protocol in the
281 # case where we're using a field that contains a descriptor as a
282 # default value. For details on __set_name__, see
283 # https://www.python.org/dev/peps/pep-0487/#implementation-details.
284 #
285 # Note that in _process_class, this Field object is overwritten
286 # with the default value, so the end result is a descriptor that
287 # had __set_name__ called on it at the right time.
288 def __set_name__(self, owner, name):
289 func = getattr(type(self.default), '__set_name__', None)
290 if func:
291 # There is a __set_name__ method on the descriptor, call
292 # it.
293 func(self.default, owner, name)
294
295
296class _DataclassParams:
297 __slots__ = ('init',
298 'repr',
299 'eq',
300 'order',
301 'unsafe_hash',
302 'frozen',
303 )
304
305 def __init__(self, init, repr, eq, order, unsafe_hash, frozen):
306 self.init = init
307 self.repr = repr
308 self.eq = eq
309 self.order = order
310 self.unsafe_hash = unsafe_hash
311 self.frozen = frozen
312
313 def __repr__(self):
314 return ('_DataclassParams('
315 f'init={self.init!r},'
316 f'repr={self.repr!r},'
317 f'eq={self.eq!r},'
318 f'order={self.order!r},'
319 f'unsafe_hash={self.unsafe_hash!r},'
320 f'frozen={self.frozen!r}'
321 ')')
322
323
324# This function is used instead of exposing Field creation directly,
325# so that a type checker can be told (via overloads) that this is a
326# function whose type depends on its parameters.
327def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
328 hash=None, compare=True, metadata=None):
329 """Return an object to identify dataclass fields.
330
331 default is the default value of the field. default_factory is a
332 0-argument function called to initialize a field's value. If init
333 is True, the field will be a parameter to the class's __init__()
334 function. If repr is True, the field will be included in the
335 object's repr(). If hash is True, the field will be included in
336 the object's hash(). If compare is True, the field will be used
337 in comparison functions. metadata, if specified, must be a
338 mapping which is stored but not otherwise examined by dataclass.
339
340 It is an error to specify both default and default_factory.
341 """
342
343 if default is not MISSING and default_factory is not MISSING:
344 raise ValueError('cannot specify both default and default_factory')
345 return Field(default, default_factory, init, repr, hash, compare,
346 metadata)
347
348
349def _tuple_str(obj_name, fields):
350 # Return a string representing each field of obj_name as a tuple
351 # member. So, if fields is ['x', 'y'] and obj_name is "self",
352 # return "(self.x,self.y)".
353
354 # Special case for the 0-tuple.
355 if not fields:
356 return '()'
357 # Note the trailing comma, needed if this turns out to be a 1-tuple.
358 return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
359
360
361# This function's logic is copied from "recursive_repr" function in
362# reprlib module to avoid dependency.
363def _recursive_repr(user_function):
364 # Decorator to make a repr function return "..." for a recursive
365 # call.
366 repr_running = set()
367
368 @functools.wraps(user_function)
369 def wrapper(self):
370 key = id(self), _thread.get_ident()
371 if key in repr_running:
372 return '...'
373 repr_running.add(key)
374 try:
375 result = user_function(self)
376 finally:
377 repr_running.discard(key)
378 return result
379 return wrapper
380
381
382def _create_fn(name, args, body, *, globals=None, locals=None,
383 return_type=MISSING):
384 # Note that we mutate locals when exec() is called. Caller
385 # beware! The only callers are internal to this module, so no
386 # worries about external callers.
387 if locals is None:
388 locals = {}
389 if 'BUILTINS' not in locals:
390 locals['BUILTINS'] = builtins
391 return_annotation = ''
392 if return_type is not MISSING:
393 locals['_return_type'] = return_type
394 return_annotation = '->_return_type'
395 args = ','.join(args)
396 body = '\n'.join(f' {b}' for b in body)
397
398 # Compute the text of the entire function.
399 txt = f' def {name}({args}){return_annotation}:\n{body}'
400
401 local_vars = ', '.join(locals.keys())
402 txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
403
404 ns = {}
405 exec(txt, globals, ns)
406 return ns['__create_fn__'](**locals)
407
408
409def _field_assign(frozen, name, value, self_name):
410 # If we're a frozen class, then assign to our fields in __init__
411 # via object.__setattr__. Otherwise, just use a simple
412 # assignment.
413 #
414 # self_name is what "self" is called in this function: don't
415 # hard-code "self", since that might be a field name.
416 if frozen:
417 return f'BUILTINS.object.__setattr__({self_name},{name!r},{value})'
418 return f'{self_name}.{name}={value}'
419
420
421def _field_init(f, frozen, globals, self_name):
422 # Return the text of the line in the body of __init__ that will
423 # initialize this field.
424
425 default_name = f'_dflt_{f.name}'
426 if f.default_factory is not MISSING:
427 if f.init:
428 # This field has a default factory. If a parameter is
429 # given, use it. If not, call the factory.
430 globals[default_name] = f.default_factory
431 value = (f'{default_name}() '
432 f'if {f.name} is _HAS_DEFAULT_FACTORY '
433 f'else {f.name}')
434 else:
435 # This is a field that's not in the __init__ params, but
436 # has a default factory function. It needs to be
437 # initialized here by calling the factory function,
438 # because there's no other way to initialize it.
439
440 # For a field initialized with a default=defaultvalue, the
441 # class dict just has the default value
442 # (cls.fieldname=defaultvalue). But that won't work for a
443 # default factory, the factory must be called in __init__
444 # and we must assign that to self.fieldname. We can't
445 # fall back to the class dict's value, both because it's
446 # not set, and because it might be different per-class
447 # (which, after all, is why we have a factory function!).
448
449 globals[default_name] = f.default_factory
450 value = f'{default_name}()'
451 else:
452 # No default factory.
453 if f.init:
454 if f.default is MISSING:
455 # There's no default, just do an assignment.
456 value = f.name
457 elif f.default is not MISSING:
458 globals[default_name] = f.default
459 value = f.name
460 else:
461 # This field does not need initialization. Signify that
462 # to the caller by returning None.
463 return None
464
465 # Only test this now, so that we can create variables for the
466 # default. However, return None to signify that we're not going
467 # to actually do the assignment statement for InitVars.
468 if f._field_type is _FIELD_INITVAR:
469 return None
470
471 # Now, actually generate the field assignment.
472 return _field_assign(frozen, f.name, value, self_name)
473
474
475def _init_param(f):
476 # Return the __init__ parameter string for this field. For
477 # example, the equivalent of 'x:int=3' (except instead of 'int',
478 # reference a variable set to int, and instead of '3', reference a
479 # variable set to 3).
480 if f.default is MISSING and f.default_factory is MISSING:
481 # There's no default, and no default_factory, just output the
482 # variable name and type.
483 default = ''
484 elif f.default is not MISSING:
485 # There's a default, this will be the name that's used to look
486 # it up.
487 default = f'=_dflt_{f.name}'
488 elif f.default_factory is not MISSING:
489 # There's a factory function. Set a marker.
490 default = '=_HAS_DEFAULT_FACTORY'
491 return f'{f.name}:_type_{f.name}{default}'
492
493
494def _init_fn(fields, frozen, has_post_init, self_name, globals):
495 # fields contains both real fields and InitVar pseudo-fields.
496
497 # Make sure we don't have fields without defaults following fields
498 # with defaults. This actually would be caught when exec-ing the
499 # function source code, but catching it here gives a better error
500 # message, and future-proofs us in case we build up the function
501 # using ast.
502 seen_default = False
503 for f in fields:
504 # Only consider fields in the __init__ call.
505 if f.init:
506 if not (f.default is MISSING and f.default_factory is MISSING):
507 seen_default = True
508 elif seen_default:
509 raise TypeError(f'non-default argument {f.name!r} '
510 'follows default argument')
511
512 locals = {f'_type_{f.name}': f.type for f in fields}
513 locals.update({
514 'MISSING': MISSING,
515 '_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY,
516 })
517
518 body_lines = []
519 for f in fields:
520 line = _field_init(f, frozen, locals, self_name)
521 # line is None means that this field doesn't require
522 # initialization (it's a pseudo-field). Just skip it.
523 if line:
524 body_lines.append(line)
525
526 # Does this class have a post-init function?
527 if has_post_init:
528 params_str = ','.join(f.name for f in fields
529 if f._field_type is _FIELD_INITVAR)
530 body_lines.append(f'{self_name}.{_POST_INIT_NAME}({params_str})')
531
532 # If no body lines, use 'pass'.
533 if not body_lines:
534 body_lines = ['pass']
535
536 return _create_fn('__init__',
537 [self_name] + [_init_param(f) for f in fields if f.init],
538 body_lines,
539 locals=locals,
540 globals=globals,
541 return_type=None)
542
543
544def _repr_fn(fields, globals):
545 fn = _create_fn('__repr__',
546 ('self',),
547 ['return self.__class__.__qualname__ + f"(' +
548 ', '.join([f"{f.name}={{self.{f.name}!r}}"
549 for f in fields]) +
550 ')"'],
551 globals=globals)
552 return _recursive_repr(fn)
553
554
555def _frozen_get_del_attr(cls, fields, globals):
556 locals = {'cls': cls,
557 'FrozenInstanceError': FrozenInstanceError}
558 if fields:
559 fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
560 else:
561 # Special case for the zero-length tuple.
562 fields_str = '()'
563 return (_create_fn('__setattr__',
564 ('self', 'name', 'value'),
565 (f'if type(self) is cls or name in {fields_str}:',
566 ' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
567 f'super(cls, self).__setattr__(name, value)'),
568 locals=locals,
569 globals=globals),
570 _create_fn('__delattr__',
571 ('self', 'name'),
572 (f'if type(self) is cls or name in {fields_str}:',
573 ' raise FrozenInstanceError(f"cannot delete field {name!r}")',
574 f'super(cls, self).__delattr__(name)'),
575 locals=locals,
576 globals=globals),
577 )
578
579
580def _cmp_fn(name, op, self_tuple, other_tuple, globals):
581 # Create a comparison function. If the fields in the object are
582 # named 'x' and 'y', then self_tuple is the string
583 # '(self.x,self.y)' and other_tuple is the string
584 # '(other.x,other.y)'.
585
586 return _create_fn(name,
587 ('self', 'other'),
588 [ 'if other.__class__ is self.__class__:',
589 f' return {self_tuple}{op}{other_tuple}',
590 'return NotImplemented'],
591 globals=globals)
592
593
594def _hash_fn(fields, globals):
595 self_tuple = _tuple_str('self', fields)
596 return _create_fn('__hash__',
597 ('self',),
598 [f'return hash({self_tuple})'],
599 globals=globals)
600
601
602def _is_classvar(a_type, typing):
603 # This test uses a typing internal class, but it's the best way to
604 # test if this is a ClassVar.
605 return (a_type is typing.ClassVar
606 or (type(a_type) is typing._GenericAlias
607 and a_type.__origin__ is typing.ClassVar))
608
609
610def _is_initvar(a_type, dataclasses):
611 # The module we're checking against is the module we're
612 # currently in (dataclasses.py).
613 return (a_type is dataclasses.InitVar
614 or type(a_type) is dataclasses.InitVar)
615
616
617def _is_type(annotation, cls, a_module, a_type, is_type_predicate):
618 # Given a type annotation string, does it refer to a_type in
619 # a_module? For example, when checking that annotation denotes a
620 # ClassVar, then a_module is typing, and a_type is
621 # typing.ClassVar.
622
623 # It's possible to look up a_module given a_type, but it involves
624 # looking in sys.modules (again!), and seems like a waste since
625 # the caller already knows a_module.
626
627 # - annotation is a string type annotation
628 # - cls is the class that this annotation was found in
629 # - a_module is the module we want to match
630 # - a_type is the type in that module we want to match
631 # - is_type_predicate is a function called with (obj, a_module)
632 # that determines if obj is of the desired type.
633
634 # Since this test does not do a local namespace lookup (and
635 # instead only a module (global) lookup), there are some things it
636 # gets wrong.
637
638 # With string annotations, cv0 will be detected as a ClassVar:
639 # CV = ClassVar
640 # @dataclass
641 # class C0:
642 # cv0: CV
643
644 # But in this example cv1 will not be detected as a ClassVar:
645 # @dataclass
646 # class C1:
647 # CV = ClassVar
648 # cv1: CV
649
650 # In C1, the code in this function (_is_type) will look up "CV" in
651 # the module and not find it, so it will not consider cv1 as a
652 # ClassVar. This is a fairly obscure corner case, and the best
653 # way to fix it would be to eval() the string "CV" with the
654 # correct global and local namespaces. However that would involve
655 # a eval() penalty for every single field of every dataclass
656 # that's defined. It was judged not worth it.
657
658 match = _MODULE_IDENTIFIER_RE.match(annotation)
659 if match:
660 ns = None
661 module_name = match.group(1)
662 if not module_name:
663 # No module name, assume the class's module did
664 # "from dataclasses import InitVar".
665 ns = sys.modules.get(cls.__module__).__dict__
666 else:
667 # Look up module_name in the class's module.
668 module = sys.modules.get(cls.__module__)
669 if module and module.__dict__.get(module_name) is a_module:
670 ns = sys.modules.get(a_type.__module__).__dict__
671 if ns and is_type_predicate(ns.get(match.group(2)), a_module):
672 return True
673 return False
674
675
676def _get_field(cls, a_name, a_type):
677 # Return a Field object for this field name and type. ClassVars
678 # and InitVars are also returned, but marked as such (see
679 # f._field_type).
680
681 # If the default value isn't derived from Field, then it's only a
682 # normal default value. Convert it to a Field().
683 default = getattr(cls, a_name, MISSING)
684 if isinstance(default, Field):
685 f = default
686 else:
687 # TODO(T42989996): Uncomment when member descriptors are implemented
688 # if isinstance(default, types.MemberDescriptorType):
689 # # This is a field in __slots__, so it has no default value.
690 # default = MISSING
691 f = field(default=default)
692
693 # Only at this point do we know the name and the type. Set them.
694 f.name = a_name
695 f.type = a_type
696
697 # Assume it's a normal field until proven otherwise. We're next
698 # going to decide if it's a ClassVar or InitVar, everything else
699 # is just a normal field.
700 f._field_type = _FIELD
701
702 # In addition to checking for actual types here, also check for
703 # string annotations. get_type_hints() won't always work for us
704 # (see https://github.com/python/typing/issues/508 for example),
705 # plus it's expensive and would require an eval for every stirng
706 # annotation. So, make a best effort to see if this is a ClassVar
707 # or InitVar using regex's and checking that the thing referenced
708 # is actually of the correct type.
709
710 # For the complete discussion, see https://bugs.python.org/issue33453
711
712 # If typing has not been imported, then it's impossible for any
713 # annotation to be a ClassVar. So, only look for ClassVar if
714 # typing has been imported by any module (not necessarily cls's
715 # module).
716 typing = sys.modules.get('typing')
717 if typing:
718 if (_is_classvar(a_type, typing)
719 or (isinstance(f.type, str)
720 and _is_type(f.type, cls, typing, typing.ClassVar,
721 _is_classvar))):
722 f._field_type = _FIELD_CLASSVAR
723
724 # If the type is InitVar, or if it's a matching string annotation,
725 # then it's an InitVar.
726 if f._field_type is _FIELD:
727 # The module we're checking against is the module we're
728 # currently in (dataclasses.py).
729 dataclasses = sys.modules[__name__]
730 if (_is_initvar(a_type, dataclasses)
731 or (isinstance(f.type, str)
732 and _is_type(f.type, cls, dataclasses, dataclasses.InitVar,
733 _is_initvar))):
734 f._field_type = _FIELD_INITVAR
735
736 # Validations for individual fields. This is delayed until now,
737 # instead of in the Field() constructor, since only here do we
738 # know the field name, which allows for better error reporting.
739
740 # Special restrictions for ClassVar and InitVar.
741 if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR):
742 if f.default_factory is not MISSING:
743 raise TypeError(f'field {f.name} cannot have a '
744 'default factory')
745 # Should I check for other field settings? default_factory
746 # seems the most serious to check for. Maybe add others. For
747 # example, how about init=False (or really,
748 # init=<not-the-default-init-value>)? It makes no sense for
749 # ClassVar and InitVar to specify init=<anything>.
750
751 # For real fields, disallow mutable defaults for known types.
752 if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)):
753 raise ValueError(f'mutable default {type(f.default)} for field '
754 f'{f.name} is not allowed: use default_factory')
755
756 return f
757
758
759def _set_new_attribute(cls, name, value):
760 # Never overwrites an existing attribute. Returns True if the
761 # attribute already exists.
762 if name in cls.__dict__:
763 return True
764 setattr(cls, name, value)
765 return False
766
767
768# Decide if/how we're going to create a hash function. Key is
769# (unsafe_hash, eq, frozen, does-hash-exist). Value is the action to
770# take. The common case is to do nothing, so instead of providing a
771# function that is a no-op, use None to signify that.
772
773def _hash_set_none(cls, fields, globals):
774 return None
775
776def _hash_add(cls, fields, globals):
777 flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
778 return _hash_fn(flds, globals)
779
780def _hash_exception(cls, fields, globals):
781 # Raise an exception.
782 raise TypeError(f'Cannot overwrite attribute __hash__ '
783 f'in class {cls.__name__}')
784
785#
786# +-------------------------------------- unsafe_hash?
787# | +------------------------------- eq?
788# | | +------------------------ frozen?
789# | | | +---------------- has-explicit-hash?
790# | | | |
791# | | | | +------- action
792# | | | | |
793# v v v v v
794_hash_action = {(False, False, False, False): None,
795 (False, False, False, True ): None,
796 (False, False, True, False): None,
797 (False, False, True, True ): None,
798 (False, True, False, False): _hash_set_none,
799 (False, True, False, True ): None,
800 (False, True, True, False): _hash_add,
801 (False, True, True, True ): None,
802 (True, False, False, False): _hash_add,
803 (True, False, False, True ): _hash_exception,
804 (True, False, True, False): _hash_add,
805 (True, False, True, True ): _hash_exception,
806 (True, True, False, False): _hash_add,
807 (True, True, False, True ): _hash_exception,
808 (True, True, True, False): _hash_add,
809 (True, True, True, True ): _hash_exception,
810 }
811# See https://bugs.python.org/issue32929#msg312829 for an if-statement
812# version of this table.
813
814
815def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
816 # Now that dicts retain insertion order, there's no reason to use
817 # an ordered dict. I am leveraging that ordering here, because
818 # derived class fields overwrite base class fields, but the order
819 # is defined by the base class, which is found first.
820 fields = {}
821
822 if cls.__module__ in sys.modules:
823 globals = sys.modules[cls.__module__].__dict__
824 else:
825 # Theoretically this can happen if someone writes
826 # a custom string to cls.__module__. In which case
827 # such dataclass won't be fully introspectable
828 # (w.r.t. typing.get_type_hints) but will still function
829 # correctly.
830 globals = {}
831
832 setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
833 unsafe_hash, frozen))
834
835 # Find our base classes in reverse MRO order, and exclude
836 # ourselves. In reversed order so that more derived classes
837 # override earlier field definitions in base classes. As long as
838 # we're iterating over them, see if any are frozen.
839 any_frozen_base = False
840 has_dataclass_bases = False
841 for b in cls.__mro__[-1:0:-1]:
842 # Only process classes that have been processed by our
843 # decorator. That is, they have a _FIELDS attribute.
844 base_fields = getattr(b, _FIELDS, None)
845 if base_fields:
846 has_dataclass_bases = True
847 for f in base_fields.values():
848 fields[f.name] = f
849 if getattr(b, _PARAMS).frozen:
850 any_frozen_base = True
851
852 # Annotations that are defined in this class (not in base
853 # classes). If __annotations__ isn't present, then this class
854 # adds no new annotations. We use this to compute fields that are
855 # added by this class.
856 #
857 # Fields are found from cls_annotations, which is guaranteed to be
858 # ordered. Default values are from class attributes, if a field
859 # has a default. If the default value is a Field(), then it
860 # contains additional info beyond (and possibly including) the
861 # actual default value. Pseudo-fields ClassVars and InitVars are
862 # included, despite the fact that they're not real fields. That's
863 # dealt with later.
864 cls_annotations = cls.__dict__.get('__annotations__', {})
865
866 # Now find fields in our class. While doing so, validate some
867 # things, and set the default values (as class attributes) where
868 # we can.
869 cls_fields = [_get_field(cls, name, type)
870 for name, type in cls_annotations.items()]
871 for f in cls_fields:
872 fields[f.name] = f
873
874 # If the class attribute (which is the default value for this
875 # field) exists and is of type 'Field', replace it with the
876 # real default. This is so that normal class introspection
877 # sees a real default value, not a Field.
878 if isinstance(getattr(cls, f.name, None), Field):
879 if f.default is MISSING:
880 # If there's no default, delete the class attribute.
881 # This happens if we specify field(repr=False), for
882 # example (that is, we specified a field object, but
883 # no default value). Also if we're using a default
884 # factory. The class attribute should not be set at
885 # all in the post-processed class.
886 delattr(cls, f.name)
887 else:
888 setattr(cls, f.name, f.default)
889
890 # Do we have any Field members that don't also have annotations?
891 for name, value in cls.__dict__.items():
892 if isinstance(value, Field) and not name in cls_annotations:
893 raise TypeError(f'{name!r} is a field but has no type annotation')
894
895 # Check rules that apply if we are derived from any dataclasses.
896 if has_dataclass_bases:
897 # Raise an exception if any of our bases are frozen, but we're not.
898 if any_frozen_base and not frozen:
899 raise TypeError('cannot inherit non-frozen dataclass from a '
900 'frozen one')
901
902 # Raise an exception if we're frozen, but none of our bases are.
903 if not any_frozen_base and frozen:
904 raise TypeError('cannot inherit frozen dataclass from a '
905 'non-frozen one')
906
907 # Remember all of the fields on our class (including bases). This
908 # also marks this class as being a dataclass.
909 setattr(cls, _FIELDS, fields)
910
911 # Was this class defined with an explicit __hash__? Note that if
912 # __eq__ is defined in this class, then python will automatically
913 # set __hash__ to None. This is a heuristic, as it's possible
914 # that such a __hash__ == None was not auto-generated, but it
915 # close enough.
916 class_hash = cls.__dict__.get('__hash__', MISSING)
917 has_explicit_hash = not (class_hash is MISSING or
918 (class_hash is None and '__eq__' in cls.__dict__))
919
920 # If we're generating ordering methods, we must be generating the
921 # eq methods.
922 if order and not eq:
923 raise ValueError('eq must be true if order is true')
924
925 if init:
926 # Does this class have a post-init function?
927 has_post_init = hasattr(cls, _POST_INIT_NAME)
928
929 # Include InitVars and regular fields (so, not ClassVars).
930 flds = [f for f in fields.values()
931 if f._field_type in (_FIELD, _FIELD_INITVAR)]
932 _set_new_attribute(cls, '__init__',
933 _init_fn(flds,
934 frozen,
935 has_post_init,
936 # The name to use for the "self"
937 # param in __init__. Use "self"
938 # if possible.
939 '__dataclass_self__' if 'self' in fields
940 else 'self',
941 globals,
942 ))
943
944 # Get the fields as a list, and include only real fields. This is
945 # used in all of the following methods.
946 field_list = [f for f in fields.values() if f._field_type is _FIELD]
947
948 if repr:
949 flds = [f for f in field_list if f.repr]
950 _set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))
951
952 if eq:
953 # Create _eq__ method. There's no need for a __ne__ method,
954 # since python will call __eq__ and negate it.
955 flds = [f for f in field_list if f.compare]
956 self_tuple = _tuple_str('self', flds)
957 other_tuple = _tuple_str('other', flds)
958 _set_new_attribute(cls, '__eq__',
959 _cmp_fn('__eq__', '==',
960 self_tuple, other_tuple,
961 globals=globals))
962
963 if order:
964 # Create and set the ordering methods.
965 flds = [f for f in field_list if f.compare]
966 self_tuple = _tuple_str('self', flds)
967 other_tuple = _tuple_str('other', flds)
968 for name, op in [('__lt__', '<'),
969 ('__le__', '<='),
970 ('__gt__', '>'),
971 ('__ge__', '>='),
972 ]:
973 if _set_new_attribute(cls, name,
974 _cmp_fn(name, op, self_tuple, other_tuple,
975 globals=globals)):
976 raise TypeError(f'Cannot overwrite attribute {name} '
977 f'in class {cls.__name__}. Consider using '
978 'functools.total_ordering')
979
980 if frozen:
981 for fn in _frozen_get_del_attr(cls, field_list, globals):
982 if _set_new_attribute(cls, fn.__name__, fn):
983 raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
984 f'in class {cls.__name__}')
985
986 # Decide if/how we're going to create a hash function.
987 hash_action = _hash_action[bool(unsafe_hash),
988 bool(eq),
989 bool(frozen),
990 has_explicit_hash]
991 if hash_action:
992 # No need to call _set_new_attribute here, since by the time
993 # we're here the overwriting is unconditional.
994 cls.__hash__ = hash_action(cls, field_list, globals)
995
996 if not getattr(cls, '__doc__'):
997 # Create a class doc-string.
998 # TODO(T63180083): Uncomment below when inspect.signature is implemented
999 # cls.__doc__ = (cls.__name__ +
1000 # str(inspect.signature(cls)).replace(' -> None', ''))
1001 pass
1002
1003 return cls
1004
1005
1006def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
1007 unsafe_hash=False, frozen=False):
1008 """Returns the same class as was passed in, with dunder methods
1009 added based on the fields defined in the class.
1010
1011 Examines PEP 526 __annotations__ to determine fields.
1012
1013 If init is true, an __init__() method is added to the class. If
1014 repr is true, a __repr__() method is added. If order is true, rich
1015 comparison dunder methods are added. If unsafe_hash is true, a
1016 __hash__() method function is added. If frozen is true, fields may
1017 not be assigned to after instance creation.
1018 """
1019
1020 def wrap(cls):
1021 return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)
1022
1023 # See if we're being called as @dataclass or @dataclass().
1024 if cls is None:
1025 # We're called with parens.
1026 return wrap
1027
1028 # We're called as @dataclass without parens.
1029 return wrap(cls)
1030
1031
1032def fields(class_or_instance):
1033 """Return a tuple describing the fields of this dataclass.
1034
1035 Accepts a dataclass or an instance of one. Tuple elements are of
1036 type Field.
1037 """
1038
1039 # Might it be worth caching this, per class?
1040 try:
1041 fields = getattr(class_or_instance, _FIELDS)
1042 except AttributeError:
1043 raise TypeError('must be called with a dataclass type or instance')
1044
1045 # Exclude pseudo-fields. Note that fields is sorted by insertion
1046 # order, so the order of the tuple is as the fields were defined.
1047 return tuple(f for f in fields.values() if f._field_type is _FIELD)
1048
1049
1050def _is_dataclass_instance(obj):
1051 """Returns True if obj is an instance of a dataclass."""
1052 return hasattr(type(obj), _FIELDS)
1053
1054
1055def is_dataclass(obj):
1056 """Returns True if obj is a dataclass or an instance of a
1057 dataclass."""
1058 cls = obj if isinstance(obj, type) else type(obj)
1059 return hasattr(cls, _FIELDS)
1060
1061
1062def asdict(obj, *, dict_factory=dict):
1063 """Return the fields of a dataclass instance as a new dictionary mapping
1064 field names to field values.
1065
1066 Example usage:
1067
1068 @dataclass
1069 class C:
1070 x: int
1071 y: int
1072
1073 c = C(1, 2)
1074 assert asdict(c) == {'x': 1, 'y': 2}
1075
1076 If given, 'dict_factory' will be used instead of built-in dict.
1077 The function applies recursively to field values that are
1078 dataclass instances. This will also look into built-in containers:
1079 tuples, lists, and dicts.
1080 """
1081 if not _is_dataclass_instance(obj):
1082 raise TypeError("asdict() should be called on dataclass instances")
1083 return _asdict_inner(obj, dict_factory)
1084
1085
1086def _asdict_inner(obj, dict_factory):
1087 if _is_dataclass_instance(obj):
1088 result = []
1089 for f in fields(obj):
1090 value = _asdict_inner(getattr(obj, f.name), dict_factory)
1091 result.append((f.name, value))
1092 return dict_factory(result)
1093 elif isinstance(obj, tuple) and hasattr(obj, '_fields'):
1094 # obj is a namedtuple. Recurse into it, but the returned
1095 # object is another namedtuple of the same type. This is
1096 # similar to how other list- or tuple-derived classes are
1097 # treated (see below), but we just need to create them
1098 # differently because a namedtuple's __init__ needs to be
1099 # called differently (see bpo-34363).
1100
1101 # I'm not using namedtuple's _asdict()
1102 # method, because:
1103 # - it does not recurse in to the namedtuple fields and
1104 # convert them to dicts (using dict_factory).
1105 # - I don't actually want to return a dict here. The main
1106 # use case here is json.dumps, and it handles converting
1107 # namedtuples to lists. Admittedly we're losing some
1108 # information here when we produce a json list instead of a
1109 # dict. Note that if we returned dicts here instead of
1110 # namedtuples, we could no longer call asdict() on a data
1111 # structure where a namedtuple was used as a dict key.
1112
1113 return type(obj)(*[_asdict_inner(v, dict_factory) for v in obj])
1114 elif isinstance(obj, (list, tuple)):
1115 # Assume we can create an object of this type by passing in a
1116 # generator (which is not true for namedtuples, handled
1117 # above).
1118 return type(obj)(_asdict_inner(v, dict_factory) for v in obj)
1119 elif isinstance(obj, dict):
1120 return type(obj)((_asdict_inner(k, dict_factory),
1121 _asdict_inner(v, dict_factory))
1122 for k, v in obj.items())
1123 else:
1124 return copy.deepcopy(obj)
1125
1126
1127def astuple(obj, *, tuple_factory=tuple):
1128 """Return the fields of a dataclass instance as a new tuple of field values.
1129
1130 Example usage::
1131
1132 @dataclass
1133 class C:
1134 x: int
1135 y: int
1136
1137 c = C(1, 2)
1138 assert astuple(c) == (1, 2)
1139
1140 If given, 'tuple_factory' will be used instead of built-in tuple.
1141 The function applies recursively to field values that are
1142 dataclass instances. This will also look into built-in containers:
1143 tuples, lists, and dicts.
1144 """
1145
1146 if not _is_dataclass_instance(obj):
1147 raise TypeError("astuple() should be called on dataclass instances")
1148 return _astuple_inner(obj, tuple_factory)
1149
1150
1151def _astuple_inner(obj, tuple_factory):
1152 if _is_dataclass_instance(obj):
1153 result = []
1154 for f in fields(obj):
1155 value = _astuple_inner(getattr(obj, f.name), tuple_factory)
1156 result.append(value)
1157 return tuple_factory(result)
1158 elif isinstance(obj, tuple) and hasattr(obj, '_fields'):
1159 # obj is a namedtuple. Recurse into it, but the returned
1160 # object is another namedtuple of the same type. This is
1161 # similar to how other list- or tuple-derived classes are
1162 # treated (see below), but we just need to create them
1163 # differently because a namedtuple's __init__ needs to be
1164 # called differently (see bpo-34363).
1165 return type(obj)(*[_astuple_inner(v, tuple_factory) for v in obj])
1166 elif isinstance(obj, (list, tuple)):
1167 # Assume we can create an object of this type by passing in a
1168 # generator (which is not true for namedtuples, handled
1169 # above).
1170 return type(obj)(_astuple_inner(v, tuple_factory) for v in obj)
1171 elif isinstance(obj, dict):
1172 return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory))
1173 for k, v in obj.items())
1174 else:
1175 return copy.deepcopy(obj)
1176
1177
1178def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
1179 repr=True, eq=True, order=False, unsafe_hash=False,
1180 frozen=False):
1181 """Return a new dynamically created dataclass.
1182
1183 The dataclass name will be 'cls_name'. 'fields' is an iterable
1184 of either (name), (name, type) or (name, type, Field) objects. If type is
1185 omitted, use the string 'typing.Any'. Field objects are created by
1186 the equivalent of calling 'field(name, type [, Field-info])'.
1187
1188 C = make_dataclass('C', ['x', ('y', int), ('z', int, field(init=False))], bases=(Base,))
1189
1190 is equivalent to:
1191
1192 @dataclass
1193 class C(Base):
1194 x: 'typing.Any'
1195 y: int
1196 z: int = field(init=False)
1197
1198 For the bases and namespace parameters, see the builtin type() function.
1199
1200 The parameters init, repr, eq, order, unsafe_hash, and frozen are passed to
1201 dataclass().
1202 """
1203
1204 if namespace is None:
1205 namespace = {}
1206 else:
1207 # Copy namespace since we're going to mutate it.
1208 namespace = namespace.copy()
1209
1210 # While we're looking through the field names, validate that they
1211 # are identifiers, are not keywords, and not duplicates.
1212 seen = set()
1213 anns = {}
1214 for item in fields:
1215 if isinstance(item, str):
1216 name = item
1217 tp = 'typing.Any'
1218 elif len(item) == 2:
1219 name, tp, = item
1220 elif len(item) == 3:
1221 name, tp, spec = item
1222 namespace[name] = spec
1223 else:
1224 raise TypeError(f'Invalid field: {item!r}')
1225
1226 if not isinstance(name, str) or not name.isidentifier():
1227 raise TypeError(f'Field names must be valid identifiers: {name!r}')
1228 if keyword.iskeyword(name):
1229 raise TypeError(f'Field names must not be keywords: {name!r}')
1230 if name in seen:
1231 raise TypeError(f'Field name duplicated: {name!r}')
1232
1233 seen.add(name)
1234 anns[name] = tp
1235
1236 namespace['__annotations__'] = anns
1237 # We use `types.new_class()` instead of simply `type()` to allow dynamic creation
1238 # of generic dataclassses.
1239 cls = types.new_class(cls_name, bases, {}, lambda ns: ns.update(namespace))
1240 return dataclass(cls, init=init, repr=repr, eq=eq, order=order,
1241 unsafe_hash=unsafe_hash, frozen=frozen)
1242
1243
1244def replace(*args, **changes):
1245 """Return a new object replacing specified fields with new values.
1246
1247 This is especially useful for frozen classes. Example usage:
1248
1249 @dataclass(frozen=True)
1250 class C:
1251 x: int
1252 y: int
1253
1254 c = C(1, 2)
1255 c1 = replace(c, x=3)
1256 assert c1.x == 3 and c1.y == 2
1257 """
1258 if len(args) > 1:
1259 raise TypeError(f'replace() takes 1 positional argument but {len(args)} were given')
1260 if args:
1261 obj, = args
1262 elif 'obj' in changes:
1263 obj = changes.pop('obj')
1264 import warnings
1265 warnings.warn("Passing 'obj' as keyword argument is deprecated",
1266 DeprecationWarning, stacklevel=2)
1267 else:
1268 raise TypeError("replace() missing 1 required positional argument: 'obj'")
1269
1270 # We're going to mutate 'changes', but that's okay because it's a
1271 # new dict, even if called with 'replace(obj, **my_changes)'.
1272
1273 if not _is_dataclass_instance(obj):
1274 raise TypeError("replace() should be called on dataclass instances")
1275
1276 # It's an error to have init=False fields in 'changes'.
1277 # If a field is not in 'changes', read its value from the provided obj.
1278
1279 for f in getattr(obj, _FIELDS).values():
1280 # Only consider normal fields or InitVars.
1281 if f._field_type is _FIELD_CLASSVAR:
1282 continue
1283
1284 if not f.init:
1285 # Error if this field is specified in changes.
1286 if f.name in changes:
1287 raise ValueError(f'field {f.name} is declared with '
1288 'init=False, it cannot be specified with '
1289 'replace()')
1290 continue
1291
1292 if f.name not in changes:
1293 if f._field_type is _FIELD_INITVAR:
1294 raise ValueError(f"InitVar {f.name!r} "
1295 'must be specified with replace()')
1296 changes[f.name] = getattr(obj, f.name)
1297
1298 # Create the new object, which calls __init__() and
1299 # __post_init__() (if defined), using all of the init fields we've
1300 # added and/or left in 'changes'. If there are values supplied in
1301 # changes that aren't fields, this will correctly raise a
1302 # TypeError.
1303 return obj.__class__(**changes)
1304replace.__text_signature__ = '(obj, /, **kwargs)'