this repo has no description
1# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
2from __future__ import annotations
3
4import ast
5import linecache
6import sys
7from ast import (
8 AST,
9 And,
10 AnnAssign,
11 Assign,
12 AsyncFor,
13 AsyncFunctionDef,
14 AsyncWith,
15 Attribute,
16 AugAssign,
17 Await,
18 BinOp,
19 BoolOp,
20 Bytes,
21 Call,
22 ClassDef,
23 Compare,
24 Constant,
25 DictComp,
26 Ellipsis,
27 For,
28 FormattedValue,
29 FunctionDef,
30 GeneratorExp,
31 If,
32 IfExp,
33 Import,
34 ImportFrom,
35 Index,
36 Is,
37 IsNot,
38 JoinedStr,
39 Lambda,
40 ListComp,
41 Module,
42 Name,
43 NameConstant,
44 Num,
45 Return,
46 SetComp,
47 Slice,
48 Starred,
49 Str,
50 Subscript,
51 Try,
52 UnaryOp,
53 While,
54 With,
55 Yield,
56 YieldFrom,
57 cmpop,
58 expr,
59)
60from contextlib import contextmanager, nullcontext
61from enum import IntEnum
62from functools import partial
63from types import BuiltinFunctionType, CodeType, MethodDescriptorType
64from typing import (
65 Callable as typingCallable,
66 Collection,
67 Dict,
68 Generator,
69 Generic,
70 Iterable,
71 List,
72 Mapping,
73 NoReturn,
74 Optional,
75 Sequence,
76 Set,
77 Tuple,
78 Type,
79 TypeVar,
80 Union,
81 cast,
82)
83
84from __static__ import chkdict # pyre-ignore[21]: unknown module
85from _static import ( # pyre-fixme[21]: Could not find module `_static`.
86 TYPED_BOOL,
87 TYPED_INT_8BIT,
88 TYPED_INT_16BIT,
89 TYPED_INT_32BIT,
90 TYPED_INT_64BIT,
91 TYPED_OBJECT,
92 TYPED_ARRAY,
93 TYPED_INT_UNSIGNED,
94 TYPED_INT_SIGNED,
95 TYPED_INT8,
96 TYPED_INT16,
97 TYPED_INT32,
98 TYPED_INT64,
99 TYPED_UINT8,
100 TYPED_UINT16,
101 TYPED_UINT32,
102 TYPED_UINT64,
103 SEQ_LIST,
104 SEQ_TUPLE,
105 SEQ_LIST_INEXACT,
106 SEQ_ARRAY_INT8,
107 SEQ_ARRAY_INT16,
108 SEQ_ARRAY_INT32,
109 SEQ_ARRAY_INT64,
110 SEQ_ARRAY_UINT8,
111 SEQ_ARRAY_UINT16,
112 SEQ_ARRAY_UINT32,
113 SEQ_ARRAY_UINT64,
114 SEQ_SUBSCR_UNCHECKED,
115 SEQ_REPEAT_INEXACT_SEQ,
116 SEQ_REPEAT_INEXACT_NUM,
117 SEQ_REPEAT_REVERSED,
118 SEQ_REPEAT_PRIMITIVE_NUM,
119 PRIM_OP_EQ_INT,
120 PRIM_OP_NE_INT,
121 PRIM_OP_LT_INT,
122 PRIM_OP_LE_INT,
123 PRIM_OP_GT_INT,
124 PRIM_OP_GE_INT,
125 PRIM_OP_LT_UN_INT,
126 PRIM_OP_LE_UN_INT,
127 PRIM_OP_GT_UN_INT,
128 PRIM_OP_GE_UN_INT,
129 PRIM_OP_ADD_INT,
130 PRIM_OP_SUB_INT,
131 PRIM_OP_MUL_INT,
132 PRIM_OP_DIV_INT,
133 PRIM_OP_DIV_UN_INT,
134 PRIM_OP_MOD_INT,
135 PRIM_OP_MOD_UN_INT,
136 PRIM_OP_LSHIFT_INT,
137 PRIM_OP_RSHIFT_INT,
138 PRIM_OP_RSHIFT_UN_INT,
139 PRIM_OP_XOR_INT,
140 PRIM_OP_OR_INT,
141 PRIM_OP_AND_INT,
142 PRIM_OP_NEG_INT,
143 PRIM_OP_INV_INT,
144 PRIM_OP_ADD_DBL,
145 PRIM_OP_SUB_DBL,
146 PRIM_OP_MUL_DBL,
147 PRIM_OP_DIV_DBL,
148 PRIM_OP_MOD_DBL,
149 PROM_OP_POW_DBL,
150 FAST_LEN_INEXACT,
151 FAST_LEN_LIST,
152 FAST_LEN_DICT,
153 FAST_LEN_SET,
154 FAST_LEN_TUPLE,
155 FAST_LEN_ARRAY,
156 FAST_LEN_STR,
157 TYPED_DOUBLE,
158 RAND_MAX,
159 rand,
160)
161
162from . import symbols, opcode38static
163from .consts import SC_LOCAL, SC_GLOBAL_EXPLICIT, SC_GLOBAL_IMPLICIT
164from .opcodebase import Opcode
165from .optimizer import AstOptimizer
166from .pyassem import Block, PyFlowGraph, PyFlowGraphCinder, IndexedSet
167from .pycodegen import (
168 AugAttribute,
169 AugName,
170 AugSubscript,
171 CodeGenerator,
172 CinderCodeGenerator,
173 Delegator,
174 compile,
175 wrap_aug,
176 FOR_LOOP,
177)
178from .symbols import Scope, SymbolVisitor, ModuleScope, ClassScope
179from .unparse import to_expr
180from .visitor import ASTVisitor, ASTRewriter, TAst
181
182
183try:
184 import xxclassloader # pyre-ignore[21]: unknown module
185 from xxclassloader import spamobj
186except ImportError:
187 spamobj = None
188
189
190def exec_static(
191 source: str,
192 locals: Dict[str, object],
193 globals: Dict[str, object],
194 modname: str = "<module>",
195) -> None:
196 code = compile(
197 source, "<module>", "exec", compiler=StaticCodeGenerator, modname=modname
198 )
199 exec(code, locals, globals) # noqa: P204
200
201
202CBOOL_TYPE: CIntType
203INT8_TYPE: CIntType
204INT16_TYPE: CIntType
205INT32_TYPE: CIntType
206INT64_TYPE: CIntType
207INT64_VALUE: CIntInstance
208SIGNED_CINT_TYPES: Sequence[CIntType]
209INT_TYPE: NumClass
210INT_EXACT_TYPE: NumClass
211FLOAT_TYPE: NumClass
212COMPLEX_TYPE: NumClass
213BOOL_TYPE: Class
214ARRAY_TYPE: Class
215DICT_TYPE: Class
216LIST_TYPE: Class
217TUPLE_TYPE: Class
218SET_TYPE: Class
219
220OBJECT_TYPE: Class
221OBJECT: Value
222
223DYNAMIC_TYPE: DynamicClass
224DYNAMIC: DynamicInstance
225FUNCTION_TYPE: Class
226METHOD_TYPE: Class
227MEMBER_TYPE: Class
228NONE_TYPE: Class
229TYPE_TYPE: Class
230ARG_TYPE: Class
231SLICE_TYPE: Class
232
233CHAR_TYPE: CIntType
234DOUBLE_TYPE: CDoubleType
235
236# Prefix for temporary var names. It's illegal in normal
237# Python, so there's no chance it will ever clash with a
238# user defined name.
239_TMP_VAR_PREFIX = "_pystatic_.0._tmp__"
240
241
242CMPOP_SIGILS: Mapping[Type[cmpop], str] = {
243 ast.Lt: "<",
244 ast.Gt: ">",
245 ast.Eq: "==",
246 ast.NotEq: "!=",
247 ast.LtE: "<=",
248 ast.GtE: ">=",
249 ast.Is: "is",
250 ast.IsNot: "is",
251}
252
253
254def syntax_error(msg: str, filename: str, node: AST) -> TypedSyntaxError:
255 lineno, offset, source_line = error_location(filename, node)
256 return TypedSyntaxError(msg, (filename, lineno, offset, source_line))
257
258
259def error_location(filename: str, node: AST) -> Tuple[int, int, Optional[str]]:
260 source_line = linecache.getline(filename, node.lineno)
261 return (node.lineno, node.col_offset, source_line or None)
262
263
264@contextmanager
265def error_context(filename: str, node: AST) -> Generator[None, None, None]:
266 """Add error location context to any TypedSyntaxError raised in with block."""
267 try:
268 yield
269 except TypedSyntaxError as exc:
270 if exc.filename is None:
271 exc.filename = filename
272 if (exc.lineno, exc.offset) == (None, None):
273 exc.lineno, exc.offset, exc.text = error_location(filename, node)
274 raise
275
276
277class TypeRef:
278 """Stores unresolved typed references, capturing the referring module
279 as well as the annotation"""
280
281 def __init__(self, module: ModuleTable, ref: ast.expr) -> None:
282 self.module = module
283 self.ref = ref
284
285 def resolved(self, is_declaration: bool = False) -> Class:
286 res = self.module.resolve_annotation(self.ref, is_declaration=is_declaration)
287 if res is None:
288 return DYNAMIC_TYPE
289 return res
290
291 def __repr__(self) -> str:
292 return f"TypeRef({self.module.name}, {ast.dump(self.ref)})"
293
294
295class ResolvedTypeRef(TypeRef):
296 def __init__(self, type: Class) -> None:
297 self._resolved = type
298
299 def resolved(self, is_declaration: bool = False) -> Class:
300 return self._resolved
301
302 def __repr__(self) -> str:
303 return f"ResolvedTypeRef({self.resolved()})"
304
305
306# Pyre doesn't support recursive generics, so we can't represent the recursively
307# nested tuples that make up a type_descr. Fortunately we don't need to, since
308# we don't parse them in Python, we just generate them and emit them as
309# constants. So just call them `Tuple[object, ...]`
310TypeDescr = Tuple[object, ...]
311
312
313class TypeName:
314 def __init__(self, module: str, name: str) -> None:
315 self.module = module
316 self.name = name
317
318 @property
319 def type_descr(self) -> TypeDescr:
320 """The metadata emitted into the const pool to describe a type.
321
322 For normal types this is just the fully qualified type name as a tuple
323 ('mypackage', 'mymod', 'C'). For optional types we have an extra '?'
324 element appended. For generic types we append a tuple of the generic
325 args' type_descrs.
326 """
327 return (self.module, self.name)
328
329 @property
330 def friendly_name(self) -> str:
331 if self.module and self.module not in ("builtins", "__static__", "typing"):
332 return f"{self.module}.{self.name}"
333 return self.name
334
335
336class GenericTypeName(TypeName):
337 def __init__(self, module: str, name: str, args: Tuple[Class, ...]) -> None:
338 super().__init__(module, name)
339 self.args = args
340
341 @property
342 def type_descr(self) -> TypeDescr:
343 gen_args: List[TypeDescr] = []
344 for arg in self.args:
345 gen_args.append(arg.type_descr)
346 return (self.module, self.name, tuple(gen_args))
347
348 @property
349 def friendly_name(self) -> str:
350 args = ", ".join(arg.instance.name for arg in self.args)
351 return f"{super().friendly_name}[{args}]"
352
353
354GenericTypeIndex = Tuple["Class", ...]
355GenericTypesDict = Dict["Class", Dict[GenericTypeIndex, "Class"]]
356
357
358class SymbolTable:
359 def __init__(self) -> None:
360 self.modules: Dict[str, ModuleTable] = {}
361 builtins_children = {
362 "object": OBJECT_TYPE,
363 "type": TYPE_TYPE,
364 "None": NONE_TYPE.instance,
365 "int": INT_EXACT_TYPE,
366 "complex": COMPLEX_EXACT_TYPE,
367 "str": STR_EXACT_TYPE,
368 "bytes": BYTES_TYPE,
369 "bool": BOOL_TYPE,
370 "float": FLOAT_EXACT_TYPE,
371 "len": LenFunction(FUNCTION_TYPE, boxed=True),
372 "min": ExtremumFunction(FUNCTION_TYPE, is_min=True),
373 "max": ExtremumFunction(FUNCTION_TYPE, is_min=False),
374 "list": LIST_EXACT_TYPE,
375 "tuple": TUPLE_EXACT_TYPE,
376 "set": SET_EXACT_TYPE,
377 "sorted": SortedFunction(FUNCTION_TYPE),
378 "Exception": EXCEPTION_TYPE,
379 "BaseException": BASE_EXCEPTION_TYPE,
380 "isinstance": IsInstanceFunction(),
381 "issubclass": IsSubclassFunction(),
382 "staticmethod": STATIC_METHOD_TYPE,
383 "reveal_type": RevealTypeFunction(),
384 }
385 strict_builtins = StrictBuiltins(builtins_children)
386 typing_children = {
387 # TODO: Need typed members for dict
388 "Dict": DICT_TYPE,
389 "List": LIST_TYPE,
390 "Final": FINAL_TYPE,
391 "final": FINAL_METHOD_TYPE,
392 "NamedTuple": NAMED_TUPLE_TYPE,
393 "Optional": OPTIONAL_TYPE,
394 "Union": UNION_TYPE,
395 "Tuple": TUPLE_TYPE,
396 "TYPE_CHECKING": BOOL_TYPE.instance,
397 }
398
399 builtins_children["<builtins>"] = strict_builtins
400 builtins_children["<fixed-modules>"] = StrictBuiltins(
401 {"typing": StrictBuiltins(typing_children)}
402 )
403
404 self.builtins = self.modules["builtins"] = ModuleTable(
405 "builtins",
406 "<builtins>",
407 self,
408 builtins_children,
409 )
410 self.typing = self.modules["typing"] = ModuleTable(
411 "typing", "<typing>", self, typing_children
412 )
413 self.statics = self.modules["__static__"] = ModuleTable(
414 "__static__",
415 "<__static__>",
416 self,
417 {
418 "Array": ARRAY_EXACT_TYPE,
419 "CheckedDict": CHECKED_DICT_EXACT_TYPE,
420 "allow_weakrefs": ALLOW_WEAKREFS_TYPE,
421 "box": BoxFunction(FUNCTION_TYPE),
422 "cast": CastFunction(FUNCTION_TYPE),
423 "clen": LenFunction(FUNCTION_TYPE, boxed=False),
424 "dynamic_return": DYNAMIC_RETURN_TYPE,
425 "size_t": UINT64_TYPE,
426 "ssize_t": INT64_TYPE,
427 "cbool": CBOOL_TYPE,
428 "inline": INLINE_TYPE,
429 # This is a way to disable the static compiler for
430 # individual functions/methods
431 "_donotcompile": DONOTCOMPILE_TYPE,
432 "int8": INT8_TYPE,
433 "int16": INT16_TYPE,
434 "int32": INT32_TYPE,
435 "int64": INT64_TYPE,
436 "uint8": UINT8_TYPE,
437 "uint16": UINT16_TYPE,
438 "uint32": UINT32_TYPE,
439 "uint64": UINT64_TYPE,
440 "char": CHAR_TYPE,
441 "double": DOUBLE_TYPE,
442 "unbox": UnboxFunction(FUNCTION_TYPE),
443 "nonchecked_dicts": BOOL_TYPE.instance,
444 "pydict": DICT_TYPE,
445 "PyDict": DICT_TYPE,
446 "Vector": VECTOR_TYPE,
447 "RAND_MAX": NumClass(
448 TypeName("builtins", "int"), pytype=int, literal_value=RAND_MAX
449 ).instance,
450 "rand": reflect_builtin_function(rand),
451 },
452 )
453
454 if SPAM_OBJ is not None:
455 self.modules["xxclassloader"] = ModuleTable(
456 "xxclassloader",
457 "<xxclassloader>",
458 self,
459 {
460 "spamobj": SPAM_OBJ,
461 "XXGeneric": XX_GENERIC_TYPE,
462 "foo": reflect_builtin_function(xxclassloader.foo),
463 "bar": reflect_builtin_function(xxclassloader.bar),
464 "neg": reflect_builtin_function(xxclassloader.neg),
465 },
466 )
467
468 # We need to clone the dictionaries for each type so that as we populate
469 # generic instantations that we don't store them in the global dict for
470 # built-in types
471 self.generic_types: GenericTypesDict = {
472 k: dict(v) for k, v in BUILTIN_GENERICS.items()
473 }
474
475 def __getitem__(self, name: str) -> ModuleTable:
476 return self.modules[name]
477
478 def __setitem__(self, name: str, value: ModuleTable) -> None:
479 self.modules[name] = value
480
481 def add_module(self, name: str, filename: str, tree: AST) -> None:
482 decl_visit = DeclarationVisitor(name, filename, self)
483 decl_visit.visit(tree)
484 decl_visit.finish_bind()
485
486 def compile(
487 self, name: str, filename: str, tree: AST, optimize: int = 0
488 ) -> CodeType:
489 if name not in self.modules:
490 self.add_module(name, filename, tree)
491
492 tree = AstOptimizer(optimize=optimize > 0).visit(tree)
493
494 # Analyze variable scopes
495 s = SymbolVisitor()
496 s.visit(tree)
497
498 # Analyze the types of objects within local scopes
499 type_binder = TypeBinder(s, filename, self, name, optimize)
500 type_binder.visit(tree)
501
502 # Compile the code w/ the static compiler
503 graph = StaticCodeGenerator.flow_graph(
504 name, filename, s.scopes[tree], peephole_enabled=True
505 )
506 graph.setFlag(StaticCodeGenerator.consts.CO_STATICALLY_COMPILED)
507
508 code_gen = StaticCodeGenerator(
509 None, tree, s, graph, self, name, flags=0, optimization_lvl=optimize
510 )
511 code_gen.visit(tree)
512
513 return code_gen.getCode()
514
515 def import_module(self, name: str) -> None:
516 pass
517
518
519TType = TypeVar("TType")
520
521
522class ModuleTable:
523 def __init__(
524 self,
525 name: str,
526 filename: str,
527 symtable: SymbolTable,
528 members: Optional[Dict[str, Value]] = None,
529 ) -> None:
530 self.name = name
531 self.filename = filename
532 self.children: Dict[str, Value] = members or {}
533 self.symtable = symtable
534 self.types: Dict[Union[AST, Delegator], Value] = {}
535 self.node_data: Dict[Tuple[Union[AST, Delegator], object], object] = {}
536 self.nonchecked_dicts = False
537 self.noframe = False
538 self.decls: List[Tuple[AST, Optional[Value]]] = []
539 # TODO: final constants should be typed to literals, and
540 # this should be removed in the future
541 self.named_finals: Dict[str, ast.Constant] = {}
542 # Functions in this module that have been decorated with
543 # `dynamic_return`. We actually store their `.args` node in here, not
544 # the `FunctionDef` node itself, since strict modules rewriter will
545 # replace the latter in between decls visit and type binding / codegen.
546 self.dynamic_returns: Set[ast.AST] = set()
547 # Have we completed our first pass through the module, populating
548 # imports and types defined in the module? Until we have, resolving
549 # type annotations is not safe.
550 self.first_pass_done = False
551
552 def finish_bind(self) -> None:
553 self.first_pass_done = True
554 for node, value in self.decls:
555 with error_context(self.filename, node):
556 if value is not None:
557 value.finish_bind(self)
558 elif isinstance(node, ast.AnnAssign):
559 typ = self.resolve_annotation(node.annotation, is_declaration=True)
560 if isinstance(typ, FinalClass):
561 target = node.target
562 value = node.value
563 if not value:
564 raise TypedSyntaxError(
565 "Must assign a value when declaring a Final"
566 )
567 elif (
568 not isinstance(typ, CType)
569 and isinstance(target, ast.Name)
570 and isinstance(value, ast.Constant)
571 ):
572 self.named_finals[target.id] = value
573
574 # We don't need these anymore...
575 self.decls.clear()
576
577 def resolve_type(self, node: ast.AST) -> Optional[Class]:
578 # TODO handle Call
579 return self._resolve(node, self.resolve_type)
580
581 def _resolve(
582 self,
583 node: ast.AST,
584 _resolve: typingCallable[[ast.AST], Optional[Class]],
585 _resolve_subscr_target: Optional[
586 typingCallable[[ast.AST], Optional[Class]]
587 ] = None,
588 ) -> Optional[Class]:
589 if isinstance(node, ast.Name):
590 res = self.resolve_name(node.id)
591 if isinstance(res, Class):
592 return res
593 elif isinstance(node, Subscript):
594 slice = node.slice
595 if isinstance(slice, Index):
596 val = (_resolve_subscr_target or _resolve)(node.value)
597 if val is not None:
598 value = slice.value
599 if isinstance(value, ast.Tuple):
600 anns = []
601 for elt in value.elts:
602 ann = _resolve(elt) or DYNAMIC_TYPE
603 anns.append(ann)
604 values = tuple(anns)
605 gen = val.make_generic_type(values, self.symtable.generic_types)
606 return gen or val
607 else:
608 index = _resolve(value) or DYNAMIC_TYPE
609 gen = val.make_generic_type(
610 (index,), self.symtable.generic_types
611 )
612 return gen or val
613 # TODO handle Attribute
614
615 def resolve_annotation(
616 self,
617 node: ast.AST,
618 is_declaration: bool = False,
619 ) -> Optional[Class]:
620 assert self.first_pass_done, (
621 "Type annotations cannot be resolved until after initial pass, "
622 "so that all imports and types are available."
623 )
624
625 with error_context(self.filename, node):
626 klass = self._resolve_annotation(node)
627
628 if isinstance(klass, FinalClass) and not is_declaration:
629 raise TypedSyntaxError(
630 "Final annotation is only valid in initial declaration "
631 "of attribute or module-level constant",
632 )
633
634 # TODO until we support runtime checking of unions, we must for
635 # safety resolve union annotations to dynamic (except for
636 # optionals, which we can check at runtime)
637 if (
638 isinstance(klass, UnionType)
639 and klass is not UNION_TYPE
640 and klass is not OPTIONAL_TYPE
641 and klass.opt_type is None
642 ):
643 return None
644
645 # Even if we know that e.g. `builtins.str` is the exact `str` type and
646 # not a subclass, and it's useful to track that knowledge, when we
647 # annotate `x: str` that annotation should not exclude subclasses.
648 return inexact_type(klass) if klass else None
649
650 def _resolve_annotation(self, node: ast.AST) -> Optional[Class]:
651 # First try to resolve non-annotation-specific forms. For resolving the
652 # outer target of a subscript (e.g. `Final` in `Final[int]`) we pass
653 # `is_declaration=True` to allow `Final` in that position; if in fact
654 # we are not resolving a declaration, the outer `resolve_annotation`
655 # (our caller) will still catch the generic Final that we end up
656 # returning.
657 typ = self._resolve(
658 node,
659 self.resolve_annotation,
660 _resolve_subscr_target=partial(
661 self.resolve_annotation, is_declaration=True
662 ),
663 )
664 if typ:
665 return typ
666 elif isinstance(node, ast.Str):
667 # pyre-ignore[16]: `AST` has no attribute `body`.
668 return self.resolve_annotation(ast.parse(node.s, "", "eval").body)
669 elif isinstance(node, ast.Constant):
670 sval = node.value
671 if sval is None:
672 return NONE_TYPE
673 elif isinstance(sval, str):
674 return self.resolve_annotation(ast.parse(node.value, "", "eval").body)
675 elif isinstance(node, NameConstant) and node.value is None:
676 return NONE_TYPE
677 elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
678 ltype = self.resolve_annotation(node.left)
679 rtype = self.resolve_annotation(node.right)
680 if ltype is None or rtype is None:
681 return None
682 return UNION_TYPE.make_generic_type(
683 (ltype, rtype), self.symtable.generic_types
684 )
685
686 def resolve_name(self, name: str) -> Optional[Value]:
687 return self.children.get(name) or self.symtable.builtins.children.get(name)
688
689 def get_final_literal(self, node: AST, scope: Scope) -> Optional[ast.Constant]:
690 if not isinstance(node, Name):
691 return None
692
693 final_val = self.named_finals.get(node.id, None)
694 if (
695 final_val is not None
696 and isinstance(node.ctx, ast.Load)
697 and (
698 # Ensure the name is not shadowed in the local scope
699 isinstance(scope, ModuleScope)
700 or node.id not in scope.defs
701 )
702 ):
703 return final_val
704
705
706TClass = TypeVar("TClass", bound="Class", covariant=True)
707TClassInv = TypeVar("TClassInv", bound="Class")
708
709
710class Value:
711 """base class for all values tracked at compile time."""
712
713 def __init__(self, klass: Class) -> None:
714 """name: the name of the value, for instances this is used solely for
715 debug/reporting purposes. In Class subclasses this will be the
716 qualified name (e.g. module.Foo).
717 klass: the Class of this object"""
718 self.klass = klass
719
720 @property
721 def name(self) -> str:
722 return type(self).__name__
723
724 def finish_bind(self, module: ModuleTable) -> None:
725 pass
726
727 def make_generic_type(
728 self, index: GenericTypeIndex, generic_types: GenericTypesDict
729 ) -> Optional[Class]:
730 pass
731
732 def get_iter_type(self, node: ast.expr, visitor: TypeBinder) -> Value:
733 """returns the type that is produced when iterating over this value"""
734 raise visitor.syntax_error(f"cannot iterate over {self.name}", node)
735
736 def as_oparg(self) -> int:
737 raise TypeError(f"{self.name} not valid here")
738
739 def bind_attr(
740 self, node: ast.Attribute, visitor: TypeBinder, type_ctx: Optional[Class]
741 ) -> None:
742
743 raise visitor.syntax_error(f"cannot load attribute from {self.name}", node)
744
745 def bind_call(
746 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
747 ) -> NarrowingEffect:
748 raise visitor.syntax_error(f"cannot call {self.name}", node)
749
750 def check_args_for_primitives(self, node: ast.Call, visitor: TypeBinder) -> None:
751 for arg in node.args:
752 if isinstance(visitor.get_type(arg), CInstance):
753 raise visitor.syntax_error("Call argument cannot be a primitive", arg)
754 for arg in node.keywords:
755 if isinstance(visitor.get_type(arg.value), CInstance):
756 raise visitor.syntax_error(
757 "Call argument cannot be a primitive", arg.value
758 )
759
760 def bind_descr_get(
761 self,
762 node: ast.Attribute,
763 inst: Optional[Object[TClassInv]],
764 ctx: TClassInv,
765 visitor: TypeBinder,
766 type_ctx: Optional[Class],
767 ) -> None:
768 raise visitor.syntax_error(f"cannot get descriptor {self.name}", node)
769
770 def bind_decorate_function(
771 self, visitor: DeclarationVisitor, fn: Function | StaticMethod
772 ) -> Optional[Value]:
773 return None
774
775 def bind_decorate_class(self, klass: Class) -> Class:
776 return DYNAMIC_TYPE
777
778 def bind_subscr(
779 self, node: ast.Subscript, type: Value, visitor: TypeBinder
780 ) -> None:
781 raise visitor.syntax_error(f"cannot index {self.name}", node)
782
783 def emit_subscr(
784 self, node: ast.Subscript, aug_flag: bool, code_gen: Static38CodeGenerator
785 ) -> None:
786 code_gen.defaultVisit(node, aug_flag)
787
788 def emit_store_subscr(
789 self, node: ast.Subscript, code_gen: Static38CodeGenerator
790 ) -> None:
791 code_gen.emit("ROT_THREE")
792 code_gen.emit("STORE_SUBSCR")
793
794 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
795 code_gen.defaultVisit(node)
796
797 def emit_attr(
798 self, node: Union[ast.Attribute, AugAttribute], code_gen: Static38CodeGenerator
799 ) -> None:
800 if isinstance(node.ctx, ast.Store):
801 code_gen.emit("STORE_ATTR", code_gen.mangle(node.attr))
802 elif isinstance(node.ctx, ast.Del):
803 code_gen.emit("DELETE_ATTR", code_gen.mangle(node.attr))
804 else:
805 code_gen.emit("LOAD_ATTR", code_gen.mangle(node.attr))
806
807 def bind_compare(
808 self,
809 node: ast.Compare,
810 left: expr,
811 op: cmpop,
812 right: expr,
813 visitor: TypeBinder,
814 type_ctx: Optional[Class],
815 ) -> bool:
816 raise visitor.syntax_error(f"cannot compare with {self.name}", node)
817
818 def bind_reverse_compare(
819 self,
820 node: ast.Compare,
821 left: expr,
822 op: cmpop,
823 right: expr,
824 visitor: TypeBinder,
825 type_ctx: Optional[Class],
826 ) -> bool:
827 raise visitor.syntax_error(f"cannot reverse with {self.name}", node)
828
829 def emit_compare(self, op: cmpop, code_gen: Static38CodeGenerator) -> None:
830 code_gen.defaultEmitCompare(op)
831
832 def bind_binop(
833 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
834 ) -> bool:
835 raise visitor.syntax_error(f"cannot bin op with {self.name}", node)
836
837 def bind_reverse_binop(
838 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
839 ) -> bool:
840 raise visitor.syntax_error(f"cannot reverse bin op with {self.name}", node)
841
842 def bind_unaryop(
843 self, node: ast.UnaryOp, visitor: TypeBinder, type_ctx: Optional[Class]
844 ) -> None:
845 raise visitor.syntax_error(f"cannot reverse unary op with {self.name}", node)
846
847 def emit_binop(self, node: ast.BinOp, code_gen: Static38CodeGenerator) -> None:
848 code_gen.defaultVisit(node)
849
850 def emit_forloop(self, node: ast.For, code_gen: Static38CodeGenerator) -> None:
851 start = code_gen.newBlock("default_forloop_start")
852 anchor = code_gen.newBlock("default_forloop_anchor")
853 after = code_gen.newBlock("default_forloop_after")
854
855 code_gen.set_lineno(node)
856 code_gen.push_loop(FOR_LOOP, start, after)
857 code_gen.visit(node.iter)
858 code_gen.emit("GET_ITER")
859
860 code_gen.nextBlock(start)
861 code_gen.emit("FOR_ITER", anchor)
862 code_gen.visit(node.target)
863 code_gen.visit(node.body)
864 code_gen.emit("JUMP_ABSOLUTE", start)
865 code_gen.nextBlock(anchor)
866 code_gen.pop_loop()
867
868 if node.orelse:
869 code_gen.visit(node.orelse)
870 code_gen.nextBlock(after)
871
872 def emit_unaryop(self, node: ast.UnaryOp, code_gen: Static38CodeGenerator) -> None:
873 code_gen.defaultVisit(node)
874
875 def emit_augassign(
876 self, node: ast.AugAssign, code_gen: Static38CodeGenerator
877 ) -> None:
878 code_gen.defaultVisit(node)
879
880 def emit_augname(
881 self, node: AugName, code_gen: Static38CodeGenerator, mode: str
882 ) -> None:
883 code_gen.defaultVisit(node, mode)
884
885 def bind_constant(self, node: ast.Constant, visitor: TypeBinder) -> None:
886 raise visitor.syntax_error(f"cannot constant with {self.name}", node)
887
888 def emit_constant(
889 self, node: ast.Constant, code_gen: Static38CodeGenerator
890 ) -> None:
891 return code_gen.defaultVisit(node)
892
893 def emit_name(self, node: ast.Name, code_gen: Static38CodeGenerator) -> None:
894 return code_gen.defaultVisit(node)
895
896 def emit_jumpif(
897 self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
898 ) -> None:
899 CinderCodeGenerator.compileJumpIf(code_gen, test, next, is_if_true)
900
901 def emit_jumpif_pop(
902 self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
903 ) -> None:
904 CinderCodeGenerator.compileJumpIfPop(code_gen, test, next, is_if_true)
905
906 def emit_box(self, node: expr, code_gen: Static38CodeGenerator) -> None:
907 raise RuntimeError(f"Unsupported box type: {code_gen.get_type(node)}")
908
909 def emit_unbox(self, node: expr, code_gen: Static38CodeGenerator) -> None:
910 raise RuntimeError("Unsupported unbox type")
911
912 def get_fast_len_type(self) -> Optional[int]:
913 return None
914
915 def emit_len(
916 self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
917 ) -> None:
918 if not boxed:
919 raise RuntimeError("Unsupported type for clen()")
920 return self.emit_call(node, code_gen)
921
922 def make_generic(
923 self, new_type: Class, name: GenericTypeName, generic_types: GenericTypesDict
924 ) -> Value:
925 return self
926
927 def emit_convert(self, to_type: Value, code_gen: Static38CodeGenerator) -> None:
928 pass
929
930
931class Object(Value, Generic[TClass]):
932 """Represents an instance of a type at compile time"""
933
934 klass: TClass
935
936 @property
937 def name(self) -> str:
938 return self.klass.instance_name
939
940 def as_oparg(self) -> int:
941 return TYPED_OBJECT
942
943 def bind_call(
944 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
945 ) -> NarrowingEffect:
946 visitor.set_type(node, DYNAMIC)
947 for arg in node.args:
948 visitor.visit(arg)
949
950 for arg in node.keywords:
951 visitor.visit(arg.value)
952 self.check_args_for_primitives(node, visitor)
953 return NO_EFFECT
954
955 def bind_attr(
956 self, node: ast.Attribute, visitor: TypeBinder, type_ctx: Optional[Class]
957 ) -> None:
958 for base in self.klass.mro:
959 member = base.members.get(node.attr)
960 if member is not None:
961 member.bind_descr_get(node, self, self.klass, visitor, type_ctx)
962 return
963
964 if node.attr == "__class__":
965 visitor.set_type(node, self.klass)
966 else:
967 visitor.set_type(node, DYNAMIC)
968
969 def emit_attr(
970 self, node: Union[ast.Attribute, AugAttribute], code_gen: Static38CodeGenerator
971 ) -> None:
972 for base in self.klass.mro:
973 member = base.members.get(node.attr)
974 if member is not None and isinstance(member, Slot):
975 type_descr = member.container_type.type_descr
976 type_descr += (member.slot_name,)
977 if isinstance(node.ctx, ast.Store):
978 code_gen.emit("STORE_FIELD", type_descr)
979 elif isinstance(node.ctx, ast.Del):
980 code_gen.emit("DELETE_ATTR", node.attr)
981 else:
982 code_gen.emit("LOAD_FIELD", type_descr)
983 return
984
985 super().emit_attr(node, code_gen)
986
987 def bind_descr_get(
988 self,
989 node: ast.Attribute,
990 inst: Optional[Object[TClass]],
991 ctx: Class,
992 visitor: TypeBinder,
993 type_ctx: Optional[Class],
994 ) -> None:
995 visitor.set_type(node, DYNAMIC)
996
997 def bind_subscr(
998 self, node: ast.Subscript, type: Value, visitor: TypeBinder
999 ) -> None:
1000 visitor.check_can_assign_from(DYNAMIC_TYPE, type.klass, node)
1001 visitor.set_type(node, DYNAMIC)
1002
1003 def bind_compare(
1004 self,
1005 node: ast.Compare,
1006 left: expr,
1007 op: cmpop,
1008 right: expr,
1009 visitor: TypeBinder,
1010 type_ctx: Optional[Class],
1011 ) -> bool:
1012 visitor.set_type(op, DYNAMIC)
1013 visitor.set_type(node, DYNAMIC)
1014 return False
1015
1016 def bind_reverse_compare(
1017 self,
1018 node: ast.Compare,
1019 left: expr,
1020 op: cmpop,
1021 right: expr,
1022 visitor: TypeBinder,
1023 type_ctx: Optional[Class],
1024 ) -> bool:
1025 visitor.set_type(op, DYNAMIC)
1026 visitor.set_type(node, DYNAMIC)
1027 return False
1028
1029 def bind_binop(
1030 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
1031 ) -> bool:
1032 return False
1033
1034 def bind_reverse_binop(
1035 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
1036 ) -> bool:
1037 # we'll set the type in case we're the only one called
1038 visitor.set_type(node, DYNAMIC)
1039 return False
1040
1041 def bind_unaryop(
1042 self, node: ast.UnaryOp, visitor: TypeBinder, type_ctx: Optional[Class]
1043 ) -> None:
1044 if isinstance(node.op, ast.Not):
1045 visitor.set_type(node, BOOL_TYPE.instance)
1046 else:
1047 visitor.set_type(node, DYNAMIC)
1048
1049 def bind_constant(self, node: ast.Constant, visitor: TypeBinder) -> None:
1050 node_type = CONSTANT_TYPES[type(node.value)]
1051 visitor.set_type(node, node_type)
1052 visitor.check_can_assign_from(self.klass, node_type.klass, node)
1053
1054 def get_iter_type(self, node: ast.expr, visitor: TypeBinder) -> Value:
1055 """returns the type that is produced when iterating over this value"""
1056 return DYNAMIC
1057
1058 def __repr__(self) -> str:
1059 return f"<{self.name}>"
1060
1061
1062class Class(Object["Class"]):
1063 """Represents a type object at compile time"""
1064
1065 suppress_exact = False
1066
1067 def __init__(
1068 self,
1069 type_name: TypeName,
1070 bases: Optional[List[Class]] = None,
1071 instance: Optional[Value] = None,
1072 klass: Optional[Class] = None,
1073 members: Optional[Dict[str, Value]] = None,
1074 is_exact: bool = False,
1075 pytype: Optional[Type[object]] = None,
1076 ) -> None:
1077 super().__init__(klass or TYPE_TYPE)
1078 assert isinstance(bases, (type(None), list))
1079 self.type_name = type_name
1080 self.instance: Value = instance or Object(self)
1081 self.bases: List[Class] = bases or []
1082 self._mro: Optional[List[Class]] = None
1083 self._mro_type_descrs: Optional[Set[TypeDescr]] = None
1084 self.members: Dict[str, Value] = members or {}
1085 self.is_exact = is_exact
1086 self.is_final = False
1087 self.allow_weakrefs = False
1088 self.donotcompile = False
1089 if pytype:
1090 self.members.update(make_type_dict(self, pytype))
1091 # store attempted slot redefinitions during type declaration, for resolution in finish_bind
1092 self._slot_redefs: Dict[str, List[TypeRef]] = {}
1093
1094 @property
1095 def name(self) -> str:
1096 return f"Type[{self.instance_name}]"
1097
1098 @property
1099 def instance_name(self) -> str:
1100 name = self.qualname
1101 if self.is_exact and not self.suppress_exact:
1102 name = f"Exact[{name}]"
1103 return name
1104
1105 @property
1106 def qualname(self) -> str:
1107 return self.type_name.friendly_name
1108
1109 @property
1110 def is_generic_parameter(self) -> bool:
1111 """Returns True if this Class represents a generic parameter"""
1112 return False
1113
1114 @property
1115 def contains_generic_parameters(self) -> bool:
1116 """Returns True if this class contains any generic parameters"""
1117 return False
1118
1119 @property
1120 def is_generic_type(self) -> bool:
1121 """Returns True if this class is a generic type"""
1122 return False
1123
1124 @property
1125 def is_generic_type_definition(self) -> bool:
1126 """Returns True if this class is a generic type definition.
1127 It'll be a generic type which still has unbound generic type
1128 parameters"""
1129 return False
1130
1131 @property
1132 def generic_type_def(self) -> Optional[Class]:
1133 """Gets the generic type definition that defined this class"""
1134 return None
1135
1136 def make_generic_type(
1137 self,
1138 index: Tuple[Class, ...],
1139 generic_types: GenericTypesDict,
1140 ) -> Optional[Class]:
1141 """Binds the generic type parameters to a generic type definition"""
1142 return None
1143
1144 def bind_attr(
1145 self, node: ast.Attribute, visitor: TypeBinder, type_ctx: Optional[Class]
1146 ) -> None:
1147 for base in self.mro:
1148 member = base.members.get(node.attr)
1149 if member is not None:
1150 member.bind_descr_get(node, None, self, visitor, type_ctx)
1151 return
1152
1153 super().bind_attr(node, visitor, type_ctx)
1154
1155 def bind_binop(
1156 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
1157 ) -> bool:
1158 if isinstance(node.op, ast.BitOr):
1159 rtype = visitor.get_type(node.right)
1160 if rtype is NONE_TYPE.instance:
1161 rtype = NONE_TYPE
1162 if rtype is DYNAMIC:
1163 rtype = DYNAMIC_TYPE
1164 if not isinstance(rtype, Class):
1165 raise visitor.syntax_error(
1166 f"unsupported operand type(s) for |: {self.name} and {rtype.name}",
1167 node,
1168 )
1169 union = UNION_TYPE.make_generic_type(
1170 (self, rtype), visitor.symtable.generic_types
1171 )
1172 visitor.set_type(node, union)
1173 return True
1174
1175 return super().bind_binop(node, visitor, type_ctx)
1176
1177 @property
1178 def can_be_narrowed(self) -> bool:
1179 return True
1180
1181 @property
1182 def type_descr(self) -> TypeDescr:
1183 return self.type_name.type_descr
1184
1185 def bind_call(
1186 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
1187 ) -> NarrowingEffect:
1188 visitor.set_type(node, self.instance)
1189 for arg in node.args:
1190 visitor.visit(arg)
1191 for arg in node.keywords:
1192 visitor.visit(arg.value)
1193 self.check_args_for_primitives(node, visitor)
1194 return NO_EFFECT
1195
1196 def can_assign_from(self, src: Class) -> bool:
1197 """checks to see if the src value can be assigned to this value. Currently
1198 you can assign a derived type to a base type. You cannot assign a primitive
1199 type to an object type.
1200
1201 At some point we may also support some form of interfaces via protocols if we
1202 implement a more efficient form of interface dispatch than doing the dictionary
1203 lookup for the member."""
1204 return src is self or (
1205 not self.is_exact and not isinstance(src, CType) and self.issubclass(src)
1206 )
1207
1208 def __repr__(self) -> str:
1209 return f"<{self.name} class>"
1210
1211 def isinstance(self, src: Value) -> bool:
1212 return self.issubclass(src.klass)
1213
1214 def issubclass(self, src: Class) -> bool:
1215 return self.type_descr in src.mro_type_descrs
1216
1217 def finish_bind(self, module: ModuleTable) -> None:
1218 for name, new_type_refs in self._slot_redefs.items():
1219 cur_slot = self.members[name]
1220 assert isinstance(cur_slot, Slot)
1221 cur_type = cur_slot.decl_type
1222 if any(tr.resolved() != cur_type for tr in new_type_refs):
1223 raise TypedSyntaxError(
1224 f"conflicting type definitions for slot {name} in {self.name}"
1225 )
1226 self._slot_redefs = {}
1227
1228 inherited = set()
1229 for name, my_value in self.members.items():
1230 for base in self.mro[1:]:
1231 value = base.members.get(name)
1232 if value is not None and type(my_value) != type(value):
1233 # TODO: There's more checking we should be doing to ensure
1234 # this is a compatible override
1235 raise TypedSyntaxError(
1236 f"class cannot hide inherited member: {value!r}"
1237 )
1238 elif isinstance(value, Slot):
1239 inherited.add(name)
1240 elif isinstance(value, (Function, StaticMethod)):
1241 if value.is_final:
1242 raise TypedSyntaxError(
1243 f"Cannot assign to a Final attribute of {self.instance.name}:{name}"
1244 )
1245 if (
1246 isinstance(my_value, Slot)
1247 and my_value.is_final
1248 and not my_value.assignment
1249 ):
1250 raise TypedSyntaxError(
1251 f"Final attribute not initialized: {self.instance.name}:{name}"
1252 )
1253
1254 for name in inherited:
1255 assert type(self.members[name]) is Slot
1256 del self.members[name]
1257
1258 def define_slot(
1259 self,
1260 name: str,
1261 type_ref: Optional[TypeRef] = None,
1262 assignment: Optional[AST] = None,
1263 ) -> None:
1264 existing = self.members.get(name)
1265 if existing is None:
1266 self.members[name] = Slot(
1267 type_ref or ResolvedTypeRef(DYNAMIC_TYPE), name, self, assignment
1268 )
1269 elif isinstance(existing, Slot):
1270 if not existing.assignment:
1271 existing.assignment = assignment
1272 if type_ref is not None:
1273 self._slot_redefs.setdefault(name, []).append(type_ref)
1274 else:
1275 raise TypedSyntaxError(
1276 f"slot conflicts with other member {name} in {self.name}"
1277 )
1278
1279 def define_function(
1280 self,
1281 name: str,
1282 func: Function | StaticMethod,
1283 visitor: DeclarationVisitor,
1284 ) -> None:
1285 if name in self.members:
1286 raise TypedSyntaxError(
1287 f"function conflicts with other member {name} in {self.name}"
1288 )
1289
1290 func.set_container_type(self)
1291
1292 self.members[name] = func
1293
1294 @property
1295 def mro(self) -> Sequence[Class]:
1296 mro = self._mro
1297 if mro is None:
1298 if not all(self.bases):
1299 # TODO: We can't compile w/ unknown bases
1300 mro = []
1301 else:
1302 mro = _mro(self)
1303 self._mro = mro
1304
1305 return mro
1306
1307 @property
1308 def mro_type_descrs(self) -> Collection[TypeDescr]:
1309 cached = self._mro_type_descrs
1310 if cached is None:
1311 self._mro_type_descrs = cached = {b.type_descr for b in self.mro}
1312 return cached
1313
1314 def bind_generics(
1315 self,
1316 name: GenericTypeName,
1317 generic_types: Dict[Class, Dict[Tuple[Class, ...], Class]],
1318 ) -> Class:
1319 return self
1320
1321 def get_own_member(self, name: str) -> Optional[Value]:
1322 return self.members.get(name)
1323
1324 def get_parent_member(self, name: str) -> Optional[Value]:
1325 # the first entry of mro is the class itself
1326 for b in self.mro[1:]:
1327 slot = b.members.get(name, None)
1328 if slot:
1329 return slot
1330
1331 def get_member(self, name: str) -> Optional[Value]:
1332 member = self.get_own_member(name)
1333 if member:
1334 return member
1335 return self.get_parent_member(name)
1336
1337
1338class GenericClass(Class):
1339 type_name: GenericTypeName
1340 is_variadic = False
1341
1342 def __init__(
1343 self,
1344 name: GenericTypeName,
1345 bases: Optional[List[Class]] = None,
1346 instance: Optional[Object[Class]] = None,
1347 klass: Optional[Class] = None,
1348 members: Optional[Dict[str, Value]] = None,
1349 type_def: Optional[GenericClass] = None,
1350 is_exact: bool = False,
1351 pytype: Optional[Type[object]] = None,
1352 ) -> None:
1353 super().__init__(name, bases, instance, klass, members, is_exact, pytype)
1354 self.gen_name = name
1355 self.type_def = type_def
1356
1357 def bind_call(
1358 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
1359 ) -> NarrowingEffect:
1360 if self.contains_generic_parameters:
1361 raise visitor.syntax_error(
1362 f"cannot create instances of a generic {self.name}", node
1363 )
1364 return super().bind_call(node, visitor, type_ctx)
1365
1366 def bind_subscr(
1367 self, node: ast.Subscript, type: Value, visitor: TypeBinder
1368 ) -> None:
1369 slice = node.slice
1370 if not isinstance(slice, ast.Index):
1371 raise visitor.syntax_error("can't slice generic types", node)
1372
1373 visitor.visit(node.slice)
1374 val = slice.value
1375
1376 if isinstance(val, ast.Tuple):
1377 multiple: List[Class] = []
1378 for elt in val.elts:
1379 klass = visitor.cur_mod.resolve_annotation(elt)
1380 if klass is None:
1381 visitor.set_type(node, DYNAMIC)
1382 return
1383 multiple.append(klass)
1384
1385 index = tuple(multiple)
1386 if (not self.is_variadic) and len(val.elts) != len(self.gen_name.args):
1387 raise visitor.syntax_error(
1388 "incorrect number of generic arguments", node
1389 )
1390 else:
1391 if (not self.is_variadic) and len(self.gen_name.args) != 1:
1392 raise visitor.syntax_error(
1393 "incorrect number of generic arguments", node
1394 )
1395
1396 single = visitor.cur_mod.resolve_annotation(val)
1397 if single is None:
1398 visitor.set_type(node, DYNAMIC)
1399 return
1400
1401 index = (single,)
1402
1403 klass = self.make_generic_type(index, visitor.symtable.generic_types)
1404 visitor.set_type(node, klass)
1405
1406 @property
1407 def type_args(self) -> Sequence[Class]:
1408 return self.type_name.args
1409
1410 @property
1411 def contains_generic_parameters(self) -> bool:
1412 for arg in self.gen_name.args:
1413 if arg.is_generic_parameter:
1414 return True
1415 return False
1416
1417 @property
1418 def is_generic_type(self) -> bool:
1419 return True
1420
1421 @property
1422 def is_generic_type_definition(self) -> bool:
1423 return self.type_def is None
1424
1425 @property
1426 def generic_type_def(self) -> Optional[Class]:
1427 """Gets the generic type definition that defined this class"""
1428 return self.type_def
1429
1430 def make_generic_type(
1431 self,
1432 index: Tuple[Class, ...],
1433 generic_types: GenericTypesDict,
1434 ) -> Class:
1435 instantiations = generic_types.get(self)
1436 if instantiations is not None:
1437 instance = instantiations.get(index)
1438 if instance is not None:
1439 return instance
1440 else:
1441 generic_types[self] = instantiations = {}
1442
1443 type_args = index
1444 type_name = GenericTypeName(
1445 self.type_name.module, self.type_name.name, type_args
1446 )
1447 generic_bases: List[Optional[Class]] = [
1448 (
1449 base.make_generic_type(index, generic_types)
1450 if base.contains_generic_parameters
1451 else base
1452 )
1453 for base in self.bases
1454 ]
1455 bases: List[Class] = [base for base in generic_bases if base is not None]
1456 InstanceType = type(self.instance)
1457 instance = InstanceType.__new__(InstanceType)
1458 instance.__dict__.update(self.instance.__dict__)
1459 concrete = type(self)(
1460 type_name,
1461 bases,
1462 instance,
1463 self.klass,
1464 {},
1465 is_exact=self.is_exact,
1466 type_def=self,
1467 )
1468
1469 instance.klass = concrete
1470
1471 instantiations[index] = concrete
1472 concrete.members.update(
1473 {
1474 k: v.make_generic(concrete, type_name, generic_types)
1475 for k, v in self.members.items()
1476 }
1477 )
1478 return concrete
1479
1480 def bind_generics(
1481 self,
1482 name: GenericTypeName,
1483 generic_types: Dict[Class, Dict[Tuple[Class, ...], Class]],
1484 ) -> Class:
1485 if self.contains_generic_parameters:
1486 type_args = [
1487 arg for arg in self.type_name.args if isinstance(arg, GenericParameter)
1488 ]
1489 assert len(type_args) == len(self.type_name.args)
1490 # map the generic type parameters for the type to the parameters provided
1491 bind_args = tuple(name.args[arg.index] for arg in type_args)
1492 # We don't yet support generic methods, so all of the generic parameters are coming from the
1493 # type definition.
1494
1495 return self.make_generic_type(bind_args, generic_types)
1496
1497 return self
1498
1499
1500class GenericParameter(Class):
1501 def __init__(self, name: str, index: int) -> None:
1502 super().__init__(TypeName("", name), [], None, None, {})
1503 self.index = index
1504
1505 @property
1506 def name(self) -> str:
1507 return self.type_name.name
1508
1509 @property
1510 def is_generic_parameter(self) -> bool:
1511 return True
1512
1513 def bind_generics(
1514 self,
1515 name: GenericTypeName,
1516 generic_types: Dict[Class, Dict[Tuple[Class, ...], Class]],
1517 ) -> Class:
1518 return name.args[self.index]
1519
1520
1521class CType(Class):
1522 """base class for primitives that aren't heap allocated"""
1523
1524 suppress_exact = True
1525
1526 def __init__(
1527 self,
1528 type_name: TypeName,
1529 bases: Optional[List[Class]] = None,
1530 instance: Optional[CInstance[Class]] = None,
1531 klass: Optional[Class] = None,
1532 members: Optional[Dict[str, Value]] = None,
1533 is_exact: bool = True,
1534 pytype: Optional[Type[object]] = None,
1535 ) -> None:
1536 super().__init__(type_name, bases, instance, klass, members, is_exact, pytype)
1537
1538 @property
1539 def can_be_narrowed(self) -> bool:
1540 return False
1541
1542 def bind_call(
1543 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
1544 ) -> NarrowingEffect:
1545 """
1546 Almost the same as the base class method, but this allows args to be primitives
1547 so we can write something like (explicit conversions):
1548 x = int32(int8(5))
1549 """
1550 visitor.set_type(node, self.instance)
1551 for arg in node.args:
1552 visitor.visit(arg, self.instance)
1553 return NO_EFFECT
1554
1555
1556class DynamicClass(Class):
1557 instance: DynamicInstance
1558
1559 def __init__(self) -> None:
1560 super().__init__(
1561 # any references to dynamic at runtime are object
1562 TypeName("builtins", "object"),
1563 bases=[OBJECT_TYPE],
1564 instance=DynamicInstance(self),
1565 )
1566
1567 @property
1568 def qualname(self) -> str:
1569 return "dynamic"
1570
1571 def can_assign_from(self, src: Class) -> bool:
1572 # No automatic boxing to the dynamic type
1573 return not isinstance(src, CType)
1574
1575
1576class DynamicInstance(Object[DynamicClass]):
1577 def __init__(self, klass: DynamicClass) -> None:
1578 super().__init__(klass)
1579
1580 def bind_constant(self, node: ast.Constant, visitor: TypeBinder) -> None:
1581 n = node.value
1582 inst = CONSTANT_TYPES.get(type(n), DYNAMIC_TYPE.instance)
1583 visitor.set_type(node, inst)
1584
1585 def emit_binop(self, node: ast.BinOp, code_gen: Static38CodeGenerator) -> None:
1586 if maybe_emit_sequence_repeat(node, code_gen):
1587 return
1588 code_gen.defaultVisit(node)
1589
1590
1591class NoneType(Class):
1592 suppress_exact = True
1593
1594 def __init__(self) -> None:
1595 super().__init__(
1596 TypeName("builtins", "None"),
1597 [OBJECT_TYPE],
1598 NoneInstance(self),
1599 is_exact=True,
1600 )
1601
1602
1603UNARY_SYMBOLS: Mapping[Type[ast.unaryop], str] = {
1604 ast.UAdd: "+",
1605 ast.USub: "-",
1606 ast.Invert: "~",
1607}
1608
1609
1610class NoneInstance(Object[NoneType]):
1611 def bind_attr(
1612 self, node: ast.Attribute, visitor: TypeBinder, type_ctx: Optional[Class]
1613 ) -> None:
1614 raise visitor.syntax_error(
1615 f"'NoneType' object has no attribute '{node.attr}'", node
1616 )
1617
1618 def bind_call(
1619 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
1620 ) -> NarrowingEffect:
1621 raise visitor.syntax_error("'NoneType' object is not callable", node)
1622
1623 def bind_subscr(
1624 self, node: ast.Subscript, type: Value, visitor: TypeBinder
1625 ) -> None:
1626 raise visitor.syntax_error("'NoneType' object is not subscriptable", node)
1627
1628 def bind_unaryop(
1629 self, node: ast.UnaryOp, visitor: TypeBinder, type_ctx: Optional[Class]
1630 ) -> None:
1631 if not isinstance(node.op, ast.Not):
1632 raise visitor.syntax_error(
1633 f"bad operand type for unary {UNARY_SYMBOLS[type(node.op)]}: 'NoneType'",
1634 node,
1635 )
1636 visitor.set_type(node, BOOL_TYPE.instance)
1637
1638 def bind_binop(
1639 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
1640 ) -> bool:
1641 # support `None | int` as a union type; None is special in that it is
1642 # not a type but can be used synonymously with NoneType for typing.
1643 if isinstance(node.op, ast.BitOr):
1644 return self.klass.bind_binop(node, visitor, type_ctx)
1645 else:
1646 return super().bind_binop(node, visitor, type_ctx)
1647
1648 def bind_compare(
1649 self,
1650 node: ast.Compare,
1651 left: expr,
1652 op: cmpop,
1653 right: expr,
1654 visitor: TypeBinder,
1655 type_ctx: Optional[Class],
1656 ) -> bool:
1657 if isinstance(op, (ast.Eq, ast.NotEq, ast.Is, ast.IsNot)):
1658 return super().bind_compare(node, left, op, right, visitor, type_ctx)
1659 ltype = visitor.get_type(left)
1660 rtype = visitor.get_type(right)
1661 raise visitor.syntax_error(
1662 f"'{CMPOP_SIGILS[type(op)]}' not supported between '{ltype.name}' and '{rtype.name}'",
1663 node,
1664 )
1665
1666 def bind_reverse_compare(
1667 self,
1668 node: ast.Compare,
1669 left: expr,
1670 op: cmpop,
1671 right: expr,
1672 visitor: TypeBinder,
1673 type_ctx: Optional[Class],
1674 ) -> bool:
1675 if isinstance(op, (ast.Eq, ast.NotEq, ast.Is, ast.IsNot)):
1676 return super().bind_reverse_compare(
1677 node, left, op, right, visitor, type_ctx
1678 )
1679 ltype = visitor.get_type(left)
1680 rtype = visitor.get_type(right)
1681 raise visitor.syntax_error(
1682 f"'{CMPOP_SIGILS[type(op)]}' not supported between '{ltype.name}' and '{rtype.name}'",
1683 node,
1684 )
1685
1686
1687# https://www.python.org/download/releases/2.3/mro/
1688def _merge(seqs: Iterable[List[Class]]) -> List[Class]:
1689 res = []
1690 i = 0
1691 while True:
1692 nonemptyseqs = [seq for seq in seqs if seq]
1693 if not nonemptyseqs:
1694 return res
1695 i += 1
1696 cand = None
1697 for seq in nonemptyseqs: # find merge candidates among seq heads
1698 cand = seq[0]
1699 nothead = [s for s in nonemptyseqs if cand in s[1:]]
1700 if nothead:
1701 cand = None # reject candidate
1702 else:
1703 break
1704 if not cand:
1705 types = {seq[0]: None for seq in nonemptyseqs}
1706 raise SyntaxError(
1707 "Cannot create a consistent method resolution order (MRO) for bases: "
1708 + ", ".join(t.name for t in types)
1709 )
1710 res.append(cand)
1711 for seq in nonemptyseqs: # remove cand
1712 if seq[0] == cand:
1713 del seq[0]
1714
1715
1716def _mro(C: Class) -> List[Class]:
1717 "Compute the class precedence list (mro) according to C3"
1718 return _merge([[C]] + list(map(_mro, C.bases)) + [list(C.bases)])
1719
1720
1721class Parameter:
1722 def __init__(
1723 self,
1724 name: str,
1725 idx: int,
1726 type_ref: TypeRef,
1727 has_default: bool,
1728 default_val: object,
1729 is_kwonly: bool,
1730 ) -> None:
1731 self.name = name
1732 self.type_ref = type_ref
1733 self.index = idx
1734 self.has_default = has_default
1735 self.default_val = default_val
1736 self.is_kwonly = is_kwonly
1737
1738 def __repr__(self) -> str:
1739 return (
1740 f"<Parameter name={self.name}, ref={self.type_ref}, "
1741 f"index={self.index}, has_default={self.has_default}>"
1742 )
1743
1744 def bind_generics(
1745 self,
1746 name: GenericTypeName,
1747 generic_types: Dict[Class, Dict[Tuple[Class, ...], Class]],
1748 ) -> Parameter:
1749 klass = self.type_ref.resolved().bind_generics(name, generic_types)
1750 if klass is not self.type_ref.resolved():
1751 return Parameter(
1752 self.name,
1753 self.index,
1754 ResolvedTypeRef(klass),
1755 self.has_default,
1756 self.default_val,
1757 self.is_kwonly,
1758 )
1759
1760 return self
1761
1762
1763def is_subsequence(a: Iterable[object], b: Iterable[object]) -> bool:
1764 # for loops go brrrr :)
1765 # https://ericlippert.com/2020/03/27/new-grad-vs-senior-dev/
1766 itr = iter(a)
1767 for each in b:
1768 if each not in itr:
1769 return False
1770 return True
1771
1772
1773class ArgMapping:
1774 def __init__(
1775 self,
1776 callable: Callable[TClass],
1777 call: ast.Call,
1778 self_arg: Optional[ast.expr],
1779 ) -> None:
1780 self.callable = callable
1781 self.call = call
1782 pos_args: List[ast.expr] = []
1783 if self_arg is not None:
1784 pos_args.append(self_arg)
1785 pos_args.extend(call.args)
1786 self.args: List[ast.expr] = pos_args
1787
1788 self.kwargs: List[Tuple[Optional[str], ast.expr]] = [
1789 (kwarg.arg, kwarg.value) for kwarg in call.keywords
1790 ]
1791 self.self_arg = self_arg
1792 self.emitters: List[ArgEmitter] = []
1793 self.nvariadic = 0
1794 self.nseen = 0
1795 self.spills: Dict[int, SpillArg] = {}
1796
1797 def bind_args(self, visitor: TypeBinder) -> None:
1798 # TODO: handle duplicate args and other weird stuff a-la
1799 # https://fburl.com/diffusion/q6tpinw8
1800
1801 # Process provided position arguments to expected parameters
1802 for idx, (param, arg) in enumerate(zip(self.callable.args, self.args)):
1803 if param.is_kwonly:
1804 raise visitor.syntax_error(
1805 f"{self.callable.qualname} takes {idx} positional args but "
1806 f"{len(self.args)} {'was' if len(self.args) == 1 else 'were'} given",
1807 self.call,
1808 )
1809 elif isinstance(arg, Starred):
1810 # Skip type verification here, f(a, b, *something)
1811 # TODO: add support for this by implementing type constrained tuples
1812 self.nvariadic += 1
1813 star_params = self.callable.args[idx:]
1814 self.emitters.append(StarredArg(arg.value, star_params))
1815 self.nseen = len(self.callable.args)
1816 for arg in self.args[idx:]:
1817 visitor.visit(arg)
1818 break
1819
1820 resolved_type = self.visit_arg(visitor, param, arg, "positional")
1821 self.emitters.append(PositionArg(arg, resolved_type))
1822 self.nseen += 1
1823
1824 self.bind_kwargs(visitor)
1825
1826 for argname, argvalue in self.kwargs:
1827 if argname is None:
1828 visitor.visit(argvalue)
1829 continue
1830
1831 if argname not in self.callable.args_by_name:
1832 raise visitor.syntax_error(
1833 f"Given argument {argname} "
1834 f"does not exist in the definition of {self.callable.qualname}",
1835 self.call,
1836 )
1837
1838 # nseen must equal number of defined args if no variadic args are used
1839 if self.nvariadic == 0 and (self.nseen != len(self.callable.args)):
1840 raise visitor.syntax_error(
1841 f"Mismatched number of args for {self.callable.name}. "
1842 f"Expected {len(self.callable.args)}, got {self.nseen}",
1843 self.call,
1844 )
1845
1846 def bind_kwargs(self, visitor: TypeBinder) -> None:
1847 spill_start = len(self.emitters)
1848 seen_variadic = False
1849 # Process unhandled arguments which can be populated via defaults,
1850 # keyword arguments, or **mapping.
1851 cur_kw_arg = 0
1852 for idx in range(self.nseen, len(self.callable.args)):
1853 param = self.callable.args[idx]
1854 name = param.name
1855 if (
1856 cur_kw_arg is not None
1857 and cur_kw_arg < len(self.kwargs)
1858 and self.kwargs[cur_kw_arg][0] == name
1859 ):
1860 # keyword arg hit, with the keyword arguments still in order...
1861 arg = self.kwargs[cur_kw_arg][1]
1862 resolved_type = self.visit_arg(visitor, param, arg, "keyword")
1863 cur_kw_arg += 1
1864
1865 self.emitters.append(KeywordArg(arg, resolved_type))
1866 self.nseen += 1
1867 continue
1868
1869 variadic_idx = None
1870 for candidate_kw in range(len(self.kwargs)):
1871 if name == self.kwargs[candidate_kw][0]:
1872 arg = self.kwargs[candidate_kw][1]
1873
1874 tmp_name = f"{_TMP_VAR_PREFIX}{name}"
1875 self.spills[candidate_kw] = SpillArg(arg, tmp_name)
1876
1877 if cur_kw_arg is not None:
1878 cur_kw_arg = None
1879 spill_start = len(self.emitters)
1880
1881 resolved_type = self.visit_arg(visitor, param, arg, "keyword")
1882 self.emitters.append(SpilledKeywordArg(tmp_name, resolved_type))
1883 break
1884 elif self.kwargs[candidate_kw][0] == None:
1885 variadic_idx = candidate_kw
1886 else:
1887 if variadic_idx is not None:
1888 # We have a f(**something), if the arg is unavailable, we
1889 # load it from the mapping
1890 if variadic_idx not in self.spills:
1891 self.spills[variadic_idx] = SpillArg(
1892 self.kwargs[variadic_idx][1], f"{_TMP_VAR_PREFIX}**"
1893 )
1894
1895 if cur_kw_arg is not None:
1896 cur_kw_arg = None
1897 spill_start = len(self.emitters)
1898
1899 self.emitters.append(
1900 KeywordMappingArg(param, f"{_TMP_VAR_PREFIX}**")
1901 )
1902 elif param.has_default:
1903 self.emitters.append(DefaultArg(param.default_val))
1904 else:
1905 # It's an error if this arg did not have a default value in the definition
1906 raise visitor.syntax_error(
1907 f"Function {self.callable.qualname} expects a value for "
1908 f"argument {param.name}",
1909 self.call,
1910 )
1911
1912 self.nseen += 1
1913
1914 if self.spills:
1915 self.emitters[spill_start:spill_start] = [
1916 x[1] for x in sorted(self.spills.items())
1917 ]
1918
1919 def visit_arg(
1920 self, visitor: TypeBinder, param: Parameter, arg: expr, arg_style: str
1921 ) -> Class:
1922 resolved_type = param.type_ref.resolved()
1923 exc = None
1924 try:
1925 visitor.visit(arg, resolved_type.instance if resolved_type else None)
1926 except TypedSyntaxError as e:
1927 # We may report a better error message below...
1928 exc = e
1929 visitor.check_can_assign_from(
1930 resolved_type,
1931 visitor.get_type(arg).klass,
1932 arg,
1933 f"{arg_style} argument type mismatch",
1934 )
1935 if exc is not None:
1936 raise exc
1937 return resolved_type
1938
1939
1940class ArgEmitter:
1941 def __init__(self, argument: expr, type: Class) -> None:
1942 self.argument = argument
1943
1944 self.type = type
1945
1946 def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
1947 pass
1948
1949
1950class PositionArg(ArgEmitter):
1951 def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
1952 arg_type = code_gen.get_type(self.argument)
1953 code_gen.visit(self.argument)
1954
1955 code_gen.emit_type_check(
1956 self.type,
1957 arg_type.klass,
1958 node,
1959 )
1960
1961 def __repr__(self) -> str:
1962 return f"PositionArg({to_expr(self.argument)}, {self.type})"
1963
1964
1965class StarredArg(ArgEmitter):
1966 def __init__(self, argument: expr, params: List[Parameter]) -> None:
1967
1968 self.argument = argument
1969 self.params = params
1970
1971 def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
1972 code_gen.visit(self.argument)
1973 for idx, param in enumerate(self.params):
1974 code_gen.emit("LOAD_ITERABLE_ARG", idx)
1975
1976 if (
1977 param.type_ref.resolved() is not None
1978 and param.type_ref.resolved() is not DYNAMIC
1979 ):
1980 code_gen.emit("ROT_TWO")
1981 code_gen.emit("CAST", param.type_ref.resolved().type_descr)
1982 code_gen.emit("ROT_TWO")
1983
1984 # Remove the tuple from TOS
1985 code_gen.emit("POP_TOP")
1986
1987
1988class SpillArg(ArgEmitter):
1989 def __init__(self, argument: expr, temporary: str) -> None:
1990 self.argument = argument
1991 self.temporary = temporary
1992
1993 def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
1994 code_gen.visit(self.argument)
1995 code_gen.emit("STORE_FAST", self.temporary)
1996
1997 def __repr__(self) -> str:
1998 return f"SpillArg(..., {self.temporary})"
1999
2000
2001class SpilledKeywordArg(ArgEmitter):
2002 def __init__(self, temporary: str, type: Class) -> None:
2003 self.temporary = temporary
2004 self.type = type
2005
2006 def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
2007 code_gen.emit("LOAD_FAST", self.temporary)
2008 code_gen.emit_type_check(
2009 self.type,
2010 DYNAMIC_TYPE,
2011 node,
2012 )
2013
2014 def __repr__(self) -> str:
2015 return f"SpilledKeywordArg({self.temporary})"
2016
2017
2018class KeywordArg(ArgEmitter):
2019 def __init__(self, argument: expr, type: Class) -> None:
2020 self.argument = argument
2021 self.type = type
2022
2023 def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
2024 code_gen.visit(self.argument)
2025 code_gen.emit_type_check(
2026 self.type,
2027 code_gen.get_type(self.argument).klass,
2028 node,
2029 )
2030
2031
2032class KeywordMappingArg(ArgEmitter):
2033 def __init__(self, param: Parameter, variadic: str) -> None:
2034 self.param = param
2035
2036 self.variadic = variadic
2037
2038 def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
2039 if self.param.has_default:
2040 code_gen.emit("LOAD_CONST", self.param.default_val)
2041 code_gen.emit("LOAD_FAST", self.variadic)
2042 code_gen.emit("LOAD_CONST", self.param.name)
2043 if self.param.has_default:
2044 code_gen.emit("LOAD_MAPPING_ARG", 3)
2045 else:
2046 code_gen.emit("LOAD_MAPPING_ARG", 2)
2047 code_gen.emit_type_check(
2048 self.param.type_ref.resolved() or DYNAMIC_TYPE, DYNAMIC_TYPE, node
2049 )
2050
2051
2052class DefaultArg(ArgEmitter):
2053 def __init__(self, value: object) -> None:
2054 self.value = value
2055
2056 def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
2057 code_gen.emit("LOAD_CONST", self.value)
2058
2059
2060class Callable(Object[TClass]):
2061 def __init__(
2062 self,
2063 klass: Class,
2064 func_name: str,
2065 module_name: str,
2066 args: List[Parameter],
2067 args_by_name: Dict[str, Parameter],
2068 num_required_args: int,
2069 vararg: Optional[Parameter],
2070 kwarg: Optional[Parameter],
2071 return_type: TypeRef,
2072 ) -> None:
2073 super().__init__(klass)
2074 self.func_name = func_name
2075 self.module_name = module_name
2076 self.container_type: Optional[Class] = None
2077 self.args = args
2078 self.args_by_name = args_by_name
2079 self.num_required_args = num_required_args
2080 self.has_vararg: bool = vararg is not None
2081 self.has_kwarg: bool = kwarg is not None
2082 self.return_type = return_type
2083 self.is_final = False
2084
2085 @property
2086 def qualname(self) -> str:
2087 cont = self.container_type
2088 if cont:
2089 return f"{cont.qualname}.{self.func_name}"
2090 return f"{self.module_name}.{self.func_name}"
2091
2092 @property
2093 def type_descr(self) -> TypeDescr:
2094 cont = self.container_type
2095 if cont:
2096 return cont.type_descr + (self.func_name,)
2097 return (self.module_name, self.func_name)
2098
2099 def set_container_type(self, klass: Optional[Class]) -> None:
2100 self.container_type = klass
2101
2102 def bind_call(
2103 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
2104 ) -> NarrowingEffect:
2105 # Careful adding logic here, MethodType.bind_call() will bypass it
2106 return self.bind_call_self(node, visitor, type_ctx)
2107
2108 def bind_call_self(
2109 self,
2110 node: ast.Call,
2111 visitor: TypeBinder,
2112 type_ctx: Optional[Class],
2113 self_expr: Optional[ast.expr] = None,
2114 ) -> NarrowingEffect:
2115 if self.has_vararg or self.has_kwarg:
2116 return super().bind_call(node, visitor, type_ctx)
2117
2118 if type_ctx is not None:
2119 visitor.check_can_assign_from(
2120 type_ctx.klass,
2121 self.return_type.resolved(),
2122 node,
2123 "is an invalid return type, expected",
2124 )
2125
2126 arg_mapping = ArgMapping(self, node, self_expr)
2127 arg_mapping.bind_args(visitor)
2128
2129 visitor.set_type(node, self.return_type.resolved().instance)
2130 visitor.set_node_data(node, ArgMapping, arg_mapping)
2131 return NO_EFFECT
2132
2133 def _emit_kwarg_temps(
2134 self, keywords: List[ast.keyword], code_gen: Static38CodeGenerator
2135 ) -> Dict[str, str]:
2136 temporaries = {}
2137 for each in keywords:
2138 name = each.arg
2139 if name is not None:
2140 code_gen.visit(each.value)
2141 temp_var_name = f"{_TMP_VAR_PREFIX}{name}"
2142 code_gen.emit("STORE_FAST", temp_var_name)
2143 temporaries[name] = temp_var_name
2144 return temporaries
2145
2146 def _find_provided_kwargs(
2147 self, node: ast.Call
2148 ) -> Tuple[Dict[int, int], Optional[int]]:
2149 # This is a mapping of indices from index in the function definition --> node.keywords
2150 provided_kwargs: Dict[int, int] = {}
2151 # Index of `**something` in the call
2152 variadic_idx: Optional[int] = None
2153 for idx, argument in enumerate(node.keywords):
2154 name = argument.arg
2155 if name is not None:
2156 provided_kwargs[self.args_by_name[name].index] = idx
2157 else:
2158 # Because of the constraints above, we will only ever reach here once
2159 variadic_idx = idx
2160 return provided_kwargs, variadic_idx
2161
2162 def can_call_self(self, node: ast.Call, has_self: bool) -> bool:
2163 if self.has_vararg or self.has_kwarg:
2164 return False
2165
2166 has_default_args = self.num_required_args < len(self.args)
2167 has_star_args = False
2168 for a in node.args:
2169 if isinstance(a, ast.Starred):
2170 if has_star_args:
2171 # We don't support f(*a, *b)
2172 return False
2173 has_star_args = True
2174 elif has_star_args:
2175 # We don't support f(*a, b)
2176 return False
2177
2178 num_star_args = [isinstance(a, ast.Starred) for a in node.args].count(True)
2179 num_dstar_args = [(a.arg is None) for a in node.keywords].count(True)
2180 num_kwonly = len([arg for arg in self.args if arg.is_kwonly])
2181
2182 start = 1 if has_self else 0
2183 for arg in self.args[start + len(node.args) :]:
2184 if arg.has_default and isinstance(arg.default_val, ast.expr):
2185 for kw_arg in node.keywords:
2186 if kw_arg.arg == arg.name:
2187 break
2188 else:
2189 return False
2190 if (
2191 # We don't support f(**a, **b)
2192 num_dstar_args > 1
2193 # We don't support f(1, 2, *a) if f has any default arg values
2194 or (has_default_args and has_star_args)
2195 or num_kwonly
2196 ):
2197 return False
2198
2199 return True
2200
2201 def emit_call_self(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
2202 arg_mapping: ArgMapping = code_gen.get_node_data(node, ArgMapping)
2203 for emitter in arg_mapping.emitters:
2204 emitter.emit(node, code_gen)
2205 self_expr = arg_mapping.self_arg
2206 if (
2207 self_expr is None
2208 or code_gen.get_type(self_expr).klass.is_exact
2209 or code_gen.get_type(self_expr).klass.is_final
2210 ):
2211 code_gen.emit("EXTENDED_ARG", 0)
2212 code_gen.emit("INVOKE_FUNCTION", (self.type_descr, len(self.args)))
2213 else:
2214 code_gen.emit_invoke_method(self.type_descr, len(self.args) - 1)
2215
2216
2217class ContainerTypeRef(TypeRef):
2218 def __init__(self, func: Function) -> None:
2219 self.func = func
2220
2221 def resolved(self, is_declaration: bool = False) -> Class:
2222 res = self.func.container_type
2223 if res is None:
2224 return DYNAMIC_TYPE
2225 return res
2226
2227
2228class InlineRewriter(ASTRewriter):
2229 def __init__(self, replacements: Dict[str, ast.expr]) -> None:
2230 super().__init__()
2231 self.replacements = replacements
2232
2233 def visit(
2234 self, node: Union[TAst, Sequence[AST]], *args: object
2235 ) -> Union[AST, Sequence[AST]]:
2236 res = super().visit(node, *args)
2237 if res is node:
2238 if isinstance(node, AST):
2239 return self.clone_node(node)
2240
2241 return list(node)
2242
2243 return res
2244
2245 def visitName(self, node: ast.Name) -> AST:
2246 res = self.replacements.get(node.id)
2247 if res is None:
2248 return self.clone_node(node)
2249
2250 return res
2251
2252
2253class InlinedCall:
2254 def __init__(
2255 self,
2256 expr: ast.expr,
2257 replacements: Dict[ast.expr, ast.expr],
2258 spills: Dict[str, Tuple[ast.expr, ast.Name]],
2259 ) -> None:
2260 self.expr = expr
2261 self.replacements = replacements
2262 self.spills = spills
2263
2264
2265class Function(Callable[Class]):
2266 def __init__(
2267 self,
2268 node: Union[AsyncFunctionDef, FunctionDef],
2269 module: ModuleTable,
2270 ret_type: TypeRef,
2271 ) -> None:
2272 super().__init__(
2273 FUNCTION_TYPE,
2274 node.name,
2275 module.name,
2276 [],
2277 {},
2278 0,
2279 None,
2280 None,
2281 ret_type,
2282 )
2283 self.node = node
2284 self.module = module
2285 self.process_args(module)
2286 self.inline = False
2287 self.donotcompile = False
2288
2289 @property
2290 def name(self) -> str:
2291 return f"function {self.qualname}"
2292
2293 def bind_call(
2294 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
2295 ) -> NarrowingEffect:
2296 res = super().bind_call(node, visitor, type_ctx)
2297 if self.inline and visitor.optimize == 2:
2298 assert isinstance(self.node.body[0], ast.Return)
2299
2300 return self.bind_inline_call(node, visitor, type_ctx) or res
2301
2302 return res
2303
2304 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
2305 if not self.can_call_self(node, False):
2306 return super().emit_call(node, code_gen)
2307
2308 if self.inline and code_gen.optimization_lvl == 2:
2309 return self.emit_inline_call(node, code_gen)
2310
2311 return self.emit_call_self(node, code_gen)
2312
2313 def bind_inline_call(
2314 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
2315 ) -> Optional[NarrowingEffect]:
2316 args = visitor.get_node_data(node, ArgMapping)
2317 arg_replacements = {}
2318 spills = {}
2319
2320 if visitor.inline_depth > 20:
2321 visitor.set_node_data(node, Optional[InlinedCall], None)
2322 return None
2323
2324 visitor.inline_depth += 1
2325 for idx, arg in enumerate(args.emitters):
2326 name = self.node.args.args[idx].arg
2327
2328 if isinstance(arg, DefaultArg):
2329 arg_replacements[name] = ast.Constant(arg.value)
2330 continue
2331 elif not isinstance(arg, (PositionArg, KeywordArg)):
2332 # We don't support complicated calls to inline functions
2333 visitor.set_node_data(node, Optional[InlinedCall], None)
2334 return None
2335
2336 if (
2337 isinstance(arg.argument, ast.Constant)
2338 or visitor.get_final_literal(arg.argument) is not None
2339 ):
2340 arg_replacements[name] = arg.argument
2341 continue
2342
2343 # store to a temporary...
2344 tmp_name = f"{_TMP_VAR_PREFIX}{visitor.inline_depth}{name}"
2345 cur_scope = visitor.symbols.scopes[visitor.scope]
2346 cur_scope.add_def(tmp_name)
2347
2348 store = ast.Name(tmp_name, ast.Store())
2349 visitor.set_type(store, visitor.get_type(arg.argument))
2350 spills[tmp_name] = arg.argument, store
2351
2352 replacement = ast.Name(tmp_name, ast.Load())
2353 visitor.assign_value(replacement, visitor.get_type(arg.argument))
2354
2355 arg_replacements[name] = replacement
2356
2357 # re-write node body with replacements...
2358 return_stmt = self.node.body[0]
2359 assert isinstance(return_stmt, Return)
2360 ret_value = return_stmt.value
2361 if ret_value is not None:
2362 new_node = InlineRewriter(arg_replacements).visit(ret_value)
2363 else:
2364 new_node = ast.Constant(None)
2365 new_node = AstOptimizer().visit(new_node)
2366
2367 inlined_call = InlinedCall(new_node, arg_replacements, spills)
2368 visitor.visit(new_node)
2369 visitor.set_node_data(node, Optional[InlinedCall], inlined_call)
2370
2371 visitor.inline_depth -= 1
2372
2373 def emit_inline_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
2374 assert isinstance(self.node.body[0], ast.Return)
2375 inlined_call = code_gen.get_node_data(node, Optional[InlinedCall])
2376 if inlined_call is None:
2377 return self.emit_call_self(node, code_gen)
2378
2379 for name, (arg, store) in inlined_call.spills.items():
2380 code_gen.visit(arg)
2381
2382 code_gen.get_type(store).emit_name(store, code_gen)
2383
2384 code_gen.visit(inlined_call.expr)
2385
2386 def bind_descr_get(
2387 self,
2388 node: ast.Attribute,
2389 inst: Optional[Object[TClassInv]],
2390 ctx: TClassInv,
2391 visitor: TypeBinder,
2392 type_ctx: Optional[Class],
2393 ) -> None:
2394 if inst is None:
2395 visitor.set_type(node, self)
2396 else:
2397 visitor.set_type(node, MethodType(ctx.type_name, self.node, node, self))
2398
2399 def register_arg(
2400 self,
2401 name: str,
2402 idx: int,
2403 ref: TypeRef,
2404 has_default: bool,
2405 default_val: object,
2406 is_kwonly: bool,
2407 ) -> None:
2408 parameter = Parameter(name, idx, ref, has_default, default_val, is_kwonly)
2409 self.args.append(parameter)
2410 self.args_by_name[name] = parameter
2411 if not has_default:
2412 self.num_required_args += 1
2413
2414 def process_args(
2415 self: Function,
2416 module: ModuleTable,
2417 ) -> None:
2418 """
2419 Register type-refs for each function argument, assume DYNAMIC if annotation is missing.
2420 """
2421 arguments = self.node.args
2422 nrequired = len(arguments.args) - len(arguments.defaults)
2423 no_defaults = cast(List[Optional[ast.expr]], [None] * nrequired)
2424 defaults = no_defaults + cast(List[Optional[ast.expr]], arguments.defaults)
2425 idx = 0
2426 for idx, (argument, default) in enumerate(zip(arguments.args, defaults)):
2427 annotation = argument.annotation
2428 default_val = None
2429 has_default = False
2430 if default is not None:
2431 has_default = True
2432 default_val = get_default_value(default)
2433
2434 if annotation:
2435 ref = TypeRef(module, annotation)
2436 elif idx == 0:
2437 ref = ContainerTypeRef(self)
2438 else:
2439 ref = ResolvedTypeRef(DYNAMIC_TYPE)
2440 self.register_arg(argument.arg, idx, ref, has_default, default_val, False)
2441
2442 base_idx = idx
2443
2444 vararg = arguments.vararg
2445 if vararg:
2446 base_idx += 1
2447 self.has_vararg = True
2448
2449 for argument, default in zip(arguments.kwonlyargs, arguments.kw_defaults):
2450 annotation = argument.annotation
2451 default_val = None
2452 has_default = default is not None
2453 if default is not None:
2454 default_val = get_default_value(default)
2455 if annotation:
2456 ref = TypeRef(module, annotation)
2457 else:
2458 ref = ResolvedTypeRef(DYNAMIC_TYPE)
2459 base_idx += 1
2460 self.register_arg(
2461 argument.arg, base_idx, ref, has_default, default_val, True
2462 )
2463
2464 kwarg = arguments.kwarg
2465 if kwarg:
2466 self.has_kwarg = True
2467
2468 def __repr__(self) -> str:
2469 return f"<{self.name} '{self.name}' instance, args={self.args}>"
2470
2471
2472class MethodType(Object[Class]):
2473 def __init__(
2474 self,
2475 bound_type_name: TypeName,
2476 node: Union[AsyncFunctionDef, FunctionDef],
2477 target: ast.Attribute,
2478 function: Function,
2479 ) -> None:
2480 super().__init__(METHOD_TYPE)
2481 # TODO currently this type (the type the bound method was accessed
2482 # from) is unused, and we just end up deferring to the type where the
2483 # function was defined. This is fine until we want to fully support a
2484 # method defined in one class being also referenced as a method in
2485 # another class.
2486 self.bound_type_name = bound_type_name
2487 self.node = node
2488 self.target = target
2489 self.function = function
2490
2491 @property
2492 def name(self) -> str:
2493 return "method " + self.function.qualname
2494
2495 def bind_call(
2496 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
2497 ) -> NarrowingEffect:
2498 result = self.function.bind_call_self(
2499 node, visitor, type_ctx, self.target.value
2500 )
2501 self.check_args_for_primitives(node, visitor)
2502 return result
2503
2504 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
2505 if not self.function.can_call_self(node, True):
2506 return super().emit_call(node, code_gen)
2507
2508 code_gen.update_lineno(node)
2509
2510 self.function.emit_call_self(node, code_gen)
2511
2512
2513class StaticMethod(Object[Class]):
2514 def __init__(
2515 self,
2516 function: Function,
2517 ) -> None:
2518 super().__init__(STATIC_METHOD_TYPE)
2519 self.function = function
2520
2521 @property
2522 def name(self) -> str:
2523 return "staticmethod " + self.function.qualname
2524
2525 @property
2526 def func_name(self) -> str:
2527 return self.function.func_name
2528
2529 @property
2530 def is_final(self) -> bool:
2531 return self.function.is_final
2532
2533 def set_container_type(self, container_type: Optional[Class]) -> None:
2534 self.function.set_container_type(container_type)
2535
2536 def bind_descr_get(
2537 self,
2538 node: ast.Attribute,
2539 inst: Optional[Object[TClassInv]],
2540 ctx: TClassInv,
2541 visitor: TypeBinder,
2542 type_ctx: Optional[Class],
2543 ) -> None:
2544 visitor.set_type(node, self.function)
2545
2546
2547class TypingFinalDecorator(Class):
2548 def bind_decorate_function(
2549 self, visitor: DeclarationVisitor, fn: Function | StaticMethod
2550 ) -> Value:
2551 if isinstance(fn, StaticMethod):
2552 fn.function.is_final = True
2553 else:
2554 fn.is_final = True
2555 return fn
2556
2557 def bind_decorate_class(self, klass: Class) -> Class:
2558 klass.is_final = True
2559 return klass
2560
2561
2562class AllowWeakrefsDecorator(Class):
2563 def bind_decorate_class(self, klass: Class) -> Class:
2564 klass.allow_weakrefs = True
2565 return klass
2566
2567
2568class DynamicReturnDecorator(Class):
2569 def bind_decorate_function(
2570 self, visitor: DeclarationVisitor, fn: Function | StaticMethod
2571 ) -> Value:
2572 real_fn = fn.function if isinstance(fn, StaticMethod) else fn
2573 real_fn.return_type = ResolvedTypeRef(DYNAMIC_TYPE)
2574 real_fn.module.dynamic_returns.add(real_fn.node.args)
2575 return fn
2576
2577
2578class StaticMethodDecorator(Class):
2579 def bind_decorate_function(
2580 self, visitor: DeclarationVisitor, fn: Function | StaticMethod
2581 ) -> Value:
2582 if isinstance(fn, StaticMethod):
2583 # no-op
2584 return fn
2585 return StaticMethod(fn)
2586
2587
2588class InlineFunctionDecorator(Class):
2589 def bind_decorate_function(
2590 self, visitor: DeclarationVisitor, fn: Function | StaticMethod
2591 ) -> Value:
2592 real_fn = fn.function if isinstance(fn, StaticMethod) else fn
2593 if not isinstance(real_fn.node.body[0], ast.Return):
2594 raise visitor.syntax_error(
2595 "@inline only supported on functions with simple return", real_fn.node
2596 )
2597
2598 real_fn.inline = True
2599 return fn
2600
2601
2602class DoNotCompileDecorator(Class):
2603 def bind_decorate_function(
2604 self, visitor: DeclarationVisitor, fn: Function | StaticMethod
2605 ) -> Optional[Value]:
2606 real_fn = fn.function if isinstance(fn, StaticMethod) else fn
2607 real_fn.donotcompile = True
2608 return fn
2609
2610 def bind_decorate_class(self, klass: Class) -> Class:
2611 klass.donotcompile = True
2612 return klass
2613
2614
2615class BuiltinFunction(Callable[Class]):
2616 def __init__(
2617 self,
2618 func_name: str,
2619 module: str,
2620 args: Optional[Tuple[Parameter, ...]] = None,
2621 return_type: Optional[TypeRef] = None,
2622 ) -> None:
2623 assert isinstance(return_type, (TypeRef, type(None)))
2624 super().__init__(
2625 BUILTIN_METHOD_DESC_TYPE,
2626 func_name,
2627 module,
2628 args,
2629 {},
2630 0,
2631 None,
2632 None,
2633 return_type or ResolvedTypeRef(DYNAMIC_TYPE),
2634 )
2635
2636 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
2637 if node.keywords or (
2638 self.args is not None and not self.can_call_self(node, True)
2639 ):
2640 return super().emit_call(node, code_gen)
2641
2642 code_gen.update_lineno(node)
2643 self.emit_call_self(node, code_gen)
2644
2645
2646class BuiltinMethodDescriptor(Callable[Class]):
2647 def __init__(
2648 self,
2649 func_name: str,
2650 container_type: Class,
2651 args: Optional[Tuple[Parameter, ...]] = None,
2652 return_type: Optional[TypeRef] = None,
2653 ) -> None:
2654 assert isinstance(return_type, (TypeRef, type(None)))
2655 super().__init__(
2656 BUILTIN_METHOD_DESC_TYPE,
2657 func_name,
2658 container_type.type_name.module,
2659 args,
2660 {},
2661 0,
2662 None,
2663 None,
2664 return_type or ResolvedTypeRef(DYNAMIC_TYPE),
2665 )
2666 self.set_container_type(container_type)
2667
2668 def bind_call_self(
2669 self,
2670 node: ast.Call,
2671 visitor: TypeBinder,
2672 type_ctx: Optional[Class],
2673 self_expr: Optional[expr] = None,
2674 ) -> NarrowingEffect:
2675 if self.args is not None:
2676 return super().bind_call_self(node, visitor, type_ctx, self_expr)
2677 elif node.keywords:
2678 return super().bind_call(node, visitor, type_ctx)
2679
2680 visitor.set_type(node, DYNAMIC)
2681 for arg in node.args:
2682 visitor.visit(arg)
2683
2684 return NO_EFFECT
2685
2686 def bind_descr_get(
2687 self,
2688 node: ast.Attribute,
2689 inst: Optional[Object[TClassInv]],
2690 ctx: TClassInv,
2691 visitor: TypeBinder,
2692 type_ctx: Optional[Class],
2693 ) -> None:
2694 if inst is None:
2695 visitor.set_type(node, self)
2696 else:
2697 visitor.set_type(node, BuiltinMethod(self, node))
2698
2699 def make_generic(
2700 self, new_type: Class, name: GenericTypeName, generic_types: GenericTypesDict
2701 ) -> Value:
2702 cur_args = self.args
2703 cur_ret_type = self.return_type
2704 if cur_args is not None and cur_ret_type is not None:
2705 new_args = tuple(arg.bind_generics(name, generic_types) for arg in cur_args)
2706 new_ret_type = cur_ret_type.resolved().bind_generics(name, generic_types)
2707 return BuiltinMethodDescriptor(
2708 self.func_name,
2709 new_type,
2710 new_args,
2711 ResolvedTypeRef(new_ret_type),
2712 )
2713 else:
2714 return BuiltinMethodDescriptor(self.func_name, new_type)
2715
2716
2717class BuiltinMethod(Callable[Class]):
2718 def __init__(self, desc: BuiltinMethodDescriptor, target: ast.Attribute) -> None:
2719 super().__init__(
2720 BUILTIN_METHOD_TYPE,
2721 desc.func_name,
2722 desc.module_name,
2723 desc.args,
2724 {},
2725 0,
2726 None,
2727 None,
2728 desc.return_type,
2729 )
2730 self.desc = desc
2731 self.target = target
2732 self.set_container_type(desc.container_type)
2733
2734 @property
2735 def name(self) -> str:
2736 return self.qualname
2737
2738 def bind_call(
2739 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
2740 ) -> NarrowingEffect:
2741 if self.args:
2742 return super().bind_call_self(node, visitor, type_ctx, self.target.value)
2743 if node.keywords:
2744 return Object.bind_call(self, node, visitor, type_ctx)
2745
2746 visitor.set_type(node, self.return_type.resolved().instance)
2747 visitor.visit(self.target.value)
2748 for arg in node.args:
2749 visitor.visit(arg)
2750 self.check_args_for_primitives(node, visitor)
2751 return NO_EFFECT
2752
2753 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
2754 if node.keywords or (
2755 self.args is not None and not self.desc.can_call_self(node, True)
2756 ):
2757 return super().emit_call(node, code_gen)
2758
2759 code_gen.update_lineno(node)
2760
2761 if self.args is not None:
2762 self.desc.emit_call_self(node, code_gen)
2763 else:
2764 # Untyped method, we can still do an INVOKE_METHOD
2765
2766 code_gen.visit(self.target.value)
2767
2768 code_gen.update_lineno(node)
2769 for arg in node.args:
2770 code_gen.visit(arg)
2771
2772 if code_gen.get_type(self.target.value).klass.is_exact:
2773 code_gen.emit("INVOKE_FUNCTION", (self.type_descr, len(node.args) + 1))
2774 else:
2775 code_gen.emit_invoke_method(self.type_descr, len(node.args))
2776
2777
2778class StrictBuiltins(Object[Class]):
2779 def __init__(self, builtins: Dict[str, Value]) -> None:
2780 super().__init__(DICT_TYPE)
2781 self.builtins = builtins
2782
2783 def bind_subscr(
2784 self, node: ast.Subscript, type: Value, visitor: TypeBinder
2785 ) -> None:
2786 slice = node.slice
2787 type = DYNAMIC
2788 if isinstance(slice, ast.Index):
2789 val = slice.value
2790 if isinstance(val, ast.Str):
2791 builtin = self.builtins.get(val.s)
2792 if builtin is not None:
2793 type = builtin
2794 elif isinstance(val, ast.Constant):
2795 svalue = val.value
2796 if isinstance(svalue, str):
2797 builtin = self.builtins.get(svalue)
2798 if builtin is not None:
2799 type = builtin
2800
2801 visitor.set_type(node, type)
2802
2803
2804def get_default_value(default: expr) -> object:
2805 if not isinstance(default, (Constant, Str, Num, Bytes, NameConstant, ast.Ellipsis)):
2806
2807 default = AstOptimizer().visit(default)
2808
2809 if isinstance(default, Str):
2810 return default.s
2811 elif isinstance(default, Num):
2812 return default.n
2813 elif isinstance(default, Bytes):
2814 return default.s
2815 elif isinstance(default, ast.Ellipsis):
2816 return ...
2817 elif isinstance(default, (ast.Constant, ast.NameConstant)):
2818 return default.value
2819 else:
2820 return default
2821
2822
2823# Bringing up the type system is a little special as we have dependencies
2824# amongst type and object
2825TYPE_TYPE = Class.__new__(Class)
2826TYPE_TYPE.type_name = TypeName("builtins", "type")
2827TYPE_TYPE.klass = TYPE_TYPE
2828TYPE_TYPE.instance = TYPE_TYPE
2829TYPE_TYPE.members = {}
2830TYPE_TYPE.is_exact = False
2831TYPE_TYPE.is_final = False
2832TYPE_TYPE._mro = None
2833TYPE_TYPE._mro_type_descrs = None
2834
2835
2836class Slot(Object[TClassInv]):
2837 def __init__(
2838 self,
2839 type_ref: TypeRef,
2840 name: str,
2841 container_type: Class,
2842 assignment: Optional[AST] = None,
2843 ) -> None:
2844 super().__init__(MEMBER_TYPE)
2845 self.container_type = container_type
2846 self.slot_name = name
2847 self._type_ref = type_ref
2848 self.assignment = assignment
2849
2850 def bind_descr_get(
2851 self,
2852 node: ast.Attribute,
2853 inst: Optional[Object[TClassInv]],
2854 ctx: TClassInv,
2855 visitor: TypeBinder,
2856 type_ctx: Optional[Class],
2857 ) -> None:
2858 if inst is None:
2859 visitor.set_type(node, self)
2860 return
2861
2862 visitor.set_type(node, self.decl_type.instance)
2863
2864 @property
2865 def decl_type(self) -> Class:
2866 type = self._type_ref.resolved(is_declaration=True)
2867 if isinstance(type, FinalClass):
2868 return type.inner_type()
2869 return type
2870
2871 @property
2872 def is_final(self) -> bool:
2873 return isinstance(self._type_ref.resolved(is_declaration=True), FinalClass)
2874
2875 @property
2876 def type_descr(self) -> TypeDescr:
2877 return self.decl_type.type_descr
2878
2879
2880# TODO (aniketpanse): move these to a better place
2881OBJECT_TYPE = Class(TypeName("builtins", "object"))
2882OBJECT = OBJECT_TYPE.instance
2883
2884DYNAMIC_TYPE = DynamicClass()
2885DYNAMIC = DYNAMIC_TYPE.instance
2886
2887
2888class BoxFunction(Object[Class]):
2889 def bind_call(
2890 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
2891 ) -> NarrowingEffect:
2892 if len(node.args) != 1:
2893 raise visitor.syntax_error("box only accepts a single argument", node)
2894
2895 arg = node.args[0]
2896 visitor.visit(arg)
2897 arg_type = visitor.get_type(arg)
2898
2899 if isinstance(arg_type, CIntInstance):
2900 typ = BOOL_TYPE if arg_type.constant == TYPED_BOOL else INT_EXACT_TYPE
2901 visitor.set_type(node, typ.instance)
2902 elif isinstance(arg_type, CDoubleInstance):
2903 visitor.set_type(node, FLOAT_EXACT_TYPE.instance)
2904 else:
2905 raise visitor.syntax_error(
2906 f"can't box non-primitive: {arg_type.name}", node
2907 )
2908 return NO_EFFECT
2909
2910 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
2911 code_gen.get_type(node.args[0]).emit_box(node.args[0], code_gen)
2912
2913
2914class UnboxFunction(Object[Class]):
2915 def bind_call(
2916 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
2917 ) -> NarrowingEffect:
2918 if len(node.args) != 1:
2919 raise visitor.syntax_error("unbox only accepts a single argument", node)
2920
2921 for arg in node.args:
2922 visitor.visit(arg, DYNAMIC)
2923 self.check_args_for_primitives(node, visitor)
2924 visitor.set_type(node, type_ctx or INT64_VALUE)
2925 return NO_EFFECT
2926
2927 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
2928 code_gen.get_type(node).emit_unbox(node.args[0], code_gen)
2929
2930
2931class LenFunction(Object[Class]):
2932 def __init__(self, klass: Class, boxed: bool) -> None:
2933 super().__init__(klass)
2934 self.boxed = boxed
2935
2936 @property
2937 def name(self) -> str:
2938 return f"{'' if self.boxed else 'c'}len function"
2939
2940 def bind_call(
2941 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
2942 ) -> NarrowingEffect:
2943 if len(node.args) != 1:
2944 visitor.syntax_error(
2945 f"len() does not accept more than one arguments ({len(node.args)} given)",
2946 node,
2947 )
2948 arg = node.args[0]
2949 visitor.visit(arg)
2950 arg_type = visitor.get_type(arg)
2951 if not self.boxed and arg_type.get_fast_len_type() is None:
2952 raise visitor.syntax_error(
2953 f"bad argument type '{arg_type.name}' for clen()", arg
2954 )
2955 self.check_args_for_primitives(node, visitor)
2956 output_type = INT_EXACT_TYPE.instance if self.boxed else INT64_TYPE.instance
2957 visitor.set_type(node, output_type)
2958 return NO_EFFECT
2959
2960 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
2961 code_gen.get_type(node.args[0]).emit_len(node, code_gen, boxed=self.boxed)
2962
2963
2964class SortedFunction(Object[Class]):
2965 @property
2966 def name(self) -> str:
2967 return "sorted function"
2968
2969 def bind_call(
2970 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
2971 ) -> NarrowingEffect:
2972 if len(node.args) != 1:
2973 visitor.syntax_error(
2974 f"sorted() accepts one positional argument ({len(node.args)} given)",
2975 node,
2976 )
2977 visitor.visit(node.args[0])
2978 for kw in node.keywords:
2979 visitor.visit(kw.value)
2980 self.check_args_for_primitives(node, visitor)
2981 visitor.set_type(node, LIST_EXACT_TYPE.instance)
2982 return NO_EFFECT
2983
2984 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
2985 super().emit_call(node, code_gen)
2986 code_gen.emit("REFINE_TYPE", LIST_EXACT_TYPE.type_descr)
2987
2988
2989class ExtremumFunction(Object[Class]):
2990 def __init__(self, klass: Class, is_min: bool) -> None:
2991 super().__init__(klass)
2992 self.is_min = is_min
2993
2994 @property
2995 def _extremum(self) -> str:
2996 return "min" if self.is_min else "max"
2997
2998 @property
2999 def name(self) -> str:
3000 return f"{self._extremum} function"
3001
3002 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
3003 if (
3004 # We only specialize for two args
3005 len(node.args) != 2
3006 # We don't support specialization if any kwargs are present
3007 or len(node.keywords) > 0
3008 # If we have any *args, we skip specialization
3009 or any(isinstance(a, ast.Starred) for a in node.args)
3010 ):
3011 return super().emit_call(node, code_gen)
3012
3013 # Compile `min(a, b)` to a ternary expression, `a if a <= b else b`.
3014 # Similar for `max(a, b).
3015 endblock = code_gen.newBlock(f"{self._extremum}_end")
3016 elseblock = code_gen.newBlock(f"{self._extremum}_else")
3017
3018 for a in node.args:
3019 code_gen.visit(a)
3020
3021 if self.is_min:
3022 op = "<="
3023 else:
3024 op = ">="
3025
3026 code_gen.emit("DUP_TOP_TWO")
3027 code_gen.emit("COMPARE_OP", op)
3028 code_gen.emit("POP_JUMP_IF_FALSE", elseblock)
3029 # Remove `b` from stack, `a` was the minimum
3030 code_gen.emit("POP_TOP")
3031 code_gen.emit("JUMP_FORWARD", endblock)
3032 code_gen.nextBlock(elseblock)
3033 # Remove `a` from the stack, `b` was the minimum
3034 code_gen.emit("ROT_TWO")
3035 code_gen.emit("POP_TOP")
3036 code_gen.nextBlock(endblock)
3037
3038
3039class IsInstanceFunction(Object[Class]):
3040 def __init__(self) -> None:
3041 super().__init__(FUNCTION_TYPE)
3042
3043 @property
3044 def name(self) -> str:
3045 return "isinstance function"
3046
3047 def bind_call(
3048 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
3049 ) -> NarrowingEffect:
3050 if node.keywords:
3051 visitor.syntax_error("isinstance() does not accept keyword arguments", node)
3052 for arg in node.args:
3053 visitor.visit(arg)
3054 self.check_args_for_primitives(node, visitor)
3055 visitor.set_type(node, BOOL_TYPE.instance)
3056 if len(node.args) == 2:
3057 arg0 = node.args[0]
3058 if not isinstance(arg0, ast.Name):
3059 return NO_EFFECT
3060
3061 arg1 = node.args[1]
3062 klass_type = None
3063 if isinstance(arg1, ast.Tuple):
3064 types = tuple(visitor.get_type(el) for el in arg1.elts)
3065 if all(isinstance(t, Class) for t in types):
3066 klass_type = UNION_TYPE.make_generic_type(
3067 types, visitor.symtable.generic_types
3068 )
3069 else:
3070 arg1_type = visitor.get_type(node.args[1])
3071 if isinstance(arg1_type, Class):
3072 klass_type = inexact(arg1_type)
3073
3074 if klass_type is not None:
3075 return IsInstanceEffect(
3076 arg0.id,
3077 visitor.get_type(arg0),
3078 inexact(klass_type.instance),
3079 visitor,
3080 )
3081
3082 return NO_EFFECT
3083
3084
3085class IsSubclassFunction(Object[Class]):
3086 def __init__(self) -> None:
3087 super().__init__(FUNCTION_TYPE)
3088
3089 @property
3090 def name(self) -> str:
3091 return "issubclass function"
3092
3093 def bind_call(
3094 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
3095 ) -> NarrowingEffect:
3096 if node.keywords:
3097 raise visitor.syntax_error(
3098 "issubclass() does not accept keyword arguments", node
3099 )
3100 for arg in node.args:
3101 visitor.visit(arg)
3102 visitor.set_type(node, BOOL_TYPE.instance)
3103 self.check_args_for_primitives(node, visitor)
3104 return NO_EFFECT
3105
3106
3107class RevealTypeFunction(Object[Class]):
3108 def __init__(self) -> None:
3109 super().__init__(FUNCTION_TYPE)
3110
3111 @property
3112 def name(self) -> str:
3113 return "reveal_type function"
3114
3115 def bind_call(
3116 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
3117 ) -> NarrowingEffect:
3118 if node.keywords:
3119 raise visitor.syntax_error(
3120 "reveal_type() does not accept keyword arguments", node
3121 )
3122 if len(node.args) != 1:
3123 raise visitor.syntax_error(
3124 "reveal_type() accepts exactly one argument", node
3125 )
3126 arg = node.args[0]
3127 visitor.visit(arg)
3128 arg_type = visitor.get_type(arg)
3129 msg = f"reveal_type({to_expr(arg)}): '{arg_type.name}'"
3130 if isinstance(arg, ast.Name) and arg.id in visitor.decl_types:
3131 decl_type = visitor.decl_types[arg.id].type
3132 local_type = visitor.local_types[arg.id]
3133 msg += f", '{arg.id}' has declared type '{decl_type.name}' and local type '{local_type.name}'"
3134 raise visitor.syntax_error(msg, node)
3135 return NO_EFFECT
3136
3137
3138class NumClass(Class):
3139 def __init__(
3140 self,
3141 name: TypeName,
3142 pytype: Optional[Type[object]] = None,
3143 is_exact: bool = False,
3144 literal_value: Optional[int] = None,
3145 ) -> None:
3146 bases: List[Class] = [OBJECT_TYPE]
3147 if literal_value is not None:
3148 is_exact = True
3149 bases = [INT_EXACT_TYPE]
3150 instance = NumExactInstance(self) if is_exact else NumInstance(self)
3151 super().__init__(
3152 name,
3153 bases,
3154 instance,
3155 pytype=pytype,
3156 is_exact=is_exact,
3157 )
3158 self.literal_value = literal_value
3159
3160 def can_assign_from(self, src: Class) -> bool:
3161 if isinstance(src, NumClass):
3162 if self.literal_value is not None:
3163 return src.literal_value == self.literal_value
3164 if self.is_exact and src.is_exact and self.type_descr == src.type_descr:
3165 return True
3166 return super().can_assign_from(src)
3167
3168
3169class NumInstance(Object[NumClass]):
3170 def bind_unaryop(
3171 self, node: ast.UnaryOp, visitor: TypeBinder, type_ctx: Optional[Class]
3172 ) -> None:
3173 if isinstance(node.op, (ast.USub, ast.Invert, ast.UAdd)):
3174 visitor.set_type(node, self)
3175 else:
3176 assert isinstance(node.op, ast.Not)
3177 visitor.set_type(node, BOOL_TYPE.instance)
3178
3179 def bind_constant(self, node: ast.Constant, visitor: TypeBinder) -> None:
3180 self._bind_constant(node.value, node, visitor)
3181
3182 def _bind_constant(
3183 self, value: object, node: ast.expr, visitor: TypeBinder
3184 ) -> None:
3185 value_inst = CONSTANT_TYPES.get(type(value), self)
3186 visitor.set_type(node, value_inst)
3187 visitor.check_can_assign_from(self.klass, value_inst.klass, node)
3188
3189
3190class NumExactInstance(NumInstance):
3191 @property
3192 def name(self) -> str:
3193 if self.klass.literal_value is not None:
3194 return f"Literal[{self.klass.literal_value}]"
3195 return super().name
3196
3197 def bind_binop(
3198 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
3199 ) -> bool:
3200 ltype = visitor.get_type(node.left)
3201 rtype = visitor.get_type(node.right)
3202 if INT_EXACT_TYPE.can_assign_from(
3203 ltype.klass
3204 ) and INT_EXACT_TYPE.can_assign_from(rtype.klass):
3205 if isinstance(node.op, ast.Div):
3206 visitor.set_type(node, FLOAT_EXACT_TYPE.instance)
3207 else:
3208 visitor.set_type(node, INT_EXACT_TYPE.instance)
3209 return True
3210 return False
3211
3212
3213def parse_param(info: Dict[str, object], idx: int) -> Parameter:
3214 name = info.get("name", "")
3215 assert isinstance(name, str)
3216
3217 return Parameter(
3218 name,
3219 idx,
3220 ResolvedTypeRef(parse_type(info)),
3221 "default" in info,
3222 info.get("default"),
3223 False,
3224 )
3225
3226
3227def parse_typed_signature(
3228 sig: Dict[str, object], klass: Optional[Class] = None
3229) -> Tuple[Tuple[Parameter, ...], Class]:
3230 args = sig["args"]
3231 assert isinstance(args, list)
3232 if klass is not None:
3233 signature = [Parameter("self", 0, ResolvedTypeRef(klass), False, None, False)]
3234 else:
3235 signature = []
3236
3237 for idx, arg in enumerate(args):
3238 signature.append(parse_param(arg, idx + 1))
3239 return_info = sig["return"]
3240 assert isinstance(return_info, dict)
3241 return_type = parse_type(return_info)
3242 return tuple(signature), return_type
3243
3244
3245def reflect_builtin_function(obj: BuiltinFunctionType) -> BuiltinFunction:
3246 sig = getattr(obj, "__typed_signature__", None)
3247 if sig is not None:
3248 signature, return_type = parse_typed_signature(sig)
3249 method = BuiltinFunction(
3250 obj.__name__,
3251 obj.__module__,
3252 signature,
3253 ResolvedTypeRef(return_type),
3254 )
3255 else:
3256 method = BuiltinFunction(obj.__name__, obj.__module__)
3257 return method
3258
3259
3260def reflect_method_desc(
3261 obj: MethodDescriptorType, klass: Class
3262) -> BuiltinMethodDescriptor:
3263 sig = getattr(obj, "__typed_signature__", None)
3264 if sig is not None:
3265 signature, return_type = parse_typed_signature(sig, klass)
3266
3267 method = BuiltinMethodDescriptor(
3268 obj.__name__,
3269 klass,
3270 signature,
3271 ResolvedTypeRef(return_type),
3272 )
3273 else:
3274 method = BuiltinMethodDescriptor(obj.__name__, klass)
3275 return method
3276
3277
3278def make_type_dict(klass: Class, t: Type[object]) -> Dict[str, Value]:
3279 ret: Dict[str, Value] = {}
3280 for k in t.__dict__.keys():
3281 obj = getattr(t, k)
3282 if isinstance(obj, MethodDescriptorType):
3283 ret[k] = reflect_method_desc(obj, klass)
3284
3285 return ret
3286
3287
3288def common_sequence_emit_len(
3289 node: ast.Call, code_gen: Static38CodeGenerator, oparg: int, boxed: bool
3290) -> None:
3291 if len(node.args) != 1:
3292 raise code_gen.syntax_error(
3293 f"Can only pass a single argument when checking sequence length", node
3294 )
3295 code_gen.visit(node.args[0])
3296 code_gen.emit("FAST_LEN", oparg)
3297 if boxed:
3298 signed = True
3299 code_gen.emit("PRIMITIVE_BOX", int(signed))
3300
3301
3302def common_sequence_emit_jumpif(
3303 test: AST,
3304 next: Block,
3305 is_if_true: bool,
3306 code_gen: Static38CodeGenerator,
3307 oparg: int,
3308) -> None:
3309 code_gen.visit(test)
3310 code_gen.emit("FAST_LEN", oparg)
3311 code_gen.emit("POP_JUMP_IF_NONZERO" if is_if_true else "POP_JUMP_IF_ZERO", next)
3312
3313
3314def common_sequence_emit_forloop(
3315 node: ast.For, code_gen: Static38CodeGenerator, oparg: int
3316) -> None:
3317 descr = ("__static__", "int64")
3318 start = code_gen.newBlock(f"seq_forloop_start")
3319 anchor = code_gen.newBlock(f"seq_forloop_anchor")
3320 after = code_gen.newBlock(f"seq_forloop_after")
3321 with code_gen.new_loopidx() as loop_idx:
3322 code_gen.set_lineno(node)
3323 code_gen.push_loop(FOR_LOOP, start, after)
3324 code_gen.visit(node.iter)
3325
3326 code_gen.emit("PRIMITIVE_LOAD_CONST", (0, TYPED_INT64))
3327 code_gen.emit("STORE_LOCAL", (loop_idx, descr))
3328 code_gen.nextBlock(start)
3329 code_gen.emit("DUP_TOP") # used for SEQUENCE_GET
3330 code_gen.emit("DUP_TOP") # used for FAST_LEN
3331 code_gen.emit("FAST_LEN", oparg)
3332 code_gen.emit("LOAD_LOCAL", (loop_idx, descr))
3333 code_gen.emit("INT_COMPARE_OP", PRIM_OP_GT_INT)
3334 code_gen.emit("POP_JUMP_IF_ZERO", anchor)
3335 code_gen.emit("LOAD_LOCAL", (loop_idx, descr))
3336 if oparg == FAST_LEN_LIST:
3337 code_gen.emit("SEQUENCE_GET", SEQ_LIST | SEQ_SUBSCR_UNCHECKED)
3338 else:
3339 # todo - we need to implement TUPLE_GET which supports primitive index
3340 code_gen.emit("PRIMITIVE_BOX", 1) # 1 is for signed
3341 code_gen.emit("BINARY_SUBSCR", 2)
3342 code_gen.emit("LOAD_LOCAL", (loop_idx, descr))
3343 code_gen.emit("PRIMITIVE_LOAD_CONST", (1, TYPED_INT64))
3344 code_gen.emit("PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT)
3345 code_gen.emit("STORE_LOCAL", (loop_idx, descr))
3346 code_gen.visit(node.target)
3347 code_gen.visit(node.body)
3348 code_gen.emit("JUMP_ABSOLUTE", start)
3349 code_gen.nextBlock(anchor)
3350 code_gen.emit("POP_TOP") # Pop loop index
3351 code_gen.emit("POP_TOP") # Pop list
3352 code_gen.pop_loop()
3353
3354 if node.orelse:
3355 code_gen.visit(node.orelse)
3356 code_gen.nextBlock(after)
3357
3358
3359class TupleClass(Class):
3360 def __init__(self, is_exact: bool = False) -> None:
3361 instance = TupleExactInstance(self) if is_exact else TupleInstance(self)
3362 super().__init__(
3363 type_name=TypeName("builtins", "tuple"),
3364 bases=[OBJECT_TYPE],
3365 instance=instance,
3366 is_exact=is_exact,
3367 pytype=tuple,
3368 )
3369
3370
3371class TupleInstance(Object[TupleClass]):
3372 def get_fast_len_type(self) -> int:
3373 return FAST_LEN_TUPLE | ((not self.klass.is_exact) << 4)
3374
3375 def emit_len(
3376 self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
3377 ) -> None:
3378 return common_sequence_emit_len(
3379 node, code_gen, self.get_fast_len_type(), boxed=boxed
3380 )
3381
3382 def emit_jumpif(
3383 self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
3384 ) -> None:
3385 return common_sequence_emit_jumpif(
3386 test, next, is_if_true, code_gen, self.get_fast_len_type()
3387 )
3388
3389 def emit_binop(self, node: ast.BinOp, code_gen: Static38CodeGenerator) -> None:
3390 if maybe_emit_sequence_repeat(node, code_gen):
3391 return
3392 code_gen.defaultVisit(node)
3393
3394
3395class TupleExactInstance(TupleInstance):
3396 def bind_binop(
3397 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
3398 ) -> bool:
3399 rtype = visitor.get_type(node.right).klass
3400 if isinstance(node.op, ast.Mult) and (
3401 INT_TYPE.can_assign_from(rtype) or rtype in SIGNED_CINT_TYPES
3402 ):
3403 visitor.set_type(node, TUPLE_EXACT_TYPE.instance)
3404 return True
3405 return super().bind_binop(node, visitor, type_ctx)
3406
3407 def bind_reverse_binop(
3408 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
3409 ) -> bool:
3410 ltype = visitor.get_type(node.left).klass
3411 if isinstance(node.op, ast.Mult) and (
3412 INT_TYPE.can_assign_from(ltype) or ltype in SIGNED_CINT_TYPES
3413 ):
3414 visitor.set_type(node, TUPLE_EXACT_TYPE.instance)
3415 return True
3416 return super().bind_reverse_binop(node, visitor, type_ctx)
3417
3418 def emit_forloop(self, node: ast.For, code_gen: Static38CodeGenerator) -> None:
3419 if not isinstance(node.target, ast.Name):
3420 # We don't yet support `for a, b in my_tuple: ...`
3421 return super().emit_forloop(node, code_gen)
3422
3423 return common_sequence_emit_forloop(node, code_gen, FAST_LEN_TUPLE)
3424
3425
3426class SetClass(Class):
3427 def __init__(self, is_exact: bool = False) -> None:
3428 super().__init__(
3429 type_name=TypeName("builtins", "set"),
3430 bases=[OBJECT_TYPE],
3431 instance=SetInstance(self),
3432 is_exact=is_exact,
3433 pytype=tuple,
3434 )
3435
3436
3437class SetInstance(Object[SetClass]):
3438 def get_fast_len_type(self) -> int:
3439 return FAST_LEN_SET | ((not self.klass.is_exact) << 4)
3440
3441 def emit_len(
3442 self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
3443 ) -> None:
3444 if len(node.args) != 1:
3445 raise code_gen.syntax_error(
3446 "Can only pass a single argument when checking set length", node
3447 )
3448 code_gen.visit(node.args[0])
3449 code_gen.emit("FAST_LEN", self.get_fast_len_type())
3450 if boxed:
3451 signed = True
3452 code_gen.emit("PRIMITIVE_BOX", int(signed))
3453
3454 def emit_jumpif(
3455 self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
3456 ) -> None:
3457 code_gen.visit(test)
3458 code_gen.emit("FAST_LEN", self.get_fast_len_type())
3459 code_gen.emit("POP_JUMP_IF_NONZERO" if is_if_true else "POP_JUMP_IF_ZERO", next)
3460
3461
3462def maybe_emit_sequence_repeat(
3463 node: ast.BinOp, code_gen: Static38CodeGenerator
3464) -> bool:
3465 if not isinstance(node.op, ast.Mult):
3466 return False
3467 for seq, num, rev in [
3468 (node.left, node.right, 0),
3469 (node.right, node.left, SEQ_REPEAT_REVERSED),
3470 ]:
3471 seq_type = code_gen.get_type(seq).klass
3472 num_type = code_gen.get_type(num).klass
3473 oparg = None
3474 if TUPLE_TYPE.can_assign_from(seq_type):
3475 oparg = SEQ_TUPLE
3476 elif LIST_TYPE.can_assign_from(seq_type):
3477 oparg = SEQ_LIST
3478 if oparg is None:
3479 continue
3480 if num_type in SIGNED_CINT_TYPES:
3481 oparg |= SEQ_REPEAT_PRIMITIVE_NUM
3482 elif not INT_TYPE.can_assign_from(num_type):
3483 continue
3484 if not seq_type.is_exact:
3485 oparg |= SEQ_REPEAT_INEXACT_SEQ
3486 if not num_type.is_exact:
3487 oparg |= SEQ_REPEAT_INEXACT_NUM
3488 oparg |= rev
3489 code_gen.visit(seq)
3490 code_gen.visit(num)
3491 code_gen.emit("SEQUENCE_REPEAT", oparg)
3492 return True
3493 return False
3494
3495
3496class ListAppendMethod(BuiltinMethodDescriptor):
3497 def bind_descr_get(
3498 self,
3499 node: ast.Attribute,
3500 inst: Optional[Object[TClassInv]],
3501 ctx: TClassInv,
3502 visitor: TypeBinder,
3503 type_ctx: Optional[Class],
3504 ) -> None:
3505 if inst is None:
3506 visitor.set_type(node, self)
3507 else:
3508 visitor.set_type(node, ListAppendBuiltinMethod(self, node))
3509
3510
3511class ListAppendBuiltinMethod(BuiltinMethod):
3512 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
3513 if len(node.args) == 1 and not node.keywords:
3514 code_gen.visit(self.target.value)
3515 code_gen.visit(node.args[0])
3516 code_gen.emit("LIST_APPEND", 1)
3517 return
3518
3519 return super().emit_call(node, code_gen)
3520
3521
3522class ListClass(Class):
3523 def __init__(self, is_exact: bool = False) -> None:
3524 instance = ListExactInstance(self) if is_exact else ListInstance(self)
3525 super().__init__(
3526 type_name=TypeName("builtins", "list"),
3527 bases=[OBJECT_TYPE],
3528 instance=instance,
3529 is_exact=is_exact,
3530 pytype=list,
3531 )
3532 if is_exact:
3533 self.members["append"] = ListAppendMethod("append", self)
3534
3535
3536class ListInstance(Object[ListClass]):
3537 def get_fast_len_type(self) -> int:
3538 return FAST_LEN_LIST | ((not self.klass.is_exact) << 4)
3539
3540 def get_subscr_type(self) -> int:
3541 return SEQ_LIST_INEXACT
3542
3543 def emit_len(
3544 self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
3545 ) -> None:
3546 return common_sequence_emit_len(
3547 node, code_gen, self.get_fast_len_type(), boxed=boxed
3548 )
3549
3550 def emit_jumpif(
3551 self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
3552 ) -> None:
3553 return common_sequence_emit_jumpif(
3554 test, next, is_if_true, code_gen, self.get_fast_len_type()
3555 )
3556
3557 def bind_subscr(
3558 self, node: ast.Subscript, type: Value, visitor: TypeBinder
3559 ) -> None:
3560 if type.klass not in SIGNED_CINT_TYPES:
3561 super().bind_subscr(node, type, visitor)
3562 visitor.set_type(node, DYNAMIC)
3563
3564 def emit_subscr(
3565 self, node: ast.Subscript, aug_flag: bool, code_gen: Static38CodeGenerator
3566 ) -> None:
3567 index_type = code_gen.get_type(node.slice)
3568 if index_type.klass not in SIGNED_CINT_TYPES:
3569 return super().emit_subscr(node, aug_flag, code_gen)
3570
3571 code_gen.update_lineno(node)
3572 code_gen.visit(node.value)
3573 code_gen.visit(node.slice)
3574 if isinstance(node.ctx, ast.Load):
3575 code_gen.emit("SEQUENCE_GET", self.get_subscr_type())
3576 elif isinstance(node.ctx, ast.Store):
3577 code_gen.emit("SEQUENCE_SET", self.get_subscr_type())
3578 elif isinstance(node.ctx, ast.Del):
3579 code_gen.emit("LIST_DEL")
3580
3581 def emit_binop(self, node: ast.BinOp, code_gen: Static38CodeGenerator) -> None:
3582 if maybe_emit_sequence_repeat(node, code_gen):
3583 return
3584 code_gen.defaultVisit(node)
3585
3586
3587class ListExactInstance(ListInstance):
3588 def get_subscr_type(self) -> int:
3589 return SEQ_LIST
3590
3591 def bind_binop(
3592 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
3593 ) -> bool:
3594 rtype = visitor.get_type(node.right).klass
3595 if isinstance(node.op, ast.Mult) and (
3596 INT_TYPE.can_assign_from(rtype) or rtype in SIGNED_CINT_TYPES
3597 ):
3598 visitor.set_type(node, LIST_EXACT_TYPE.instance)
3599 return True
3600 return super().bind_binop(node, visitor, type_ctx)
3601
3602 def bind_reverse_binop(
3603 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
3604 ) -> bool:
3605 ltype = visitor.get_type(node.left).klass
3606 if isinstance(node.op, ast.Mult) and (
3607 INT_TYPE.can_assign_from(ltype) or ltype in SIGNED_CINT_TYPES
3608 ):
3609 visitor.set_type(node, LIST_EXACT_TYPE.instance)
3610 return True
3611 return super().bind_reverse_binop(node, visitor, type_ctx)
3612
3613 def emit_forloop(self, node: ast.For, code_gen: Static38CodeGenerator) -> None:
3614 if not isinstance(node.target, ast.Name):
3615 # We don't yet support `for a, b in my_list: ...`
3616 return super().emit_forloop(node, code_gen)
3617
3618 return common_sequence_emit_forloop(node, code_gen, FAST_LEN_LIST)
3619
3620
3621class StrClass(Class):
3622 def __init__(self, is_exact: bool = False) -> None:
3623 super().__init__(
3624 type_name=TypeName("builtins", "str"),
3625 bases=[OBJECT_TYPE],
3626 instance=StrInstance(self),
3627 is_exact=is_exact,
3628 pytype=str,
3629 )
3630
3631
3632class StrInstance(Object[StrClass]):
3633 def get_fast_len_type(self) -> int:
3634 return FAST_LEN_STR | ((not self.klass.is_exact) << 4)
3635
3636 def emit_len(
3637 self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
3638 ) -> None:
3639 return common_sequence_emit_len(
3640 node, code_gen, self.get_fast_len_type(), boxed=boxed
3641 )
3642
3643 def emit_jumpif(
3644 self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
3645 ) -> None:
3646 return common_sequence_emit_jumpif(
3647 test, next, is_if_true, code_gen, self.get_fast_len_type()
3648 )
3649
3650
3651class DictClass(Class):
3652 def __init__(self, is_exact: bool = False) -> None:
3653 super().__init__(
3654 type_name=TypeName("builtins", "dict"),
3655 bases=[OBJECT_TYPE],
3656 instance=DictInstance(self),
3657 is_exact=is_exact,
3658 pytype=dict,
3659 )
3660
3661
3662class DictInstance(Object[DictClass]):
3663 def get_fast_len_type(self) -> int:
3664 return FAST_LEN_DICT | ((not self.klass.is_exact) << 4)
3665
3666 def emit_len(
3667 self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
3668 ) -> None:
3669 if len(node.args) != 1:
3670 raise code_gen.syntax_error(
3671 "Can only pass a single argument when checking dict length", node
3672 )
3673 code_gen.visit(node.args[0])
3674 code_gen.emit("FAST_LEN", self.get_fast_len_type())
3675 if boxed:
3676 signed = True
3677 code_gen.emit("PRIMITIVE_BOX", int(signed))
3678
3679 def emit_jumpif(
3680 self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
3681 ) -> None:
3682 code_gen.visit(test)
3683 code_gen.emit("FAST_LEN", self.get_fast_len_type())
3684 code_gen.emit("POP_JUMP_IF_NONZERO" if is_if_true else "POP_JUMP_IF_ZERO", next)
3685
3686
3687FUNCTION_TYPE = Class(TypeName("types", "FunctionType"))
3688METHOD_TYPE = Class(TypeName("types", "MethodType"))
3689MEMBER_TYPE = Class(TypeName("types", "MemberDescriptorType"))
3690BUILTIN_METHOD_DESC_TYPE = Class(TypeName("types", "MethodDescriptorType"))
3691BUILTIN_METHOD_TYPE = Class(TypeName("types", "BuiltinMethodType"))
3692ARG_TYPE = Class(TypeName("builtins", "arg"))
3693SLICE_TYPE = Class(TypeName("builtins", "slice"))
3694
3695# builtin types
3696NONE_TYPE = NoneType()
3697STR_TYPE = StrClass()
3698STR_EXACT_TYPE = StrClass(is_exact=True)
3699INT_TYPE = NumClass(TypeName("builtins", "int"), pytype=int)
3700INT_EXACT_TYPE = NumClass(TypeName("builtins", "int"), pytype=int, is_exact=True)
3701FLOAT_TYPE = NumClass(TypeName("builtins", "float"), pytype=float)
3702FLOAT_EXACT_TYPE = NumClass(TypeName("builtins", "float"), pytype=float, is_exact=True)
3703COMPLEX_TYPE = NumClass(TypeName("builtins", "complex"), pytype=complex)
3704COMPLEX_EXACT_TYPE = NumClass(
3705 TypeName("builtins", "complex"), pytype=complex, is_exact=True
3706)
3707BYTES_TYPE = Class(TypeName("builtins", "bytes"), [OBJECT_TYPE], pytype=bytes)
3708BOOL_TYPE = Class(TypeName("builtins", "bool"), [OBJECT_TYPE], pytype=bool)
3709ELLIPSIS_TYPE = Class(TypeName("builtins", "ellipsis"), [OBJECT_TYPE], pytype=type(...))
3710DICT_TYPE = DictClass(is_exact=False)
3711DICT_EXACT_TYPE = DictClass(is_exact=True)
3712TUPLE_TYPE = TupleClass()
3713TUPLE_EXACT_TYPE = TupleClass(is_exact=True)
3714SET_TYPE = SetClass()
3715SET_EXACT_TYPE = SetClass(is_exact=True)
3716LIST_TYPE = ListClass()
3717LIST_EXACT_TYPE = ListClass(is_exact=True)
3718
3719BASE_EXCEPTION_TYPE = Class(TypeName("builtins", "BaseException"), pytype=BaseException)
3720EXCEPTION_TYPE = Class(
3721 TypeName("builtins", "Exception"),
3722 bases=[BASE_EXCEPTION_TYPE],
3723 pytype=Exception,
3724)
3725STATIC_METHOD_TYPE = StaticMethodDecorator(
3726 TypeName("builtins", "staticmethod"),
3727 bases=[OBJECT_TYPE],
3728 pytype=staticmethod,
3729)
3730FINAL_METHOD_TYPE = TypingFinalDecorator(TypeName("typing", "final"))
3731ALLOW_WEAKREFS_TYPE = AllowWeakrefsDecorator(TypeName("__static__", "allow_weakrefs"))
3732DYNAMIC_RETURN_TYPE = DynamicReturnDecorator(TypeName("__static__", "dynamic_return"))
3733INLINE_TYPE = InlineFunctionDecorator(TypeName("__static__", "inline"))
3734DONOTCOMPILE_TYPE = DoNotCompileDecorator(TypeName("__static__", "_donotcompile"))
3735
3736RESOLVED_INT_TYPE = ResolvedTypeRef(INT_TYPE)
3737RESOLVED_STR_TYPE = ResolvedTypeRef(STR_TYPE)
3738RESOLVED_NONE_TYPE = ResolvedTypeRef(NONE_TYPE)
3739
3740TYPE_TYPE.bases = [OBJECT_TYPE]
3741
3742CONSTANT_TYPES: Mapping[Type[object], Value] = {
3743 str: STR_EXACT_TYPE.instance,
3744 int: INT_EXACT_TYPE.instance,
3745 float: FLOAT_EXACT_TYPE.instance,
3746 complex: COMPLEX_EXACT_TYPE.instance,
3747 bytes: BYTES_TYPE.instance,
3748 bool: BOOL_TYPE.instance,
3749 type(None): NONE_TYPE.instance,
3750 tuple: TUPLE_EXACT_TYPE.instance,
3751 type(...): ELLIPSIS_TYPE.instance,
3752}
3753
3754NAMED_TUPLE_TYPE = Class(TypeName("typing", "NamedTuple"))
3755
3756
3757class FinalClass(GenericClass):
3758
3759 is_variadic = True
3760
3761 def make_generic_type(
3762 self,
3763 index: Tuple[Class, ...],
3764 generic_types: GenericTypesDict,
3765 ) -> Class:
3766 if len(index) > 1:
3767 raise TypedSyntaxError(
3768 f"Final types can only have a single type arg. Given: {str(index)}"
3769 )
3770 return super(FinalClass, self).make_generic_type(index, generic_types)
3771
3772 def inner_type(self) -> Class:
3773 if self.type_args:
3774 return self.type_args[0]
3775 else:
3776 return DYNAMIC_TYPE
3777
3778
3779class UnionTypeName(GenericTypeName):
3780 @property
3781 def opt_type(self) -> Optional[Class]:
3782 """If we're an Optional (i.e. Union[T, None]), return T, otherwise None."""
3783 # Assumes well-formed union (no duplicate elements, >1 element)
3784 opt_type = None
3785 if len(self.args) == 2:
3786 if self.args[0] is NONE_TYPE:
3787 opt_type = self.args[1]
3788 elif self.args[1] is NONE_TYPE:
3789 opt_type = self.args[0]
3790 return opt_type
3791
3792 @property
3793 def type_descr(self) -> TypeDescr:
3794 opt_type = self.opt_type
3795 if opt_type is not None:
3796 return opt_type.type_descr + ("?",)
3797 # the runtime does not support unions beyond optional, so just fall back
3798 # to dynamic for runtime purposes
3799 return DYNAMIC_TYPE.type_descr
3800
3801 @property
3802 def friendly_name(self) -> str:
3803 opt_type = self.opt_type
3804 if opt_type is not None:
3805 return f"Optional[{opt_type.instance.name}]"
3806 return super().friendly_name
3807
3808
3809class UnionType(GenericClass):
3810 type_name: UnionTypeName
3811 # Union is a variadic generic, so we don't give the unbound Union any
3812 # GenericParameters, and we allow it to accept any number of type args.
3813 is_variadic = True
3814
3815 def __init__(
3816 self,
3817 type_name: Optional[UnionTypeName] = None,
3818 type_def: Optional[GenericClass] = None,
3819 instance_type: Optional[Type[Object[Class]]] = None,
3820 generic_types: Optional[GenericTypesDict] = None,
3821 ) -> None:
3822 instance_type = instance_type or UnionInstance
3823 super().__init__(
3824 type_name or UnionTypeName("typing", "Union", ()),
3825 bases=[],
3826 instance=instance_type(self),
3827 type_def=type_def,
3828 )
3829 self.generic_types = generic_types
3830
3831 @property
3832 def opt_type(self) -> Optional[Class]:
3833 return self.type_name.opt_type
3834
3835 def issubclass(self, src: Class) -> bool:
3836 if isinstance(src, UnionType):
3837 return all(self.issubclass(t) for t in src.type_args)
3838 return any(t.issubclass(src) for t in self.type_args)
3839
3840 def make_generic_type(
3841 self,
3842 index: Tuple[Class, ...],
3843 generic_types: GenericTypesDict,
3844 ) -> Class:
3845 instantiations = generic_types.get(self)
3846 if instantiations is not None:
3847 instance = instantiations.get(index)
3848 if instance is not None:
3849 return instance
3850 else:
3851 generic_types[self] = instantiations = {}
3852
3853 type_args = self._simplify_args(index)
3854 if len(type_args) == 1 and not type_args[0].is_generic_parameter:
3855 return type_args[0]
3856 type_name = UnionTypeName(self.type_name.module, self.type_name.name, type_args)
3857 if any(isinstance(a, CType) for a in type_args):
3858 raise TypedSyntaxError(
3859 f"invalid union type {type_name.friendly_name}; unions cannot include primitive types"
3860 )
3861 ThisUnionType = type(self)
3862 if type_name.opt_type is not None:
3863 ThisUnionType = OptionalType
3864 instantiations[index] = concrete = ThisUnionType(
3865 type_name,
3866 type_def=self,
3867 generic_types=generic_types,
3868 )
3869 return concrete
3870
3871 def _simplify_args(self, args: Sequence[Class]) -> Tuple[Class, ...]:
3872 args = self._flatten_args(args)
3873 remove = set()
3874 for i, arg1 in enumerate(args):
3875 if i in remove:
3876 continue
3877 for j, arg2 in enumerate(args):
3878 # TODO this should be is_subtype_of once we split that from can_assign_from
3879 if i != j and arg1.can_assign_from(arg2):
3880 remove.add(j)
3881 return tuple(arg for i, arg in enumerate(args) if i not in remove)
3882
3883 def _flatten_args(self, args: Sequence[Class]) -> Sequence[Class]:
3884 new_args = []
3885 for arg in args:
3886 if isinstance(arg, UnionType):
3887 new_args.extend(self._flatten_args(arg.type_args))
3888 else:
3889 new_args.append(arg)
3890 return new_args
3891
3892
3893class UnionInstance(Object[UnionType]):
3894 def _generic_bind(
3895 self,
3896 node: ast.AST,
3897 callback: typingCallable[[Class], object],
3898 description: str,
3899 visitor: TypeBinder,
3900 ) -> List[object]:
3901 if self.klass.is_generic_type_definition:
3902 raise visitor.syntax_error(f"cannot {description} unbound Union", node)
3903 result_types: List[Class] = []
3904 ret_types: List[object] = []
3905 try:
3906 for el in self.klass.type_args:
3907 ret_types.append(callback(el))
3908 result_types.append(visitor.get_type(node).klass)
3909 except TypedSyntaxError as e:
3910 raise visitor.syntax_error(f"{self.name}: {e.msg}", node)
3911
3912 union = UNION_TYPE.make_generic_type(
3913 tuple(result_types), visitor.symtable.generic_types
3914 )
3915 visitor.set_type(node, union.instance)
3916 return ret_types
3917
3918 def bind_attr(
3919 self, node: ast.Attribute, visitor: TypeBinder, type_ctx: Optional[Class]
3920 ) -> None:
3921 def cb(el: Class) -> None:
3922 return el.instance.bind_attr(node, visitor, type_ctx)
3923
3924 self._generic_bind(
3925 node,
3926 cb,
3927 "access attribute from",
3928 visitor,
3929 )
3930
3931 def bind_call(
3932 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
3933 ) -> NarrowingEffect:
3934 def cb(el: Class) -> NarrowingEffect:
3935 return el.instance.bind_call(node, visitor, type_ctx)
3936
3937 self._generic_bind(node, cb, "call", visitor)
3938 return NO_EFFECT
3939
3940 def bind_subscr(
3941 self, node: ast.Subscript, type: Value, visitor: TypeBinder
3942 ) -> None:
3943 def cb(el: Class) -> None:
3944 return el.instance.bind_subscr(node, type, visitor)
3945
3946 self._generic_bind(node, cb, "subscript", visitor)
3947
3948 def bind_unaryop(
3949 self, node: ast.UnaryOp, visitor: TypeBinder, type_ctx: Optional[Class]
3950 ) -> None:
3951 def cb(el: Class) -> None:
3952 return el.instance.bind_unaryop(node, visitor, type_ctx)
3953
3954 self._generic_bind(
3955 node,
3956 cb,
3957 "unary op",
3958 visitor,
3959 )
3960
3961 def bind_compare(
3962 self,
3963 node: ast.Compare,
3964 left: expr,
3965 op: cmpop,
3966 right: expr,
3967 visitor: TypeBinder,
3968 type_ctx: Optional[Class],
3969 ) -> bool:
3970 def cb(el: Class) -> bool:
3971 return el.instance.bind_compare(node, left, op, right, visitor, type_ctx)
3972
3973 rets = self._generic_bind(node, cb, "compare", visitor)
3974 return all(rets)
3975
3976 def bind_reverse_compare(
3977 self,
3978 node: ast.Compare,
3979 left: expr,
3980 op: cmpop,
3981 right: expr,
3982 visitor: TypeBinder,
3983 type_ctx: Optional[Class],
3984 ) -> bool:
3985 def cb(el: Class) -> bool:
3986 return el.instance.bind_reverse_compare(
3987 node, left, op, right, visitor, type_ctx
3988 )
3989
3990 rets = self._generic_bind(node, cb, "compare", visitor)
3991 return all(rets)
3992
3993
3994class OptionalType(UnionType):
3995 """UnionType for instantiations with [T, None], and to support Optional[T] special form."""
3996
3997 is_variadic = False
3998
3999 def __init__(
4000 self,
4001 type_name: Optional[UnionTypeName] = None,
4002 type_def: Optional[GenericClass] = None,
4003 generic_types: Optional[GenericTypesDict] = None,
4004 ) -> None:
4005 super().__init__(
4006 type_name
4007 or UnionTypeName("typing", "Optional", (GenericParameter("T", 0),)),
4008 type_def=type_def,
4009 instance_type=OptionalInstance,
4010 generic_types=generic_types,
4011 )
4012
4013 @property
4014 def opt_type(self) -> Class:
4015 opt_type = self.type_name.opt_type
4016 if opt_type is None:
4017 params = ", ".join(t.name for t in self.type_args)
4018 raise TypeError(f"OptionalType has invalid type parameters {params}")
4019 return opt_type
4020
4021 def make_generic_type(
4022 self, index: Tuple[Class, ...], generic_types: GenericTypesDict
4023 ) -> Class:
4024 assert len(index) == 1
4025 if not index[0].is_generic_parameter:
4026 # Optional[T] is syntactic sugar for Union[T, None]
4027 index = index + (NONE_TYPE,)
4028 return super().make_generic_type(index, generic_types)
4029
4030
4031class OptionalInstance(UnionInstance):
4032 """Only exists for typing purposes (so we know .klass is OptionalType)."""
4033
4034 klass: OptionalType
4035
4036
4037class ArrayInstance(Object["ArrayClass"]):
4038 def _seq_type(self) -> int:
4039 idx = self.klass.index
4040 if not isinstance(idx, CIntType):
4041 # should never happen
4042 raise SyntaxError(f"Invalid Array type: {idx}")
4043 size = idx.size
4044 if size == 0:
4045 return SEQ_ARRAY_INT8 if idx.signed else SEQ_ARRAY_UINT8
4046 elif size == 1:
4047 return SEQ_ARRAY_INT16 if idx.signed else SEQ_ARRAY_UINT16
4048 elif size == 2:
4049 return SEQ_ARRAY_INT32 if idx.signed else SEQ_ARRAY_UINT32
4050 elif size == 3:
4051 return SEQ_ARRAY_INT64 if idx.signed else SEQ_ARRAY_UINT64
4052 else:
4053 raise SyntaxError(f"Invalid Array size: {size}")
4054
4055 def bind_subscr(
4056 self, node: ast.Subscript, type: Value, visitor: TypeBinder
4057 ) -> None:
4058 if type == SLICE_TYPE.instance:
4059 # Slicing preserves type
4060 return visitor.set_type(node, self)
4061
4062 visitor.set_type(node, self.klass.index.instance)
4063
4064 def emit_subscr(
4065 self, node: ast.Subscript, aug_flag: bool, code_gen: Static38CodeGenerator
4066 ) -> None:
4067 index_type = code_gen.get_type(node.slice)
4068 is_del = isinstance(node.ctx, ast.Del)
4069 index_is_python_int = INT_TYPE.can_assign_from(index_type.klass)
4070 index_is_primitive_int = isinstance(index_type.klass, CIntType)
4071
4072 # ARRAY_{GET,SET} support only integer indices and don't support del;
4073 # otherwise defer to the usual bytecode
4074 if is_del or not (index_is_python_int or index_is_primitive_int):
4075 return super().emit_subscr(node, aug_flag, code_gen)
4076
4077 code_gen.update_lineno(node)
4078 code_gen.visit(node.value)
4079 code_gen.visit(node.slice)
4080
4081 if index_is_python_int:
4082 # If the index is not a primitive, unbox its value to an int64, our implementation of
4083 # SEQUENCE_{GET/SET} expects the index to be a primitive int.
4084 code_gen.emit("PRIMITIVE_UNBOX", INT64_TYPE.instance.as_oparg())
4085
4086 if isinstance(node.ctx, ast.Store) and not aug_flag:
4087 code_gen.emit("SEQUENCE_SET", self._seq_type())
4088 elif isinstance(node.ctx, ast.Load) or aug_flag:
4089 if aug_flag:
4090 code_gen.emit("DUP_TOP_TWO")
4091 code_gen.emit("SEQUENCE_GET", self._seq_type())
4092
4093 def emit_store_subscr(
4094 self, node: ast.Subscript, code_gen: Static38CodeGenerator
4095 ) -> None:
4096 code_gen.emit("ROT_THREE")
4097 code_gen.emit("SEQUENCE_SET", self._seq_type())
4098
4099 def __repr__(self) -> str:
4100 return f"{self.klass.type_name.name}[{self.klass.index.name!r}]"
4101
4102 def get_fast_len_type(self) -> int:
4103 return FAST_LEN_ARRAY | ((not self.klass.is_exact) << 4)
4104
4105 def emit_len(
4106 self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
4107 ) -> None:
4108 if len(node.args) != 1:
4109 raise code_gen.syntax_error(
4110 "Can only pass a single argument when checking array length", node
4111 )
4112 code_gen.visit(node.args[0])
4113 code_gen.emit("FAST_LEN", self.get_fast_len_type())
4114 if boxed:
4115 signed = True
4116 code_gen.emit("PRIMITIVE_BOX", int(signed))
4117
4118
4119class ArrayClass(GenericClass):
4120 def __init__(
4121 self,
4122 name: GenericTypeName,
4123 bases: Optional[List[Class]] = None,
4124 instance: Optional[Object[Class]] = None,
4125 klass: Optional[Class] = None,
4126 members: Optional[Dict[str, Value]] = None,
4127 type_def: Optional[GenericClass] = None,
4128 is_exact: bool = False,
4129 pytype: Optional[Type[object]] = None,
4130 ) -> None:
4131 default_bases: List[Class] = [OBJECT_TYPE]
4132 default_instance: Object[Class] = ArrayInstance(self)
4133 super().__init__(
4134 name,
4135 bases or default_bases,
4136 instance or default_instance,
4137 klass,
4138 members,
4139 type_def,
4140 is_exact,
4141 pytype,
4142 )
4143
4144 @property
4145 def index(self) -> Class:
4146 return self.type_args[0]
4147
4148 def make_generic_type(
4149 self, index: Tuple[Class, ...], generic_types: GenericTypesDict
4150 ) -> Class:
4151 for tp in index:
4152 if tp not in ALLOWED_ARRAY_TYPES:
4153 raise TypedSyntaxError(
4154 f"Invalid {self.gen_name.name} element type: {tp.instance.name}"
4155 )
4156 return super().make_generic_type(index, generic_types)
4157
4158
4159class VectorClass(ArrayClass):
4160 def __init__(
4161 self,
4162 name: GenericTypeName,
4163 bases: Optional[List[Class]] = None,
4164 instance: Optional[Object[Class]] = None,
4165 klass: Optional[Class] = None,
4166 members: Optional[Dict[str, Value]] = None,
4167 type_def: Optional[GenericClass] = None,
4168 is_exact: bool = False,
4169 pytype: Optional[Type[object]] = None,
4170 ) -> None:
4171 super().__init__(
4172 name,
4173 bases,
4174 instance,
4175 klass,
4176 members,
4177 type_def,
4178 is_exact,
4179 pytype,
4180 )
4181 self.members["append"] = BuiltinMethodDescriptor(
4182 "append",
4183 self,
4184 (
4185 Parameter("self", 0, ResolvedTypeRef(self), False, None, False),
4186 Parameter(
4187 "v",
4188 0,
4189 ResolvedTypeRef(VECTOR_TYPE_PARAM),
4190 False,
4191 None,
4192 False,
4193 ),
4194 ),
4195 )
4196
4197
4198BUILTIN_GENERICS: Dict[Class, Dict[GenericTypeIndex, Class]] = {}
4199UNION_TYPE = UnionType()
4200OPTIONAL_TYPE = OptionalType()
4201FINAL_TYPE = FinalClass(GenericTypeName("typing", "Final", ()))
4202CHECKED_DICT_TYPE_NAME = GenericTypeName(
4203 "__static__", "chkdict", (GenericParameter("K", 0), GenericParameter("V", 1))
4204)
4205
4206
4207class CheckedDict(GenericClass):
4208 def __init__(
4209 self,
4210 name: GenericTypeName,
4211 bases: Optional[List[Class]] = None,
4212 instance: Optional[Object[Class]] = None,
4213 klass: Optional[Class] = None,
4214 members: Optional[Dict[str, Value]] = None,
4215 type_def: Optional[GenericClass] = None,
4216 is_exact: bool = False,
4217 pytype: Optional[Type[object]] = None,
4218 ) -> None:
4219 if instance is None:
4220 instance = CheckedDictInstance(self)
4221 super().__init__(
4222 name,
4223 bases,
4224 instance,
4225 klass,
4226 members,
4227 type_def,
4228 is_exact,
4229 pytype,
4230 )
4231
4232
4233class CheckedDictInstance(Object[CheckedDict]):
4234 def bind_subscr(
4235 self, node: ast.Subscript, type: Value, visitor: TypeBinder
4236 ) -> None:
4237 visitor.visit(node.slice, self.klass.gen_name.args[0].instance)
4238 visitor.set_type(node, self.klass.gen_name.args[1].instance)
4239
4240 def emit_subscr(
4241 self, node: ast.Subscript, aug_flag: bool, code_gen: Static38CodeGenerator
4242 ) -> None:
4243 if isinstance(node.ctx, ast.Load):
4244 code_gen.visit(node.value)
4245 code_gen.visit(node.slice)
4246 dict_descr = self.klass.type_descr
4247 update_descr = dict_descr + ("__getitem__",)
4248 code_gen.emit_invoke_method(update_descr, 1)
4249 elif isinstance(node.ctx, ast.Store):
4250 code_gen.visit(node.value)
4251 code_gen.emit("ROT_TWO")
4252 code_gen.visit(node.slice)
4253 code_gen.emit("ROT_TWO")
4254 dict_descr = self.klass.type_descr
4255 setitem_descr = dict_descr + ("__setitem__",)
4256 code_gen.emit_invoke_method(setitem_descr, 2)
4257 code_gen.emit("POP_TOP")
4258 else:
4259 code_gen.defaultVisit(node, aug_flag)
4260
4261 def get_fast_len_type(self) -> int:
4262 return FAST_LEN_DICT | ((not self.klass.is_exact) << 4)
4263
4264 def emit_len(
4265 self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
4266 ) -> None:
4267 if len(node.args) != 1:
4268 raise code_gen.syntax_error(
4269 "Can only pass a single argument when checking dict length", node
4270 )
4271 code_gen.visit(node.args[0])
4272 code_gen.emit("FAST_LEN", self.get_fast_len_type())
4273 if boxed:
4274 signed = True
4275 code_gen.emit("PRIMITIVE_BOX", int(signed))
4276
4277 def emit_jumpif(
4278 self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
4279 ) -> None:
4280 code_gen.visit(test)
4281 code_gen.emit("FAST_LEN", self.get_fast_len_type())
4282 code_gen.emit("POP_JUMP_IF_NONZERO" if is_if_true else "POP_JUMP_IF_ZERO", next)
4283
4284
4285class CastFunction(Object[Class]):
4286 def bind_call(
4287 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
4288 ) -> NarrowingEffect:
4289 if len(node.args) != 2:
4290 raise visitor.syntax_error(
4291 "cast requires two parameters: type and value", node
4292 )
4293
4294 for arg in node.args:
4295 visitor.visit(arg)
4296 self.check_args_for_primitives(node, visitor)
4297
4298 cast_type = visitor.cur_mod.resolve_annotation(node.args[0])
4299 if cast_type is None:
4300 raise visitor.syntax_error("cast to unknown type", node)
4301
4302 visitor.set_type(node, cast_type.instance)
4303 return NO_EFFECT
4304
4305 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
4306 code_gen.visit(node.args[1])
4307 code_gen.emit("CAST", code_gen.get_type(node).klass.type_descr)
4308
4309
4310prim_name_to_type: Mapping[str, int] = {
4311 "int8": TYPED_INT8,
4312 "int16": TYPED_INT16,
4313 "int32": TYPED_INT32,
4314 "int64": TYPED_INT64,
4315 "uint8": TYPED_UINT8,
4316 "uint16": TYPED_UINT16,
4317 "uint32": TYPED_UINT32,
4318 "uint64": TYPED_UINT64,
4319}
4320
4321
4322class CInstance(Value, Generic[TClass]):
4323
4324 _op_name: Dict[Type[ast.operator], str] = {
4325 ast.Add: "add",
4326 ast.Sub: "subtract",
4327 ast.Mult: "multiply",
4328 ast.FloorDiv: "divide",
4329 ast.Div: "divide",
4330 ast.Mod: "modulus",
4331 ast.LShift: "left shift",
4332 ast.RShift: "right shift",
4333 ast.BitOr: "bitwise or",
4334 ast.BitXor: "xor",
4335 ast.BitAnd: "bitwise and",
4336 }
4337
4338 @property
4339 def name(self) -> str:
4340 return self.klass.instance_name
4341
4342 def binop_error(self, left: Value, right: Value, op: ast.operator) -> str:
4343 return f"cannot {self._op_name[type(op)]} {left.name} and {right.name}"
4344
4345 def bind_reverse_binop(
4346 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
4347 ) -> bool:
4348 try:
4349 visitor.visit(node.left, self)
4350 except TypedSyntaxError:
4351 raise visitor.syntax_error(
4352 self.binop_error(visitor.get_type(node.left), self, node.op), node
4353 )
4354 visitor.set_type(node, self)
4355 return True
4356
4357 def get_op_id(self, op: AST) -> int:
4358 raise NotImplementedError("Must be implemented in the subclass")
4359
4360 def emit_binop(self, node: ast.BinOp, code_gen: Static38CodeGenerator) -> None:
4361 code_gen.update_lineno(node)
4362 common_type = code_gen.get_type(node)
4363 code_gen.visit(node.left)
4364 ltype = code_gen.get_type(node.left)
4365 if ltype != common_type:
4366 common_type.emit_convert(ltype, code_gen)
4367 code_gen.visit(node.right)
4368 rtype = code_gen.get_type(node.right)
4369 if rtype != common_type:
4370 common_type.emit_convert(rtype, code_gen)
4371 op = self.get_op_id(node.op)
4372 code_gen.emit("PRIMITIVE_BINARY_OP", op)
4373
4374 def emit_augassign(
4375 self, node: ast.AugAssign, code_gen: Static38CodeGenerator
4376 ) -> None:
4377 code_gen.set_lineno(node)
4378 aug_node = wrap_aug(node.target)
4379 code_gen.visit(aug_node, "load")
4380 code_gen.visit(node.value)
4381 code_gen.emit("PRIMITIVE_BINARY_OP", self.get_op_id(node.op))
4382 code_gen.visit(aug_node, "store")
4383
4384
4385class CIntInstance(CInstance["CIntType"]):
4386 def __init__(self, klass: CIntType, constant: int, size: int, signed: bool) -> None:
4387 super().__init__(klass)
4388 self.constant = constant
4389 self.size = size
4390 self.signed = signed
4391
4392 def as_oparg(self) -> int:
4393 return self.constant
4394
4395 _int_binary_opcode_signed: Mapping[Type[ast.AST], int] = {
4396 ast.Lt: PRIM_OP_LT_INT,
4397 ast.Gt: PRIM_OP_GT_INT,
4398 ast.Eq: PRIM_OP_EQ_INT,
4399 ast.NotEq: PRIM_OP_NE_INT,
4400 ast.LtE: PRIM_OP_LE_INT,
4401 ast.GtE: PRIM_OP_GE_INT,
4402 ast.Add: PRIM_OP_ADD_INT,
4403 ast.Sub: PRIM_OP_SUB_INT,
4404 ast.Mult: PRIM_OP_MUL_INT,
4405 ast.FloorDiv: PRIM_OP_DIV_INT,
4406 ast.Div: PRIM_OP_DIV_INT,
4407 ast.Mod: PRIM_OP_MOD_INT,
4408 ast.LShift: PRIM_OP_LSHIFT_INT,
4409 ast.RShift: PRIM_OP_RSHIFT_INT,
4410 ast.BitOr: PRIM_OP_OR_INT,
4411 ast.BitXor: PRIM_OP_XOR_INT,
4412 ast.BitAnd: PRIM_OP_AND_INT,
4413 }
4414
4415 _int_binary_opcode_unsigned: Mapping[Type[ast.AST], int] = {
4416 ast.Lt: PRIM_OP_LT_UN_INT,
4417 ast.Gt: PRIM_OP_GT_UN_INT,
4418 ast.Eq: PRIM_OP_EQ_INT,
4419 ast.NotEq: PRIM_OP_NE_INT,
4420 ast.LtE: PRIM_OP_LE_UN_INT,
4421 ast.GtE: PRIM_OP_GE_UN_INT,
4422 ast.Add: PRIM_OP_ADD_INT,
4423 ast.Sub: PRIM_OP_SUB_INT,
4424 ast.Mult: PRIM_OP_MUL_INT,
4425 ast.FloorDiv: PRIM_OP_DIV_UN_INT,
4426 ast.Div: PRIM_OP_DIV_UN_INT,
4427 ast.Mod: PRIM_OP_MOD_UN_INT,
4428 ast.LShift: PRIM_OP_LSHIFT_INT,
4429 ast.RShift: PRIM_OP_RSHIFT_INT,
4430 ast.RShift: PRIM_OP_RSHIFT_UN_INT,
4431 ast.BitOr: PRIM_OP_OR_INT,
4432 ast.BitXor: PRIM_OP_XOR_INT,
4433 ast.BitAnd: PRIM_OP_AND_INT,
4434 }
4435
4436 def get_op_id(self, op: AST) -> int:
4437 return (
4438 self._int_binary_opcode_signed[type(op)]
4439 if self.signed
4440 else (self._int_binary_opcode_unsigned[type(op)])
4441 )
4442
4443 def validate_mixed_math(self, other: Value) -> Optional[Value]:
4444 if self.constant == TYPED_BOOL:
4445 return None
4446 if other is self:
4447 return self
4448 elif isinstance(other, CIntInstance):
4449 if other.constant == TYPED_BOOL:
4450 return None
4451 if self.signed == other.signed:
4452 # signs match, we can just treat this as a comparison of the larger type
4453 if self.size > other.size:
4454 return self
4455 else:
4456 return other
4457 else:
4458 new_size = max(
4459 self.size if self.signed else self.size + 1,
4460 other.size if other.signed else other.size + 1,
4461 )
4462
4463 if new_size <= TYPED_INT_64BIT:
4464 # signs don't match, but we can promote to the next highest data type
4465 return SIGNED_CINT_TYPES[new_size].instance
4466
4467 return None
4468
4469 def bind_compare(
4470 self,
4471 node: ast.Compare,
4472 left: expr,
4473 op: cmpop,
4474 right: expr,
4475 visitor: TypeBinder,
4476 type_ctx: Optional[Class],
4477 ) -> bool:
4478 rtype = visitor.get_type(right)
4479 if rtype != self and not isinstance(rtype, CIntInstance):
4480 try:
4481 visitor.visit(right, self)
4482 except TypedSyntaxError:
4483 # Report a better error message than the generic can't be used
4484 raise visitor.syntax_error(
4485 f"can't compare {self.name} to {visitor.get_type(right).name}",
4486 node,
4487 )
4488
4489 compare_type = self.validate_mixed_math(visitor.get_type(right))
4490 if compare_type is None:
4491 raise visitor.syntax_error(
4492 f"can't compare {self.name} to {visitor.get_type(right).name}", node
4493 )
4494
4495 visitor.set_type(op, compare_type)
4496 visitor.set_type(node, CBOOL_TYPE.instance)
4497 return True
4498
4499 def bind_reverse_compare(
4500 self,
4501 node: ast.Compare,
4502 left: expr,
4503 op: cmpop,
4504 right: expr,
4505 visitor: TypeBinder,
4506 type_ctx: Optional[Class],
4507 ) -> bool:
4508 if not isinstance(visitor.get_type(left), CIntInstance):
4509 try:
4510 visitor.visit(left, self)
4511 except TypedSyntaxError:
4512 # Report a better error message than the generic can't be used
4513 raise visitor.syntax_error(
4514 f"can't compare {self.name} to {visitor.get_type(right).name}", node
4515 )
4516
4517 compare_type = self.validate_mixed_math(visitor.get_type(left))
4518 if compare_type is None:
4519 raise visitor.syntax_error(
4520 f"can't compare {visitor.get_type(left).name} to {self.name}", node
4521 )
4522
4523 visitor.set_type(op, compare_type)
4524 visitor.set_type(node, CBOOL_TYPE.instance)
4525 return True
4526
4527 return False
4528
4529 def emit_compare(self, op: cmpop, code_gen: Static38CodeGenerator) -> None:
4530 code_gen.emit("INT_COMPARE_OP", self.get_op_id(op))
4531
4532 def emit_augname(
4533 self, node: AugName, code_gen: Static38CodeGenerator, mode: str
4534 ) -> None:
4535 if mode == "load":
4536 code_gen.emit("LOAD_LOCAL", (node.id, self.klass.type_descr))
4537 elif mode == "store":
4538 code_gen.emit("STORE_LOCAL", (node.id, self.klass.type_descr))
4539
4540 def validate_int(self, val: object, node: ast.AST, visitor: TypeBinder) -> None:
4541 if not isinstance(val, int):
4542 raise visitor.syntax_error(
4543 f"{type(val).__name__} cannot be used in a context where an int is expected",
4544 node,
4545 )
4546
4547 bits = 8 << self.size
4548 if self.signed:
4549 low = -(1 << (bits - 1))
4550 high = (1 << (bits - 1)) - 1
4551 else:
4552 low = 0
4553 high = (1 << bits) - 1
4554
4555 if not low <= val <= high:
4556 # We set a type here so that when call handles the syntax error and tries to
4557 # improve the error message to "positional argument type mismatch" it can
4558 # successfully get the type
4559 visitor.set_type(node, INT_TYPE.instance)
4560 raise visitor.syntax_error(
4561 f"constant {val} is outside of the range {low} to {high} for {self.name}",
4562 node,
4563 )
4564
4565 def bind_constant(self, node: ast.Constant, visitor: TypeBinder) -> None:
4566 self.validate_int(node.value, node, visitor)
4567 visitor.set_type(node, self)
4568
4569 def emit_constant(
4570 self, node: ast.Constant, code_gen: Static38CodeGenerator
4571 ) -> None:
4572 val = node.value
4573 if self.constant == TYPED_BOOL:
4574 val = bool(val)
4575 code_gen.emit("PRIMITIVE_LOAD_CONST", (val, self.as_oparg()))
4576
4577 def emit_name(self, node: ast.Name, code_gen: Static38CodeGenerator) -> None:
4578 if isinstance(node.ctx, ast.Load):
4579 code_gen.emit("LOAD_LOCAL", (node.id, self.klass.type_descr))
4580 elif isinstance(node.ctx, ast.Store):
4581 code_gen.emit("STORE_LOCAL", (node.id, self.klass.type_descr))
4582 else:
4583 raise TypedSyntaxError("unsupported op")
4584
4585 def emit_jumpif(
4586 self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
4587 ) -> None:
4588 code_gen.visit(test)
4589 code_gen.emit("POP_JUMP_IF_NONZERO" if is_if_true else "POP_JUMP_IF_ZERO", next)
4590
4591 def emit_jumpif_pop(
4592 self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
4593 ) -> None:
4594 code_gen.visit(test)
4595 code_gen.emit(
4596 "JUMP_IF_NONZERO_OR_POP" if is_if_true else "JUMP_IF_ZERO_OR_POP", next
4597 )
4598
4599 def bind_binop(
4600 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
4601 ) -> bool:
4602 if self.constant == TYPED_BOOL:
4603 raise TypedSyntaxError(
4604 f"cbool is not a valid operand type for {self._op_name[type(node.op)]}"
4605 )
4606 rinst = visitor.get_type(node.right)
4607 if rinst != self:
4608 if rinst.klass == LIST_EXACT_TYPE:
4609 visitor.set_type(node, LIST_EXACT_TYPE.instance)
4610 return True
4611 if rinst.klass == TUPLE_EXACT_TYPE:
4612 visitor.set_type(node, TUPLE_EXACT_TYPE.instance)
4613 return True
4614
4615 try:
4616 visitor.visit(node.right, type_ctx or INT64_VALUE)
4617 except TypedSyntaxError:
4618 # Report a better error message than the generic can't be used
4619 raise visitor.syntax_error(
4620 self.binop_error(self, visitor.get_type(node.right), node.op),
4621 node,
4622 )
4623
4624 if type_ctx is None:
4625 type_ctx = self.validate_mixed_math(visitor.get_type(node.right))
4626 if type_ctx is None:
4627 raise visitor.syntax_error(
4628 self.binop_error(self, visitor.get_type(node.right), node.op),
4629 node,
4630 )
4631
4632 visitor.set_type(node, type_ctx)
4633 return True
4634
4635 def emit_box(self, node: expr, code_gen: Static38CodeGenerator) -> None:
4636 code_gen.visit(node)
4637 type = code_gen.get_type(node)
4638 if isinstance(type, CIntInstance):
4639 code_gen.emit("PRIMITIVE_BOX", self.as_oparg())
4640 else:
4641 raise RuntimeError("unsupported box type: " + type.name)
4642
4643 def emit_unbox(self, node: expr, code_gen: Static38CodeGenerator) -> None:
4644 final_val = code_gen.get_final_literal(node)
4645 if final_val is not None:
4646 return self.emit_constant(final_val, code_gen)
4647 typ = code_gen.get_type(node).klass
4648 if isinstance(typ, NumClass) and typ.literal_value is not None:
4649 code_gen.emit("PRIMITIVE_LOAD_CONST", (typ.literal_value, self.as_oparg()))
4650 return
4651 code_gen.visit(node)
4652 code_gen.emit("PRIMITIVE_UNBOX", self.as_oparg())
4653
4654 def bind_unaryop(
4655 self, node: ast.UnaryOp, visitor: TypeBinder, type_ctx: Optional[Class]
4656 ) -> None:
4657 if isinstance(node.op, (ast.USub, ast.Invert, ast.UAdd)):
4658 visitor.set_type(node, self)
4659 else:
4660 assert isinstance(node.op, ast.Not)
4661 visitor.set_type(node, BOOL_TYPE.instance)
4662
4663 def emit_unaryop(self, node: ast.UnaryOp, code_gen: Static38CodeGenerator) -> None:
4664 code_gen.update_lineno(node)
4665 if isinstance(node.op, ast.USub):
4666 code_gen.visit(node.operand)
4667 code_gen.emit("PRIMITIVE_UNARY_OP", PRIM_OP_NEG_INT)
4668 elif isinstance(node.op, ast.Invert):
4669 code_gen.visit(node.operand)
4670 code_gen.emit("PRIMITIVE_UNARY_OP", PRIM_OP_INV_INT)
4671 elif isinstance(node.op, ast.UAdd):
4672 code_gen.visit(node.operand)
4673 elif isinstance(node.op, ast.Not):
4674 raise NotImplementedError()
4675
4676 def emit_convert(self, to_type: Value, code_gen: Static38CodeGenerator) -> None:
4677 assert isinstance(to_type, CIntInstance)
4678 # Lower nibble is type-from, higher nibble is type-to.
4679 code_gen.emit("CONVERT_PRIMITIVE", (self.as_oparg() << 4) | to_type.as_oparg())
4680
4681
4682class CIntType(CType):
4683 instance: CIntInstance
4684
4685 def __init__(self, constant: int, name_override: Optional[str] = None) -> None:
4686 self.constant = constant
4687 # See TYPED_SIZE macro
4688 self.size: int = (constant >> 1) & 3
4689 self.signed: bool = bool(constant & 1)
4690 if name_override is None:
4691 name = ("" if self.signed else "u") + "int" + str(8 << self.size)
4692 else:
4693 name = name_override
4694 super().__init__(
4695 TypeName("__static__", name),
4696 [],
4697 CIntInstance(self, self.constant, self.size, self.signed),
4698 )
4699
4700 def can_assign_from(self, src: Class) -> bool:
4701 if isinstance(src, CIntType):
4702 if src.size <= self.size and src.signed == self.signed:
4703 # assignment to same or larger size, with same sign
4704 # is allowed
4705 return True
4706 if src.size < self.size and self.signed:
4707 # assignment to larger signed size from unsigned is
4708 # allowed
4709 return True
4710
4711 return super().can_assign_from(src)
4712
4713 def bind_call(
4714 self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
4715 ) -> NarrowingEffect:
4716 if len(node.args) != 1:
4717 raise visitor.syntax_error(
4718 f"{self.name} requires a single argument ({len(node.args)} given)", node
4719 )
4720
4721 visitor.set_type(node, self.instance)
4722 arg = node.args[0]
4723 try:
4724 visitor.visit(arg, self.instance)
4725 except TypedSyntaxError:
4726 visitor.visit(arg)
4727 arg_type = visitor.get_type(arg)
4728 if (
4729 arg_type is not INT_TYPE.instance
4730 and arg_type is not DYNAMIC
4731 and arg_type is not OBJECT
4732 ):
4733 raise
4734
4735 return NO_EFFECT
4736
4737 def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
4738 if len(node.args) != 1:
4739 raise code_gen.syntax_error(
4740 f"{self.name} requires a single argument ({len(node.args)} given)", node
4741 )
4742
4743 arg = node.args[0]
4744 arg_type = code_gen.get_type(arg)
4745 if isinstance(arg_type, CIntInstance):
4746 code_gen.visit(arg)
4747 if arg_type != self.instance:
4748 self.instance.emit_convert(arg_type, code_gen)
4749 else:
4750 self.instance.emit_unbox(arg, code_gen)
4751
4752
4753class CDoubleInstance(CInstance["CDoubleType"]):
4754
4755 _double_binary_opcode_signed: Mapping[Type[ast.AST], int] = {
4756 ast.Add: PRIM_OP_ADD_DBL,
4757 ast.Sub: PRIM_OP_SUB_DBL,
4758 ast.Mult: PRIM_OP_MUL_DBL,
4759 ast.Div: PRIM_OP_DIV_DBL,
4760 }
4761
4762 def get_op_id(self, op: AST) -> int:
4763 return self._double_binary_opcode_signed[type(op)]
4764
4765 def as_oparg(self) -> int:
4766 return TYPED_DOUBLE
4767
4768 def emit_name(self, node: ast.Name, code_gen: Static38CodeGenerator) -> None:
4769 if isinstance(node.ctx, ast.Load):
4770 code_gen.emit("LOAD_LOCAL", (node.id, self.klass.type_descr))
4771 elif isinstance(node.ctx, ast.Store):
4772 code_gen.emit("STORE_LOCAL", (node.id, self.klass.type_descr))
4773 else:
4774 raise TypedSyntaxError("unsupported op")
4775
4776 def bind_binop(
4777 self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
4778 ) -> bool:
4779 rtype = visitor.get_type(node.right)
4780 if rtype != self or type(node.op) not in self._double_binary_opcode_signed:
4781 raise visitor.syntax_error(self.binop_error(self, rtype, node.op), node)
4782
4783 visitor.set_type(node, self)
4784 return True
4785
4786 def bind_constant(self, node: ast.Constant, visitor: TypeBinder) -> None:
4787 visitor.set_type(node, self)
4788
4789 def emit_constant(
4790 self, node: ast.Constant, code_gen: Static38CodeGenerator
4791 ) -> None:
4792 code_gen.emit("PRIMITIVE_LOAD_CONST", (float(node.value), self.as_oparg()))
4793
4794 def emit_box(self, node: expr, code_gen: Static38CodeGenerator) -> None:
4795 code_gen.visit(node)
4796 type = code_gen.get_type(node)
4797 if isinstance(type, CDoubleInstance):
4798 code_gen.emit("PRIMITIVE_BOX", self.as_oparg())
4799 else:
4800 raise RuntimeError("unsupported box type: " + type.name)
4801
4802
4803class CDoubleType(CType):
4804 def __init__(self) -> None:
4805 super().__init__(
4806 TypeName("__static__", "double"),
4807 [OBJECT_TYPE],
4808 CDoubleInstance(self),
4809 )
4810
4811
4812CBOOL_TYPE = CIntType(TYPED_BOOL, name_override="cbool")
4813
4814INT8_TYPE = CIntType(TYPED_INT8)
4815INT16_TYPE = CIntType(TYPED_INT16)
4816INT32_TYPE = CIntType(TYPED_INT32)
4817INT64_TYPE = CIntType(TYPED_INT64)
4818
4819UINT8_TYPE = CIntType(TYPED_UINT8)
4820UINT16_TYPE = CIntType(TYPED_UINT16)
4821UINT32_TYPE = CIntType(TYPED_UINT32)
4822UINT64_TYPE = CIntType(TYPED_UINT64)
4823
4824INT64_VALUE = INT64_TYPE.instance
4825
4826CHAR_TYPE = CIntType(TYPED_INT8, name_override="char")
4827DOUBLE_TYPE = CDoubleType()
4828ARRAY_TYPE = ArrayClass(
4829 GenericTypeName("__static__", "Array", (GenericParameter("T", 0),))
4830)
4831ARRAY_EXACT_TYPE = ArrayClass(
4832 GenericTypeName("__static__", "Array", (GenericParameter("T", 0),)), is_exact=True
4833)
4834
4835# Vectors are just currently a special type of array that support
4836# methods that resize them.
4837VECTOR_TYPE_PARAM = GenericParameter("T", 0)
4838VECTOR_TYPE_NAME = GenericTypeName("__static__", "Vector", (VECTOR_TYPE_PARAM,))
4839
4840VECTOR_TYPE = VectorClass(VECTOR_TYPE_NAME, is_exact=True)
4841
4842
4843ALLOWED_ARRAY_TYPES: List[Class] = [
4844 INT8_TYPE,
4845 INT16_TYPE,
4846 INT32_TYPE,
4847 INT64_TYPE,
4848 UINT8_TYPE,
4849 UINT16_TYPE,
4850 UINT32_TYPE,
4851 UINT64_TYPE,
4852 CHAR_TYPE,
4853 DOUBLE_TYPE,
4854 FLOAT_TYPE,
4855]
4856
4857SIGNED_CINT_TYPES = [INT8_TYPE, INT16_TYPE, INT32_TYPE, INT64_TYPE]
4858UNSIGNED_CINT_TYPES: List[CIntType] = [
4859 UINT8_TYPE,
4860 UINT16_TYPE,
4861 UINT32_TYPE,
4862 UINT64_TYPE,
4863]
4864ALL_CINT_TYPES: Sequence[CIntType] = SIGNED_CINT_TYPES + UNSIGNED_CINT_TYPES
4865
4866NAME_TO_TYPE: Mapping[object, Class] = {
4867 "NoneType": NONE_TYPE,
4868 "object": OBJECT_TYPE,
4869 "str": STR_TYPE,
4870 "__static__.int8": INT8_TYPE,
4871 "__static__.int16": INT16_TYPE,
4872 "__static__.int32": INT32_TYPE,
4873 "__static__.int64": INT64_TYPE,
4874 "__static__.uint8": UINT8_TYPE,
4875 "__static__.uint16": UINT16_TYPE,
4876 "__static__.uint32": UINT32_TYPE,
4877 "__static__.uint64": UINT64_TYPE,
4878}
4879
4880
4881def parse_type(info: Dict[str, object]) -> Class:
4882 optional = info.get("optional", False)
4883 type = info.get("type")
4884 if type:
4885 klass = NAME_TO_TYPE.get(type)
4886 if klass is None:
4887 raise NotImplementedError("unsupported type: " + str(type))
4888 else:
4889 type_param = info.get("type_param")
4890 assert isinstance(type_param, int)
4891 klass = GenericParameter("T" + str(type_param), type_param)
4892
4893 if optional:
4894 return OPTIONAL_TYPE.make_generic_type((klass,), BUILTIN_GENERICS)
4895
4896 return klass
4897
4898
4899CHECKED_DICT_TYPE = CheckedDict(CHECKED_DICT_TYPE_NAME, [OBJECT_TYPE], pytype=chkdict)
4900
4901CHECKED_DICT_EXACT_TYPE = CheckedDict(
4902 CHECKED_DICT_TYPE_NAME, [OBJECT_TYPE], pytype=chkdict, is_exact=True
4903)
4904
4905EXACT_TYPES: Mapping[Class, Class] = {
4906 ARRAY_TYPE: ARRAY_EXACT_TYPE,
4907 LIST_TYPE: LIST_EXACT_TYPE,
4908 TUPLE_TYPE: TUPLE_EXACT_TYPE,
4909 INT_TYPE: INT_EXACT_TYPE,
4910 FLOAT_TYPE: FLOAT_EXACT_TYPE,
4911 COMPLEX_TYPE: COMPLEX_EXACT_TYPE,
4912 DICT_TYPE: DICT_EXACT_TYPE,
4913 CHECKED_DICT_TYPE: CHECKED_DICT_EXACT_TYPE,
4914 SET_TYPE: SET_EXACT_TYPE,
4915 STR_TYPE: STR_EXACT_TYPE,
4916}
4917
4918EXACT_INSTANCES: Mapping[Value, Value] = {
4919 k.instance: v.instance for k, v in EXACT_TYPES.items()
4920}
4921
4922INEXACT_TYPES: Mapping[Class, Class] = {v: k for k, v in EXACT_TYPES.items()}
4923
4924INEXACT_INSTANCES: Mapping[Value, Value] = {v: k for k, v in EXACT_INSTANCES.items()}
4925
4926
4927def exact(maybe_inexact: Value) -> Value:
4928 if isinstance(maybe_inexact, UnionInstance):
4929 return exact_type(maybe_inexact.klass).instance
4930 exact = EXACT_INSTANCES.get(maybe_inexact)
4931 return exact or maybe_inexact
4932
4933
4934def inexact(maybe_exact: Value) -> Value:
4935 if isinstance(maybe_exact, UnionInstance):
4936 return inexact_type(maybe_exact.klass).instance
4937 inexact = INEXACT_INSTANCES.get(maybe_exact)
4938 return inexact or maybe_exact
4939
4940
4941def exact_type(maybe_inexact: Class) -> Class:
4942 if isinstance(maybe_inexact, UnionType):
4943 generic_types = maybe_inexact.generic_types
4944 if generic_types is not None:
4945 return UNION_TYPE.make_generic_type(
4946 tuple(exact_type(a) for a in maybe_inexact.type_args), generic_types
4947 )
4948 exact = EXACT_TYPES.get(maybe_inexact)
4949 return exact or maybe_inexact
4950
4951
4952def inexact_type(maybe_exact: Class) -> Class:
4953 if isinstance(maybe_exact, UnionType):
4954 generic_types = maybe_exact.generic_types
4955 if generic_types is not None:
4956 return UNION_TYPE.make_generic_type(
4957 tuple(inexact_type(a) for a in maybe_exact.type_args), generic_types
4958 )
4959 inexact = INEXACT_TYPES.get(maybe_exact)
4960 return inexact or maybe_exact
4961
4962
4963if spamobj is not None:
4964 SPAM_OBJ = GenericClass(
4965 GenericTypeName("xxclassloader", "spamobj", (GenericParameter("T", 0),)),
4966 pytype=spamobj,
4967 )
4968 XXGENERIC_T = GenericParameter("T", 0)
4969 XXGENERIC_U = GenericParameter("U", 1)
4970 XXGENERIC_TYPE_NAME = GenericTypeName(
4971 "xxclassloader", "XXGeneric", (XXGENERIC_T, XXGENERIC_U)
4972 )
4973
4974 class XXGeneric(GenericClass):
4975 def __init__(
4976 self,
4977 name: GenericTypeName,
4978 bases: Optional[List[Class]] = None,
4979 instance: Optional[Object[Class]] = None,
4980 klass: Optional[Class] = None,
4981 members: Optional[Dict[str, Value]] = None,
4982 type_def: Optional[GenericClass] = None,
4983 is_exact: bool = False,
4984 pytype: Optional[Type[object]] = None,
4985 ) -> None:
4986 super().__init__(
4987 name,
4988 bases,
4989 instance,
4990 klass,
4991 members,
4992 type_def,
4993 is_exact,
4994 pytype,
4995 )
4996 self.members["foo"] = BuiltinMethodDescriptor(
4997 "foo",
4998 self,
4999 (
5000 Parameter("self", 0, ResolvedTypeRef(self), False, None, False),
5001 Parameter(
5002 "t",
5003 0,
5004 ResolvedTypeRef(XXGENERIC_T),
5005 False,
5006 None,
5007 False,
5008 ),
5009 Parameter(
5010 "u",
5011 0,
5012 ResolvedTypeRef(XXGENERIC_U),
5013 False,
5014 None,
5015 False,
5016 ),
5017 ),
5018 )
5019
5020 XX_GENERIC_TYPE = XXGeneric(XXGENERIC_TYPE_NAME)
5021else:
5022 SPAM_OBJ: Optional[GenericClass] = None
5023
5024
5025class GenericVisitor(ASTVisitor):
5026 def __init__(self, module_name: str, filename: str) -> None:
5027 super().__init__()
5028 self.module_name = module_name
5029 self.filename = filename
5030
5031 def visit(self, node: Union[AST, Sequence[AST]], *args: object) -> Optional[object]:
5032 # if we have a sequence of nodes, don't catch TypedSyntaxError here;
5033 # walk_list will call us back with each individual node in turn and we
5034 # can catch errors and add node info then.
5035 ctx = (
5036 error_context(self.filename, node)
5037 if isinstance(node, AST)
5038 else nullcontext()
5039 )
5040 with ctx:
5041 return super().visit(node, *args)
5042
5043 def syntax_error(self, msg: str, node: AST) -> TypedSyntaxError:
5044 return syntax_error(msg, self.filename, node)
5045
5046
5047class InitVisitor(ASTVisitor):
5048 def __init__(
5049 self, module: ModuleTable, klass: Class, init_func: FunctionDef
5050 ) -> None:
5051 super().__init__()
5052 self.module = module
5053 self.klass = klass
5054 self.init_func = init_func
5055
5056 def visitAnnAssign(self, node: AnnAssign) -> None:
5057 target = node.target
5058 if isinstance(target, Attribute):
5059 value = target.value
5060 if (
5061 isinstance(value, ast.Name)
5062 and value.id == self.init_func.args.args[0].arg
5063 ):
5064 attr = target.attr
5065 self.klass.define_slot(
5066 attr,
5067 TypeRef(self.module, node.annotation),
5068 assignment=node,
5069 )
5070
5071 def visitAssign(self, node: Assign) -> None:
5072 for target in node.targets:
5073 if not isinstance(target, Attribute):
5074 continue
5075 value = target.value
5076 if (
5077 isinstance(value, ast.Name)
5078 and value.id == self.init_func.args.args[0].arg
5079 ):
5080 attr = target.attr
5081 self.klass.define_slot(attr, assignment=node)
5082
5083
5084class DeclarationVisitor(GenericVisitor):
5085 def __init__(self, mod_name: str, filename: str, symbols: SymbolTable) -> None:
5086 super().__init__(mod_name, filename)
5087 self.symbols = symbols
5088 self.module = symbols[mod_name] = ModuleTable(mod_name, filename, symbols)
5089
5090 def finish_bind(self) -> None:
5091 self.module.finish_bind()
5092
5093 def visitAnnAssign(self, node: AnnAssign) -> None:
5094 self.module.decls.append((node, None))
5095
5096 def visitClassDef(self, node: ClassDef) -> None:
5097 bases = [self.module.resolve_type(base) or DYNAMIC_TYPE for base in node.bases]
5098 if not bases:
5099 bases.append(OBJECT_TYPE)
5100 klass = Class(TypeName(self.module_name, node.name), bases)
5101 self.module.decls.append((node, klass))
5102 for item in node.body:
5103 with error_context(self.filename, item):
5104 if isinstance(item, (AsyncFunctionDef, FunctionDef)):
5105 function = self._make_function(item)
5106 if not function:
5107 continue
5108 klass.define_function(item.name, function, self)
5109 if (
5110 item.name != "__init__"
5111 or not item.args.args
5112 or not isinstance(item, FunctionDef)
5113 ):
5114 continue
5115
5116 InitVisitor(self.module, klass, item).visit(item.body)
5117 elif isinstance(item, AnnAssign):
5118 # class C:
5119 # x: foo
5120 target = item.target
5121 if isinstance(target, ast.Name):
5122 klass.define_slot(
5123 target.id,
5124 TypeRef(self.module, item.annotation),
5125 # Note down whether the slot has been assigned a value.
5126 assignment=item if item.value else None,
5127 )
5128
5129 for base in bases:
5130 if base is NAMED_TUPLE_TYPE:
5131 # In named tuples, the fields are actually elements
5132 # of the tuple, so we can't do any advanced binding against it.
5133 klass = DYNAMIC_TYPE
5134 break
5135
5136 if base.is_final:
5137 raise self.syntax_error(
5138 f"Class `{klass.instance.name}` cannot subclass a Final class: `{base.instance.name}`",
5139 node,
5140 )
5141
5142 for d in node.decorator_list:
5143 if klass is DYNAMIC_TYPE:
5144 break
5145 with error_context(self.filename, d):
5146 decorator = self.module.resolve_type(d) or DYNAMIC_TYPE
5147 klass = decorator.bind_decorate_class(klass)
5148
5149 self.module.children[node.name] = klass
5150
5151 def _visitFunc(self, node: Union[FunctionDef, AsyncFunctionDef]) -> None:
5152 function = self._make_function(node)
5153 if function:
5154 self.module.children[function.func_name] = function
5155
5156 def _make_function(
5157 self, node: Union[FunctionDef, AsyncFunctionDef]
5158 ) -> Function | StaticMethod | None:
5159 func = Function(node, self.module, self.type_ref(node.returns))
5160 for decorator in node.decorator_list:
5161 decorator_type = self.module.resolve_type(decorator) or DYNAMIC_TYPE
5162 func = decorator_type.bind_decorate_function(self, func)
5163 if not isinstance(func, (Function, StaticMethod)):
5164 return None
5165 return func
5166
5167 def visitFunctionDef(self, node: FunctionDef) -> None:
5168 self._visitFunc(node)
5169
5170 def visitAsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
5171 self._visitFunc(node)
5172
5173 def type_ref(self, ann: Optional[expr]) -> TypeRef:
5174 if not ann:
5175 return ResolvedTypeRef(DYNAMIC_TYPE)
5176 return TypeRef(self.module, ann)
5177
5178 def visitImport(self, node: Import) -> None:
5179 for name in node.names:
5180 self.symbols.import_module(name.name)
5181
5182 def visitImportFrom(self, node: ImportFrom) -> None:
5183 mod_name = node.module
5184 if not mod_name or node.level:
5185 raise NotImplementedError("relative imports aren't supported")
5186 self.symbols.import_module(mod_name)
5187 mod = self.symbols.modules.get(mod_name)
5188 if mod is not None:
5189 for name in node.names:
5190 val = mod.children.get(name.name)
5191 if val is not None:
5192 self.module.children[name.asname or name.name] = val
5193
5194 # We don't pick up declarations in nested statements
5195 def visitFor(self, node: For) -> None:
5196 pass
5197
5198 def visitAsyncFor(self, node: AsyncFor) -> None:
5199 pass
5200
5201 def visitWhile(self, node: While) -> None:
5202 pass
5203
5204 def visitIf(self, node: If) -> None:
5205 test = node.test
5206 if isinstance(test, Name) and test.id == "TYPE_CHECKING":
5207 self.visit(node.body)
5208
5209 def visitWith(self, node: With) -> None:
5210 pass
5211
5212 def visitAsyncWith(self, node: AsyncWith) -> None:
5213 pass
5214
5215 def visitTry(self, node: Try) -> None:
5216 pass
5217
5218
5219class TypedSyntaxError(SyntaxError):
5220 pass
5221
5222
5223class LocalsBranch:
5224 """Handles branching and merging local variable types"""
5225
5226 def __init__(self, scope: BindingScope) -> None:
5227 self.scope = scope
5228 self.entry_locals: Dict[str, Value] = dict(scope.local_types)
5229
5230 def copy(self) -> Dict[str, Value]:
5231 """Make a copy of the current local state"""
5232 return dict(self.scope.local_types)
5233
5234 def restore(self, state: Optional[Dict[str, Value]] = None) -> None:
5235 """Restore the locals to the state when we entered"""
5236 self.scope.local_types.clear()
5237 self.scope.local_types.update(state or self.entry_locals)
5238
5239 def merge(self, entry_locals: Optional[Dict[str, Value]] = None) -> None:
5240 """Merge the entry locals, or a specific copy, into the current locals"""
5241 # TODO: What about del's?
5242 if entry_locals is None:
5243 entry_locals = self.entry_locals
5244 local_types = self.scope.local_types
5245 for key, value in entry_locals.items():
5246 if key in local_types:
5247 if value != local_types[key]:
5248 widest = self._widest_type(value, local_types[key])
5249 local_types[key] = widest or self.scope.decl_types[key].type
5250 continue
5251
5252 for key in local_types.keys():
5253 # If a value isn't definitely assigned we can safely turn it
5254 # back into the declared type
5255 if key not in entry_locals and key in self.scope.decl_types:
5256 local_types[key] = self.scope.decl_types[key].type
5257
5258 def _widest_type(self, *types: Value) -> Optional[Value]:
5259 # TODO: this should be a join, rather than just reverting to decl_type
5260 # if neither type is greater than the other
5261 if len(types) == 1:
5262 return types[0]
5263
5264 widest_type = None
5265 for src in types:
5266 if src == DYNAMIC:
5267 return DYNAMIC
5268
5269 if widest_type is None or src.klass.can_assign_from(widest_type.klass):
5270 widest_type = src
5271 elif widest_type is not None and not widest_type.klass.can_assign_from(
5272 src.klass
5273 ):
5274 return None
5275
5276 return widest_type
5277
5278
5279class TypeDeclaration:
5280 def __init__(self, typ: Value, is_final: bool = False) -> None:
5281 self.type = typ
5282 self.is_final = is_final
5283
5284
5285class BindingScope:
5286 def __init__(self, node: AST) -> None:
5287 self.node = node
5288 self.local_types: Dict[str, Value] = {}
5289 self.decl_types: Dict[str, TypeDeclaration] = {}
5290
5291 def branch(self) -> LocalsBranch:
5292 return LocalsBranch(self)
5293
5294 def declare(self, name: str, typ: Value, is_final: bool = False) -> TypeDeclaration:
5295 decl = TypeDeclaration(typ, is_final)
5296 self.decl_types[name] = decl
5297 self.local_types[name] = typ
5298 return decl
5299
5300
5301class ModuleBindingScope(BindingScope):
5302 def __init__(self, node: ast.Module, module: ModuleTable) -> None:
5303 super().__init__(node)
5304 self.module = module
5305 for name, typ in self.module.children.items():
5306 self.declare(name, typ)
5307
5308 def declare(self, name: str, typ: Value, is_final: bool = False) -> TypeDeclaration:
5309 self.module.children[name] = typ
5310 return super().declare(name, typ, is_final)
5311
5312
5313class NarrowingEffect:
5314 """captures type narrowing effects on variables"""
5315
5316 def and_(self, other: NarrowingEffect) -> NarrowingEffect:
5317 if other is NoEffect:
5318 return self
5319
5320 return AndEffect(self, other)
5321
5322 def or_(self, other: NarrowingEffect) -> NarrowingEffect:
5323 if other is NoEffect:
5324 return self
5325
5326 return OrEffect(self, other)
5327
5328 def not_(self) -> NarrowingEffect:
5329 return NegationEffect(self)
5330
5331 def apply(self, local_types: Dict[str, Value]) -> None:
5332 """applies the given effect in the target scope"""
5333 pass
5334
5335 def undo(self, local_types: Dict[str, Value]) -> None:
5336 """restores the type to its original value"""
5337 pass
5338
5339 def reverse(self, local_types: Dict[str, Value]) -> None:
5340 """applies the reverse of the scope or reverts it if
5341 there is no reverse"""
5342 self.undo(local_types)
5343
5344
5345class AndEffect(NarrowingEffect):
5346 def __init__(self, *effects: NarrowingEffect) -> None:
5347 self.effects: Sequence[NarrowingEffect] = effects
5348
5349 def and_(self, other: NarrowingEffect) -> NarrowingEffect:
5350 if other is NoEffect:
5351 return self
5352 elif isinstance(other, AndEffect):
5353 return AndEffect(*self.effects, *other.effects)
5354
5355 return AndEffect(*self.effects, other)
5356
5357 def apply(self, local_types: Dict[str, Value]) -> None:
5358 for effect in self.effects:
5359 effect.apply(local_types)
5360
5361 def undo(self, local_types: Dict[str, Value]) -> None:
5362 """restores the type to its original value"""
5363 for effect in self.effects:
5364 effect.undo(local_types)
5365
5366
5367class OrEffect(NarrowingEffect):
5368 def __init__(self, *effects: NarrowingEffect) -> None:
5369 self.effects: Sequence[NarrowingEffect] = effects
5370
5371 def and_(self, other: NarrowingEffect) -> NarrowingEffect:
5372 if other is NoEffect:
5373 return self
5374 elif isinstance(other, OrEffect):
5375 return OrEffect(*self.effects, *other.effects)
5376
5377 return OrEffect(*self.effects, other)
5378
5379 def reverse(self, local_types: Dict[str, Value]) -> None:
5380 for effect in self.effects:
5381 effect.reverse(local_types)
5382
5383 def undo(self, local_types: Dict[str, Value]) -> None:
5384 """restores the type to its original value"""
5385 for effect in self.effects:
5386 effect.undo(local_types)
5387
5388
5389class NoEffect(NarrowingEffect):
5390 def union(self, other: NarrowingEffect) -> NarrowingEffect:
5391 return other
5392
5393
5394# Singleton instance for no effects
5395NO_EFFECT = NoEffect()
5396
5397
5398class NegationEffect(NarrowingEffect):
5399 def __init__(self, negated: NarrowingEffect) -> None:
5400 self.negated = negated
5401
5402 def not_(self) -> NarrowingEffect:
5403 return self.negated
5404
5405 def apply(self, local_types: Dict[str, Value]) -> None:
5406 self.negated.reverse(local_types)
5407
5408 def undo(self, local_types: Dict[str, Value]) -> None:
5409 self.negated.undo(local_types)
5410
5411 def reverse(self, local_types: Dict[str, Value]) -> None:
5412 self.negated.apply(local_types)
5413
5414
5415class IsInstanceEffect(NarrowingEffect):
5416 def __init__(self, var: str, prev: Value, inst: Value, visitor: TypeBinder) -> None:
5417 self.var = var
5418 self.prev = prev
5419 self.inst = inst
5420 reverse = prev
5421 if isinstance(prev, UnionInstance):
5422 type_args = tuple(
5423 ta for ta in prev.klass.type_args if not inst.klass.can_assign_from(ta)
5424 )
5425 reverse = UNION_TYPE.make_generic_type(
5426 type_args, visitor.symtable.generic_types
5427 ).instance
5428 self.rev: Value = reverse
5429
5430 def apply(self, local_types: Dict[str, Value]) -> None:
5431 local_types[self.var] = self.inst
5432
5433 def undo(self, local_types: Dict[str, Value]) -> None:
5434 local_types[self.var] = self.prev
5435
5436 def reverse(self, local_types: Dict[str, Value]) -> None:
5437 local_types[self.var] = self.rev
5438
5439
5440class TerminalKind(IntEnum):
5441 NonTerminal = 0
5442 BreakOrContinue = 1
5443 Return = 2
5444
5445
5446class TypeBinder(GenericVisitor):
5447 """Walks an AST and produces an optionally strongly typed AST, reporting errors when
5448 operations are occuring that are not sound. Strong types are based upon places where
5449 annotations occur which opt-in the strong typing"""
5450
5451 def __init__(
5452 self,
5453 symbols: SymbolVisitor,
5454 filename: str,
5455 symtable: SymbolTable,
5456 module_name: str,
5457 optimize: int = 0,
5458 ) -> None:
5459 super().__init__(module_name, filename)
5460 self.symbols = symbols
5461 self.scopes: List[BindingScope] = []
5462 self.symtable = symtable
5463 self.cur_mod: ModuleTable = symtable[module_name]
5464 self.optimize = optimize
5465 self.terminals: Dict[AST, TerminalKind] = {}
5466 self.inline_depth = 0
5467
5468 @property
5469 def local_types(self) -> Dict[str, Value]:
5470 return self.binding_scope.local_types
5471
5472 @property
5473 def decl_types(self) -> Dict[str, TypeDeclaration]:
5474 return self.binding_scope.decl_types
5475
5476 @property
5477 def binding_scope(self) -> BindingScope:
5478 return self.scopes[-1]
5479
5480 @property
5481 def scope(self) -> AST:
5482 return self.binding_scope.node
5483
5484 def maybe_set_local_type(self, name: str, local_type: Value) -> Value:
5485 decl_type = self.decl_types[name].type
5486 if local_type is DYNAMIC or not decl_type.klass.can_be_narrowed:
5487 local_type = decl_type
5488 self.local_types[name] = local_type
5489 return local_type
5490
5491 def maybe_get_current_class(self) -> Optional[Class]:
5492 scope = self.scope
5493 if isinstance(scope, ClassDef):
5494 klass = self.cur_mod.resolve_name(scope.name)
5495 assert isinstance(klass, Class)
5496 return klass
5497
5498 def visit(
5499 self, node: Union[AST, Sequence[AST]], *args: object
5500 ) -> Optional[NarrowingEffect]:
5501 """This override is only here to give Pyre the return type information."""
5502 ret = super().visit(node, *args)
5503 if ret is not None:
5504 assert isinstance(ret, NarrowingEffect)
5505 return ret
5506 return None
5507
5508 def get_final_literal(self, node: AST) -> Optional[ast.Constant]:
5509 return self.cur_mod.get_final_literal(node, self.symbols.scopes[self.scope])
5510
5511 def declare_local(
5512 self, target: ast.Name, typ: Value, is_final: bool = False
5513 ) -> None:
5514 if target.id in self.decl_types:
5515 raise self.syntax_error(
5516 f"Cannot redefine local variable {target.id}", target
5517 )
5518 if isinstance(typ, CInstance):
5519 self.check_primitive_scope(target)
5520 self.binding_scope.declare(target.id, typ, is_final)
5521
5522 def check_static_import_flags(self, node: Module) -> None:
5523 saw_doc_str = False
5524 for stmt in node.body:
5525 if isinstance(stmt, ast.Expr):
5526 val = stmt.value
5527 if isinstance(val, ast.Constant) and isinstance(val.value, str):
5528 if saw_doc_str:
5529 break
5530 saw_doc_str = True
5531 else:
5532 break
5533 elif isinstance(stmt, ast.Import):
5534 continue
5535 elif isinstance(stmt, ast.ImportFrom):
5536 if stmt.module == "__static__.compiler_flags":
5537 for name in stmt.names:
5538 if name.name == "nonchecked_dicts":
5539 self.cur_mod.nonchecked_dicts = True
5540 elif name.name == "noframe":
5541 self.cur_mod.noframe = True
5542
5543 def visitModule(self, node: Module) -> None:
5544 self.scopes.append(ModuleBindingScope(node, self.cur_mod))
5545
5546 self.check_static_import_flags(node)
5547
5548 for stmt in node.body:
5549 self.visit(stmt)
5550
5551 self.scopes.pop()
5552
5553 def set_param(self, arg: ast.arg, arg_type: Class, scope: BindingScope) -> None:
5554 scope.declare(arg.arg, arg_type.instance)
5555 self.set_type(arg, arg_type.instance)
5556
5557 def _visitFunc(self, node: Union[FunctionDef, AsyncFunctionDef]) -> None:
5558 scope = BindingScope(node)
5559 for decorator in node.decorator_list:
5560 self.visit(decorator)
5561 cur_scope = self.scope
5562
5563 if (
5564 not node.decorator_list
5565 and isinstance(cur_scope, ClassDef)
5566 and node.args.args
5567 ):
5568 # Handle type of "self"
5569 klass = self.cur_mod.resolve_name(cur_scope.name)
5570 if isinstance(klass, Class):
5571 self.set_param(node.args.args[0], klass, scope)
5572 else:
5573 self.set_param(node.args.args[0], DYNAMIC_TYPE, scope)
5574
5575 for arg in node.args.posonlyargs:
5576 ann = arg.annotation
5577 if ann:
5578 self.visit(ann)
5579 arg_type = self.cur_mod.resolve_annotation(ann) or DYNAMIC_TYPE
5580 elif arg.arg in scope.decl_types:
5581 # Already handled self
5582 continue
5583 else:
5584 arg_type = DYNAMIC_TYPE
5585 self.set_param(arg, arg_type, scope)
5586
5587 for arg in node.args.args:
5588 ann = arg.annotation
5589 if ann:
5590 self.visit(ann)
5591 arg_type = self.cur_mod.resolve_annotation(ann) or DYNAMIC_TYPE
5592 elif arg.arg in scope.decl_types:
5593 # Already handled self
5594 continue
5595 else:
5596 arg_type = DYNAMIC_TYPE
5597
5598 self.set_param(arg, arg_type, scope)
5599
5600 if node.args.defaults:
5601 for default in node.args.defaults:
5602 self.visit(default)
5603
5604 if node.args.kw_defaults:
5605 for default in node.args.kw_defaults:
5606 if default is not None:
5607 self.visit(default)
5608
5609 vararg = node.args.vararg
5610 if vararg:
5611 ann = vararg.annotation
5612 if ann:
5613 self.visit(ann)
5614
5615 self.set_param(vararg, TUPLE_EXACT_TYPE, scope)
5616
5617 for arg in node.args.kwonlyargs:
5618 ann = arg.annotation
5619 if ann:
5620 self.visit(ann)
5621 arg_type = self.cur_mod.resolve_annotation(ann) or DYNAMIC_TYPE
5622 else:
5623 arg_type = DYNAMIC_TYPE
5624
5625 self.set_param(arg, arg_type, scope)
5626
5627 kwarg = node.args.kwarg
5628 if kwarg:
5629 ann = kwarg.annotation
5630 if ann:
5631 self.visit(ann)
5632 self.set_param(kwarg, DICT_EXACT_TYPE, scope)
5633
5634 returns = None if node.args in self.cur_mod.dynamic_returns else node.returns
5635 if returns:
5636 # We store the return type on the node for the function as we otherwise
5637 # don't need to store type information for it
5638 expected = self.cur_mod.resolve_annotation(returns) or DYNAMIC_TYPE
5639 self.set_type(node, expected.instance)
5640 self.visit(returns)
5641 else:
5642 self.set_type(node, DYNAMIC)
5643
5644 self.scopes.append(scope)
5645
5646 for stmt in node.body:
5647 self.visit(stmt)
5648
5649 self.scopes.pop()
5650
5651 def visitFunctionDef(self, node: FunctionDef) -> None:
5652 self._visitFunc(node)
5653
5654 def visitAsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
5655 self._visitFunc(node)
5656
5657 def visitClassDef(self, node: ClassDef) -> None:
5658 for decorator in node.decorator_list:
5659 self.visit(decorator)
5660
5661 for kwarg in node.keywords:
5662 self.visit(kwarg.value)
5663
5664 for base in node.bases:
5665 self.visit(base)
5666
5667 self.scopes.append(BindingScope(node))
5668
5669 for stmt in node.body:
5670 self.visit(stmt)
5671
5672 self.scopes.pop()
5673
5674 def set_type(self, node: AST, type: Value) -> None:
5675 self.cur_mod.types[node] = type
5676
5677 def get_type(self, node: AST) -> Value:
5678 assert node in self.cur_mod.types, f"node not found: {node}, {node.lineno}"
5679 return self.cur_mod.types[node]
5680
5681 def get_node_data(
5682 self, key: Union[AST, Delegator], data_type: Type[TType]
5683 ) -> TType:
5684 return cast(TType, self.cur_mod.node_data[key, data_type])
5685
5686 def set_node_data(
5687 self, key: Union[AST, Delegator], data_type: Type[TType], value: TType
5688 ) -> None:
5689 self.cur_mod.node_data[key, data_type] = value
5690
5691 def check_primitive_scope(self, node: Name) -> None:
5692 cur_scope = self.symbols.scopes[self.scope]
5693 var_scope = cur_scope.check_name(node.id)
5694 if var_scope != SC_LOCAL or isinstance(self.scope, Module):
5695 raise self.syntax_error(
5696 "cannot use primitives in global or closure scope", node
5697 )
5698
5699 def get_var_scope(self, var_id: str) -> Optional[int]:
5700 cur_scope = self.symbols.scopes[self.scope]
5701 var_scope = cur_scope.check_name(var_id)
5702 return var_scope
5703
5704 def _check_final_attribute_reassigned(
5705 self,
5706 target: AST,
5707 assignment: Optional[AST],
5708 ) -> None:
5709 member = None
5710 klass = None
5711 member_name = None
5712
5713 # Try to look up the Class and associated Slot
5714 scope = self.scope
5715 if isinstance(target, ast.Name) and isinstance(scope, ast.ClassDef):
5716 klass = self.cur_mod.resolve_name(scope.name)
5717 assert isinstance(klass, Class)
5718 member_name = target.id
5719 member = klass.get_member(member_name)
5720 elif isinstance(target, ast.Attribute):
5721 klass = self.get_type(target.value).klass
5722 member_name = target.attr
5723 member = klass.get_member(member_name)
5724
5725 # Ensure we don't reassign to Finals
5726 if (
5727 klass is not None
5728 and member is not None
5729 and (
5730 (
5731 isinstance(member, Slot)
5732 and member.is_final
5733 and member.assignment != assignment
5734 )
5735 or (isinstance(member, Function) and member.is_final)
5736 )
5737 ):
5738 raise self.syntax_error(
5739 f"Cannot assign to a Final attribute of {klass.instance.name}:{member_name}",
5740 target,
5741 )
5742
5743 def visitAnnAssign(self, node: AnnAssign) -> None:
5744 self.visit(node.annotation)
5745
5746 target = node.target
5747 comp_type = (
5748 self.cur_mod.resolve_annotation(node.annotation, is_declaration=True)
5749 or DYNAMIC_TYPE
5750 )
5751 is_final = False
5752 if isinstance(comp_type, FinalClass):
5753 is_final = True
5754 comp_type = comp_type.inner_type()
5755
5756 if isinstance(target, Name):
5757 self.declare_local(target, comp_type.instance, is_final)
5758 self.set_type(target, comp_type.instance)
5759
5760 self.visit(target)
5761
5762 value = node.value
5763 if value:
5764 self.visit(value, comp_type.instance)
5765 if isinstance(target, Name):
5766 # We could be narrowing the type after the assignment, so we update it here
5767 # even though we assigned it above (but we never narrow primtives)
5768 new_type = self.get_type(value)
5769 local_type = self.maybe_set_local_type(target.id, new_type)
5770 self.set_type(target, local_type)
5771
5772 self.check_can_assign_from(comp_type, self.get_type(value).klass, node)
5773 self._check_final_attribute_reassigned(target, node)
5774
5775 def visitAugAssign(self, node: AugAssign) -> None:
5776 self.visit(node.target)
5777 target_type = inexact(self.get_type(node.target))
5778 self.visit(node.value, target_type)
5779 self.set_type(node, target_type)
5780
5781 def visitAssign(self, node: Assign) -> None:
5782 # Sometimes, we need to propagate types from the target to the value to allow primitives to be handled
5783 # correctly. So we compute the narrowest target type. (Other checks do happen later).
5784 # e.g: `x: int8 = 1` means we need `1` to be of type `int8`
5785 narrowest_target_type = None
5786 for target in reversed(node.targets):
5787 cur_type = None
5788 if isinstance(target, ast.Name):
5789 # This is a name, it could be unassigned still
5790 decl_type = self.decl_types.get(target.id)
5791 if decl_type is not None:
5792 cur_type = decl_type.type
5793 elif isinstance(target, (ast.Tuple, ast.List)):
5794 # TODO: We should walk into the tuple/list and use it to infer
5795 # types down on the RHS if we can
5796 self.visit(target)
5797 else:
5798 # This is an attribute or subscript, the assignment can't change the type
5799 self.visit(target)
5800 cur_type = self.get_type(target)
5801
5802 if cur_type is not None and (
5803 narrowest_target_type is None
5804 or narrowest_target_type.klass.can_assign_from(cur_type.klass)
5805 ):
5806 narrowest_target_type = cur_type
5807
5808 self.visit(node.value, narrowest_target_type)
5809 value_type = self.get_type(node.value)
5810 for target in reversed(node.targets):
5811 self.assign_value(target, value_type, src=node.value, assignment=node)
5812
5813 self.set_type(node, value_type)
5814
5815 def check_can_assign_from(
5816 self, dest: Class, src: Class, node: AST, reason: str = "cannot be assigned to"
5817 ) -> None:
5818 if not dest.can_assign_from(src) and src is not DYNAMIC_TYPE:
5819 raise self.syntax_error(
5820 f"type mismatch: {src.instance.name} {reason} {dest.instance.name} ",
5821 node,
5822 )
5823
5824 def visitBoolOp(
5825 self, node: BoolOp, type_ctx: Optional[Class] = None
5826 ) -> NarrowingEffect:
5827 effect = NO_EFFECT
5828 final_type = None
5829 if isinstance(node.op, And):
5830 for value in node.values:
5831 new_effect = self.visit(value) or NO_EFFECT
5832 effect = effect.and_(new_effect)
5833 final_type = self.widen(final_type, self.get_type(value))
5834
5835 # apply the new effect as short circuiting would
5836 # eliminate it.
5837 new_effect.apply(self.local_types)
5838
5839 # we undo the effect as we have no clue what context we're in
5840 # but then we return the combined effect in case we're being used
5841 # in a conditional context
5842 effect.undo(self.local_types)
5843 elif isinstance(node.op, ast.Or):
5844 for value in node.values:
5845 new_effect = self.visit(value) or NO_EFFECT
5846 effect = effect.or_(new_effect)
5847 final_type = self.widen(final_type, self.get_type(value))
5848
5849 new_effect.reverse(self.local_types)
5850
5851 effect.undo(self.local_types)
5852 else:
5853 for value in node.values:
5854 self.visit(value)
5855 final_type = self.widen(final_type, self.get_type(value))
5856
5857 self.set_type(node, final_type or DYNAMIC)
5858 return effect
5859
5860 def visitBinOp(
5861 self, node: BinOp, type_ctx: Optional[Class] = None
5862 ) -> NarrowingEffect:
5863 # In order to interpret numeric literals as primitives within a
5864 # primitive type context, we want to try to pass the type context down
5865 # to each side, but we can't require this, otherwise things like `List:
5866 # List * int` would fail.
5867 try:
5868 self.visit(node.left, type_ctx)
5869 except TypedSyntaxError:
5870 self.visit(node.left)
5871 try:
5872 self.visit(node.right, type_ctx)
5873 except TypedSyntaxError:
5874 self.visit(node.right)
5875
5876 ltype = self.get_type(node.left)
5877 rtype = self.get_type(node.right)
5878
5879 tried_right = False
5880 if ltype.klass in rtype.klass.mro[1:]:
5881 if rtype.bind_reverse_binop(node, self, type_ctx):
5882 return NO_EFFECT
5883 tried_right = True
5884
5885 if ltype.bind_binop(node, self, type_ctx):
5886 return NO_EFFECT
5887
5888 if not tried_right:
5889 rtype.bind_reverse_binop(node, self, type_ctx)
5890
5891 return NO_EFFECT
5892
5893 def visitUnaryOp(
5894 self, node: UnaryOp, type_ctx: Optional[Class] = None
5895 ) -> NarrowingEffect:
5896 effect = self.visit(node.operand, type_ctx)
5897 self.get_type(node.operand).bind_unaryop(node, self, type_ctx)
5898 if (
5899 effect is not None
5900 and effect is not NO_EFFECT
5901 and isinstance(node.op, ast.Not)
5902 ):
5903 return effect.not_()
5904 return NO_EFFECT
5905
5906 def visitLambda(
5907 self, node: Lambda, type_ctx: Optional[Class] = None
5908 ) -> NarrowingEffect:
5909 self.visit(node.body)
5910 self.set_type(node, DYNAMIC)
5911 return NO_EFFECT
5912
5913 def visitIfExp(
5914 self, node: IfExp, type_ctx: Optional[Class] = None
5915 ) -> NarrowingEffect:
5916 effect = self.visit(node.test) or NO_EFFECT
5917 effect.apply(self.local_types)
5918 self.visit(node.body)
5919 effect.reverse(self.local_types)
5920 self.visit(node.orelse)
5921 effect.undo(self.local_types)
5922
5923 # Select the most compatible types that we can, or fallback to
5924 # dynamic if we can coerce to dynamic, otherwise report an error.
5925 body_t = self.get_type(node.body)
5926 else_t = self.get_type(node.orelse)
5927 if body_t.klass.can_assign_from(else_t.klass):
5928 self.set_type(node, body_t)
5929 elif else_t.klass.can_assign_from(body_t.klass):
5930 self.set_type(node, else_t)
5931 elif DYNAMIC_TYPE.can_assign_from(
5932 body_t.klass
5933 ) and DYNAMIC_TYPE.can_assign_from(else_t.klass):
5934 self.set_type(node, DYNAMIC)
5935 else:
5936 raise self.syntax_error(
5937 f"if expression has incompatible types: {body_t.name} and {else_t.name}",
5938 node,
5939 )
5940 return NO_EFFECT
5941
5942 def visitSlice(
5943 self, node: Slice, type_ctx: Optional[Class] = None
5944 ) -> NarrowingEffect:
5945 lower = node.lower
5946 if lower:
5947 self.visit(lower, type_ctx)
5948 upper = node.upper
5949 if upper:
5950 self.visit(upper, type_ctx)
5951 step = node.step
5952 if step:
5953 self.visit(step, type_ctx)
5954 self.set_type(node, SLICE_TYPE.instance)
5955 return NO_EFFECT
5956
5957 def widen(self, existing: Optional[Value], new: Value) -> Value:
5958 if existing is None or new.klass.can_assign_from(existing.klass):
5959 return new
5960 elif existing.klass.can_assign_from(new.klass):
5961 return existing
5962
5963 res = UNION_TYPE.make_generic_type(
5964 (existing.klass, new.klass), self.symtable.generic_types
5965 ).instance
5966 return res
5967
5968 def visitDict(
5969 self, node: ast.Dict, type_ctx: Optional[Class] = None
5970 ) -> NarrowingEffect:
5971 key_type: Optional[Value] = None
5972 value_type: Optional[Value] = None
5973 for k, v in zip(node.keys, node.values):
5974 if k:
5975 self.visit(k)
5976 key_type = self.widen(key_type, self.get_type(k))
5977 self.visit(v)
5978 value_type = self.widen(value_type, self.get_type(v))
5979 else:
5980 self.visit(v, type_ctx)
5981 d_type = self.get_type(v).klass
5982 if (
5983 d_type.generic_type_def is CHECKED_DICT_TYPE
5984 or d_type.generic_type_def is CHECKED_DICT_EXACT_TYPE
5985 ):
5986 assert isinstance(d_type, GenericClass)
5987 key_type = self.widen(key_type, d_type.type_args[0].instance)
5988 value_type = self.widen(value_type, d_type.type_args[1].instance)
5989 elif d_type in (DICT_TYPE, DICT_EXACT_TYPE, DYNAMIC_TYPE):
5990 key_type = DYNAMIC
5991 value_type = DYNAMIC
5992
5993 self.set_dict_type(node, key_type, value_type, type_ctx, is_exact=True)
5994 return NO_EFFECT
5995
5996 def set_dict_type(
5997 self,
5998 node: ast.expr,
5999 key_type: Optional[Value],
6000 value_type: Optional[Value],
6001 type_ctx: Optional[Class],
6002 is_exact: bool = False,
6003 ) -> Value:
6004 if self.cur_mod.nonchecked_dicts or not isinstance(
6005 type_ctx, CheckedDictInstance
6006 ):
6007 # This is not a checked dict, or the user opted out of checked dicts
6008 if type_ctx in (DICT_TYPE.instance, DICT_EXACT_TYPE.instance):
6009 typ = type_ctx
6010 elif is_exact:
6011 typ = DICT_EXACT_TYPE.instance
6012 else:
6013 typ = DICT_TYPE.instance
6014 assert typ is not None
6015 self.set_type(node, typ)
6016 return typ
6017
6018 # Calculate the type that is inferred by the keys and values
6019 if key_type is None:
6020 key_type = OBJECT_TYPE.instance
6021
6022 if value_type is None:
6023 value_type = OBJECT_TYPE.instance
6024
6025 checked_dict_typ = CHECKED_DICT_EXACT_TYPE if is_exact else CHECKED_DICT_TYPE
6026
6027 gen_type = checked_dict_typ.make_generic_type(
6028 (key_type.klass, value_type.klass), self.symtable.generic_types
6029 )
6030
6031 if type_ctx is not None:
6032 type_class = type_ctx.klass
6033 if type_class.generic_type_def in (
6034 CHECKED_DICT_EXACT_TYPE,
6035 CHECKED_DICT_TYPE,
6036 ):
6037 assert isinstance(type_class, GenericClass)
6038 self.set_type(node, type_ctx)
6039 # We can use the type context to have a type which is wider than the
6040 # inferred types. But we need to make sure that the keys/values are compatible
6041 # with the wider type, and if not, we'll report that the inferred type isn't
6042 # compatible.
6043 if not type_class.type_args[0].can_assign_from(
6044 key_type.klass
6045 ) or not type_class.type_args[1].can_assign_from(value_type.klass):
6046 self.check_can_assign_from(type_class, gen_type, node)
6047 return type_ctx
6048 else:
6049 # Otherwise we allow something that would assign to dynamic, but not something
6050 # that would assign to an unrelated type (e.g. int)
6051 self.set_type(node, gen_type.instance)
6052 self.check_can_assign_from(type_class, gen_type, node)
6053 else:
6054 self.set_type(node, gen_type.instance)
6055
6056 return gen_type.instance
6057
6058 def visitSet(
6059 self, node: ast.Set, type_ctx: Optional[Class] = None
6060 ) -> NarrowingEffect:
6061 for elt in node.elts:
6062 self.visit(elt)
6063 self.set_type(node, SET_EXACT_TYPE.instance)
6064 return NO_EFFECT
6065
6066 def visitGeneratorExp(
6067 self, node: GeneratorExp, type_ctx: Optional[Class] = None
6068 ) -> NarrowingEffect:
6069 self.visit_comprehension(node, node.generators, node.elt)
6070 self.set_type(node, DYNAMIC)
6071 return NO_EFFECT
6072
6073 def visitListComp(
6074 self, node: ListComp, type_ctx: Optional[Class] = None
6075 ) -> NarrowingEffect:
6076 self.visit_comprehension(node, node.generators, node.elt)
6077 self.set_type(node, LIST_EXACT_TYPE.instance)
6078 return NO_EFFECT
6079
6080 def visitSetComp(
6081 self, node: SetComp, type_ctx: Optional[Class] = None
6082 ) -> NarrowingEffect:
6083 self.visit_comprehension(node, node.generators, node.elt)
6084 self.set_type(node, SET_EXACT_TYPE.instance)
6085 return NO_EFFECT
6086
6087 def assign_value(
6088 self,
6089 target: expr,
6090 value: Value,
6091 src: Optional[expr] = None,
6092 assignment: Optional[AST] = None,
6093 ) -> None:
6094 if isinstance(target, Name):
6095 decl_type = self.decl_types.get(target.id)
6096 if decl_type is None:
6097 # This var is not declared in the current scope, but it might be a
6098 # global or nonlocal. In that case, we need to check whether it's a Final.
6099 scope_type = self.get_var_scope(target.id)
6100 if scope_type == SC_GLOBAL_EXPLICIT or scope_type == SC_GLOBAL_IMPLICIT:
6101 declared_type = self.scopes[0].decl_types.get(target.id, None)
6102 if declared_type is not None and declared_type.is_final:
6103 raise self.syntax_error(
6104 "Cannot assign to a Final variable", target
6105 )
6106
6107 # For an inferred exact type, we want to declare the inexact
6108 # type; the exact type is useful local inference information,
6109 # but we should still allow assignment of a subclass later.
6110 self.declare_local(target, inexact(value))
6111 else:
6112 if decl_type.is_final:
6113 raise self.syntax_error("Cannot assign to a Final variable", target)
6114 self.check_can_assign_from(decl_type.type.klass, value.klass, target)
6115
6116 local_type = self.maybe_set_local_type(target.id, value)
6117 self.set_type(target, local_type)
6118 elif isinstance(target, (ast.Tuple, ast.List)):
6119 if isinstance(src, (ast.Tuple, ast.List)) and len(target.elts) == len(
6120 src.elts
6121 ):
6122 for target, inner_value in zip(target.elts, src.elts):
6123 self.assign_value(
6124 target, self.get_type(inner_value), src=inner_value
6125 )
6126 elif isinstance(src, ast.Constant):
6127 t = src.value
6128 if isinstance(t, tuple) and len(t) == len(target.elts):
6129 for target, inner_value in zip(target.elts, t):
6130 self.assign_value(target, CONSTANT_TYPES[type(inner_value)])
6131 else:
6132 for val in target.elts:
6133 self.assign_value(val, DYNAMIC)
6134 else:
6135 for val in target.elts:
6136 self.assign_value(val, DYNAMIC)
6137 else:
6138 self.check_can_assign_from(self.get_type(target).klass, value.klass, target)
6139 self._check_final_attribute_reassigned(target, assignment)
6140
6141 def visitDictComp(
6142 self, node: DictComp, type_ctx: Optional[Class] = None
6143 ) -> NarrowingEffect:
6144 self.visit(node.generators[0].iter)
6145
6146 scope = BindingScope(node)
6147 self.scopes.append(scope)
6148
6149 iter_type = self.get_type(node.generators[0].iter).get_iter_type(
6150 node.generators[0].iter, self
6151 )
6152
6153 self.assign_value(node.generators[0].target, iter_type)
6154 for if_ in node.generators[0].ifs:
6155 self.visit(if_)
6156
6157 for gen in node.generators[1:]:
6158 self.visit(gen.iter)
6159 iter_type = self.get_type(gen.iter).get_iter_type(gen.iter, self)
6160 self.assign_value(gen.target, iter_type)
6161 for if_ in node.generators[0].ifs:
6162 self.visit(if_)
6163
6164 self.visit(node.key)
6165 self.visit(node.value)
6166
6167 self.scopes.pop()
6168
6169 key_type = self.get_type(node.key)
6170 value_type = self.get_type(node.value)
6171 self.set_dict_type(node, key_type, value_type, type_ctx, is_exact=True)
6172
6173 return NO_EFFECT
6174
6175 def visit_comprehension(
6176 self, node: ast.expr, generators: List[ast.comprehension], *elts: ast.expr
6177 ) -> None:
6178 self.visit(generators[0].iter)
6179
6180 scope = BindingScope(node)
6181 self.scopes.append(scope)
6182
6183 iter_type = self.get_type(generators[0].iter).get_iter_type(
6184 generators[0].iter, self
6185 )
6186
6187 self.assign_value(generators[0].target, iter_type)
6188 for if_ in generators[0].ifs:
6189 self.visit(if_)
6190
6191 for gen in generators[1:]:
6192 self.visit(gen.iter)
6193 iter_type = self.get_type(gen.iter).get_iter_type(gen.iter, self)
6194 self.assign_value(gen.target, iter_type)
6195 for if_ in generators[0].ifs:
6196 self.visit(if_)
6197
6198 for elt in elts:
6199 self.visit(elt)
6200
6201 self.scopes.pop()
6202
6203 def visitAwait(
6204 self, node: Await, type_ctx: Optional[Class] = None
6205 ) -> NarrowingEffect:
6206 self.visit(node.value)
6207 self.set_type(node, DYNAMIC)
6208 return NO_EFFECT
6209
6210 def visitYield(
6211 self, node: Yield, type_ctx: Optional[Class] = None
6212 ) -> NarrowingEffect:
6213 value = node.value
6214 if value is not None:
6215 self.visit(value)
6216 self.set_type(node, DYNAMIC)
6217 return NO_EFFECT
6218
6219 def visitYieldFrom(
6220 self, node: YieldFrom, type_ctx: Optional[Class] = None
6221 ) -> NarrowingEffect:
6222 self.visit(node.value)
6223 self.set_type(node, DYNAMIC)
6224 return NO_EFFECT
6225
6226 def visitIndex(
6227 self, node: Index, type_ctx: Optional[Class] = None
6228 ) -> NarrowingEffect:
6229 self.visit(node.value, type_ctx)
6230 self.set_type(node, self.get_type(node.value))
6231 return NO_EFFECT
6232
6233 def visitCompare(
6234 self, node: Compare, type_ctx: Optional[Class] = None
6235 ) -> NarrowingEffect:
6236 if len(node.ops) == 1 and isinstance(node.ops[0], (Is, IsNot)):
6237 left = node.left
6238 right = node.comparators[0]
6239 other = None
6240
6241 self.set_type(node, BOOL_TYPE.instance)
6242 self.set_type(node.ops[0], BOOL_TYPE.instance)
6243
6244 self.visit(left)
6245 self.visit(right)
6246
6247 if isinstance(left, (Constant, NameConstant)) and left.value is None:
6248 other = right
6249 elif isinstance(right, (Constant, NameConstant)) and right.value is None:
6250 other = left
6251
6252 if other is not None and isinstance(other, Name):
6253 var_type = self.get_type(other)
6254
6255 if (
6256 isinstance(var_type, UnionInstance)
6257 and not var_type.klass.is_generic_type_definition
6258 ):
6259 effect = IsInstanceEffect(
6260 other.id, var_type, NONE_TYPE.instance, self
6261 )
6262 if isinstance(node.ops[0], IsNot):
6263 effect = effect.not_()
6264 return effect
6265
6266 self.visit(node.left)
6267 left = node.left
6268 ltype = self.get_type(node.left)
6269 node.ops = [type(op)() for op in node.ops]
6270 for comparator, op in zip(node.comparators, node.ops):
6271 self.visit(comparator)
6272 rtype = self.get_type(comparator)
6273
6274 tried_right = False
6275 if ltype.klass in rtype.klass.mro[1:]:
6276 if ltype.bind_reverse_compare(
6277 node, left, op, comparator, self, type_ctx
6278 ):
6279 continue
6280 tried_right = True
6281
6282 if ltype.bind_compare(node, left, op, comparator, self, type_ctx):
6283 continue
6284
6285 if not tried_right:
6286 rtype.bind_reverse_compare(node, left, op, comparator, self, type_ctx)
6287
6288 ltype = rtype
6289 right = comparator
6290 return NO_EFFECT
6291
6292 def visitCall(
6293 self, node: Call, type_ctx: Optional[Class] = None
6294 ) -> NarrowingEffect:
6295 self.visit(node.func)
6296 result = self.get_type(node.func).bind_call(node, self, type_ctx)
6297 return result
6298
6299 def visitFormattedValue(
6300 self, node: FormattedValue, type_ctx: Optional[Class] = None
6301 ) -> NarrowingEffect:
6302 self.visit(node.value)
6303 self.set_type(node, DYNAMIC)
6304 return NO_EFFECT
6305
6306 def visitJoinedStr(
6307 self, node: JoinedStr, type_ctx: Optional[Class] = None
6308 ) -> NarrowingEffect:
6309 for value in node.values:
6310 self.visit(value)
6311
6312 self.set_type(node, STR_EXACT_TYPE.instance)
6313 return NO_EFFECT
6314
6315 def visitConstant(
6316 self, node: Constant, type_ctx: Optional[Class] = None
6317 ) -> NarrowingEffect:
6318 if type_ctx is not None:
6319 type_ctx.bind_constant(node, self)
6320 else:
6321 DYNAMIC.bind_constant(node, self)
6322 return NO_EFFECT
6323
6324 def visitAttribute(
6325 self, node: Attribute, type_ctx: Optional[Class] = None
6326 ) -> NarrowingEffect:
6327 self.visit(node.value)
6328 self.get_type(node.value).bind_attr(node, self, type_ctx)
6329 return NO_EFFECT
6330
6331 def visitSubscript(
6332 self, node: Subscript, type_ctx: Optional[Class] = None
6333 ) -> NarrowingEffect:
6334 self.visit(node.value)
6335 self.visit(node.slice)
6336 val_type = self.get_type(node.value)
6337 val_type.bind_subscr(node, self.get_type(node.slice), self)
6338 return NO_EFFECT
6339
6340 def visitStarred(
6341 self, node: Starred, type_ctx: Optional[Class] = None
6342 ) -> NarrowingEffect:
6343 self.visit(node.value)
6344 self.set_type(node, DYNAMIC)
6345 return NO_EFFECT
6346
6347 def visitName(
6348 self, node: Name, type_ctx: Optional[Class] = None
6349 ) -> NarrowingEffect:
6350 cur_scope = self.symbols.scopes[self.scope]
6351 scope = cur_scope.check_name(node.id)
6352 if scope == SC_LOCAL and not isinstance(self.scope, Module):
6353 var_type = self.local_types.get(node.id, DYNAMIC)
6354 self.set_type(node, var_type)
6355 if type_ctx is not None:
6356 self.check_can_assign_from(type_ctx.klass, var_type.klass, node)
6357 else:
6358 self.set_type(node, self.cur_mod.resolve_name(node.id) or DYNAMIC)
6359
6360 type = self.get_type(node)
6361 if (
6362 isinstance(type, UnionInstance)
6363 and not type.klass.is_generic_type_definition
6364 ):
6365 effect = IsInstanceEffect(node.id, type, NONE_TYPE.instance, self)
6366 return effect.not_()
6367
6368 return NO_EFFECT
6369
6370 def visitList(
6371 self, node: ast.List, type_ctx: Optional[Class] = None
6372 ) -> NarrowingEffect:
6373 for elt in node.elts:
6374 self.visit(elt, DYNAMIC)
6375 self.set_type(node, LIST_EXACT_TYPE.instance)
6376 return NO_EFFECT
6377
6378 def visitTuple(
6379 self, node: ast.Tuple, type_ctx: Optional[Class] = None
6380 ) -> NarrowingEffect:
6381 for elt in node.elts:
6382 self.visit(elt, DYNAMIC)
6383 self.set_type(node, TUPLE_EXACT_TYPE.instance)
6384 return NO_EFFECT
6385
6386 def set_terminal_kind(self, node: AST, level: TerminalKind) -> None:
6387 current = self.terminals.get(node, TerminalKind.NonTerminal)
6388 if current < level:
6389 self.terminals[node] = level
6390
6391 def visitContinue(self, node: ast.Continue) -> None:
6392 self.set_terminal_kind(node, TerminalKind.BreakOrContinue)
6393
6394 def visitBreak(self, node: ast.Break) -> None:
6395 self.set_terminal_kind(node, TerminalKind.BreakOrContinue)
6396
6397 def visitReturn(self, node: Return) -> None:
6398 self.set_terminal_kind(node, TerminalKind.Return)
6399 value = node.value
6400 if value is not None:
6401 cur_scope = self.binding_scope
6402 func = cur_scope.node
6403 expected = DYNAMIC
6404 if isinstance(func, (ast.FunctionDef, ast.AsyncFunctionDef)):
6405 func_returns = func.returns
6406 if func_returns:
6407 expected = (
6408 self.cur_mod.resolve_annotation(func_returns) or DYNAMIC_TYPE
6409 ).instance
6410
6411 self.visit(value, expected)
6412 returned = self.get_type(value).klass
6413 if returned is not DYNAMIC_TYPE and not expected.klass.can_assign_from(
6414 returned
6415 ):
6416 raise self.syntax_error(
6417 f"return type must be {expected.name}, not "
6418 + str(self.get_type(value).name),
6419 node,
6420 )
6421
6422 def visitImportFrom(self, node: ImportFrom) -> None:
6423 mod_name = node.module
6424 if node.level or not mod_name:
6425 raise NotImplementedError("relative imports aren't supported")
6426
6427 if mod_name == "__static__":
6428 for alias in node.names:
6429 name = alias.name
6430 if name == "*":
6431 raise self.syntax_error(
6432 "from __static__ import * is disallowed", node
6433 )
6434 elif name not in self.symtable.statics.children:
6435 raise self.syntax_error(f"unsupported static import {name}", node)
6436
6437 def visit_until_terminates(self, nodes: List[ast.stmt]) -> TerminalKind:
6438 for stmt in nodes:
6439 self.visit(stmt)
6440 if stmt in self.terminals:
6441 return self.terminals[stmt]
6442
6443 return TerminalKind.NonTerminal
6444
6445 def visitIf(self, node: If) -> None:
6446 branch = self.binding_scope.branch()
6447
6448 effect = self.visit(node.test) or NO_EFFECT
6449 effect.apply(self.local_types)
6450
6451 terminates = self.visit_until_terminates(node.body)
6452
6453 if node.orelse:
6454 if_end = branch.copy()
6455 branch.restore()
6456
6457 effect.reverse(self.local_types)
6458 else_terminates = self.visit_until_terminates(node.orelse)
6459 if else_terminates:
6460 if terminates:
6461 # We're the least severe terminal of our two children
6462 self.terminals[node] = min(terminates, else_terminates)
6463 else:
6464 branch.restore(if_end)
6465 elif not terminates:
6466 # Merge end of orelse with end of if
6467 branch.merge(if_end)
6468 elif terminates:
6469 effect.reverse(self.local_types)
6470 else:
6471 # Merge end of if w/ opening (with test effect reversed)
6472 branch.merge(effect.reverse(branch.entry_locals))
6473
6474 def visitTry(self, node: Try) -> None:
6475 branch = self.binding_scope.branch()
6476 self.visit(node.body)
6477
6478 branch.merge()
6479 post_try = branch.copy()
6480 merges = []
6481
6482 if node.orelse:
6483 self.visit(node.orelse)
6484 merges.append(branch.copy())
6485
6486 for handler in node.handlers:
6487 branch.restore(post_try)
6488 self.visit(handler)
6489 merges.append(branch.copy())
6490
6491 branch.restore(post_try)
6492 for merge in merges:
6493 branch.merge(merge)
6494
6495 if node.finalbody:
6496 self.visit(node.finalbody)
6497
6498 def visitExceptHandler(self, node: ast.ExceptHandler) -> None:
6499 htype = node.type
6500 hname = None
6501 if htype:
6502 self.visit(htype)
6503 handler_type = self.get_type(htype)
6504 hname = node.name
6505 if hname:
6506 if handler_type is DYNAMIC or not isinstance(handler_type, Class):
6507 handler_type = DYNAMIC_TYPE
6508
6509 decl_type = self.decl_types.get(hname)
6510 if decl_type and decl_type.is_final:
6511 raise self.syntax_error("Cannot assign to a Final variable", node)
6512
6513 self.binding_scope.declare(hname, handler_type.instance)
6514
6515 self.visit(node.body)
6516 if hname is not None:
6517 del self.decl_types[hname]
6518 del self.local_types[hname]
6519
6520 def visitWhile(self, node: While) -> None:
6521 branch = self.scopes[-1].branch()
6522
6523 effect = self.visit(node.test) or NO_EFFECT
6524 effect.apply(self.local_types)
6525
6526 while_returns = self.visit_until_terminates(node.body) == TerminalKind.Return
6527 if while_returns:
6528 branch.restore()
6529 effect.reverse(self.local_types)
6530 else:
6531 branch.merge(effect.reverse(branch.entry_locals))
6532
6533 if node.orelse:
6534 # The or-else can happen after the while body, or without executing
6535 # it, but it can only happen after the while condition evaluates to
6536 # False.
6537 effect.reverse(self.local_types)
6538 self.visit(node.orelse)
6539
6540 branch.merge()
6541
6542 def visitFor(self, node: For) -> None:
6543 self.visit(node.iter)
6544 target_type = self.get_type(node.iter).get_iter_type(node.iter, self)
6545 self.visit(node.target)
6546 self.assign_value(node.target, target_type)
6547 self.visit(node.body)
6548 self.visit(node.orelse)
6549
6550 def visitwithitem(self, node: ast.withitem) -> None:
6551 self.visit(node.context_expr)
6552 optional_vars = node.optional_vars
6553 if optional_vars:
6554 self.visit(optional_vars)
6555 self.assign_value(optional_vars, DYNAMIC)
6556
6557
6558class PyFlowGraph38Static(PyFlowGraphCinder):
6559 opcode: Opcode = opcode38static.opcode
6560
6561
6562class Static38CodeGenerator(CinderCodeGenerator):
6563 flow_graph = PyFlowGraph38Static
6564 _default_cache: Dict[Type[ast.AST], typingCallable[[...], None]] = {}
6565
6566 def __init__(
6567 self,
6568 parent: Optional[CodeGenerator],
6569 node: AST,
6570 symbols: SymbolVisitor,
6571 graph: PyFlowGraph,
6572 symtable: SymbolTable,
6573 modname: str,
6574 flags: int = 0,
6575 optimization_lvl: int = 0,
6576 ) -> None:
6577 super().__init__(parent, node, symbols, graph, flags, optimization_lvl)
6578 self.symtable = symtable
6579 self.modname = modname
6580 # Use this counter to allocate temporaries for loop indices
6581 self._tmpvar_loopidx_count = 0
6582 self.cur_mod: ModuleTable = self.symtable.modules[modname]
6583
6584 def _is_static_compiler_disabled(self, node: AST) -> bool:
6585 if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef)):
6586 # Static compilation can only be disabled for functions and classes.
6587 return False
6588 scope = self.scope
6589 fn = None
6590 if isinstance(scope, ClassScope):
6591 klass = self.cur_mod.resolve_name(scope.name)
6592 if klass:
6593 assert isinstance(klass, Class)
6594 if klass.donotcompile:
6595 # If static compilation is disabled on the entire class, it's skipped for all contained
6596 # methods too.
6597 return True
6598 fn = klass.get_own_member(node.name)
6599
6600 if fn is None:
6601 # Wasn't a method, let's check if it's a module level function
6602 fn = self.cur_mod.resolve_name(node.name)
6603
6604 if isinstance(fn, (Function, StaticMethod)):
6605 return (
6606 fn.donotcompile
6607 if isinstance(fn, Function)
6608 else fn.function.donotcompile
6609 )
6610
6611 return False
6612
6613 def make_child_codegen(
6614 self,
6615 tree: AST,
6616 graph: PyFlowGraph,
6617 codegen_type: Optional[Type[CinderCodeGenerator]] = None,
6618 ) -> CodeGenerator:
6619 if self._is_static_compiler_disabled(tree):
6620 return super().make_child_codegen(
6621 tree, graph, codegen_type=CinderCodeGenerator
6622 )
6623 graph.setFlag(self.consts.CO_STATICALLY_COMPILED)
6624 if self.cur_mod.noframe:
6625 graph.setFlag(self.consts.CO_NO_FRAME)
6626 gen = StaticCodeGenerator(
6627 self,
6628 tree,
6629 self.symbols,
6630 graph,
6631 symtable=self.symtable,
6632 modname=self.modname,
6633 optimization_lvl=self.optimization_lvl,
6634 )
6635 if not isinstance(tree, ast.ClassDef):
6636 self._processArgTypes(tree, gen)
6637 return gen
6638
6639 def _processArgTypes(self, node: AST, gen: Static38CodeGenerator) -> None:
6640 arg_checks = []
6641 cellvars = gen.graph.cellvars
6642 # pyre-fixme[16]: When node is a comprehension (i.e., not a FunctionDef
6643 # or Lambda), our caller manually adds an args attribute.
6644 args: ast.arguments = node.args
6645 is_comprehension = not isinstance(
6646 node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.Lambda)
6647 )
6648
6649 for i, arg in enumerate(args.posonlyargs):
6650 t = self.get_type(arg)
6651 if t is not DYNAMIC and t is not OBJECT:
6652 arg_checks.append(self._calculate_idx(arg.arg, i, cellvars))
6653 arg_checks.append(t.klass.type_descr)
6654
6655 for i, arg in enumerate(args.args):
6656 # Comprehension nodes don't have arguments when they're typed; make
6657 # up for that here.
6658 t = DYNAMIC if is_comprehension else self.get_type(arg)
6659 if t is not DYNAMIC and t is not OBJECT:
6660 arg_checks.append(
6661 self._calculate_idx(arg.arg, i + len(args.posonlyargs), cellvars)
6662 )
6663 arg_checks.append(t.klass.type_descr)
6664
6665 for i, arg in enumerate(args.kwonlyargs):
6666 t = self.get_type(arg)
6667 if t is not DYNAMIC and t is not OBJECT:
6668 arg_checks.append(
6669 self._calculate_idx(
6670 arg.arg,
6671 i + len(args.posonlyargs) + len(args.args),
6672 cellvars,
6673 )
6674 )
6675 arg_checks.append(t.klass.type_descr)
6676
6677 # we should never emit arg checks for object
6678 assert not any(td == ("builtins", "object") for td in arg_checks[1::2])
6679
6680 gen.emit("CHECK_ARGS", tuple(arg_checks))
6681
6682 def get_type(self, node: Union[AST, Delegator]) -> Value:
6683 return self.cur_mod.types[node]
6684
6685 def get_node_data(
6686 self, key: Union[AST, Delegator], data_type: Type[TType]
6687 ) -> TType:
6688 return cast(TType, self.cur_mod.node_data[key, data_type])
6689
6690 def set_node_data(
6691 self, key: Union[AST, Delegator], data_type: Type[TType], value: TType
6692 ) -> None:
6693 self.cur_mod.node_data[key, data_type] = value
6694
6695 @classmethod
6696 # pyre-fixme[14]: `make_code_gen` overrides method defined in
6697 # `Python37CodeGenerator` inconsistently.
6698 def make_code_gen(
6699 cls,
6700 module_name: str,
6701 tree: AST,
6702 filename: str,
6703 flags: int,
6704 optimize: int,
6705 peephole_enabled: bool = True,
6706 ast_optimizer_enabled: bool = True,
6707 ) -> Static38CodeGenerator:
6708 # TODO: Parsing here should really be that we run declaration visitor over all nodes,
6709 # and then perform post processing on the symbol table, and then proceed to analysis
6710 # and compilation
6711 symtable = SymbolTable()
6712 decl_visit = DeclarationVisitor(module_name, filename, symtable)
6713 decl_visit.visit(tree)
6714
6715 for module in symtable.modules.values():
6716 module.finish_bind()
6717
6718 if ast_optimizer_enabled:
6719 tree = AstOptimizer(optimize=optimize > 0).visit(tree)
6720
6721 s = symbols.SymbolVisitor()
6722 s.visit(tree)
6723
6724 graph = cls.flow_graph(
6725 module_name, filename, s.scopes[tree], peephole_enabled=peephole_enabled
6726 )
6727 graph.setFlag(cls.consts.CO_STATICALLY_COMPILED)
6728
6729 type_binder = TypeBinder(s, filename, symtable, module_name, optimize)
6730 type_binder.visit(tree)
6731
6732 code_gen = cls(None, tree, s, graph, symtable, module_name, flags, optimize)
6733 code_gen.visit(tree)
6734 return code_gen
6735
6736 def make_function_graph(
6737 self,
6738 func: FunctionDef,
6739 filename: str,
6740 scopes: Dict[AST, Scope],
6741 class_name: str,
6742 name: str,
6743 first_lineno: int,
6744 ) -> PyFlowGraph:
6745 graph = super().make_function_graph(
6746 func, filename, scopes, class_name, name, first_lineno
6747 )
6748
6749 # we tagged the graph as CO_STATICALLY_COMPILED, and the last co_const entry
6750 # will inform the runtime of the return type for the code object.
6751 ret_type = self.get_type(func)
6752 type_descr = ret_type.klass.type_descr
6753 graph.extra_consts.append(type_descr)
6754 return graph
6755
6756 @contextmanager
6757 def new_loopidx(self) -> Generator[str, None, None]:
6758 self._tmpvar_loopidx_count += 1
6759 try:
6760 yield f"{_TMP_VAR_PREFIX}.{self._tmpvar_loopidx_count}"
6761 finally:
6762 self._tmpvar_loopidx_count -= 1
6763
6764 def store_type_name_and_flags(self, node: ClassDef) -> None:
6765 self.emit("INVOKE_FUNCTION", (("_static", "set_type_static"), 1))
6766 self.storeName(node.name)
6767
6768 def walkClassBody(self, node: ClassDef, gen: CodeGenerator) -> None:
6769 super().walkClassBody(node, gen)
6770 cur_mod = self.symtable.modules[self.modname]
6771 klass = cur_mod.resolve_name(node.name)
6772 if not isinstance(klass, Class) or klass is DYNAMIC_TYPE:
6773 return
6774
6775 class_mems = [
6776 name for name, value in klass.members.items() if isinstance(value, Slot)
6777 ]
6778 if klass.allow_weakrefs:
6779 class_mems.append("__weakref__")
6780
6781 # In the future we may want a compatibility mode where we add
6782 # __dict__ and __weakref__
6783 gen.emit("LOAD_CONST", tuple(class_mems))
6784 gen.emit("STORE_NAME", "__slots__")
6785
6786 count = 0
6787 for name, value in klass.members.items():
6788 if not isinstance(value, Slot):
6789 continue
6790
6791 if value.decl_type is DYNAMIC_TYPE:
6792 continue
6793
6794 gen.emit("LOAD_CONST", name)
6795 gen.emit("LOAD_CONST", value.type_descr)
6796 count += 1
6797
6798 if count:
6799 gen.emit("BUILD_MAP", count)
6800 gen.emit("STORE_NAME", "__slot_types__")
6801
6802 def visitModule(self, node: Module) -> None:
6803 if not self.cur_mod.nonchecked_dicts:
6804 self.emit("LOAD_CONST", 0)
6805 self.emit("LOAD_CONST", ("chkdict",))
6806 self.emit("IMPORT_NAME", "_static")
6807 self.emit("IMPORT_FROM", "chkdict")
6808 self.emit("STORE_NAME", "dict")
6809
6810 super().visitModule(node)
6811
6812 def emit_module_return(self, node: ast.Module) -> None:
6813 self.emit("LOAD_CONST", tuple(self.cur_mod.named_finals.keys()))
6814 self.emit("STORE_NAME", "__final_constants__")
6815 super().emit_module_return(node)
6816
6817 def visitAugAttribute(self, node: AugAttribute, mode: str) -> None:
6818 if mode == "load":
6819 self.visit(node.value)
6820 self.emit("DUP_TOP")
6821 load = ast.Attribute(node.value, node.attr, ast.Load())
6822 load.lineno = node.lineno
6823 load.col_offset = node.col_offset
6824 self.get_type(node.value).emit_attr(load, self)
6825 elif mode == "store":
6826 self.emit("ROT_TWO")
6827 self.get_type(node.value).emit_attr(node, self)
6828
6829 def visitAugSubscript(self, node: AugSubscript, mode: str) -> None:
6830 if mode == "load":
6831 self.get_type(node.value).emit_subscr(node.obj, 1, self)
6832 elif mode == "store":
6833 self.get_type(node.value).emit_store_subscr(node.obj, self)
6834
6835 def visitAttribute(self, node: Attribute) -> None:
6836 self.update_lineno(node)
6837 if isinstance(node.ctx, ast.Load) and self._is_super_call(node.value):
6838 self.emit("LOAD_GLOBAL", "super")
6839 load_arg = self._emit_args_for_super(node.value, node.attr)
6840 self.emit("LOAD_ATTR_SUPER", load_arg)
6841 else:
6842 self.visit(node.value)
6843 self.get_type(node.value).emit_attr(node, self)
6844
6845 def emit_type_check(self, dest: Class, src: Class, node: AST) -> None:
6846 if src is DYNAMIC_TYPE and dest is not OBJECT_TYPE and dest is not DYNAMIC_TYPE:
6847 if isinstance(dest, CType):
6848 # TODO raise this in type binding instead
6849 raise syntax_error(
6850 f"Cannot assign a {src.instance.name} to {dest.instance.name}",
6851 self.graph.filename,
6852 node,
6853 )
6854 self.emit("CAST", dest.type_descr)
6855 elif not dest.can_assign_from(src):
6856 # TODO raise this in type binding instead
6857 raise syntax_error(
6858 f"Cannot assign a {src.instance.name} to {dest.instance.name}",
6859 self.graph.filename,
6860 node,
6861 )
6862
6863 def visitAssignTarget(
6864 self, elt: expr, stmt: AST, value: Optional[expr] = None
6865 ) -> None:
6866 if isinstance(elt, (ast.Tuple, ast.List)):
6867 self._visitUnpack(elt)
6868 if isinstance(value, ast.Tuple) and len(value.elts) == len(elt.elts):
6869 for target, inner_value in zip(elt.elts, value.elts):
6870 self.visitAssignTarget(target, stmt, inner_value)
6871 else:
6872 for target in elt.elts:
6873 self.visitAssignTarget(target, stmt, None)
6874 else:
6875 if value is not None:
6876 self.emit_type_check(
6877 self.get_type(elt).klass, self.get_type(value).klass, stmt
6878 )
6879 else:
6880 self.emit_type_check(self.get_type(elt).klass, DYNAMIC_TYPE, stmt)
6881 self.visit(elt)
6882
6883 def visitAssign(self, node: Assign) -> None:
6884 self.set_lineno(node)
6885 self.visit(node.value)
6886 dups = len(node.targets) - 1
6887 for i in range(len(node.targets)):
6888 elt = node.targets[i]
6889 if i < dups:
6890 self.emit("DUP_TOP")
6891 if isinstance(elt, ast.AST):
6892 self.visitAssignTarget(elt, node, node.value)
6893
6894 def visitAnnAssign(self, node: ast.AnnAssign) -> None:
6895 self.set_lineno(node)
6896 value = node.value
6897 if value:
6898 self.visit(value)
6899 self.emit_type_check(
6900 self.get_type(node.target).klass, self.get_type(value).klass, node
6901 )
6902 self.visit(node.target)
6903 target = node.target
6904 if isinstance(target, ast.Name):
6905 # If we have a simple name in a module or class, store the annotation
6906 if node.simple and isinstance(self.tree, (ast.Module, ast.ClassDef)):
6907 self.emitStoreAnnotation(target.id, node.annotation)
6908 elif isinstance(target, ast.Attribute):
6909 if not node.value:
6910 self.checkAnnExpr(target.value)
6911 elif isinstance(target, ast.Subscript):
6912 if not node.value:
6913 self.checkAnnExpr(target.value)
6914 self.checkAnnSubscr(target.slice)
6915 else:
6916 raise SystemError(
6917 f"invalid node type {type(node).__name__} for annotated assignment"
6918 )
6919
6920 if not node.simple:
6921 self.checkAnnotation(node)
6922
6923 def visitConstant(self, node: Constant) -> None:
6924 self.get_type(node).emit_constant(node, self)
6925
6926 def get_final_literal(self, node: AST) -> Optional[ast.Constant]:
6927 return self.cur_mod.get_final_literal(node, self.scope)
6928
6929 def visitName(self, node: Name) -> None:
6930 final_val = self.get_final_literal(node)
6931 if final_val is not None:
6932 # visit the constant directly
6933 return self.defaultVisit(final_val)
6934 self.get_type(node).emit_name(node, self)
6935
6936 def visitAugAssign(self, node: AugAssign) -> None:
6937 self.get_type(node.target).emit_augassign(node, self)
6938
6939 def visitAugName(self, node: AugName, mode: str) -> None:
6940 self.get_type(node).emit_augname(node, self, mode)
6941
6942 def visitCompare(self, node: Compare) -> None:
6943 self.update_lineno(node)
6944 self.visit(node.left)
6945 cleanup = self.newBlock("cleanup")
6946 left = node.left
6947 for op, code in zip(node.ops[:-1], node.comparators[:-1]):
6948 optype = self.get_type(op)
6949 ltype = self.get_type(left)
6950 if ltype != optype:
6951 optype.emit_convert(ltype, self)
6952 self.emitChainedCompareStep(op, optype, code, cleanup)
6953 left = code
6954 # now do the last comparison
6955 if node.ops:
6956 op = node.ops[-1]
6957 optype = self.get_type(op)
6958 ltype = self.get_type(left)
6959 if ltype != optype:
6960 optype.emit_convert(ltype, self)
6961 code = node.comparators[-1]
6962 self.visit(code)
6963 rtype = self.get_type(code)
6964 if rtype != optype:
6965 optype.emit_convert(rtype, self)
6966 optype.emit_compare(op, self)
6967 if len(node.ops) > 1:
6968 end = self.newBlock("end")
6969 self.emit("JUMP_FORWARD", end)
6970 self.nextBlock(cleanup)
6971 self.emit("ROT_TWO")
6972 self.emit("POP_TOP")
6973 self.nextBlock(end)
6974
6975 def emitChainedCompareStep(
6976 self,
6977 op: cmpop,
6978 optype: Value,
6979 value: AST,
6980 cleanup: Block,
6981 jump: str = "JUMP_IF_ZERO_OR_POP",
6982 ) -> None:
6983 self.visit(value)
6984 rtype = self.get_type(value)
6985 if rtype != optype:
6986 optype.emit_convert(rtype, self)
6987 self.emit("DUP_TOP")
6988 self.emit("ROT_THREE")
6989 optype.emit_compare(op, self)
6990 self.emit(jump, cleanup)
6991 self.nextBlock(label="compare_or_cleanup")
6992
6993 def visitBoolOp(self, node: BoolOp) -> None:
6994 end = self.newBlock()
6995 for child in node.values[:-1]:
6996 self.get_type(child).emit_jumpif_pop(
6997 child, end, type(node.op) == ast.Or, self
6998 )
6999 self.nextBlock()
7000 self.visit(node.values[-1])
7001 self.nextBlock(end)
7002
7003 def visitBinOp(self, node: BinOp) -> None:
7004 self.get_type(node).emit_binop(node, self)
7005
7006 def visitUnaryOp(self, node: UnaryOp, type_ctx: Optional[Class] = None) -> None:
7007 self.get_type(node).emit_unaryop(node, self)
7008
7009 def visitCall(self, node: Call) -> None:
7010 self.get_type(node.func).emit_call(node, self)
7011
7012 def visitSubscript(self, node: ast.Subscript, aug_flag: bool = False) -> None:
7013 self.get_type(node.value).emit_subscr(node, aug_flag, self)
7014
7015 def _visitReturnValue(self, value: ast.AST, expected: Class) -> None:
7016 self.visit(value)
7017 if expected is not DYNAMIC_TYPE and self.get_type(value) is DYNAMIC:
7018 self.emit("CAST", expected.type_descr)
7019
7020 def visitReturn(self, node: ast.Return) -> None:
7021 self.checkReturn(node)
7022 expected = self.get_type(self.tree).klass
7023 self.set_lineno(node)
7024 value = node.value
7025 is_return_constant = isinstance(value, ast.Constant)
7026 opcode = "RETURN_VALUE"
7027 oparg = 0
7028 if value:
7029 if not is_return_constant:
7030 self._visitReturnValue(value, expected)
7031 self.unwind_setup_entries(preserve_tos=True)
7032 else:
7033 self.unwind_setup_entries(preserve_tos=False)
7034 self._visitReturnValue(value, expected)
7035 if isinstance(expected, CType):
7036 opcode = "RETURN_INT"
7037 oparg = expected.instance.as_oparg()
7038 else:
7039 self.unwind_setup_entries(preserve_tos=False)
7040 self.emit("LOAD_CONST", None)
7041
7042 self.emit(opcode, oparg)
7043
7044 def visitDictComp(self, node: DictComp) -> None:
7045 dict_type = self.get_type(node)
7046 if dict_type in (DICT_TYPE.instance, DICT_EXACT_TYPE.instance):
7047 return super().visitDictComp(node)
7048 klass = dict_type.klass
7049
7050 assert isinstance(klass, GenericClass) and (
7051 klass.type_def is CHECKED_DICT_TYPE
7052 or klass.type_def is CHECKED_DICT_EXACT_TYPE
7053 ), dict_type
7054 self.compile_comprehension(
7055 node,
7056 sys.intern("<dictcomp>"),
7057 node.key,
7058 node.value,
7059 "BUILD_CHECKED_MAP",
7060 (dict_type.klass.type_descr, 0),
7061 )
7062
7063 def compile_subgendict(
7064 self, node: ast.Dict, begin: int, end: int, dict_descr: TypeDescr
7065 ) -> None:
7066 n = end - begin
7067 for i in range(begin, end):
7068 k = node.keys[i]
7069 assert k is not None
7070 self.visit(k)
7071 self.visit(node.values[i])
7072
7073 self.emit("BUILD_CHECKED_MAP", (dict_descr, n))
7074
7075 def visitDict(self, node: ast.Dict) -> None:
7076 dict_type = self.get_type(node)
7077 if dict_type in (DICT_TYPE.instance, DICT_EXACT_TYPE.instance):
7078 return super().visitDict(node)
7079 klass = dict_type.klass
7080
7081 assert isinstance(klass, GenericClass) and (
7082 klass.type_def is CHECKED_DICT_TYPE
7083 or klass.type_def is CHECKED_DICT_EXACT_TYPE
7084 ), dict_type
7085
7086 self.update_lineno(node)
7087 elements = 0
7088 is_unpacking = False
7089 built_final_dict = False
7090
7091 # This is similar to the normal dict code generation, but instead of relying
7092 # upon an opcode for BUILD_MAP_UNPACK we invoke the update method on the
7093 # underlying dict type. Therefore the first dict that we create becomes
7094 # the final dict. This allows us to not introduce a new opcode, but we should
7095 # also be able to dispatch the INVOKE_METHOD rather efficiently.
7096 dict_descr = dict_type.klass.type_descr
7097 update_descr = dict_descr + ("update",)
7098 for i, (k, v) in enumerate(zip(node.keys, node.values)):
7099 is_unpacking = k is None
7100 if elements == 0xFFFF or (elements and is_unpacking):
7101 self.compile_subgendict(node, i - elements, i, dict_descr)
7102 built_final_dict = True
7103 elements = 0
7104
7105 if is_unpacking:
7106 if not built_final_dict:
7107 # {**foo, ...}, we need to generate the empty dict
7108 self.emit("BUILD_CHECKED_MAP", (dict_descr, 0))
7109 built_final_dict = True
7110 self.emit("DUP_TOP")
7111 self.visit(v)
7112
7113 self.emit_invoke_method(update_descr, 1)
7114 self.emit("POP_TOP")
7115 else:
7116 elements += 1
7117
7118 if elements or not built_final_dict:
7119 if built_final_dict:
7120 self.emit("DUP_TOP")
7121 self.compile_subgendict(
7122 node, len(node.keys) - elements, len(node.keys), dict_descr
7123 )
7124 if built_final_dict:
7125 self.emit_invoke_method(update_descr, 1)
7126 self.emit("POP_TOP")
7127
7128 def visitFor(self, node: ast.For) -> None:
7129 iter_type = self.get_type(node.iter)
7130 return iter_type.emit_forloop(node, self)
7131
7132 def emit_invoke_method(self, descr: TypeDescr, arg_count: int) -> None:
7133 # Emit a zero EXTENDED_ARG before so that we can optimize and insert the
7134 # arg count
7135 self.emit("EXTENDED_ARG", 0)
7136 self.emit("INVOKE_METHOD", (descr, arg_count))
7137
7138 def defaultVisit(self, node: object, *args: object) -> None:
7139 self.node = node
7140 klass = node.__class__
7141 meth = self._default_cache.get(klass, None)
7142 if meth is None:
7143 className = klass.__name__
7144 meth = getattr(
7145 super(Static38CodeGenerator, Static38CodeGenerator),
7146 "visit" + className,
7147 StaticCodeGenerator.generic_visit,
7148 )
7149 self._default_cache[klass] = meth
7150 return meth(self, node, *args)
7151
7152 def compileJumpIf(self, test: AST, next: Block, is_if_true: bool) -> None:
7153 self.get_type(test).emit_jumpif(test, next, is_if_true, self)
7154
7155 def _calculate_idx(
7156 self, arg_name: str, non_cellvar_pos: int, cellvars: IndexedSet
7157 ) -> int:
7158 try:
7159 offset = cellvars.index(arg_name)
7160 except ValueError:
7161 return non_cellvar_pos
7162 else:
7163 # the negative sign indicates to the runtime/JIT that this is a cellvar
7164 return -(offset + 1)
7165
7166
7167StaticCodeGenerator = Static38CodeGenerator