this repo has no description
1# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
2import ast
3import asyncio
4import dis
5import gc
6import inspect
7import itertools
8import re
9import string
10import sys
11import unittest
12import warnings
13from array import array
14from collections import UserDict
15from compiler import consts, walk
16from compiler.consts38 import CO_NO_FRAME, CO_STATICALLY_COMPILED
17from compiler.optimizer import AstOptimizer
18from compiler.pycodegen import PythonCodeGenerator, make_compiler
19from compiler.static import (
20 prim_name_to_type,
21 BOOL_TYPE,
22 BYTES_TYPE,
23 COMPLEX_EXACT_TYPE,
24 Class,
25 Object,
26 DICT_TYPE,
27 DYNAMIC,
28 FLOAT_TYPE,
29 INT_TYPE,
30 LIST_EXACT_TYPE,
31 LIST_TYPE,
32 NONE_TYPE,
33 OBJECT_TYPE,
34 PRIM_OP_ADD_INT,
35 PRIM_OP_DIV_INT,
36 PRIM_OP_GT_INT,
37 PRIM_OP_LT_INT,
38 STR_EXACT_TYPE,
39 STR_TYPE,
40 DeclarationVisitor,
41 Function,
42 StaticCodeGenerator,
43 SymbolTable,
44 TypeBinder,
45 TypedSyntaxError,
46 Value,
47 TUPLE_EXACT_TYPE,
48 TUPLE_TYPE,
49 INT_EXACT_TYPE,
50 FLOAT_EXACT_TYPE,
51 SET_EXACT_TYPE,
52 ELLIPSIS_TYPE,
53 FAST_LEN_ARRAY,
54 FAST_LEN_LIST,
55 FAST_LEN_TUPLE,
56 FAST_LEN_INEXACT,
57 FAST_LEN_DICT,
58 FAST_LEN_SET,
59 SEQ_ARRAY_INT8,
60 SEQ_ARRAY_INT16,
61 SEQ_ARRAY_INT32,
62 SEQ_ARRAY_INT64,
63 SEQ_ARRAY_UINT8,
64 SEQ_ARRAY_UINT16,
65 SEQ_ARRAY_UINT32,
66 SEQ_ARRAY_UINT64,
67 SEQ_LIST,
68 SEQ_LIST_INEXACT,
69 SEQ_TUPLE,
70 SEQ_REPEAT_INEXACT_SEQ,
71 SEQ_REPEAT_INEXACT_NUM,
72 SEQ_REPEAT_PRIMITIVE_NUM,
73 SEQ_REPEAT_REVERSED,
74 SEQ_SUBSCR_UNCHECKED,
75 TYPED_BOOL,
76 TYPED_INT8,
77 TYPED_INT16,
78 TYPED_INT32,
79 TYPED_INT64,
80 TYPED_UINT8,
81 TYPED_UINT16,
82 TYPED_UINT32,
83 TYPED_UINT64,
84 FAST_LEN_STR,
85 DICT_EXACT_TYPE,
86 TYPED_DOUBLE,
87)
88from compiler.symbols import SymbolVisitor
89from contextlib import contextmanager
90from copy import deepcopy
91from io import StringIO
92from os import path
93from test.support import maybe_get_event_loop_policy
94from textwrap import dedent
95from types import CodeType, MemberDescriptorType, ModuleType
96from typing import Generic, Optional, Tuple, TypeVar
97from unittest import TestCase, skip, skipIf
98from unittest.mock import Mock, patch
99
100import cinder
101import xxclassloader
102from __static__ import (
103 Array,
104 Vector,
105 chkdict,
106 int32,
107 int64,
108 int8,
109 make_generic_type,
110 StaticGeneric,
111 is_type_static,
112)
113from cinder import StrictModule
114
115from .common import CompilerTest
116
117IS_MULTITHREADED_COMPILE_TEST = False
118try:
119 import cinderjit
120
121 IS_MULTITHREADED_COMPILE_TEST = cinderjit.is_test_multithreaded_compile_enabled()
122except ImportError:
123 cinderjit = None
124
125
126RICHARDS_PATH = path.join(
127 path.dirname(__file__),
128 "..",
129 "..",
130 "..",
131 "Tools",
132 "benchmarks",
133 "richards_static.py",
134)
135
136
137PRIM_NAME_TO_TYPE = {
138 "cbool": TYPED_BOOL,
139 "int8": TYPED_INT8,
140 "int16": TYPED_INT16,
141 "int32": TYPED_INT32,
142 "int64": TYPED_INT64,
143 "uint8": TYPED_UINT8,
144 "uint16": TYPED_UINT16,
145 "uint32": TYPED_UINT32,
146 "uint64": TYPED_UINT64,
147}
148
149
150def type_mismatch(from_type: str, to_type: str) -> str:
151 return re.escape(f"type mismatch: {from_type} cannot be assigned to {to_type}")
152
153
154def optional(type: str) -> str:
155 return f"Optional[{type}]"
156
157
158def init_xxclassloader():
159 codestr = """
160 from typing import Generic, TypeVar, _tp_cache
161 from __static__.compiler_flags import nonchecked_dicts
162 # Setup a test for typing
163 T = TypeVar('T')
164 U = TypeVar('U')
165
166
167 class XXGeneric(Generic[T, U]):
168 d = {}
169
170 def foo(self, t: T, u: U) -> str:
171 return str(t) + str(u)
172
173 @classmethod
174 def __class_getitem__(cls, elem_type):
175 if elem_type in XXGeneric.d:
176 return XXGeneric.d[elem_type]
177
178 XXGeneric.d[elem_type] = type(
179 f"XXGeneric[{elem_type[0].__name__}, {elem_type[1].__name__}]",
180 (object, ),
181 {
182 "foo": XXGeneric.foo,
183 "__slots__":(),
184 }
185 )
186 return XXGeneric.d[elem_type]
187 """
188
189 code = make_compiler(
190 inspect.cleandoc(codestr),
191 "",
192 "exec",
193 generator=StaticCodeGenerator,
194 modname="xxclassloader",
195 ).getCode()
196 d = {}
197 exec(code, d, d)
198
199 xxclassloader.XXGeneric = d["XXGeneric"]
200
201
202class StaticTestBase(CompilerTest):
203 def compile(
204 self,
205 code,
206 generator=StaticCodeGenerator,
207 modname="<module>",
208 optimize=0,
209 peephole_enabled=True,
210 ast_optimizer_enabled=True,
211 ):
212 if (
213 not peephole_enabled
214 or not ast_optimizer_enabled
215 or generator is not StaticCodeGenerator
216 ):
217 return super().compile(
218 code,
219 generator,
220 modname,
221 optimize,
222 peephole_enabled,
223 ast_optimizer_enabled,
224 )
225
226 symtable = SymbolTable()
227 code = inspect.cleandoc("\n" + code)
228 tree = ast.parse(code)
229 return symtable.compile(modname, f"{modname}.py", tree, optimize)
230
231 def type_error(self, code, pattern):
232 with self.assertRaisesRegex(TypedSyntaxError, pattern):
233 self.compile(code)
234
235 _temp_mod_num = 0
236
237 def _temp_mod_name(self):
238 StaticTestBase._temp_mod_num += 1
239 return sys._getframe().f_back.f_back.f_back.f_code.co_name + str(
240 StaticTestBase._temp_mod_num
241 )
242
243 @contextmanager
244 def in_module(self, code, name=None, code_gen=StaticCodeGenerator, optimize=0):
245 if name is None:
246 name = self._temp_mod_name()
247
248 try:
249 compiled = self.compile(code, code_gen, name, optimize)
250 m = type(sys)(name)
251 d = m.__dict__
252 sys.modules[name] = m
253 exec(compiled, d)
254 d["__name__"] = name
255
256 yield d
257 finally:
258 if not IS_MULTITHREADED_COMPILE_TEST:
259 # don't throw a new exception if we failed to compile
260 if name in sys.modules:
261 del sys.modules[name]
262 d.clear()
263 gc.collect()
264
265 @contextmanager
266 def in_strict_module(
267 self,
268 code,
269 name=None,
270 code_gen=StaticCodeGenerator,
271 optimize=0,
272 enable_patching=False,
273 ):
274 if name is None:
275 name = self._temp_mod_name()
276
277 try:
278 compiled = self.compile(code, code_gen, name, optimize)
279 d = {"__name__": name}
280 m = StrictModule(d, enable_patching)
281 sys.modules[name] = m
282 exec(compiled, d)
283
284 yield m
285 finally:
286 if not IS_MULTITHREADED_COMPILE_TEST:
287 # don't throw a new exception if we failed to compile
288 if name in sys.modules:
289 del sys.modules[name]
290 d.clear()
291 gc.collect()
292
293 def run_code(self, code, generator=None, modname=None, peephole_enabled=True):
294 if modname is None:
295 modname = self._temp_mod_name()
296 d = super().run_code(code, generator, modname, peephole_enabled)
297 if IS_MULTITHREADED_COMPILE_TEST:
298 sys.modules[modname] = d
299 return d
300
301 @property
302 def base_size(self):
303 class C:
304 __slots__ = ()
305
306 return sys.getsizeof(C())
307
308 @property
309 def ptr_size(self):
310 return 8 if sys.maxsize > 2 ** 32 else 4
311
312 def assert_jitted(self, func):
313 if cinderjit is None:
314 return
315
316 self.assertTrue(cinderjit.is_jit_compiled(func), func.__name__)
317
318 def assert_not_jitted(self, func):
319 if cinderjit is None:
320 return
321
322 self.assertFalse(cinderjit.is_jit_compiled(func))
323
324 def assert_not_jitted(self, func):
325 if cinderjit is None:
326 return
327
328 self.assertFalse(cinderjit.is_jit_compiled(func))
329
330 def setUp(self):
331 # ensure clean classloader/vtable slate for all tests
332 cinder.clear_classloader_caches()
333 # ensure our async tests don't change the event loop policy
334 policy = maybe_get_event_loop_policy()
335 self.addCleanup(lambda: asyncio.set_event_loop_policy(policy))
336
337 def subTest(self, **kwargs):
338 cinder.clear_classloader_caches()
339 return super().subTest(**kwargs)
340
341 def make_async_func_hot(self, func):
342 async def make_hot():
343 for i in range(50):
344 await func()
345
346 asyncio.run(make_hot())
347
348
349class StaticCompilationTests(StaticTestBase):
350 @classmethod
351 def setUpClass(cls):
352 init_xxclassloader()
353
354 @classmethod
355 def tearDownClass(cls):
356 if not IS_MULTITHREADED_COMPILE_TEST:
357 del xxclassloader.XXGeneric
358
359 def test_static_import_unknown(self) -> None:
360 codestr = """
361 from __static__ import does_not_exist
362 """
363 with self.assertRaises(TypedSyntaxError):
364 self.compile(codestr, StaticCodeGenerator, modname="foo")
365
366 def test_static_import_star(self) -> None:
367 codestr = """
368 from __static__ import *
369 """
370 with self.assertRaises(TypedSyntaxError):
371 self.compile(codestr, StaticCodeGenerator, modname="foo")
372
373 def test_reveal_type(self) -> None:
374 codestr = """
375 def f(x: int):
376 reveal_type(x or None)
377 """
378 with self.assertRaisesRegex(
379 TypedSyntaxError,
380 r"reveal_type\(x or None\): 'Optional\[int\]'",
381 ):
382 self.compile(codestr)
383
384 def test_reveal_type_local(self) -> None:
385 codestr = """
386 def f(x: int | None):
387 if x is not None:
388 reveal_type(x)
389 """
390 with self.assertRaisesRegex(
391 TypedSyntaxError,
392 r"reveal_type\(x\): 'int', 'x' has declared type 'Optional\[int\]' and local type 'int'",
393 ):
394 self.compile(codestr)
395
396 def test_redefine_local_type(self) -> None:
397 codestr = """
398 class C: pass
399 class D: pass
400
401 def f():
402 x: C = C()
403 x: D = D()
404 """
405 with self.assertRaises(TypedSyntaxError):
406 self.compile(codestr, StaticCodeGenerator, modname="foo")
407
408 def test_mixed_chain_assign(self) -> None:
409 codestr = """
410 class C: pass
411 class D: pass
412
413 def f():
414 x: C = C()
415 y: D = D()
416 x = y = D()
417 """
418 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("foo.D", "foo.C")):
419 self.compile(codestr, StaticCodeGenerator, modname="foo")
420
421 def test_bool_cast(self) -> None:
422 codestr = """
423 from __static__ import cast
424 class D: pass
425
426 def f(x) -> bool:
427 y: bool = cast(bool, x)
428 return y
429 """
430 self.compile(codestr, StaticCodeGenerator, modname="foo")
431
432 def test_typing_overload(self) -> None:
433 """Typing overloads are ignored, don't cause member name conflict."""
434 codestr = """
435 from typing import Optional, overload
436
437 class C:
438 @overload
439 def foo(self, x: int) -> int:
440 ...
441
442 def foo(self, x: Optional[int]) -> Optional[int]:
443 return x
444
445 def f(x: int) -> Optional[int]:
446 return C().foo(x)
447 """
448 self.assertReturns(codestr, "Optional[int]")
449
450 def test_mixed_binop(self):
451 with self.assertRaisesRegex(
452 TypedSyntaxError, "cannot add int64 and Exact\\[int\\]"
453 ):
454 self.bind_module(
455 """
456 from __static__ import ssize_t
457
458 def f():
459 x: ssize_t = 1
460 y = 1
461 x + y
462 """
463 )
464
465 with self.assertRaisesRegex(
466 TypedSyntaxError, "cannot add Exact\\[int\\] and int64"
467 ):
468 self.bind_module(
469 """
470 from __static__ import ssize_t
471
472 def f():
473 x: ssize_t = 1
474 y = 1
475 y + x
476 """
477 )
478
479 def test_mixed_binop_okay(self):
480 codestr = """
481 from __static__ import ssize_t, box
482
483 def f():
484 x: ssize_t = 1
485 y = x + 1
486 return box(y)
487 """
488 with self.in_module(codestr) as mod:
489 f = mod["f"]
490 self.assertEqual(f(), 2)
491
492 def test_mixed_binop_okay_1(self):
493 codestr = """
494 from __static__ import ssize_t, box
495
496 def f():
497 x: ssize_t = 1
498 y = 1 + x
499 return box(y)
500 """
501 with self.in_module(codestr) as mod:
502 f = mod["f"]
503 self.assertEqual(f(), 2)
504
505 def test_inferred_primitive_type(self):
506 codestr = """
507 from __static__ import ssize_t, box
508
509 def f():
510 x: ssize_t = 1
511 y = x
512 return box(y)
513 """
514 with self.in_module(codestr) as mod:
515 f = mod["f"]
516 self.assertEqual(f(), 1)
517
518 @skipIf(cinderjit is None, "not jitting")
519 def test_deep_attr_chain(self):
520 """this shouldn't explode exponentially"""
521 codestr = """
522 def f(x):
523 return x.x.x.x.x.x.x
524
525 """
526
527 class C:
528 def __init__(self):
529 self.x = self
530
531 orig_bind_attr = Object.bind_attr
532 call_count = 0
533
534 def bind_attr(*args):
535 nonlocal call_count
536 call_count += 1
537 return orig_bind_attr(*args)
538
539 with patch("compiler.static.Object.bind_attr", bind_attr):
540 with self.in_module(codestr) as mod:
541 f = mod["f"]
542 x = C()
543 self.assertEqual(f(x), x)
544 # Initially this would be 63 when we were double visiting
545 self.assertLess(call_count, 10)
546
547 @skipIf(cinderjit is None, "not jitting")
548 def test_no_frame(self):
549 codestr = """
550 from __static__.compiler_flags import noframe
551
552 def f():
553 return 456
554 """
555 with self.in_module(codestr) as mod:
556 f = mod["f"]
557 self.assertTrue(f.__code__.co_flags & CO_NO_FRAME)
558 self.assertEqual(f(), 456)
559 self.assert_jitted(f)
560
561 @skipIf(cinderjit is None, "not jitting")
562 def test_no_frame_generator(self):
563 codestr = """
564 from __static__.compiler_flags import noframe
565
566 def g():
567 for i in range(10):
568 yield i
569 def f():
570 return list(g())
571 """
572 with self.in_module(codestr) as mod:
573 f = mod["f"]
574 self.assertTrue(f.__code__.co_flags & CO_NO_FRAME)
575 self.assertEqual(f(), list(range(10)))
576 self.assert_jitted(f)
577
578 def test_subclass_binop(self):
579 codestr = """
580 class C: pass
581 class D(C): pass
582
583 def f(x: C, y: D):
584 return x + y
585 """
586 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
587 f = self.find_code(code, "f")
588 self.assertInBytecode(f, "BINARY_ADD")
589
590 def test_exact_invoke_function(self):
591 codestr = """
592 def f() -> str:
593 return ", ".join(['1','2','3'])
594 """
595 f = self.find_code(self.compile(codestr))
596 with self.in_module(codestr) as mod:
597 f = mod["f"]
598 self.assertInBytecode(
599 f, "INVOKE_FUNCTION", (("builtins", "str", "join"), 2)
600 )
601 f()
602
603 def test_multiply_list_exact_by_int(self):
604 codestr = """
605 def f() -> int:
606 l = [1, 2, 3] * 2
607 return len(l)
608 """
609 f = self.find_code(self.compile(codestr))
610 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST)
611 with self.in_module(codestr) as mod:
612 self.assertEqual(mod["f"](), 6)
613
614 def test_multiply_list_exact_by_int_reverse(self):
615 codestr = """
616 def f() -> int:
617 l = 2 * [1, 2, 3]
618 return len(l)
619 """
620 f = self.find_code(self.compile(codestr))
621 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST)
622 with self.in_module(codestr) as mod:
623 self.assertEqual(mod["f"](), 6)
624
625 def test_int_bad_assign(self):
626 with self.assertRaisesRegex(
627 TypedSyntaxError, "str cannot be used in a context where an int is expected"
628 ):
629 code = self.compile(
630 """
631 from __static__ import ssize_t
632 def f():
633 x: ssize_t = 'abc'
634 """,
635 StaticCodeGenerator,
636 )
637
638 def test_sign_extend(self):
639 codestr = f"""
640 from __static__ import int16, int64, box
641 def testfunc():
642 x: int16 = -40
643 y: int64 = x
644 return box(y)
645 """
646 code = self.compile(codestr, StaticCodeGenerator)
647 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"]
648 self.assertEqual(f(), -40)
649
650 def test_field_size(self):
651 for type in [
652 "int8",
653 "int16",
654 "int32",
655 "int64",
656 "uint8",
657 "uint16",
658 "uint32",
659 "uint64",
660 ]:
661 codestr = f"""
662 from __static__ import {type}, box
663 class C{type}:
664 def __init__(self):
665 self.a: {type} = 1
666 self.b: {type} = 1
667
668 def testfunc(c: C{type}):
669 c.a = 2
670 c.b = 3
671 return box(c.a + c.b)
672 """
673 with self.subTest(type=type):
674 with self.in_module(codestr) as mod:
675 C = mod["C" + type]
676 f = mod["testfunc"]
677 self.assertEqual(f(C()), 5)
678
679 def test_field_sign_ext(self):
680 """tests that we do the correct sign extension when loading from a field"""
681 for type, val in [
682 ("int32", 65537),
683 ("int16", 256),
684 ("int8", 0x7F),
685 ("uint32", 65537),
686 ]:
687 codestr = f"""
688 from __static__ import {type}, box
689 class C{type}:
690 def __init__(self):
691 self.value: {type} = {val}
692
693 def testfunc(c: C{type}):
694 return box(c.value)
695 """
696 with self.subTest(type=type, val=val):
697 with self.in_module(codestr) as mod:
698 C = mod["C" + type]
699 f = mod["testfunc"]
700 self.assertEqual(f(C()), val)
701
702 def test_field_unsign_ext(self):
703 """tests that we do the correct sign extension when loading from a field"""
704 for type, val, test in [("uint32", 65537, -1)]:
705 codestr = f"""
706 from __static__ import {type}, int64, box
707 class C{type}:
708 def __init__(self):
709 self.value: {type} = {val}
710
711 def testfunc(c: C{type}):
712 z: int64 = {test}
713 if c.value < z:
714 return True
715 return False
716 """
717 with self.subTest(type=type, val=val, test=test):
718 with self.in_module(codestr) as mod:
719 C = mod["C" + type]
720 f = mod["testfunc"]
721 self.assertEqual(f(C()), False)
722
723 def test_field_sign_compare(self):
724 for type, val, test in [("int32", -1, -1)]:
725 codestr = f"""
726 from __static__ import {type}, box
727 class C{type}:
728 def __init__(self):
729 self.value: {type} = {val}
730
731 def testfunc(c: C{type}):
732 if c.value == {test}:
733 return True
734 return False
735 """
736 with self.subTest(type=type, val=val, test=test):
737 with self.in_module(codestr) as mod:
738 C = mod["C" + type]
739 f = mod["testfunc"]
740 self.assertTrue(f(C()))
741
742 def test_mixed_binop_sign(self):
743 """mixed signed/unsigned ops should be promoted to signed"""
744 codestr = """
745 from __static__ import int8, uint8, box
746 def testfunc():
747 x: uint8 = 42
748 y: int8 = 2
749 return box(x / y)
750 """
751 code = self.compile(codestr, StaticCodeGenerator)
752 f = self.find_code(code)
753 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_DIV_INT)
754 with self.in_module(codestr) as mod:
755 f = mod["testfunc"]
756 self.assertEqual(f(), 21)
757
758 codestr = """
759 from __static__ import int8, uint8, box
760 def testfunc():
761 x: int8 = 42
762 y: uint8 = 2
763 return box(x / y)
764 """
765 code = self.compile(codestr, StaticCodeGenerator)
766 f = self.find_code(code)
767 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_DIV_INT)
768 with self.in_module(codestr) as mod:
769 f = mod["testfunc"]
770 self.assertEqual(f(), 21)
771
772 def test_mixed_cmpop_sign(self):
773 """mixed signed/unsigned ops should be promoted to signed"""
774 codestr = """
775 from __static__ import int8, uint8, box
776 def testfunc(tst=False):
777 x: uint8 = 42
778 y: int8 = 2
779 if tst:
780 x += 1
781 y += 1
782
783 if x < y:
784 return True
785 return False
786 """
787 code = self.compile(codestr, StaticCodeGenerator)
788 f = self.find_code(code)
789 self.assertInBytecode(f, "INT_COMPARE_OP", PRIM_OP_LT_INT)
790 with self.in_module(codestr) as mod:
791 f = mod["testfunc"]
792 self.assertEqual(f(), False)
793
794 codestr = """
795 from __static__ import int8, uint8, box
796 def testfunc(tst=False):
797 x: int8 = 42
798 y: uint8 = 2
799 if tst:
800 x += 1
801 y += 1
802
803 if x < y:
804 return True
805 return False
806 """
807 code = self.compile(codestr, StaticCodeGenerator)
808 f = self.find_code(code)
809 self.assertInBytecode(f, "INT_COMPARE_OP", PRIM_OP_LT_INT)
810 with self.in_module(codestr) as mod:
811 f = mod["testfunc"]
812 self.assertEqual(f(), False)
813
814 def test_mixed_add_reversed(self):
815 codestr = """
816 from __static__ import int8, uint8, int64, box, int16
817 def testfunc(tst=False):
818 x: int8 = 42
819 y: int16 = 2
820 if tst:
821 x += 1
822 y += 1
823
824 return box(y + x)
825 """
826 code = self.compile(codestr, StaticCodeGenerator)
827 f = self.find_code(code)
828 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT)
829 with self.in_module(codestr) as mod:
830 f = mod["testfunc"]
831 self.assertEqual(f(), 44)
832
833 def test_mixed_tri_add(self):
834 codestr = """
835 from __static__ import int8, uint8, int64, box
836 def testfunc(tst=False):
837 x: uint8 = 42
838 y: int8 = 2
839 z: int64 = 3
840 if tst:
841 x += 1
842 y += 1
843
844 return box(x + y + z)
845 """
846 code = self.compile(codestr, StaticCodeGenerator)
847 f = self.find_code(code)
848 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT)
849 with self.in_module(codestr) as mod:
850 f = mod["testfunc"]
851 self.assertEqual(f(), 47)
852
853 def test_mixed_tri_add_unsigned(self):
854 """promote int/uint to int, can't add to uint64"""
855
856 codestr = """
857 from __static__ import int8, uint8, uint64, box
858 def testfunc(tst=False):
859 x: uint8 = 42
860 y: int8 = 2
861 z: uint64 = 3
862
863 return box(x + y + z)
864 """
865
866 with self.assertRaisesRegex(TypedSyntaxError, "cannot add int16 and uint64"):
867 self.compile(codestr, StaticCodeGenerator)
868
869 def test_store_signed_to_unsigned(self):
870
871 codestr = """
872 from __static__ import int8, uint8, uint64, box
873 def testfunc(tst=False):
874 x: uint8 = 42
875 y: int8 = 2
876 x = y
877 """
878 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("int8", "uint8")):
879 self.compile(codestr, StaticCodeGenerator)
880
881 def test_store_unsigned_to_signed(self):
882 """promote int/uint to int, can't add to uint64"""
883
884 codestr = """
885 from __static__ import int8, uint8, uint64, box
886 def testfunc(tst=False):
887 x: uint8 = 42
888 y: int8 = 2
889 y = x
890 """
891 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("uint8", "int8")):
892 self.compile(codestr, StaticCodeGenerator)
893
894 def test_mixed_assign_larger(self):
895 """promote int/uint to int16"""
896
897 codestr = """
898 from __static__ import int8, uint8, int16, box
899 def testfunc(tst=False):
900 x: uint8 = 42
901 y: int8 = 2
902 z: int16 = x + y
903
904 return box(z)
905 """
906 code = self.compile(codestr, StaticCodeGenerator)
907 f = self.find_code(code)
908 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT)
909 with self.in_module(codestr) as mod:
910 f = mod["testfunc"]
911 self.assertEqual(f(), 44)
912
913 def test_mixed_assign_larger_2(self):
914 """promote int/uint to int16"""
915
916 codestr = """
917 from __static__ import int8, uint8, int16, box
918 def testfunc(tst=False):
919 x: uint8 = 42
920 y: int8 = 2
921 z: int16
922 z = x + y
923
924 return box(z)
925 """
926 code = self.compile(codestr, StaticCodeGenerator)
927 f = self.find_code(code)
928 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT)
929 with self.in_module(codestr) as mod:
930 f = mod["testfunc"]
931 self.assertEqual(f(), 44)
932
933 @skipIf(True, "this isn't implemented yet")
934 def test_unwind(self):
935 codestr = f"""
936 from __static__ import int32
937 def raises():
938 raise IndexError()
939
940 def testfunc():
941 x: int32 = 1
942 raises()
943 print(x)
944 """
945
946 with self.in_module(codestr) as mod:
947 f = mod["testfunc"]
948 with self.assertRaises(IndexError):
949 f()
950
951 def test_int_constant_range(self):
952 for type, val, low, high in [
953 ("int8", 128, -128, 127),
954 ("int8", -129, -128, 127),
955 ("int16", 32768, -32768, 32767),
956 ("int16", -32769, -32768, 32767),
957 ("int32", 2147483648, -2147483648, 2147483647),
958 ("int32", -2147483649, -2147483648, 2147483647),
959 ("int64", 9223372036854775808, -9223372036854775808, 9223372036854775807),
960 ("int64", -9223372036854775809, -9223372036854775808, 9223372036854775807),
961 ("uint8", 257, 0, 255),
962 ("uint8", -1, 0, 255),
963 ("uint16", 65537, 0, 65535),
964 ("uint16", -1, 0, 65535),
965 ("uint32", 4294967297, 0, 4294967295),
966 ("uint32", -1, 0, 4294967295),
967 ("uint64", 18446744073709551617, 0, 18446744073709551615),
968 ("uint64", -1, 0, 18446744073709551615),
969 ]:
970 codestr = f"""
971 from __static__ import {type}
972 def testfunc(tst):
973 x: {type} = {val}
974 """
975 with self.subTest(type=type, val=val, low=low, high=high):
976 with self.assertRaisesRegex(
977 TypedSyntaxError,
978 f"constant {val} is outside of the range {low} to {high} for {type}",
979 ):
980 self.compile(codestr, StaticCodeGenerator)
981
982 def test_int_assign_float(self):
983 codestr = """
984 from __static__ import int8
985 def testfunc(tst):
986 x: int8 = 1.0
987 """
988 with self.assertRaisesRegex(
989 TypedSyntaxError,
990 f"float cannot be used in a context where an int is expected",
991 ):
992 self.compile(codestr, StaticCodeGenerator)
993
994 def test_int_assign_str_constant(self):
995 codestr = """
996 from __static__ import int8
997 def testfunc(tst):
998 x: int8 = 'abc' + 'def'
999 """
1000 with self.assertRaisesRegex(
1001 TypedSyntaxError,
1002 f"str cannot be used in a context where an int is expected",
1003 ):
1004 self.compile(codestr, StaticCodeGenerator)
1005
1006 def test_int_large_int_constant(self):
1007 codestr = """
1008 from __static__ import int64
1009 def testfunc(tst):
1010 x: int64 = 0x7FFFFFFF + 1
1011 """
1012 code = self.compile(codestr, StaticCodeGenerator)
1013 f = self.find_code(code)
1014 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (0x80000000, TYPED_INT64))
1015
1016 def test_int_int_constant(self):
1017 codestr = """
1018 from __static__ import int64
1019 def testfunc(tst):
1020 x: int64 = 0x7FFFFFFE + 1
1021 """
1022 code = self.compile(codestr, StaticCodeGenerator)
1023 f = self.find_code(code)
1024 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (0x7FFFFFFF, TYPED_INT64))
1025
1026 def test_int_add_mixed_64(self):
1027 codestr = """
1028 from __static__ import uint64, int64, box
1029 def testfunc(tst):
1030 x: uint64 = 0
1031 y: int64 = 1
1032 if tst:
1033 x = x + 1
1034 y = y + 2
1035
1036 return box(x + y)
1037 """
1038 with self.assertRaisesRegex(TypedSyntaxError, "cannot add uint64 and int64"):
1039 self.compile(codestr, StaticCodeGenerator)
1040
1041 def test_int_overflow_add(self):
1042 tests = [
1043 ("int8", 100, 100, -56),
1044 ("int16", 200, 200, 400),
1045 ("int32", 200, 200, 400),
1046 ("int64", 200, 200, 400),
1047 ("int16", 20000, 20000, -25536),
1048 ("int32", 40000, 40000, 80000),
1049 ("int64", 40000, 40000, 80000),
1050 ("int32", 2000000000, 2000000000, -294967296),
1051 ("int64", 2000000000, 2000000000, 4000000000),
1052 ("int8", 127, 127, -2),
1053 ("int16", 32767, 32767, -2),
1054 ("int32", 2147483647, 2147483647, -2),
1055 ("int64", 9223372036854775807, 9223372036854775807, -2),
1056 ("uint8", 200, 200, 144),
1057 ("uint16", 200, 200, 400),
1058 ("uint32", 200, 200, 400),
1059 ("uint64", 200, 200, 400),
1060 ("uint16", 40000, 40000, 14464),
1061 ("uint32", 40000, 40000, 80000),
1062 ("uint64", 40000, 40000, 80000),
1063 ("uint32", 2000000000, 2000000000, 4000000000),
1064 ("uint64", 2000000000, 2000000000, 4000000000),
1065 ("uint8", 1 << 7, 1 << 7, 0),
1066 ("uint16", 1 << 15, 1 << 15, 0),
1067 ("uint32", 1 << 31, 1 << 31, 0),
1068 ("uint64", 1 << 63, 1 << 63, 0),
1069 ("uint8", 1 << 6, 1 << 6, 128),
1070 ("uint16", 1 << 14, 1 << 14, 32768),
1071 ("uint32", 1 << 30, 1 << 30, 2147483648),
1072 ("uint64", 1 << 62, 1 << 62, 9223372036854775808),
1073 ]
1074
1075 for type, x, y, res in tests:
1076 codestr = f"""
1077 from __static__ import {type}, box
1078 def f():
1079 x: {type} = {x}
1080 y: {type} = {y}
1081 z: {type} = x + y
1082 return box(z)
1083 """
1084 with self.subTest(type=type, x=x, y=y, res=res):
1085 code = self.compile(codestr, StaticCodeGenerator)
1086 f = self.run_code(codestr, StaticCodeGenerator)["f"]
1087 self.assertEqual(f(), res, f"{type} {x} {y} {res}")
1088
1089 def test_int_unary(self):
1090 tests = [
1091 ("int8", "-", 1, -1),
1092 ("uint8", "-", 1, (1 << 8) - 1),
1093 ("int16", "-", 1, -1),
1094 ("int16", "-", 256, -256),
1095 ("uint16", "-", 1, (1 << 16) - 1),
1096 ("int32", "-", 1, -1),
1097 ("int32", "-", 65536, -65536),
1098 ("uint32", "-", 1, (1 << 32) - 1),
1099 ("int64", "-", 1, -1),
1100 ("int64", "-", 1 << 32, -(1 << 32)),
1101 ("uint64", "-", 1, (1 << 64) - 1),
1102 ("int8", "~", 1, -2),
1103 ("uint8", "~", 1, (1 << 8) - 2),
1104 ("int16", "~", 1, -2),
1105 ("uint16", "~", 1, (1 << 16) - 2),
1106 ("int32", "~", 1, -2),
1107 ("uint32", "~", 1, (1 << 32) - 2),
1108 ("int64", "~", 1, -2),
1109 ("uint64", "~", 1, (1 << 64) - 2),
1110 ]
1111 for type, op, x, res in tests:
1112 codestr = f"""
1113 from __static__ import {type}, box
1114 def testfunc(tst):
1115 x: {type} = {x}
1116 if tst:
1117 x = x + 1
1118 x = {op}x
1119 return box(x)
1120 """
1121 with self.subTest(type=type, op=op, x=x, res=res):
1122 code = self.compile(codestr, StaticCodeGenerator)
1123 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"]
1124 self.assertEqual(f(False), res, f"{type} {op} {x} {res}")
1125
1126 def test_int_compare(self):
1127 tests = [
1128 ("int8", 1, 2, "==", False),
1129 ("int8", 1, 2, "!=", True),
1130 ("int8", 1, 2, "<", True),
1131 ("int8", 1, 2, "<=", True),
1132 ("int8", 2, 1, "<", False),
1133 ("int8", 2, 1, "<=", False),
1134 ("int8", -1, 2, "==", False),
1135 ("int8", -1, 2, "!=", True),
1136 ("int8", -1, 2, "<", True),
1137 ("int8", -1, 2, "<=", True),
1138 ("int8", 2, -1, "<", False),
1139 ("int8", 2, -1, "<=", False),
1140 ("uint8", 1, 2, "==", False),
1141 ("uint8", 1, 2, "!=", True),
1142 ("uint8", 1, 2, "<", True),
1143 ("uint8", 1, 2, "<=", True),
1144 ("uint8", 2, 1, "<", False),
1145 ("uint8", 2, 1, "<=", False),
1146 ("uint8", 255, 2, "==", False),
1147 ("uint8", 255, 2, "!=", True),
1148 ("uint8", 255, 2, "<", False),
1149 ("uint8", 255, 2, "<=", False),
1150 ("uint8", 2, 255, "<", True),
1151 ("uint8", 2, 255, "<=", True),
1152 ("int16", 1, 2, "==", False),
1153 ("int16", 1, 2, "!=", True),
1154 ("int16", 1, 2, "<", True),
1155 ("int16", 1, 2, "<=", True),
1156 ("int16", 2, 1, "<", False),
1157 ("int16", 2, 1, "<=", False),
1158 ("int16", -1, 2, "==", False),
1159 ("int16", -1, 2, "!=", True),
1160 ("int16", -1, 2, "<", True),
1161 ("int16", -1, 2, "<=", True),
1162 ("int16", 2, -1, "<", False),
1163 ("int16", 2, -1, "<=", False),
1164 ("uint16", 1, 2, "==", False),
1165 ("uint16", 1, 2, "!=", True),
1166 ("uint16", 1, 2, "<", True),
1167 ("uint16", 1, 2, "<=", True),
1168 ("uint16", 2, 1, "<", False),
1169 ("uint16", 2, 1, "<=", False),
1170 ("uint16", 65535, 2, "==", False),
1171 ("uint16", 65535, 2, "!=", True),
1172 ("uint16", 65535, 2, "<", False),
1173 ("uint16", 65535, 2, "<=", False),
1174 ("uint16", 2, 65535, "<", True),
1175 ("uint16", 2, 65535, "<=", True),
1176 ("int32", 1, 2, "==", False),
1177 ("int32", 1, 2, "!=", True),
1178 ("int32", 1, 2, "<", True),
1179 ("int32", 1, 2, "<=", True),
1180 ("int32", 2, 1, "<", False),
1181 ("int32", 2, 1, "<=", False),
1182 ("int32", -1, 2, "==", False),
1183 ("int32", -1, 2, "!=", True),
1184 ("int32", -1, 2, "<", True),
1185 ("int32", -1, 2, "<=", True),
1186 ("int32", 2, -1, "<", False),
1187 ("int32", 2, -1, "<=", False),
1188 ("uint32", 1, 2, "==", False),
1189 ("uint32", 1, 2, "!=", True),
1190 ("uint32", 1, 2, "<", True),
1191 ("uint32", 1, 2, "<=", True),
1192 ("uint32", 2, 1, "<", False),
1193 ("uint32", 2, 1, "<=", False),
1194 ("uint32", 4294967295, 2, "!=", True),
1195 ("uint32", 4294967295, 2, "<", False),
1196 ("uint32", 4294967295, 2, "<=", False),
1197 ("uint32", 2, 4294967295, "<", True),
1198 ("uint32", 2, 4294967295, "<=", True),
1199 ("int64", 1, 2, "==", False),
1200 ("int64", 1, 2, "!=", True),
1201 ("int64", 1, 2, "<", True),
1202 ("int64", 1, 2, "<=", True),
1203 ("int64", 2, 1, "<", False),
1204 ("int64", 2, 1, "<=", False),
1205 ("int64", -1, 2, "==", False),
1206 ("int64", -1, 2, "!=", True),
1207 ("int64", -1, 2, "<", True),
1208 ("int64", -1, 2, "<=", True),
1209 ("int64", 2, -1, "<", False),
1210 ("int64", 2, -1, "<=", False),
1211 ("uint64", 1, 2, "==", False),
1212 ("uint64", 1, 2, "!=", True),
1213 ("uint64", 1, 2, "<", True),
1214 ("uint64", 1, 2, "<=", True),
1215 ("uint64", 2, 1, "<", False),
1216 ("uint64", 2, 1, "<=", False),
1217 ("int64", 2, -1, ">", True),
1218 ("uint64", 2, 18446744073709551615, ">", False),
1219 ("int64", 2, -1, "<", False),
1220 ("uint64", 2, 18446744073709551615, "<", True),
1221 ("int64", 2, -1, ">=", True),
1222 ("uint64", 2, 18446744073709551615, ">=", False),
1223 ("int64", 2, -1, "<=", False),
1224 ("uint64", 2, 18446744073709551615, "<=", True),
1225 ]
1226 for type, x, y, op, res in tests:
1227 codestr = f"""
1228 from __static__ import {type}, box
1229 def testfunc(tst):
1230 x: {type} = {x}
1231 y: {type} = {y}
1232 if tst:
1233 x = x + 1
1234 y = y + 2
1235
1236 if x {op} y:
1237 return True
1238 return False
1239 """
1240 with self.subTest(type=type, x=x, y=y, op=op, res=res):
1241 code = self.compile(codestr, StaticCodeGenerator)
1242 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"]
1243 self.assertEqual(f(False), res, f"{type} {x} {op} {y} {res}")
1244
1245 def test_int_compare_unboxed(self):
1246 codestr = f"""
1247 from __static__ import ssize_t, unbox
1248 def testfunc(x, y):
1249 x1: ssize_t = unbox(x)
1250 y1: ssize_t = unbox(y)
1251
1252 if x1 > y1:
1253 return True
1254 return False
1255 """
1256 code = self.compile(codestr, StaticCodeGenerator)
1257 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"]
1258 self.assertInBytecode(f, "POP_JUMP_IF_ZERO")
1259 self.assertEqual(f(1, 2), False)
1260
1261 def test_int_compare_mixed(self):
1262 codestr = """
1263 from __static__ import box, ssize_t
1264 x = 1
1265
1266 def testfunc():
1267 i: ssize_t = 0
1268 j = 0
1269 while box(i < 100) and x:
1270 i = i + 1
1271 j = j + 1
1272 return j
1273 """
1274
1275 code = self.compile(codestr, StaticCodeGenerator)
1276 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"]
1277 self.assertEqual(f(), 100)
1278 self.assert_jitted(f)
1279
1280 def test_int_compare_or(self):
1281 codestr = """
1282 from __static__ import box, ssize_t
1283
1284 def testfunc():
1285 i: ssize_t = 0
1286 j = i > 2 or i < -2
1287 return box(j)
1288 """
1289
1290 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
1291 f = mod["testfunc"]
1292 self.assertInBytecode(f, "JUMP_IF_NONZERO_OR_POP")
1293 self.assertIs(f(), False)
1294
1295 def test_int_compare_and(self):
1296 codestr = """
1297 from __static__ import box, ssize_t
1298
1299 def testfunc():
1300 i: ssize_t = 0
1301 j = i > 2 and i > 3
1302 return box(j)
1303 """
1304
1305 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
1306 f = mod["testfunc"]
1307 self.assertInBytecode(f, "JUMP_IF_ZERO_OR_POP")
1308 self.assertIs(f(), False)
1309
1310 def test_disallow_prim_nonprim_union(self):
1311 codestr = """
1312 from __static__ import int32
1313
1314 def f(y: int):
1315 x: int32 = 2
1316 z = x or y
1317 return z
1318 """
1319 with self.assertRaisesRegex(
1320 TypedSyntaxError,
1321 r"invalid union type Union\[int32, int\]; unions cannot include primitive types",
1322 ):
1323 self.compile(codestr)
1324
1325 def test_int_binop(self):
1326 tests = [
1327 ("int8", 1, 2, "/", 0),
1328 ("int8", 4, 2, "/", 2),
1329 ("int8", 4, -2, "/", -2),
1330 ("uint8", 0xFF, 0x7F, "/", 2),
1331 ("int16", 4, -2, "/", -2),
1332 ("uint16", 0xFF, 0x7F, "/", 2),
1333 ("uint32", 0xFFFF, 0x7FFF, "/", 2),
1334 ("int32", 4, -2, "/", -2),
1335 ("uint32", 0xFF, 0x7F, "/", 2),
1336 ("uint32", 0xFFFFFFFF, 0x7FFFFFFF, "/", 2),
1337 ("int64", 4, -2, "/", -2),
1338 ("uint64", 0xFF, 0x7F, "/", 2),
1339 ("uint64", 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, "/", 2),
1340 ("int8", 1, -2, "-", 3),
1341 ("int8", 1, 2, "-", -1),
1342 ("int16", 1, -2, "-", 3),
1343 ("int16", 1, 2, "-", -1),
1344 ("int32", 1, -2, "-", 3),
1345 ("int32", 1, 2, "-", -1),
1346 ("int64", 1, -2, "-", 3),
1347 ("int64", 1, 2, "-", -1),
1348 ("int8", 1, -2, "*", -2),
1349 ("int8", 1, 2, "*", 2),
1350 ("int16", 1, -2, "*", -2),
1351 ("int16", 1, 2, "*", 2),
1352 ("int32", 1, -2, "*", -2),
1353 ("int32", 1, 2, "*", 2),
1354 ("int64", 1, -2, "*", -2),
1355 ("int64", 1, 2, "*", 2),
1356 ("int8", 1, -2, "&", 0),
1357 ("int8", 1, 3, "&", 1),
1358 ("int16", 1, 3, "&", 1),
1359 ("int16", 1, 3, "&", 1),
1360 ("int32", 1, 3, "&", 1),
1361 ("int32", 1, 3, "&", 1),
1362 ("int64", 1, 3, "&", 1),
1363 ("int64", 1, 3, "&", 1),
1364 ("int8", 1, 2, "|", 3),
1365 ("uint8", 1, 2, "|", 3),
1366 ("int16", 1, 2, "|", 3),
1367 ("uint16", 1, 2, "|", 3),
1368 ("int32", 1, 2, "|", 3),
1369 ("uint32", 1, 2, "|", 3),
1370 ("int64", 1, 2, "|", 3),
1371 ("uint64", 1, 2, "|", 3),
1372 ("int8", 1, 3, "^", 2),
1373 ("uint8", 1, 3, "^", 2),
1374 ("int16", 1, 3, "^", 2),
1375 ("uint16", 1, 3, "^", 2),
1376 ("int32", 1, 3, "^", 2),
1377 ("uint32", 1, 3, "^", 2),
1378 ("int64", 1, 3, "^", 2),
1379 ("uint64", 1, 3, "^", 2),
1380 ("int8", 1, 3, "%", 1),
1381 ("uint8", 1, 3, "%", 1),
1382 ("int16", 1, 3, "%", 1),
1383 ("uint16", 1, 3, "%", 1),
1384 ("int32", 1, 3, "%", 1),
1385 ("uint32", 1, 3, "%", 1),
1386 ("int64", 1, 3, "%", 1),
1387 ("uint64", 1, 3, "%", 1),
1388 ("int8", 1, -3, "%", 1),
1389 ("uint8", 1, 0xFF, "%", 1),
1390 ("int16", 1, -3, "%", 1),
1391 ("uint16", 1, 0xFFFF, "%", 1),
1392 ("int32", 1, -3, "%", 1),
1393 ("uint32", 1, 0xFFFFFFFF, "%", 1),
1394 ("int64", 1, -3, "%", 1),
1395 ("uint64", 1, 0xFFFFFFFFFFFFFFFF, "%", 1),
1396 ("int8", 1, 2, "<<", 4),
1397 ("uint8", 1, 2, "<<", 4),
1398 ("int16", 1, 2, "<<", 4),
1399 ("uint16", 1, 2, "<<", 4),
1400 ("int32", 1, 2, "<<", 4),
1401 ("uint32", 1, 2, "<<", 4),
1402 ("int64", 1, 2, "<<", 4),
1403 ("uint64", 1, 2, "<<", 4),
1404 ("int8", 4, 1, ">>", 2),
1405 ("int8", -1, 1, ">>", -1),
1406 ("uint8", 0xFF, 1, ">>", 127),
1407 ("int16", 4, 1, ">>", 2),
1408 ("int16", -1, 1, ">>", -1),
1409 ("uint16", 0xFFFF, 1, ">>", 32767),
1410 ("int32", 4, 1, ">>", 2),
1411 ("int32", -1, 1, ">>", -1),
1412 ("uint32", 0xFFFFFFFF, 1, ">>", 2147483647),
1413 ("int64", 4, 1, ">>", 2),
1414 ("int64", -1, 1, ">>", -1),
1415 ("uint64", 0xFFFFFFFFFFFFFFFF, 1, ">>", 9223372036854775807),
1416 ]
1417 for type, x, y, op, res in tests:
1418 codestr = f"""
1419 from __static__ import {type}, box
1420 def testfunc(tst):
1421 x: {type} = {x}
1422 y: {type} = {y}
1423 if tst:
1424 x = x + 1
1425 y = y + 2
1426
1427 z: {type} = x {op} y
1428 return box(z), box(x {op} y)
1429 """
1430 with self.subTest(type=type, x=x, y=y, op=op, res=res):
1431 with self.in_module(codestr) as mod:
1432 f = mod["testfunc"]
1433 self.assertEqual(f(False), (res, res), f"{type} {x} {op} {y} {res}")
1434
1435 def test_primitive_arithmetic(self):
1436 cases = [
1437 ("int8", 127, "*", 1, 127),
1438 ("int8", -64, "*", 2, -128),
1439 ("int8", 0, "*", 4, 0),
1440 ("uint8", 51, "*", 5, 255),
1441 ("uint8", 5, "*", 0, 0),
1442 ("int16", 3123, "*", -10, -31230),
1443 ("int16", -32767, "*", -1, 32767),
1444 ("int16", -32768, "*", 1, -32768),
1445 ("int16", 3, "*", 0, 0),
1446 ("uint16", 65535, "*", 1, 65535),
1447 ("uint16", 0, "*", 4, 0),
1448 ("int32", (1 << 31) - 1, "*", 1, (1 << 31) - 1),
1449 ("int32", -(1 << 30), "*", 2, -(1 << 31)),
1450 ("int32", 0, "*", 1, 0),
1451 ("uint32", (1 << 32) - 1, "*", 1, (1 << 32) - 1),
1452 ("uint32", 0, "*", 4, 0),
1453 ("int64", (1 << 63) - 1, "*", 1, (1 << 63) - 1),
1454 ("int64", -(1 << 62), "*", 2, -(1 << 63)),
1455 ("int64", 0, "*", 1, 0),
1456 ("uint64", (1 << 64) - 1, "*", 1, (1 << 64) - 1),
1457 ("uint64", 0, "*", 4, 0),
1458 ("int8", 127, "//", 4, 31),
1459 ("int8", -128, "//", 4, -32),
1460 ("int8", 0, "//", 4, 0),
1461 ("uint8", 255, "//", 5, 51),
1462 ("uint8", 0, "//", 5, 0),
1463 ("int16", 32767, "//", -1000, -32),
1464 ("int16", -32768, "//", -1000, 32),
1465 ("int16", 0, "//", 4, 0),
1466 ("uint16", 65535, "//", 5, 13107),
1467 ("uint16", 0, "//", 4, 0),
1468 ("int32", (1 << 31) - 1, "//", (1 << 31) - 1, 1),
1469 ("int32", -(1 << 31), "//", 1, -(1 << 31)),
1470 ("int32", 0, "//", 1, 0),
1471 ("uint32", (1 << 32) - 1, "//", 500, 8589934),
1472 ("uint32", 0, "//", 4, 0),
1473 ("int64", (1 << 63) - 1, "//", 2, (1 << 62) - 1),
1474 ("int64", -(1 << 63), "//", 2, -(1 << 62)),
1475 ("int64", 0, "//", 1, 0),
1476 ("uint64", (1 << 64) - 1, "//", (1 << 64) - 1, 1),
1477 ("uint64", 0, "//", 4, 0),
1478 ("int8", 127, "%", 4, 3),
1479 ("int8", -128, "%", 4, 0),
1480 ("int8", 0, "%", 4, 0),
1481 ("uint8", 255, "%", 6, 3),
1482 ("uint8", 0, "%", 5, 0),
1483 ("int16", 32767, "%", -1000, 767),
1484 ("int16", -32768, "%", -1000, -768),
1485 ("int16", 0, "%", 4, 0),
1486 ("uint16", 65535, "%", 7, 1),
1487 ("uint16", 0, "%", 4, 0),
1488 ("int32", (1 << 31) - 1, "%", (1 << 31) - 1, 0),
1489 ("int32", -(1 << 31), "%", 1, 0),
1490 ("int32", 0, "%", 1, 0),
1491 ("uint32", (1 << 32) - 1, "%", 500, 295),
1492 ("uint32", 0, "%", 4, 0),
1493 ("int64", (1 << 63) - 1, "%", 2, 1),
1494 ("int64", -(1 << 63), "%", 2, 0),
1495 ("int64", 0, "%", 1, 0),
1496 ("uint64", (1 << 64) - 1, "%", (1 << 64) - 1, 0),
1497 ("uint64", 0, "%", 4, 0),
1498 ]
1499 for typ, a, op, b, res in cases:
1500 for const in ["noconst", "constfirst", "constsecond"]:
1501 if const == "noconst":
1502 codestr = f"""
1503 from __static__ import {typ}
1504
1505 def f(a: {typ}, b: {typ}) -> {typ}:
1506 return a {op} b
1507 """
1508 elif const == "constfirst":
1509 codestr = f"""
1510 from __static__ import {typ}
1511
1512 def f(b: {typ}) -> {typ}:
1513 return {a} {op} b
1514 """
1515 elif const == "constsecond":
1516 codestr = f"""
1517 from __static__ import {typ}
1518
1519 def f(a: {typ}) -> {typ}:
1520 return a {op} {b}
1521 """
1522
1523 with self.subTest(typ=typ, a=a, op=op, b=b, res=res, const=const):
1524 with self.in_module(codestr) as mod:
1525 f = mod["f"]
1526 act = None
1527 if const == "noconst":
1528 act = f(a, b)
1529 elif const == "constfirst":
1530 act = f(b)
1531 elif const == "constsecond":
1532 act = f(a)
1533 self.assertEqual(act, res)
1534
1535 def test_int_binop_type_context(self):
1536 codestr = f"""
1537 from __static__ import box, int8, int16
1538
1539 def f(x: int8, y: int8) -> int:
1540 z: int16 = x * y
1541 return box(z)
1542 """
1543 with self.in_module(codestr) as mod:
1544 f = mod["f"]
1545 self.assertInBytecode(
1546 f, "CONVERT_PRIMITIVE", TYPED_INT8 | (TYPED_INT16 << 4)
1547 )
1548 self.assertEqual(f(120, 120), 14400)
1549
1550 def test_int_compare_mixed_sign(self):
1551 tests = [
1552 ("uint16", 10000, "int16", -1, "<", False),
1553 ("uint16", 10000, "int16", -1, "<=", False),
1554 ("int16", -1, "uint16", 10000, ">", False),
1555 ("int16", -1, "uint16", 10000, ">=", False),
1556 ("uint32", 10000, "int16", -1, "<", False),
1557 ]
1558 for type1, x, type2, y, op, res in tests:
1559 codestr = f"""
1560 from __static__ import {type1}, {type2}, box
1561 def testfunc(tst):
1562 x: {type1} = {x}
1563 y: {type2} = {y}
1564 if tst:
1565 x = x + 1
1566 y = y + 2
1567
1568 if x {op} y:
1569 return True
1570 return False
1571 """
1572 with self.subTest(type1=type1, x=x, type2=type2, y=y, op=op, res=res):
1573 code = self.compile(codestr, StaticCodeGenerator)
1574 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"]
1575 self.assertEqual(f(False), res, f"{type} {x} {op} {y} {res}")
1576
1577 def test_int_compare64_mixed_sign(self):
1578 codestr = """
1579 from __static__ import uint64, int64
1580 def testfunc(tst):
1581 x: uint64 = 0
1582 y: int64 = 1
1583 if tst:
1584 x = x + 1
1585 y = y + 2
1586
1587 if x < y:
1588 return True
1589 return False
1590 """
1591 with self.assertRaises(TypedSyntaxError):
1592 self.compile(codestr, StaticCodeGenerator)
1593
1594 def test_compile_method(self):
1595 code = self.compile(
1596 """
1597 from __static__ import ssize_t
1598 def f():
1599 x: ssize_t = 42
1600 """,
1601 StaticCodeGenerator,
1602 )
1603
1604 f = self.find_code(code)
1605 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (42, TYPED_INT64))
1606
1607 def test_mixed_compare(self):
1608 codestr = """
1609 from __static__ import ssize_t, box, unbox
1610 def f(a):
1611 x: ssize_t = 0
1612 while x != a:
1613 pass
1614 """
1615 with self.assertRaisesRegex(TypedSyntaxError, "can't compare int64 to dynamic"):
1616 self.compile(codestr, StaticCodeGenerator)
1617
1618 def test_unbox(self):
1619 for size, val in [
1620 ("int8", 126),
1621 ("int8", -128),
1622 ("int16", 32766),
1623 ("int16", -32768),
1624 ("int32", 2147483646),
1625 ("int32", -2147483648),
1626 ("int64", 9223372036854775806),
1627 ("int64", -9223372036854775808),
1628 ("uint8", 254),
1629 ("uint16", 65534),
1630 ("uint32", 4294967294),
1631 ("uint64", 18446744073709551614),
1632 ]:
1633 codestr = f"""
1634 from __static__ import {size}, box, unbox
1635 def f(x):
1636 y: {size} = unbox(x)
1637 y = y + 1
1638 return box(y)
1639 """
1640
1641 code = self.compile(codestr, StaticCodeGenerator)
1642 f = self.find_code(code)
1643 f = self.run_code(codestr, StaticCodeGenerator)["f"]
1644 self.assertEqual(f(val), val + 1)
1645
1646 def test_int_loop_inplace(self):
1647 codestr = """
1648 from __static__ import ssize_t, box
1649 def f():
1650 i: ssize_t = 0
1651 while i < 100:
1652 i += 1
1653 return box(i)
1654 """
1655
1656 code = self.compile(codestr, StaticCodeGenerator)
1657 f = self.find_code(code)
1658 f = self.run_code(codestr, StaticCodeGenerator)["f"]
1659 self.assertEqual(f(), 100)
1660
1661 def test_int_loop(self):
1662 codestr = """
1663 from __static__ import ssize_t, box
1664 def testfunc():
1665 i: ssize_t = 0
1666 while i < 100:
1667 i = i + 1
1668 return box(i)
1669 """
1670
1671 code = self.compile(codestr, StaticCodeGenerator)
1672 f = self.find_code(code)
1673
1674 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"]
1675 self.assertEqual(f(), 100)
1676
1677 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (0, TYPED_INT64))
1678 self.assertInBytecode(f, "LOAD_LOCAL", (0, ("__static__", "int64")))
1679 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT)
1680 self.assertInBytecode(f, "INT_COMPARE_OP", PRIM_OP_LT_INT)
1681 self.assertInBytecode(f, "POP_JUMP_IF_ZERO")
1682
1683 def test_int_assert(self):
1684 codestr = """
1685 from __static__ import ssize_t, box
1686 def testfunc():
1687 i: ssize_t = 0
1688 assert i == 0, "hello there"
1689 """
1690
1691 code = self.compile(codestr, StaticCodeGenerator)
1692 f = self.find_code(code)
1693 self.assertInBytecode(f, "POP_JUMP_IF_NONZERO")
1694
1695 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"]
1696 self.assertEqual(f(), None)
1697
1698 def test_int_assert_raises(self):
1699 codestr = """
1700 from __static__ import ssize_t, box
1701 def testfunc():
1702 i: ssize_t = 0
1703 assert i != 0, "hello there"
1704 """
1705
1706 code = self.compile(codestr, StaticCodeGenerator)
1707 f = self.find_code(code)
1708 self.assertInBytecode(f, "POP_JUMP_IF_NONZERO")
1709
1710 with self.assertRaises(AssertionError):
1711 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"]
1712 self.assertEqual(f(), None)
1713
1714 def test_int_loop_reversed(self):
1715 codestr = """
1716 from __static__ import ssize_t, box
1717 def testfunc():
1718 i: ssize_t = 0
1719 while 100 > i:
1720 i = i + 1
1721 return box(i)
1722 """
1723
1724 code = self.compile(codestr, StaticCodeGenerator)
1725 f = self.find_code(code)
1726 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"]
1727 self.assertEqual(f(), 100)
1728
1729 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (0, TYPED_INT64))
1730 self.assertInBytecode(f, "LOAD_LOCAL", (0, ("__static__", "int64")))
1731 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT)
1732 self.assertInBytecode(f, "INT_COMPARE_OP", PRIM_OP_GT_INT)
1733 self.assertInBytecode(f, "POP_JUMP_IF_ZERO")
1734
1735 def test_int_loop_chained(self):
1736 codestr = """
1737 from __static__ import ssize_t, box
1738 def testfunc():
1739 i: ssize_t = 0
1740 while -1 < i < 100:
1741 i = i + 1
1742 return box(i)
1743 """
1744
1745 code = self.compile(codestr, StaticCodeGenerator)
1746 f = self.find_code(code)
1747 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"]
1748 self.assertEqual(f(), 100)
1749
1750 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (0, TYPED_INT64))
1751 self.assertInBytecode(f, "LOAD_LOCAL", (0, ("__static__", "int64")))
1752 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT)
1753 self.assertInBytecode(f, "INT_COMPARE_OP", PRIM_OP_LT_INT)
1754 self.assertInBytecode(f, "POP_JUMP_IF_ZERO")
1755
1756 def test_compare_subclass(self):
1757 codestr = """
1758 class C: pass
1759 class D(C): pass
1760
1761 x = C() > D()
1762 """
1763 code = self.compile(codestr, StaticCodeGenerator)
1764 self.assertInBytecode(code, "COMPARE_OP")
1765
1766 def test_compat_int_math(self):
1767 codestr = """
1768 from __static__ import ssize_t, box
1769 def f():
1770 x: ssize_t = 42
1771 z: ssize_t = 1 + x
1772 return box(z)
1773 """
1774
1775 code = self.compile(codestr, StaticCodeGenerator)
1776 f = self.find_code(code)
1777 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (42, TYPED_INT64))
1778 self.assertInBytecode(f, "LOAD_LOCAL", (0, ("__static__", "int64")))
1779 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT)
1780 f = self.run_code(codestr, StaticCodeGenerator)["f"]
1781 self.assertEqual(f(), 43)
1782
1783 def test_unbox_long(self):
1784 codestr = """
1785 from __static__ import unbox, int64
1786 def f():
1787 x:int64 = unbox(1)
1788 """
1789
1790 self.compile(codestr, StaticCodeGenerator)
1791
1792 def test_unbox_str(self):
1793 codestr = """
1794 from __static__ import unbox, int64
1795 def f():
1796 x:int64 = unbox('abc')
1797 """
1798
1799 with self.in_module(codestr) as mod:
1800 f = mod["f"]
1801 with self.assertRaisesRegex(
1802 TypeError, "(expected int, got str)|(an integer is required)"
1803 ):
1804 f()
1805
1806 def test_unbox_typed(self):
1807 codestr = """
1808 from __static__ import int64, box
1809 def f(i: object):
1810 x = int64(i)
1811 return box(x)
1812 """
1813
1814 with self.in_module(codestr) as mod:
1815 f = mod["f"]
1816 self.assertEqual(f(42), 42)
1817 self.assertInBytecode(f, "PRIMITIVE_UNBOX")
1818 with self.assertRaisesRegex(
1819 TypeError, "(expected int, got str)|(an integer is required)"
1820 ):
1821 self.assertEqual(f("abc"), 42)
1822
1823 def test_unbox_typed_bool(self):
1824 codestr = """
1825 from __static__ import int64, box
1826 def f(i: object):
1827 x = int64(i)
1828 return box(x)
1829 """
1830
1831 with self.in_module(codestr) as mod:
1832 f = mod["f"]
1833 self.assertEqual(f(42), 42)
1834 self.assertInBytecode(f, "PRIMITIVE_UNBOX")
1835 self.assertEqual(f(True), 1)
1836 self.assertEqual(f(False), 0)
1837
1838 def test_unbox_incompat_type(self):
1839 codestr = """
1840 from __static__ import int64, box
1841 def f(i: str):
1842 x:int64 = int64(i)
1843 return box(x)
1844 """
1845
1846 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("str", "int64")):
1847 self.compile(codestr)
1848
1849 def test_uninit_value(self):
1850 codestr = """
1851 from __static__ import box, int64
1852 def f():
1853 x:int64
1854 return box(x)
1855 x = 0
1856 """
1857 f = self.run_code(codestr, StaticCodeGenerator)["f"]
1858 self.assertEqual(f(), 0)
1859
1860 def test_uninit_value_2(self):
1861 codestr = """
1862 from __static__ import box, int64
1863 def testfunc(x):
1864 if x:
1865 y:int64 = 42
1866 return box(y)
1867 """
1868 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"]
1869 self.assertEqual(f(False), 0)
1870
1871 def test_bad_box(self):
1872 codestr = """
1873 from __static__ import box
1874 box('abc')
1875 """
1876
1877 with self.assertRaisesRegex(
1878 TypedSyntaxError, "can't box non-primitive: Exact\\[str\\]"
1879 ):
1880 self.compile(codestr, StaticCodeGenerator)
1881
1882 def test_bad_unbox(self):
1883 codestr = """
1884 from __static__ import unbox, int64
1885 def f():
1886 x:int64 = 42
1887 unbox(x)
1888 """
1889
1890 with self.assertRaisesRegex(
1891 TypedSyntaxError, type_mismatch("int64", "dynamic")
1892 ):
1893 self.compile(codestr, StaticCodeGenerator)
1894
1895 def test_bad_box_2(self):
1896 codestr = """
1897 from __static__ import box
1898 box('abc', 'foo')
1899 """
1900
1901 with self.assertRaisesRegex(
1902 TypedSyntaxError, "box only accepts a single argument"
1903 ):
1904 self.compile(codestr, StaticCodeGenerator)
1905
1906 def test_bad_unbox_2(self):
1907 codestr = """
1908 from __static__ import unbox, int64
1909 def f():
1910 x:int64 = 42
1911 unbox(x, y)
1912 """
1913
1914 with self.assertRaisesRegex(
1915 TypedSyntaxError, "unbox only accepts a single argument"
1916 ):
1917 self.compile(codestr, StaticCodeGenerator)
1918
1919 def test_int_reassign(self):
1920 codestr = """
1921 from __static__ import ssize_t, box
1922 def f():
1923 x: ssize_t = 42
1924 z: ssize_t = 1 + x
1925 x = 100
1926 x = x + x
1927 return box(z)
1928 """
1929
1930 code = self.compile(codestr, StaticCodeGenerator)
1931 f = self.find_code(code)
1932 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (42, TYPED_INT64))
1933 self.assertInBytecode(f, "LOAD_LOCAL", (0, ("__static__", "int64")))
1934 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT)
1935 f = self.run_code(codestr, StaticCodeGenerator)["f"]
1936 self.assertEqual(f(), 43)
1937
1938 def test_assign_to_object(self):
1939 codestr = """
1940 def f():
1941 x: object
1942 x = None
1943 x = 1
1944 x = 'abc'
1945 x = []
1946 x = {}
1947 x = {1, 2}
1948 x = ()
1949 x = 1.0
1950 x = 1j
1951 x = b'foo'
1952 x = int
1953 x = True
1954 x = NotImplemented
1955 x = ...
1956 """
1957
1958 self.compile(codestr, StaticCodeGenerator)
1959
1960 def test_global_call_add(self) -> None:
1961 codestr = """
1962 X = ord(42)
1963 def f():
1964 y = X + 1
1965 """
1966 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
1967
1968 def test_type_binder(self) -> None:
1969 self.assertEqual(self.bind_expr("42"), INT_EXACT_TYPE.instance)
1970 self.assertEqual(self.bind_expr("42.0"), FLOAT_EXACT_TYPE.instance)
1971 self.assertEqual(self.bind_expr("'abc'"), STR_EXACT_TYPE.instance)
1972 self.assertEqual(self.bind_expr("b'abc'"), BYTES_TYPE.instance)
1973 self.assertEqual(self.bind_expr("3j"), COMPLEX_EXACT_TYPE.instance)
1974 self.assertEqual(self.bind_expr("None"), NONE_TYPE.instance)
1975 self.assertEqual(self.bind_expr("True"), BOOL_TYPE.instance)
1976 self.assertEqual(self.bind_expr("False"), BOOL_TYPE.instance)
1977 self.assertEqual(self.bind_expr("..."), ELLIPSIS_TYPE.instance)
1978 self.assertEqual(self.bind_expr("f''"), STR_EXACT_TYPE.instance)
1979 self.assertEqual(self.bind_expr("f'{x}'"), STR_EXACT_TYPE.instance)
1980
1981 self.assertEqual(self.bind_expr("a"), DYNAMIC)
1982 self.assertEqual(self.bind_expr("a.b"), DYNAMIC)
1983 self.assertEqual(self.bind_expr("a + b"), DYNAMIC)
1984
1985 self.assertEqual(self.bind_expr("1 + 2"), INT_EXACT_TYPE.instance)
1986 self.assertEqual(self.bind_expr("1 - 2"), INT_EXACT_TYPE.instance)
1987 self.assertEqual(self.bind_expr("1 // 2"), INT_EXACT_TYPE.instance)
1988 self.assertEqual(self.bind_expr("1 * 2"), INT_EXACT_TYPE.instance)
1989 self.assertEqual(self.bind_expr("1 / 2"), FLOAT_EXACT_TYPE.instance)
1990 self.assertEqual(self.bind_expr("1 % 2"), INT_EXACT_TYPE.instance)
1991 self.assertEqual(self.bind_expr("1 & 2"), INT_EXACT_TYPE.instance)
1992 self.assertEqual(self.bind_expr("1 | 2"), INT_EXACT_TYPE.instance)
1993 self.assertEqual(self.bind_expr("1 ^ 2"), INT_EXACT_TYPE.instance)
1994 self.assertEqual(self.bind_expr("1 << 2"), INT_EXACT_TYPE.instance)
1995 self.assertEqual(self.bind_expr("100 >> 2"), INT_EXACT_TYPE.instance)
1996
1997 self.assertEqual(self.bind_stmt("x = 1"), INT_EXACT_TYPE.instance)
1998 # self.assertEqual(self.bind_stmt("x: foo = 1").target.comp_type, DYNAMIC)
1999 self.assertEqual(self.bind_stmt("x += 1"), DYNAMIC)
2000 self.assertEqual(self.bind_expr("a or b"), DYNAMIC)
2001 self.assertEqual(self.bind_expr("+a"), DYNAMIC)
2002 self.assertEqual(self.bind_expr("not a"), BOOL_TYPE.instance)
2003 self.assertEqual(self.bind_expr("lambda: 42"), DYNAMIC)
2004 self.assertEqual(self.bind_expr("a if b else c"), DYNAMIC)
2005 self.assertEqual(self.bind_expr("x > y"), DYNAMIC)
2006 self.assertEqual(self.bind_expr("x()"), DYNAMIC)
2007 self.assertEqual(self.bind_expr("x(y)"), DYNAMIC)
2008 self.assertEqual(self.bind_expr("x[y]"), DYNAMIC)
2009 self.assertEqual(self.bind_expr("x[1:2]"), DYNAMIC)
2010 self.assertEqual(self.bind_expr("x[1:2:3]"), DYNAMIC)
2011 self.assertEqual(self.bind_expr("x[:]"), DYNAMIC)
2012 self.assertEqual(self.bind_expr("{}"), DICT_EXACT_TYPE.instance)
2013 self.assertEqual(self.bind_expr("{2:3}"), DICT_EXACT_TYPE.instance)
2014 self.assertEqual(self.bind_expr("{1,2}"), SET_EXACT_TYPE.instance)
2015 self.assertEqual(self.bind_expr("[]"), LIST_EXACT_TYPE.instance)
2016 self.assertEqual(self.bind_expr("[1,2]"), LIST_EXACT_TYPE.instance)
2017 self.assertEqual(self.bind_expr("(1,2)"), TUPLE_EXACT_TYPE.instance)
2018
2019 self.assertEqual(self.bind_expr("[x for x in y]"), LIST_EXACT_TYPE.instance)
2020 self.assertEqual(self.bind_expr("{x for x in y}"), SET_EXACT_TYPE.instance)
2021 self.assertEqual(self.bind_expr("{x:y for x in y}"), DICT_EXACT_TYPE.instance)
2022 self.assertEqual(self.bind_expr("(x for x in y)"), DYNAMIC)
2023
2024 def body_get(stmt):
2025 return stmt.body[0].value
2026
2027 self.assertEqual(
2028 self.bind_stmt("def f(): return 42", getter=body_get),
2029 INT_EXACT_TYPE.instance,
2030 )
2031 self.assertEqual(self.bind_stmt("def f(): yield 42", getter=body_get), DYNAMIC)
2032 self.assertEqual(
2033 self.bind_stmt("def f(): yield from x", getter=body_get), DYNAMIC
2034 )
2035 self.assertEqual(
2036 self.bind_stmt("async def f(): await x", getter=body_get), DYNAMIC
2037 )
2038
2039 self.assertEqual(self.bind_expr("object"), OBJECT_TYPE)
2040
2041 self.assertEqual(
2042 self.bind_expr("1 + 2", optimize=True), INT_EXACT_TYPE.instance
2043 )
2044
2045 def test_if_exp(self) -> None:
2046 mod, syms = self.bind_module(
2047 """
2048 class C: pass
2049 class D: pass
2050
2051 x = C() if a else D()
2052 """
2053 )
2054 node = mod.body[-1]
2055 types = syms.modules["foo"].types
2056 self.assertEqual(types[node], DYNAMIC)
2057
2058 mod, syms = self.bind_module(
2059 """
2060 class C: pass
2061
2062 x = C() if a else C()
2063 """
2064 )
2065 node = mod.body[-1]
2066 types = syms.modules["foo"].types
2067 self.assertEqual(types[node.value].name, "foo.C")
2068
2069 def test_cmpop(self):
2070 codestr = """
2071 from __static__ import int32
2072 def f():
2073 i: int32 = 0
2074 j: int = 0
2075
2076 if i == 0:
2077 return 0
2078 if j == 0:
2079 return 1
2080 """
2081 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2082 x = self.find_code(code, "f")
2083 self.assertInBytecode(x, "INT_COMPARE_OP", 0)
2084 self.assertInBytecode(x, "COMPARE_OP", "==")
2085
2086 def test_bind_instance(self) -> None:
2087 mod, syms = self.bind_module("class C: pass\na: C = C()")
2088 assign = mod.body[1]
2089 types = syms.modules["foo"].types
2090 self.assertEqual(types[assign.target].name, "foo.C")
2091 self.assertEqual(repr(types[assign.target]), "<foo.C>")
2092
2093 def test_bind_func_def(self) -> None:
2094 mod, syms = self.bind_module(
2095 """
2096 def f(x: object = None, y: object = None):
2097 pass
2098 """
2099 )
2100 modtable = syms.modules["foo"]
2101 self.assertTrue(isinstance(modtable.children["f"], Function))
2102
2103 def assertReturns(self, code: str, typename: str) -> None:
2104 actual = self.bind_final_return(code).name
2105 self.assertEqual(actual, typename)
2106
2107 def bind_final_return(self, code: str) -> Value:
2108 mod, syms = self.bind_module(code)
2109 types = syms.modules["foo"].types
2110 node = mod.body[-1].body[-1].value
2111 return types[node]
2112
2113 def bind_stmt(
2114 self, code: str, optimize: bool = False, getter=lambda stmt: stmt
2115 ) -> ast.stmt:
2116 mod, syms = self.bind_module(code, optimize)
2117 assert len(mod.body) == 1
2118 types = syms.modules["foo"].types
2119 return types[getter(mod.body[0])]
2120
2121 def bind_expr(self, code: str, optimize: bool = False) -> Value:
2122 mod, syms = self.bind_module(code, optimize)
2123 assert len(mod.body) == 1
2124 types = syms.modules["foo"].types
2125 return types[mod.body[0].value]
2126
2127 def bind_module(
2128 self, code: str, optimize: bool = False
2129 ) -> Tuple[ast.Module, SymbolTable]:
2130 tree = ast.parse(dedent(code))
2131 if optimize:
2132 tree = AstOptimizer().visit(tree)
2133
2134 symtable = SymbolTable()
2135 decl_visit = DeclarationVisitor("foo", "foo.py", symtable)
2136 decl_visit.visit(tree)
2137 decl_visit.module.finish_bind()
2138
2139 s = SymbolVisitor()
2140 walk(tree, s)
2141
2142 type_binder = TypeBinder(s, "foo.py", symtable, "foo")
2143 type_binder.visit(tree)
2144
2145 # Make sure we can compile the code, just verifying all nodes are
2146 # visited.
2147 graph = StaticCodeGenerator.flow_graph("foo", "foo.py", s.scopes[tree])
2148 code_gen = StaticCodeGenerator(None, tree, s, graph, symtable, "foo", optimize)
2149 code_gen.visit(tree)
2150
2151 return tree, symtable
2152
2153 def test_cross_module(self) -> None:
2154 acode = """
2155 class C:
2156 def f(self):
2157 return 42
2158 """
2159 bcode = """
2160 from a import C
2161
2162 def f():
2163 x = C()
2164 return x.f()
2165 """
2166 symtable = SymbolTable()
2167 acomp = symtable.compile("a", "a.py", ast.parse(dedent(acode)))
2168 bcomp = symtable.compile("b", "b.py", ast.parse(dedent(bcode)))
2169 x = self.find_code(bcomp, "f")
2170 self.assertInBytecode(x, "INVOKE_METHOD", (("a", "C", "f"), 0))
2171
2172 def test_cross_module_import_time_resolution(self) -> None:
2173 class TestSymbolTable(SymbolTable):
2174 def import_module(self, name):
2175 if name == "a":
2176 symtable.add_module("a", "a.py", ast.parse(dedent(acode)))
2177
2178 acode = """
2179 class C:
2180 def f(self):
2181 return 42
2182 """
2183 bcode = """
2184 from a import C
2185
2186 def f():
2187 x = C()
2188 return x.f()
2189 """
2190 symtable = TestSymbolTable()
2191 bcomp = symtable.compile("b", "b.py", ast.parse(dedent(bcode)))
2192 x = self.find_code(bcomp, "f")
2193 self.assertInBytecode(x, "INVOKE_METHOD", (("a", "C", "f"), 0))
2194
2195 def test_cross_module_type_checking(self) -> None:
2196 acode = """
2197 class C:
2198 def f(self):
2199 return 42
2200 """
2201 bcode = """
2202 from typing import TYPE_CHECKING
2203
2204 if TYPE_CHECKING:
2205 from a import C
2206
2207 def f(x: C):
2208 return x.f()
2209 """
2210 symtable = SymbolTable()
2211 symtable.add_module("a", "a.py", ast.parse(dedent(acode)))
2212 symtable.add_module("b", "b.py", ast.parse(dedent(bcode)))
2213 acomp = symtable.compile("a", "a.py", ast.parse(dedent(acode)))
2214 bcomp = symtable.compile("b", "b.py", ast.parse(dedent(bcode)))
2215 x = self.find_code(bcomp, "f")
2216 self.assertInBytecode(x, "INVOKE_METHOD", (("a", "C", "f"), 0))
2217
2218 def test_primitive_invoke(self) -> None:
2219 codestr = """
2220 from __static__ import int8
2221 def f():
2222 x: int8 = 42
2223 print(x.__str__())
2224 """
2225 with self.assertRaisesRegex(
2226 TypedSyntaxError, "cannot load attribute from int8"
2227 ):
2228 self.compile(codestr, StaticCodeGenerator)
2229
2230 def test_primitive_call(self) -> None:
2231 codestr = """
2232 from __static__ import int8
2233 def f():
2234 x: int8 = 42
2235 print(x())
2236 """
2237 with self.assertRaisesRegex(TypedSyntaxError, "cannot call int8"):
2238 self.compile(codestr, StaticCodeGenerator)
2239
2240 def test_primitive_subscr(self) -> None:
2241 codestr = """
2242 from __static__ import int8
2243 def f():
2244 x: int8 = 42
2245 print(x[42])
2246 """
2247 with self.assertRaisesRegex(TypedSyntaxError, "cannot index int8"):
2248 self.compile(codestr, StaticCodeGenerator)
2249
2250 def test_primitive_iter(self) -> None:
2251 codestr = """
2252 from __static__ import int8
2253 def f():
2254 x: int8 = 42
2255 for a in x:
2256 pass
2257 """
2258 with self.assertRaisesRegex(TypedSyntaxError, "cannot iterate over int8"):
2259 self.compile(codestr, StaticCodeGenerator)
2260
2261 def test_pseudo_strict_module(self) -> None:
2262 # simulate strict modules where the builtins come from <builtins>
2263 code = """
2264 def f(a):
2265 x: bool = a
2266 """
2267 builtins = ast.Assign(
2268 [ast.Name("bool", ast.Store())],
2269 ast.Subscript(
2270 ast.Name("<builtins>", ast.Load()),
2271 ast.Index(ast.Str("bool")),
2272 ast.Load(),
2273 ),
2274 None,
2275 )
2276 tree = ast.parse(dedent(code))
2277 tree.body.insert(0, builtins)
2278
2279 symtable = SymbolTable()
2280 symtable.add_module("a", "a.py", tree)
2281 acomp = symtable.compile("a", "a.py", tree)
2282 x = self.find_code(acomp, "f")
2283 self.assertInBytecode(x, "CAST", ("builtins", "bool"))
2284
2285 def test_aug_assign(self) -> None:
2286 codestr = """
2287 def f(l):
2288 l[0] += 1
2289 """
2290 with self.in_module(codestr) as mod:
2291 f = mod["f"]
2292 l = [1]
2293 f(l)
2294 self.assertEqual(l[0], 2)
2295
2296 def test_pseudo_strict_module_constant(self) -> None:
2297 # simulate strict modules where the builtins come from <builtins>
2298 code = """
2299 def f(a):
2300 x: bool = a
2301 """
2302 builtins = ast.Assign(
2303 [ast.Name("bool", ast.Store())],
2304 ast.Subscript(
2305 ast.Name("<builtins>", ast.Load()),
2306 ast.Index(ast.Constant("bool")),
2307 ast.Load(),
2308 ),
2309 None,
2310 )
2311 tree = ast.parse(dedent(code))
2312 tree.body.insert(0, builtins)
2313
2314 symtable = SymbolTable()
2315 symtable.add_module("a", "a.py", tree)
2316 acomp = symtable.compile("a", "a.py", tree)
2317 x = self.find_code(acomp, "f")
2318 self.assertInBytecode(x, "CAST", ("builtins", "bool"))
2319
2320 def test_cross_module_inheritance(self) -> None:
2321 acode = """
2322 class C:
2323 def f(self):
2324 return 42
2325 """
2326 bcode = """
2327 from a import C
2328
2329 class D(C):
2330 def f(self):
2331 return 'abc'
2332
2333 def f(y):
2334 x: C
2335 if y:
2336 x = D()
2337 else:
2338 x = C()
2339 return x.f()
2340 """
2341 symtable = SymbolTable()
2342 symtable.add_module("a", "a.py", ast.parse(dedent(acode)))
2343 symtable.add_module("b", "b.py", ast.parse(dedent(bcode)))
2344 acomp = symtable.compile("a", "a.py", ast.parse(dedent(acode)))
2345 bcomp = symtable.compile("b", "b.py", ast.parse(dedent(bcode)))
2346 x = self.find_code(bcomp, "f")
2347 self.assertInBytecode(x, "INVOKE_METHOD", (("a", "C", "f"), 0))
2348
2349 def test_annotated_function(self):
2350 codestr = """
2351 class C:
2352 def f(self) -> int:
2353 return 1
2354
2355 def x(c: C):
2356 x = c.f()
2357 x += c.f()
2358 return x
2359 """
2360
2361 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2362 x = self.find_code(code, "x")
2363 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
2364
2365 with self.in_module(codestr) as mod:
2366 x, C = mod["x"], mod["C"]
2367 c = C()
2368 self.assertEqual(x(c), 2)
2369
2370 def test_invoke_new_derived(self):
2371 codestr = """
2372 class C:
2373 def f(self):
2374 return 1
2375
2376 def x(c: C):
2377 x = c.f()
2378 x += c.f()
2379 return x
2380
2381 a = x(C())
2382
2383 class D(C):
2384 def f(self):
2385 return 2
2386
2387 b = x(D())
2388 """
2389
2390 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2391 x = self.find_code(code, "x")
2392 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
2393
2394 with self.in_module(codestr) as mod:
2395 a, b = mod["a"], mod["b"]
2396 self.assertEqual(a, 2)
2397 self.assertEqual(b, 4)
2398
2399 def test_invoke_explicit_slots(self):
2400 codestr = """
2401 class C:
2402 __slots__ = ()
2403 def f(self):
2404 return 1
2405
2406 def x(c: C):
2407 x = c.f()
2408 x += c.f()
2409 return x
2410
2411 a = x(C())
2412 """
2413
2414 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2415 x = self.find_code(code, "x")
2416 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
2417
2418 with self.in_module(codestr) as mod:
2419 a = mod["a"]
2420 self.assertEqual(a, 2)
2421
2422 def test_invoke_new_derived_nonfunc(self):
2423 codestr = """
2424 class C:
2425 def f(self):
2426 return 1
2427
2428 def x(c: C):
2429 x = c.f()
2430 x += c.f()
2431 return x
2432 """
2433
2434 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2435 x = self.find_code(code, "x")
2436 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
2437
2438 with self.in_module(codestr) as mod:
2439 x, C = mod["x"], mod["C"]
2440 self.assertEqual(x(C()), 2)
2441
2442 class Callable:
2443 def __call__(self_, obj):
2444 self.assertTrue(isinstance(obj, D))
2445 return 42
2446
2447 class D(C):
2448 f = Callable()
2449
2450 d = D()
2451 self.assertEqual(x(d), 84)
2452
2453 def test_invoke_new_derived_nonfunc_slots(self):
2454 codestr = """
2455 class C:
2456 def f(self):
2457 return 1
2458
2459 def x(c: C):
2460 x = c.f()
2461 x += c.f()
2462 return x
2463 """
2464
2465 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2466 x = self.find_code(code, "x")
2467 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
2468
2469 with self.in_module(codestr) as mod:
2470 x, C = mod["x"], mod["C"]
2471 self.assertEqual(x(C()), 2)
2472
2473 class Callable:
2474 def __call__(self_, obj):
2475 self.assertTrue(isinstance(obj, D))
2476 return 42
2477
2478 class D(C):
2479 __slots__ = ()
2480 f = Callable()
2481
2482 d = D()
2483 self.assertEqual(x(d), 84)
2484
2485 def test_invoke_new_derived_nonfunc_descriptor(self):
2486 codestr = """
2487 class C:
2488 def f(self):
2489 return 1
2490
2491 def x(c: C):
2492 x = c.f()
2493 x += c.f()
2494 return x
2495 """
2496
2497 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2498 x = self.find_code(code, "x")
2499 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
2500
2501 with self.in_module(codestr) as mod:
2502 x, C = mod["x"], mod["C"]
2503 self.assertEqual(x(C()), 2)
2504
2505 class Callable:
2506 def __call__(self):
2507 return 42
2508
2509 class Descr:
2510 def __get__(self, inst, ctx):
2511 return Callable()
2512
2513 class D(C):
2514 f = Descr()
2515
2516 d = D()
2517 self.assertEqual(x(d), 84)
2518
2519 def test_invoke_new_derived_nonfunc_data_descriptor(self):
2520 codestr = """
2521 class C:
2522 def f(self):
2523 return 1
2524
2525 def x(c: C):
2526 x = c.f()
2527 x += c.f()
2528 return x
2529 """
2530
2531 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2532 x = self.find_code(code, "x")
2533 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
2534
2535 with self.in_module(codestr) as mod:
2536 x, C = mod["x"], mod["C"]
2537 self.assertEqual(x(C()), 2)
2538
2539 class Callable:
2540 def __call__(self):
2541 return 42
2542
2543 class Descr:
2544 def __get__(self, inst, ctx):
2545 return Callable()
2546
2547 def __set__(self, inst, value):
2548 raise ValueError("no way")
2549
2550 class D(C):
2551 f = Descr()
2552
2553 d = D()
2554 self.assertEqual(x(d), 84)
2555
2556 def test_invoke_new_derived_nonfunc_descriptor_inst_override(self):
2557 codestr = """
2558 class C:
2559 def f(self):
2560 return 1
2561
2562 def x(c: C):
2563 x = c.f()
2564 x += c.f()
2565 return x
2566 """
2567
2568 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2569 x = self.find_code(code, "x")
2570 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
2571
2572 with self.in_module(codestr) as mod:
2573 x, C = mod["x"], mod["C"]
2574 self.assertEqual(x(C()), 2)
2575
2576 class Callable:
2577 def __call__(self):
2578 return 42
2579
2580 class Descr:
2581 def __get__(self, inst, ctx):
2582 return Callable()
2583
2584 class D(C):
2585 f = Descr()
2586
2587 d = D()
2588 self.assertEqual(x(d), 84)
2589 d.__dict__["f"] = lambda x: 100
2590 self.assertEqual(x(d), 200)
2591
2592 def test_invoke_new_derived_nonfunc_descriptor_modified(self):
2593 codestr = """
2594 class C:
2595 def f(self):
2596 return 1
2597
2598 def x(c: C):
2599 x = c.f()
2600 x += c.f()
2601 return x
2602 """
2603
2604 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2605 x = self.find_code(code, "x")
2606 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
2607
2608 with self.in_module(codestr) as mod:
2609 x, C = mod["x"], mod["C"]
2610 self.assertEqual(x(C()), 2)
2611
2612 class Callable:
2613 def __call__(self):
2614 return 42
2615
2616 class Descr:
2617 def __get__(self, inst, ctx):
2618 return Callable()
2619
2620 def __call__(self, arg):
2621 return 23
2622
2623 class D(C):
2624 f = Descr()
2625
2626 d = D()
2627 self.assertEqual(x(d), 84)
2628 del Descr.__get__
2629 self.assertEqual(x(d), 46)
2630
2631 def test_invoke_dict_override(self):
2632 codestr = """
2633 class C:
2634 def f(self):
2635 return 1
2636
2637 def x(c: C):
2638 x = c.f()
2639 x += c.f()
2640 return x
2641 """
2642
2643 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2644 x = self.find_code(code, "x")
2645 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
2646
2647 with self.in_module(codestr) as mod:
2648 x, C = mod["x"], mod["C"]
2649 self.assertEqual(x(C()), 2)
2650
2651 class D(C):
2652 def __init__(self):
2653 self.f = lambda: 42
2654
2655 d = D()
2656 self.assertEqual(x(d), 84)
2657
2658 def test_invoke_type_modified(self):
2659 codestr = """
2660 class C:
2661 def f(self):
2662 return 1
2663
2664 def x(c: C):
2665 x = c.f()
2666 x += c.f()
2667 return x
2668 """
2669
2670 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2671 x = self.find_code(code, "x")
2672 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
2673
2674 with self.in_module(codestr) as mod:
2675 x, C = mod["x"], mod["C"]
2676 self.assertEqual(x(C()), 2)
2677 C.f = lambda self: 42
2678 self.assertEqual(x(C()), 84)
2679
2680 def test_annotated_function_derived(self):
2681 codestr = """
2682 class C:
2683 def f(self) -> int:
2684 return 1
2685
2686 class D(C):
2687 def f(self) -> int:
2688 return 2
2689
2690 class E(C):
2691 pass
2692
2693 def x(c: C,):
2694 x = c.f()
2695 x += c.f()
2696 return x
2697 """
2698
2699 code = self.compile(
2700 codestr, StaticCodeGenerator, modname="test_annotated_function_derived"
2701 )
2702 x = self.find_code(code, "x")
2703 self.assertInBytecode(
2704 x, "INVOKE_METHOD", (("test_annotated_function_derived", "C", "f"), 0)
2705 )
2706
2707 with self.in_module(codestr) as mod:
2708 x = mod["x"]
2709 self.assertEqual(x(mod["C"]()), 2)
2710 self.assertEqual(x(mod["D"]()), 4)
2711 self.assertEqual(x(mod["E"]()), 2)
2712
2713 def test_conditional_init(self):
2714 codestr = f"""
2715 from __static__ import box, int64
2716
2717 class C:
2718 def __init__(self, init=True):
2719 if init:
2720 self.value: int64 = 1
2721
2722 def f(self) -> int:
2723 return box(self.value)
2724 """
2725
2726 with self.in_module(codestr) as mod:
2727 C = mod["C"]
2728 x = C()
2729 self.assertEqual(x.f(), 1)
2730 x = C(False)
2731 self.assertEqual(x.f(), 0)
2732 self.assertInBytecode(C.f, "LOAD_FIELD", (mod["__name__"], "C", "value"))
2733
2734 def test_error_incompat_assign_local(self):
2735 codestr = """
2736 class C:
2737 def __init__(self):
2738 self.x = None
2739
2740 def f(self):
2741 x: "C" = self.x
2742 """
2743 with self.in_module(codestr) as mod:
2744 C = mod["C"]
2745 with self.assertRaisesRegex(TypeError, "expected 'C', got 'NoneType'"):
2746 C().f()
2747
2748 def test_error_incompat_field_non_dynamic(self):
2749 codestr = """
2750 class C:
2751 def __init__(self):
2752 self.x: int = 'abc'
2753 """
2754 with self.assertRaises(TypedSyntaxError):
2755 self.compile(codestr, StaticCodeGenerator)
2756
2757 def test_error_incompat_field(self):
2758 codestr = """
2759 class C:
2760 def __init__(self):
2761 self.x: int = 100
2762
2763 def f(self, x):
2764 self.x = x
2765 """
2766 with self.in_module(codestr) as mod:
2767 C = mod["C"]
2768 C().f(42)
2769 with self.assertRaises(TypeError):
2770 C().f("abc")
2771
2772 def test_error_incompat_assign_dynamic(self):
2773 with self.assertRaises(TypedSyntaxError):
2774 code = self.compile(
2775 """
2776 class C:
2777 x: "C"
2778 def __init__(self):
2779 self.x = None
2780 """,
2781 StaticCodeGenerator,
2782 )
2783
2784 def test_annotated_class_var(self):
2785 codestr = """
2786 class C:
2787 x: int
2788 """
2789 code = self.compile(
2790 codestr, StaticCodeGenerator, modname="test_annotated_class_var"
2791 )
2792
2793 def test_annotated_instance_var(self):
2794 codestr = """
2795 class C:
2796 def __init__(self):
2797 self.x: str = 'abc'
2798 """
2799 code = self.compile(
2800 codestr, StaticCodeGenerator, modname="test_annotated_instance_var"
2801 )
2802 # get C from module, and then get __init__ from C
2803 code = self.find_code(self.find_code(code))
2804 self.assertInBytecode(code, "STORE_FIELD")
2805
2806 def test_load_store_attr_value(self):
2807 codestr = """
2808 class C:
2809 x: int
2810
2811 def __init__(self, value: int):
2812 self.x = value
2813
2814 def f(self):
2815 return self.x
2816 """
2817 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2818 init = self.find_code(self.find_code(code), "__init__")
2819 self.assertInBytecode(init, "STORE_FIELD")
2820 f = self.find_code(self.find_code(code), "f")
2821 self.assertInBytecode(f, "LOAD_FIELD")
2822 with self.in_module(codestr) as mod:
2823 C = mod["C"]
2824 a = C(42)
2825 self.assertEqual(a.f(), 42)
2826
2827 def test_load_store_attr(self):
2828 codestr = """
2829 class C:
2830 x: "C"
2831
2832 def __init__(self):
2833 self.x = self
2834
2835 def g(self):
2836 return 42
2837
2838 def f(self):
2839 return self.x.g()
2840 """
2841 with self.in_module(codestr) as mod:
2842 C = mod["C"]
2843 a = C()
2844 self.assertEqual(a.f(), 42)
2845
2846 def test_load_store_attr_init(self):
2847 codestr = """
2848 class C:
2849 def __init__(self):
2850 self.x: C = self
2851
2852 def g(self):
2853 return 42
2854
2855 def f(self):
2856 return self.x.g()
2857 """
2858 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2859
2860 with self.in_module(codestr) as mod:
2861 C = mod["C"]
2862 a = C()
2863 self.assertEqual(a.f(), 42)
2864
2865 def test_load_store_attr_init_no_ann(self):
2866 codestr = """
2867 class C:
2868 def __init__(self):
2869 self.x = self
2870
2871 def g(self):
2872 return 42
2873
2874 def f(self):
2875 return self.x.g()
2876 """
2877 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2878
2879 with self.in_module(codestr) as mod:
2880 C = mod["C"]
2881 a = C()
2882 self.assertEqual(a.f(), 42)
2883
2884 def test_unknown_annotation(self):
2885 codestr = """
2886 def f(a):
2887 x: foo = a
2888 return x.bar
2889 """
2890 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2891
2892 class C:
2893 bar = 42
2894
2895 f = self.run_code(codestr, StaticCodeGenerator)["f"]
2896 self.assertEqual(f(C()), 42)
2897
2898 def test_class_method_invoke(self):
2899 codestr = """
2900 class B:
2901 def __init__(self, value):
2902 self.value = value
2903
2904 class D(B):
2905 def __init__(self, value):
2906 B.__init__(self, value)
2907
2908 def f(self):
2909 return self.value
2910 """
2911 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2912
2913 b_init = self.find_code(self.find_code(code, "B"), "__init__")
2914 self.assertInBytecode(b_init, "STORE_FIELD", ("foo", "B", "value"))
2915
2916 f = self.find_code(self.find_code(code, "D"), "f")
2917 self.assertInBytecode(f, "LOAD_FIELD", ("foo", "B", "value"))
2918
2919 with self.in_module(codestr) as mod:
2920 D = mod["D"]
2921 d = D(42)
2922 self.assertEqual(d.f(), 42)
2923
2924 def test_slotification(self):
2925 codestr = """
2926 class C:
2927 x: "unknown_type"
2928 """
2929 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2930 C = self.run_code(codestr, StaticCodeGenerator)["C"]
2931 self.assertEqual(type(C.x), MemberDescriptorType)
2932
2933 def test_slotification_init(self):
2934 codestr = """
2935 class C:
2936 x: "unknown_type"
2937 def __init__(self):
2938 self.x = 42
2939 """
2940 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2941 C = self.run_code(codestr, StaticCodeGenerator)["C"]
2942 self.assertEqual(type(C.x), MemberDescriptorType)
2943
2944 def test_slotification_ann_init(self):
2945 codestr = """
2946 class C:
2947 x: "unknown_type"
2948 def __init__(self):
2949 self.x: "unknown_type" = 42
2950 """
2951 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
2952 C = self.run_code(codestr, StaticCodeGenerator)["C"]
2953 self.assertEqual(type(C.x), MemberDescriptorType)
2954
2955 def test_slotification_typed(self):
2956 codestr = """
2957 class C:
2958 x: int
2959 """
2960 C = self.run_code(codestr, StaticCodeGenerator)["C"]
2961 self.assertNotEqual(type(C.x), MemberDescriptorType)
2962
2963 def test_slotification_init_typed(self):
2964 codestr = """
2965 class C:
2966 x: int
2967 def __init__(self):
2968 self.x = 42
2969 """
2970 with self.in_module(codestr) as mod:
2971 C = mod["C"]
2972 self.assertNotEqual(type(C.x), MemberDescriptorType)
2973 x = C()
2974 self.assertEqual(x.x, 42)
2975 with self.assertRaisesRegex(
2976 TypeError, "expected 'int', got 'str' for attribute 'x'"
2977 ) as e:
2978 x.x = "abc"
2979
2980 def test_slotification_ann_init_typed(self):
2981 codestr = """
2982 class C:
2983 x: int
2984 def __init__(self):
2985 self.x: int = 42
2986 """
2987 C = self.run_code(codestr, StaticCodeGenerator)["C"]
2988 self.assertNotEqual(type(C.x), MemberDescriptorType)
2989
2990 def test_slotification_conflicting_types(self):
2991 codestr = """
2992 class C:
2993 x: object
2994 def __init__(self):
2995 self.x: int = 42
2996 """
2997 with self.assertRaisesRegex(
2998 TypedSyntaxError,
2999 r"conflicting type definitions for slot x in Type\[foo.C\]",
3000 ):
3001 self.compile(codestr, StaticCodeGenerator, modname="foo")
3002
3003 def test_slotification_conflicting_types_imported(self):
3004 self.type_error(
3005 """
3006 from typing import Optional
3007
3008 class C:
3009 x: Optional[int]
3010 def __init__(self):
3011 self.x: Optional[str] = "foo"
3012 """,
3013 r"conflicting type definitions for slot x in Type\[<module>.C\]",
3014 )
3015
3016 def test_slotification_conflicting_members(self):
3017 codestr = """
3018 class C:
3019 def x(self): pass
3020 x: object
3021 """
3022 with self.assertRaisesRegex(
3023 TypedSyntaxError, r"slot conflicts with other member x in Type\[foo.C\]"
3024 ):
3025 self.compile(codestr, StaticCodeGenerator, modname="foo")
3026
3027 def test_slotification_conflicting_function(self):
3028 codestr = """
3029 class C:
3030 x: object
3031 def x(self): pass
3032 """
3033 with self.assertRaisesRegex(
3034 TypedSyntaxError, r"function conflicts with other member x in Type\[foo.C\]"
3035 ):
3036 self.compile(codestr, StaticCodeGenerator, modname="foo")
3037
3038 def test_slot_inheritance(self):
3039 codestr = """
3040 class B:
3041 def __init__(self):
3042 self.x = 42
3043
3044 def f(self):
3045 return self.x
3046
3047 class D(B):
3048 def __init__(self):
3049 self.x = 100
3050 """
3051 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
3052 with self.in_module(codestr) as mod:
3053 D = mod["D"]
3054 inst = D()
3055 self.assertEqual(inst.f(), 100)
3056
3057 def test_del_slot(self):
3058 codestr = """
3059 class C:
3060 x: object
3061
3062 def f(a: C):
3063 del a.x
3064 """
3065 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
3066 code = self.find_code(code, name="f")
3067 self.assertInBytecode(code, "DELETE_ATTR", "x")
3068
3069 def test_uninit_slot(self):
3070 codestr = """
3071 class C:
3072 x: object
3073
3074 def f(a: C):
3075 return a.x
3076 """
3077 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
3078 code = self.find_code(code, name="f")
3079 with self.in_module(codestr) as mod:
3080 with self.assertRaises(AttributeError) as e:
3081 f, C = mod["f"], mod["C"]
3082 f(C())
3083
3084 self.assertEqual(e.exception.args[0], "x")
3085
3086 def test_aug_add(self):
3087 codestr = """
3088 class C:
3089 def __init__(self):
3090 self.x = 1
3091
3092 def f(a: C):
3093 a.x += 1
3094 """
3095 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
3096 code = self.find_code(code, name="f")
3097 self.assertInBytecode(code, "LOAD_FIELD", ("foo", "C", "x"))
3098 self.assertInBytecode(code, "STORE_FIELD", ("foo", "C", "x"))
3099
3100 def test_untyped_attr(self):
3101 codestr = """
3102 y = x.load
3103 x.store = 42
3104 del x.delete
3105 """
3106 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
3107 self.assertInBytecode(code, "LOAD_ATTR", "load")
3108 self.assertInBytecode(code, "STORE_ATTR", "store")
3109 self.assertInBytecode(code, "DELETE_ATTR", "delete")
3110
3111 def test_incompat_override(self):
3112 codestr = """
3113 class C:
3114 x: int
3115
3116 class D(C):
3117 def x(self): pass
3118 """
3119 with self.assertRaises(TypedSyntaxError):
3120 self.compile(codestr, StaticCodeGenerator, modname="foo")
3121
3122 def test_redefine_type(self):
3123 codestr = """
3124 class C: pass
3125 class D: pass
3126
3127 def f(a):
3128 x: C = C()
3129 x: D = D()
3130 """
3131 with self.assertRaises(TypedSyntaxError):
3132 self.compile(codestr, StaticCodeGenerator, modname="foo")
3133
3134 def test_optional_error(self):
3135 codestr = """
3136 from typing import Optional
3137 class C:
3138 x: Optional["C"]
3139 def __init__(self, set):
3140 if set:
3141 self.x = self
3142 else:
3143 self.x = None
3144
3145 def f(self) -> Optional["C"]:
3146 return self.x.x
3147 """
3148 with self.assertRaisesRegex(
3149 TypedSyntaxError,
3150 re.escape("Optional[foo.C]: 'NoneType' object has no attribute 'x'"),
3151 ):
3152 self.compile(codestr, StaticCodeGenerator, modname="foo")
3153
3154 def test_optional_subscript_error(self) -> None:
3155 codestr = """
3156 from typing import Optional
3157
3158 def f(a: Optional[int]):
3159 a[1]
3160 """
3161 with self.assertRaisesRegex(
3162 TypedSyntaxError,
3163 re.escape("Optional[int]: 'NoneType' object is not subscriptable"),
3164 ):
3165 self.compile(codestr, StaticCodeGenerator)
3166
3167 def test_optional_unary_error(self) -> None:
3168 codestr = """
3169 from typing import Optional
3170
3171 def f(a: Optional[int]):
3172 -a
3173 """
3174 with self.assertRaisesRegex(
3175 TypedSyntaxError,
3176 re.escape("Optional[int]: bad operand type for unary -: 'NoneType'"),
3177 ):
3178 self.compile(codestr, StaticCodeGenerator)
3179
3180 def test_optional_assign(self):
3181 codestr = """
3182 from typing import Optional
3183 class C:
3184 def f(self, x: Optional["C"]):
3185 if x is None:
3186 return self
3187 else:
3188 p: Optional["C"] = x
3189 """
3190 self.compile(codestr, StaticCodeGenerator, modname="foo")
3191
3192 def test_nonoptional_load(self):
3193 codestr = """
3194 class C:
3195 def __init__(self, y: int):
3196 self.y = y
3197
3198 def f(c: C) -> int:
3199 return c.y
3200 """
3201 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
3202 f = self.find_code(code, "f")
3203 self.assertInBytecode(f, "LOAD_FIELD", ("foo", "C", "y"))
3204
3205 def test_optional_assign_subclass(self):
3206 codestr = """
3207 from typing import Optional
3208 class B: pass
3209 class D(B): pass
3210
3211 def f(x: D):
3212 a: Optional[B] = x
3213 """
3214 self.compile(codestr, StaticCodeGenerator, modname="foo")
3215
3216 def test_optional_assign_subclass_opt(self):
3217 codestr = """
3218 from typing import Optional
3219 class B: pass
3220 class D(B): pass
3221
3222 def f(x: Optional[D]):
3223 a: Optional[B] = x
3224 """
3225 self.compile(codestr, StaticCodeGenerator, modname="foo")
3226
3227 def test_optional_assign_none(self):
3228 codestr = """
3229 from typing import Optional
3230 class B: pass
3231
3232 def f(x: Optional[B]):
3233 a: Optional[B] = None
3234 """
3235 self.compile(codestr, StaticCodeGenerator, modname="foo")
3236
3237 def test_optional_union_syntax(self):
3238 self.assertReturns(
3239 """
3240 from typing import Union
3241 class B: pass
3242 class C(B): pass
3243
3244 def f(x: Union[int, None]) -> int:
3245 # can assign None
3246 y: Optional[int] = None
3247 # can assign subclass
3248 z: Optional[B] = C()
3249 # can narrow
3250 if x is None:
3251 return 1
3252 return x
3253 """,
3254 "int",
3255 )
3256
3257 def test_optional_union_syntax_error(self):
3258 self.type_error(
3259 """
3260 from typing import Union
3261
3262 def f(x: Union[int, None]) -> int:
3263 return x
3264 """,
3265 type_mismatch("Optional[int]", "int"),
3266 )
3267
3268 def test_union_can_assign_to_broader_union(self):
3269 self.assertReturns(
3270 """
3271 from typing import Union
3272 class B:
3273 pass
3274
3275 def f(x: int, y: str) -> Union[int, str, B]:
3276 return x or y
3277 """,
3278 "Union[int, str]",
3279 )
3280
3281 def test_union_can_assign_to_same_union(self):
3282 self.assertReturns(
3283 """
3284 from typing import Union
3285
3286 def f(x: int, y: str) -> Union[int, str]:
3287 return x or y
3288 """,
3289 "Union[int, str]",
3290 )
3291
3292 def test_union_can_assign_from_individual_element(self):
3293 self.assertReturns(
3294 """
3295 from typing import Union
3296
3297 def f(x: int) -> Union[int, str]:
3298 return x
3299 """,
3300 "int",
3301 )
3302
3303 def test_union_cannot_assign_from_broader_union(self):
3304 # TODO this should be a type error, but can't be safely
3305 # until we have runtime checking for unions
3306 self.assertReturns(
3307 """
3308 from typing import Union
3309 class B: pass
3310
3311 def f(x: int, y: str, z: B) -> Union[int, str]:
3312 return x or y or z
3313 """,
3314 "Union[int, str, foo.B]",
3315 )
3316
3317 def test_union_simplify_to_single_type(self):
3318 self.assertReturns(
3319 """
3320 from typing import Union
3321
3322 def f(x: int, y: int) -> int:
3323 return x or y
3324 """,
3325 "int",
3326 )
3327
3328 def test_union_simplify_related(self):
3329 self.assertReturns(
3330 """
3331 from typing import Union
3332 class B: pass
3333 class C(B): pass
3334
3335 def f(x: B, y: C) -> B:
3336 return x or y
3337 """,
3338 "foo.B",
3339 )
3340
3341 def test_union_flatten_nested(self):
3342 self.assertReturns(
3343 """
3344 from typing import Union
3345 class B: pass
3346
3347 def f(x: int, y: str, z: B):
3348 return x or (y or z)
3349 """,
3350 "Union[int, str, foo.B]",
3351 )
3352
3353 def test_union_deep_simplify(self):
3354 self.assertReturns(
3355 """
3356 from typing import Union
3357
3358 def f(x: int, y: None) -> int:
3359 z = (x or x) or (y or y) or (x or x)
3360 if z is None:
3361 return 1
3362 return z
3363 """,
3364 "int",
3365 )
3366
3367 def test_union_dynamic_element(self):
3368 self.assertReturns(
3369 """
3370 from somewhere import unknown
3371
3372 def f(x: int, y: unknown):
3373 return x or y
3374 """,
3375 "dynamic",
3376 )
3377
3378 def test_union_or_syntax(self):
3379 self.type_error(
3380 """
3381 def f(x) -> int:
3382 if isinstance(x, int|str):
3383 return x
3384 return 1
3385 """,
3386 type_mismatch("Union[int, str]", "int"),
3387 )
3388
3389 def test_union_or_syntax_none(self):
3390 self.type_error(
3391 """
3392 def f(x) -> int:
3393 if isinstance(x, int|None):
3394 return x
3395 return 1
3396 """,
3397 type_mismatch("Optional[int]", "int"),
3398 )
3399
3400 def test_union_or_syntax_builtin_type(self):
3401 self.compile(
3402 """
3403 from typing import Iterator
3404 def f(x) -> int:
3405 if isinstance(x, bytes | Iterator[bytes]):
3406 return 1
3407 return 2
3408 """,
3409 StaticCodeGenerator,
3410 modname="foo.py",
3411 )
3412
3413 def test_union_or_syntax_none_first(self):
3414 self.type_error(
3415 """
3416 def f(x) -> int:
3417 if isinstance(x, None|int):
3418 return x
3419 return 1
3420 """,
3421 type_mismatch("Optional[int]", "int"),
3422 )
3423
3424 def test_union_or_syntax_annotation(self):
3425 self.type_error(
3426 """
3427 def f(y: int, z: str) -> int:
3428 x: int|str = y or z
3429 return x
3430 """,
3431 type_mismatch("Union[int, str]", "int"),
3432 )
3433
3434 def test_union_or_syntax_error(self):
3435 self.type_error(
3436 """
3437 def f():
3438 x = int | "foo"
3439 """,
3440 r"unsupported operand type(s) for |: Type\[Exact\[int\]\] and Exact\[str\]",
3441 )
3442
3443 def test_union_or_syntax_annotation_bad_type(self):
3444 # TODO given that len is not unknown/dynamic, but is a known object
3445 # with type that is invalid in this position, this should really be an
3446 # error. But the current form of `resolve_annotations` doesn't let us
3447 # distinguish between unknown/dynamic and bad type. So for now we just
3448 # let this go as dynamic.
3449 self.assertReturns(
3450 """
3451 def f(x: len | int) -> int:
3452 return x
3453 """,
3454 "dynamic",
3455 )
3456
3457 def test_union_attr(self):
3458 self.assertReturns(
3459 """
3460 class A:
3461 attr: int
3462
3463 class B:
3464 attr: str
3465
3466 def f(x: A, y: B):
3467 z = x or y
3468 return z.attr
3469 """,
3470 "Union[int, str]",
3471 )
3472
3473 def test_union_attr_error(self):
3474 self.type_error(
3475 """
3476 class A:
3477 attr: int
3478
3479 def f(x: A | None):
3480 return x.attr
3481 """,
3482 re.escape(
3483 "Optional[<module>.A]: 'NoneType' object has no attribute 'attr'"
3484 ),
3485 )
3486
3487 # TODO add test_union_call when we have Type[] or Callable[] or
3488 # __call__ support. Right now we have no way to construct a Union of
3489 # callables that return different types.
3490
3491 def test_union_call_error(self):
3492 self.type_error(
3493 """
3494 def f(x: int | None):
3495 return x()
3496 """,
3497 re.escape("Optional[int]: 'NoneType' object is not callable"),
3498 )
3499
3500 def test_union_subscr(self):
3501 self.assertReturns(
3502 """
3503 from __static__ import CheckedDict
3504
3505 def f(x: CheckedDict[int, int], y: CheckedDict[int, str]):
3506 return (x or y)[0]
3507 """,
3508 "Union[int, str]",
3509 )
3510
3511 def test_union_unaryop(self):
3512 self.assertReturns(
3513 """
3514 def f(x: int, y: complex):
3515 return -(x or y)
3516 """,
3517 "Union[int, complex]",
3518 )
3519
3520 def test_union_isinstance_reverse_narrow(self):
3521 self.assertReturns(
3522 """
3523 def f(x: int, y: str):
3524 z = x or y
3525 if isinstance(z, str):
3526 return 1
3527 return z
3528 """,
3529 "int",
3530 )
3531
3532 def test_union_isinstance_reverse_narrow_supertype(self):
3533 self.assertReturns(
3534 """
3535 class A: pass
3536 class B(A): pass
3537
3538 def f(x: int, y: B):
3539 o = x or y
3540 if isinstance(o, A):
3541 return 1
3542 return o
3543 """,
3544 "int",
3545 )
3546
3547 def test_union_isinstance_reverse_narrow_other_union(self):
3548 self.assertReturns(
3549 """
3550 class A: pass
3551 class B: pass
3552 class C: pass
3553
3554 def f(x: A, y: B, z: C):
3555 o = x or y or z
3556 if isinstance(o, A | B):
3557 return 1
3558 return o
3559 """,
3560 "foo.C",
3561 )
3562
3563 def test_union_not_isinstance_narrow(self):
3564 self.assertReturns(
3565 """
3566 def f(x: int, y: str):
3567 o = x or y
3568 if not isinstance(o, int):
3569 return 1
3570 return o
3571 """,
3572 "int",
3573 )
3574
3575 def test_union_isinstance_tuple(self):
3576 self.assertReturns(
3577 """
3578 class A: pass
3579 class B: pass
3580 class C: pass
3581
3582 def f(x: A, y: B, z: C):
3583 o = x or y or z
3584 if isinstance(o, (A, B)):
3585 return 1
3586 return o
3587 """,
3588 "foo.C",
3589 )
3590
3591 def test_union_no_arg_check(self):
3592 codestr = """
3593 def f(x: int | str) -> int:
3594 return x
3595 """
3596 with self.in_module(codestr) as mod:
3597 f = mod["f"]
3598 # no arg check for the union, it's just dynamic
3599 self.assertInBytecode(f, "CHECK_ARGS", ())
3600 # so we do have to check the return value
3601 self.assertInBytecode(f, "CAST", ("builtins", "int"))
3602 # runtime type error comes from return, not argument
3603 with self.assertRaisesRegex(TypeError, "expected 'int', got 'list'"):
3604 f([])
3605
3606 def test_error_return_int(self):
3607 with self.assertRaisesRegex(
3608 TypedSyntaxError, "type mismatch: int64 cannot be assigned to dynamic"
3609 ):
3610 code = self.compile(
3611 """
3612 from __static__ import ssize_t
3613 def f():
3614 y: ssize_t = 1
3615 return y
3616 """,
3617 StaticCodeGenerator,
3618 )
3619
3620 def test_error_mixed_math(self):
3621 with self.assertRaises(TypedSyntaxError):
3622 code = self.compile(
3623 """
3624 from __static__ import ssize_t
3625 def f():
3626 y = 1
3627 x: ssize_t = 42 + y
3628 """,
3629 StaticCodeGenerator,
3630 )
3631
3632 def test_error_incompat_return(self):
3633 with self.assertRaises(TypedSyntaxError):
3634 code = self.compile(
3635 """
3636 class D: pass
3637 class C:
3638 def __init__(self):
3639 self.x = None
3640
3641 def f(self) -> "C":
3642 return D()
3643 """,
3644 StaticCodeGenerator,
3645 )
3646
3647 def test_cast(self):
3648 for code_gen in (StaticCodeGenerator, PythonCodeGenerator):
3649 codestr = """
3650 from __static__ import cast
3651 class C:
3652 pass
3653
3654 a = C()
3655
3656 def f() -> C:
3657 return cast(C, a)
3658 """
3659 code = self.compile(codestr, code_gen)
3660 f = self.find_code(code, "f")
3661 if code_gen is StaticCodeGenerator:
3662 self.assertInBytecode(f, "CAST", ("<module>", "C"))
3663 with self.in_module(codestr, code_gen=code_gen) as mod:
3664 C = mod["C"]
3665 f = mod["f"]
3666 self.assertTrue(isinstance(f(), C))
3667 self.assert_jitted(f)
3668
3669 def test_cast_fail(self):
3670 for code_gen in (StaticCodeGenerator, PythonCodeGenerator):
3671 codestr = """
3672 from __static__ import cast
3673 class C:
3674 pass
3675
3676 def f() -> C:
3677 return cast(C, 42)
3678 """
3679 code = self.compile(codestr, code_gen)
3680 f = self.find_code(code, "f")
3681 if code_gen is StaticCodeGenerator:
3682 self.assertInBytecode(f, "CAST", ("<module>", "C"))
3683 with self.in_module(codestr) as mod:
3684 with self.assertRaises(TypeError):
3685 f = mod["f"]
3686 f()
3687 self.assert_jitted(f)
3688
3689 def test_cast_wrong_args(self):
3690 codestr = """
3691 from __static__ import cast
3692 def f():
3693 cast(42)
3694 """
3695 with self.assertRaises(TypedSyntaxError):
3696 self.compile(codestr, StaticCodeGenerator)
3697
3698 def test_cast_unknown_type(self):
3699 codestr = """
3700 from __static__ import cast
3701 def f():
3702 cast(abc, 42)
3703 """
3704 with self.assertRaises(TypedSyntaxError):
3705 self.compile(codestr, StaticCodeGenerator)
3706
3707 def test_cast_optional(self):
3708 for code_gen in (StaticCodeGenerator, PythonCodeGenerator):
3709 codestr = """
3710 from __static__ import cast
3711 from typing import Optional
3712
3713 class C:
3714 pass
3715
3716 def f(x) -> Optional[C]:
3717 return cast(Optional[C], x)
3718 """
3719 code = self.compile(codestr, code_gen)
3720 f = self.find_code(code, "f")
3721 if code_gen is StaticCodeGenerator:
3722 self.assertInBytecode(f, "CAST", ("<module>", "C", "?"))
3723 with self.in_module(codestr, code_gen=code_gen) as mod:
3724 C = mod["C"]
3725 f = mod["f"]
3726 self.assertTrue(isinstance(f(C()), C))
3727 self.assertEqual(f(None), None)
3728 self.assert_jitted(f)
3729
3730 def test_code_flags(self):
3731 codestr = """
3732 def func():
3733 print("hi")
3734
3735 func()
3736 """
3737 module = self.compile(codestr, StaticCodeGenerator)
3738 self.assertTrue(module.co_flags & CO_STATICALLY_COMPILED)
3739 self.assertTrue(
3740 self.find_code(module, name="func").co_flags & CO_STATICALLY_COMPILED
3741 )
3742
3743 def test_invoke_kws(self):
3744 codestr = """
3745 class C:
3746 def f(self, a):
3747 return a
3748
3749 def func():
3750 a = C()
3751 return a.f(a=2)
3752
3753 """
3754 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3755 f = mod["func"]
3756 self.assertEqual(f(), 2)
3757
3758 def test_invoke_str_method(self):
3759 codestr = """
3760 def func():
3761 a = 'a b c'
3762 return a.split()
3763
3764 """
3765 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3766 f = mod["func"]
3767 self.assertInBytecode(
3768 f, "INVOKE_FUNCTION", (("builtins", "str", "split"), 1)
3769 )
3770 self.assertEqual(f(), ["a", "b", "c"])
3771
3772 def test_invoke_str_method_arg(self):
3773 codestr = """
3774 def func():
3775 a = 'a b c'
3776 return a.split('a')
3777
3778 """
3779 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3780 f = mod["func"]
3781 self.assertInBytecode(
3782 f, "INVOKE_FUNCTION", (("builtins", "str", "split"), 2)
3783 )
3784 self.assertEqual(f(), ["", " b c"])
3785
3786 def test_invoke_str_method_kwarg(self):
3787 codestr = """
3788 def func():
3789 a = 'a b c'
3790 return a.split(sep='a')
3791
3792 """
3793 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3794 f = mod["func"]
3795 self.assertNotInBytecode(f, "INVOKE_FUNCTION")
3796 self.assertNotInBytecode(f, "INVOKE_METHOD")
3797 self.assertEqual(f(), ["", " b c"])
3798
3799 def test_invoke_int_method(self):
3800 codestr = """
3801 def func():
3802 a = 42
3803 return a.bit_length()
3804
3805 """
3806 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3807 f = mod["func"]
3808 self.assertInBytecode(
3809 f, "INVOKE_FUNCTION", (("builtins", "int", "bit_length"), 1)
3810 )
3811 self.assertEqual(f(), 6)
3812
3813 def test_invoke_chkdict_method(self):
3814 codestr = """
3815 from __static__ import CheckedDict
3816 def dict_maker() -> CheckedDict[int, int]:
3817 return CheckedDict[int, int]({2:2})
3818 def func():
3819 a = dict_maker()
3820 return a.keys()
3821
3822 """
3823 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3824 f = mod["func"]
3825
3826 self.assertInBytecode(
3827 f,
3828 "INVOKE_METHOD",
3829 (
3830 (
3831 "__static__",
3832 "chkdict",
3833 (("builtins", "int"), ("builtins", "int")),
3834 "keys",
3835 ),
3836 0,
3837 ),
3838 )
3839 self.assertEqual(list(f()), [2])
3840 self.assert_jitted(f)
3841
3842 def test_invoke_method_non_static_base(self):
3843 codestr = """
3844 class C(Exception):
3845 def f(self):
3846 return 42
3847
3848 def g(self):
3849 return self.f()
3850 """
3851
3852 with self.in_module(codestr) as mod:
3853 C = mod["C"]
3854 self.assertEqual(C().g(), 42)
3855
3856 def test_invoke_builtin_func(self):
3857 codestr = """
3858 from xxclassloader import foo
3859 from __static__ import int64, box
3860
3861 def func():
3862 a: int64 = foo()
3863 return box(a)
3864
3865 """
3866 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3867 f = mod["func"]
3868 self.assertInBytecode(f, "INVOKE_FUNCTION", ((("xxclassloader", "foo"), 0)))
3869 self.assertEqual(f(), 42)
3870 self.assert_jitted(f)
3871
3872 def test_invoke_builtin_func_ret_neg(self):
3873 # setup xxclassloader as a built-in function for this test, so we can
3874 # do a direct invoke
3875 xxclassloader = sys.modules["xxclassloader"]
3876 try:
3877 sys.modules["xxclassloader"] = StrictModule(xxclassloader.__dict__, False)
3878 codestr = """
3879 from xxclassloader import neg
3880 from __static__ import int64, box
3881
3882 def test():
3883 x: int64 = neg()
3884 return box(x)
3885 """
3886 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3887 test = mod["test"]
3888 self.assertEqual(test(), -1)
3889 finally:
3890 sys.modules["xxclassloader"] = xxclassloader
3891
3892 def test_invoke_builtin_func_arg(self):
3893 codestr = """
3894 from xxclassloader import bar
3895 from __static__ import int64, box
3896
3897 def func():
3898 a: int64 = bar(42)
3899 return box(a)
3900
3901 """
3902 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3903 f = mod["func"]
3904 self.assertInBytecode(f, "INVOKE_FUNCTION", ((("xxclassloader", "bar"), 1)))
3905 self.assertEqual(f(), 42)
3906 self.assert_jitted(f)
3907
3908 def test_invoke_meth_o(self):
3909 codestr = """
3910 from xxclassloader import spamobj
3911
3912 def func():
3913 a = spamobj[int]()
3914 a.setstate_untyped(42)
3915 return a.getstate()
3916
3917 """
3918 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3919 f = mod["func"]
3920
3921 self.assertInBytecode(
3922 f,
3923 "INVOKE_METHOD",
3924 (
3925 (
3926 "xxclassloader",
3927 "spamobj",
3928 (("builtins", "int"),),
3929 "setstate_untyped",
3930 ),
3931 1,
3932 ),
3933 )
3934 self.assertEqual(f(), 42)
3935 self.assert_jitted(f)
3936
3937 def test_multi_generic(self):
3938 codestr = """
3939 from xxclassloader import XXGeneric
3940
3941 def func():
3942 a = XXGeneric[int, str]()
3943 return a.foo(42, 'abc')
3944 """
3945 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3946 f = mod["func"]
3947 self.assertEqual(f(), "42abc")
3948
3949 def test_verify_positional_args(self):
3950 codestr = """
3951 def x(a: int, b: str) -> None:
3952 pass
3953 x("a", 2)
3954 """
3955 with self.assertRaisesRegex(TypedSyntaxError, "argument type mismatch"):
3956 self.compile(codestr, StaticCodeGenerator)
3957
3958 def test_verify_positional_args_unordered(self):
3959 codestr = """
3960 def x(a: int, b: str) -> None:
3961 return y(a, b)
3962 def y(a: int, b: str) -> None:
3963 pass
3964 """
3965 self.compile(codestr, StaticCodeGenerator)
3966
3967 def test_verify_kwargs(self):
3968 codestr = """
3969 def x(a: int=1, b: str="hunter2") -> None:
3970 return
3971 x(b="lol", a=23)
3972 """
3973 self.compile(codestr, StaticCodeGenerator)
3974
3975 def test_verify_kwdefaults(self):
3976 codestr = """
3977 def x(*, b: str="hunter2"):
3978 return b
3979 z = x(b="lol")
3980 """
3981 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3982 self.assertEqual(mod["z"], "lol")
3983
3984 def test_verify_kwdefaults_no_value(self):
3985 codestr = """
3986 def x(*, b: str="hunter2"):
3987 return b
3988 a = x()
3989 """
3990 module = self.compile(codestr, StaticCodeGenerator)
3991 # we don't yet support optimized dispatch to kw-only functions
3992 self.assertInBytecode(module, "CALL_FUNCTION")
3993 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
3994 self.assertEqual(mod["a"], "hunter2")
3995
3996 def test_verify_arg_dynamic_type(self):
3997 codestr = """
3998 def x(v:str):
3999 return 'abc'
4000 def y(v):
4001 return x(v)
4002 """
4003 module = self.compile(codestr, StaticCodeGenerator)
4004 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
4005 y = mod["y"]
4006 with self.assertRaises(TypeError):
4007 y(42)
4008 self.assertEqual(y("foo"), "abc")
4009
4010 def test_verify_arg_unknown_type(self):
4011 codestr = """
4012 def x(x:foo):
4013 return b
4014 x('abc')
4015 """
4016 module = self.compile(codestr, StaticCodeGenerator)
4017 self.assertInBytecode(module, "INVOKE_FUNCTION")
4018 x = self.find_code(module)
4019 self.assertInBytecode(x, "CHECK_ARGS", ())
4020
4021 def test_dict_invoke(self):
4022 codestr = """
4023 from __static__ import pydict
4024 def f(x):
4025 y: pydict = x
4026 return y.get('foo')
4027 """
4028 with self.in_module(codestr) as mod:
4029 f = mod["f"]
4030 self.assertInBytecode(f, "INVOKE_METHOD", (("builtins", "dict", "get"), 1))
4031 self.assertEqual(f({}), None)
4032
4033 def test_dict_invoke_ret(self):
4034 codestr = """
4035 from __static__ import pydict
4036 def g(): return None
4037 def f(x):
4038 y: pydict = x
4039 z = y.get('foo')
4040 z = None # should be typed to dynamic
4041 return z
4042 """
4043 with self.in_module(codestr) as mod:
4044 f = mod["f"]
4045 self.assertInBytecode(f, "INVOKE_METHOD", (("builtins", "dict", "get"), 1))
4046 self.assertEqual(f({}), None)
4047
4048 def test_verify_kwarg_unknown_type(self):
4049 codestr = """
4050 def x(x:foo):
4051 return b
4052 x(x='abc')
4053 """
4054 module = self.compile(codestr, StaticCodeGenerator)
4055 self.assertInBytecode(module, "INVOKE_FUNCTION")
4056 x = self.find_code(module)
4057 self.assertInBytecode(x, "CHECK_ARGS", ())
4058
4059 def test_verify_kwdefaults_too_many(self):
4060 codestr = """
4061 def x(*, b: str="hunter2") -> None:
4062 return
4063 x('abc')
4064 """
4065 with self.assertRaisesRegex(
4066 TypedSyntaxError, "x takes 0 positional args but 1 was given"
4067 ):
4068 self.compile(codestr, StaticCodeGenerator)
4069
4070 def test_verify_kwdefaults_too_many_class(self):
4071 codestr = """
4072 class C:
4073 def x(self, *, b: str="hunter2") -> None:
4074 return
4075 C().x('abc')
4076 """
4077 with self.assertRaisesRegex(
4078 TypedSyntaxError, "x takes 1 positional args but 2 were given"
4079 ):
4080 self.compile(codestr, StaticCodeGenerator)
4081
4082 def test_verify_kwonly_failure(self):
4083 codestr = """
4084 def x(*, a: int=1, b: str="hunter2") -> None:
4085 return
4086 x(a="hi", b="lol")
4087 """
4088 with self.assertRaisesRegex(TypedSyntaxError, "keyword argument type mismatch"):
4089 self.compile(codestr, StaticCodeGenerator)
4090
4091 def test_verify_kwonly_self_loaded_once(self):
4092 codestr = """
4093 class C:
4094 def x(self, *, a: int=1) -> int:
4095 return 43
4096
4097 def f():
4098 return C().x(a=1)
4099 """
4100 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
4101 f = mod["f"]
4102 io = StringIO()
4103 dis.dis(f, file=io)
4104 self.assertEqual(1, io.getvalue().count("LOAD_GLOBAL"))
4105
4106 def test_call_function_unknown_ret_type(self):
4107 codestr = """
4108 from __future__ import annotations
4109 def g() -> foo:
4110 return 42
4111
4112 def testfunc():
4113 return g()
4114 """
4115 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
4116 f = mod["testfunc"]
4117 self.assertEqual(f(), 42)
4118
4119 def test_verify_kwargs_failure(self):
4120 codestr = """
4121 def x(a: int=1, b: str="hunter2") -> None:
4122 return
4123 x(a="hi", b="lol")
4124 """
4125 with self.assertRaisesRegex(TypedSyntaxError, "keyword argument type mismatch"):
4126 self.compile(codestr, StaticCodeGenerator)
4127
4128 def test_verify_mixed_args(self):
4129 codestr = """
4130 def x(a: int=1, b: str="hunter2", c: int=14) -> None:
4131 return
4132 x(12, c=56, b="lol")
4133 """
4134 self.compile(codestr, StaticCodeGenerator)
4135
4136 def test_kwarg_cast(self):
4137 codestr = """
4138 def x(a: int=1, b: str="hunter2", c: int=14) -> None:
4139 return
4140
4141 def g(a):
4142 x(b=a)
4143 """
4144 code = self.find_code(self.compile(codestr, StaticCodeGenerator), "g")
4145 self.assertInBytecode(code, "CAST", ("builtins", "str"))
4146
4147 def test_kwarg_nocast(self):
4148 codestr = """
4149 def x(a: int=1, b: str="hunter2", c: int=14) -> None:
4150 return
4151
4152 def g():
4153 x(b='abc')
4154 """
4155 code = self.find_code(self.compile(codestr, StaticCodeGenerator), "g")
4156 self.assertNotInBytecode(code, "CAST", ("builtins", "str"))
4157
4158 def test_verify_mixed_args_kw_failure(self):
4159 codestr = """
4160 def x(a: int=1, b: str="hunter2", c: int=14) -> None:
4161 return
4162 x(12, c="hi", b="lol")
4163 """
4164 with self.assertRaisesRegex(TypedSyntaxError, "keyword argument type mismatch"):
4165 self.compile(codestr, StaticCodeGenerator)
4166
4167 def test_verify_mixed_args_positional_failure(self):
4168 codestr = """
4169 def x(a: int=1, b: str="hunter2", c: int=14) -> None:
4170 return
4171 x("hi", b="lol")
4172 """
4173 with self.assertRaisesRegex(
4174 TypedSyntaxError, "positional argument type mismatch"
4175 ):
4176 self.compile(codestr, StaticCodeGenerator)
4177
4178 # Same tests as above, but for methods.
4179 def test_verify_positional_args_method(self):
4180 codestr = """
4181 class C:
4182 def x(self, a: int, b: str) -> None:
4183 pass
4184 C().x(2, "hi")
4185 """
4186 self.compile(codestr, StaticCodeGenerator)
4187
4188 def test_verify_positional_args_failure_method(self):
4189 codestr = """
4190 class C:
4191 def x(self, a: int, b: str) -> None:
4192 pass
4193 C().x("a", 2)
4194 """
4195 with self.assertRaisesRegex(
4196 TypedSyntaxError, "positional argument type mismatch"
4197 ):
4198 self.compile(codestr, StaticCodeGenerator)
4199
4200 def test_verify_mixed_args_method(self):
4201 codestr = """
4202 class C:
4203 def x(self, a: int=1, b: str="hunter2", c: int=14) -> None:
4204 return
4205 C().x(12, c=56, b="lol")
4206 """
4207 self.compile(codestr, StaticCodeGenerator)
4208
4209 def test_starargs_invoked_once(self):
4210 codestr = """
4211 X = 0
4212
4213 def f():
4214 global X
4215 X += 1
4216 return {"a": 1, "b": "foo", "c": 42}
4217
4218 class C:
4219 def x(self, a: int=1, b: str="hunter2", c: int=14) -> None:
4220 return
4221 C().x(12, **f())
4222 """
4223 with self.in_module(codestr) as mod:
4224 x = mod["X"]
4225 self.assertEqual(x, 1)
4226
4227 def test_starargs_invoked_in_order(self):
4228 codestr = """
4229 X = 1
4230
4231 def f():
4232 global X
4233 X += 1
4234 return {"a": 1, "b": "foo"}
4235
4236 def make_c():
4237 global X
4238 X *= 2
4239 return 42
4240
4241 class C:
4242 def x(self, a: int=1, b: str="hunter2", c: int=14) -> None:
4243 return
4244
4245 def test():
4246 C().x(12, c=make_c(), **f())
4247 """
4248 with self.in_module(codestr) as mod:
4249 test = mod["test"]
4250 test()
4251 x = mod["X"]
4252 self.assertEqual(x, 3)
4253
4254 def test_verify_mixed_args_kw_failure_method(self):
4255 codestr = """
4256 class C:
4257 def x(self, a: int=1, b: str="hunter2", c: int=14) -> None:
4258 return
4259 C().x(12, c=b'lol', b="lol")
4260 """
4261 with self.assertRaisesRegex(TypedSyntaxError, "keyword argument type mismatch"):
4262 self.compile(codestr, StaticCodeGenerator)
4263
4264 def test_verify_mixed_args_positional_failure_method(self):
4265 codestr = """
4266 class C:
4267 def x(self, a: int=1, b: str="hunter2", c: int=14) -> None:
4268 return
4269 C().x("hi", b="lol")
4270 """
4271 with self.assertRaisesRegex(
4272 TypedSyntaxError, "positional argument type mismatch"
4273 ):
4274 self.compile(codestr, StaticCodeGenerator)
4275
4276 def test_generic_kwargs_unsupported(self):
4277 # definition is allowed, we just don't do an optimal invoke
4278 codestr = """
4279 def f(a: int, b: str, **my_stuff) -> None:
4280 pass
4281
4282 def g():
4283 return f(1, 'abc', x="y")
4284 """
4285 with self.in_module(codestr) as mod:
4286 g = mod["g"]
4287 self.assertInBytecode(g, "CALL_FUNCTION_KW", 3)
4288
4289 def test_generic_kwargs_method_unsupported(self):
4290 # definition is allowed, we just don't do an optimal invoke
4291 codestr = """
4292 class C:
4293 def f(self, a: int, b: str, **my_stuff) -> None:
4294 pass
4295
4296 def g():
4297 return C().f(1, 'abc', x="y")
4298 """
4299 with self.in_module(codestr) as mod:
4300 g = mod["g"]
4301 self.assertInBytecode(g, "CALL_FUNCTION_KW", 3)
4302
4303 def test_generic_varargs_unsupported(self):
4304 # definition is allowed, we just don't do an optimal invoke
4305 codestr = """
4306 def f(a: int, b: str, *my_stuff) -> None:
4307 pass
4308
4309 def g():
4310 return f(1, 'abc', "foo")
4311 """
4312 with self.in_module(codestr) as mod:
4313 g = mod["g"]
4314 self.assertInBytecode(g, "CALL_FUNCTION", 3)
4315
4316 def test_generic_varargs_method_unsupported(self):
4317 # definition is allowed, we just don't do an optimal invoke
4318 codestr = """
4319 class C:
4320 def f(self, a: int, b: str, *my_stuff) -> None:
4321 pass
4322
4323 def g():
4324 return C().f(1, 'abc', "foo")
4325 """
4326 with self.in_module(codestr) as mod:
4327 g = mod["g"]
4328 self.assertInBytecode(g, "CALL_METHOD", 3)
4329
4330 def test_kwargs_get(self):
4331 codestr = """
4332 def test(**foo):
4333 print(foo.get('bar'))
4334 """
4335
4336 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
4337 test = mod["test"]
4338 self.assertInBytecode(
4339 test, "INVOKE_FUNCTION", (("builtins", "dict", "get"), 2)
4340 )
4341
4342 def test_varargs_count(self):
4343 codestr = """
4344 def test(*foo):
4345 print(foo.count('bar'))
4346 """
4347
4348 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
4349 test = mod["test"]
4350 self.assertInBytecode(
4351 test, "INVOKE_FUNCTION", (("builtins", "tuple", "count"), 2)
4352 )
4353
4354 def test_varargs_call(self):
4355 codestr = """
4356 def g(*foo):
4357 return foo
4358
4359 def testfunc():
4360 return g(2)
4361 """
4362
4363 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
4364 test = mod["testfunc"]
4365 self.assertEqual(test(), (2,))
4366
4367 def test_kwargs_call(self):
4368 codestr = """
4369 def g(**foo):
4370 return foo
4371
4372 def testfunc():
4373 return g(x=2)
4374 """
4375
4376 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
4377 test = mod["testfunc"]
4378 self.assertEqual(test(), {"x": 2})
4379
4380 def test_index_by_int(self):
4381 codestr = """
4382 from __static__ import int32
4383 def f(x):
4384 i: int32 = 0
4385 return x[i]
4386 """
4387 with self.assertRaises(TypedSyntaxError):
4388 self.compile(codestr, StaticCodeGenerator)
4389
4390 def test_pydict_arg_annotation(self):
4391 codestr = """
4392 from __static__ import PyDict
4393 def f(d: PyDict[str, int]) -> str:
4394 # static python ignores the untrusted generic types
4395 return d[3]
4396 """
4397 with self.in_module(codestr) as mod:
4398 self.assertEqual(mod["f"]({3: "foo"}), "foo")
4399
4400 def test_list_get_primitive_int(self):
4401 codestr = """
4402 from __static__ import int8
4403 def f():
4404 l = [1, 2, 3]
4405 x: int8 = 1
4406 return l[x]
4407 """
4408 f = self.find_code(self.compile(codestr))
4409 self.assertInBytecode(f, "SEQUENCE_GET", SEQ_LIST)
4410 with self.in_module(codestr) as mod:
4411 self.assertEqual(mod["f"](), 2)
4412
4413 def test_list_set_primitive_int(self):
4414 codestr = """
4415 from __static__ import int8
4416 def f():
4417 l = [1, 2, 3]
4418 x: int8 = 1
4419 l[x] = 5
4420 return l
4421 """
4422 f = self.find_code(self.compile(codestr))
4423 self.assertInBytecode(f, "SEQUENCE_SET")
4424 with self.in_module(codestr) as mod:
4425 self.assertEqual(mod["f"](), [1, 5, 3])
4426
4427 def test_list_set_primitive_int_2(self):
4428 codestr = """
4429 from __static__ import int64
4430 def f(l1):
4431 l2 = [None] * len(l1)
4432 i: int64 = 0
4433 for item in l1:
4434 l2[i] = item + 1
4435 i += 1
4436 return l2
4437 """
4438 f = self.find_code(self.compile(codestr))
4439 self.assertInBytecode(f, "SEQUENCE_SET")
4440 with self.in_module(codestr) as mod:
4441 self.assertEqual(mod["f"]([1, 2000]), [2, 2001])
4442
4443 def test_list_del_primitive_int(self):
4444 codestr = """
4445 from __static__ import int8
4446 def f():
4447 l = [1, 2, 3]
4448 x: int8 = 1
4449 del l[x]
4450 return l
4451 """
4452 f = self.find_code(self.compile(codestr))
4453 self.assertInBytecode(f, "LIST_DEL")
4454 with self.in_module(codestr) as mod:
4455 self.assertEqual(mod["f"](), [1, 3])
4456
4457 def test_list_append(self):
4458 codestr = """
4459 from __static__ import int8
4460 def f():
4461 l = [1, 2, 3]
4462 l.append(4)
4463 return l
4464 """
4465 f = self.find_code(self.compile(codestr))
4466 self.assertInBytecode(f, "LIST_APPEND", 1)
4467 with self.in_module(codestr) as mod:
4468 self.assertEqual(mod["f"](), [1, 2, 3, 4])
4469
4470 def test_unknown_type_unary(self):
4471 codestr = """
4472 def x(y):
4473 z = -y
4474 """
4475 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo"))
4476 self.assertInBytecode(f, "UNARY_NEGATIVE")
4477
4478 def test_unknown_type_binary(self):
4479 codestr = """
4480 def x(a, b):
4481 z = a + b
4482 """
4483 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo"))
4484 self.assertInBytecode(f, "BINARY_ADD")
4485
4486 def test_unknown_type_compare(self):
4487 codestr = """
4488 def x(a, b):
4489 z = a > b
4490 """
4491 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo"))
4492 self.assertInBytecode(f, "COMPARE_OP")
4493
4494 def test_async_func_ret_type(self):
4495 codestr = """
4496 async def x(a) -> int:
4497 return a
4498 """
4499 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo"))
4500 self.assertInBytecode(f, "CAST")
4501
4502 def test_async_func_arg_types(self):
4503 codestr = """
4504 async def f(x: int):
4505 pass
4506 """
4507 f = self.find_code(self.compile(codestr))
4508 self.assertInBytecode(f, "CHECK_ARGS", (0, ("builtins", "int")))
4509
4510 def test_assign_prim_to_class(self):
4511 codestr = """
4512 from __static__ import int64
4513 class C: pass
4514
4515 def f():
4516 x: C = C()
4517 y: int64 = 42
4518 x = y
4519 """
4520 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("int64", "foo.C")):
4521 self.compile(codestr, StaticCodeGenerator, modname="foo")
4522
4523 def test_field_refcount(self):
4524 codestr = """
4525 class C:
4526 def __init__(self):
4527 self.x = None
4528
4529 def set_x(self, x):
4530 self.x = x
4531 """
4532 count = 0
4533 with self.in_module(codestr) as mod:
4534 C = mod["C"]
4535
4536 class X:
4537 def __init__(self):
4538 nonlocal count
4539 count += 1
4540
4541 def __del__(self):
4542 nonlocal count
4543 count -= 1
4544
4545 c = C()
4546 c.set_x(X())
4547 c.set_x(X())
4548 self.assertEqual(count, 1)
4549 del c
4550 self.assertEqual(count, 0)
4551
4552 def test_typed_field_del(self):
4553 codestr = """
4554 class D:
4555 def __init__(self, counter):
4556 self.counter = counter
4557 self.counter[0] += 1
4558
4559 def __del__(self):
4560 self.counter[0] -= 1
4561
4562 class C:
4563 def __init__(self, value: D):
4564 self.x: D = value
4565
4566 def __del__(self):
4567 del self.x
4568 """
4569 count = 0
4570 with self.in_module(codestr) as mod:
4571 D = mod["D"]
4572 counter = [0]
4573 d = D(counter)
4574
4575 C = mod["C"]
4576 a = C(d)
4577 del d
4578 self.assertEqual(counter[0], 1)
4579 del a
4580 self.assertEqual(counter[0], 0)
4581
4582 def test_typed_field_deleted_attr(self):
4583 codestr = """
4584 class C:
4585 def __init__(self, value: str):
4586 self.x: str = value
4587 """
4588 count = 0
4589 with self.in_module(codestr) as mod:
4590 C = mod["C"]
4591 a = C("abc")
4592 del a.x
4593 with self.assertRaises(AttributeError):
4594 a.x
4595
4596 def test_generic_method_ret_type(self):
4597 codestr = """
4598 from __static__ import CheckedDict
4599
4600 from typing import Optional
4601 MAP: CheckedDict[str, Optional[str]] = CheckedDict[str, Optional[str]]({'abc': 'foo', 'bar': None})
4602 def f(x: str) -> Optional[str]:
4603 return MAP.get(x)
4604 """
4605
4606 with self.in_module(codestr) as mod:
4607 f = mod["f"]
4608 self.assertInBytecode(
4609 f,
4610 "INVOKE_METHOD",
4611 (
4612 (
4613 "__static__",
4614 "chkdict",
4615 (("builtins", "str"), ("builtins", "str", "?")),
4616 "get",
4617 ),
4618 2,
4619 ),
4620 )
4621 self.assertEqual(f("abc"), "foo")
4622 self.assertEqual(f("bar"), None)
4623
4624
4625 def test_attr_generic_optional(self):
4626 codestr = """
4627 from typing import Optional
4628 def f(x: Optional):
4629 return x.foo
4630 """
4631
4632 with self.assertRaisesRegex(
4633 TypedSyntaxError, "cannot access attribute from unbound Union"
4634 ):
4635 self.compile(codestr, StaticCodeGenerator, modname="foo")
4636
4637 def test_assign_generic_optional(self):
4638 codestr = """
4639 from typing import Optional
4640 def f():
4641 x: Optional = 42
4642 """
4643
4644 with self.assertRaisesRegex(
4645 TypedSyntaxError, type_mismatch("Exact[int]", "Optional[T]")
4646 ):
4647 self.compile(codestr, StaticCodeGenerator, modname="foo")
4648
4649 def test_assign_generic_optional_2(self):
4650 codestr = """
4651 from typing import Optional
4652 def f():
4653 x: Optional = 42 + 1
4654 """
4655
4656 with self.assertRaises(TypedSyntaxError):
4657 self.compile(codestr, StaticCodeGenerator, modname="foo")
4658
4659 def test_assign_from_generic_optional(self):
4660 codestr = """
4661 from typing import Optional
4662 class C: pass
4663
4664 def f(x: Optional):
4665 y: Optional[C] = x
4666 """
4667
4668 with self.assertRaisesRegex(
4669 TypedSyntaxError, type_mismatch("Optional[T]", optional("foo.C"))
4670 ):
4671 self.compile(codestr, StaticCodeGenerator, modname="foo")
4672
4673 def test_list_of_dynamic(self):
4674 codestr = """
4675 from threading import Thread
4676 from typing import List
4677
4678 def f(threads: List[Thread]) -> int:
4679 return len(threads)
4680 """
4681 f = self.find_code(self.compile(codestr), "f")
4682 self.assertInBytecode(f, "FAST_LEN")
4683
4684 def test_int_swap(self):
4685 codestr = """
4686 from __static__ import int64, box
4687
4688 def test():
4689 x: int64 = 42
4690 y: int64 = 100
4691 x, y = y, x
4692 return box(x), box(y)
4693 """
4694
4695 with self.assertRaisesRegex(
4696 TypedSyntaxError, type_mismatch("int64", "dynamic")
4697 ):
4698 self.compile(codestr, StaticCodeGenerator, modname="foo")
4699
4700 def test_typed_swap(self):
4701 codestr = """
4702 def test(a):
4703 x: int
4704 y: str
4705 x, y = 1, a
4706 """
4707
4708 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo"))
4709 self.assertInBytecode(f, "CAST", ("builtins", "str"))
4710 self.assertNotInBytecode(f, "CAST", ("builtins", "int"))
4711
4712 def test_typed_swap_2(self):
4713 codestr = """
4714 def test(a):
4715 x: int
4716 y: str
4717 x, y = a, 'abc'
4718
4719 """
4720
4721 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo"))
4722 self.assertInBytecode(f, "CAST", ("builtins", "int"))
4723 self.assertNotInBytecode(f, "CAST", ("builtins", "str"))
4724
4725 def test_typed_swap_member(self):
4726 codestr = """
4727 class C:
4728 def __init__(self):
4729 self.x: int = 42
4730
4731 def test(a):
4732 x: int
4733 y: str
4734 C().x, y = a, 'abc'
4735
4736 """
4737
4738 f = self.find_code(
4739 self.compile(codestr, StaticCodeGenerator, modname="foo"), "test"
4740 )
4741 self.assertInBytecode(f, "CAST", ("builtins", "int"))
4742 self.assertNotInBytecode(f, "CAST", ("builtins", "str"))
4743
4744 def test_typed_swap_list(self):
4745 codestr = """
4746 def test(a):
4747 x: int
4748 y: str
4749 [x, y] = a, 'abc'
4750 """
4751
4752 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo"))
4753 self.assertInBytecode(f, "CAST", ("builtins", "int"))
4754 self.assertNotInBytecode(f, "CAST", ("builtins", "str"))
4755
4756 def test_typed_swap_nested(self):
4757 codestr = """
4758 def test(a):
4759 x: int
4760 y: str
4761 z: str
4762 ((x, y), z) = (a, 'abc'), 'foo'
4763 """
4764
4765 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo"))
4766 self.assertInBytecode(f, "CAST", ("builtins", "int"))
4767 self.assertNotInBytecode(f, "CAST", ("builtins", "str"))
4768
4769 def test_typed_swap_nested_2(self):
4770 codestr = """
4771 def test(a):
4772 x: int
4773 y: str
4774 z: str
4775 ((x, y), z) = (1, a), 'foo'
4776
4777 """
4778
4779 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo"))
4780 self.assertInBytecode(f, "CAST", ("builtins", "str"))
4781 self.assertNotInBytecode(f, "CAST", ("builtins", "int"))
4782
4783 def test_typed_swap_nested_3(self):
4784 codestr = """
4785 def test(a):
4786 x: int
4787 y: int
4788 z: str
4789 ((x, y), z) = (1, 2), a
4790
4791 """
4792
4793 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo"))
4794 self.assertInBytecode(f, "CAST", ("builtins", "str"))
4795 # Currently because the tuple gets turned into a constant this is less than
4796 # ideal:
4797 self.assertInBytecode(f, "CAST", ("builtins", "int"))
4798
4799 def test_if_optional(self):
4800 codestr = """
4801 from typing import Optional
4802 class C:
4803 def __init__(self):
4804 self.field = 42
4805
4806 def f(x: Optional[C]):
4807 if x is not None:
4808 return x.field
4809
4810 return None
4811 """
4812
4813 self.compile(codestr, StaticCodeGenerator, modname="foo")
4814
4815 def test_return_outside_func(self):
4816 codestr = """
4817 return 42
4818 """
4819 with self.assertRaisesRegex(SyntaxError, "'return' outside function"):
4820 self.compile(codestr, StaticCodeGenerator, modname="foo")
4821
4822 def test_double_decl(self):
4823 codestr = """
4824 def f():
4825 x: int
4826 x: str
4827 """
4828 with self.assertRaisesRegex(
4829 TypedSyntaxError, "Cannot redefine local variable x"
4830 ):
4831 self.compile(codestr, StaticCodeGenerator, modname="foo")
4832
4833 codestr = """
4834 def f():
4835 x = 42
4836 x: str
4837 """
4838 with self.assertRaisesRegex(
4839 TypedSyntaxError, "Cannot redefine local variable x"
4840 ):
4841 self.compile(codestr, StaticCodeGenerator, modname="foo")
4842
4843 codestr = """
4844 def f():
4845 x = 42
4846 x: int
4847 """
4848 with self.assertRaisesRegex(
4849 TypedSyntaxError, "Cannot redefine local variable x"
4850 ):
4851 self.compile(codestr, StaticCodeGenerator, modname="foo")
4852
4853 def test_if_else_optional(self):
4854 codestr = """
4855 from typing import Optional
4856 class C:
4857 def __init__(self):
4858 self.field = self
4859
4860 def g(x: C):
4861 pass
4862
4863 def f(x: Optional[C], y: Optional[C]):
4864 if x is None:
4865 x = y
4866 if x is None:
4867 return None
4868 else:
4869 return g(x)
4870 else:
4871 return g(x)
4872
4873
4874 return None
4875 """
4876
4877 self.compile(codestr, StaticCodeGenerator, modname="foo")
4878
4879 def test_if_else_optional_return(self):
4880 codestr = """
4881 from typing import Optional
4882 class C:
4883 def __init__(self):
4884 self.field = self
4885
4886 def f(x: Optional[C]):
4887 if x is None:
4888 return 0
4889 return x.field
4890 """
4891
4892 self.compile(codestr, StaticCodeGenerator, modname="foo")
4893
4894 def test_if_else_optional_return_two_branches(self):
4895 codestr = """
4896 from typing import Optional
4897 class C:
4898 def __init__(self):
4899 self.field = self
4900
4901 def f(x: Optional[C]):
4902 if x is None:
4903 if a:
4904 return 0
4905 else:
4906 return 2
4907 return x.field
4908 """
4909
4910 self.compile(codestr, StaticCodeGenerator, modname="foo")
4911
4912 def test_if_else_optional_return_in_else(self):
4913 codestr = """
4914 from typing import Optional
4915
4916 def f(x: Optional[int]) -> int:
4917 if x is not None:
4918 pass
4919 else:
4920 return 2
4921 return x
4922 """
4923
4924 self.compile(codestr, StaticCodeGenerator, modname="foo")
4925
4926 def test_if_else_optional_return_in_else_assignment_in_if(self):
4927 codestr = """
4928 from typing import Optional
4929
4930 def f(x: Optional[int]) -> int:
4931 if x is None:
4932 x = 1
4933 else:
4934 return 2
4935 return x
4936 """
4937
4938 self.compile(codestr, StaticCodeGenerator, modname="foo")
4939
4940 def test_if_else_optional_return_in_if_assignment_in_else(self):
4941 codestr = """
4942 from typing import Optional
4943
4944 def f(x: Optional[int]) -> int:
4945 if x is not None:
4946 return 2
4947 else:
4948 x = 1
4949 return x
4950 """
4951
4952 self.compile(codestr, StaticCodeGenerator, modname="foo")
4953
4954 def test_narrow_conditional(self):
4955 codestr = """
4956 class B:
4957 def f(self):
4958 return 42
4959 class D(B):
4960 def f(self):
4961 return 'abc'
4962
4963 def testfunc(abc):
4964 x = B()
4965 if abc:
4966 x = D()
4967 return x.f()
4968 """
4969
4970 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
4971 f = self.find_code(code, "testfunc")
4972 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "D", "f"), 0))
4973 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
4974 test = mod["testfunc"]
4975 self.assertEqual(test(True), "abc")
4976 self.assertEqual(test(False), None)
4977
4978 def test_no_narrow_to_dynamic(self):
4979 codestr = """
4980 def f():
4981 return 42
4982
4983 def g():
4984 x: int = 100
4985 x = f()
4986 return x.bit_length()
4987 """
4988
4989 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
4990 g = mod["g"]
4991 self.assertInBytecode(g, "CAST", ("builtins", "int"))
4992 self.assertInBytecode(
4993 g, "INVOKE_METHOD", (("builtins", "int", "bit_length"), 0)
4994 )
4995 self.assertEqual(g(), 6)
4996
4997 def test_global_uses_decl_type(self):
4998 codestr = """
4999 # even though we can locally infer G must be None,
5000 # it's not Final so nested scopes can't assume it
5001 # remains None
5002 G: int | None = None
5003
5004 def f() -> int:
5005 global G
5006 # if we use the local_type for G's type,
5007 # x would have a local type of None
5008 x: int | None = G
5009 if x is None:
5010 x = G = 1
5011 return x
5012 """
5013 with self.in_strict_module(codestr) as mod:
5014 self.assertEqual(mod.f(), 1)
5015
5016 def test_module_level_type_narrow(self):
5017 codestr = """
5018 def a() -> int | None:
5019 return 1
5020
5021 G = a()
5022 if G is not None:
5023 G += 1
5024
5025 def f() -> int:
5026 if G is None:
5027 return 0
5028 reveal_type(G)
5029 """
5030 with self.assertRaisesRegex(TypedSyntaxError, r"Optional\[int\]"):
5031 self.compile(codestr)
5032
5033 def test_narrow_conditional_widened(self):
5034 codestr = """
5035 class B:
5036 def f(self):
5037 return 42
5038 class D(B):
5039 def f(self):
5040 return 'abc'
5041
5042 def testfunc(abc):
5043 x = B()
5044 if abc:
5045 x = D()
5046 return x.f()
5047 """
5048
5049 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5050 f = self.find_code(code, "testfunc")
5051 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5052 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5053 test = mod["testfunc"]
5054 self.assertEqual(test(True), "abc")
5055 self.assertEqual(test(False), 42)
5056
5057 def test_widen_to_dynamic(self):
5058 self.assertReturns(
5059 """
5060 def f(x, flag):
5061 if flag:
5062 x = 3
5063 return x
5064 """,
5065 "dynamic",
5066 )
5067
5068 def test_assign_conditional_both_sides(self):
5069 codestr = """
5070 class B:
5071 def f(self):
5072 return 42
5073 class D(B):
5074 def f(self):
5075 return 'abc'
5076
5077 def testfunc(abc):
5078 x = B()
5079 if abc:
5080 x = D()
5081 else:
5082 x = D()
5083 return x.f()
5084 """
5085
5086 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5087 f = self.find_code(code, "testfunc")
5088 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "D", "f"), 0))
5089 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5090 test = mod["testfunc"]
5091 self.assertEqual(test(True), "abc")
5092
5093 def test_assign_conditional_invoke_in_else(self):
5094 codestr = """
5095 class B:
5096 def f(self):
5097 return 42
5098 class D(B):
5099 def f(self):
5100 return 'abc'
5101
5102 def testfunc(abc):
5103 x = B()
5104 if abc:
5105 x = D()
5106 else:
5107 return x.f()
5108 """
5109
5110 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5111 f = self.find_code(code, "testfunc")
5112 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5113 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5114 test = mod["testfunc"]
5115 self.assertEqual(test(True), None)
5116
5117 def test_assign_else_only(self):
5118 codestr = """
5119 class B:
5120 def f(self):
5121 return 42
5122 class D(B):
5123 def f(self):
5124 return 'abc'
5125
5126 def testfunc(abc):
5127 if abc:
5128 pass
5129 else:
5130 x = B()
5131 return x.f()
5132 """
5133
5134 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5135 f = self.find_code(code, "testfunc")
5136 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5137 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5138 test = mod["testfunc"]
5139 self.assertEqual(test(False), 42)
5140
5141 def test_assign_test_var(self):
5142 codestr = """
5143 from typing import Optional
5144
5145 def f(x: Optional[int]) -> int:
5146 if x is None:
5147 x = 1
5148 return x
5149 """
5150
5151 self.compile(codestr, StaticCodeGenerator, modname="foo")
5152
5153 def test_assign_while(self):
5154 codestr = """
5155 class B:
5156 def f(self):
5157 return 42
5158 class D(B):
5159 def f(self):
5160 return 'abc'
5161
5162 def testfunc(abc):
5163 x = B()
5164 while abc:
5165 x = D()
5166 return x.f()
5167 """
5168
5169 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5170 f = self.find_code(code, "testfunc")
5171 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5172 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5173 test = mod["testfunc"]
5174 self.assertEqual(test(False), 42)
5175
5176 def test_assign_while_test_var(self):
5177 codestr = """
5178 from typing import Optional
5179
5180 def f(x: Optional[int]) -> int:
5181 while x is None:
5182 x = 1
5183 return x
5184 """
5185 self.compile(codestr, StaticCodeGenerator, modname="foo")
5186
5187 def test_assign_while_returns(self):
5188 codestr = """
5189 from typing import Optional
5190
5191 def f(x: Optional[int]) -> int:
5192 while x is None:
5193 return 1
5194 return x
5195 """
5196 self.compile(codestr, StaticCodeGenerator, modname="foo")
5197
5198 def test_assign_while_returns_but_assigns_first(self):
5199 codestr = """
5200 from typing import Optional
5201
5202 def f(x: Optional[int]) -> int:
5203 y: Optional[int] = 1
5204 while x is None:
5205 y = None
5206 return 1
5207 return y
5208 """
5209 self.compile(codestr, StaticCodeGenerator, modname="foo")
5210
5211 def test_while_else_reverses_condition(self):
5212 codestr = """
5213 from typing import Optional
5214
5215 def f(x: Optional[int]) -> int:
5216 while x is None:
5217 pass
5218 else:
5219 return x
5220 return 1
5221 """
5222 self.compile(codestr, StaticCodeGenerator, modname="foo")
5223
5224 def test_continue_condition(self):
5225 codestr = """
5226 from typing import Optional
5227
5228 def f(x: Optional[str]) -> str:
5229 while True:
5230 if x is None:
5231 continue
5232 return x
5233 """
5234 self.compile(codestr, StaticCodeGenerator, modname="foo")
5235
5236 def test_break_condition(self):
5237 codestr = """
5238 from typing import Optional
5239
5240 def f(x: Optional[str]) -> str:
5241 while True:
5242 if x is None:
5243 break
5244 return x
5245 """
5246 self.compile(codestr, StaticCodeGenerator, modname="foo")
5247
5248 def test_assign_but_annotated(self):
5249 codestr = """
5250 class B:
5251 def f(self):
5252 return 42
5253 class D(B):
5254 def f(self):
5255 return 'abc'
5256
5257 def testfunc(abc):
5258 x: B = D()
5259 return x.f()
5260 """
5261
5262 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5263 f = self.find_code(code, "testfunc")
5264 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "D", "f"), 0))
5265 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5266 test = mod["testfunc"]
5267 self.assertEqual(test(False), "abc")
5268
5269 def test_assign_while_2(self):
5270 codestr = """
5271 class B:
5272 def f(self):
5273 return 42
5274 class D(B):
5275 def f(self):
5276 return 'abc'
5277
5278 def testfunc(abc):
5279 x: B = D()
5280 while abc:
5281 x = B()
5282 return x.f()
5283 """
5284
5285 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5286 f = self.find_code(code, "testfunc")
5287 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5288 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5289 test = mod["testfunc"]
5290 self.assertEqual(test(False), "abc")
5291
5292 def test_assign_while_else(self):
5293 codestr = """
5294 class B:
5295 def f(self):
5296 return 42
5297 class D(B):
5298 def f(self):
5299 return 'abc'
5300
5301 def testfunc(abc):
5302 x = B()
5303 while abc:
5304 pass
5305 else:
5306 x = D()
5307 return x.f()
5308 """
5309
5310 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5311 f = self.find_code(code, "testfunc")
5312 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5313 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5314 test = mod["testfunc"]
5315 self.assertEqual(test(False), "abc")
5316
5317 def test_assign_while_else_2(self):
5318 codestr = """
5319 class B:
5320 def f(self):
5321 return 42
5322 class D(B):
5323 def f(self):
5324 return 'abc'
5325
5326 def testfunc(abc):
5327 x: B = D()
5328 while abc:
5329 pass
5330 else:
5331 x = B()
5332 return x.f()
5333 """
5334
5335 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5336 f = self.find_code(code, "testfunc")
5337 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5338 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5339 test = mod["testfunc"]
5340 self.assertEqual(test(False), 42)
5341
5342 def test_assign_try_except_no_initial(self):
5343 codestr = """
5344 class B:
5345 def f(self):
5346 return 42
5347 class D(B):
5348 def f(self):
5349 return 'abc'
5350
5351 def testfunc():
5352 try:
5353 x: B = D()
5354 except:
5355 x = B()
5356 return x.f()
5357 """
5358
5359 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5360 f = self.find_code(code, "testfunc")
5361 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5362 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5363 test = mod["testfunc"]
5364 self.assertEqual(test(), "abc")
5365
5366 def test_narrow_or(self):
5367 codestr = """
5368 def f(x: int | None) -> int:
5369 if x is None or x > 1:
5370 x = 1
5371 return x
5372 """
5373 self.compile(codestr)
5374
5375 def test_type_of_or(self):
5376 codestr = """
5377 def f(x: int, y: str) -> int | str:
5378 return x or y
5379 """
5380 self.compile(codestr)
5381
5382 def test_none_annotation(self):
5383 codestr = """
5384 from typing import Optional
5385
5386 def f(x: Optional[int]) -> None:
5387 return x
5388 """
5389 with self.assertRaisesRegex(
5390 TypedSyntaxError,
5391 type_mismatch("Optional[int]", "None"),
5392 ):
5393 self.compile(codestr, StaticCodeGenerator, modname="foo")
5394
5395 def test_none_compare(self):
5396 codestr = """
5397 def f(x: int | None):
5398 if x > 1:
5399 x = 1
5400 return x
5401 """
5402 with self.assertRaisesRegex(
5403 TypedSyntaxError,
5404 r"'>' not supported between 'Optional\[int\]' and 'Exact\[int\]'",
5405 ):
5406 self.compile(codestr)
5407
5408 def test_none_compare_reverse(self):
5409 codestr = """
5410 def f(x: int | None):
5411 if 1 > x:
5412 x = 1
5413 return x
5414 """
5415 with self.assertRaisesRegex(
5416 TypedSyntaxError,
5417 r"'>' not supported between 'Exact\[int\]' and 'Optional\[int\]'",
5418 ):
5419 self.compile(codestr)
5420
5421 def test_union_compare(self):
5422 codestr = """
5423 def f(x: int | float) -> bool:
5424 return x > 0
5425 """
5426 with self.in_strict_module(codestr) as mod:
5427 self.assertEqual(mod.f(3), True)
5428 self.assertEqual(mod.f(3.1), True)
5429 self.assertEqual(mod.f(-3), False)
5430 self.assertEqual(mod.f(-3.1), False)
5431
5432 def test_global_int(self):
5433 codestr = """
5434 X: int = 60 * 60 * 24
5435 """
5436
5437 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5438 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5439 X = mod["X"]
5440 self.assertEqual(X, 60 * 60 * 24)
5441
5442 def test_with_traceback(self):
5443 codestr = """
5444 def f():
5445 x = Exception()
5446 return x.with_traceback(None)
5447 """
5448
5449 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5450 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5451 f = mod["f"]
5452 self.assertEqual(type(f()), Exception)
5453 self.assertInBytecode(
5454 f, "INVOKE_METHOD", (("builtins", "BaseException", "with_traceback"), 1)
5455 )
5456
5457 def test_assign_num_to_object(self):
5458 codestr = """
5459 def f():
5460 x: object = 42
5461 """
5462
5463 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5464 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5465 f = mod["f"]
5466 self.assertNotInBytecode(f, "CAST", ("builtins", "object"))
5467
5468 def test_assign_num_to_dynamic(self):
5469 codestr = """
5470 def f():
5471 x: foo = 42
5472 """
5473
5474 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5475 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5476 f = mod["f"]
5477 self.assertNotInBytecode(f, "CAST", ("builtins", "object"))
5478
5479 def test_assign_dynamic_to_object(self):
5480 codestr = """
5481 def f(C):
5482 x: object = C()
5483 """
5484
5485 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5486 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5487 f = mod["f"]
5488 self.assertNotInBytecode(f, "CAST", ("builtins", "object"))
5489
5490 def test_assign_dynamic_to_dynamic(self):
5491 codestr = """
5492 def f(C):
5493 x: unknown = C()
5494 """
5495
5496 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5497 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5498 f = mod["f"]
5499 self.assertNotInBytecode(f, "CAST", ("builtins", "object"))
5500
5501 def test_assign_constant_to_object(self):
5502 codestr = """
5503 def f():
5504 x: object = 42 + 1
5505 """
5506
5507 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5508 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5509 f = mod["f"]
5510 self.assertNotInBytecode(f, "CAST", ("builtins", "object"))
5511
5512 def test_assign_try_except_typing(self):
5513 codestr = """
5514 def testfunc():
5515 try:
5516 pass
5517 except Exception as e:
5518 pass
5519 return 42
5520 """
5521
5522 # We don't do anything special w/ Exception type yet, but it should compile
5523 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5524
5525 def test_assign_try_except_typing_predeclared(self):
5526 codestr = """
5527 def testfunc():
5528 e: Exception
5529 try:
5530 pass
5531 except Exception as e:
5532 pass
5533 return 42
5534 """
5535 # We don't do anything special w/ Exception type yet, but it should compile
5536 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5537
5538 def test_assign_try_except_typing_narrowed(self):
5539 codestr = """
5540 class E(Exception):
5541 pass
5542
5543 def testfunc():
5544 e: Exception
5545 try:
5546 pass
5547 except E as e:
5548 pass
5549 return 42
5550 """
5551 # We don't do anything special w/ Exception type yet, but it should compile
5552 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5553
5554 def test_assign_try_except_typing_redeclared_after(self):
5555 codestr = """
5556 def testfunc():
5557 try:
5558 pass
5559 except Exception as e:
5560 pass
5561 e: int = 42
5562 return 42
5563 """
5564 # We don't do anything special w/ Exception type yet, but it should compile
5565 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5566
5567 def test_assign_try_except_redeclare(self):
5568 codestr = """
5569 def testfunc():
5570 e: int
5571 try:
5572 pass
5573 except Exception as e:
5574 pass
5575 return 42
5576 """
5577
5578 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5579
5580 def test_assign_try_except_redeclare_unknown_type(self):
5581 codestr = """
5582 def testfunc():
5583 e: int
5584 try:
5585 pass
5586 except UnknownException as e:
5587 pass
5588 return 42
5589 """
5590
5591 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5592
5593 def test_assign_try_assign_in_except(self):
5594 codestr = """
5595 class B:
5596 def f(self):
5597 return 42
5598 class D(B):
5599 def f(self):
5600 return 'abc'
5601
5602 def testfunc():
5603 x: B = D()
5604 try:
5605 pass
5606 except:
5607 x = B()
5608 return x.f()
5609 """
5610
5611 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5612 f = self.find_code(code, "testfunc")
5613 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5614 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5615 test = mod["testfunc"]
5616 self.assertEqual(test(), "abc")
5617
5618 def test_assign_try_assign_in_second_except(self):
5619 codestr = """
5620 class B:
5621 def f(self):
5622 return 42
5623 class D(B):
5624 def f(self):
5625 return 'abc'
5626
5627 def testfunc():
5628 x: B = D()
5629 try:
5630 pass
5631 except TypeError:
5632 pass
5633 except:
5634 x = B()
5635 return x.f()
5636 """
5637
5638 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5639 f = self.find_code(code, "testfunc")
5640 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5641 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5642 test = mod["testfunc"]
5643 self.assertEqual(test(), "abc")
5644
5645 def test_assign_try_assign_in_except_with_var(self):
5646 codestr = """
5647 class B:
5648 def f(self):
5649 return 42
5650 class D(B):
5651 def f(self):
5652 return 'abc'
5653
5654 def testfunc():
5655 x: B = D()
5656 try:
5657 pass
5658 except TypeError as e:
5659 x = B()
5660 return x.f()
5661 """
5662
5663 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5664 f = self.find_code(code, "testfunc")
5665 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5666 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5667 test = mod["testfunc"]
5668 self.assertEqual(test(), "abc")
5669
5670 def test_try_except_finally(self):
5671 codestr = """
5672 class B:
5673 def f(self):
5674 return 42
5675 class D(B):
5676 def f(self):
5677 return 'abc'
5678
5679 def testfunc():
5680 x: B = D()
5681 try:
5682 pass
5683 except TypeError:
5684 pass
5685 finally:
5686 x = B()
5687 return x.f()
5688 """
5689
5690 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5691 f = self.find_code(code, "testfunc")
5692 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5693 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5694 test = mod["testfunc"]
5695 self.assertEqual(test(), 42)
5696
5697 def test_assign_try_assign_in_try(self):
5698 codestr = """
5699 class B:
5700 def f(self):
5701 return 42
5702 class D(B):
5703 def f(self):
5704 return 'abc'
5705
5706 def testfunc():
5707 x: B = D()
5708 try:
5709 x = B()
5710 except:
5711 pass
5712 return x.f()
5713 """
5714
5715 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5716 f = self.find_code(code, "testfunc")
5717 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5718 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5719 test = mod["testfunc"]
5720 self.assertEqual(test(), 42)
5721
5722 def test_assign_try_assign_in_finally(self):
5723 codestr = """
5724 class B:
5725 def f(self):
5726 return 42
5727 class D(B):
5728 def f(self):
5729 return 'abc'
5730
5731 def testfunc():
5732 x: B = D()
5733 try:
5734 pass
5735 finally:
5736 x = B()
5737 return x.f()
5738 """
5739
5740 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5741 f = self.find_code(code, "testfunc")
5742 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5743 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5744 test = mod["testfunc"]
5745 self.assertEqual(test(), 42)
5746
5747 def test_assign_try_assign_in_else(self):
5748 codestr = """
5749 class B:
5750 def f(self):
5751 return 42
5752 class D(B):
5753 def f(self):
5754 return 'abc'
5755
5756 def testfunc():
5757 x: B = D()
5758 try:
5759 pass
5760 except:
5761 pass
5762 else:
5763 x = B()
5764 return x.f()
5765 """
5766
5767 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5768 f = self.find_code(code, "testfunc")
5769 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0))
5770 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5771 test = mod["testfunc"]
5772 self.assertEqual(test(), 42)
5773
5774 def test_if_optional_reassign(self):
5775 codestr = """
5776 class C: pass
5777
5778 def testfunc(abc: Optional[C]):
5779 if abc is not None:
5780 abc = None
5781 """
5782 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5783
5784 def test_widening_assign(self):
5785 codestr = """
5786 from __static__ import int8, int16, box
5787
5788 def testfunc():
5789 x: int16
5790 y: int8
5791 x = y = 42
5792 return box(x), box(y)
5793 """
5794 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5795 test = mod["testfunc"]
5796 self.assertEqual(test(), (42, 42))
5797
5798 def test_unknown_imported_annotation(self):
5799 codestr = """
5800 from unknown_mod import foo
5801
5802 def testfunc():
5803 x: foo = 42
5804 return x
5805 """
5806 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
5807
5808 def test_widening_assign_reassign(self):
5809 codestr = """
5810 from __static__ import int8, int16, box
5811
5812 def testfunc():
5813 x: int16
5814 y: int8
5815 x = y = 42
5816 x = 257
5817 return box(x), box(y)
5818 """
5819 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5820 test = mod["testfunc"]
5821 self.assertEqual(test(), (257, 42))
5822
5823 def test_widening_assign_reassign_error(self):
5824 codestr = """
5825 from __static__ import int8, int16, box
5826
5827 def testfunc():
5828 x: int16
5829 y: int8
5830 x = y = 42
5831 y = 128
5832 return box(x), box(y)
5833 """
5834 with self.assertRaisesRegex(
5835 TypedSyntaxError,
5836 "constant 128 is outside of the range -128 to 127 for int8",
5837 ):
5838 self.compile(codestr, StaticCodeGenerator, modname="foo")
5839
5840 def test_narrowing_assign_literal(self):
5841 codestr = """
5842 from __static__ import int8, int16, box
5843
5844 def testfunc():
5845 x: int8
5846 y: int16
5847 x = y = 42
5848 return box(x), box(y)
5849 """
5850 self.compile(codestr, StaticCodeGenerator, modname="foo")
5851
5852 def test_narrowing_assign_out_of_range(self):
5853 codestr = """
5854 from __static__ import int8, int16, box
5855
5856 def testfunc():
5857 x: int8
5858 y: int16
5859 x = y = 300
5860 return box(x), box(y)
5861 """
5862 with self.assertRaisesRegex(
5863 TypedSyntaxError,
5864 "constant 300 is outside of the range -128 to 127 for int8",
5865 ):
5866 self.compile(codestr, StaticCodeGenerator, modname="foo")
5867
5868 def test_module_primitive(self):
5869 codestr = """
5870 from __static__ import int8
5871 x: int8
5872 """
5873 with self.assertRaisesRegex(
5874 TypedSyntaxError, "cannot use primitives in global or closure scope"
5875 ):
5876 self.compile(codestr, StaticCodeGenerator, modname="foo")
5877
5878 def test_implicit_module_primitive(self):
5879 codestr = """
5880 from __static__ import int8
5881 x = y = int8(0)
5882 """
5883 with self.assertRaisesRegex(
5884 TypedSyntaxError, "cannot use primitives in global or closure scope"
5885 ):
5886 self.compile(codestr, StaticCodeGenerator, modname="foo")
5887
5888 def test_chained_primitive_to_non_primitive(self):
5889 codestr = """
5890 from __static__ import int8
5891 def f():
5892 x: object
5893 y: int8 = 42
5894 x = y = 42
5895 """
5896 with self.assertRaisesRegex(
5897 TypedSyntaxError, "int8 cannot be assigned to object"
5898 ):
5899 self.compile(codestr, StaticCodeGenerator, modname="foo")
5900
5901 def test_closure_primitive(self):
5902 codestr = """
5903 from __static__ import int8
5904 def f():
5905 x: int8 = 0
5906 def g():
5907 return x
5908 """
5909 with self.assertRaisesRegex(
5910 TypedSyntaxError, "cannot use primitives in global or closure scope"
5911 ):
5912 self.compile(codestr, StaticCodeGenerator, modname="foo")
5913
5914 def test_nonlocal_primitive(self):
5915 codestr = """
5916 from __static__ import int8
5917 def f():
5918 x: int8 = 0
5919 def g():
5920 nonlocal x
5921 x = 1
5922 """
5923 with self.assertRaisesRegex(
5924 TypedSyntaxError, "cannot use primitives in global or closure scope"
5925 ):
5926 self.compile(codestr, StaticCodeGenerator, modname="foo")
5927
5928 def test_dynamic_chained_assign_param(self):
5929 codestr = """
5930 from __static__ import int16
5931 def testfunc(y):
5932 x: int16
5933 x = y = 42
5934 return box(x)
5935 """
5936 with self.assertRaisesRegex(
5937 TypedSyntaxError, type_mismatch("Exact[int]", "int16")
5938 ):
5939 self.compile(codestr, StaticCodeGenerator, modname="foo")
5940
5941 def test_dynamic_chained_assign_param_2(self):
5942 codestr = """
5943 from __static__ import int16
5944 def testfunc(y):
5945 x: int16
5946 y = x = 42
5947 """
5948 with self.assertRaisesRegex(
5949 TypedSyntaxError, type_mismatch("int16", "dynamic")
5950 ):
5951 self.compile(codestr, StaticCodeGenerator, modname="foo")
5952
5953 def test_dynamic_chained_assign_1(self):
5954 codestr = """
5955 from __static__ import int16, box
5956 def testfunc():
5957 x: int16
5958 x = y = 42
5959 return box(x)
5960 """
5961 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5962 test = mod["testfunc"]
5963 self.assertEqual(test(), 42)
5964
5965 def test_dynamic_chained_assign_2(self):
5966 codestr = """
5967 from __static__ import int16, box
5968 def testfunc():
5969 x: int16
5970 y = x = 42
5971 return box(y)
5972 """
5973 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
5974 test = mod["testfunc"]
5975 self.assertEqual(test(), 42)
5976
5977 def test_tuple_assign_list(self):
5978 codestr = """
5979 from __static__ import int16, box
5980 def testfunc(a: int, b: int):
5981 x: int
5982 y: str
5983 x, y = [a, b]
5984 """
5985 with self.assertRaisesRegex(TypedSyntaxError, "int cannot be assigned to str"):
5986 self.compile(codestr, StaticCodeGenerator, modname="foo")
5987
5988 def test_tuple_assign_tuple(self):
5989 codestr = """
5990 from __static__ import int16, box
5991 def testfunc(a: int, b: int):
5992 x: int
5993 y: str
5994 x, y = a, b
5995 """
5996 with self.assertRaisesRegex(TypedSyntaxError, "int cannot be assigned to str"):
5997 self.compile(codestr, StaticCodeGenerator, modname="foo")
5998
5999 def test_tuple_assign_constant(self):
6000 codestr = """
6001 from __static__ import int16, box
6002 def testfunc():
6003 x: int
6004 y: str
6005 x, y = 1, 1
6006 """
6007 with self.assertRaisesRegex(
6008 TypedSyntaxError,
6009 r"type mismatch: Exact\[int\] cannot be assigned to str",
6010 ):
6011 self.compile(codestr, StaticCodeGenerator, modname="foo")
6012
6013 def test_if_optional_cond(self):
6014 codestr = """
6015 from typing import Optional
6016 class C:
6017 def __init__(self):
6018 self.field = 42
6019
6020 def f(x: Optional[C]):
6021 return x.field if x is not None else None
6022 """
6023
6024 self.compile(codestr, StaticCodeGenerator, modname="foo")
6025
6026 def test_while_optional_cond(self):
6027 codestr = """
6028 from typing import Optional
6029 class C:
6030 def __init__(self):
6031 self.field: Optional["C"] = self
6032
6033 def f(x: Optional[C]):
6034 while x is not None:
6035 val: Optional[C] = x.field
6036 if val is not None:
6037 x = val
6038 """
6039
6040 self.compile(codestr, StaticCodeGenerator, modname="foo")
6041
6042 def test_if_optional_dependent_conditions(self):
6043 codestr = """
6044 from typing import Optional
6045 class C:
6046 def __init__(self):
6047 self.field: Optional[C] = None
6048
6049 def f(x: Optional[C]) -> C:
6050 if x is not None and x.field is not None:
6051 return x
6052
6053 if x is None:
6054 return C()
6055
6056 return x
6057 """
6058
6059 self.compile(codestr, StaticCodeGenerator, modname="foo")
6060
6061 def test_none_attribute_error(self):
6062 codestr = """
6063 def f():
6064 x = None
6065 return x.foo
6066 """
6067
6068 with self.assertRaisesRegex(
6069 TypedSyntaxError, "'NoneType' object has no attribute 'foo'"
6070 ):
6071 self.compile(codestr, StaticCodeGenerator, modname="foo")
6072
6073 def test_none_call(self):
6074 codestr = """
6075 def f():
6076 x = None
6077 return x()
6078 """
6079
6080 with self.assertRaisesRegex(
6081 TypedSyntaxError, "'NoneType' object is not callable"
6082 ):
6083 self.compile(codestr, StaticCodeGenerator, modname="foo")
6084
6085 def test_none_subscript(self):
6086 codestr = """
6087 def f():
6088 x = None
6089 return x[0]
6090 """
6091
6092 with self.assertRaisesRegex(
6093 TypedSyntaxError, "'NoneType' object is not subscriptable"
6094 ):
6095 self.compile(codestr, StaticCodeGenerator, modname="foo")
6096
6097 def test_none_unaryop(self):
6098 codestr = """
6099 def f():
6100 x = None
6101 return -x
6102 """
6103
6104 with self.assertRaisesRegex(
6105 TypedSyntaxError, "bad operand type for unary -: 'NoneType'"
6106 ):
6107 self.compile(codestr, StaticCodeGenerator, modname="foo")
6108
6109 def test_vector_import(self):
6110 codestr = """
6111 from __static__ import int64, Vector
6112
6113 def test() -> Vector[int64]:
6114 x: Vector[int64] = Vector[int64]()
6115 x.append(1)
6116 return x
6117 """
6118
6119 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
6120 test = mod["test"]
6121 self.assertEqual(test(), array("L", [1]))
6122
6123 def test_vector_assign_non_primitive(self):
6124 codestr = """
6125 from __static__ import int64, Vector
6126
6127 def test(abc) -> Vector[int64]:
6128 x: Vector[int64] = Vector[int64](2)
6129 i: int64 = 0
6130 x[i] = abc
6131 """
6132
6133 with self.assertRaisesRegex(
6134 TypedSyntaxError, "Cannot assign a dynamic to int64"
6135 ):
6136 self.compile(codestr)
6137
6138 def test_vector_sizes(self):
6139 for signed in ["int", "uint"]:
6140 for size in ["8", "16", "32", "64"]:
6141 with self.subTest(size=size, signed=signed):
6142 int_type = f"{signed}{size}"
6143 codestr = f"""
6144 from __static__ import {int_type}, Vector
6145
6146 def test() -> Vector[{int_type}]:
6147 x: Vector[{int_type}] = Vector[{int_type}]()
6148 y: {int_type} = 1
6149 x.append(y)
6150 return x
6151 """
6152
6153 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
6154 test = mod["test"]
6155 res = test()
6156 self.assertEqual(list(res), [1])
6157
6158 def test_vector_invalid_literal(self):
6159 codestr = f"""
6160 from __static__ import int8, Vector
6161
6162 def test() -> Vector[int8]:
6163 x: Vector[int8] = Vector[int8]()
6164 x.append(128)
6165 return x
6166 """
6167 with self.assertRaisesRegex(
6168 TypedSyntaxError,
6169 "type mismatch: int positional argument type mismatch int8",
6170 ):
6171 self.compile(codestr)
6172
6173 def test_vector_wrong_size(self):
6174 codestr = f"""
6175 from __static__ import int8, int16, Vector
6176
6177 def test() -> Vector[int8]:
6178 y: int16 = 1
6179 x: Vector[int8] = Vector[int8]()
6180 x.append(y)
6181 return x
6182 """
6183
6184 with self.assertRaisesRegex(
6185 TypedSyntaxError,
6186 "type mismatch: int16 positional argument type mismatch int8",
6187 ):
6188 self.compile(codestr)
6189
6190 def test_vector_presized(self):
6191 codestr = f"""
6192 from __static__ import int8, Vector
6193
6194 def test() -> Vector[int8]:
6195 x: Vector[int8] = Vector[int8](4)
6196 x[1] = 1
6197 return x
6198 """
6199
6200 with self.in_module(codestr) as mod:
6201 f = mod["test"]
6202 self.assertEqual(f(), array("b", [0, 1, 0, 0]))
6203
6204 def test_array_import(self):
6205 codestr = """
6206 from __static__ import int64, Array
6207
6208 def test() -> Array[int64]:
6209 x: Array[int64] = Array[int64](1)
6210 x[0] = 1
6211 return x
6212 """
6213
6214 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
6215 test = mod["test"]
6216 self.assertEqual(test(), array("L", [1]))
6217
6218 def test_array_create(self):
6219 codestr = """
6220 from __static__ import int64, Array
6221
6222 def test() -> Array[int64]:
6223 x: Array[int64] = Array[int64]([1, 3, 5])
6224 return x
6225 """
6226
6227 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
6228 test = mod["test"]
6229 self.assertEqual(test(), array("l", [1, 3, 5]))
6230
6231 def test_array_create_failure(self):
6232 # todo - in the future we're going to support this, but for now fail it.
6233 codestr = """
6234 from __static__ import int64, Array
6235
6236 class C: pass
6237
6238 def test() -> Array[C]:
6239 return Array[C]([1, 3, 5])
6240 """
6241 with self.assertRaisesRegex(
6242 TypedSyntaxError, "Invalid Array element type: foo.C"
6243 ):
6244 self.compile(codestr, StaticCodeGenerator, modname="foo")
6245
6246 def test_array_call_unbound(self):
6247 codestr = """
6248 from __static__ import Array
6249
6250 def f() -> Array:
6251 return Array([1, 2, 3])
6252 """
6253 with self.assertRaisesRegex(
6254 TypedSyntaxError,
6255 r"create instances of a generic Type\[Exact\[Array\[T\]\]\]",
6256 ):
6257 self.compile(codestr, StaticCodeGenerator, modname="foo")
6258
6259 def test_array_assign_wrong_type(self):
6260 codestr = """
6261 from __static__ import int64, char, Array
6262
6263 def test() -> None:
6264 x: Array[int64] = Array[char]([48])
6265 """
6266 with self.assertRaisesRegex(
6267 TypedSyntaxError,
6268 type_mismatch(
6269 "Exact[Array[char]]",
6270 "Array[int64]",
6271 ),
6272 ):
6273 self.compile(codestr, StaticCodeGenerator, modname="foo")
6274
6275 def test_array_subclass_assign(self):
6276 codestr = """
6277 from __static__ import int64, Array
6278
6279 class MyArray(Array):
6280 pass
6281
6282 def y(inexact: Array[int64]):
6283 exact = Array[int64]([1])
6284 exact = inexact
6285 """
6286 with self.assertRaisesRegex(
6287 TypedSyntaxError,
6288 type_mismatch(
6289 "Array[int64]",
6290 "Exact[Array[int64]]",
6291 ),
6292 ):
6293 self.compile(codestr, StaticCodeGenerator, modname="foo")
6294
6295 def test_array_types(self):
6296 codestr = """
6297 from __static__ import (
6298 int8,
6299 int16,
6300 int32,
6301 int64,
6302 uint8,
6303 uint16,
6304 uint32,
6305 uint64,
6306 char,
6307 double,
6308 Array
6309 )
6310 from typing import Tuple
6311
6312 def test() -> Tuple[Array[int64], Array[char], Array[double]]:
6313 x1: Array[int8] = Array[int8]([1, 3, -5])
6314 x2: Array[uint8] = Array[uint8]([1, 3, 5])
6315 x3: Array[int16] = Array[int16]([1, -3, 5])
6316 x4: Array[uint16] = Array[uint16]([1, 3, 5])
6317 x5: Array[int32] = Array[int32]([1, 3, 5])
6318 x6: Array[uint32] = Array[uint32]([1, 3, 5])
6319 x7: Array[int64] = Array[int64]([1, 3, 5])
6320 x8: Array[uint64] = Array[uint64]([1, 3, 5])
6321 x9: Array[char] = Array[char]([ord('a')])
6322 x10: Array[double] = Array[double]([1.1, 3.3, 5.5])
6323 x11: Array[float] = Array[float]([1.1, 3.3, 5.5])
6324 arrays = [
6325 x1,
6326 x2,
6327 x3,
6328 x4,
6329 x5,
6330 x6,
6331 x7,
6332 x8,
6333 x9,
6334 x10,
6335 x11,
6336 ]
6337 first_elements = []
6338 for ar in arrays:
6339 first_elements.append(ar[0])
6340 return (arrays, first_elements)
6341 """
6342
6343 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
6344 test = mod["test"]
6345 arrays, first_elements = test()
6346 exp_arrays = [
6347 array(*args)
6348 for args in [
6349 ("h", [1, 3, -5]),
6350 ("H", [1, 3, 5]),
6351 ("i", [1, -3, 5]),
6352 ("I", [1, 3, 5]),
6353 ("l", [1, 3, 5]),
6354 ("L", [1, 3, 5]),
6355 ("q", [1, 3, 5]),
6356 ("Q", [1, 3, 5]),
6357 ("b", [ord("a")]),
6358 ("d", [1.1, 3.3, 5.5]),
6359 ("f", [1.1, 3.3, 5.5]),
6360 ]
6361 ]
6362 exp_first_elements = [ar[0] for ar in exp_arrays]
6363 for result, expectation in zip(arrays, exp_arrays):
6364 self.assertEqual(result, expectation)
6365 for result, expectation in zip(first_elements, exp_first_elements):
6366 self.assertEqual(result, expectation)
6367
6368 def test_assign_type_propagation(self):
6369 codestr = """
6370 def test() -> int:
6371 x = 5
6372 return x
6373 """
6374 self.compile(codestr, StaticCodeGenerator, modname="foo")
6375
6376 def test_assign_subtype_handling(self):
6377 codestr = """
6378 class B: pass
6379 class D(B): pass
6380
6381 def f():
6382 b = B()
6383 b = D()
6384 b = B()
6385 """
6386 self.compile(codestr, StaticCodeGenerator, modname="foo")
6387
6388 def test_assign_subtype_handling_fail(self):
6389 codestr = """
6390 class B: pass
6391 class D(B): pass
6392
6393 def f():
6394 d = D()
6395 d = B()
6396 """
6397 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("foo.B", "foo.D")):
6398 self.compile(codestr, StaticCodeGenerator, modname="foo")
6399
6400 def test_assign_chained(self):
6401 codestr = """
6402 def test() -> str:
6403 x: str = "hi"
6404 y = x = "hello"
6405 return y
6406 """
6407 self.compile(codestr, StaticCodeGenerator, modname="foo")
6408
6409 def test_assign_chained_failure_wrong_target_type(self):
6410 codestr = """
6411 def test() -> str:
6412 x = 1
6413 y = x = "hello"
6414 return y
6415 """
6416 with self.assertRaisesRegex(
6417 TypedSyntaxError, type_mismatch("Exact[str]", "int")
6418 ):
6419 self.compile(codestr, StaticCodeGenerator, modname="foo")
6420
6421 def test_chained_assign_type_propagation(self):
6422 codestr = """
6423 from __static__ import int64, char, Array
6424
6425 def test2() -> Array[char]:
6426 x = y = Array[char]([48])
6427 return y
6428 """
6429 self.compile(codestr, StaticCodeGenerator, modname="foo")
6430
6431 def test_chained_assign_type_propagation_failure_redefine(self):
6432 codestr = """
6433 from __static__ import int64, char, Array
6434
6435 def test2() -> Array[char]:
6436 x = Array[int64]([54])
6437 x = y = Array[char]([48])
6438 return y
6439 """
6440 with self.assertRaisesRegex(
6441 TypedSyntaxError,
6442 type_mismatch(
6443 "Exact[Array[char]]",
6444 "Exact[Array[int64]]",
6445 ),
6446 ):
6447 self.compile(codestr, StaticCodeGenerator, modname="foo")
6448
6449 def test_chained_assign_type_propagation_failure_redefine_2(self):
6450 codestr = """
6451 from __static__ import int64, char, Array
6452
6453 def test2() -> Array[char]:
6454 x = Array[int64]([54])
6455 y = x = Array[char]([48])
6456 return y
6457 """
6458 with self.assertRaisesRegex(
6459 TypedSyntaxError,
6460 type_mismatch(
6461 "Exact[Array[char]]",
6462 "Exact[Array[int64]]",
6463 ),
6464 ):
6465 self.compile(codestr, StaticCodeGenerator, modname="foo")
6466
6467 def test_chained_assign_type_inference(self):
6468 codestr = """
6469 from __static__ import int64, char, Array
6470
6471 def test2():
6472 y = x = 4
6473 x = "hello"
6474 return y
6475 """
6476 with self.assertRaisesRegex(
6477 TypedSyntaxError, type_mismatch("Exact[str]", "int")
6478 ):
6479 self.compile(codestr, StaticCodeGenerator, modname="foo")
6480
6481 def test_chained_assign_type_inference_2(self):
6482 codestr = """
6483 from __static__ import int64, char, Array
6484
6485 def test2():
6486 y = x = 4
6487 y = "hello"
6488 return x
6489 """
6490 with self.assertRaisesRegex(
6491 TypedSyntaxError, type_mismatch("Exact[str]", "int")
6492 ):
6493 self.compile(codestr, StaticCodeGenerator, modname="foo")
6494
6495 def test_array_inplace_assign(self):
6496 codestr = """
6497 from __static__ import Array, int8
6498
6499 def m() -> Array[int8]:
6500 a = Array[int8]([1, 3, -5, -1, 7, 22])
6501 a[0] += 1
6502 return a
6503 """
6504 with self.in_module(codestr) as mod:
6505 m = mod["m"]
6506 self.assertEqual(m()[0], 2)
6507
6508 def test_array_subscripting_slice(self):
6509 codestr = """
6510 from __static__ import Array, int8
6511
6512 def m() -> Array[int8]:
6513 a = Array[int8]([1, 3, -5, -1, 7, 22])
6514 return a[1:3]
6515 """
6516 self.compile(codestr, StaticCodeGenerator, modname="foo")
6517
6518 @skipIf(cinderjit is not None, "can't report error from JIT")
6519 def test_load_uninit_module(self):
6520 """verify we don't crash if we receive a module w/o a dictionary"""
6521 codestr = """
6522 class C:
6523 def __init__(self):
6524 self.x: Optional[C] = None
6525
6526 """
6527 with self.in_module(codestr) as mod:
6528 C = mod["C"]
6529
6530 class UninitModule(ModuleType):
6531 def __init__(self):
6532 # don't call super init
6533 pass
6534
6535 sys.modules[mod["__name__"]] = UninitModule()
6536 with self.assertRaisesRegex(
6537 TypeError,
6538 r"bad name provided for class loader: \('"
6539 + mod["__name__"]
6540 + r"', 'C'\), not a class",
6541 ):
6542 C()
6543
6544 def test_module_subclass(self):
6545 codestr = """
6546 class C:
6547 def __init__(self):
6548 self.x: Optional[C] = None
6549
6550 """
6551 with self.in_module(codestr) as mod:
6552 C = mod["C"]
6553
6554 class CustomModule(ModuleType):
6555 def __getattr__(self, name):
6556 if name == "C":
6557 return C
6558
6559 sys.modules[mod["__name__"]] = CustomModule(mod["__name__"])
6560 c = C()
6561 self.assertEqual(c.x, None)
6562
6563 def test_invoke_and_raise_noframe_strictmod(self):
6564 codestr = """
6565 from __static__.compiler_flags import noframe
6566
6567 def x():
6568 raise TypeError()
6569
6570 def y():
6571 return x()
6572 """
6573 with self.in_strict_module(codestr) as mod:
6574 y = mod.y
6575 x = mod.x
6576 with self.assertRaises(TypeError):
6577 y()
6578 self.assert_jitted(x)
6579 self.assertInBytecode(
6580 y,
6581 "INVOKE_FUNCTION",
6582 ((mod.__name__, "x"), 0),
6583 )
6584
6585 def test_override_okay(self):
6586 codestr = """
6587 class B:
6588 def f(self) -> "B":
6589 return self
6590
6591 def f(x: B):
6592 return x.f()
6593 """
6594 with self.in_module(codestr) as mod:
6595 B = mod["B"]
6596 f = mod["f"]
6597
6598 class D(B):
6599 def f(self):
6600 return self
6601
6602 x = f(D())
6603
6604 def test_override_override_inherited(self):
6605 codestr = """
6606 from typing import Optional
6607 class B:
6608 def f(self) -> "Optional[B]":
6609 return self
6610
6611 class D(B):
6612 pass
6613
6614 def f(x: B):
6615 return x.f()
6616 """
6617 with self.in_module(codestr) as mod:
6618 B = mod["B"]
6619 D = mod["D"]
6620 f = mod["f"]
6621
6622 b = B()
6623 d = D()
6624 self.assertEqual(f(b), b)
6625 self.assertEqual(f(d), d)
6626
6627 D.f = lambda self: None
6628 self.assertEqual(f(b), b)
6629 self.assertEqual(f(d), None)
6630
6631 def test_override_bad_ret(self):
6632 codestr = """
6633 class B:
6634 def f(self) -> "B":
6635 return self
6636
6637 def f(x: B):
6638 return x.f()
6639 """
6640 with self.in_module(codestr) as mod:
6641 B = mod["B"]
6642 f = mod["f"]
6643
6644 class D(B):
6645 def f(self):
6646 return 42
6647
6648 with self.assertRaisesRegex(
6649 TypeError, "unexpected return type from D.f, expected B, got int"
6650 ):
6651 f(D())
6652
6653 def test_dynamic_base(self):
6654 nonstatic_code = """
6655 class Foo:
6656 pass
6657 """
6658
6659 with self.in_module(
6660 nonstatic_code, code_gen=PythonCodeGenerator, name="nonstatic"
6661 ):
6662 codestr = """
6663 from nonstatic import Foo
6664
6665 class A(Foo):
6666 def __init__(self):
6667 self.x = 1
6668
6669 def f(self) -> int:
6670 return self.x
6671
6672 def f(x: A) -> int:
6673 return x.f()
6674 """
6675 with self.in_module(codestr) as mod:
6676 f = mod["f"]
6677 self.assertInBytecode(f, "INVOKE_METHOD")
6678 a = mod["A"]()
6679 self.assertEqual(f(a), 1)
6680 # x is a data descriptor, it takes precedence
6681 a.__dict__["x"] = 100
6682 self.assertEqual(f(a), 1)
6683 # but methods are normal descriptors, instance
6684 # attributes should take precedence
6685 a.__dict__["f"] = lambda: 42
6686 self.assertEqual(f(a), 42)
6687
6688 def test_invoke_type_modified(self):
6689 codestr = """
6690 class C:
6691 def f(self):
6692 return 1
6693
6694 def x(c: C):
6695 x = c.f()
6696 x += c.f()
6697 return x
6698 """
6699
6700 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
6701 x = self.find_code(code, "x")
6702 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
6703
6704 with self.in_module(codestr) as mod:
6705 x, C = mod["x"], mod["C"]
6706 self.assertEqual(x(C()), 2)
6707 C.f = lambda self: 42
6708 self.assertEqual(x(C()), 84)
6709
6710 def test_invoke_type_modified_pre_invoke(self):
6711 codestr = """
6712 class C:
6713 def f(self):
6714 return 1
6715
6716 def x(c: C):
6717 x = c.f()
6718 x += c.f()
6719 return x
6720 """
6721
6722 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
6723 x = self.find_code(code, "x")
6724 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0))
6725
6726 with self.in_module(codestr) as mod:
6727 x, C = mod["x"], mod["C"]
6728 C.f = lambda self: 42
6729 self.assertEqual(x(C()), 84)
6730
6731 def test_override_modified_base_class(self):
6732 codestr = """
6733 class B:
6734 def f(self):
6735 return 1
6736
6737 def f(x: B):
6738 return x.f()
6739 """
6740 with self.in_module(codestr) as mod:
6741 B = mod["B"]
6742 f = mod["f"]
6743 B.f = lambda self: 2
6744
6745 class D(B):
6746 def f(self):
6747 return 3
6748
6749 d = D()
6750 self.assertEqual(f(d), 3)
6751
6752 def test_override_remove_base_method(self):
6753 codestr = """
6754 from typing import Optional
6755 class B:
6756 def f(self) -> "B":
6757 return self
6758
6759 class D(B): pass
6760
6761 def f(x: B):
6762 return x.f()
6763 """
6764 with self.in_module(codestr) as mod:
6765 B = mod["B"]
6766 D = mod["D"]
6767 f = mod["f"]
6768 b = B()
6769 d = D()
6770 self.assertEqual(f(b), b)
6771 self.assertEqual(f(d), d)
6772 del B.f
6773
6774 with self.assertRaises(AttributeError):
6775 f(b)
6776 with self.assertRaises(AttributeError):
6777 f(d)
6778
6779 def test_override_remove_derived_method(self):
6780 codestr = """
6781 from typing import Optional
6782 class B:
6783 def f(self) -> "Optional[B]":
6784 return self
6785
6786 class D(B):
6787 def f(self) -> Optional["B"]:
6788 return None
6789
6790 def f(x: B):
6791 return x.f()
6792 """
6793 with self.in_module(codestr) as mod:
6794 B = mod["B"]
6795 D = mod["D"]
6796 f = mod["f"]
6797 b = B()
6798 d = D()
6799 self.assertEqual(f(b), b)
6800 self.assertEqual(f(d), None)
6801 del D.f
6802
6803 self.assertEqual(f(b), b)
6804 self.assertEqual(f(d), d)
6805
6806 def test_override_remove_method(self):
6807 codestr = """
6808 from typing import Optional
6809 class B:
6810 def f(self) -> "Optional[B]":
6811 return self
6812
6813 def f(x: B):
6814 return x.f()
6815 """
6816 with self.in_module(codestr) as mod:
6817 B = mod["B"]
6818 f = mod["f"]
6819 b = B()
6820 self.assertEqual(f(b), b)
6821 del B.f
6822
6823 with self.assertRaises(AttributeError):
6824 f(b)
6825
6826 def test_override_remove_method_add_type_check(self):
6827 codestr = """
6828 from typing import Optional
6829 class B:
6830 def f(self) -> "B":
6831 return self
6832
6833 def f(x: B):
6834 return x.f()
6835 """
6836 with self.in_module(codestr) as mod:
6837 B = mod["B"]
6838 f = mod["f"]
6839 b = B()
6840 self.assertEqual(f(b), b)
6841 del B.f
6842
6843 with self.assertRaises(AttributeError):
6844 f(b)
6845
6846 B.f = lambda self: None
6847 with self.assertRaises(TypeError):
6848 f(b)
6849
6850 def test_override_update_derived(self):
6851 codestr = """
6852 from typing import Optional
6853 class B:
6854 def f(self) -> "Optional[B]":
6855 return self
6856
6857 class D(B):
6858 pass
6859
6860 def f(x: B):
6861 return x.f()
6862 """
6863 with self.in_module(codestr) as mod:
6864 B = mod["B"]
6865 D = mod["D"]
6866 f = mod["f"]
6867
6868 b = B()
6869 d = D()
6870 self.assertEqual(f(b), b)
6871 self.assertEqual(f(d), d)
6872
6873 B.f = lambda self: None
6874 self.assertEqual(f(b), None)
6875 self.assertEqual(f(d), None)
6876
6877 def test_override_update_derived_2(self):
6878 codestr = """
6879 from typing import Optional
6880 class B:
6881 def f(self) -> "Optional[B]":
6882 return self
6883
6884 class D1(B): pass
6885
6886 class D(D1):
6887 pass
6888
6889 def f(x: B):
6890 return x.f()
6891 """
6892 with self.in_module(codestr) as mod:
6893 B = mod["B"]
6894 D = mod["D"]
6895 f = mod["f"]
6896
6897 b = B()
6898 d = D()
6899 self.assertEqual(f(b), b)
6900 self.assertEqual(f(d), d)
6901
6902 B.f = lambda self: None
6903 self.assertEqual(f(b), None)
6904 self.assertEqual(f(d), None)
6905
6906 def test_method_prologue(self):
6907 codestr = """
6908 def f(x: str):
6909 return 42
6910 """
6911 with self.in_module(codestr) as mod:
6912 f = mod["f"]
6913 self.assertInBytecode(f, "CHECK_ARGS", (0, ("builtins", "str")))
6914 with self.assertRaisesRegex(
6915 TypeError, ".*expected 'str' for argument x, got 'int'"
6916 ):
6917 f(42)
6918
6919 def test_method_prologue_2(self):
6920 codestr = """
6921 def f(x, y: str):
6922 return 42
6923 """
6924 with self.in_module(codestr) as mod:
6925 f = mod["f"]
6926 self.assertInBytecode(f, "CHECK_ARGS", (1, ("builtins", "str")))
6927 with self.assertRaisesRegex(
6928 TypeError, ".*expected 'str' for argument y, got 'int'"
6929 ):
6930 f("abc", 42)
6931
6932 def test_method_prologue_3(self):
6933 codestr = """
6934 def f(x: int, y: str):
6935 return 42
6936 """
6937 with self.in_module(codestr) as mod:
6938 f = mod["f"]
6939 self.assertInBytecode(
6940 f, "CHECK_ARGS", (0, ("builtins", "int"), 1, ("builtins", "str"))
6941 )
6942 with self.assertRaisesRegex(
6943 TypeError, ".*expected 'str' for argument y, got 'int'"
6944 ):
6945 f(42, 42)
6946
6947 def test_method_prologue_posonly(self):
6948 codestr = """
6949 def f(x: int, /, y: str):
6950 return 42
6951 """
6952 with self.in_module(codestr) as mod:
6953 f = mod["f"]
6954 self.assertInBytecode(
6955 f, "CHECK_ARGS", (0, ("builtins", "int"), 1, ("builtins", "str"))
6956 )
6957 with self.assertRaisesRegex(
6958 TypeError, ".*expected 'str' for argument y, got 'int'"
6959 ):
6960 f(42, 42)
6961
6962 def test_method_prologue_shadowcode(self):
6963 codestr = """
6964 def f(x, y: str):
6965 return 42
6966 """
6967 with self.in_module(codestr) as mod:
6968 f = mod["f"]
6969 self.assertInBytecode(f, "CHECK_ARGS", (1, ("builtins", "str")))
6970 for i in range(100):
6971 self.assertEqual(f("abc", "abc"), 42)
6972 with self.assertRaisesRegex(
6973 TypeError, ".*expected 'str' for argument y, got 'int'"
6974 ):
6975 f("abc", 42)
6976
6977 def test_method_prologue_shadowcode_2(self):
6978 codestr = """
6979 def f(x: str):
6980 return 42
6981 """
6982 with self.in_module(codestr) as mod:
6983 f = mod["f"]
6984 self.assertInBytecode(f, "CHECK_ARGS", (0, ("builtins", "str")))
6985 for i in range(100):
6986 self.assertEqual(f("abc"), 42)
6987 with self.assertRaisesRegex(
6988 TypeError, ".*expected 'str' for argument x, got 'int'"
6989 ):
6990 f(42)
6991
6992 def test_method_prologue_no_annotation(self):
6993 codestr = """
6994 def f(x):
6995 return 42
6996 """
6997 with self.in_module(codestr) as mod:
6998 f = mod["f"]
6999 self.assertInBytecode(f, "CHECK_ARGS", ())
7000 self.assertEqual(f("abc"), 42)
7001
7002 def test_method_prologue_kwonly(self):
7003 codestr = """
7004 def f(*, x: str):
7005 return 42
7006 """
7007 with self.in_module(codestr) as mod:
7008 f = mod["f"]
7009 self.assertInBytecode(f, "CHECK_ARGS", (0, ("builtins", "str")))
7010 with self.assertRaisesRegex(
7011 TypeError, "f expected 'str' for argument x, got 'int'"
7012 ):
7013 f(x=42)
7014
7015 def test_method_prologue_kwonly_2(self):
7016 codestr = """
7017 def f(x, *, y: str):
7018 return 42
7019 """
7020 with self.in_module(codestr) as mod:
7021 f = mod["f"]
7022 self.assertInBytecode(f, "CHECK_ARGS", (1, ("builtins", "str")))
7023 with self.assertRaisesRegex(
7024 TypeError, "f expected 'str' for argument y, got 'object'"
7025 ):
7026 f(1, y=object())
7027
7028 def test_method_prologue_kwonly_3(self):
7029 codestr = """
7030 def f(x, *, y: str, z=1):
7031 return 42
7032 """
7033 with self.in_module(codestr) as mod:
7034 f = mod["f"]
7035 self.assertInBytecode(f, "CHECK_ARGS", (1, ("builtins", "str")))
7036 with self.assertRaisesRegex(
7037 TypeError, "f expected 'str' for argument y, got 'object'"
7038 ):
7039 f(1, y=object())
7040
7041 def test_method_prologue_kwonly_4(self):
7042 codestr = """
7043 def f(x, *, y: str, **rest):
7044 return 42
7045 """
7046 with self.in_module(codestr) as mod:
7047 f = mod["f"]
7048 self.assertInBytecode(f, "CHECK_ARGS", (1, ("builtins", "str")))
7049 with self.assertRaisesRegex(
7050 TypeError, "f expected 'str' for argument y, got 'object'"
7051 ):
7052 f(1, y=object(), z=2)
7053
7054 def test_method_prologue_kwonly_no_annotation(self):
7055 codestr = """
7056 def f(*, x):
7057 return 42
7058 """
7059 with self.in_module(codestr) as mod:
7060 f = mod["f"]
7061 self.assertInBytecode(f, "CHECK_ARGS", ())
7062 f(x=42)
7063
7064 def test_package_no_parent(self):
7065 codestr = """
7066 class C:
7067 def f(self):
7068 return 42
7069 """
7070 with self.in_module(
7071 codestr, code_gen=StaticCodeGenerator, name="package_no_parent.child"
7072 ) as mod:
7073 C = mod["C"]
7074 self.assertInBytecode(
7075 C.f, "CHECK_ARGS", (0, ("package_no_parent.child", "C"))
7076 )
7077 self.assertEqual(C().f(), 42)
7078
7079 def test_direct_super_init(self):
7080 value = 42
7081 expected = value
7082 codestr = f"""
7083 class Obj:
7084 pass
7085
7086 class C:
7087 def __init__(self, x: Obj):
7088 pass
7089
7090 class D:
7091 def __init__(self):
7092 C.__init__(None)
7093 """
7094 with self.assertRaisesRegex(
7095 TypedSyntaxError,
7096 "type mismatch: None positional argument type mismatch foo.C",
7097 ):
7098 self.compile(codestr, StaticCodeGenerator, modname="foo")
7099
7100 def test_class_unknown_attr(self):
7101 value = 42
7102 expected = value
7103 codestr = f"""
7104 class C:
7105 pass
7106
7107 def f():
7108 return C.foo
7109 """
7110 with self.in_module(codestr) as mod:
7111 f = mod["f"]
7112 self.assertInBytecode(f, "LOAD_ATTR", "foo")
7113
7114 def test_descriptor_access(self):
7115 value = 42
7116 expected = value
7117 codestr = f"""
7118 class Obj:
7119 abc: int
7120
7121 class C:
7122 x: Obj
7123
7124 def f():
7125 return C.x.abc
7126 """
7127 with self.in_module(codestr) as mod:
7128 f = mod["f"]
7129 self.assertInBytecode(f, "LOAD_ATTR", "abc")
7130 self.assertNotInBytecode(f, "LOAD_FIELD")
7131
7132 @skipIf(not path.exists(RICHARDS_PATH), "richards not found")
7133 def test_richards(self):
7134 with open(RICHARDS_PATH) as f:
7135 codestr = f.read()
7136
7137 with self.in_module(codestr) as mod:
7138 Richards = mod["Richards"]
7139 self.assertTrue(Richards().run(1))
7140
7141 def test_unknown_isinstance_bool_ret(self):
7142 codestr = """
7143 from typing import Any
7144
7145 class C:
7146 def __init__(self, x: str):
7147 self.x: str = x
7148
7149 def __eq__(self, other: Any) -> bool:
7150 return isinstance(other, C)
7151
7152 """
7153 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7154 C = mod["C"]
7155 x = C("abc")
7156 y = C("foo")
7157 self.assertTrue(x == y)
7158
7159 def test_unknown_issubclass_bool_ret(self):
7160 codestr = """
7161 from typing import Any
7162
7163 class C:
7164 def __init__(self, x: str):
7165 self.x: str = x
7166
7167 def __eq__(self, other: Any) -> bool:
7168 return issubclass(type(other), C)
7169
7170 """
7171 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7172 C = mod["C"]
7173 x = C("abc")
7174 y = C("foo")
7175 self.assertTrue(x == y)
7176
7177 def test_unknown_isinstance_narrows(self):
7178 codestr = """
7179 from typing import Any
7180
7181 class C:
7182 def __init__(self, x: str):
7183 self.x: str = x
7184
7185 def testfunc(x):
7186 if isinstance(x, C):
7187 return x.x
7188 """
7189 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7190 testfunc = mod["testfunc"]
7191 self.assertInBytecode(testfunc, "LOAD_FIELD", (mod["__name__"], "C", "x"))
7192
7193 def test_unknown_isinstance_narrows_class_attr(self):
7194 codestr = """
7195 from typing import Any
7196
7197 class C:
7198 def __init__(self, x: str):
7199 self.x: str = x
7200
7201 def f(self, other) -> str:
7202 if isinstance(other, self.__class__):
7203 return other.x
7204 return ''
7205 """
7206 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7207 C = mod["C"]
7208 self.assertInBytecode(
7209 C.f,
7210 "LOAD_FIELD",
7211 (mod["__name__"], "C", "x"),
7212 )
7213
7214 def test_unknown_isinstance_narrows_class_attr_dynamic(self):
7215 codestr = """
7216 from typing import Any
7217
7218 class C:
7219 def __init__(self, x: str):
7220 self.x: str = x
7221
7222 def f(self, other, unknown):
7223 if isinstance(other, unknown.__class__):
7224 return other.x
7225 return ''
7226 """
7227 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7228 C = mod["C"]
7229 self.assertInBytecode(C.f, "LOAD_ATTR", "x")
7230
7231 def test_unknown_isinstance_narrows_else_correct(self):
7232 codestr = """
7233 from typing import Any
7234
7235 class C:
7236 def __init__(self, x: str):
7237 self.x: str = x
7238
7239 def testfunc(x):
7240 if isinstance(x, C):
7241 pass
7242 else:
7243 return x.x
7244 """
7245 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7246 testfunc = mod["testfunc"]
7247 self.assertNotInBytecode(
7248 testfunc, "LOAD_FIELD", (mod["__name__"], "C", "x")
7249 )
7250
7251 def test_narrow_while_break(self):
7252 codestr = """
7253 from typing import Optional
7254 def f(x: Optional[int]) -> int:
7255 while x is None:
7256 break
7257 return x
7258 """
7259 with self.assertRaisesRegex(
7260 TypedSyntaxError,
7261 type_mismatch("Optional[int]", "int"),
7262 ):
7263 self.compile(codestr)
7264
7265 def test_narrow_while_if_break_else_return(self):
7266 codestr = """
7267 from typing import Optional
7268 def f(x: Optional[int], y: int) -> int:
7269 while x is None:
7270 if y > 0:
7271 break
7272 else:
7273 return 42
7274 return x
7275 """
7276 with self.assertRaisesRegex(
7277 TypedSyntaxError,
7278 type_mismatch("Optional[int]", "int"),
7279 ):
7280 self.compile(codestr)
7281
7282 def test_narrow_while_break_if(self):
7283 codestr = """
7284 from typing import Optional
7285 def f(x: Optional[int]) -> int:
7286 while True:
7287 if x is None:
7288 break
7289 return x
7290 """
7291 self.compile(codestr)
7292
7293 def test_narrow_while_continue_if(self):
7294 codestr = """
7295 from typing import Optional
7296 def f(x: Optional[int]) -> int:
7297 while True:
7298 if x is None:
7299 continue
7300 return x
7301 """
7302 self.compile(codestr)
7303
7304 def test_unknown_param_ann(self):
7305 codestr = """
7306 from typing import Any
7307
7308 class C:
7309 def __init__(self, x: str):
7310 self.x: str = x
7311
7312 def __eq__(self, other: Any) -> bool:
7313 return False
7314
7315 """
7316 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7317 C = mod["C"]
7318 x = C("abc")
7319 self.assertInBytecode(C.__eq__, "CHECK_ARGS", (0, (mod["__name__"], "C")))
7320 self.assertNotEqual(x, x)
7321
7322 def test_class_init_kw(self):
7323 codestr = """
7324 class C:
7325 def __init__(self, x: str):
7326 self.x: str = x
7327
7328 def f():
7329 x = C(x='abc')
7330
7331 """
7332 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7333 f = mod["f"]
7334 self.assertInBytecode(f, "CALL_FUNCTION_KW", 1)
7335
7336 def test_ret_type_cast(self):
7337 codestr = """
7338 from typing import Any
7339
7340 def testfunc(x: str, y: str) -> bool:
7341 return x == y
7342 """
7343 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7344 f = mod["testfunc"]
7345 self.assertEqual(f("abc", "abc"), True)
7346 self.assertInBytecode(f, "CAST", ("builtins", "bool"))
7347
7348 def test_bind_boolop_type(self):
7349 codestr = """
7350 from typing import Any
7351
7352 class C:
7353 def f(self) -> bool:
7354 return True
7355
7356 def g(self) -> bool:
7357 return False
7358
7359 def x(self) -> bool:
7360 return self.f() and self.g()
7361
7362 def y(self) -> bool:
7363 return self.f() or self.g()
7364 """
7365 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7366 C = mod["C"]
7367 c = C()
7368 self.assertEqual(c.x(), False)
7369 self.assertEqual(c.y(), True)
7370
7371 def test_decorated_function_ignored_class(self):
7372 codestr = """
7373 class C:
7374 @property
7375 def x(self):
7376 return lambda: 42
7377
7378 def y(self):
7379 return self.x()
7380
7381 """
7382 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7383 C = mod["C"]
7384 self.assertNotInBytecode(C.y, "INVOKE_METHOD")
7385 self.assertEqual(C().y(), 42)
7386
7387 def test_decorated_function_ignored(self):
7388 codestr = """
7389 class C: pass
7390
7391 def mydecorator(x):
7392 return C
7393
7394 @mydecorator
7395 def f():
7396 return 42
7397
7398 def g():
7399 return f()
7400
7401 """
7402 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7403 C = mod["C"]
7404 g = mod["g"]
7405 self.assertNotInBytecode(g, "INVOKE_FUNCTION")
7406 self.assertEqual(type(g()), C)
7407
7408 def test_static_function_invoke(self):
7409 codestr = """
7410 class C:
7411 @staticmethod
7412 def f():
7413 return 42
7414
7415 def f():
7416 return C.f()
7417 """
7418 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7419 f = mod["f"]
7420 self.assertInBytecode(
7421 f, "INVOKE_FUNCTION", ((mod["__name__"], "C", "f"), 0)
7422 )
7423 self.assertEqual(f(), 42)
7424
7425 def test_static_function_invoke_on_instance(self):
7426 codestr = """
7427 class C:
7428 @staticmethod
7429 def f():
7430 return 42
7431
7432 def f():
7433 return C().f()
7434 """
7435 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7436 f = mod["f"]
7437 self.assertInBytecode(
7438 f,
7439 "INVOKE_FUNCTION",
7440 ((mod["__name__"], "C", "f"), 0),
7441 )
7442 self.assertEqual(f(), 42)
7443
7444 def test_spamobj_no_params(self):
7445 codestr = """
7446 from xxclassloader import spamobj
7447
7448 def f():
7449 x = spamobj()
7450 """
7451 with self.assertRaisesRegex(
7452 TypedSyntaxError,
7453 r"cannot create instances of a generic Type\[xxclassloader.spamobj\[T\]\]",
7454 ):
7455 self.compile(codestr, StaticCodeGenerator, modname="foo")
7456
7457 def test_spamobj_error(self):
7458 codestr = """
7459 from xxclassloader import spamobj
7460
7461 def f():
7462 x = spamobj[int]()
7463 return x.error(1)
7464 """
7465 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7466 f = mod["f"]
7467 with self.assertRaisesRegex(TypeError, "no way!"):
7468 f()
7469
7470 def test_spamobj_no_error(self):
7471 codestr = """
7472 from xxclassloader import spamobj
7473
7474 def testfunc():
7475 x = spamobj[int]()
7476 return x.error(0)
7477 """
7478 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7479 f = mod["testfunc"]
7480 self.assertEqual(f(), None)
7481
7482 def test_generic_type_box_box(self):
7483 codestr = """
7484 from xxclassloader import spamobj
7485
7486 def testfunc():
7487 x = spamobj[str]()
7488 return (x.getint(), )
7489 """
7490
7491 with self.assertRaisesRegex(
7492 TypedSyntaxError,
7493 "type mismatch: int64 is an invalid return type, expected dynamic",
7494 ):
7495 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
7496
7497 def test_generic_type(self):
7498 codestr = """
7499 from xxclassloader import spamobj
7500 from __static__ import box
7501
7502 def testfunc():
7503 x = spamobj[str]()
7504 x.setstate('abc')
7505 x.setint(42)
7506 return (x.getstate(), box(x.getint()))
7507 """
7508
7509 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
7510 f = self.find_code(code, "testfunc")
7511 self.assertInBytecode(
7512 f,
7513 "INVOKE_METHOD",
7514 ((("xxclassloader", "spamobj", (("builtins", "str"),), "setstate"), 1)),
7515 )
7516 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7517 test = mod["testfunc"]
7518 self.assertEqual(test(), ("abc", 42))
7519
7520 def test_ret_void(self):
7521 codestr = """
7522 from xxclassloader import spamobj
7523 from __static__ import box
7524
7525 def testfunc():
7526 x = spamobj[str]()
7527 y = x.setstate('abc')
7528 return y
7529 """
7530
7531 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
7532 f = self.find_code(code, "testfunc")
7533 self.assertInBytecode(
7534 f,
7535 "INVOKE_METHOD",
7536 ((("xxclassloader", "spamobj", (("builtins", "str"),), "setstate"), 1)),
7537 )
7538 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7539 test = mod["testfunc"]
7540 self.assertEqual(test(), None)
7541
7542 def test_user_enumerate_list(self):
7543 codestr = """
7544 from __static__ import int64, box, clen
7545
7546 def f(x: list):
7547 i: int64 = 0
7548 res = []
7549 while i < clen(x):
7550 elem = x[i]
7551 res.append((box(i), elem))
7552 i += 1
7553 return res
7554 """
7555 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7556 f = mod["f"]
7557 self.assertInBytecode(f, "SEQUENCE_GET", SEQ_LIST_INEXACT)
7558 res = f([1, 2, 3])
7559 self.assertEqual(res, [(0, 1), (1, 2), (2, 3)])
7560
7561 def test_user_enumerate_list_nooverride(self):
7562 class mylist(list):
7563 pass
7564
7565 codestr = """
7566 from __static__ import int64, box, clen
7567
7568 def f(x: list):
7569 i: int64 = 0
7570 res = []
7571 while i < clen(x):
7572 elem = x[i]
7573 res.append((box(i), elem))
7574 i += 1
7575 return res
7576 """
7577 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7578 f = mod["f"]
7579 self.assertInBytecode(f, "SEQUENCE_GET", SEQ_LIST_INEXACT)
7580 res = f(mylist([1, 2, 3]))
7581 self.assertEqual(res, [(0, 1), (1, 2), (2, 3)])
7582
7583 def test_user_enumerate_list_subclass(self):
7584 class mylist(list):
7585 def __getitem__(self, idx):
7586 return list.__getitem__(self, idx) + 1
7587
7588 codestr = """
7589 from __static__ import int64, box, clen
7590
7591 def f(x: list):
7592 i: int64 = 0
7593 res = []
7594 while i < clen(x):
7595 elem = x[i]
7596 res.append((box(i), elem))
7597 i += 1
7598 return res
7599 """
7600 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7601 f = mod["f"]
7602 self.assertInBytecode(f, "SEQUENCE_GET", SEQ_LIST_INEXACT)
7603 res = f(mylist([1, 2, 3]))
7604 self.assertEqual(res, [(0, 2), (1, 3), (2, 4)])
7605
7606 def test_list_assign_subclass(self):
7607 class mylist(list):
7608 def __setitem__(self, idx, value):
7609 return list.__setitem__(self, idx, value + 1)
7610
7611 codestr = """
7612 from __static__ import int64, box, clen
7613
7614 def f(x: list):
7615 i: int64 = 0
7616 x[i] = 42
7617 """
7618 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7619 f = mod["f"]
7620 self.assertInBytecode(f, "SEQUENCE_SET", SEQ_LIST_INEXACT)
7621 l = mylist([0])
7622 f(l)
7623 self.assertEqual(l[0], 43)
7624
7625 def test_inexact_list_negative(self):
7626 codestr = """
7627 from __static__ import int64, box, clen
7628
7629 def f(x: list):
7630 i: int64 = 1
7631 return x[-i]
7632 """
7633 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7634 f = mod["f"]
7635 self.assertInBytecode(f, "SEQUENCE_GET", SEQ_LIST_INEXACT)
7636 res = f([1, 2, 3])
7637 self.assertEqual(res, 3)
7638
7639 def test_inexact_list_negative_small_int(self):
7640 codestr = """
7641 from __static__ import int64, box, clen
7642
7643 def f(x: list):
7644 i: int8 = 1
7645 return x[-i]
7646 """
7647 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7648 f = mod["f"]
7649 res = f([1, 2, 3])
7650 self.assertEqual(res, 3)
7651
7652 def test_inexact_list_large_unsigned(self):
7653 codestr = """
7654 from __static__ import uint64
7655 def f(x: list):
7656 i: uint64 = 0xffffffffffffffff
7657 return x[i]
7658 """
7659 with self.assertRaisesRegex(
7660 TypedSyntaxError, "type mismatch: uint64 cannot be assigned to dynamic"
7661 ):
7662 self.compile(codestr)
7663
7664 def test_named_tuple(self):
7665 codestr = """
7666 from typing import NamedTuple
7667
7668 class C(NamedTuple):
7669 x: int
7670 y: str
7671
7672 def myfunc(x: C):
7673 return x.x
7674 """
7675 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7676 f = mod["myfunc"]
7677 self.assertNotInBytecode(f, "LOAD_FIELD")
7678
7679 def test_generic_type_error(self):
7680 codestr = """
7681 from xxclassloader import spamobj
7682
7683 def testfunc():
7684 x = spamobj[str]()
7685 x.setstate(42)
7686 """
7687
7688 with self.assertRaisesRegex(
7689 TypedSyntaxError,
7690 "type mismatch: Exact\\[int\\] positional argument type mismatch str",
7691 ):
7692 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
7693
7694 def test_generic_optional_type_param(self):
7695 codestr = """
7696 from xxclassloader import spamobj
7697
7698 def testfunc():
7699 x = spamobj[str]()
7700 x.setstateoptional(None)
7701 """
7702
7703 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
7704
7705 def test_generic_optional_type_param_2(self):
7706 codestr = """
7707 from xxclassloader import spamobj
7708
7709 def testfunc():
7710 x = spamobj[str]()
7711 x.setstateoptional('abc')
7712 """
7713
7714 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
7715
7716 def test_generic_optional_type_param_error(self):
7717 codestr = """
7718 from xxclassloader import spamobj
7719
7720 def testfunc():
7721 x = spamobj[str]()
7722 x.setstateoptional(42)
7723 """
7724
7725 with self.assertRaisesRegex(
7726 TypedSyntaxError,
7727 "type mismatch: Exact\\[int\\] positional argument type mismatch Optional\\[str\\]",
7728 ):
7729 code = self.compile(codestr, StaticCodeGenerator, modname="foo")
7730
7731 def test_compile_nested_dict(self):
7732 codestr = """
7733 from __static__ import CheckedDict
7734
7735 class B: pass
7736 class D(B): pass
7737
7738 def testfunc():
7739 x = CheckedDict[B, int]({B():42, D():42})
7740 y = CheckedDict[int, CheckedDict[B, int]]({42: x})
7741 return y
7742 """
7743 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7744 test = mod["testfunc"]
7745 B = mod["B"]
7746 self.assertEqual(type(test()), chkdict[int, chkdict[B, int]])
7747
7748 def test_compile_dict_setdefault(self):
7749 codestr = """
7750 from __static__ import CheckedDict
7751 def testfunc():
7752 x = CheckedDict[int, str]({42: 'abc', })
7753 x.setdefault(100, 43)
7754 """
7755 with self.assertRaisesRegex(
7756 TypedSyntaxError,
7757 "type mismatch: Exact\\[int\\] positional argument type mismatch Optional\\[str\\]",
7758 ):
7759 self.compile(codestr, StaticCodeGenerator, modname="foo")
7760
7761 def test_compile_dict_get(self):
7762 codestr = """
7763 from __static__ import CheckedDict
7764 def testfunc():
7765 x = CheckedDict[int, str]({42: 'abc', })
7766 x.get(42, 42)
7767 """
7768 with self.assertRaisesRegex(
7769 TypedSyntaxError,
7770 "type mismatch: Exact\\[int\\] positional argument type mismatch Optional\\[str\\]",
7771 ):
7772 self.compile(codestr, StaticCodeGenerator, modname="foo")
7773
7774 codestr = """
7775 from __static__ import CheckedDict
7776
7777 class B: pass
7778 class D(B): pass
7779
7780 def testfunc():
7781 x = CheckedDict[B, int]({B():42, D():42})
7782 return x
7783 """
7784 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7785 test = mod["testfunc"]
7786 B = mod["B"]
7787 self.assertEqual(type(test()), chkdict[B, int])
7788
7789 def test_compile_dict_setitem(self):
7790 codestr = """
7791 from __static__ import CheckedDict
7792
7793 def testfunc():
7794 x = CheckedDict[int, str]({1:'abc'})
7795 x.__setitem__(2, 'def')
7796 return x
7797 """
7798 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7799 test = mod["testfunc"]
7800 x = test()
7801 self.assertInBytecode(
7802 test,
7803 "INVOKE_FUNCTION",
7804 (
7805 (
7806 "__static__",
7807 "chkdict",
7808 (("builtins", "int"), ("builtins", "str")),
7809 "__setitem__",
7810 ),
7811 3,
7812 ),
7813 )
7814 self.assertEqual(x, {1: "abc", 2: "def"})
7815
7816 def test_compile_dict_setitem_subscr(self):
7817 codestr = """
7818 from __static__ import CheckedDict
7819
7820 def testfunc():
7821 x = CheckedDict[int, str]({1:'abc'})
7822 x[2] = 'def'
7823 return x
7824 """
7825 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7826 test = mod["testfunc"]
7827 x = test()
7828 self.assertInBytecode(
7829 test,
7830 "INVOKE_METHOD",
7831 (
7832 (
7833 "__static__",
7834 "chkdict",
7835 (("builtins", "int"), ("builtins", "str")),
7836 "__setitem__",
7837 ),
7838 2,
7839 ),
7840 )
7841 self.assertEqual(x, {1: "abc", 2: "def"})
7842
7843 def test_compile_generic_dict_getitem_bad_type(self):
7844 codestr = """
7845 from __static__ import CheckedDict
7846
7847 def testfunc():
7848 x = CheckedDict[str, int]({"abc": 42})
7849 return x[42]
7850 """
7851 with self.assertRaisesRegex(
7852 TypedSyntaxError,
7853 type_mismatch("Exact[int]", "str"),
7854 ):
7855 self.compile(codestr, StaticCodeGenerator, modname="foo")
7856
7857 def test_compile_generic_dict_setitem_bad_type(self):
7858 codestr = """
7859 from __static__ import CheckedDict
7860
7861 def testfunc():
7862 x = CheckedDict[str, int]({"abc": 42})
7863 x[42] = 42
7864 """
7865 with self.assertRaisesRegex(
7866 TypedSyntaxError,
7867 type_mismatch("Exact[int]", "str"),
7868 ):
7869 self.compile(codestr, StaticCodeGenerator, modname="foo")
7870
7871 def test_compile_generic_dict_setitem_bad_type_2(self):
7872 codestr = """
7873 from __static__ import CheckedDict
7874
7875 def testfunc():
7876 x = CheckedDict[str, int]({"abc": 42})
7877 x["foo"] = "abc"
7878 """
7879 with self.assertRaisesRegex(
7880 TypedSyntaxError,
7881 type_mismatch("Exact[str]", "int"),
7882 ):
7883 self.compile(codestr, StaticCodeGenerator, modname="foo")
7884
7885 def test_compile_checked_dict_shadowcode(self):
7886 codestr = """
7887 from __static__ import CheckedDict
7888
7889 class B: pass
7890 class D(B): pass
7891
7892 def testfunc():
7893 x = CheckedDict[B, int]({B():42, D():42})
7894 return x
7895 """
7896 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7897 test = mod["testfunc"]
7898 B = mod["B"]
7899 for i in range(200):
7900 self.assertEqual(type(test()), chkdict[B, int])
7901
7902 def test_compile_checked_dict_optional(self):
7903 codestr = """
7904 from __static__ import CheckedDict
7905 from typing import Optional
7906
7907 def testfunc():
7908 x = CheckedDict[str, str | None]({
7909 'x': None,
7910 'y': 'z'
7911 })
7912 return x
7913 """
7914 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7915 f = mod["testfunc"]
7916 x = f()
7917 x["z"] = None
7918 self.assertEqual(type(x), chkdict[str, str | None])
7919
7920 def test_compile_checked_dict_bad_annotation(self):
7921 codestr = """
7922 from __static__ import CheckedDict
7923
7924 def testfunc():
7925 x: 42 = CheckedDict[str, str]({'abc':'abc'})
7926 return x
7927 """
7928 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7929 test = mod["testfunc"]
7930 self.assertEqual(type(test()), chkdict[str, str])
7931
7932 def test_compile_checked_dict_ann_differs(self):
7933 codestr = """
7934 from __static__ import CheckedDict
7935
7936 def testfunc():
7937 x: CheckedDict[int, int] = CheckedDict[str, str]({'abc':'abc'})
7938 return x
7939 """
7940 with self.assertRaisesRegex(
7941 TypedSyntaxError,
7942 type_mismatch(
7943 "Exact[chkdict[str, str]]",
7944 "chkdict[int, int]",
7945 ),
7946 ):
7947 self.compile(codestr, StaticCodeGenerator, modname="foo")
7948
7949 def test_compile_checked_dict_ann_differs_2(self):
7950 codestr = """
7951 from __static__ import CheckedDict
7952
7953 def testfunc():
7954 x: int = CheckedDict[str, str]({'abc':'abc'})
7955 return x
7956 """
7957 with self.assertRaisesRegex(
7958 TypedSyntaxError,
7959 type_mismatch("Exact[chkdict[str, str]]", "int"),
7960 ):
7961 self.compile(codestr, StaticCodeGenerator, modname="foo")
7962
7963 def test_compile_checked_dict_opt_out(self):
7964 codestr = """
7965 from __static__.compiler_flags import nonchecked_dicts
7966 class B: pass
7967 class D(B): pass
7968
7969 def testfunc():
7970 x = {B():42, D():42}
7971 return x
7972 """
7973 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7974 test = mod["testfunc"]
7975 B = mod["B"]
7976 self.assertEqual(type(test()), dict)
7977
7978 def test_compile_checked_dict_explicit_dict(self):
7979 codestr = """
7980 from __static__ import pydict
7981 class B: pass
7982 class D(B): pass
7983
7984 def testfunc():
7985 x: pydict = {B():42, D():42}
7986 return x
7987 """
7988 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
7989 test = mod["testfunc"]
7990 self.assertEqual(type(test()), dict)
7991
7992 def test_compile_checked_dict_reversed(self):
7993 codestr = """
7994 from __static__ import CheckedDict
7995
7996 class B: pass
7997 class D(B): pass
7998
7999 def testfunc():
8000 x = CheckedDict[B, int]({D():42, B():42})
8001 return x
8002 """
8003 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8004 test = mod["testfunc"]
8005 B = mod["B"]
8006 self.assertEqual(type(test()), chkdict[B, int])
8007
8008 def test_compile_checked_dict_type_specified(self):
8009 codestr = """
8010 from __static__ import CheckedDict
8011
8012 class B: pass
8013 class D(B): pass
8014
8015 def testfunc():
8016 x: CheckedDict[B, int] = CheckedDict[B, int]({D():42})
8017 return x
8018 """
8019 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8020 test = mod["testfunc"]
8021 B = mod["B"]
8022 self.assertEqual(type(test()), chkdict[B, int])
8023
8024 def test_compile_checked_dict_with_annotation_comprehension(self):
8025 codestr = """
8026 from __static__ import CheckedDict
8027
8028 def testfunc():
8029 x: CheckedDict[int, object] = {int(i): object() for i in range(1, 5)}
8030 return x
8031 """
8032 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8033 test = mod["testfunc"]
8034 self.assertEqual(type(test()), chkdict[int, object])
8035
8036 def test_compile_checked_dict_with_annotation(self):
8037 codestr = """
8038 from __static__ import CheckedDict
8039
8040 class B: pass
8041
8042 def testfunc():
8043 x: CheckedDict[B, int] = {B():42}
8044 return x
8045 """
8046 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8047 test = mod["testfunc"]
8048 B = mod["B"]
8049 self.assertEqual(type(test()), chkdict[B, int])
8050
8051 def test_compile_checked_dict_with_annotation_wrong_value_type(self):
8052 codestr = """
8053 from __static__ import CheckedDict
8054
8055 class B: pass
8056
8057 def testfunc():
8058 x: CheckedDict[B, int] = {B():'hi'}
8059 return x
8060 """
8061 with self.assertRaisesRegex(
8062 TypedSyntaxError,
8063 type_mismatch(
8064 "Exact[chkdict[foo.B, Exact[str]]]",
8065 "chkdict[foo.B, int]",
8066 ),
8067 ):
8068 self.compile(codestr, modname="foo")
8069
8070 def test_compile_checked_dict_with_annotation_wrong_key_type(self):
8071 codestr = """
8072 from __static__ import CheckedDict
8073
8074 class B: pass
8075
8076 def testfunc():
8077 x: CheckedDict[B, int] = {object():42}
8078 return x
8079 """
8080 with self.assertRaisesRegex(
8081 TypedSyntaxError,
8082 type_mismatch(
8083 "Exact[chkdict[object, Exact[int]]]",
8084 "chkdict[foo.B, int]",
8085 ),
8086 ):
8087 self.compile(codestr, modname="foo")
8088
8089 def test_compile_checked_dict_wrong_unknown_type(self):
8090 codestr = """
8091 def f(x: int):
8092 return x
8093
8094 def testfunc(iter):
8095 return f({x:42 for x in iter})
8096
8097 """
8098 with self.assertRaisesRegex(
8099 TypedSyntaxError, "positional argument type mismatch"
8100 ):
8101 self.compile(codestr, StaticCodeGenerator, modname="foo")
8102
8103 def test_compile_checked_dict_opt_out_dict_call(self):
8104 codestr = """
8105 from __static__.compiler_flags import nonchecked_dicts
8106
8107 def testfunc():
8108 x = dict(x=42)
8109 return x
8110 """
8111 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8112 test = mod["testfunc"]
8113 self.assertEqual(type(test()), dict)
8114
8115 def test_compile_checked_dict_explicit_dict_as_dict(self):
8116 codestr = """
8117 from __static__ import pydict as dict
8118 class B: pass
8119 class D(B): pass
8120
8121 def testfunc():
8122 x: dict = {B():42, D():42}
8123 return x
8124 """
8125 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8126 test = mod["testfunc"]
8127 self.assertEqual(type(test()), dict)
8128
8129 def test_compile_checked_dict_from_dict_call(self):
8130 codestr = """
8131 def testfunc():
8132 x = dict(x=42)
8133 return x
8134 """
8135 with self.assertRaisesRegex(
8136 TypeError, "cannot create 'dict\\[K, V\\]' instances"
8137 ):
8138 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8139 test = mod["testfunc"]
8140 test()
8141
8142 def test_compile_checked_dict_from_dict_call_2(self):
8143 codestr = """
8144 def testfunc():
8145 x = dict[str, int](x=42)
8146 return x
8147 """
8148 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8149 test = mod["testfunc"]
8150 self.assertEqual(type(test()), chkdict[str, int])
8151
8152 def test_compile_checked_dict_from_dict_call_3(self):
8153 # we emit the chkdict import first before future annotations, but that
8154 # should be fine as we're the compiler.
8155 codestr = """
8156 from __future__ import annotations
8157
8158 def testfunc():
8159 x = dict[str, int](x=42)
8160 return x
8161 """
8162 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8163 test = mod["testfunc"]
8164 self.assertEqual(type(test()), chkdict[str, int])
8165
8166 def test_patch_function(self):
8167 codestr = """
8168 def f():
8169 return 42
8170
8171 def g():
8172 return f()
8173 """
8174 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8175 g = mod["g"]
8176 for i in range(100):
8177 g()
8178 with patch(f"{mod['__name__']}.f", autospec=True, return_value=100) as p:
8179 self.assertEqual(g(), 100)
8180
8181 def test_patch_async_function(self):
8182 codestr = """
8183 class C:
8184 async def f(self) -> int:
8185 return 42
8186
8187 def g(self):
8188 return self.f()
8189 """
8190 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8191 C = mod["C"]
8192 c = C()
8193 for i in range(100):
8194 try:
8195 c.g().send(None)
8196 except StopIteration as e:
8197 self.assertEqual(e.args[0], 42)
8198
8199 with patch(f"{mod['__name__']}.C.f", autospec=True, return_value=100) as p:
8200 try:
8201 c.g().send(None)
8202 except StopIteration as e:
8203 self.assertEqual(e.args[0], 100)
8204
8205 def test_patch_parentclass_slot(self):
8206 codestr = """
8207 class A:
8208 def f(self) -> int:
8209 return 3
8210
8211 class B(A):
8212 pass
8213
8214 def a_f_invoker() -> int:
8215 return A().f()
8216
8217 def b_f_invoker() -> int:
8218 return B().f()
8219 """
8220 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8221 A = mod["A"]
8222 a_f_invoker = mod["a_f_invoker"]
8223 b_f_invoker = mod["b_f_invoker"]
8224 setattr(A, "f", lambda _: 7)
8225
8226 self.assertEqual(a_f_invoker(), 7)
8227 self.assertEqual(b_f_invoker(), 7)
8228
8229 def test_self_patching_function(self):
8230 codestr = """
8231 def x(d, d2=1): pass
8232 def removeit(d):
8233 global f
8234 f = x
8235
8236 def f(d):
8237 if d:
8238 removeit(d)
8239 return 42
8240
8241 def g(d):
8242 return f(d)
8243 """
8244 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8245 g = mod["g"]
8246 f = mod["f"]
8247 import weakref
8248
8249 wr = weakref.ref(f, lambda *args: self.assertEqual(i, -1))
8250 del f
8251 for i in range(100):
8252 g(False)
8253 i = -1
8254 self.assertEqual(g(True), 42)
8255 i = 0
8256 self.assertEqual(g(True), None)
8257
8258 def test_patch_function_unwatchable_dict(self):
8259 codestr = """
8260 def f():
8261 return 42
8262
8263 def g():
8264 return f()
8265 """
8266 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8267 g = mod["g"]
8268 for i in range(100):
8269 g()
8270 with patch(
8271 f"{mod['__name__']}.f",
8272 autospec=True,
8273 return_value=100,
8274 ) as p:
8275 mod[42] = 1
8276 self.assertEqual(g(), 100)
8277
8278 def test_patch_function_deleted_func(self):
8279 codestr = """
8280 def f():
8281 return 42
8282
8283 def g():
8284 return f()
8285 """
8286 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8287 g = mod["g"]
8288 for i in range(100):
8289 g()
8290 del mod["f"]
8291 with self.assertRaisesRegex(
8292 TypeError,
8293 re.escape(f"unknown function ('{mod['__name__']}', 'f')"),
8294 ):
8295 g()
8296
8297 def test_patch_static_function(self):
8298 codestr = """
8299 class C:
8300 @staticmethod
8301 def f():
8302 return 42
8303
8304 def g():
8305 return C.f()
8306 """
8307 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8308 g = mod["g"]
8309 for i in range(100):
8310 self.assertEqual(g(), 42)
8311 with patch(f"{mod['__name__']}.C.f", autospec=True, return_value=100) as p:
8312 self.assertEqual(g(), 100)
8313
8314 def test_patch_static_function_non_autospec(self):
8315 codestr = """
8316 class C:
8317 @staticmethod
8318 def f():
8319 return 42
8320
8321 def g():
8322 return C.f()
8323 """
8324 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8325 g = mod["g"]
8326 for i in range(100):
8327 g()
8328 with patch(f"{mod['__name__']}.C.f", return_value=100) as p:
8329 self.assertEqual(g(), 100)
8330
8331 def test_patch_primitive_ret_type(self):
8332 for type_name, value, patched in [
8333 ("cbool", True, False),
8334 ("cbool", False, True),
8335 ("int8", 0, 1),
8336 ("int16", 0, 1),
8337 ("int32", 0, 1),
8338 ("int64", 0, 1),
8339 ("uint8", 0, 1),
8340 ("uint16", 0, 1),
8341 ("uint32", 0, 1),
8342 ("uint64", 0, 1),
8343 ]:
8344 with self.subTest(type_name=type, value=value, patched=patched):
8345 codestr = f"""
8346 from __static__ import {type_name}, box
8347 class C:
8348 def f(self) -> {type_name}:
8349 return {value!r}
8350
8351 def g():
8352 return box(C().f())
8353 """
8354 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8355 g = mod["g"]
8356 for i in range(100):
8357 self.assertEqual(g(), value)
8358 with patch(f"{mod['__name__']}.C.f", return_value=patched) as p:
8359 self.assertEqual(g(), patched)
8360
8361 def test_patch_primitive_ret_type_overflow(self):
8362 codestr = f"""
8363 from __static__ import int8, box
8364 class C:
8365 def f(self) -> int8:
8366 return 1
8367
8368 def g():
8369 return box(C().f())
8370 """
8371 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8372 g = mod["g"]
8373 for i in range(100):
8374 self.assertEqual(g(), 1)
8375 with patch(f"{mod['__name__']}.C.f", return_value=256) as p:
8376 with self.assertRaisesRegex(
8377 OverflowError,
8378 "unexpected return type from C.f, expected "
8379 "int8, got out-of-range int \\(256\\)",
8380 ):
8381 g()
8382
8383 def test_invoke_frozen_type(self):
8384 codestr = """
8385 from cinder import freeze_type
8386
8387 @freeze_type
8388 class C:
8389 @staticmethod
8390 def f():
8391 return 42
8392
8393 def g():
8394 return C.f()
8395 """
8396 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8397 g = mod["g"]
8398 for i in range(100):
8399 self.assertEqual(g(), 42)
8400
8401 def test_invoke_strict_module(self):
8402 codestr = """
8403 def f():
8404 return 42
8405
8406 def g():
8407 return f()
8408 """
8409 with self.in_strict_module(codestr) as mod:
8410 g = mod.g
8411 for i in range(100):
8412 self.assertEqual(g(), 42)
8413 self.assertInBytecode(g, "INVOKE_FUNCTION", ((mod.__name__, "f"), 0))
8414
8415 def test_invoke_with_cell(self):
8416 codestr = """
8417 def f(l: list):
8418 x = 2
8419 return [x + y for y in l]
8420
8421 def g():
8422 return f([1,2,3])
8423 """
8424 with self.in_strict_module(codestr) as mod:
8425 g = mod.g
8426 self.assertEqual(g(), [3, 4, 5])
8427 self.assertInBytecode(g, "INVOKE_FUNCTION", ((mod.__name__, "f"), 1))
8428
8429 def test_invoke_with_cell_arg(self):
8430 codestr = """
8431 def f(l: list, x: int):
8432 return [x + y for y in l]
8433
8434 def g():
8435 return f([1,2,3], 2)
8436 """
8437 with self.in_strict_module(codestr) as mod:
8438 g = mod.g
8439 self.assertEqual(g(), [3, 4, 5])
8440 self.assertInBytecode(g, "INVOKE_FUNCTION", ((mod.__name__, "f"), 2))
8441
8442 def test_invoke_all_reg_args(self):
8443 codestr = """
8444 def target(a, b, c, d, e, f):
8445 return a * 2 + b * 3 + c * 4 + d * 5 + e * 6 + f * 7
8446
8447 def testfunc():
8448 return target(1,2,3,4,5,6)
8449 """
8450 with self.in_strict_module(codestr) as mod:
8451 f = mod.testfunc
8452 self.assertInBytecode(
8453 f,
8454 "INVOKE_FUNCTION",
8455 ((mod.__name__, "target"), 6),
8456 )
8457 self.assertEqual(f(), 112)
8458
8459 def test_invoke_all_extra_args(self):
8460 codestr = """
8461 def target(a, b, c, d, e, f, g):
8462 return a * 2 + b * 3 + c * 4 + d * 5 + e * 6 + f * 7 + g
8463
8464 def testfunc():
8465 return target(1,2,3,4,5,6,7)
8466 """
8467 with self.in_strict_module(codestr) as mod:
8468 f = mod.testfunc
8469 self.assertInBytecode(
8470 f,
8471 "INVOKE_FUNCTION",
8472 ((mod.__name__, "target"), 7),
8473 )
8474 self.assertEqual(f(), 119)
8475
8476 def test_invoke_strict_module_deep(self):
8477 codestr = """
8478 def f0(): return 42
8479 def f1(): return f0()
8480 def f2(): return f1()
8481 def f3(): return f2()
8482 def f4(): return f3()
8483 def f5(): return f4()
8484 def f6(): return f5()
8485 def f7(): return f6()
8486 def f8(): return f7()
8487 def f9(): return f8()
8488 def f10(): return f9()
8489 def f11(): return f10()
8490
8491 def g():
8492 return f11()
8493 """
8494 with self.in_strict_module(codestr) as mod:
8495 g = mod.g
8496 self.assertEqual(g(), 42)
8497 self.assertEqual(g(), 42)
8498 self.assertInBytecode(g, "INVOKE_FUNCTION", ((mod.__name__, "f11"), 0))
8499
8500 def test_invoke_strict_module_deep_unjitable(self):
8501 codestr = """
8502 def f0(): return 42
8503 def f1():
8504 from sys import *
8505 return f0()
8506 def f2(): return f1()
8507 def f3(): return f2()
8508 def f4(): return f3()
8509 def f5(): return f4()
8510 def f6(): return f5()
8511 def f7(): return f6()
8512 def f8(): return f7()
8513 def f9(): return f8()
8514 def f10(): return f9()
8515 def f11(): return f10()
8516
8517 def g(x):
8518 if x: return 0
8519
8520 return f11()
8521 """
8522 with self.in_strict_module(codestr) as mod:
8523 g = mod.g
8524 self.assertEqual(g(True), 0)
8525 # we should have done some level of pre-jitting
8526 self.assert_not_jitted(mod.f2)
8527 self.assert_not_jitted(mod.f1)
8528 self.assert_not_jitted(mod.f0)
8529 [self.assert_jitted(getattr(mod, f"f{i}")) for i in range(3, 12)]
8530 self.assertEqual(g(False), 42)
8531 self.assertInBytecode(
8532 g,
8533 "INVOKE_FUNCTION",
8534 ((mod.__name__, "f11"), 0),
8535 )
8536
8537 def test_invoke_strict_module_deep_unjitable_many_args(self):
8538 codestr = """
8539 def f0(): return 42
8540 def f1(a, b, c, d, e, f, g, h):
8541 from sys import *
8542 return f0() - a + b - c + d - e + f - g + h - 4
8543
8544 def f2(): return f1(1,2,3,4,5,6,7,8)
8545 def f3(): return f2()
8546 def f4(): return f3()
8547 def f5(): return f4()
8548 def f6(): return f5()
8549 def f7(): return f6()
8550 def f8(): return f7()
8551 def f9(): return f8()
8552 def f10(): return f9()
8553 def f11(): return f10()
8554
8555 def g():
8556 return f11()
8557 """
8558 with self.in_strict_module(codestr) as mod:
8559 g = mod.g
8560 f1 = mod.f1
8561 self.assertEqual(g(), 42)
8562 self.assertEqual(g(), 42)
8563 self.assertInBytecode(
8564 g,
8565 "INVOKE_FUNCTION",
8566 ((mod.__name__, "f11"), 0),
8567 )
8568 self.assert_not_jitted(f1)
8569
8570 def test_invoke_strict_module_recursive(self):
8571 codestr = """
8572 def fib(number):
8573 if number <= 1:
8574 return number
8575 return(fib(number-1) + fib(number-2))
8576 """
8577 with self.in_strict_module(codestr) as mod:
8578 fib = mod.fib
8579 self.assertInBytecode(
8580 fib,
8581 "INVOKE_FUNCTION",
8582 ((mod.__name__, "fib"), 1),
8583 )
8584 self.assertEqual(fib(4), 3)
8585
8586 def test_invoke_strict_module_mutual_recursive(self):
8587 codestr = """
8588 def fib1(number):
8589 if number <= 1:
8590 return number
8591 return(fib(number-1) + fib(number-2))
8592
8593 def fib(number):
8594 if number <= 1:
8595 return number
8596 return(fib1(number-1) + fib1(number-2))
8597 """
8598 with self.in_strict_module(codestr) as mod:
8599 fib = mod.fib
8600 fib1 = mod.fib1
8601 self.assertInBytecode(
8602 fib,
8603 "INVOKE_FUNCTION",
8604 ((mod.__name__, "fib1"), 1),
8605 )
8606 self.assertInBytecode(
8607 fib1,
8608 "INVOKE_FUNCTION",
8609 ((mod.__name__, "fib"), 1),
8610 )
8611 self.assertEqual(fib(0), 0)
8612 self.assert_jitted(fib1)
8613 self.assertEqual(fib(4), 3)
8614
8615 def test_invoke_strict_module_pre_invoked(self):
8616 codestr = """
8617 def f():
8618 return 42
8619
8620 def g():
8621 return f()
8622 """
8623 with self.in_strict_module(codestr) as mod:
8624 self.assertEqual(mod.f(), 42)
8625 self.assert_jitted(mod.f)
8626 g = mod.g
8627 self.assertEqual(g(), 42)
8628 self.assertInBytecode(
8629 g,
8630 "INVOKE_FUNCTION",
8631 ((mod.__name__, "f"), 0),
8632 )
8633
8634 def test_invoke_strict_module_patching(self):
8635 codestr = """
8636 def f():
8637 return 42
8638
8639 def g():
8640 return f()
8641 """
8642 with self.in_strict_module(codestr, enable_patching=True) as mod:
8643 g = mod.g
8644 for i in range(100):
8645 self.assertEqual(g(), 42)
8646 self.assertInBytecode(g, "INVOKE_FUNCTION", ((mod.__name__, "f"), 0))
8647 mod.patch("f", lambda: 100)
8648 self.assertEqual(g(), 100)
8649
8650 def test_invoke_patch_non_vectorcall(self):
8651 codestr = """
8652 def f():
8653 return 42
8654
8655 def g():
8656 return f()
8657 """
8658 with self.in_strict_module(codestr, enable_patching=True) as mod:
8659 g = mod.g
8660 self.assertInBytecode(g, "INVOKE_FUNCTION", ((mod.__name__, "f"), 0))
8661 self.assertEqual(g(), 42)
8662 mod.patch("f", Mock(return_value=100))
8663 self.assertEqual(g(), 100)
8664
8665 def test_patch_method(self):
8666 codestr = """
8667 class C:
8668 def f(self):
8669 pass
8670
8671 def g():
8672 return C().f()
8673 """
8674 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8675 g = mod["g"]
8676 C = mod["C"]
8677 orig = C.f
8678 C.f = lambda *args: args
8679 for i in range(100):
8680 v = g()
8681 self.assertEqual(type(v), tuple)
8682 self.assertEqual(type(v[0]), C)
8683 C.f = orig
8684 self.assertEqual(g(), None)
8685
8686 def test_patch_method_ret_none_error(self):
8687 codestr = """
8688 class C:
8689 def f(self) -> None:
8690 pass
8691
8692 def g():
8693 return C().f()
8694 """
8695 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8696 g = mod["g"]
8697 C = mod["C"]
8698 C.f = lambda *args: args
8699 with self.assertRaisesRegex(
8700 TypeError,
8701 "unexpected return type from C.f, expected NoneType, got tuple",
8702 ):
8703 v = g()
8704
8705 def test_patch_method_ret_none(self):
8706 codestr = """
8707 class C:
8708 def f(self) -> None:
8709 pass
8710
8711 def g():
8712 return C().f()
8713 """
8714 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8715 g = mod["g"]
8716 C = mod["C"]
8717 C.f = lambda *args: None
8718 self.assertEqual(g(), None)
8719
8720 def test_patch_method_bad_ret(self):
8721 codestr = """
8722 class C:
8723 def f(self) -> int:
8724 return 42
8725
8726 def g():
8727 return C().f()
8728 """
8729 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
8730 g = mod["g"]
8731 C = mod["C"]
8732 C.f = lambda *args: "abc"
8733 with self.assertRaisesRegex(
8734 TypeError, "unexpected return type from C.f, expected int, got str"
8735 ):
8736 v = g()
8737
8738 def test_primitive_args_funcdef(self):
8739 codestr = """
8740 from __static__ import int8, box
8741
8742 def n(val: int8):
8743 return box(val)
8744
8745 def x():
8746 y: int8 = 42
8747 return n(y)
8748 """
8749 with self.in_strict_module(codestr) as mod:
8750 n = mod.n
8751 x = mod.x
8752 self.assertEqual(x(), 42)
8753 self.assertEqual(mod.n(-128), -128)
8754 self.assertEqual(mod.n(127), 127)
8755 with self.assertRaises(OverflowError):
8756 print(mod.n(-129))
8757 with self.assertRaises(OverflowError):
8758 print(mod.n(128))
8759
8760 def test_primitive_args_funcdef_unjitable(self):
8761 codestr = """
8762 from __static__ import int8, box
8763
8764 def n(val: int8):
8765 try:
8766 from sys import *
8767 except:
8768 pass
8769 return box(val)
8770
8771 def x():
8772 y: int8 = 42
8773 return n(y)
8774 """
8775 with self.in_strict_module(codestr) as mod:
8776 n = mod.n
8777 x = mod.x
8778 self.assertEqual(x(), 42)
8779 self.assertEqual(mod.n(-128), -128)
8780 self.assertEqual(mod.n(127), 127)
8781 with self.assertRaises(OverflowError):
8782 print(mod.n(-129))
8783 with self.assertRaises(OverflowError):
8784 print(mod.n(128))
8785
8786 def test_primitive_args_funcdef_too_many_args(self):
8787 codestr = """
8788 from __static__ import int8, box
8789
8790 def n(x: int8):
8791 return box(x)
8792 """
8793 with self.in_strict_module(codestr) as mod:
8794 n = mod.n
8795 with self.assertRaises(TypeError):
8796 print(mod.n(-128, x=2))
8797 with self.assertRaises(TypeError):
8798 print(mod.n(-128, 2))
8799
8800 def test_primitive_args_funcdef_missing_starargs(self):
8801 codestr = """
8802 from __static__ import int8, box
8803
8804 def x(val: int8, *foo):
8805 return box(val), foo
8806 def y(val: int8, **foo):
8807 return box(val), foo
8808 """
8809 with self.in_strict_module(codestr) as mod:
8810 self.assertEqual(mod.x(-128), (-128, ()))
8811 self.assertEqual(mod.y(-128), (-128, {}))
8812
8813 def test_primitive_args_many_args(self):
8814 codestr = """
8815 from __static__ import int8, int16, int32, int64, uint8, uint16, uint32, uint64, box
8816
8817 def x(i8: int8, i16: int16, i32: int32, i64: int64, u8: uint8, u16: uint16, u32: uint32, u64: uint64):
8818 return box(i8), box(i16), box(i32), box(i64), box(u8), box(u16), box(u32), box(u64)
8819
8820 def y():
8821 return x(1,2,3,4,5,6,7,8)
8822 """
8823 with self.in_strict_module(codestr) as mod:
8824 self.assertInBytecode(mod.y, "INVOKE_FUNCTION", ((mod.__name__, "x"), 8))
8825 self.assertEqual(mod.y(), (1, 2, 3, 4, 5, 6, 7, 8))
8826 self.assertEqual(mod.x(1, 2, 3, 4, 5, 6, 7, 8), (1, 2, 3, 4, 5, 6, 7, 8))
8827
8828 def test_primitive_args_sizes(self):
8829 cases = [
8830 ("cbool", True, False),
8831 ("cbool", False, False),
8832 ("int8", (1 << 7), True),
8833 ("int8", (-1 << 7) - 1, True),
8834 ("int8", -1 << 7, False),
8835 ("int8", (1 << 7) - 1, False),
8836 ("int16", (1 << 15), True),
8837 ("int16", (-1 << 15) - 1, True),
8838 ("int16", -1 << 15, False),
8839 ("int16", (1 << 15) - 1, False),
8840 ("int32", (1 << 31), True),
8841 ("int32", (-1 << 31) - 1, True),
8842 ("int32", -1 << 31, False),
8843 ("int32", (1 << 31) - 1, False),
8844 ("int64", (1 << 63), True),
8845 ("int64", (-1 << 63) - 1, True),
8846 ("int64", -1 << 63, False),
8847 ("int64", (1 << 63) - 1, False),
8848 ("uint8", (1 << 8), True),
8849 ("uint8", -1, True),
8850 ("uint8", (1 << 8) - 1, False),
8851 ("uint8", 0, False),
8852 ("uint16", (1 << 16), True),
8853 ("uint16", -1, True),
8854 ("uint16", (1 << 16) - 1, False),
8855 ("uint16", 0, False),
8856 ("uint32", (1 << 32), True),
8857 ("uint32", -1, True),
8858 ("uint32", (1 << 32) - 1, False),
8859 ("uint32", 0, False),
8860 ("uint64", (1 << 64), True),
8861 ("uint64", -1, True),
8862 ("uint64", (1 << 64) - 1, False),
8863 ("uint64", 0, False),
8864 ]
8865 for type, val, overflows in cases:
8866 codestr = f"""
8867 from __static__ import {type}, box
8868
8869 def x(val: {type}):
8870 return box(val)
8871 """
8872 with self.subTest(type=type, val=val, overflows=overflows):
8873 with self.in_strict_module(codestr) as mod:
8874 if overflows:
8875 with self.assertRaises(OverflowError):
8876 mod.x(val)
8877 else:
8878 self.assertEqual(mod.x(val), val)
8879
8880 def test_primitive_args_funcdef_missing_kw_call(self):
8881 codestr = """
8882 from __static__ import int8, box
8883
8884 def testfunc(x: int8, foo):
8885 return box(x), foo
8886 """
8887 with self.in_strict_module(codestr) as mod:
8888 self.assertEqual(mod.testfunc(-128, foo=42), (-128, 42))
8889
8890 def test_primitive_args_funccall(self):
8891 codestr = """
8892 from __static__ import int8
8893
8894 def f(foo):
8895 pass
8896
8897 def n() -> int:
8898 x: int8 = 3
8899 return f(x)
8900 """
8901 with self.assertRaisesRegex(
8902 TypedSyntaxError,
8903 "type mismatch: int8 positional argument type mismatch dynamic",
8904 ):
8905 self.compile(codestr, StaticCodeGenerator, modname="foo.py")
8906
8907 def test_primitive_args_funccall_int(self):
8908 codestr = """
8909 from __static__ import int8
8910
8911 def f(foo: int):
8912 pass
8913
8914 def n() -> int:
8915 x: int8 = 3
8916 return f(x)
8917 """
8918 with self.assertRaisesRegex(
8919 TypedSyntaxError,
8920 "type mismatch: int8 positional argument type mismatch int",
8921 ):
8922 self.compile(codestr, StaticCodeGenerator, modname="foo.py")
8923
8924 def test_primitive_args_typecall(self):
8925 codestr = """
8926 from __static__ import int8
8927
8928 def n() -> int:
8929 x: int8 = 3
8930 return int(x)
8931 """
8932 with self.assertRaisesRegex(
8933 TypedSyntaxError, "Call argument cannot be a primitive"
8934 ):
8935 self.compile(codestr, StaticCodeGenerator, modname="foo.py")
8936
8937 def test_primitive_args_typecall_kwarg(self):
8938 codestr = """
8939 from __static__ import int8
8940
8941 def n() -> int:
8942 x: int8 = 3
8943 return dict(a=x)
8944 """
8945 with self.assertRaisesRegex(
8946 TypedSyntaxError, "Call argument cannot be a primitive"
8947 ):
8948 self.compile(codestr, StaticCodeGenerator, modname="foo.py")
8949
8950 def test_primitive_args_nonstrict(self):
8951 codestr = """
8952 from __static__ import int8, int16, box
8953
8954 def f(x: int8, y: int16) -> int16:
8955 return x + y
8956
8957 def g() -> int:
8958 return box(f(1, 300))
8959 """
8960 with self.in_module(codestr) as mod:
8961 self.assertEqual(mod["g"](), 301)
8962
8963 def test_primitive_args_and_return(self):
8964 cases = [
8965 ("cbool", 1),
8966 ("cbool", 0),
8967 ("int8", -1 << 7),
8968 ("int8", (1 << 7) - 1),
8969 ("int16", -1 << 15),
8970 ("int16", (1 << 15) - 1),
8971 ("int32", -1 << 31),
8972 ("int32", (1 << 31) - 1),
8973 ("int64", -1 << 63),
8974 ("int64", (1 << 63) - 1),
8975 ("uint8", (1 << 8) - 1),
8976 ("uint8", 0),
8977 ("uint16", (1 << 16) - 1),
8978 ("uint16", 0),
8979 ("uint32", (1 << 32) - 1),
8980 ("uint32", 0),
8981 ("uint64", (1 << 64) - 1),
8982 ("uint64", 0),
8983 ]
8984 for typ, val in cases:
8985 if typ == "cbool":
8986 op = "or"
8987 expected = True
8988 other = "cbool(True)"
8989 boxed = "bool"
8990 else:
8991 op = "+" if val <= 0 else "-"
8992 expected = val + (1 if op == "+" else -1)
8993 other = "1"
8994 boxed = "int"
8995 with self.subTest(typ=typ, val=val, op=op, expected=expected):
8996 codestr = f"""
8997 from __static__ import {typ}, box
8998
8999 def f(x: {typ}, y: {typ}) -> {typ}:
9000 return x {op} y
9001
9002 def g() -> {boxed}:
9003 return box(f({val}, {other}))
9004 """
9005 with self.in_strict_module(codestr) as mod:
9006 self.assertEqual(mod.g(), expected)
9007
9008 def test_primitive_return(self):
9009 cases = [
9010 ("cbool", True),
9011 ("cbool", False),
9012 ("int8", -1 << 7),
9013 ("int8", (1 << 7) - 1),
9014 ("int16", -1 << 15),
9015 ("int16", (1 << 15) - 1),
9016 ("int32", -1 << 31),
9017 ("int32", (1 << 31) - 1),
9018 ("int64", -1 << 63),
9019 ("int64", (1 << 63) - 1),
9020 ("uint8", (1 << 8) - 1),
9021 ("uint8", 0),
9022 ("uint16", (1 << 16) - 1),
9023 ("uint16", 0),
9024 ("uint32", (1 << 32) - 1),
9025 ("uint32", 0),
9026 ("uint64", (1 << 64) - 1),
9027 ("uint64", 0),
9028 ]
9029 tf = [True, False]
9030 for (type, val), box, strict, error, unjitable in itertools.product(
9031 cases, [False], [True], [False], [True]
9032 ):
9033 if type == "cbool":
9034 op = "or"
9035 other = "False"
9036 boxed = "bool"
9037 else:
9038 op = "*"
9039 other = "1"
9040 boxed = "int"
9041 unjitable_code = "from sys import *" if unjitable else ""
9042 codestr = f"""
9043 from __static__ import {type}, box
9044
9045 def f(error: bool) -> {type}:
9046 {unjitable_code}
9047 if error:
9048 raise RuntimeError("boom")
9049 return {val}
9050 """
9051 if box:
9052 codestr += f"""
9053
9054 def g() -> {boxed}:
9055 return box(f({error}) {op} {type}({other}))
9056 """
9057 else:
9058 codestr += f"""
9059
9060 def g() -> {type}:
9061 return f({error}) {op} {type}({other})
9062 """
9063 ctx = self.in_strict_module if strict else self.in_module
9064 oparg = PRIM_NAME_TO_TYPE[type]
9065 with self.subTest(
9066 type=type,
9067 val=val,
9068 strict=strict,
9069 box=box,
9070 error=error,
9071 unjitable=unjitable,
9072 ):
9073 with ctx(codestr) as mod:
9074 f = mod.f if strict else mod["f"]
9075 g = mod.g if strict else mod["g"]
9076 self.assertInBytecode(f, "RETURN_INT", oparg)
9077 if box:
9078 self.assertNotInBytecode(g, "RETURN_INT")
9079 else:
9080 self.assertInBytecode(g, "RETURN_INT", oparg)
9081 if error:
9082 with self.assertRaisesRegex(RuntimeError, "boom"):
9083 g()
9084 else:
9085 self.assertEqual(g(), val)
9086 self.assert_jitted(g)
9087 if unjitable:
9088 self.assert_not_jitted(f)
9089 else:
9090 self.assert_jitted(f)
9091
9092 def test_primitive_return_recursive(self):
9093 codestr = """
9094 from __static__ import int32
9095
9096 def fib(n: int32) -> int32:
9097 if n <= 1:
9098 return n
9099 return fib(n-1) + fib(n-2)
9100 """
9101 with self.in_strict_module(codestr) as mod:
9102 self.assertInBytecode(
9103 mod.fib,
9104 "INVOKE_FUNCTION",
9105 ((mod.__name__, "fib"), 1),
9106 )
9107 self.assertEqual(mod.fib(2), 1)
9108 self.assert_jitted(mod.fib)
9109
9110 def test_primitive_return_unannotated(self):
9111 codestr = """
9112 from __static__ import int32
9113
9114 def f():
9115 x: int32 = 1
9116 return x
9117 """
9118 with self.assertRaisesRegex(
9119 TypedSyntaxError, "type mismatch: int32 cannot be assigned to dynamic"
9120 ):
9121 self.compile(codestr)
9122
9123 def test_module_level_final_decl(self):
9124 codestr = """
9125 from typing import Final
9126
9127 x: Final
9128 """
9129 with self.assertRaisesRegex(
9130 TypedSyntaxError, "Must assign a value when declaring a Final"
9131 ):
9132 self.compile(codestr, StaticCodeGenerator, modname="foo")
9133
9134 def test_int_compare_to_cbool(self):
9135 codestr = """
9136 from __static__ import int64, cbool
9137 def foo(i: int64) -> cbool:
9138 return i == 0
9139 """
9140 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
9141 foo = mod["foo"]
9142 self.assertEqual(foo(0), True)
9143 self.assertEqual(foo(1), False)
9144
9145 def test_int_compare_to_cbool_reversed(self):
9146 codestr = """
9147 from __static__ import int64, cbool
9148 def foo(i: int64) -> cbool:
9149 return 0 == i
9150 """
9151 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
9152 foo = mod["foo"]
9153 self.assertEqual(foo(0), True)
9154 self.assertEqual(foo(1), False)
9155
9156 def test_inline_primitive(self):
9157 codestr = """
9158 from __static__ import int64, cbool, inline
9159
9160 @inline
9161 def x(i: int64) -> cbool:
9162 return i == 1
9163
9164 def foo(i: int64) -> cbool:
9165 return i >0 and x(i)
9166 """
9167 with self.in_module(codestr, optimize=2) as mod:
9168 foo = mod["foo"]
9169 self.assertEqual(foo(0), False)
9170 self.assertEqual(foo(1), True)
9171 self.assertEqual(foo(2), False)
9172 self.assertNotInBytecode(foo, "STORE_FAST")
9173 self.assertInBytecode(foo, "STORE_LOCAL")
9174
9175 def test_final_multiple_typeargs(self):
9176 codestr = """
9177 from typing import Final
9178 from something import hello
9179
9180 x: Final[int, str] = hello()
9181 """
9182 with self.assertRaisesRegex(
9183 TypedSyntaxError, "Final types can only have a single type arg"
9184 ):
9185 self.compile(codestr, StaticCodeGenerator, modname="foo")
9186
9187 def test_final_annotation_nesting(self):
9188 with self.assertRaisesRegex(
9189 TypedSyntaxError, "Final annotation is only valid in initial declaration"
9190 ):
9191 self.compile(
9192 """
9193 from typing import Final, List
9194
9195 x: List[Final[str]] = []
9196 """,
9197 StaticCodeGenerator,
9198 modname="foo",
9199 )
9200
9201 with self.assertRaisesRegex(
9202 TypedSyntaxError, "Final annotation is only valid in initial declaration"
9203 ):
9204 self.compile(
9205 """
9206 from typing import Final, List
9207 x: List[int | Final] = []
9208 """,
9209 StaticCodeGenerator,
9210 modname="foo",
9211 )
9212
9213 def test_final(self):
9214 codestr = """
9215 from typing import Final
9216
9217 x: Final = 0xdeadbeef
9218 """
9219 self.compile(codestr, StaticCodeGenerator, modname="foo")
9220
9221 def test_final_generic(self):
9222 codestr = """
9223 from typing import Final
9224
9225 x: Final[int] = 0xdeadbeef
9226 """
9227 self.compile(codestr, StaticCodeGenerator, modname="foo")
9228
9229 def test_final_generic_types(self):
9230 codestr = """
9231 from typing import Final
9232
9233 def g(i: int) -> int:
9234 return i
9235
9236 def f() -> int:
9237 x: Final[int] = 0xdeadbeef
9238 return g(x)
9239 """
9240 self.compile(codestr, StaticCodeGenerator, modname="foo")
9241
9242 def test_final_uninitialized(self):
9243 codestr = """
9244 from typing import Final
9245
9246 x: Final
9247 """
9248 with self.assertRaisesRegex(
9249 TypedSyntaxError, "Must assign a value when declaring a Final"
9250 ):
9251 self.compile(codestr, StaticCodeGenerator, modname="foo")
9252
9253 def test_final_reassign(self):
9254 codestr = """
9255 from typing import Final
9256
9257 x: Final = 0xdeadbeef
9258 x = "something"
9259 """
9260 with self.assertRaisesRegex(
9261 TypedSyntaxError, "Cannot assign to a Final variable"
9262 ):
9263 self.compile(codestr, StaticCodeGenerator, modname="foo")
9264
9265 def test_final_reassign_explicit_global(self):
9266 codestr = """
9267 from typing import Final
9268
9269 a: Final = 1337
9270
9271 def fn():
9272 def fn2():
9273 global a
9274 a = 0
9275 """
9276 with self.assertRaisesRegex(
9277 TypedSyntaxError, "Cannot assign to a Final variable"
9278 ):
9279 self.compile(codestr, StaticCodeGenerator, modname="foo")
9280
9281 def test_final_reassign_explicit_global_shadowed(self):
9282 codestr = """
9283 from typing import Final
9284
9285 a: Final = 1337
9286
9287 def fn():
9288 a = 2
9289 def fn2():
9290 global a
9291 a = 0
9292 """
9293 with self.assertRaisesRegex(
9294 TypedSyntaxError, "Cannot assign to a Final variable"
9295 ):
9296 self.compile(codestr, StaticCodeGenerator, modname="foo")
9297
9298 def test_final_reassign_nonlocal(self):
9299 codestr = """
9300 from typing import Final
9301
9302 a: Final = 1337
9303
9304 def fn():
9305 def fn2():
9306 nonlocal a
9307 a = 0
9308 """
9309 with self.assertRaisesRegex(
9310 TypedSyntaxError, "Cannot assign to a Final variable"
9311 ):
9312 self.compile(codestr, StaticCodeGenerator, modname="foo")
9313
9314 def test_final_reassign_nonlocal_shadowed(self):
9315 codestr = """
9316 from typing import Final
9317
9318 a: Final = 1337
9319
9320 def fn():
9321 a = 3
9322 def fn2():
9323 nonlocal a
9324 # should be allowed, we're assigning to the shadowed
9325 # value
9326 a = 0
9327 """
9328 self.compile(codestr, StaticCodeGenerator, modname="foo")
9329
9330 def test_final_reassigned_in_tuple(self):
9331 codestr = """
9332 from typing import Final
9333
9334 x: Final = 0xdeadbeef
9335 y = 3
9336 x, y = 4, 5
9337 """
9338 with self.assertRaisesRegex(
9339 TypedSyntaxError, "Cannot assign to a Final variable"
9340 ):
9341 self.compile(codestr, StaticCodeGenerator, modname="foo")
9342
9343 def test_final_reassigned_in_loop(self):
9344 codestr = """
9345 from typing import Final
9346
9347 x: Final = 0xdeadbeef
9348
9349 for x in [1, 3, 5]:
9350 pass
9351 """
9352 with self.assertRaisesRegex(
9353 TypedSyntaxError, "Cannot assign to a Final variable"
9354 ):
9355 self.compile(codestr, StaticCodeGenerator, modname="foo")
9356
9357 def test_final_reassigned_in_except(self):
9358 codestr = """
9359 from typing import Final
9360
9361 def f():
9362 e: Final = 3
9363 try:
9364 x = 1 + "2"
9365 except Exception as e:
9366 pass
9367 """
9368 with self.assertRaisesRegex(
9369 TypedSyntaxError, "Cannot assign to a Final variable"
9370 ):
9371 self.compile(codestr, StaticCodeGenerator, modname="foo")
9372
9373 def test_final_reassigned_in_loop_target_tuple(self):
9374 codestr = """
9375 from typing import Final
9376
9377 x: Final = 0xdeadbeef
9378
9379 for x, y in [(1, 2)]:
9380 pass
9381 """
9382 with self.assertRaisesRegex(
9383 TypedSyntaxError, "Cannot assign to a Final variable"
9384 ):
9385 self.compile(codestr, StaticCodeGenerator, modname="foo")
9386
9387 def test_final_reassigned_in_ctxmgr(self):
9388 codestr = """
9389 from typing import Final
9390
9391 x: Final = 0xdeadbeef
9392
9393 with open("lol") as x:
9394 pass
9395 """
9396 with self.assertRaisesRegex(
9397 TypedSyntaxError, "Cannot assign to a Final variable"
9398 ):
9399 self.compile(codestr, StaticCodeGenerator, modname="foo")
9400
9401 def test_final_generic_reassign(self):
9402 codestr = """
9403 from typing import Final
9404
9405 x: Final[int] = 0xdeadbeef
9406 x = 0x5ca1ab1e
9407 """
9408 with self.assertRaisesRegex(
9409 TypedSyntaxError, "Cannot assign to a Final variable"
9410 ):
9411 self.compile(codestr, StaticCodeGenerator, modname="foo")
9412
9413 def test_class_level_final_decl(self):
9414 codestr = """
9415 from typing import Final
9416
9417 class C:
9418 x: Final[int]
9419
9420 def __init__(self):
9421 self.x = 3
9422 """
9423 self.compile(codestr, StaticCodeGenerator, modname="foo")
9424
9425 def test_class_level_final_decl_uninitialized(self):
9426 codestr = """
9427 from typing import Final
9428
9429 class C:
9430 x: Final
9431 """
9432 with self.assertRaisesRegex(
9433 TypedSyntaxError,
9434 re.escape("Final attribute not initialized: foo.C:x"),
9435 ):
9436 self.compile(codestr, StaticCodeGenerator, modname="foo")
9437
9438 def test_class_level_final_reinitialized(self):
9439 codestr = """
9440 from typing import Final
9441
9442 class C:
9443 x: Final = 3
9444 x = 4
9445 """
9446 with self.assertRaisesRegex(
9447 TypedSyntaxError, "Cannot assign to a Final variable"
9448 ):
9449 self.compile(codestr, StaticCodeGenerator, modname="foo")
9450
9451 def test_class_level_final_reinitialized_directly(self):
9452 codestr = """
9453 from typing import Final
9454
9455 class C:
9456 x: Final = 3
9457
9458 C.x = 4
9459 """
9460 # Note - this will raise even without the Final, we don't allow assignments to slots
9461 with self.assertRaisesRegex(
9462 TypedSyntaxError,
9463 type_mismatch("Exact[int]", "types.MemberDescriptorType"),
9464 ):
9465 self.compile(codestr, StaticCodeGenerator, modname="foo")
9466
9467 def test_class_level_final_reinitialized_in_instance(self):
9468 codestr = """
9469 from typing import Final
9470
9471 class C:
9472 x: Final = 3
9473
9474 C().x = 4
9475 """
9476 with self.assertRaisesRegex(
9477 TypedSyntaxError,
9478 "Cannot assign to a Final attribute of foo.C:x",
9479 ):
9480 self.compile(codestr, StaticCodeGenerator, modname="foo")
9481
9482 def test_class_level_final_reinitialized_in_method(self):
9483 codestr = """
9484 from typing import Final
9485
9486 class C:
9487 x: Final = 3
9488
9489 def something(self) -> None:
9490 self.x = 4
9491 """
9492 with self.assertRaisesRegex(
9493 TypedSyntaxError, "Cannot assign to a Final attribute of foo.C:x"
9494 ):
9495 self.compile(codestr, StaticCodeGenerator, modname="foo")
9496
9497 def test_class_level_final_reinitialized_in_subclass_without_annotation(self):
9498 codestr = """
9499 from typing import Final
9500
9501 class C:
9502 x: Final = 3
9503
9504 class D(C):
9505 x = 4
9506 """
9507 with self.assertRaisesRegex(
9508 TypedSyntaxError,
9509 "Cannot assign to a Final attribute of foo.D:x",
9510 ):
9511 self.compile(codestr, StaticCodeGenerator, modname="foo")
9512
9513 def test_class_level_final_reinitialized_in_subclass_with_annotation(self):
9514 codestr = """
9515 from typing import Final
9516
9517 class C:
9518 x: Final = 3
9519
9520 class D(C):
9521 x: Final[int] = 4
9522 """
9523 with self.assertRaisesRegex(
9524 TypedSyntaxError,
9525 "Cannot assign to a Final attribute of foo.D:x",
9526 ):
9527 self.compile(codestr, StaticCodeGenerator, modname="foo")
9528
9529 def test_class_level_final_reinitialized_in_subclass_init(self):
9530 codestr = """
9531 from typing import Final
9532
9533 class C:
9534 x: Final = 3
9535
9536 class D(C):
9537 def __init__(self):
9538 self.x = 4
9539 """
9540 with self.assertRaisesRegex(
9541 TypedSyntaxError,
9542 "Cannot assign to a Final attribute of foo.D:x",
9543 ):
9544 self.compile(codestr, StaticCodeGenerator, modname="foo")
9545
9546 def test_class_level_final_reinitialized_in_subclass_init_with_annotation(self):
9547 codestr = """
9548 from typing import Final
9549
9550 class C:
9551 x: Final[int] = 3
9552
9553 class D(C):
9554 def __init__(self):
9555 self.x: Final[int] = 4
9556 """
9557 with self.assertRaisesRegex(
9558 TypedSyntaxError,
9559 "Cannot assign to a Final attribute of foo.D:x",
9560 ):
9561 self.compile(codestr, StaticCodeGenerator, modname="foo")
9562
9563 def test_class_level_final_decl_in_init_reinitialized_in_subclass_init(self):
9564 codestr = """
9565 from typing import Final
9566
9567 class C:
9568 x: Final[int]
9569
9570 def __init__(self) -> None:
9571 self.x = 3
9572
9573 class D(C):
9574 def __init__(self) -> None:
9575 self.x = 4
9576 """
9577 with self.assertRaisesRegex(
9578 TypedSyntaxError,
9579 "Cannot assign to a Final attribute of foo.D:x",
9580 ):
9581 self.compile(codestr, StaticCodeGenerator, modname="foo")
9582
9583 def test_final_in_args(self):
9584 codestr = """
9585 from typing import Final
9586
9587 def f(a: Final) -> None:
9588 pass
9589 """
9590 with self.assertRaisesRegex(
9591 TypedSyntaxError,
9592 "Final annotation is only valid in initial declaration",
9593 ):
9594 self.compile(codestr, StaticCodeGenerator, modname="foo")
9595
9596 def test_final_returns(self):
9597 codestr = """
9598 from typing import Final
9599
9600 def f() -> Final[int]:
9601 return 1
9602 """
9603 with self.assertRaisesRegex(
9604 TypedSyntaxError,
9605 "Final annotation is only valid in initial declaration",
9606 ):
9607 self.compile(codestr, StaticCodeGenerator, modname="foo")
9608
9609 def test_final_decorator(self):
9610 codestr = """
9611 from typing import final
9612
9613 class C:
9614 @final
9615 def f():
9616 pass
9617 """
9618 self.compile(codestr, StaticCodeGenerator, modname="foo")
9619
9620 def test_final_decorator_override(self):
9621 codestr = """
9622 from typing import final
9623
9624 class C:
9625 @final
9626 def f():
9627 pass
9628
9629 class D(C):
9630 def f():
9631 pass
9632 """
9633 with self.assertRaisesRegex(
9634 TypedSyntaxError, "Cannot assign to a Final attribute of foo.D:f"
9635 ):
9636 self.compile(codestr, StaticCodeGenerator, modname="foo")
9637
9638 def test_final_decorator_override_with_assignment(self):
9639 codestr = """
9640 from typing import final
9641
9642 class C:
9643 @final
9644 def f():
9645 pass
9646
9647 class D(C):
9648 f = print
9649 """
9650 with self.assertRaisesRegex(
9651 TypedSyntaxError, "Cannot assign to a Final attribute of foo.D:f"
9652 ):
9653 self.compile(codestr, StaticCodeGenerator, modname="foo")
9654
9655 def test_final_decorator_override_transitivity(self):
9656 codestr = """
9657 from typing import final
9658
9659 class C:
9660 @final
9661 def f():
9662 pass
9663
9664 class D(C):
9665 pass
9666
9667 class E(D):
9668 def f():
9669 pass
9670 """
9671 with self.assertRaisesRegex(
9672 TypedSyntaxError, "Cannot assign to a Final attribute of foo.E:f"
9673 ):
9674 self.compile(codestr, StaticCodeGenerator, modname="foo")
9675
9676 def test_final_decorator_class(self):
9677 codestr = """
9678 from typing import final
9679
9680 @final
9681 class C:
9682 def f(self):
9683 pass
9684
9685 def f():
9686 return C().f()
9687 """
9688 c = self.compile(codestr, StaticCodeGenerator, modname="foo")
9689 f = self.find_code(c, "f")
9690 self.assertInBytecode(f, "INVOKE_FUNCTION")
9691
9692 def test_final_decorator_class_inheritance(self):
9693 codestr = """
9694 from typing import final
9695
9696 @final
9697 class C:
9698 pass
9699
9700 class D(C):
9701 pass
9702 """
9703 with self.assertRaisesRegex(
9704 TypedSyntaxError, "Class `foo.D` cannot subclass a Final class: `foo.C`"
9705 ):
9706 self.compile(codestr, StaticCodeGenerator, modname="foo")
9707
9708 def test_final_decorator_class_dynamic(self):
9709 """We should never mark DYNAMIC_TYPE as final."""
9710 codestr = """
9711 from typing import final, Generic, NamedTuple
9712
9713 @final
9714 class NT(NamedTuple):
9715 x: int
9716
9717 class C(Generic):
9718 pass
9719 """
9720 # No TypedSyntaxError "cannot inherit from Final class 'dynamic'"
9721 self.compile(codestr)
9722
9723 def test_slotification_decorated(self):
9724 codestr = """
9725 class _Inner():
9726 pass
9727
9728 def something(klass):
9729 return _Inner
9730
9731 @something
9732 class C:
9733 def f(self):
9734 pass
9735
9736 def f():
9737 return C().f()
9738 """
9739 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
9740 f = mod["f"]
9741 self.assertNotInBytecode(f, "INVOKE_FUNCTION")
9742 self.assertNotInBytecode(f, "INVOKE_METHOD")
9743
9744 def test_inline_func(self):
9745 codestr = """
9746 from __static__ import inline
9747
9748 @inline
9749 def f(x, y):
9750 return x + y
9751
9752 def g():
9753 return f(1,2)
9754 """
9755 # we only inline at opt level 2 to avoid test patching problems
9756 # TODO longer term we might need something better here (e.g. emit both
9757 # inlined code and call and a guard to choose); assuming
9758 # non-patchability at opt 2 works for IG but isn't generally valid
9759 for opt in [0, 1, 2]:
9760 with self.subTest(opt=opt):
9761 with self.in_module(codestr, optimize=opt) as mod:
9762 g = mod["g"]
9763 if opt == 2:
9764 self.assertInBytecode(g, "LOAD_CONST", 3)
9765 else:
9766 self.assertInBytecode(
9767 g, "INVOKE_FUNCTION", ((mod["__name__"], "f"), 2)
9768 )
9769 self.assertEqual(g(), 3)
9770
9771 def test_inline_kwarg(self):
9772 codestr = """
9773 from __static__ import inline
9774
9775 @inline
9776 def f(x, y):
9777 return x + y
9778
9779 def g():
9780 return f(x=1,y=2)
9781 """
9782 with self.in_module(codestr, optimize=2) as mod:
9783 g = mod["g"]
9784 self.assertInBytecode(g, "LOAD_CONST", 3)
9785 self.assertEqual(g(), 3)
9786
9787 def test_inline_bare_return(self):
9788 codestr = """
9789 from __static__ import inline
9790
9791 @inline
9792 def f(x, y):
9793 return
9794
9795 def g():
9796 return f(x=1,y=2)
9797 """
9798 with self.in_module(codestr, optimize=2) as mod:
9799 g = mod["g"]
9800 self.assertInBytecode(g, "LOAD_CONST", None)
9801 self.assertEqual(g(), None)
9802
9803 def test_inline_final(self):
9804 codestr = """
9805 from __static__ import inline
9806 from typing import Final
9807
9808 Y: Final[int] = 42
9809 @inline
9810 def f(x):
9811 return x + Y
9812
9813 def g():
9814 return f(1)
9815 """
9816 with self.in_module(codestr, optimize=2) as mod:
9817 g = mod["g"]
9818 # We don't currently inline math with finals
9819 self.assertInBytecode(g, "LOAD_CONST", 42)
9820 self.assertEqual(g(), 43)
9821
9822 def test_inline_nested(self):
9823 codestr = """
9824 from __static__ import inline
9825
9826 @inline
9827 def e(x, y):
9828 return x + y
9829
9830 @inline
9831 def f(x, y):
9832 return e(x, 3)
9833
9834 def g():
9835 return f(1,2)
9836 """
9837 with self.in_module(codestr, optimize=2) as mod:
9838 g = mod["g"]
9839 self.assertInBytecode(g, "LOAD_CONST", 4)
9840 self.assertEqual(g(), 4)
9841
9842 def test_inline_nested_arg(self):
9843 codestr = """
9844 from __static__ import inline
9845
9846 @inline
9847 def e(x, y):
9848 return x + y
9849
9850 @inline
9851 def f(x, y):
9852 return e(x, 3)
9853
9854 def g(a,b):
9855 return f(a,b)
9856 """
9857 with self.in_module(codestr, optimize=2) as mod:
9858 g = mod["g"]
9859 self.assertInBytecode(g, "LOAD_CONST", 3)
9860 self.assertInBytecode(g, "BINARY_ADD")
9861 self.assertEqual(g(1, 2), 4)
9862
9863 def test_inline_recursive(self):
9864 codestr = """
9865 from __static__ import inline
9866
9867 @inline
9868 def f(x, y):
9869 return f(x, y)
9870
9871 def g():
9872 return f(1,2)
9873 """
9874 with self.in_module(codestr, optimize=2) as mod:
9875 g = mod["g"]
9876 self.assertInBytecode(g, "INVOKE_FUNCTION", (((mod["__name__"], "f"), 2)))
9877
9878 def test_inline_func_default(self):
9879 codestr = """
9880 from __static__ import inline
9881
9882 @inline
9883 def f(x, y = 2):
9884 return x + y
9885
9886 def g():
9887 return f(1)
9888 """
9889 with self.in_module(codestr, optimize=2) as mod:
9890 g = mod["g"]
9891 self.assertInBytecode(g, "LOAD_CONST", 3)
9892
9893 self.assertEqual(g(), 3)
9894
9895 def test_inline_arg_type(self):
9896 codestr = """
9897 from __static__ import box, inline, int64, int32
9898
9899 @inline
9900 def f(x: int64) -> int:
9901 return box(x)
9902
9903 def g(arg: int) -> int:
9904 return f(int64(arg))
9905 """
9906 with self.in_module(codestr, optimize=2) as mod:
9907 g = mod["g"]
9908 self.assertInBytecode(g, "PRIMITIVE_BOX")
9909 self.assertEqual(g(3), 3)
9910
9911 def test_augassign_primitive_int(self):
9912 codestr = """
9913 from __static__ import int8, box, unbox
9914
9915 def a(i: int) -> int:
9916 j: int8 = unbox(i)
9917 j += 2
9918 return box(j)
9919 """
9920 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
9921 a = mod["a"]
9922 self.assertInBytecode(a, "PRIMITIVE_BINARY_OP", 0)
9923 self.assertEqual(a(3), 5)
9924
9925 def test_primitive_compare_immediate_no_branch_on_result(self):
9926 for rev in [True, False]:
9927 compare = "0 == xp" if rev else "xp == 0"
9928 codestr = f"""
9929 from __static__ import box, int64, int32
9930
9931 def f(x: int) -> bool:
9932 xp = int64(x)
9933 y = {compare}
9934 return box(y)
9935 """
9936 with self.subTest(rev=rev):
9937 with self.in_module(codestr) as mod:
9938 f = mod["f"]
9939 self.assertEqual(f(3), 0)
9940 self.assertEqual(f(0), 1)
9941 self.assertIs(f(0), True)
9942
9943 def test_type_type_final(self):
9944 codestr = """
9945 class A(type):
9946 pass
9947 """
9948 self.compile(codestr)
9949
9950
9951class StaticRuntimeTests(StaticTestBase):
9952 def test_bad_slots_qualname_conflict(self):
9953 with self.assertRaises(ValueError):
9954
9955 class C:
9956 __slots__ = ("x",)
9957 __slot_types__ = {"x": ("__static__", "int32")}
9958 x = 42
9959
9960 def test_typed_slots_bad_inst(self):
9961 class C:
9962 __slots__ = ("a",)
9963 __slot_types__ = {"a": ("__static__", "int32")}
9964
9965 class D:
9966 pass
9967
9968 with self.assertRaises(TypeError):
9969 C.a.__get__(D(), D)
9970
9971 def test_typed_slots_bad_slots(self):
9972 with self.assertRaises(TypeError):
9973
9974 class C:
9975 __slots__ = ("a",)
9976 __slot_types__ = None
9977
9978 def test_typed_slots_bad_slot_dict(self):
9979 with self.assertRaises(TypeError):
9980
9981 class C:
9982 __slots__ = ("__dict__",)
9983 __slot_types__ = {"__dict__": "object"}
9984
9985 def test_typed_slots_bad_slot_weakerf(self):
9986 with self.assertRaises(TypeError):
9987
9988 class C:
9989 __slots__ = ("__weakref__",)
9990 __slot_types__ = {"__weakref__": "object"}
9991
9992 def test_typed_slots_object(self):
9993 codestr = """
9994 class C:
9995 __slots__ = ('a', )
9996 __slot_types__ = {'a': (__name__, 'C')}
9997
9998 inst = C()
9999 """
10000
10001 with self.in_module(codestr, code_gen=PythonCodeGenerator) as mod:
10002 inst, C = mod["inst"], mod["C"]
10003 self.assertEqual(C.a.__class__.__name__, "typed_descriptor")
10004 with self.assertRaises(TypeError):
10005 # type is checked
10006 inst.a = 42
10007 with self.assertRaises(TypeError):
10008 inst.a = None
10009 with self.assertRaises(AttributeError):
10010 # is initially unassigned
10011 inst.a
10012
10013 # can assign correct type
10014 inst.a = inst
10015
10016 # __sizeof__ doesn't include GC header size
10017 self.assertEqual(inst.__sizeof__(), self.base_size + self.ptr_size)
10018 # size is +2 words for GC header, one word for reference
10019 self.assertEqual(sys.getsizeof(inst), self.base_size + (self.ptr_size * 3))
10020
10021 # subclasses are okay
10022 class D(C):
10023 pass
10024
10025 inst.a = D()
10026
10027 def test_allow_weakrefs(self):
10028 codestr = """
10029 from __static__ import allow_weakrefs
10030 import weakref
10031
10032 @allow_weakrefs
10033 class C:
10034 pass
10035
10036 def f(c: C):
10037 return weakref.ref(c)
10038 """
10039 with self.in_module(codestr) as mod:
10040 C = mod["C"]
10041 c = C()
10042 ref = mod["f"](c)
10043 self.assertIs(ref(), c)
10044 del c
10045 self.assertIs(ref(), None)
10046 self.assertEqual(C.__slots__, ("__weakref__",))
10047
10048 def test_dynamic_return(self):
10049 codestr = """
10050 from __future__ import annotations
10051 from __static__ import allow_weakrefs, dynamic_return
10052 import weakref
10053
10054 singletons = []
10055
10056 @allow_weakrefs
10057 class C:
10058 @dynamic_return
10059 @staticmethod
10060 def make() -> C:
10061 return weakref.proxy(singletons[0])
10062
10063 def g(self) -> int:
10064 return 1
10065
10066 singletons.append(C())
10067
10068 def f() -> int:
10069 c = C.make()
10070 return c.g()
10071 """
10072 with self.in_strict_module(codestr) as mod:
10073 # We don't try to cast the return type of make
10074 self.assertNotInBytecode(mod.C.make, "CAST")
10075 # We can statically invoke make
10076 self.assertInBytecode(
10077 mod.f, "INVOKE_FUNCTION", ((mod.__name__, "C", "make"), 0)
10078 )
10079 # But we can't statically invoke a method on the returned instance
10080 self.assertNotInBytecode(mod.f, "INVOKE_METHOD")
10081 self.assertEqual(mod.f(), 1)
10082 # We don't mess with __annotations__
10083 self.assertEqual(mod.C.make.__annotations__, {"return": "C"})
10084
10085 def test_generic_type_def_no_create(self):
10086 from xxclassloader import spamobj
10087
10088 with self.assertRaises(TypeError):
10089 spamobj()
10090
10091 def test_generic_type_def_bad_args(self):
10092 from xxclassloader import spamobj
10093
10094 with self.assertRaises(TypeError):
10095 spamobj[str, int]
10096
10097 def test_generic_type_def_non_type(self):
10098 from xxclassloader import spamobj
10099
10100 with self.assertRaises(TypeError):
10101 spamobj[42]
10102
10103 def test_generic_type_inst_okay(self):
10104 from xxclassloader import spamobj
10105
10106 o = spamobj[str]()
10107 o.setstate("abc")
10108
10109 def test_generic_type_inst_optional_okay(self):
10110 from xxclassloader import spamobj
10111
10112 o = spamobj[Optional[str]]()
10113 o.setstate("abc")
10114 o.setstate(None)
10115
10116 def test_generic_type_inst_non_optional_error(self):
10117 from xxclassloader import spamobj
10118
10119 o = spamobj[str]()
10120 with self.assertRaises(TypeError):
10121 o.setstate(None)
10122
10123 def test_generic_type_inst_bad_type(self):
10124 from xxclassloader import spamobj
10125
10126 o = spamobj[str]()
10127 with self.assertRaises(TypeError):
10128 o.setstate(42)
10129
10130 def test_generic_type_inst_name(self):
10131 from xxclassloader import spamobj
10132
10133 self.assertEqual(spamobj[str].__name__, "spamobj[str]")
10134
10135 def test_generic_type_inst_name_optional(self):
10136 from xxclassloader import spamobj
10137
10138 self.assertEqual(spamobj[Optional[str]].__name__, "spamobj[Optional[str]]")
10139
10140 def test_generic_type_inst_okay_func(self):
10141 from xxclassloader import spamobj
10142
10143 o = spamobj[str]()
10144 f = o.setstate
10145 f("abc")
10146
10147 def test_generic_type_inst_optional_okay_func(self):
10148 from xxclassloader import spamobj
10149
10150 o = spamobj[Optional[str]]()
10151 f = o.setstate
10152 f("abc")
10153 f(None)
10154
10155 def test_generic_type_inst_non_optional_error_func(self):
10156 from xxclassloader import spamobj
10157
10158 o = spamobj[str]()
10159 f = o.setstate
10160 with self.assertRaises(TypeError):
10161 f(None)
10162
10163 def test_generic_type_inst_bad_type_func(self):
10164 from xxclassloader import spamobj
10165
10166 o = spamobj[str]()
10167 f = o.setstate
10168 with self.assertRaises(TypeError):
10169 f(42)
10170
10171 def test_generic_int_funcs(self):
10172 from xxclassloader import spamobj
10173
10174 o = spamobj[str]()
10175 o.setint(42)
10176 self.assertEqual(o.getint8(), 42)
10177 self.assertEqual(o.getint16(), 42)
10178 self.assertEqual(o.getint32(), 42)
10179
10180 def test_generic_uint_funcs(self):
10181 from xxclassloader import spamobj
10182
10183 o = spamobj[str]()
10184 o.setuint64(42)
10185 self.assertEqual(o.getuint8(), 42)
10186 self.assertEqual(o.getuint16(), 42)
10187 self.assertEqual(o.getuint32(), 42)
10188 self.assertEqual(o.getuint64(), 42)
10189
10190 def test_generic_int_funcs_overflow(self):
10191 from xxclassloader import spamobj
10192
10193 o = spamobj[str]()
10194 o.setuint64(42)
10195 for i, f in enumerate([o.setint8, o.setint16, o.setint32, o.setint]):
10196 with self.assertRaises(OverflowError):
10197 x = -(1 << ((8 << i) - 1)) - 1
10198 f(x)
10199 with self.assertRaises(OverflowError):
10200 x = 1 << ((8 << i) - 1)
10201 f(x)
10202
10203 def test_generic_uint_funcs_overflow(self):
10204 from xxclassloader import spamobj
10205
10206 o = spamobj[str]()
10207 o.setuint64(42)
10208 for f in [o.setuint8, o.setuint16, o.setuint32, o.setuint64]:
10209 with self.assertRaises(OverflowError):
10210 f(-1)
10211 for i, f in enumerate([o.setuint8, o.setuint16, o.setuint32, o.setuint64]):
10212 with self.assertRaises(OverflowError):
10213 x = (1 << (8 << i)) + 1
10214 f(x)
10215
10216 def test_generic_type_int_func(self):
10217 from xxclassloader import spamobj
10218
10219 o = spamobj[str]()
10220 o.setint(42)
10221 self.assertEqual(o.getint(), 42)
10222 with self.assertRaises(TypeError):
10223 o.setint("abc")
10224
10225 def test_generic_type_str_func(self):
10226 from xxclassloader import spamobj
10227
10228 o = spamobj[str]()
10229 o.setstr("abc")
10230 self.assertEqual(o.getstr(), "abc")
10231 with self.assertRaises(TypeError):
10232 o.setstr(42)
10233
10234 def test_generic_type_bad_arg_cnt(self):
10235 from xxclassloader import spamobj
10236
10237 o = spamobj[str]()
10238 with self.assertRaises(TypeError):
10239 o.setstr()
10240 with self.assertRaises(TypeError):
10241 o.setstr("abc", "abc")
10242
10243 def test_generic_type_bad_arg_cnt(self):
10244 from xxclassloader import spamobj
10245
10246 o = spamobj[str]()
10247 self.assertEqual(o.twoargs(1, 2), 3)
10248
10249 def test_typed_slots_one_missing(self):
10250 codestr = """
10251 class C:
10252 __slots__ = ('a', 'b')
10253 __slot_types__ = {'a': (__name__, 'C')}
10254
10255 inst = C()
10256 """
10257
10258 with self.in_module(codestr, code_gen=PythonCodeGenerator) as mod:
10259 inst, C = mod["inst"], mod["C"]
10260 self.assertEqual(C.a.__class__.__name__, "typed_descriptor")
10261 with self.assertRaises(TypeError):
10262 # type is checked
10263 inst.a = 42
10264
10265 def test_typed_slots_optional_object(self):
10266 codestr = """
10267 class C:
10268 __slots__ = ('a', )
10269 __slot_types__ = {'a': (__name__, 'C', '?')}
10270
10271 inst = C()
10272 """
10273
10274 with self.in_module(codestr, code_gen=PythonCodeGenerator) as mod:
10275 inst, C = mod["inst"], mod["C"]
10276 inst.a = None
10277 self.assertEqual(inst.a, None)
10278
10279 def test_typed_slots_private(self):
10280 codestr = """
10281 class C:
10282 __slots__ = ('__a', )
10283 __slot_types__ = {'__a': (__name__, 'C', '?')}
10284 def __init__(self):
10285 self.__a = None
10286
10287 inst = C()
10288 """
10289
10290 with self.in_module(codestr, code_gen=PythonCodeGenerator) as mod:
10291 inst, C = mod["inst"], mod["C"]
10292 self.assertEqual(inst._C__a, None)
10293 inst._C__a = inst
10294 self.assertEqual(inst._C__a, inst)
10295 inst._C__a = None
10296 self.assertEqual(inst._C__a, None)
10297
10298 def test_typed_slots_optional_not_defined(self):
10299 codestr = """
10300 class C:
10301 __slots__ = ('a', )
10302 __slot_types__ = {'a': (__name__, 'D', '?')}
10303
10304 def __init__(self):
10305 self.a = None
10306
10307 inst = C()
10308
10309 class D:
10310 pass
10311 """
10312
10313 with self.in_module(codestr, code_gen=PythonCodeGenerator) as mod:
10314 inst, C = mod["inst"], mod["C"]
10315 inst.a = None
10316 self.assertEqual(inst.a, None)
10317
10318 def test_typed_slots_alignment(self):
10319 return
10320 codestr = """
10321 class C:
10322 __slots__ = ('a', 'b')
10323 __slot_types__ {'a': ('__static__', 'int16')}
10324
10325 inst = C()
10326 """
10327
10328 with self.in_module(codestr, code_gen=PythonCodeGenerator) as mod:
10329 inst, C = mod["inst"], mod["C"]
10330 inst.a = None
10331 self.assertEqual(inst.a, None)
10332
10333 def test_typed_slots_primitives(self):
10334 slot_types = [
10335 # signed
10336 (
10337 ("__static__", "byte"),
10338 0,
10339 1,
10340 [(1 << 7) - 1, -(1 << 7)],
10341 [1 << 8],
10342 ["abc"],
10343 ),
10344 (
10345 ("__static__", "int8"),
10346 0,
10347 1,
10348 [(1 << 7) - 1, -(1 << 7)],
10349 [1 << 8],
10350 ["abc"],
10351 ),
10352 (
10353 ("__static__", "int16"),
10354 0,
10355 2,
10356 [(1 << 15) - 1, -(1 << 15)],
10357 [1 << 15, -(1 << 15) - 1],
10358 ["abc"],
10359 ),
10360 (
10361 ("__static__", "int32"),
10362 0,
10363 4,
10364 [(1 << 31) - 1, -(1 << 31)],
10365 [1 << 31, -(1 << 31) - 1],
10366 ["abc"],
10367 ),
10368 (("__static__", "int64"), 0, 8, [(1 << 63) - 1, -(1 << 63)], [], [1 << 63]),
10369 # unsigned
10370 (("__static__", "uint8"), 0, 1, [(1 << 8) - 1, 0], [1 << 8, -1], ["abc"]),
10371 (
10372 ("__static__", "uint16"),
10373 0,
10374 2,
10375 [(1 << 16) - 1, 0],
10376 [1 << 16, -1],
10377 ["abc"],
10378 ),
10379 (
10380 ("__static__", "uint32"),
10381 0,
10382 4,
10383 [(1 << 32) - 1, 0],
10384 [1 << 32, -1],
10385 ["abc"],
10386 ),
10387 (("__static__", "uint64"), 0, 8, [(1 << 64) - 1, 0], [], [1 << 64]),
10388 # pointer
10389 (
10390 ("__static__", "ssize_t"),
10391 0,
10392 self.ptr_size,
10393 [1, sys.maxsize, -sys.maxsize - 1],
10394 [],
10395 [sys.maxsize + 1, -sys.maxsize - 2],
10396 ),
10397 # floating point
10398 (("__static__", "single"), 0.0, 4, [1.0], [], ["abc"]),
10399 (("__static__", "double"), 0.0, 8, [1.0], [], ["abc"]),
10400 # misc
10401 (("__static__", "char"), "\x00", 1, ["a"], [], ["abc"]),
10402 (("__static__", "cbool"), False, 1, [True], [], ["abc", 1]),
10403 ]
10404
10405 base_size = self.base_size
10406 for type_spec, default, size, test_vals, warn_vals, err_vals in slot_types:
10407 with self.subTest(
10408 type_spec=type_spec,
10409 default=default,
10410 size=size,
10411 test_vals=test_vals,
10412 warn_vals=warn_vals,
10413 err_vals=err_vals,
10414 ):
10415
10416 class C:
10417 __slots__ = ("a",)
10418 __slot_types__ = {"a": type_spec}
10419
10420 a = C()
10421 self.assertEqual(sys.getsizeof(a), base_size + size, type)
10422 self.assertEqual(a.a, default)
10423 self.assertEqual(type(a.a), type(default))
10424 for val in test_vals:
10425 a.a = val
10426 self.assertEqual(a.a, val)
10427
10428 with warnings.catch_warnings():
10429 warnings.simplefilter("error", category=RuntimeWarning)
10430 for val in warn_vals:
10431 with self.assertRaises(RuntimeWarning):
10432 a.a = val
10433
10434 for val in err_vals:
10435 with self.assertRaises((TypeError, OverflowError)):
10436 a.a = val
10437
10438 def test_invoke_function(self):
10439 my_int = "12345"
10440 codestr = f"""
10441 def x(a: str, b: int) -> str:
10442 return a + str(b)
10443
10444 def test() -> str:
10445 return x("hello", {my_int})
10446 """
10447 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
10448 test = self.find_code(c, "test")
10449 self.assertInBytecode(test, "INVOKE_FUNCTION", (("foo.py", "x"), 2))
10450 with self.in_module(codestr) as mod:
10451 test_callable = mod["test"]
10452 self.assertEqual(test_callable(), "hello" + my_int)
10453
10454 def test_awaited_invoke_function(self):
10455 codestr = """
10456 async def f() -> int:
10457 return 1
10458
10459 async def g() -> int:
10460 return await f()
10461 """
10462 with self.in_strict_module(codestr) as mod:
10463 self.assertInBytecode(mod.g, "INVOKE_FUNCTION", ((mod.__name__, "f"), 0))
10464 self.assertEqual(asyncio.run(mod.g()), 1)
10465
10466 def test_awaited_invoke_function_unjitable(self):
10467 codestr = """
10468 async def f() -> int:
10469 from os.path import *
10470 return 1
10471
10472 async def g() -> int:
10473 return await f()
10474 """
10475 with self.in_strict_module(codestr) as mod:
10476 self.assertInBytecode(
10477 mod.g,
10478 "INVOKE_FUNCTION",
10479 ((mod.__name__, "f"), 0),
10480 )
10481 self.assertEqual(asyncio.run(mod.g()), 1)
10482
10483 def test_awaited_invoke_function_with_args(self):
10484 codestr = """
10485 async def f(a: int, b: int) -> int:
10486 return a + b
10487
10488 async def g() -> int:
10489 return await f(1, 2)
10490 """
10491 with self.in_strict_module(codestr) as mod:
10492 self.assertInBytecode(
10493 mod.g,
10494 "INVOKE_FUNCTION",
10495 ((mod.__name__, "f"), 2),
10496 )
10497 self.assertEqual(asyncio.run(mod.g()), 3)
10498
10499 # exercise shadowcode, INVOKE_FUNCTION_CACHED
10500 self.make_async_func_hot(mod.g)
10501 self.assertEqual(asyncio.run(mod.g()), 3)
10502
10503 def test_awaited_invoke_function_indirect_with_args(self):
10504 codestr = """
10505 async def f(a: int, b: int) -> int:
10506 return a + b
10507
10508 async def g() -> int:
10509 return await f(1, 2)
10510 """
10511 with self.in_module(codestr) as mod:
10512 g = mod["g"]
10513 self.assertInBytecode(
10514 g,
10515 "INVOKE_FUNCTION",
10516 ((mod["__name__"], "f"), 2),
10517 )
10518 self.assertEqual(asyncio.run(g()), 3)
10519
10520 # exercise shadowcode, INVOKE_FUNCTION_INDIRECT_CACHED
10521 self.make_async_func_hot(g)
10522 self.assertEqual(asyncio.run(g()), 3)
10523
10524 def test_awaited_invoke_function_future(self):
10525 codestr = """
10526 from asyncio import ensure_future
10527
10528 async def h() -> int:
10529 return 1
10530
10531 async def g() -> None:
10532 await ensure_future(h())
10533
10534 async def f():
10535 await g()
10536 """
10537 with self.in_strict_module(codestr) as mod:
10538 self.assertInBytecode(
10539 mod.f,
10540 "INVOKE_FUNCTION",
10541 ((mod.__name__, "g"), 0),
10542 )
10543 asyncio.run(mod.f())
10544
10545 # exercise shadowcode
10546 self.make_async_func_hot(mod.f)
10547 asyncio.run(mod.f())
10548
10549 def test_awaited_invoke_method(self):
10550 codestr = """
10551 class C:
10552 async def f(self) -> int:
10553 return 1
10554
10555 async def g(self) -> int:
10556 return await self.f()
10557 """
10558 with self.in_strict_module(codestr) as mod:
10559 self.assertInBytecode(
10560 mod.C.g, "INVOKE_METHOD", ((mod.__name__, "C", "f"), 0)
10561 )
10562 self.assertEqual(asyncio.run(mod.C().g()), 1)
10563
10564 def test_awaited_invoke_method_with_args(self):
10565 codestr = """
10566 class C:
10567 async def f(self, a: int, b: int) -> int:
10568 return a + b
10569
10570 async def g(self) -> int:
10571 return await self.f(1, 2)
10572 """
10573 with self.in_strict_module(codestr) as mod:
10574 self.assertInBytecode(
10575 mod.C.g,
10576 "INVOKE_METHOD",
10577 ((mod.__name__, "C", "f"), 2),
10578 )
10579 self.assertEqual(asyncio.run(mod.C().g()), 3)
10580
10581 # exercise shadowcode, INVOKE_METHOD_CACHED
10582 async def make_hot():
10583 c = mod.C()
10584 for i in range(50):
10585 await c.g()
10586
10587 asyncio.run(make_hot())
10588 self.assertEqual(asyncio.run(mod.C().g()), 3)
10589
10590 def test_awaited_invoke_method_future(self):
10591 codestr = """
10592 from asyncio import ensure_future
10593
10594 async def h() -> int:
10595 return 1
10596
10597 class C:
10598 async def g(self) -> None:
10599 await ensure_future(h())
10600
10601 async def f():
10602 c = C()
10603 await c.g()
10604 """
10605 with self.in_strict_module(codestr) as mod:
10606 self.assertInBytecode(
10607 mod.f,
10608 "INVOKE_METHOD",
10609 ((mod.__name__, "C", "g"), 0),
10610 )
10611 asyncio.run(mod.f())
10612
10613 # exercise shadowcode, INVOKE_METHOD_CACHED
10614 self.make_async_func_hot(mod.f)
10615 asyncio.run(mod.f())
10616
10617 def test_load_iterable_arg(self):
10618 codestr = """
10619 def x(a: int, b: int, c: str, d: float, e: float) -> int:
10620 return 7
10621
10622 def y() -> int:
10623 p = ("hi", 0.1, 0.2)
10624 return x(1, 3, *p)
10625 """
10626 y = self.find_code(
10627 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
10628 )
10629 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 0)
10630 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 1)
10631 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 2)
10632 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG", 3)
10633 with self.in_module(codestr) as mod:
10634 y_callable = mod["y"]
10635 self.assertEqual(y_callable(), 7)
10636
10637 def test_load_iterable_arg_default_overridden(self):
10638 codestr = """
10639 def x(a: int, b: int, c: str, d: float = 10.1, e: float = 20.1) -> bool:
10640 return bool(
10641 a == 1
10642 and b == 3
10643 and c == "hi"
10644 and d == 0.1
10645 and e == 0.2
10646 )
10647
10648 def y() -> bool:
10649 p = ("hi", 0.1, 0.2)
10650 return x(1, 3, *p)
10651 """
10652 y = self.find_code(
10653 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
10654 )
10655 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG", 3)
10656 self.assertNotInBytecode(y, "LOAD_MAPPING_ARG", 3)
10657 with self.in_module(codestr) as mod:
10658 y_callable = mod["y"]
10659 self.assertTrue(y_callable())
10660
10661 def test_load_iterable_arg_multi_star(self):
10662 codestr = """
10663 def x(a: int, b: int, c: str, d: float, e: float) -> int:
10664 return 7
10665
10666 def y() -> int:
10667 p = (1, 3)
10668 q = ("hi", 0.1, 0.2)
10669 return x(*p, *q)
10670 """
10671 y = self.find_code(
10672 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
10673 )
10674 # we should fallback to the normal Python compiler for this
10675 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG")
10676 with self.in_module(codestr) as mod:
10677 y_callable = mod["y"]
10678 self.assertEqual(y_callable(), 7)
10679
10680 def test_load_iterable_arg_star_not_last(self):
10681 codestr = """
10682 def x(a: int, b: int, c: str, d: float, e: float) -> int:
10683 return 7
10684
10685 def y() -> int:
10686 p = (1, 3, 'abc', 0.1)
10687 return x(*p, 1.0)
10688 """
10689 y = self.find_code(
10690 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
10691 )
10692 # we should fallback to the normal Python compiler for this
10693 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG")
10694 with self.in_module(codestr) as mod:
10695 y_callable = mod["y"]
10696 self.assertEqual(y_callable(), 7)
10697
10698 def test_load_iterable_arg_failure(self):
10699 codestr = """
10700 def x(a: int, b: int, c: str, d: float, e: float) -> int:
10701 return 7
10702
10703 def y() -> int:
10704 p = ("hi", 0.1)
10705 return x(1, 3, *p)
10706 """
10707 y = self.find_code(
10708 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
10709 )
10710 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 0)
10711 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 1)
10712 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 2)
10713 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG", 3)
10714 with self.in_module(codestr) as mod:
10715 y_callable = mod["y"]
10716 with self.assertRaises(IndexError):
10717 y_callable()
10718
10719 def test_load_iterable_arg_sequence(self):
10720 codestr = """
10721 def x(a: int, b: int, c: str, d: float, e: float) -> int:
10722 return 7
10723
10724 def y() -> int:
10725 p = ["hi", 0.1, 0.2]
10726 return x(1, 3, *p)
10727 """
10728 y = self.find_code(
10729 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
10730 )
10731 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 0)
10732 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 1)
10733 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 2)
10734 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG", 3)
10735 with self.in_module(codestr) as mod:
10736 y_callable = mod["y"]
10737 self.assertEqual(y_callable(), 7)
10738
10739 def test_load_iterable_arg_sequence_1(self):
10740 codestr = """
10741 def x(a: int, b: int, c: str, d: float, e: float) -> int:
10742 return 7
10743
10744 def gen():
10745 for i in ["hi", 0.05, 0.2]:
10746 yield i
10747
10748 def y() -> int:
10749 g = gen()
10750 return x(1, 3, *g)
10751 """
10752 y = self.find_code(
10753 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
10754 )
10755 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 0)
10756 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 1)
10757 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 2)
10758 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG", 3)
10759 with self.in_module(codestr) as mod:
10760 y_callable = mod["y"]
10761 self.assertEqual(y_callable(), 7)
10762
10763 def test_load_iterable_arg_sequence_failure(self):
10764 codestr = """
10765 def x(a: int, b: int, c: str, d: float, e: float) -> int:
10766 return 7
10767
10768 def y() -> int:
10769 p = ["hi", 0.1]
10770 return x(1, 3, *p)
10771 """
10772 y = self.find_code(
10773 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
10774 )
10775 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 0)
10776 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 1)
10777 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 2)
10778 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG", 3)
10779 with self.in_module(codestr) as mod:
10780 y_callable = mod["y"]
10781 with self.assertRaises(IndexError):
10782 y_callable()
10783
10784 def test_load_mapping_arg(self):
10785 codestr = """
10786 def x(a: int, b: int, c: str, d: float=-0.1, e: float=1.1, f: str="something") -> bool:
10787 return bool(f == "yo" and d == 1.0 and e == 1.1)
10788
10789 def y() -> bool:
10790 d = {"d": 1.0}
10791 return x(1, 3, "hi", f="yo", **d)
10792 """
10793 y = self.find_code(
10794 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
10795 )
10796 self.assertInBytecode(y, "LOAD_MAPPING_ARG", 3)
10797 with self.in_module(codestr) as mod:
10798 y_callable = mod["y"]
10799 self.assertTrue(y_callable())
10800
10801 def test_load_mapping_and_iterable_args_failure_1(self):
10802 """
10803 Fails because we don't supply enough positional args
10804 """
10805
10806 codestr = """
10807 def x(a: int, b: int, c: str, d: float=2.2, e: float=1.1, f: str="something") -> bool:
10808 return bool(a == 1 and b == 3 and f == "yo" and d == 2.2 and e == 1.1)
10809
10810 def y() -> bool:
10811 return x(1, 3, f="yo")
10812 """
10813 with self.assertRaisesRegex(
10814 SyntaxError, "Function foo.x expects a value for argument c"
10815 ):
10816 self.compile(codestr, StaticCodeGenerator, modname="foo")
10817
10818 def test_load_mapping_arg_failure(self):
10819 """
10820 Fails because we supply an extra kwarg
10821 """
10822 codestr = """
10823 def x(a: int, b: int, c: str, d: float=2.2, e: float=1.1, f: str="something") -> bool:
10824 return bool(a == 1 and b == 3 and f == "yo" and d == 2.2 and e == 1.1)
10825
10826 def y() -> bool:
10827 return x(1, 3, "hi", f="yo", g="lol")
10828 """
10829 with self.assertRaisesRegex(
10830 TypedSyntaxError,
10831 "Given argument g does not exist in the definition of foo.x",
10832 ):
10833 self.compile(codestr, StaticCodeGenerator, modname="foo")
10834
10835 def test_load_mapping_arg_custom_class(self):
10836 """
10837 Fails because we supply a custom class for the mapped args, instead of a dict
10838 """
10839 codestr = """
10840 def x(a: int, b: int, c: str="hello") -> bool:
10841 return bool(a == 1 and b == 3 and c == "hello")
10842
10843 class C:
10844 def __getitem__(self, key: str) -> str:
10845 if key == "c":
10846 return "hi"
10847
10848 def keys(self):
10849 return ["c"]
10850
10851 def y() -> bool:
10852 return x(1, 3, **C())
10853 """
10854 with self.in_module(codestr) as mod:
10855 y_callable = mod["y"]
10856 with self.assertRaisesRegex(
10857 TypeError, r"argument after \*\* must be a dict, not C"
10858 ):
10859 self.assertTrue(y_callable())
10860
10861 def test_load_mapping_arg_use_defaults(self):
10862 codestr = """
10863 def x(a: int, b: int, c: str, d: float=-0.1, e: float=1.1, f: str="something") -> bool:
10864 return bool(f == "yo" and d == -0.1 and e == 1.1)
10865
10866 def y() -> bool:
10867 d = {"d": 1.0}
10868 return x(1, 3, "hi", f="yo")
10869 """
10870 y = self.find_code(
10871 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
10872 )
10873 self.assertInBytecode(y, "LOAD_CONST", 1.1)
10874 with self.in_module(codestr) as mod:
10875 y_callable = mod["y"]
10876 self.assertTrue(y_callable())
10877
10878 def test_default_arg_non_const(self):
10879 codestr = """
10880 class C: pass
10881 def x(val=C()) -> C:
10882 return val
10883
10884 def f() -> C:
10885 return x()
10886 """
10887 with self.in_module(codestr) as mod:
10888 f = mod["f"]
10889 self.assertInBytecode(f, "CALL_FUNCTION")
10890
10891 def test_default_arg_non_const_kw_provided(self):
10892 codestr = """
10893 class C: pass
10894 def x(val:object=C()):
10895 return val
10896
10897 def f():
10898 return x(val=42)
10899 """
10900
10901 with self.in_module(codestr) as mod:
10902 f = mod["f"]
10903 self.assertEqual(f(), 42)
10904
10905 def test_load_mapping_arg_order(self):
10906 codestr = """
10907 def x(a: int, b: int, c: str, d: float=-0.1, e: float=1.1, f: str="something") -> bool:
10908 return bool(
10909 a == 1
10910 and b == 3
10911 and c == "hi"
10912 and d == 1.1
10913 and e == 3.3
10914 and f == "hmm"
10915 )
10916
10917 stuff = []
10918 def q() -> float:
10919 stuff.append("q")
10920 return 1.1
10921
10922 def r() -> float:
10923 stuff.append("r")
10924 return 3.3
10925
10926 def s() -> str:
10927 stuff.append("s")
10928 return "hmm"
10929
10930 def y() -> bool:
10931 return x(1, 3, "hi", f=s(), d=q(), e=r())
10932 """
10933 y = self.find_code(
10934 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
10935 )
10936 self.assertInBytecode(y, "STORE_FAST", "_pystatic_.0._tmp__d")
10937 self.assertInBytecode(y, "LOAD_FAST", "_pystatic_.0._tmp__d")
10938 with self.in_module(codestr) as mod:
10939 y_callable = mod["y"]
10940 self.assertTrue(y_callable())
10941 self.assertEqual(["s", "q", "r"], mod["stuff"])
10942
10943 def test_load_mapping_arg_order_with_variadic_kw_args(self):
10944 codestr = """
10945 def x(a: int, b: int, c: str, d: float=-0.1, e: float=1.1, f: str="something", g: str="look-here") -> bool:
10946 return bool(
10947 a == 1
10948 and b == 3
10949 and c == "hi"
10950 and d == 1.1
10951 and e == 3.3
10952 and f == "hmm"
10953 and g == "overridden"
10954 )
10955
10956 stuff = []
10957 def q() -> float:
10958 stuff.append("q")
10959 return 1.1
10960
10961 def r() -> float:
10962 stuff.append("r")
10963 return 3.3
10964
10965 def s() -> str:
10966 stuff.append("s")
10967 return "hmm"
10968
10969 def y() -> bool:
10970 kw = {"g": "overridden"}
10971 return x(1, 3, "hi", f=s(), **kw, d=q(), e=r())
10972 """
10973 y = self.find_code(
10974 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
10975 )
10976 self.assertInBytecode(y, "STORE_FAST", "_pystatic_.0._tmp__d")
10977 self.assertInBytecode(y, "LOAD_FAST", "_pystatic_.0._tmp__d")
10978 with self.in_module(codestr) as mod:
10979 y_callable = mod["y"]
10980 self.assertTrue(y_callable())
10981 self.assertEqual(["s", "q", "r"], mod["stuff"])
10982
10983 def test_load_mapping_arg_order_with_variadic_kw_args_one_positional(self):
10984 codestr = """
10985 def x(a: int, b: int, c: str, d: float=-0.1, e: float=1.1, f: str="something", g: str="look-here") -> bool:
10986 return bool(
10987 a == 1
10988 and b == 3
10989 and c == "hi"
10990 and d == 1.1
10991 and e == 3.3
10992 and f == "hmm"
10993 and g == "overridden"
10994 )
10995
10996 stuff = []
10997 def q() -> float:
10998 stuff.append("q")
10999 return 1.1
11000
11001 def r() -> float:
11002 stuff.append("r")
11003 return 3.3
11004
11005 def s() -> str:
11006 stuff.append("s")
11007 return "hmm"
11008
11009
11010 def y() -> bool:
11011 kw = {"g": "overridden"}
11012 return x(1, 3, "hi", 1.1, f=s(), **kw, e=r())
11013 """
11014 y = self.find_code(
11015 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
11016 )
11017 self.assertNotInBytecode(y, "STORE_FAST", "_pystatic_.0._tmp__d")
11018 self.assertNotInBytecode(y, "LOAD_FAST", "_pystatic_.0._tmp__d")
11019 with self.in_module(codestr) as mod:
11020 y_callable = mod["y"]
11021 self.assertTrue(y_callable())
11022 self.assertEqual(["s", "r"], mod["stuff"])
11023
11024 def test_vector_generics(self):
11025 T = TypeVar("T")
11026 VT = Vector[T]
11027 VT2 = VT[int64]
11028 a = VT2()
11029 a.append(42)
11030 with self.assertRaisesRegex(TypeError, "Cannot create plain Vector"):
11031 VT()
11032
11033 def test_vector_invalid_type(self):
11034 class C:
11035 pass
11036
11037 with self.assertRaisesRegex(
11038 TypeError, "Invalid type for ArrayElement: C when instantiating Vector"
11039 ):
11040 Vector[C]
11041
11042 def test_vector_wrong_arg_count(self):
11043 class C:
11044 pass
11045
11046 with self.assertRaisesRegex(
11047 TypeError, "Incorrect number of type arguments for Vector"
11048 ):
11049 Vector[int64, int64]
11050
11051 def test_generic_type_args(self):
11052 T = TypeVar("T")
11053 U = TypeVar("U")
11054
11055 class C(StaticGeneric[T, U]):
11056 pass
11057
11058 c_t = make_generic_type(C, (T, int))
11059 self.assertEqual(c_t.__parameters__, (T,))
11060 c_t_s = make_generic_type(c_t, (str,))
11061 self.assertEqual(c_t_s.__name__, "C[str, int]")
11062 c_u = make_generic_type(C, (int, U))
11063 self.assertEqual(c_u.__parameters__, (U,))
11064 c_u_t = make_generic_type(c_u, (str,))
11065 self.assertEqual(c_u_t.__name__, "C[int, str]")
11066 self.assertFalse(hasattr(c_u_t, "__parameters__"))
11067
11068 c_u_t_1 = make_generic_type(c_u, (int,))
11069 c_u_t_2 = make_generic_type(c_t, (int,))
11070 self.assertEqual(c_u_t_1.__name__, "C[int, int]")
11071 self.assertIs(c_u_t_1, c_u_t_2)
11072
11073 def test_array_slice(self):
11074 v = Array[int64]([1, 2, 3, 4])
11075 self.assertEqual(v[1:3], Array[int64]([2, 3]))
11076 self.assertEqual(type(v[1:2]), Array[int64])
11077
11078 def test_vector_slice(self):
11079 v = Vector[int64]([1, 2, 3, 4])
11080 self.assertEqual(v[1:3], Vector[int64]([2, 3]))
11081 self.assertEqual(type(v[1:2]), Vector[int64])
11082
11083 def test_array_deepcopy(self):
11084 v = Array[int64]([1, 2, 3, 4])
11085 self.assertEqual(v, deepcopy(v))
11086 self.assertIsNot(v, deepcopy(v))
11087 self.assertEqual(type(v), type(deepcopy(v)))
11088
11089 def test_vector_deepcopy(self):
11090 v = Vector[int64]([1, 2, 3, 4])
11091 self.assertEqual(v, deepcopy(v))
11092 self.assertIsNot(v, deepcopy(v))
11093 self.assertEqual(type(v), type(deepcopy(v)))
11094
11095 def test_nested_generic(self):
11096 S = TypeVar("S")
11097 T = TypeVar("T")
11098 U = TypeVar("U")
11099
11100 class F(StaticGeneric[U]):
11101 pass
11102
11103 class C(StaticGeneric[T]):
11104 pass
11105
11106 A = F[S]
11107 self.assertEqual(A.__parameters__, (S,))
11108 X = C[F[T]]
11109 self.assertEqual(X.__parameters__, (T,))
11110
11111 def test_array_len(self):
11112 codestr = """
11113 from __static__ import int64, char, double, Array
11114 from array import array
11115
11116 def y():
11117 return len(Array[int64]([1, 3, 5]))
11118 """
11119 y = self.find_code(
11120 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
11121 )
11122 self.assertInBytecode(y, "FAST_LEN", FAST_LEN_ARRAY)
11123 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11124 y = mod["y"]
11125 self.assertEqual(y(), 3)
11126
11127 def test_array_isinstance(self):
11128 x = Array[int64](0)
11129 self.assertTrue(isinstance(x, Array[int64]))
11130 self.assertFalse(isinstance(x, Array[int32]))
11131 self.assertTrue(issubclass(Array[int64], Array[int64]))
11132 self.assertFalse(issubclass(Array[int64], Array[int32]))
11133
11134 def test_array_weird_type_constrution(self):
11135 self.assertIs(
11136 Array[int64],
11137 Array[
11138 int64,
11139 ],
11140 )
11141
11142 def test_array_not_subclassable(self):
11143
11144 with self.assertRaises(TypeError):
11145
11146 class C(Array[int64]):
11147 pass
11148
11149 with self.assertRaises(TypeError):
11150
11151 class C(Array):
11152
11153 pass
11154
11155 def test_array_enum(self):
11156 codestr = """
11157 from __static__ import Array, clen, int64, box
11158
11159 def f(x: Array[int64]):
11160 i: int64 = 0
11161 j: int64 = 0
11162 while i < clen(x):
11163 j += x[i]
11164 i+=1
11165 return box(j)
11166 """
11167 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11168 f = mod["f"]
11169 a = Array[int64]([1, 2, 3, 4])
11170 self.assertEqual(f(a), 10)
11171 with self.assertRaises(TypeError):
11172 f(None)
11173
11174 def test_optional_array_enum(self):
11175 codestr = """
11176 from __static__ import Array, clen, int64, box
11177 from typing import Optional
11178
11179 def f(x: Optional[Array[int64]]):
11180 if x is None:
11181 return 42
11182
11183 i: int64 = 0
11184 j: int64 = 0
11185 while i < clen(x):
11186 j += x[i]
11187 i+=1
11188 return box(j)
11189 """
11190 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11191 f = mod["f"]
11192 a = Array[int64]([1, 2, 3, 4])
11193 self.assertEqual(f(a), 10)
11194 self.assertEqual(f(None), 42)
11195
11196 def test_array_len_subclass(self):
11197 codestr = """
11198 from __static__ import int64, Array
11199
11200 def y(ar: Array[int64]):
11201 return len(ar)
11202 """
11203 y = self.find_code(
11204 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
11205 )
11206 self.assertInBytecode(y, "FAST_LEN", FAST_LEN_ARRAY | FAST_LEN_INEXACT)
11207
11208 # TODO the below requires Array to be a generic type in C, or else
11209 # support for generic annotations for not-generic-in-C types. For now
11210 # it's sufficient to validate we emitted FAST_LEN_INEXACT flag.
11211
11212 # class MyArray(Array):
11213 # def __len__(self):
11214 # return 123
11215
11216 # with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11217 # y = mod["y"]
11218 # self.assertEqual(y(MyArray[int64]([1])), 123)
11219
11220 def test_nonarray_len(self):
11221 codestr = """
11222 class Lol:
11223 def __len__(self):
11224 return 421
11225
11226 def y():
11227 return len(Lol())
11228 """
11229 y = self.find_code(
11230 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
11231 )
11232 self.assertNotInBytecode(y, "FAST_LEN")
11233 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11234 y = mod["y"]
11235 self.assertEqual(y(), 421)
11236
11237 def test_clen(self):
11238 codestr = """
11239 from __static__ import box, clen, int64
11240 from typing import List
11241
11242 def f(l: List[int]):
11243 x: int64 = clen(l)
11244 return box(x)
11245 """
11246 with self.in_module(codestr) as mod:
11247 f = mod["f"]
11248 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST | FAST_LEN_INEXACT)
11249 self.assertEqual(f([1, 2, 3]), 3)
11250
11251 class MyList(list):
11252 def __len__(self):
11253 return 99
11254
11255 self.assertEqual(f(MyList([1, 2])), 99)
11256
11257 def test_clen_bad_arg(self):
11258 codestr = """
11259 from __static__ import clen
11260
11261 def f(l):
11262 clen(l)
11263 """
11264 with self.assertRaisesRegex(
11265 TypedSyntaxError, "bad argument type 'dynamic' for clen()"
11266 ):
11267 self.compile(codestr)
11268
11269 def test_seq_repeat_list(self):
11270 codestr = """
11271 def f():
11272 l = [1, 2]
11273 return l * 2
11274 """
11275 f = self.find_code(self.compile(codestr))
11276 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_LIST)
11277 with self.in_module(codestr) as mod:
11278 self.assertEqual(mod["f"](), [1, 2, 1, 2])
11279
11280 def test_seq_repeat_list_reversed(self):
11281 codestr = """
11282 def f():
11283 l = [1, 2]
11284 return 2 * l
11285 """
11286 f = self.find_code(self.compile(codestr))
11287 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_LIST | SEQ_REPEAT_REVERSED)
11288 with self.in_module(codestr) as mod:
11289 self.assertEqual(mod["f"](), [1, 2, 1, 2])
11290
11291 def test_seq_repeat_primitive(self):
11292 codestr = """
11293 from __static__ import int64
11294
11295 def f():
11296 x: int64 = 2
11297 l = [1, 2]
11298 return l * x
11299 """
11300 f = self.find_code(self.compile(codestr))
11301 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_LIST | SEQ_REPEAT_PRIMITIVE_NUM)
11302 with self.in_module(codestr) as mod:
11303 self.assertEqual(mod["f"](), [1, 2, 1, 2])
11304
11305 def test_seq_repeat_primitive_reversed(self):
11306 codestr = """
11307 from __static__ import int64
11308
11309 def f():
11310 x: int64 = 2
11311 l = [1, 2]
11312 return x * l
11313 """
11314 f = self.find_code(self.compile(codestr))
11315 self.assertInBytecode(
11316 f,
11317 "SEQUENCE_REPEAT",
11318 SEQ_LIST | SEQ_REPEAT_REVERSED | SEQ_REPEAT_PRIMITIVE_NUM,
11319 )
11320 with self.in_module(codestr) as mod:
11321 self.assertEqual(mod["f"](), [1, 2, 1, 2])
11322
11323 def test_seq_repeat_tuple(self):
11324 codestr = """
11325 def f():
11326 t = (1, 2)
11327 return t * 2
11328 """
11329 f = self.find_code(self.compile(codestr))
11330 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_TUPLE)
11331 with self.in_module(codestr) as mod:
11332 self.assertEqual(mod["f"](), (1, 2, 1, 2))
11333
11334 def test_seq_repeat_tuple_reversed(self):
11335 codestr = """
11336 def f():
11337 t = (1, 2)
11338 return 2 * t
11339 """
11340 f = self.find_code(self.compile(codestr))
11341 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_TUPLE | SEQ_REPEAT_REVERSED)
11342 with self.in_module(codestr) as mod:
11343 self.assertEqual(mod["f"](), (1, 2, 1, 2))
11344
11345 def test_seq_repeat_inexact_list(self):
11346 codestr = """
11347 from typing import List
11348
11349 def f(l: List[int]):
11350 return l * 2
11351 """
11352 f = self.find_code(self.compile(codestr))
11353 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_LIST | SEQ_REPEAT_INEXACT_SEQ)
11354 with self.in_module(codestr) as mod:
11355 self.assertEqual(mod["f"]([1, 2]), [1, 2, 1, 2])
11356
11357 class MyList(list):
11358 def __mul__(self, other):
11359 return "RESULT"
11360
11361 self.assertEqual(mod["f"](MyList([1, 2])), "RESULT")
11362
11363 def test_seq_repeat_inexact_tuple(self):
11364
11365 codestr = """
11366 from typing import Tuple
11367
11368 def f(t: Tuple[int]):
11369 return t * 2
11370 """
11371 f = self.find_code(self.compile(codestr))
11372 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_TUPLE | SEQ_REPEAT_INEXACT_SEQ)
11373 with self.in_module(codestr) as mod:
11374 self.assertEqual(mod["f"]((1, 2)), (1, 2, 1, 2))
11375
11376 class MyTuple(tuple):
11377 def __mul__(self, other):
11378 return "RESULT"
11379
11380 self.assertEqual(mod["f"](MyTuple((1, 2))), "RESULT")
11381
11382 def test_seq_repeat_inexact_num(self):
11383 codestr = """
11384 def f(num: int):
11385
11386 return num * [1, 2]
11387 """
11388 f = self.find_code(self.compile(codestr))
11389 self.assertInBytecode(
11390 f,
11391 "SEQUENCE_REPEAT",
11392 SEQ_LIST | SEQ_REPEAT_INEXACT_NUM | SEQ_REPEAT_REVERSED,
11393 )
11394 with self.in_module(codestr) as mod:
11395 self.assertEqual(mod["f"](2), [1, 2, 1, 2])
11396
11397 class MyInt(int):
11398 def __mul__(self, other):
11399 return "RESULT"
11400
11401 self.assertEqual(mod["f"](MyInt(2)), "RESULT")
11402
11403 def test_load_int_const_sizes(self):
11404 cases = [
11405 ("int8", (1 << 7), True),
11406 ("int8", (-1 << 7) - 1, True),
11407 ("int8", -1 << 7, False),
11408 ("int8", (1 << 7) - 1, False),
11409 ("int16", (1 << 15), True),
11410 ("int16", (-1 << 15) - 1, True),
11411 ("int16", -1 << 15, False),
11412 ("int16", (1 << 15) - 1, False),
11413 ("int32", (1 << 31), True),
11414 ("int32", (-1 << 31) - 1, True),
11415 ("int32", -1 << 31, False),
11416 ("int32", (1 << 31) - 1, False),
11417 ("int64", (1 << 63), True),
11418 ("int64", (-1 << 63) - 1, True),
11419 ("int64", -1 << 63, False),
11420 ("int64", (1 << 63) - 1, False),
11421 ("uint8", (1 << 8), True),
11422 ("uint8", -1, True),
11423 ("uint8", (1 << 8) - 1, False),
11424 ("uint8", 0, False),
11425 ("uint16", (1 << 16), True),
11426 ("uint16", -1, True),
11427 ("uint16", (1 << 16) - 1, False),
11428 ("uint16", 0, False),
11429 ("uint32", (1 << 32), True),
11430 ("uint32", -1, True),
11431 ("uint32", (1 << 32) - 1, False),
11432 ("uint32", 0, False),
11433 ("uint64", (1 << 64), True),
11434 ("uint64", -1, True),
11435 ("uint64", (1 << 64) - 1, False),
11436 ("uint64", 0, False),
11437 ]
11438 for type, val, overflows in cases:
11439 codestr = f"""
11440 from __static__ import {type}, box
11441
11442 def f() -> int:
11443 x: {type} = {val}
11444 return box(x)
11445 """
11446 with self.subTest(type=type, val=val, overflows=overflows):
11447 if overflows:
11448 with self.assertRaisesRegex(
11449 TypedSyntaxError, "outside of the range"
11450 ):
11451 self.compile(codestr)
11452 else:
11453 with self.in_strict_module(codestr) as mod:
11454 self.assertEqual(mod.f(), val)
11455
11456 def test_load_int_const_signed(self):
11457 int_types = [
11458 "int8",
11459 "int16",
11460 "int32",
11461 "int64",
11462 ]
11463 signs = ["-", ""]
11464 values = [12]
11465
11466 for type, sign, value in itertools.product(int_types, signs, values):
11467 expected = value if sign == "" else -value
11468
11469 codestr = f"""
11470 from __static__ import {type}, box
11471
11472 def y() -> int:
11473 x: {type} = {sign}{value}
11474 return box(x)
11475 """
11476 with self.subTest(type=type, sign=sign, value=value):
11477 y = self.find_code(
11478 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
11479 )
11480 self.assertInBytecode(y, "PRIMITIVE_LOAD_CONST")
11481 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11482 y = mod["y"]
11483 self.assertEqual(y(), expected)
11484
11485 def test_load_int_const_unsigned(self):
11486 int_types = [
11487 "uint8",
11488 "uint16",
11489 "uint32",
11490 "uint64",
11491 ]
11492 values = [12]
11493
11494 for type, value in itertools.product(int_types, values):
11495 expected = value
11496 codestr = f"""
11497 from __static__ import {type}, box
11498
11499 def y() -> int:
11500 return box({type}({value}))
11501 """
11502 with self.subTest(type=type, value=value):
11503 y = self.find_code(
11504 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y"
11505 )
11506 self.assertInBytecode(y, "PRIMITIVE_LOAD_CONST")
11507 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11508 y = mod["y"]
11509 self.assertEqual(y(), expected)
11510
11511 def test_primitive_out_of_range(self):
11512 codestr = f"""
11513 from __static__ import int8, box
11514
11515 def f() -> int:
11516 x = int8(255)
11517 return box(x)
11518 """
11519 with self.assertRaisesRegex(
11520 TypedSyntaxError,
11521 "constant 255 is outside of the range -128 to 127 for int8",
11522 ):
11523 self.compile(codestr)
11524
11525 def test_primitive_conversions(self):
11526 cases = [
11527 ("int8", "int8", 5, 5),
11528 ("int8", "int16", 5, 5),
11529 ("int8", "int32", 5, 5),
11530 ("int8", "int64", 5, 5),
11531 ("int8", "uint8", -1, 255),
11532 ("int8", "uint8", 12, 12),
11533 ("int8", "uint16", -1, 65535),
11534 ("int8", "uint16", 12, 12),
11535 ("int8", "uint32", -1, 4294967295),
11536 ("int8", "uint32", 12, 12),
11537 ("int8", "uint64", -1, 18446744073709551615),
11538 ("int8", "uint64", 12, 12),
11539 ("int16", "int8", 5, 5),
11540 ("int16", "int8", -1, -1),
11541 ("int16", "int8", 32767, -1),
11542 ("int16", "int16", 5, 5),
11543 ("int16", "int32", -5, -5),
11544 ("int16", "int64", -6, -6),
11545 ("int16", "uint8", 32767, 255),
11546 ("int16", "uint8", -1, 255),
11547 ("int16", "uint16", 32767, 32767),
11548 ("int16", "uint16", -1, 65535),
11549 ("int16", "uint32", 1000, 1000),
11550 ("int16", "uint32", -1, 4294967295),
11551 ("int16", "uint64", 1414, 1414),
11552 ("int16", "uint64", -1, 18446744073709551615),
11553 ("int32", "int8", 5, 5),
11554 ("int32", "int8", -1, -1),
11555 ("int32", "int8", 2147483647, -1),
11556 ("int32", "int16", 5, 5),
11557 ("int32", "int16", -1, -1),
11558 ("int32", "int16", 2147483647, -1),
11559 ("int32", "int32", 5, 5),
11560 ("int32", "int64", 5, 5),
11561 ("int32", "uint8", 5, 5),
11562 ("int32", "uint8", 65535, 255),
11563 ("int32", "uint8", -1, 255),
11564 ("int32", "uint16", 5, 5),
11565 ("int32", "uint16", 2147483647, 65535),
11566 ("int32", "uint16", -1, 65535),
11567 ("int32", "uint32", 5, 5),
11568 ("int32", "uint32", -1, 4294967295),
11569 ("int32", "uint64", 5, 5),
11570 ("int32", "uint64", -1, 18446744073709551615),
11571 ("int64", "int8", 5, 5),
11572 ("int64", "int8", -1, -1),
11573 ("int64", "int8", 65535, -1),
11574 ("int64", "int16", 5, 5),
11575 ("int64", "int16", -1, -1),
11576 ("int64", "int16", 4294967295, -1),
11577 ("int64", "int32", 5, 5),
11578 ("int64", "int32", -1, -1),
11579 ("int64", "int32", 9223372036854775807, -1),
11580 ("int64", "int64", 5, 5),
11581 ("int64", "uint8", 5, 5),
11582 ("int64", "uint8", 65535, 255),
11583 ("int64", "uint8", -1, 255),
11584 ("int64", "uint16", 5, 5),
11585 ("int64", "uint16", 4294967295, 65535),
11586 ("int64", "uint16", -1, 65535),
11587 ("int64", "uint32", 5, 5),
11588 ("int64", "uint32", 9223372036854775807, 4294967295),
11589 ("int64", "uint32", -1, 4294967295),
11590 ("int64", "uint64", 5, 5),
11591 ("int64", "uint64", -1, 18446744073709551615),
11592 ("uint8", "int8", 5, 5),
11593 ("uint8", "int8", 255, -1),
11594 ("uint8", "int16", 255, 255),
11595 ("uint8", "int32", 255, 255),
11596 ("uint8", "int64", 255, 255),
11597 ("uint8", "uint8", 5, 5),
11598 ("uint8", "uint16", 255, 255),
11599 ("uint8", "uint32", 255, 255),
11600 ("uint8", "uint64", 255, 255),
11601 ("uint16", "int8", 5, 5),
11602 ("uint16", "int8", 65535, -1),
11603 ("uint16", "int16", 5, 5),
11604 ("uint16", "int16", 65535, -1),
11605 ("uint16", "int32", 65535, 65535),
11606 ("uint16", "int64", 65535, 65535),
11607 ("uint16", "uint8", 65535, 255),
11608 ("uint16", "uint16", 65535, 65535),
11609 ("uint16", "uint32", 65535, 65535),
11610 ("uint16", "uint64", 65535, 65535),
11611 ("uint32", "int8", 4, 4),
11612 ("uint32", "int8", 4294967295, -1),
11613 ("uint32", "int16", 5, 5),
11614 ("uint32", "int16", 4294967295, -1),
11615 ("uint32", "int32", 65535, 65535),
11616 ("uint32", "int32", 4294967295, -1),
11617 ("uint32", "int64", 4294967295, 4294967295),
11618 ("uint32", "uint8", 4, 4),
11619 ("uint32", "uint8", 65535, 255),
11620 ("uint32", "uint16", 4294967295, 65535),
11621 ("uint32", "uint32", 5, 5),
11622 ("uint32", "uint64", 4294967295, 4294967295),
11623 ("uint64", "int8", 4, 4),
11624 ("uint64", "int8", 18446744073709551615, -1),
11625 ("uint64", "int16", 4, 4),
11626 ("uint64", "int16", 18446744073709551615, -1),
11627 ("uint64", "int32", 4, 4),
11628 ("uint64", "int32", 18446744073709551615, -1),
11629 ("uint64", "int64", 4, 4),
11630 ("uint64", "int64", 18446744073709551615, -1),
11631 ("uint64", "uint8", 5, 5),
11632 ("uint64", "uint8", 65535, 255),
11633 ("uint64", "uint16", 4294967295, 65535),
11634 ("uint64", "uint32", 18446744073709551615, 4294967295),
11635 ("uint64", "uint64", 5, 5),
11636 ]
11637
11638 for src, dest, val, expected in cases:
11639 codestr = f"""
11640 from __static__ import {src}, {dest}, box
11641
11642 def y() -> int:
11643 x = {dest}({src}({val}))
11644 return box(x)
11645 """
11646 with self.subTest(src=src, dest=dest, val=val, expected=expected):
11647 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11648 y = mod["y"]
11649 actual = y()
11650 self.assertEqual(
11651 actual,
11652 expected,
11653 f"failing case: {[src, dest, val, actual, expected]}",
11654 )
11655
11656 def test_no_cast_after_box(self):
11657 codestr = """
11658 from __static__ import int64, box
11659
11660 def f(x: int) -> int:
11661 y = int64(x) + 1
11662 return box(y)
11663 """
11664 with self.in_module(codestr) as mod:
11665 f = mod["f"]
11666 self.assertNotInBytecode(f, "CAST")
11667 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (1, TYPED_INT64))
11668 self.assertEqual(f(3), 4)
11669
11670 def test_rand(self):
11671 codestr = """
11672 from __static__ import rand, RAND_MAX, box, int64
11673
11674 def test():
11675 x: int64 = rand()
11676 return box(x)
11677 """
11678 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11679 test = mod["test"]
11680 self.assertEqual(type(test()), int)
11681
11682 def test_rand_max_inlined(self):
11683 codestr = """
11684 from __static__ import rand, RAND_MAX, box, int64
11685
11686 def f() -> int:
11687 x: int64 = rand() // int64(RAND_MAX)
11688 return box(x)
11689 """
11690 with self.in_module(codestr) as mod:
11691 f = mod["f"]
11692 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST")
11693 self.assertIsInstance(f(), int)
11694
11695 def test_array_get_primitive_idx(self):
11696 codestr = """
11697 from __static__ import Array, int8, box
11698
11699 def m() -> int:
11700 content = list(range(121))
11701 a = Array[int8](content)
11702 return box(a[int8(111)])
11703 """
11704 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11705 m = self.find_code(c, "m")
11706 self.assertInBytecode(m, "PRIMITIVE_LOAD_CONST", (111, TYPED_INT8))
11707 self.assertInBytecode(m, "SEQUENCE_GET", SEQ_ARRAY_INT8)
11708 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11709 m = mod["m"]
11710 actual = m()
11711 self.assertEqual(actual, 111)
11712
11713 def test_array_get_nonprimitive_idx(self):
11714 codestr = """
11715 from __static__ import Array, int8, box
11716
11717 def m() -> int:
11718 content = list(range(121))
11719 a = Array[int8](content)
11720 return box(a[111])
11721 """
11722 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11723 m = self.find_code(c, "m")
11724 self.assertInBytecode(m, "LOAD_CONST", 111)
11725 self.assertNotInBytecode(m, "PRIMITIVE_LOAD_CONST")
11726 self.assertInBytecode(m, "PRIMITIVE_UNBOX")
11727 self.assertInBytecode(m, "SEQUENCE_GET", SEQ_ARRAY_INT8)
11728 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11729 m = mod["m"]
11730 actual = m()
11731 self.assertEqual(actual, 111)
11732
11733 def test_array_get_failure(self):
11734 codestr = """
11735 from __static__ import Array, int8, box
11736
11737 def m() -> int:
11738 a = Array[int8]([1, 3, -5])
11739 return box(a[20])
11740 """
11741 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11742 m = mod["m"]
11743 with self.assertRaisesRegex(IndexError, "index out of range"):
11744 m()
11745
11746 def test_array_get_negative_idx(self):
11747 codestr = """
11748 from __static__ import Array, int8, box
11749
11750 def m() -> int:
11751 a = Array[int8]([1, 3, -5])
11752 return box(a[-1])
11753 """
11754 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11755 m = mod["m"]
11756 self.assertEqual(m(), -5)
11757
11758 def test_array_set_signed(self):
11759 int_types = [
11760 "int8",
11761 "int16",
11762 "int32",
11763 "int64",
11764 ]
11765 seq_types = {
11766 "int8": SEQ_ARRAY_INT8,
11767 "int16": SEQ_ARRAY_INT16,
11768 "int32": SEQ_ARRAY_INT32,
11769 "int64": SEQ_ARRAY_INT64,
11770 }
11771 signs = ["-", ""]
11772 value = 77
11773
11774 for type, sign in itertools.product(int_types, signs):
11775 codestr = f"""
11776 from __static__ import Array, {type}
11777
11778 def m() -> Array[{type}]:
11779 a = Array[{type}]([1, 3, -5])
11780 a[1] = {sign}{value}
11781 return a
11782 """
11783 with self.subTest(type=type, sign=sign):
11784 val = -value if sign else value
11785 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11786 m = self.find_code(c, "m")
11787 self.assertInBytecode(
11788 m, "PRIMITIVE_LOAD_CONST", (val, prim_name_to_type[type])
11789 )
11790 self.assertInBytecode(m, "LOAD_CONST", 1)
11791 self.assertInBytecode(m, "SEQUENCE_SET", seq_types[type])
11792 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11793 m = mod["m"]
11794 if sign:
11795 expected = -value
11796 else:
11797 expected = value
11798 result = m()
11799 self.assertEqual(
11800 result,
11801 array("q", [1, expected, -5]),
11802 f"Failing case: {type}, {sign}",
11803 )
11804
11805 def test_array_set_unsigned(self):
11806 uint_types = [
11807 "uint8",
11808 "uint16",
11809 "uint32",
11810 "uint64",
11811 ]
11812 value = 77
11813 seq_types = {
11814 "uint8": SEQ_ARRAY_UINT8,
11815 "uint16": SEQ_ARRAY_UINT16,
11816 "uint32": SEQ_ARRAY_UINT32,
11817 "uint64": SEQ_ARRAY_UINT64,
11818 }
11819 for type in uint_types:
11820 codestr = f"""
11821 from __static__ import Array, {type}
11822
11823 def m() -> Array[{type}]:
11824 a = Array[{type}]([1, 3, 5])
11825 a[1] = {value}
11826 return a
11827 """
11828 with self.subTest(type=type):
11829 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11830 m = self.find_code(c, "m")
11831 self.assertInBytecode(
11832 m, "PRIMITIVE_LOAD_CONST", (value, prim_name_to_type[type])
11833 )
11834 self.assertInBytecode(m, "LOAD_CONST", 1)
11835 self.assertInBytecode(m, "SEQUENCE_SET", seq_types[type])
11836 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11837 m = mod["m"]
11838 expected = value
11839 result = m()
11840 self.assertEqual(
11841 result, array("q", [1, expected, 5]), f"Failing case: {type}"
11842 )
11843
11844 def test_array_set_negative_idx(self):
11845 codestr = """
11846 from __static__ import Array, int8
11847
11848 def m() -> Array[int8]:
11849 a = Array[int8]([1, 3, -5])
11850 a[-2] = 7
11851 return a
11852 """
11853 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11854 m = self.find_code(c, "m")
11855 self.assertInBytecode(m, "PRIMITIVE_LOAD_CONST", (7, TYPED_INT8))
11856 self.assertInBytecode(m, "LOAD_CONST", -2)
11857 self.assertInBytecode(m, "SEQUENCE_SET", SEQ_ARRAY_INT8)
11858 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11859 m = mod["m"]
11860 self.assertEqual(m(), array("h", [1, 7, -5]))
11861
11862 def test_array_set_failure(self) -> object:
11863 codestr = """
11864 from __static__ import Array, int8
11865
11866 def m() -> Array[int8]:
11867 a = Array[int8]([1, 3, -5])
11868 a[-100] = 7
11869 return a
11870 """
11871 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11872 m = self.find_code(c, "m")
11873 self.assertInBytecode(m, "PRIMITIVE_LOAD_CONST", (7, TYPED_INT8))
11874 self.assertInBytecode(m, "SEQUENCE_SET", SEQ_ARRAY_INT8)
11875 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11876 m = mod["m"]
11877 with self.assertRaisesRegex(IndexError, "index out of range"):
11878 m()
11879
11880 def test_array_set_failure_invalid_subscript(self):
11881 codestr = """
11882 from __static__ import Array, int8
11883
11884 def x():
11885 return object()
11886
11887 def m() -> Array[int8]:
11888 a = Array[int8]([1, 3, -5])
11889 a[x()] = 7
11890 return a
11891 """
11892 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11893 m = self.find_code(c, "m")
11894 self.assertInBytecode(m, "PRIMITIVE_LOAD_CONST", (7, TYPED_INT8))
11895 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11896 m = mod["m"]
11897 with self.assertRaisesRegex(TypeError, "array indices must be integers"):
11898 m()
11899
11900 def test_fast_len_list(self):
11901 codestr = """
11902 def f():
11903 l = [1, 2, 3, 4, 5, 6, 7]
11904 return len(l)
11905 """
11906 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11907 f = self.find_code(c, "f")
11908 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST)
11909 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11910 f = mod["f"]
11911 self.assertEqual(f(), 7)
11912
11913 def test_fast_len_str(self):
11914 codestr = """
11915 def f():
11916 l = "my str!"
11917 return len(l)
11918 """
11919 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11920 f = self.find_code(c, "f")
11921 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_STR)
11922 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11923 f = mod["f"]
11924 self.assertEqual(f(), 7)
11925
11926 def test_fast_len_str_unicode_chars(self):
11927 codestr = """
11928 def f():
11929 l = "\U0001F923" # ROFL emoji
11930 return len(l)
11931 """
11932 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11933 f = self.find_code(c, "f")
11934 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_STR)
11935 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11936 f = mod["f"]
11937 self.assertEqual(f(), 1)
11938
11939 def test_fast_len_tuple(self):
11940 codestr = """
11941 def f(a, b):
11942 l = (a, b)
11943 return len(l)
11944 """
11945 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11946 f = self.find_code(c, "f")
11947 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_TUPLE)
11948 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11949 f = mod["f"]
11950 self.assertEqual(f("a", "b"), 2)
11951
11952 def test_fast_len_set(self):
11953 codestr = """
11954 def f(a, b):
11955 l = {a, b}
11956 return len(l)
11957 """
11958 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11959 f = self.find_code(c, "f")
11960 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_SET)
11961 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11962 f = mod["f"]
11963 self.assertEqual(f("a", "b"), 2)
11964
11965 def test_fast_len_dict(self):
11966 codestr = """
11967 def f():
11968 l = {1: 'a', 2: 'b', 3: 'c', 4: 'd'}
11969 return len(l)
11970 """
11971 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11972 f = self.find_code(c, "f")
11973 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_DICT)
11974 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11975 f = mod["f"]
11976 self.assertEqual(f(), 4)
11977
11978 def test_fast_len_conditional_list(self):
11979 codestr = """
11980 def f(n: int) -> bool:
11981 l = [i for i in range(n)]
11982 if l:
11983 return True
11984 return False
11985 """
11986 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
11987 f = self.find_code(c, "f")
11988 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST)
11989 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
11990 f = mod["f"]
11991 for length in [0, 7]:
11992 self.assertEqual(f(length), length > 0)
11993
11994 def test_fast_len_conditional_str(self):
11995 codestr = """
11996 def f(n: int) -> bool:
11997 l = f"{'a' * n}"
11998 if l:
11999 return True
12000 return False
12001 """
12002 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12003 f = self.find_code(c, "f")
12004 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_STR)
12005 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12006 f = mod["f"]
12007 for length in [0, 7]:
12008 self.assertEqual(f(length), length > 0)
12009
12010 def test_fast_len_loop_conditional_list(self):
12011 codestr = """
12012 def f(n: int) -> bool:
12013 l = [i for i in range(n)]
12014 while l:
12015 return True
12016 return False
12017 """
12018 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12019 f = self.find_code(c, "f")
12020 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST)
12021 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12022 f = mod["f"]
12023 for length in [0, 7]:
12024 self.assertEqual(f(length), length > 0)
12025
12026 def test_fast_len_loop_conditional_str(self):
12027 codestr = """
12028 def f(n: int) -> bool:
12029 l = f"{'a' * n}"
12030 while l:
12031 return True
12032 return False
12033 """
12034 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12035 f = self.find_code(c, "f")
12036 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_STR)
12037 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12038 f = mod["f"]
12039 for length in [0, 7]:
12040 self.assertEqual(f(length), length > 0)
12041
12042 def test_fast_len_loop_conditional_tuple(self):
12043 codestr = """
12044 def f(n: int) -> bool:
12045 l = tuple(i for i in range(n))
12046 while l:
12047 return True
12048 return False
12049 """
12050 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12051 f = self.find_code(c, "f")
12052 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_TUPLE)
12053 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12054 f = mod["f"]
12055 for length in [0, 7]:
12056 self.assertEqual(f(length), length > 0)
12057
12058 def test_fast_len_loop_conditional_set(self):
12059 codestr = """
12060 def f(n: int) -> bool:
12061 l = {i for i in range(n)}
12062 while l:
12063 return True
12064 return False
12065 """
12066 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12067 f = self.find_code(c, "f")
12068 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_SET)
12069 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12070 f = mod["f"]
12071 for length in [0, 7]:
12072 self.assertEqual(f(length), length > 0)
12073
12074 def test_fast_len_conditional_tuple(self):
12075 codestr = """
12076 def f(n: int) -> bool:
12077 l = tuple(i for i in range(n))
12078 if l:
12079 return True
12080 return False
12081 """
12082 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12083 f = self.find_code(c, "f")
12084 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_TUPLE)
12085 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12086 f = mod["f"]
12087 for length in [0, 7]:
12088 self.assertEqual(f(length), length > 0)
12089
12090 def test_fast_len_conditional_set(self):
12091 codestr = """
12092 def f(n: int) -> bool:
12093 l = {i for i in range(n)}
12094 if l:
12095 return True
12096 return False
12097 """
12098 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12099 f = self.find_code(c, "f")
12100 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_SET)
12101 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12102 f = mod["f"]
12103 for length in [0, 7]:
12104 self.assertEqual(f(length), length > 0)
12105
12106 def test_fast_len_conditional_dict(self):
12107 codestr = """
12108 def f(n: int) -> bool:
12109 l = {i: i for i in range(n)}
12110 if l:
12111 return True
12112 return False
12113 """
12114 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12115 f = self.find_code(c, "f")
12116 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_DICT)
12117 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12118 f = mod["f"]
12119 for length in [0, 7]:
12120 self.assertEqual(f(length), length > 0)
12121
12122 def test_fast_len_conditional_list_subclass(self):
12123 codestr = """
12124 from typing import List
12125
12126 class MyList(list):
12127 def __len__(self):
12128 return 1729
12129
12130 def f(n: int, flag: bool) -> bool:
12131 x: List[int] = [i for i in range(n)]
12132 if flag:
12133 x = MyList([i for i in range(n)])
12134 if x:
12135 return True
12136 return False
12137 """
12138 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12139 f = self.find_code(c, "f")
12140 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST | FAST_LEN_INEXACT)
12141 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12142 f = mod["f"]
12143 for boolean, length in itertools.product((True, False), [0, 7]):
12144 self.assertEqual(
12145 f(length, boolean),
12146 length > 0 or boolean,
12147 f"length={length}, flag={boolean}",
12148 )
12149
12150 def test_fast_len_conditional_str_subclass(self):
12151 codestr = """
12152 class MyStr(str):
12153 def __len__(self):
12154 return 1729
12155
12156 def f(n: int, flag: bool) -> bool:
12157 x: str = f"{'a' * n}"
12158 if flag:
12159 x = MyStr(f"{'a' * n}")
12160 if x:
12161 return True
12162 return False
12163 """
12164 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12165 f = self.find_code(c, "f")
12166 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_STR | FAST_LEN_INEXACT)
12167 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12168 f = mod["f"]
12169 for boolean, length in itertools.product((True, False), [0, 7]):
12170 self.assertEqual(
12171 f(length, boolean),
12172 length > 0 or boolean,
12173 f"length={length}, flag={boolean}",
12174 )
12175
12176 def test_fast_len_conditional_tuple_subclass(self):
12177 codestr = """
12178 class Mytuple(tuple):
12179 def __len__(self):
12180 return 1729
12181
12182 def f(n: int, flag: bool) -> bool:
12183 x = tuple(i for i in range(n))
12184 if flag:
12185 x = Mytuple([i for i in range(n)])
12186 if x:
12187 return True
12188 return False
12189 """
12190 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12191 f = self.find_code(c, "f")
12192 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_TUPLE | FAST_LEN_INEXACT)
12193 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12194 f = mod["f"]
12195 for boolean, length in itertools.product((True, False), [0, 7]):
12196 self.assertEqual(
12197 f(length, boolean),
12198 length > 0 or boolean,
12199 f"length={length}, flag={boolean}",
12200 )
12201
12202 def test_fast_len_conditional_set_subclass(self):
12203 codestr = """
12204 class Myset(set):
12205 def __len__(self):
12206 return 1729
12207
12208 def f(n: int, flag: bool) -> bool:
12209 x = set(i for i in range(n))
12210 if flag:
12211 x = Myset([i for i in range(n)])
12212 if x:
12213 return True
12214 return False
12215 """
12216 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12217 f = self.find_code(c, "f")
12218 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_SET | FAST_LEN_INEXACT)
12219 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12220 f = mod["f"]
12221 for boolean, length in itertools.product((True, False), [0, 7]):
12222 self.assertEqual(
12223 f(length, boolean),
12224 length > 0 or boolean,
12225 f"length={length}, flag={boolean}",
12226 )
12227
12228 def test_fast_len_conditional_list_funcarg(self):
12229 codestr = """
12230 def z(b: object) -> bool:
12231 return bool(b)
12232
12233 def f(n: int) -> bool:
12234 l = [i for i in range(n)]
12235 if z(l):
12236 return True
12237 return False
12238 """
12239 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12240 f = self.find_code(c, "f")
12241 # Since the list is given to z(), do not optimize the check
12242 # with FAST_LEN
12243 self.assertNotInBytecode(f, "FAST_LEN")
12244 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12245 f = mod["f"]
12246 for length in [0, 7]:
12247 self.assertEqual(f(length), length > 0)
12248
12249 def test_fast_len_conditional_str_funcarg(self):
12250 codestr = """
12251 def z(b: object) -> bool:
12252 return bool(b)
12253
12254 def f(n: int) -> bool:
12255 l = f"{'a' * n}"
12256 if z(l):
12257 return True
12258 return False
12259 """
12260 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12261 f = self.find_code(c, "f")
12262 # Since the list is given to z(), do not optimize the check
12263 # with FAST_LEN
12264 self.assertNotInBytecode(f, "FAST_LEN")
12265 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12266 f = mod["f"]
12267 for length in [0, 7]:
12268 self.assertEqual(f(length), length > 0)
12269
12270 def test_fast_len_conditional_tuple_funcarg(self):
12271 codestr = """
12272 def z(b: object) -> bool:
12273 return bool(b)
12274
12275 def f(n: int) -> bool:
12276 l = tuple(i for i in range(n))
12277 if z(l):
12278 return True
12279 return False
12280 """
12281 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12282 f = self.find_code(c, "f")
12283 # Since the tuple is given to z(), do not optimize the check
12284 # with FAST_LEN
12285 self.assertNotInBytecode(f, "FAST_LEN")
12286 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12287 f = mod["f"]
12288 for length in [0, 7]:
12289 self.assertEqual(f(length), length > 0)
12290
12291 def test_fast_len_conditional_set_funcarg(self):
12292 codestr = """
12293 def z(b: object) -> bool:
12294 return bool(b)
12295
12296 def f(n: int) -> bool:
12297 l = set(i for i in range(n))
12298 if z(l):
12299 return True
12300 return False
12301 """
12302 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12303 f = self.find_code(c, "f")
12304 # Since the set is given to z(), do not optimize the check
12305 # with FAST_LEN
12306 self.assertNotInBytecode(f, "FAST_LEN")
12307 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12308 f = mod["f"]
12309 for length in [0, 7]:
12310 self.assertEqual(f(length), length > 0)
12311
12312 def test_fast_len_conditional_dict_funcarg(self):
12313 codestr = """
12314 def z(b) -> bool:
12315 return bool(b)
12316
12317 def f(n: int) -> bool:
12318 l = {i: i for i in range(n)}
12319 if z(l):
12320 return True
12321 return False
12322 """
12323 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12324 f = self.find_code(c, "f")
12325 # Since the dict is given to z(), do not optimize the check
12326 # with FAST_LEN
12327 self.assertNotInBytecode(f, "FAST_LEN")
12328 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12329 f = mod["f"]
12330 for length in [0, 7]:
12331 self.assertEqual(f(length), length > 0)
12332
12333 def test_fast_len_list_subclass(self):
12334 codestr = """
12335 class mylist(list):
12336 def __len__(self):
12337 return 1111
12338
12339 def f():
12340 l = mylist([1, 2])
12341 return len(l)
12342 """
12343 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12344 f = self.find_code(c, "f")
12345 self.assertNotInBytecode(f, "FAST_LEN")
12346 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12347 f = mod["f"]
12348 self.assertEqual(f(), 1111)
12349
12350 def test_fast_len_str_subclass(self):
12351 codestr = """
12352 class mystr(str):
12353 def __len__(self):
12354 return 1111
12355
12356 def f():
12357 s = mystr("a")
12358 return len(s)
12359 """
12360 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12361 f = self.find_code(c, "f")
12362 self.assertNotInBytecode(f, "FAST_LEN")
12363 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12364 f = mod["f"]
12365 self.assertEqual(f(), 1111)
12366
12367 def test_fast_len_tuple_subclass(self):
12368 codestr = """
12369 class mytuple(tuple):
12370 def __len__(self):
12371 return 1111
12372
12373 def f():
12374 l = mytuple([1, 2])
12375 return len(l)
12376 """
12377 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12378 f = self.find_code(c, "f")
12379 self.assertNotInBytecode(f, "FAST_LEN")
12380 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12381 f = mod["f"]
12382 self.assertEqual(f(), 1111)
12383
12384 def test_fast_len_set_subclass(self):
12385 codestr = """
12386 class myset(set):
12387 def __len__(self):
12388 return 1111
12389
12390 def f():
12391 l = myset([1, 2])
12392 return len(l)
12393 """
12394 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12395 f = self.find_code(c, "f")
12396 self.assertNotInBytecode(f, "FAST_LEN")
12397 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12398 f = mod["f"]
12399 self.assertEqual(f(), 1111)
12400
12401 def test_fast_len_dict_subclass(self):
12402 codestr = """
12403 from typing import Dict
12404
12405 class mydict(Dict[str, int]):
12406 def __len__(self):
12407 return 1111
12408
12409 def f():
12410 l = mydict(a=1, b=2)
12411 return len(l)
12412 """
12413 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12414 f = self.find_code(c, "f")
12415 self.assertNotInBytecode(f, "FAST_LEN")
12416 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12417 f = mod["f"]
12418 self.assertEqual(f(), 1111)
12419
12420 def test_fast_len_list_subclass_2(self):
12421 codestr = """
12422 class mylist(list):
12423 def __len__(self):
12424 return 1111
12425
12426 def f(x):
12427 l = [1, 2]
12428 if x:
12429 l = mylist([1, 2])
12430 return len(l)
12431 """
12432 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12433 f = self.find_code(c, "f")
12434 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST | FAST_LEN_INEXACT)
12435 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12436 f = mod["f"]
12437 self.assertEqual(f(True), 1111)
12438
12439 def test_fast_len_str_subclass_2(self):
12440 codestr = """
12441 class mystr(str):
12442 def __len__(self):
12443 return 1111
12444
12445 def f(x):
12446 s = "abc"
12447 if x:
12448 s = mystr("pqr")
12449 return len(s)
12450 """
12451 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12452 f = self.find_code(c, "f")
12453 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_STR | FAST_LEN_INEXACT)
12454 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12455 f = mod["f"]
12456 self.assertEqual(f(True), 1111)
12457 self.assertEqual(f(False), 3)
12458
12459 def test_fast_len_tuple_subclass_2(self):
12460 codestr = """
12461 class mytuple(tuple):
12462 def __len__(self):
12463 return 1111
12464
12465 def f(x, a, b):
12466 l = (a, b)
12467 if x:
12468 l = mytuple([a, b])
12469 return len(l)
12470 """
12471 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12472 f = self.find_code(c, "f")
12473 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_TUPLE | FAST_LEN_INEXACT)
12474 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12475 f = mod["f"]
12476 self.assertEqual(f(True, 1, 2), 1111)
12477
12478 def test_fast_len_dict_subclass_2(self):
12479 codestr = """
12480 from typing import Dict
12481
12482 class mydict(Dict[str, int]):
12483 def __len__(self):
12484 return 1111
12485
12486 def f(x, a, b):
12487 l: Dict[str, int] = {'c': 3}
12488 if x:
12489 l = mydict(a=1, b=2)
12490 return len(l)
12491 """
12492 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12493 f = self.find_code(c, "f")
12494 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_DICT | FAST_LEN_INEXACT)
12495 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12496 f = mod["f"]
12497 self.assertEqual(f(True, 1, 2), 1111)
12498
12499 def test_fast_len_set_subclass_2(self):
12500 codestr = """
12501 class myset(set):
12502 def __len__(self):
12503 return 1111
12504
12505 def f(x, a, b):
12506 l = {a, b}
12507 if x:
12508 l = myset([a, b])
12509 return len(l)
12510 """
12511 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py")
12512 f = self.find_code(c, "f")
12513 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_SET | FAST_LEN_INEXACT)
12514 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12515 f = mod["f"]
12516 self.assertEqual(f(True, 1, 2), 1111)
12517
12518 def test_dynamic_type_param(self):
12519 """DYNAMIC as type param of generic doesn't render the whole type DYNAMIC."""
12520 codestr = """
12521 from __static__ import int64, clen
12522 from nonstatic import Foo
12523 from typing import Dict
12524
12525 def f(d: Dict[Foo, int]):
12526 x: int64 = clen(d)
12527 """
12528 self.compile(codestr)
12529
12530 def test_checked_dict(self):
12531 x = chkdict[str, str]()
12532 x["abc"] = "foo"
12533 self.assertEqual(repr(x), "{'abc': 'foo'}")
12534 x = chkdict[str, int]()
12535 x["abc"] = 42
12536 x = chkdict[int, str]()
12537 x[42] = "abc"
12538
12539 def test_checked_dict_optional(self):
12540 x = chkdict[str, Optional[str]]()
12541 x["abc"] = None
12542 x = chkdict[Optional[str], str]()
12543 x[None] = "abc"
12544
12545 def test_checked_dict_nonoptional(self):
12546 x = chkdict[str, Optional[str]]()
12547 with self.assertRaises(TypeError):
12548 x[None] = "abc"
12549 x = chkdict[Optional[str], str]()
12550 with self.assertRaises(TypeError):
12551 x["abc"] = None
12552
12553 def test_checked_dict_types_enforced(self):
12554 x = chkdict[str, str]()
12555 with self.assertRaises(TypeError):
12556 x[42] = "abc"
12557 self.assertEqual(x, {})
12558 with self.assertRaises(TypeError):
12559 x["abc"] = 42
12560 self.assertEqual(x, {})
12561
12562 x = chkdict[str, int]()
12563 with self.assertRaises(TypeError):
12564 x[42] = 42
12565 self.assertEqual(x, {})
12566 with self.assertRaises(TypeError):
12567 x["abc"] = "abc"
12568 self.assertEqual(x, {})
12569
12570 def test_checked_dict_ctor(self):
12571 self.assertEqual(chkdict[str, str](x="abc"), {"x": "abc"})
12572 self.assertEqual(chkdict[str, int](x=42), {"x": 42})
12573 self.assertEqual(chkdict[str, str]({"x": "abc"}), {"x": "abc"})
12574 self.assertEqual(chkdict[str, str]([("a", "b")]), {"a": "b"})
12575 self.assertEqual(chkdict[str, str]([("a", "b")]), {"a": "b"})
12576 self.assertEqual(chkdict[str, str](chkdict[str, str](x="abc")), {"x": "abc"})
12577 self.assertEqual(chkdict[str, str](chkdict[str, object](x="abc")), {"x": "abc"})
12578 self.assertEqual(chkdict[str, str](UserDict(x="abc")), {"x": "abc"})
12579 self.assertEqual(chkdict[str, str](UserDict(x="abc"), x="foo"), {"x": "foo"})
12580
12581 def test_checked_dict_bad_ctor(self):
12582 with self.assertRaises(TypeError):
12583 chkdict[str, str](None)
12584
12585 def test_checked_dict_setdefault(self):
12586 x = chkdict[str, str]()
12587 x.setdefault("abc", "foo")
12588 self.assertEqual(x, {"abc": "foo"})
12589
12590 def test_checked_dict___module__(self):
12591 class Lol:
12592 pass
12593
12594 x = chkdict[int, Lol]()
12595 self.assertEqual(type(x).__module__, "__static__")
12596
12597 def test_checked_dict_setdefault_bad_values(self):
12598 x = chkdict[str, int]()
12599 with self.assertRaises(TypeError):
12600 x.setdefault("abc", "abc")
12601 self.assertEqual(x, {})
12602 with self.assertRaises(TypeError):
12603 x.setdefault(42, 42)
12604 self.assertEqual(x, {})
12605
12606 def test_checked_dict_fromkeys(self):
12607 x = chkdict[str, int].fromkeys("abc", 42)
12608 self.assertEqual(x, {"a": 42, "b": 42, "c": 42})
12609
12610 def test_checked_dict_fromkeys_optional(self):
12611 x = chkdict[Optional[str], int].fromkeys(["a", "b", "c", None], 42)
12612 self.assertEqual(x, {"a": 42, "b": 42, "c": 42, None: 42})
12613
12614 x = chkdict[str, Optional[int]].fromkeys("abc", None)
12615 self.assertEqual(x, {"a": None, "b": None, "c": None})
12616
12617 def test_checked_dict_fromkeys_bad_types(self):
12618 with self.assertRaises(TypeError):
12619 chkdict[str, int].fromkeys([2], 42)
12620
12621 with self.assertRaises(TypeError):
12622 chkdict[str, int].fromkeys("abc", object())
12623
12624 with self.assertRaises(TypeError):
12625 chkdict[str, int].fromkeys("abc")
12626
12627 def test_checked_dict_copy(self):
12628 x = chkdict[str, str](x="abc")
12629 self.assertEqual(type(x), chkdict[str, str])
12630 self.assertEqual(x, {"x": "abc"})
12631
12632 def test_checked_dict_clear(self):
12633 x = chkdict[str, str](x="abc")
12634 x.clear()
12635 self.assertEqual(x, {})
12636
12637 def test_checked_dict_update(self):
12638 x = chkdict[str, str](x="abc")
12639 x.update(y="foo")
12640 self.assertEqual(x, {"x": "abc", "y": "foo"})
12641 x.update({"z": "bar"})
12642 self.assertEqual(x, {"x": "abc", "y": "foo", "z": "bar"})
12643
12644 def test_checked_dict_update_bad_type(self):
12645 x = chkdict[str, int]()
12646 with self.assertRaises(TypeError):
12647 x.update(x="abc")
12648 self.assertEqual(x, {})
12649 with self.assertRaises(TypeError):
12650 x.update({"x": "abc"})
12651 with self.assertRaises(TypeError):
12652 x.update({24: 42})
12653 self.assertEqual(x, {})
12654
12655 def test_checked_dict_keys(self):
12656 x = chkdict[str, int](x=2)
12657 self.assertEqual(list(x.keys()), ["x"])
12658 x = chkdict[str, int](x=2, y=3)
12659 self.assertEqual(list(x.keys()), ["x", "y"])
12660
12661 def test_checked_dict_values(self):
12662 x = chkdict[str, int](x=2, y=3)
12663 self.assertEqual(list(x.values()), [2, 3])
12664
12665 def test_checked_dict_items(self):
12666 x = chkdict[str, int](x=2)
12667 self.assertEqual(
12668 list(x.items()),
12669 [
12670 ("x", 2),
12671 ],
12672 )
12673 x = chkdict[str, int](x=2, y=3)
12674 self.assertEqual(list(x.items()), [("x", 2), ("y", 3)])
12675
12676 def test_checked_dict_pop(self):
12677 x = chkdict[str, int](x=2)
12678 y = x.pop("x")
12679 self.assertEqual(y, 2)
12680 with self.assertRaises(KeyError):
12681 x.pop("z")
12682
12683 def test_checked_dict_popitem(self):
12684 x = chkdict[str, int](x=2)
12685 y = x.popitem()
12686 self.assertEqual(y, ("x", 2))
12687 with self.assertRaises(KeyError):
12688 x.popitem()
12689
12690 def test_checked_dict_get(self):
12691 x = chkdict[str, int](x=2)
12692 self.assertEqual(x.get("x"), 2)
12693 self.assertEqual(x.get("y", 100), 100)
12694
12695 def test_checked_dict_errors(self):
12696 x = chkdict[str, int](x=2)
12697 with self.assertRaises(TypeError):
12698 x.get(100)
12699 with self.assertRaises(TypeError):
12700 x.get("x", "abc")
12701
12702 def test_checked_dict_sizeof(self):
12703 x = chkdict[str, int](x=2).__sizeof__()
12704 self.assertEqual(type(x), int)
12705
12706 def test_checked_dict_getitem(self):
12707 x = chkdict[str, int](x=2)
12708 self.assertEqual(x.__getitem__("x"), 2)
12709
12710 def test_checked_dict_free_list(self):
12711 t1 = chkdict[str, int]
12712 t2 = chkdict[str, str]
12713 x = t1()
12714 x_id1 = id(x)
12715 del x
12716 x = t2()
12717 x_id2 = id(x)
12718 self.assertEqual(x_id1, x_id2)
12719
12720 def test_check_args(self):
12721 """
12722 Tests whether CHECK_ARGS can handle variables which are in a Cell,
12723 and are a positional arg at index 0.
12724 """
12725
12726 codestr = """
12727 def use(i: object) -> object:
12728 return i
12729
12730 def outer(x: int) -> object:
12731
12732 def inner() -> None:
12733 use(x)
12734
12735 return use(x)
12736 """
12737 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12738 outer = mod["outer"]
12739 self.assertEqual(outer(1), 1)
12740
12741 def test_check_args_2(self):
12742 """
12743 Tests whether CHECK_ARGS can handle multiple variables which are in a Cell,
12744 and are positional args.
12745 """
12746
12747 codestr = """
12748 def use(i: object) -> object:
12749 return i
12750
12751 def outer(x: int, y: str) -> object:
12752
12753 def inner() -> None:
12754 use(x)
12755 use(y)
12756
12757 use(x)
12758 return use(y)
12759 """
12760 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12761 outer = mod["outer"]
12762 self.assertEqual(outer(1, "yo"), "yo")
12763 # Force JIT-compiled code to go through argument checks after
12764 # keyword arg binding
12765 self.assertEqual(outer(1, y="yo"), "yo")
12766
12767 def test_check_args_3(self):
12768 """
12769 Tests whether CHECK_ARGS can handle variables which are in a Cell,
12770 and are a positional arg at index > 0.
12771 """
12772
12773 codestr = """
12774 def use(i: object) -> object:
12775 return i
12776
12777 def outer(x: int, y: str) -> object:
12778
12779 def inner() -> None:
12780 use(y)
12781
12782 use(x)
12783 return use(y)
12784 """
12785 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12786 outer = mod["outer"]
12787 self.assertEqual(outer(1, "yo"), "yo")
12788 # Force JIT-compiled code to go through argument checks after
12789 # keyword arg binding
12790 self.assertEqual(outer(1, y="yo"), "yo")
12791
12792 def test_check_args_4(self):
12793 """
12794 Tests whether CHECK_ARGS can handle variables which are in a Cell,
12795 and are a kwarg at index 0.
12796 """
12797
12798 codestr = """
12799 def use(i: object) -> object:
12800 return i
12801
12802 def outer(x: int = 0) -> object:
12803
12804 def inner() -> None:
12805 use(x)
12806
12807 return use(x)
12808 """
12809 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12810 outer = mod["outer"]
12811 self.assertEqual(outer(1), 1)
12812
12813 def test_check_args_5(self):
12814 """
12815 Tests whether CHECK_ARGS can handle variables which are in a Cell,
12816 and are a kw-only arg.
12817 """
12818 codestr = """
12819 def use(i: object) -> object:
12820 return i
12821
12822 def outer(x: int, *, y: str = "lol") -> object:
12823
12824 def inner() -> None:
12825 use(y)
12826
12827 return use(y)
12828
12829 """
12830 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12831 outer = mod["outer"]
12832 self.assertEqual(outer(1, y="hi"), "hi")
12833
12834 def test_check_args_6(self):
12835 """
12836 Tests whether CHECK_ARGS can handle variables which are in a Cell,
12837 and are a pos-only arg.
12838 """
12839 codestr = """
12840 def use(i: object) -> object:
12841 return i
12842
12843 def outer(x: int, /, y: str) -> object:
12844
12845 def inner() -> None:
12846 use(y)
12847
12848 return use(y)
12849
12850 """
12851 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12852 outer = mod["outer"]
12853 self.assertEqual(outer(1, "hi"), "hi")
12854
12855 def test_check_args_7(self):
12856 """
12857 Tests whether CHECK_ARGS can handle multiple variables which are in a Cell,
12858 and are a mix of positional, pos-only and kw-only args.
12859 """
12860
12861 codestr = """
12862 def use(i: object) -> object:
12863 return i
12864
12865 def outer(x: int, /, y: int, *, z: str = "lol") -> object:
12866
12867 def inner() -> None:
12868 use(x)
12869 use(y)
12870 use(z)
12871
12872 return use(x), use(y), use(z)
12873 """
12874 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12875 outer = mod["outer"]
12876 self.assertEqual(outer(3, 2, z="hi"), (3, 2, "hi"))
12877
12878 def test_str_split(self):
12879 codestr = """
12880 def get_str() -> str:
12881 return "something here"
12882
12883 def test() -> str:
12884 a, b = get_str().split(None, 1)
12885 return b
12886 """
12887 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12888 test = mod["test"]
12889 self.assertEqual(test(), "here")
12890
12891 def test_vtable_shadow_builtin_subclass_after_init(self):
12892 """Shadowing methods on subclass of list after vtables are inited."""
12893
12894 class MyList(list):
12895 pass
12896
12897 def myreverse(self):
12898 return 1
12899
12900 codestr = """
12901 def f(l: list):
12902 l.reverse()
12903 return l
12904 """
12905 f = self.find_code(self.compile(codestr), "f")
12906 self.assertInBytecode(
12907 f, "INVOKE_METHOD", ((("builtins", "list", "reverse"), 0))
12908 )
12909 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12910 # Now cause vtables to be inited
12911 self.assertEqual(mod["f"]([1, 2]), [2, 1])
12912
12913 # And now patch
12914 MyList.reverse = myreverse
12915
12916 self.assertEqual(MyList().reverse(), 1)
12917
12918 def test_vtable_shadow_builtin_subclass_before_init(self):
12919 """Shadowing methods on subclass of list before vtables are inited."""
12920 # Create a subclass of list...
12921 class MyList(list):
12922 pass
12923
12924 def myreverse(self):
12925 return 1
12926
12927 # ... and override a slot from list with a non-static func
12928 MyList.reverse = myreverse
12929
12930 codestr = """
12931 def f(l: list):
12932 l.reverse()
12933 return l
12934 """
12935 f = self.find_code(self.compile(codestr), "f")
12936 self.assertInBytecode(
12937 f, "INVOKE_METHOD", ((("builtins", "list", "reverse"), 0))
12938 )
12939 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12940 # Now cause vtables to be inited
12941 self.assertEqual(mod["f"]([1, 2]), [2, 1])
12942
12943 # ... and this should not blow up when we remove the override.
12944 del MyList.reverse
12945
12946 self.assertEqual(MyList().reverse(), None)
12947
12948 def test_vtable_shadow_static_subclass(self):
12949 """Shadowing methods of a static type before its inited should not bypass typechecks."""
12950 # Define a static type and shadow a subtype method before invoking.
12951 codestr = """
12952 class StaticType:
12953 def foo(self) -> int:
12954 return 1
12955
12956 class SubType(StaticType):
12957 pass
12958
12959 def goodfoo(self):
12960 return 2
12961
12962 SubType.foo = goodfoo
12963
12964 def f(x: StaticType) -> int:
12965 return x.foo()
12966 """
12967 f = self.find_code(self.compile(codestr), "f")
12968 self.assertInBytecode(
12969 f, "INVOKE_METHOD", ((("<module>", "StaticType", "foo"), 0))
12970 )
12971 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
12972 SubType = mod["SubType"]
12973 # Now invoke:
12974 self.assertEqual(mod["f"](SubType()), 2)
12975
12976 # And replace the function again, forcing us to find the right slot type:
12977 def badfoo(self):
12978 return "foo"
12979
12980 SubType.foo = badfoo
12981
12982 with self.assertRaisesRegex(TypeError, "expected int, got str"):
12983 mod["f"](SubType())
12984
12985 def test_vtable_shadow_static_subclass_nonstatic_patch(self):
12986 """Shadowing methods of a static type before its inited should not bypass typechecks."""
12987 code1 = """
12988 def nonstaticfoo(self):
12989 return 2
12990 """
12991 with self.in_module(
12992 code1, code_gen=PythonCodeGenerator, name="nonstatic"
12993 ) as mod1:
12994 # Define a static type and shadow a subtype method with a non-static func before invoking.
12995 codestr = """
12996 from nonstatic import nonstaticfoo
12997
12998 class StaticType:
12999 def foo(self) -> int:
13000 return 1
13001
13002 class SubType(StaticType):
13003 pass
13004
13005 SubType.foo = nonstaticfoo
13006
13007 def f(x: StaticType) -> int:
13008 return x.foo()
13009
13010 def badfoo(self):
13011 return "foo"
13012 """
13013 code = self.compile(codestr)
13014 f = self.find_code(code, "f")
13015 self.assertInBytecode(
13016 f, "INVOKE_METHOD", ((("<module>", "StaticType", "foo"), 0))
13017 )
13018 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13019 SubType = mod["SubType"]
13020 badfoo = mod["badfoo"]
13021
13022 # And replace the function again, forcing us to find the right slot type:
13023 SubType.foo = badfoo
13024
13025 with self.assertRaisesRegex(TypeError, "expected int, got str"):
13026 mod["f"](SubType())
13027
13028 def test_vtable_shadow_grandparent(self):
13029 codestr = """
13030 class Base:
13031 def foo(self) -> int:
13032 return 1
13033
13034 class Sub(Base):
13035 pass
13036
13037 class Grand(Sub):
13038 pass
13039
13040 def f(x: Base) -> int:
13041 return x.foo()
13042
13043 def grandfoo(self):
13044 return "foo"
13045 """
13046 f = self.find_code(self.compile(codestr), "f")
13047 self.assertInBytecode(f, "INVOKE_METHOD", ((("<module>", "Base", "foo"), 0)))
13048 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13049 Grand = mod["Grand"]
13050 grandfoo = mod["grandfoo"]
13051 f = mod["f"]
13052
13053 # init vtables
13054 self.assertEqual(f(Grand()), 1)
13055
13056 # patch in an override of the grandparent method
13057 Grand.foo = grandfoo
13058
13059 with self.assertRaisesRegex(TypeError, "expected int, got str"):
13060 f(Grand())
13061
13062 def test_for_iter_list(self):
13063 codestr = """
13064 from typing import List
13065
13066 def f(n: int) -> List:
13067 acc = []
13068 l = [i for i in range(n)]
13069 for i in l:
13070 acc.append(i + 1)
13071 return acc
13072 """
13073 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13074 f = mod["f"]
13075 self.assertNotInBytecode(f, "FOR_ITER")
13076 self.assertEqual(f(4), [i + 1 for i in range(4)])
13077
13078 def test_for_iter_tuple(self):
13079 codestr = """
13080 from typing import List
13081
13082 def f(n: int) -> List:
13083 acc = []
13084 l = tuple([i for i in range(n)])
13085 for i in l:
13086 acc.append(i + 1)
13087 return acc
13088 """
13089 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13090 f = mod["f"]
13091 self.assertNotInBytecode(f, "FOR_ITER")
13092 self.assertEqual(f(4), [i + 1 for i in range(4)])
13093
13094 def test_for_iter_sequence_orelse(self):
13095 codestr = """
13096 from typing import List
13097
13098 def f(n: int) -> List:
13099 acc = []
13100 l = [i for i in range(n)]
13101 for i in l:
13102 acc.append(i + 1)
13103 else:
13104 acc.append(999)
13105 return acc
13106 """
13107 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13108 f = mod["f"]
13109 self.assertNotInBytecode(f, "FOR_ITER")
13110 self.assertEqual(f(4), [i + 1 for i in range(4)] + [999])
13111
13112 def test_for_iter_sequence_break(self):
13113 codestr = """
13114 from typing import List
13115
13116 def f(n: int) -> List:
13117 acc = []
13118 l = [i for i in range(n)]
13119 for i in l:
13120 if i == 3:
13121 break
13122 acc.append(i + 1)
13123 return acc
13124 """
13125 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13126 f = mod["f"]
13127 self.assertNotInBytecode(f, "FOR_ITER")
13128 self.assertEqual(f(5), [1, 2, 3])
13129
13130 def test_for_iter_sequence_orelse_break(self):
13131 codestr = """
13132 from typing import List
13133
13134 def f(n: int) -> List:
13135 acc = []
13136 l = [i for i in range(n)]
13137 for i in l:
13138 if i == 2:
13139 break
13140 acc.append(i + 1)
13141 else:
13142 acc.append(999)
13143 return acc
13144 """
13145 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13146 f = mod["f"]
13147 self.assertNotInBytecode(f, "FOR_ITER")
13148 self.assertEqual(f(4), [1, 2])
13149
13150 def test_for_iter_sequence_return(self):
13151 codestr = """
13152 from typing import List
13153
13154 def f(n: int) -> List:
13155 acc = []
13156 l = [i for i in range(n)]
13157 for i in l:
13158 if i == 3:
13159 return acc
13160 acc.append(i + 1)
13161 return acc
13162 """
13163 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13164 f = mod["f"]
13165 self.assertNotInBytecode(f, "FOR_ITER")
13166 self.assertEqual(f(6), [1, 2, 3])
13167
13168 def test_nested_for_iter_sequence(self):
13169 codestr = """
13170 from typing import List
13171
13172 def f(n: int) -> List:
13173 acc = []
13174 l = [i for i in range(n)]
13175 for i in l:
13176 for j in l:
13177 acc.append(i + j)
13178 return acc
13179 """
13180 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13181 f = mod["f"]
13182 self.assertNotInBytecode(f, "FOR_ITER")
13183 self.assertEqual(f(3), [0, 1, 2, 1, 2, 3, 2, 3, 4])
13184
13185 def test_nested_for_iter_sequence_break(self):
13186 codestr = """
13187 from typing import List
13188
13189 def f(n: int) -> List:
13190 acc = []
13191 l = [i for i in range(n)]
13192 for i in l:
13193 for j in l:
13194 if j == 2:
13195 break
13196 acc.append(i + j)
13197 return acc
13198 """
13199 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13200 f = mod["f"]
13201 self.assertNotInBytecode(f, "FOR_ITER")
13202 self.assertEqual(f(3), [0, 1, 1, 2, 2, 3])
13203
13204 def test_nested_for_iter_sequence_return(self):
13205 codestr = """
13206 from typing import List
13207
13208 def f(n: int) -> List:
13209 acc = []
13210 l = [i for i in range(n)]
13211 for i in l:
13212 for j in l:
13213 if j == 1:
13214 return acc
13215 acc.append(i + j)
13216 return acc
13217 """
13218 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13219 f = mod["f"]
13220 self.assertNotInBytecode(f, "FOR_ITER")
13221 self.assertEqual(f(3), [0])
13222
13223 def test_for_iter_unchecked_get(self):
13224 """We don't need to check sequence bounds when we've just compared with the list size."""
13225 codestr = """
13226 def f():
13227 l = [1, 2, 3]
13228 acc = []
13229 for x in l:
13230 acc.append(x)
13231 return acc
13232 """
13233 with self.in_module(codestr) as mod:
13234 f = mod["f"]
13235 self.assertInBytecode(f, "SEQUENCE_GET", SEQ_LIST | SEQ_SUBSCR_UNCHECKED)
13236 self.assertEqual(f(), [1, 2, 3])
13237
13238 def test_for_iter_list_modified(self):
13239 codestr = """
13240 def f():
13241 l = [1, 2, 3, 4, 5]
13242 acc = []
13243 for x in l:
13244 acc.append(x)
13245 l[2:] = []
13246 return acc
13247 """
13248 with self.in_module(codestr) as mod:
13249 f = mod["f"]
13250 self.assertNotInBytecode(f, "FOR_ITER")
13251 self.assertEqual(f(), [1, 2])
13252
13253 def test_sorted(self):
13254 """sorted() builtin returns an Exact[List]."""
13255 codestr = """
13256 from typing import Iterable
13257
13258 def f(l: Iterable[int]):
13259 for x in sorted(l):
13260 pass
13261 """
13262 with self.in_module(codestr) as mod:
13263 f = mod["f"]
13264 self.assertNotInBytecode(f, "FOR_ITER")
13265 self.assertInBytecode(f, "REFINE_TYPE", ("builtins", "list"))
13266
13267 def test_min(self):
13268 codestr = """
13269 def f(a: int, b: int) -> int:
13270 return min(a, b)
13271 """
13272 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13273 f = mod["f"]
13274 self.assertInBytecode(f, "COMPARE_OP", "<=")
13275 self.assertInBytecode(f, "POP_JUMP_IF_FALSE")
13276 self.assertEqual(f(1, 3), 1)
13277 self.assertEqual(f(3, 1), 1)
13278
13279 def test_min_stability(self):
13280 codestr = """
13281 def f(a: int, b: int) -> int:
13282 return min(a, b)
13283 """
13284 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13285 f = mod["f"]
13286 self.assertInBytecode(f, "COMPARE_OP", "<=")
13287 self.assertInBytecode(f, "POP_JUMP_IF_FALSE")
13288 # p & q should be different objects, but with same value
13289 p = int("11334455667")
13290 q = int("11334455667")
13291 self.assertNotEqual(id(p), id(q))
13292 # Since p and q are equal, the returned value should be the first arg
13293 self.assertEqual(id(f(p, q)), id(p))
13294 self.assertEqual(id(f(q, p)), id(q))
13295
13296 def test_max(self):
13297 codestr = """
13298 def f(a: int, b: int) -> int:
13299 return max(a, b)
13300 """
13301 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13302 f = mod["f"]
13303 self.assertInBytecode(f, "COMPARE_OP", ">=")
13304 self.assertInBytecode(f, "POP_JUMP_IF_FALSE")
13305 self.assertEqual(f(1, 3), 3)
13306 self.assertEqual(f(3, 1), 3)
13307
13308 def test_max_stability(self):
13309 codestr = """
13310 def f(a: int, b: int) -> int:
13311 return max(a, b)
13312 """
13313 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13314 f = mod["f"]
13315 self.assertInBytecode(f, "COMPARE_OP", ">=")
13316 self.assertInBytecode(f, "POP_JUMP_IF_FALSE")
13317 # p & q should be different objects, but with same value
13318 p = int("11334455667")
13319 q = int("11334455667")
13320 self.assertNotEqual(id(p), id(q))
13321 # Since p and q are equal, the returned value should be the first arg
13322 self.assertEqual(id(f(p, q)), id(p))
13323 self.assertEqual(id(f(q, p)), id(q))
13324
13325 def test_extremum_primitive(self):
13326 codestr = """
13327 from __static__ import int8
13328
13329 def f() -> None:
13330 a: int8 = 4
13331 b: int8 = 5
13332 min(a, b)
13333 """
13334 with self.assertRaisesRegex(
13335 TypedSyntaxError, "Call argument cannot be a primitive"
13336 ):
13337 self.compile(codestr, StaticCodeGenerator, modname="foo.py")
13338
13339 def test_extremum_non_specialization_kwarg(self):
13340 codestr = """
13341 def f() -> None:
13342 a = "4"
13343 b = "5"
13344 min(a, b, key=int)
13345 """
13346 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13347 f = mod["f"]
13348 self.assertNotInBytecode(f, "COMPARE_OP")
13349 self.assertNotInBytecode(f, "POP_JUMP_IF_FALSE")
13350
13351 def test_extremum_non_specialization_stararg(self):
13352 codestr = """
13353 def f() -> None:
13354 a = [3, 4]
13355 min(*a)
13356 """
13357 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13358 f = mod["f"]
13359 self.assertNotInBytecode(f, "COMPARE_OP")
13360 self.assertNotInBytecode(f, "POP_JUMP_IF_FALSE")
13361
13362 def test_extremum_non_specialization_dstararg(self):
13363 codestr = """
13364 def f() -> None:
13365 k = {
13366 "default": 5
13367 }
13368 min(3, 4, **k)
13369 """
13370 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13371 f = mod["f"]
13372 self.assertNotInBytecode(f, "COMPARE_OP")
13373 self.assertNotInBytecode(f, "POP_JUMP_IF_FALSE")
13374
13375 def test_try_return_finally(self):
13376 codestr = """
13377 from typing import List
13378
13379 def f1(x: List):
13380 try:
13381 return
13382 finally:
13383 x.append("hi")
13384 """
13385 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13386 f1 = mod["f1"]
13387 l = []
13388 f1(l)
13389 self.assertEqual(l, ["hi"])
13390
13391 def test_cbool(self):
13392 for b in ("True", "False"):
13393 codestr = f"""
13394 from __static__ import cbool
13395
13396 def f() -> int:
13397 x: cbool = {b}
13398 if x:
13399 return 1
13400 else:
13401 return 2
13402 """
13403 with self.subTest(b=b):
13404 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13405 f = mod["f"]
13406 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST")
13407 self.assertInBytecode(
13408 f, "STORE_LOCAL", (0, ("__static__", "cbool"))
13409 )
13410 self.assertInBytecode(f, "POP_JUMP_IF_ZERO")
13411 self.assertEqual(f(), 1 if b == "True" else 2)
13412
13413 def test_cbool_field(self):
13414 codestr = """
13415 from __static__ import cbool
13416
13417 class C:
13418 def __init__(self, x: cbool) -> None:
13419 self.x: cbool = x
13420
13421 def f(c: C):
13422 if c.x:
13423 return True
13424 return False
13425 """
13426 with self.in_module(codestr) as mod:
13427 f, C = mod["f"], mod["C"]
13428 self.assertInBytecode(f, "LOAD_FIELD", (mod["__name__"], "C", "x"))
13429 self.assertInBytecode(f, "POP_JUMP_IF_ZERO")
13430 self.assertIs(C(True).x, True)
13431 self.assertIs(C(False).x, False)
13432 self.assertIs(f(C(True)), True)
13433 self.assertIs(f(C(False)), False)
13434
13435 def test_cbool_cast(self):
13436 codestr = """
13437 from __static__ import cbool
13438
13439 def f(y: bool) -> int:
13440 x: cbool = y
13441 if x:
13442 return 1
13443 else:
13444 return 2
13445 """
13446 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("bool", "cbool")):
13447 self.compile(codestr, StaticCodeGenerator, modname="foo")
13448
13449 def test_primitive_compare_returns_cbool(self):
13450 codestr = """
13451 from __static__ import cbool, int64
13452
13453 def f(x: int64, y: int64) -> cbool:
13454 return x == y
13455 """
13456 with self.in_module(codestr) as mod:
13457 f = mod["f"]
13458 self.assertIs(f(1, 1), True)
13459 self.assertIs(f(1, 2), False)
13460
13461 def test_no_cbool_math(self):
13462 codestr = """
13463 from __static__ import cbool
13464
13465 def f(x: cbool, y: cbool) -> cbool:
13466 return x + y
13467 """
13468 with self.assertRaisesRegex(
13469 TypedSyntaxError, "cbool is not a valid operand type for add"
13470 ):
13471 self.compile(codestr)
13472
13473 def test_chkdict_del(self):
13474 codestr = """
13475 def f():
13476 x = {}
13477 x[1] = "a"
13478 x[2] = "b"
13479 del x[1]
13480 return x
13481 """
13482 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13483 f = mod["f"]
13484 ret = f()
13485 self.assertNotIn(1, ret)
13486 self.assertIn(2, ret)
13487
13488 def test_final_constant_folding_int(self):
13489 codestr = """
13490 from typing import Final
13491
13492 X: Final[int] = 1337
13493
13494 def plus_1337(i: int) -> int:
13495 return i + X
13496 """
13497 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13498 plus_1337 = mod["plus_1337"]
13499 self.assertInBytecode(plus_1337, "LOAD_CONST", 1337)
13500 self.assertNotInBytecode(plus_1337, "LOAD_GLOBAL")
13501 self.assertEqual(plus_1337(3), 1340)
13502
13503 def test_final_constant_folding_bool(self):
13504 codestr = """
13505 from typing import Final
13506
13507 X: Final[bool] = True
13508
13509 def f() -> bool:
13510 return not X
13511 """
13512 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13513 f = mod["f"]
13514 self.assertInBytecode(f, "LOAD_CONST", True)
13515 self.assertNotInBytecode(f, "LOAD_GLOBAL")
13516 self.assertFalse(f())
13517
13518 def test_final_constant_folding_str(self):
13519 codestr = """
13520 from typing import Final
13521
13522 X: Final[str] = "omg"
13523
13524 def f() -> str:
13525 return X[1]
13526 """
13527 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13528 f = mod["f"]
13529 self.assertInBytecode(f, "LOAD_CONST", "omg")
13530 self.assertNotInBytecode(f, "LOAD_GLOBAL")
13531 self.assertEqual(f(), "m")
13532
13533 def test_final_constant_folding_disabled_on_nonfinals(self):
13534 codestr = """
13535 from typing import Final
13536
13537 X: str = "omg"
13538
13539 def f() -> str:
13540 return X[1]
13541 """
13542 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13543 f = mod["f"]
13544 self.assertNotInBytecode(f, "LOAD_CONST", "omg")
13545 self.assertInBytecode(f, "LOAD_GLOBAL", "X")
13546 self.assertEqual(f(), "m")
13547
13548 def test_final_constant_folding_disabled_on_nonconstant_finals(self):
13549 codestr = """
13550 from typing import Final
13551
13552 def p() -> str:
13553 return "omg"
13554
13555 X: Final[str] = p()
13556
13557 def f() -> str:
13558 return X[1]
13559 """
13560 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13561 f = mod["f"]
13562 self.assertNotInBytecode(f, "LOAD_CONST", "omg")
13563 self.assertInBytecode(f, "LOAD_GLOBAL", "X")
13564 self.assertEqual(f(), "m")
13565
13566 def test_final_constant_folding_shadowing(self):
13567 codestr = """
13568 from typing import Final
13569
13570 X: Final[str] = "omg"
13571
13572 def f() -> str:
13573 X = "lol"
13574 return X[1]
13575 """
13576 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13577 f = mod["f"]
13578 self.assertInBytecode(f, "LOAD_CONST", "lol")
13579 self.assertNotInBytecode(f, "LOAD_GLOBAL", "omg")
13580 self.assertEqual(f(), "o")
13581
13582 def test_final_constant_folding_in_module_scope(self):
13583 codestr = """
13584 from typing import Final
13585
13586 X: Final[int] = 21
13587 y = X + 3
13588 """
13589 c = self.compile(codestr, generator=StaticCodeGenerator, modname="foo.py")
13590 self.assertNotInBytecode(c, "LOAD_NAME", "X")
13591 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13592 self.assertEqual(mod["y"], 24)
13593
13594 def test_final_constant_in_module_scope(self):
13595 codestr = """
13596 from typing import Final
13597
13598 X: Final[int] = 21
13599 """
13600 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13601 self.assertEqual(mod["__final_constants__"], ("X",))
13602
13603 def test_final_nonconstant_in_module_scope(self):
13604 codestr = """
13605 from typing import Final
13606
13607 def p() -> str:
13608 return "omg"
13609
13610 X: Final[str] = p()
13611 """
13612 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13613 self.assertEqual(mod["__final_constants__"], ())
13614
13615 def test_double_load_const(self):
13616 codestr = """
13617 from __static__ import double
13618
13619 def t():
13620 pi: double = 3.14159
13621 """
13622 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13623 t = mod["t"]
13624 self.assertInBytecode(t, "PRIMITIVE_LOAD_CONST", (3.14159, TYPED_DOUBLE))
13625 t()
13626 self.assert_jitted(t)
13627
13628 def test_double_box(self):
13629 codestr = """
13630 from __static__ import double, box
13631
13632 def t() -> float:
13633 pi: double = 3.14159
13634 return box(pi)
13635 """
13636 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13637 t = mod["t"]
13638 self.assertInBytecode(t, "PRIMITIVE_LOAD_CONST", (3.14159, TYPED_DOUBLE))
13639 self.assertNotInBytecode(t, "CAST")
13640 self.assertEqual(t(), 3.14159)
13641 self.assert_jitted(t)
13642
13643 def test_none_not(self):
13644 codestr = """
13645 def t() -> bool:
13646 x = None
13647 if not x:
13648 return True
13649 else:
13650 return False
13651 """
13652 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13653 t = mod["t"]
13654 self.assertInBytecode(t, "POP_JUMP_IF_TRUE")
13655 self.assertTrue(t())
13656
13657 def test_unbox_final_inlined(self):
13658 for func in ["unbox", "int64"]:
13659 with self.subTest(func=func):
13660 codestr = f"""
13661 from typing import Final
13662 from __static__ import int64, unbox
13663
13664 MY_FINAL: Final[int] = 111
13665
13666 def t() -> bool:
13667 i: int64 = 64
13668 if i < {func}(MY_FINAL):
13669 return True
13670 else:
13671 return False
13672 """
13673 with self.in_module(codestr) as mod:
13674 t = mod["t"]
13675 self.assertInBytecode(t, "PRIMITIVE_LOAD_CONST", (111, TYPED_INT64))
13676 self.assertEqual(t(), True)
13677
13678 def test_augassign_inexact(self):
13679 codestr = """
13680 def something():
13681 return 3
13682
13683 def t():
13684 a: int = something()
13685
13686 b = 0
13687 b += a
13688 return b
13689 """
13690 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13691 t = mod["t"]
13692 self.assertInBytecode(t, "INPLACE_ADD")
13693 self.assertEqual(t(), 3)
13694
13695 def test_qualname(self):
13696 codestr = """
13697 def f():
13698 pass
13699
13700
13701 class C:
13702 def x(self):
13703 pass
13704
13705 @staticmethod
13706 def sm():
13707 pass
13708
13709 @classmethod
13710 def cm():
13711 pass
13712
13713 def f(self):
13714 class G:
13715 def y(self):
13716 pass
13717 return G.y
13718 """
13719 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13720 f = mod["f"]
13721 C = mod["C"]
13722 self.assertEqual(cinder._get_qualname(f.__code__), "f")
13723
13724 self.assertEqual(cinder._get_qualname(C.x.__code__), "C.x")
13725 self.assertEqual(cinder._get_qualname(C.sm.__code__), "C.sm")
13726 self.assertEqual(cinder._get_qualname(C.cm.__code__), "C.cm")
13727
13728 self.assertEqual(cinder._get_qualname(C().f().__code__), "C.f.<locals>.G.y")
13729
13730 def test_refine_optional_name(self):
13731 codestr = """
13732 from typing import Optional
13733
13734 def f(s: Optional[str]) -> bytes:
13735 return s.encode("utf-8") if s else b""
13736 """
13737 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13738 f = mod["f"]
13739 self.assertEqual(f("A"), b"A")
13740 self.assertEqual(f(None), b"")
13741
13742 def test_donotcompile_fn(self):
13743 codestr = """
13744 from __static__ import _donotcompile
13745
13746 def a() -> int:
13747 return 1
13748
13749 @_donotcompile
13750 def fn() -> None:
13751 a() + 2
13752
13753 def fn2() -> None:
13754 a() + 2
13755 """
13756 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13757 fn = mod["fn"]
13758 self.assertInBytecode(fn, "CALL_FUNCTION")
13759 self.assertNotInBytecode(fn, "INVOKE_FUNCTION")
13760 self.assertFalse(fn.__code__.co_flags & CO_STATICALLY_COMPILED)
13761 self.assertEqual(fn(), None)
13762
13763 fn2 = mod["fn2"]
13764 self.assertNotInBytecode(fn2, "CALL_FUNCTION")
13765 self.assertInBytecode(fn2, "INVOKE_FUNCTION")
13766 self.assertTrue(fn2.__code__.co_flags & CO_STATICALLY_COMPILED)
13767 self.assertEqual(fn2(), None)
13768
13769 def test_donotcompile_method(self):
13770 codestr = """
13771 from __static__ import _donotcompile
13772
13773 def a() -> int:
13774 return 1
13775
13776 class C:
13777 @_donotcompile
13778 def fn() -> None:
13779 a() + 2
13780
13781 def fn2() -> None:
13782 a() + 2
13783
13784 c = C()
13785 """
13786 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13787 C = mod["C"]
13788
13789 fn2 = C.fn2
13790 self.assertNotInBytecode(fn2, "CALL_FUNCTION")
13791 self.assertInBytecode(fn2, "INVOKE_FUNCTION")
13792 self.assertTrue(fn2.__code__.co_flags & CO_STATICALLY_COMPILED)
13793 self.assertEqual(fn2(), None)
13794
13795 def test_donotcompile_class(self):
13796 codestr = """
13797 from __static__ import _donotcompile
13798
13799 def a() -> int:
13800 return 1
13801
13802 @_donotcompile
13803 class C:
13804 def fn() -> None:
13805 a() + 2
13806
13807 @_donotcompile
13808 class D:
13809 a()
13810
13811 c = C()
13812 """
13813 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13814 C = mod["C"]
13815 fn = C.fn
13816 self.assertInBytecode(fn, "CALL_FUNCTION")
13817 self.assertNotInBytecode(fn, "INVOKE_FUNCTION")
13818 self.assertFalse(fn.__code__.co_flags & CO_STATICALLY_COMPILED)
13819 self.assertEqual(fn(), None)
13820
13821 D = mod["D"]
13822
13823 def test_donotcompile_lambda(self):
13824 codestr = """
13825 from __static__ import _donotcompile
13826
13827 def a() -> int:
13828 return 1
13829
13830 class C:
13831 @_donotcompile
13832 def fn() -> None:
13833 z = lambda: a() + 2
13834 z()
13835
13836 def fn2() -> None:
13837 z = lambda: a() + 2
13838 z()
13839
13840 c = C()
13841 """
13842 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod:
13843 C = mod["C"]
13844 fn = C.fn
13845 lambda_code = self.find_code(fn.__code__)
13846 self.assertNotInBytecode(lambda_code, "INVOKE_FUNCTION")
13847 self.assertFalse(lambda_code.co_flags & CO_STATICALLY_COMPILED)
13848 self.assertEqual(fn(), None)
13849
13850 fn2 = C.fn2
13851 lambda_code2 = self.find_code(fn2.__code__)
13852 self.assertInBytecode(lambda_code2, "INVOKE_FUNCTION")
13853 self.assertTrue(lambda_code2.co_flags & CO_STATICALLY_COMPILED)
13854 self.assertEqual(fn2(), None)
13855
13856 def test_double_binop(self):
13857 tests = [
13858 (1.732, 2.0, "+", 3.732),
13859 (1.732, 2.0, "-", -0.268),
13860 (1.732, 2.0, "/", 0.866),
13861 (1.732, 2.0, "*", 3.464),
13862 (1.732, 2, "+", 3.732),
13863 ]
13864
13865 if cinderjit is not None:
13866 # test for division by zero
13867 tests.append((1.732, 0.0, "/", float("inf")))
13868
13869 for x, y, op, res in tests:
13870 codestr = f"""
13871 from __static__ import double, box
13872 def testfunc(tst):
13873 x: double = {x}
13874 y: double = {y}
13875
13876 z: double = x {op} y
13877 return box(z)
13878 """
13879 with self.subTest(type=type, x=x, y=y, op=op, res=res):
13880 with self.in_module(codestr) as mod:
13881 f = mod["testfunc"]
13882 self.assertEqual(f(False), res, f"{type} {x} {op} {y} {res}")
13883
13884 def test_primitive_stack_spill(self):
13885 # Create enough locals that some must get spilled to stack, to test
13886 # shuffling stack-spilled values across basic block transitions, and
13887 # field reads/writes with stack-spilled values. These can create
13888 # mem->mem moves that otherwise wouldn't exist, and trigger issues
13889 # like push/pop not supporting 8 or 32 bits on x64.
13890 varnames = string.ascii_lowercase[:20]
13891 sizes = ["uint8", "int16", "int32", "int64"]
13892 for size in sizes:
13893 indent = " " * 20
13894 attrs = f"\n{indent}".join(f"{var}: {size}" for var in varnames)
13895 inits = f"\n{indent}".join(
13896 f"{var}: {size} = {val}" for val, var in enumerate(varnames)
13897 )
13898 assigns = f"\n{indent}".join(f"val.{var} = {var}" for var in varnames)
13899 reads = f"\n{indent}".join(f"{var} = val.{var}" for var in varnames)
13900 indent = " " * 24
13901 incrs = f"\n{indent}".join(f"{var} += 1" for var in varnames)
13902 codestr = f"""
13903 from __static__ import {size}
13904
13905 class C:
13906 {attrs}
13907
13908 def f(val: C, flag: {size}) -> {size}:
13909 {inits}
13910 if flag:
13911 {incrs}
13912 {assigns}
13913 {reads}
13914 return {' + '.join(varnames)}
13915 """
13916 with self.subTest(size=size):
13917 with self.in_module(codestr) as mod:
13918 f, C = mod["f"], mod["C"]
13919 c = C()
13920 self.assertEqual(f(c, 0), sum(range(len(varnames))))
13921 for val, var in enumerate(varnames):
13922 self.assertEqual(getattr(c, var), val)
13923
13924 c = C()
13925 self.assertEqual(f(c, 1), sum(range(len(varnames) + 1)))
13926 for val, var in enumerate(varnames):
13927 self.assertEqual(getattr(c, var), val + 1)
13928
13929 def test_class_static_tpflag(self):
13930 codestr = """
13931 class A:
13932 pass
13933 """
13934 with self.in_module(codestr) as mod:
13935 A = mod["A"]
13936 self.assertTrue(is_type_static(A))
13937
13938 class B:
13939 pass
13940
13941 self.assertFalse(is_type_static(B))
13942
13943
13944if __name__ == "__main__":
13945 unittest.main()