this repo has no description
at trunk 13945 lines 438 kB view raw
1# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) 2import ast 3import asyncio 4import dis 5import gc 6import inspect 7import itertools 8import re 9import string 10import sys 11import unittest 12import warnings 13from array import array 14from collections import UserDict 15from compiler import consts, walk 16from compiler.consts38 import CO_NO_FRAME, CO_STATICALLY_COMPILED 17from compiler.optimizer import AstOptimizer 18from compiler.pycodegen import PythonCodeGenerator, make_compiler 19from compiler.static import ( 20 prim_name_to_type, 21 BOOL_TYPE, 22 BYTES_TYPE, 23 COMPLEX_EXACT_TYPE, 24 Class, 25 Object, 26 DICT_TYPE, 27 DYNAMIC, 28 FLOAT_TYPE, 29 INT_TYPE, 30 LIST_EXACT_TYPE, 31 LIST_TYPE, 32 NONE_TYPE, 33 OBJECT_TYPE, 34 PRIM_OP_ADD_INT, 35 PRIM_OP_DIV_INT, 36 PRIM_OP_GT_INT, 37 PRIM_OP_LT_INT, 38 STR_EXACT_TYPE, 39 STR_TYPE, 40 DeclarationVisitor, 41 Function, 42 StaticCodeGenerator, 43 SymbolTable, 44 TypeBinder, 45 TypedSyntaxError, 46 Value, 47 TUPLE_EXACT_TYPE, 48 TUPLE_TYPE, 49 INT_EXACT_TYPE, 50 FLOAT_EXACT_TYPE, 51 SET_EXACT_TYPE, 52 ELLIPSIS_TYPE, 53 FAST_LEN_ARRAY, 54 FAST_LEN_LIST, 55 FAST_LEN_TUPLE, 56 FAST_LEN_INEXACT, 57 FAST_LEN_DICT, 58 FAST_LEN_SET, 59 SEQ_ARRAY_INT8, 60 SEQ_ARRAY_INT16, 61 SEQ_ARRAY_INT32, 62 SEQ_ARRAY_INT64, 63 SEQ_ARRAY_UINT8, 64 SEQ_ARRAY_UINT16, 65 SEQ_ARRAY_UINT32, 66 SEQ_ARRAY_UINT64, 67 SEQ_LIST, 68 SEQ_LIST_INEXACT, 69 SEQ_TUPLE, 70 SEQ_REPEAT_INEXACT_SEQ, 71 SEQ_REPEAT_INEXACT_NUM, 72 SEQ_REPEAT_PRIMITIVE_NUM, 73 SEQ_REPEAT_REVERSED, 74 SEQ_SUBSCR_UNCHECKED, 75 TYPED_BOOL, 76 TYPED_INT8, 77 TYPED_INT16, 78 TYPED_INT32, 79 TYPED_INT64, 80 TYPED_UINT8, 81 TYPED_UINT16, 82 TYPED_UINT32, 83 TYPED_UINT64, 84 FAST_LEN_STR, 85 DICT_EXACT_TYPE, 86 TYPED_DOUBLE, 87) 88from compiler.symbols import SymbolVisitor 89from contextlib import contextmanager 90from copy import deepcopy 91from io import StringIO 92from os import path 93from test.support import maybe_get_event_loop_policy 94from textwrap import dedent 95from types import CodeType, MemberDescriptorType, ModuleType 96from typing import Generic, Optional, Tuple, TypeVar 97from unittest import TestCase, skip, skipIf 98from unittest.mock import Mock, patch 99 100import cinder 101import xxclassloader 102from __static__ import ( 103 Array, 104 Vector, 105 chkdict, 106 int32, 107 int64, 108 int8, 109 make_generic_type, 110 StaticGeneric, 111 is_type_static, 112) 113from cinder import StrictModule 114 115from .common import CompilerTest 116 117IS_MULTITHREADED_COMPILE_TEST = False 118try: 119 import cinderjit 120 121 IS_MULTITHREADED_COMPILE_TEST = cinderjit.is_test_multithreaded_compile_enabled() 122except ImportError: 123 cinderjit = None 124 125 126RICHARDS_PATH = path.join( 127 path.dirname(__file__), 128 "..", 129 "..", 130 "..", 131 "Tools", 132 "benchmarks", 133 "richards_static.py", 134) 135 136 137PRIM_NAME_TO_TYPE = { 138 "cbool": TYPED_BOOL, 139 "int8": TYPED_INT8, 140 "int16": TYPED_INT16, 141 "int32": TYPED_INT32, 142 "int64": TYPED_INT64, 143 "uint8": TYPED_UINT8, 144 "uint16": TYPED_UINT16, 145 "uint32": TYPED_UINT32, 146 "uint64": TYPED_UINT64, 147} 148 149 150def type_mismatch(from_type: str, to_type: str) -> str: 151 return re.escape(f"type mismatch: {from_type} cannot be assigned to {to_type}") 152 153 154def optional(type: str) -> str: 155 return f"Optional[{type}]" 156 157 158def init_xxclassloader(): 159 codestr = """ 160 from typing import Generic, TypeVar, _tp_cache 161 from __static__.compiler_flags import nonchecked_dicts 162 # Setup a test for typing 163 T = TypeVar('T') 164 U = TypeVar('U') 165 166 167 class XXGeneric(Generic[T, U]): 168 d = {} 169 170 def foo(self, t: T, u: U) -> str: 171 return str(t) + str(u) 172 173 @classmethod 174 def __class_getitem__(cls, elem_type): 175 if elem_type in XXGeneric.d: 176 return XXGeneric.d[elem_type] 177 178 XXGeneric.d[elem_type] = type( 179 f"XXGeneric[{elem_type[0].__name__}, {elem_type[1].__name__}]", 180 (object, ), 181 { 182 "foo": XXGeneric.foo, 183 "__slots__":(), 184 } 185 ) 186 return XXGeneric.d[elem_type] 187 """ 188 189 code = make_compiler( 190 inspect.cleandoc(codestr), 191 "", 192 "exec", 193 generator=StaticCodeGenerator, 194 modname="xxclassloader", 195 ).getCode() 196 d = {} 197 exec(code, d, d) 198 199 xxclassloader.XXGeneric = d["XXGeneric"] 200 201 202class StaticTestBase(CompilerTest): 203 def compile( 204 self, 205 code, 206 generator=StaticCodeGenerator, 207 modname="<module>", 208 optimize=0, 209 peephole_enabled=True, 210 ast_optimizer_enabled=True, 211 ): 212 if ( 213 not peephole_enabled 214 or not ast_optimizer_enabled 215 or generator is not StaticCodeGenerator 216 ): 217 return super().compile( 218 code, 219 generator, 220 modname, 221 optimize, 222 peephole_enabled, 223 ast_optimizer_enabled, 224 ) 225 226 symtable = SymbolTable() 227 code = inspect.cleandoc("\n" + code) 228 tree = ast.parse(code) 229 return symtable.compile(modname, f"{modname}.py", tree, optimize) 230 231 def type_error(self, code, pattern): 232 with self.assertRaisesRegex(TypedSyntaxError, pattern): 233 self.compile(code) 234 235 _temp_mod_num = 0 236 237 def _temp_mod_name(self): 238 StaticTestBase._temp_mod_num += 1 239 return sys._getframe().f_back.f_back.f_back.f_code.co_name + str( 240 StaticTestBase._temp_mod_num 241 ) 242 243 @contextmanager 244 def in_module(self, code, name=None, code_gen=StaticCodeGenerator, optimize=0): 245 if name is None: 246 name = self._temp_mod_name() 247 248 try: 249 compiled = self.compile(code, code_gen, name, optimize) 250 m = type(sys)(name) 251 d = m.__dict__ 252 sys.modules[name] = m 253 exec(compiled, d) 254 d["__name__"] = name 255 256 yield d 257 finally: 258 if not IS_MULTITHREADED_COMPILE_TEST: 259 # don't throw a new exception if we failed to compile 260 if name in sys.modules: 261 del sys.modules[name] 262 d.clear() 263 gc.collect() 264 265 @contextmanager 266 def in_strict_module( 267 self, 268 code, 269 name=None, 270 code_gen=StaticCodeGenerator, 271 optimize=0, 272 enable_patching=False, 273 ): 274 if name is None: 275 name = self._temp_mod_name() 276 277 try: 278 compiled = self.compile(code, code_gen, name, optimize) 279 d = {"__name__": name} 280 m = StrictModule(d, enable_patching) 281 sys.modules[name] = m 282 exec(compiled, d) 283 284 yield m 285 finally: 286 if not IS_MULTITHREADED_COMPILE_TEST: 287 # don't throw a new exception if we failed to compile 288 if name in sys.modules: 289 del sys.modules[name] 290 d.clear() 291 gc.collect() 292 293 def run_code(self, code, generator=None, modname=None, peephole_enabled=True): 294 if modname is None: 295 modname = self._temp_mod_name() 296 d = super().run_code(code, generator, modname, peephole_enabled) 297 if IS_MULTITHREADED_COMPILE_TEST: 298 sys.modules[modname] = d 299 return d 300 301 @property 302 def base_size(self): 303 class C: 304 __slots__ = () 305 306 return sys.getsizeof(C()) 307 308 @property 309 def ptr_size(self): 310 return 8 if sys.maxsize > 2 ** 32 else 4 311 312 def assert_jitted(self, func): 313 if cinderjit is None: 314 return 315 316 self.assertTrue(cinderjit.is_jit_compiled(func), func.__name__) 317 318 def assert_not_jitted(self, func): 319 if cinderjit is None: 320 return 321 322 self.assertFalse(cinderjit.is_jit_compiled(func)) 323 324 def assert_not_jitted(self, func): 325 if cinderjit is None: 326 return 327 328 self.assertFalse(cinderjit.is_jit_compiled(func)) 329 330 def setUp(self): 331 # ensure clean classloader/vtable slate for all tests 332 cinder.clear_classloader_caches() 333 # ensure our async tests don't change the event loop policy 334 policy = maybe_get_event_loop_policy() 335 self.addCleanup(lambda: asyncio.set_event_loop_policy(policy)) 336 337 def subTest(self, **kwargs): 338 cinder.clear_classloader_caches() 339 return super().subTest(**kwargs) 340 341 def make_async_func_hot(self, func): 342 async def make_hot(): 343 for i in range(50): 344 await func() 345 346 asyncio.run(make_hot()) 347 348 349class StaticCompilationTests(StaticTestBase): 350 @classmethod 351 def setUpClass(cls): 352 init_xxclassloader() 353 354 @classmethod 355 def tearDownClass(cls): 356 if not IS_MULTITHREADED_COMPILE_TEST: 357 del xxclassloader.XXGeneric 358 359 def test_static_import_unknown(self) -> None: 360 codestr = """ 361 from __static__ import does_not_exist 362 """ 363 with self.assertRaises(TypedSyntaxError): 364 self.compile(codestr, StaticCodeGenerator, modname="foo") 365 366 def test_static_import_star(self) -> None: 367 codestr = """ 368 from __static__ import * 369 """ 370 with self.assertRaises(TypedSyntaxError): 371 self.compile(codestr, StaticCodeGenerator, modname="foo") 372 373 def test_reveal_type(self) -> None: 374 codestr = """ 375 def f(x: int): 376 reveal_type(x or None) 377 """ 378 with self.assertRaisesRegex( 379 TypedSyntaxError, 380 r"reveal_type\(x or None\): 'Optional\[int\]'", 381 ): 382 self.compile(codestr) 383 384 def test_reveal_type_local(self) -> None: 385 codestr = """ 386 def f(x: int | None): 387 if x is not None: 388 reveal_type(x) 389 """ 390 with self.assertRaisesRegex( 391 TypedSyntaxError, 392 r"reveal_type\(x\): 'int', 'x' has declared type 'Optional\[int\]' and local type 'int'", 393 ): 394 self.compile(codestr) 395 396 def test_redefine_local_type(self) -> None: 397 codestr = """ 398 class C: pass 399 class D: pass 400 401 def f(): 402 x: C = C() 403 x: D = D() 404 """ 405 with self.assertRaises(TypedSyntaxError): 406 self.compile(codestr, StaticCodeGenerator, modname="foo") 407 408 def test_mixed_chain_assign(self) -> None: 409 codestr = """ 410 class C: pass 411 class D: pass 412 413 def f(): 414 x: C = C() 415 y: D = D() 416 x = y = D() 417 """ 418 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("foo.D", "foo.C")): 419 self.compile(codestr, StaticCodeGenerator, modname="foo") 420 421 def test_bool_cast(self) -> None: 422 codestr = """ 423 from __static__ import cast 424 class D: pass 425 426 def f(x) -> bool: 427 y: bool = cast(bool, x) 428 return y 429 """ 430 self.compile(codestr, StaticCodeGenerator, modname="foo") 431 432 def test_typing_overload(self) -> None: 433 """Typing overloads are ignored, don't cause member name conflict.""" 434 codestr = """ 435 from typing import Optional, overload 436 437 class C: 438 @overload 439 def foo(self, x: int) -> int: 440 ... 441 442 def foo(self, x: Optional[int]) -> Optional[int]: 443 return x 444 445 def f(x: int) -> Optional[int]: 446 return C().foo(x) 447 """ 448 self.assertReturns(codestr, "Optional[int]") 449 450 def test_mixed_binop(self): 451 with self.assertRaisesRegex( 452 TypedSyntaxError, "cannot add int64 and Exact\\[int\\]" 453 ): 454 self.bind_module( 455 """ 456 from __static__ import ssize_t 457 458 def f(): 459 x: ssize_t = 1 460 y = 1 461 x + y 462 """ 463 ) 464 465 with self.assertRaisesRegex( 466 TypedSyntaxError, "cannot add Exact\\[int\\] and int64" 467 ): 468 self.bind_module( 469 """ 470 from __static__ import ssize_t 471 472 def f(): 473 x: ssize_t = 1 474 y = 1 475 y + x 476 """ 477 ) 478 479 def test_mixed_binop_okay(self): 480 codestr = """ 481 from __static__ import ssize_t, box 482 483 def f(): 484 x: ssize_t = 1 485 y = x + 1 486 return box(y) 487 """ 488 with self.in_module(codestr) as mod: 489 f = mod["f"] 490 self.assertEqual(f(), 2) 491 492 def test_mixed_binop_okay_1(self): 493 codestr = """ 494 from __static__ import ssize_t, box 495 496 def f(): 497 x: ssize_t = 1 498 y = 1 + x 499 return box(y) 500 """ 501 with self.in_module(codestr) as mod: 502 f = mod["f"] 503 self.assertEqual(f(), 2) 504 505 def test_inferred_primitive_type(self): 506 codestr = """ 507 from __static__ import ssize_t, box 508 509 def f(): 510 x: ssize_t = 1 511 y = x 512 return box(y) 513 """ 514 with self.in_module(codestr) as mod: 515 f = mod["f"] 516 self.assertEqual(f(), 1) 517 518 @skipIf(cinderjit is None, "not jitting") 519 def test_deep_attr_chain(self): 520 """this shouldn't explode exponentially""" 521 codestr = """ 522 def f(x): 523 return x.x.x.x.x.x.x 524 525 """ 526 527 class C: 528 def __init__(self): 529 self.x = self 530 531 orig_bind_attr = Object.bind_attr 532 call_count = 0 533 534 def bind_attr(*args): 535 nonlocal call_count 536 call_count += 1 537 return orig_bind_attr(*args) 538 539 with patch("compiler.static.Object.bind_attr", bind_attr): 540 with self.in_module(codestr) as mod: 541 f = mod["f"] 542 x = C() 543 self.assertEqual(f(x), x) 544 # Initially this would be 63 when we were double visiting 545 self.assertLess(call_count, 10) 546 547 @skipIf(cinderjit is None, "not jitting") 548 def test_no_frame(self): 549 codestr = """ 550 from __static__.compiler_flags import noframe 551 552 def f(): 553 return 456 554 """ 555 with self.in_module(codestr) as mod: 556 f = mod["f"] 557 self.assertTrue(f.__code__.co_flags & CO_NO_FRAME) 558 self.assertEqual(f(), 456) 559 self.assert_jitted(f) 560 561 @skipIf(cinderjit is None, "not jitting") 562 def test_no_frame_generator(self): 563 codestr = """ 564 from __static__.compiler_flags import noframe 565 566 def g(): 567 for i in range(10): 568 yield i 569 def f(): 570 return list(g()) 571 """ 572 with self.in_module(codestr) as mod: 573 f = mod["f"] 574 self.assertTrue(f.__code__.co_flags & CO_NO_FRAME) 575 self.assertEqual(f(), list(range(10))) 576 self.assert_jitted(f) 577 578 def test_subclass_binop(self): 579 codestr = """ 580 class C: pass 581 class D(C): pass 582 583 def f(x: C, y: D): 584 return x + y 585 """ 586 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 587 f = self.find_code(code, "f") 588 self.assertInBytecode(f, "BINARY_ADD") 589 590 def test_exact_invoke_function(self): 591 codestr = """ 592 def f() -> str: 593 return ", ".join(['1','2','3']) 594 """ 595 f = self.find_code(self.compile(codestr)) 596 with self.in_module(codestr) as mod: 597 f = mod["f"] 598 self.assertInBytecode( 599 f, "INVOKE_FUNCTION", (("builtins", "str", "join"), 2) 600 ) 601 f() 602 603 def test_multiply_list_exact_by_int(self): 604 codestr = """ 605 def f() -> int: 606 l = [1, 2, 3] * 2 607 return len(l) 608 """ 609 f = self.find_code(self.compile(codestr)) 610 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST) 611 with self.in_module(codestr) as mod: 612 self.assertEqual(mod["f"](), 6) 613 614 def test_multiply_list_exact_by_int_reverse(self): 615 codestr = """ 616 def f() -> int: 617 l = 2 * [1, 2, 3] 618 return len(l) 619 """ 620 f = self.find_code(self.compile(codestr)) 621 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST) 622 with self.in_module(codestr) as mod: 623 self.assertEqual(mod["f"](), 6) 624 625 def test_int_bad_assign(self): 626 with self.assertRaisesRegex( 627 TypedSyntaxError, "str cannot be used in a context where an int is expected" 628 ): 629 code = self.compile( 630 """ 631 from __static__ import ssize_t 632 def f(): 633 x: ssize_t = 'abc' 634 """, 635 StaticCodeGenerator, 636 ) 637 638 def test_sign_extend(self): 639 codestr = f""" 640 from __static__ import int16, int64, box 641 def testfunc(): 642 x: int16 = -40 643 y: int64 = x 644 return box(y) 645 """ 646 code = self.compile(codestr, StaticCodeGenerator) 647 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"] 648 self.assertEqual(f(), -40) 649 650 def test_field_size(self): 651 for type in [ 652 "int8", 653 "int16", 654 "int32", 655 "int64", 656 "uint8", 657 "uint16", 658 "uint32", 659 "uint64", 660 ]: 661 codestr = f""" 662 from __static__ import {type}, box 663 class C{type}: 664 def __init__(self): 665 self.a: {type} = 1 666 self.b: {type} = 1 667 668 def testfunc(c: C{type}): 669 c.a = 2 670 c.b = 3 671 return box(c.a + c.b) 672 """ 673 with self.subTest(type=type): 674 with self.in_module(codestr) as mod: 675 C = mod["C" + type] 676 f = mod["testfunc"] 677 self.assertEqual(f(C()), 5) 678 679 def test_field_sign_ext(self): 680 """tests that we do the correct sign extension when loading from a field""" 681 for type, val in [ 682 ("int32", 65537), 683 ("int16", 256), 684 ("int8", 0x7F), 685 ("uint32", 65537), 686 ]: 687 codestr = f""" 688 from __static__ import {type}, box 689 class C{type}: 690 def __init__(self): 691 self.value: {type} = {val} 692 693 def testfunc(c: C{type}): 694 return box(c.value) 695 """ 696 with self.subTest(type=type, val=val): 697 with self.in_module(codestr) as mod: 698 C = mod["C" + type] 699 f = mod["testfunc"] 700 self.assertEqual(f(C()), val) 701 702 def test_field_unsign_ext(self): 703 """tests that we do the correct sign extension when loading from a field""" 704 for type, val, test in [("uint32", 65537, -1)]: 705 codestr = f""" 706 from __static__ import {type}, int64, box 707 class C{type}: 708 def __init__(self): 709 self.value: {type} = {val} 710 711 def testfunc(c: C{type}): 712 z: int64 = {test} 713 if c.value < z: 714 return True 715 return False 716 """ 717 with self.subTest(type=type, val=val, test=test): 718 with self.in_module(codestr) as mod: 719 C = mod["C" + type] 720 f = mod["testfunc"] 721 self.assertEqual(f(C()), False) 722 723 def test_field_sign_compare(self): 724 for type, val, test in [("int32", -1, -1)]: 725 codestr = f""" 726 from __static__ import {type}, box 727 class C{type}: 728 def __init__(self): 729 self.value: {type} = {val} 730 731 def testfunc(c: C{type}): 732 if c.value == {test}: 733 return True 734 return False 735 """ 736 with self.subTest(type=type, val=val, test=test): 737 with self.in_module(codestr) as mod: 738 C = mod["C" + type] 739 f = mod["testfunc"] 740 self.assertTrue(f(C())) 741 742 def test_mixed_binop_sign(self): 743 """mixed signed/unsigned ops should be promoted to signed""" 744 codestr = """ 745 from __static__ import int8, uint8, box 746 def testfunc(): 747 x: uint8 = 42 748 y: int8 = 2 749 return box(x / y) 750 """ 751 code = self.compile(codestr, StaticCodeGenerator) 752 f = self.find_code(code) 753 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_DIV_INT) 754 with self.in_module(codestr) as mod: 755 f = mod["testfunc"] 756 self.assertEqual(f(), 21) 757 758 codestr = """ 759 from __static__ import int8, uint8, box 760 def testfunc(): 761 x: int8 = 42 762 y: uint8 = 2 763 return box(x / y) 764 """ 765 code = self.compile(codestr, StaticCodeGenerator) 766 f = self.find_code(code) 767 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_DIV_INT) 768 with self.in_module(codestr) as mod: 769 f = mod["testfunc"] 770 self.assertEqual(f(), 21) 771 772 def test_mixed_cmpop_sign(self): 773 """mixed signed/unsigned ops should be promoted to signed""" 774 codestr = """ 775 from __static__ import int8, uint8, box 776 def testfunc(tst=False): 777 x: uint8 = 42 778 y: int8 = 2 779 if tst: 780 x += 1 781 y += 1 782 783 if x < y: 784 return True 785 return False 786 """ 787 code = self.compile(codestr, StaticCodeGenerator) 788 f = self.find_code(code) 789 self.assertInBytecode(f, "INT_COMPARE_OP", PRIM_OP_LT_INT) 790 with self.in_module(codestr) as mod: 791 f = mod["testfunc"] 792 self.assertEqual(f(), False) 793 794 codestr = """ 795 from __static__ import int8, uint8, box 796 def testfunc(tst=False): 797 x: int8 = 42 798 y: uint8 = 2 799 if tst: 800 x += 1 801 y += 1 802 803 if x < y: 804 return True 805 return False 806 """ 807 code = self.compile(codestr, StaticCodeGenerator) 808 f = self.find_code(code) 809 self.assertInBytecode(f, "INT_COMPARE_OP", PRIM_OP_LT_INT) 810 with self.in_module(codestr) as mod: 811 f = mod["testfunc"] 812 self.assertEqual(f(), False) 813 814 def test_mixed_add_reversed(self): 815 codestr = """ 816 from __static__ import int8, uint8, int64, box, int16 817 def testfunc(tst=False): 818 x: int8 = 42 819 y: int16 = 2 820 if tst: 821 x += 1 822 y += 1 823 824 return box(y + x) 825 """ 826 code = self.compile(codestr, StaticCodeGenerator) 827 f = self.find_code(code) 828 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT) 829 with self.in_module(codestr) as mod: 830 f = mod["testfunc"] 831 self.assertEqual(f(), 44) 832 833 def test_mixed_tri_add(self): 834 codestr = """ 835 from __static__ import int8, uint8, int64, box 836 def testfunc(tst=False): 837 x: uint8 = 42 838 y: int8 = 2 839 z: int64 = 3 840 if tst: 841 x += 1 842 y += 1 843 844 return box(x + y + z) 845 """ 846 code = self.compile(codestr, StaticCodeGenerator) 847 f = self.find_code(code) 848 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT) 849 with self.in_module(codestr) as mod: 850 f = mod["testfunc"] 851 self.assertEqual(f(), 47) 852 853 def test_mixed_tri_add_unsigned(self): 854 """promote int/uint to int, can't add to uint64""" 855 856 codestr = """ 857 from __static__ import int8, uint8, uint64, box 858 def testfunc(tst=False): 859 x: uint8 = 42 860 y: int8 = 2 861 z: uint64 = 3 862 863 return box(x + y + z) 864 """ 865 866 with self.assertRaisesRegex(TypedSyntaxError, "cannot add int16 and uint64"): 867 self.compile(codestr, StaticCodeGenerator) 868 869 def test_store_signed_to_unsigned(self): 870 871 codestr = """ 872 from __static__ import int8, uint8, uint64, box 873 def testfunc(tst=False): 874 x: uint8 = 42 875 y: int8 = 2 876 x = y 877 """ 878 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("int8", "uint8")): 879 self.compile(codestr, StaticCodeGenerator) 880 881 def test_store_unsigned_to_signed(self): 882 """promote int/uint to int, can't add to uint64""" 883 884 codestr = """ 885 from __static__ import int8, uint8, uint64, box 886 def testfunc(tst=False): 887 x: uint8 = 42 888 y: int8 = 2 889 y = x 890 """ 891 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("uint8", "int8")): 892 self.compile(codestr, StaticCodeGenerator) 893 894 def test_mixed_assign_larger(self): 895 """promote int/uint to int16""" 896 897 codestr = """ 898 from __static__ import int8, uint8, int16, box 899 def testfunc(tst=False): 900 x: uint8 = 42 901 y: int8 = 2 902 z: int16 = x + y 903 904 return box(z) 905 """ 906 code = self.compile(codestr, StaticCodeGenerator) 907 f = self.find_code(code) 908 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT) 909 with self.in_module(codestr) as mod: 910 f = mod["testfunc"] 911 self.assertEqual(f(), 44) 912 913 def test_mixed_assign_larger_2(self): 914 """promote int/uint to int16""" 915 916 codestr = """ 917 from __static__ import int8, uint8, int16, box 918 def testfunc(tst=False): 919 x: uint8 = 42 920 y: int8 = 2 921 z: int16 922 z = x + y 923 924 return box(z) 925 """ 926 code = self.compile(codestr, StaticCodeGenerator) 927 f = self.find_code(code) 928 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT) 929 with self.in_module(codestr) as mod: 930 f = mod["testfunc"] 931 self.assertEqual(f(), 44) 932 933 @skipIf(True, "this isn't implemented yet") 934 def test_unwind(self): 935 codestr = f""" 936 from __static__ import int32 937 def raises(): 938 raise IndexError() 939 940 def testfunc(): 941 x: int32 = 1 942 raises() 943 print(x) 944 """ 945 946 with self.in_module(codestr) as mod: 947 f = mod["testfunc"] 948 with self.assertRaises(IndexError): 949 f() 950 951 def test_int_constant_range(self): 952 for type, val, low, high in [ 953 ("int8", 128, -128, 127), 954 ("int8", -129, -128, 127), 955 ("int16", 32768, -32768, 32767), 956 ("int16", -32769, -32768, 32767), 957 ("int32", 2147483648, -2147483648, 2147483647), 958 ("int32", -2147483649, -2147483648, 2147483647), 959 ("int64", 9223372036854775808, -9223372036854775808, 9223372036854775807), 960 ("int64", -9223372036854775809, -9223372036854775808, 9223372036854775807), 961 ("uint8", 257, 0, 255), 962 ("uint8", -1, 0, 255), 963 ("uint16", 65537, 0, 65535), 964 ("uint16", -1, 0, 65535), 965 ("uint32", 4294967297, 0, 4294967295), 966 ("uint32", -1, 0, 4294967295), 967 ("uint64", 18446744073709551617, 0, 18446744073709551615), 968 ("uint64", -1, 0, 18446744073709551615), 969 ]: 970 codestr = f""" 971 from __static__ import {type} 972 def testfunc(tst): 973 x: {type} = {val} 974 """ 975 with self.subTest(type=type, val=val, low=low, high=high): 976 with self.assertRaisesRegex( 977 TypedSyntaxError, 978 f"constant {val} is outside of the range {low} to {high} for {type}", 979 ): 980 self.compile(codestr, StaticCodeGenerator) 981 982 def test_int_assign_float(self): 983 codestr = """ 984 from __static__ import int8 985 def testfunc(tst): 986 x: int8 = 1.0 987 """ 988 with self.assertRaisesRegex( 989 TypedSyntaxError, 990 f"float cannot be used in a context where an int is expected", 991 ): 992 self.compile(codestr, StaticCodeGenerator) 993 994 def test_int_assign_str_constant(self): 995 codestr = """ 996 from __static__ import int8 997 def testfunc(tst): 998 x: int8 = 'abc' + 'def' 999 """ 1000 with self.assertRaisesRegex( 1001 TypedSyntaxError, 1002 f"str cannot be used in a context where an int is expected", 1003 ): 1004 self.compile(codestr, StaticCodeGenerator) 1005 1006 def test_int_large_int_constant(self): 1007 codestr = """ 1008 from __static__ import int64 1009 def testfunc(tst): 1010 x: int64 = 0x7FFFFFFF + 1 1011 """ 1012 code = self.compile(codestr, StaticCodeGenerator) 1013 f = self.find_code(code) 1014 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (0x80000000, TYPED_INT64)) 1015 1016 def test_int_int_constant(self): 1017 codestr = """ 1018 from __static__ import int64 1019 def testfunc(tst): 1020 x: int64 = 0x7FFFFFFE + 1 1021 """ 1022 code = self.compile(codestr, StaticCodeGenerator) 1023 f = self.find_code(code) 1024 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (0x7FFFFFFF, TYPED_INT64)) 1025 1026 def test_int_add_mixed_64(self): 1027 codestr = """ 1028 from __static__ import uint64, int64, box 1029 def testfunc(tst): 1030 x: uint64 = 0 1031 y: int64 = 1 1032 if tst: 1033 x = x + 1 1034 y = y + 2 1035 1036 return box(x + y) 1037 """ 1038 with self.assertRaisesRegex(TypedSyntaxError, "cannot add uint64 and int64"): 1039 self.compile(codestr, StaticCodeGenerator) 1040 1041 def test_int_overflow_add(self): 1042 tests = [ 1043 ("int8", 100, 100, -56), 1044 ("int16", 200, 200, 400), 1045 ("int32", 200, 200, 400), 1046 ("int64", 200, 200, 400), 1047 ("int16", 20000, 20000, -25536), 1048 ("int32", 40000, 40000, 80000), 1049 ("int64", 40000, 40000, 80000), 1050 ("int32", 2000000000, 2000000000, -294967296), 1051 ("int64", 2000000000, 2000000000, 4000000000), 1052 ("int8", 127, 127, -2), 1053 ("int16", 32767, 32767, -2), 1054 ("int32", 2147483647, 2147483647, -2), 1055 ("int64", 9223372036854775807, 9223372036854775807, -2), 1056 ("uint8", 200, 200, 144), 1057 ("uint16", 200, 200, 400), 1058 ("uint32", 200, 200, 400), 1059 ("uint64", 200, 200, 400), 1060 ("uint16", 40000, 40000, 14464), 1061 ("uint32", 40000, 40000, 80000), 1062 ("uint64", 40000, 40000, 80000), 1063 ("uint32", 2000000000, 2000000000, 4000000000), 1064 ("uint64", 2000000000, 2000000000, 4000000000), 1065 ("uint8", 1 << 7, 1 << 7, 0), 1066 ("uint16", 1 << 15, 1 << 15, 0), 1067 ("uint32", 1 << 31, 1 << 31, 0), 1068 ("uint64", 1 << 63, 1 << 63, 0), 1069 ("uint8", 1 << 6, 1 << 6, 128), 1070 ("uint16", 1 << 14, 1 << 14, 32768), 1071 ("uint32", 1 << 30, 1 << 30, 2147483648), 1072 ("uint64", 1 << 62, 1 << 62, 9223372036854775808), 1073 ] 1074 1075 for type, x, y, res in tests: 1076 codestr = f""" 1077 from __static__ import {type}, box 1078 def f(): 1079 x: {type} = {x} 1080 y: {type} = {y} 1081 z: {type} = x + y 1082 return box(z) 1083 """ 1084 with self.subTest(type=type, x=x, y=y, res=res): 1085 code = self.compile(codestr, StaticCodeGenerator) 1086 f = self.run_code(codestr, StaticCodeGenerator)["f"] 1087 self.assertEqual(f(), res, f"{type} {x} {y} {res}") 1088 1089 def test_int_unary(self): 1090 tests = [ 1091 ("int8", "-", 1, -1), 1092 ("uint8", "-", 1, (1 << 8) - 1), 1093 ("int16", "-", 1, -1), 1094 ("int16", "-", 256, -256), 1095 ("uint16", "-", 1, (1 << 16) - 1), 1096 ("int32", "-", 1, -1), 1097 ("int32", "-", 65536, -65536), 1098 ("uint32", "-", 1, (1 << 32) - 1), 1099 ("int64", "-", 1, -1), 1100 ("int64", "-", 1 << 32, -(1 << 32)), 1101 ("uint64", "-", 1, (1 << 64) - 1), 1102 ("int8", "~", 1, -2), 1103 ("uint8", "~", 1, (1 << 8) - 2), 1104 ("int16", "~", 1, -2), 1105 ("uint16", "~", 1, (1 << 16) - 2), 1106 ("int32", "~", 1, -2), 1107 ("uint32", "~", 1, (1 << 32) - 2), 1108 ("int64", "~", 1, -2), 1109 ("uint64", "~", 1, (1 << 64) - 2), 1110 ] 1111 for type, op, x, res in tests: 1112 codestr = f""" 1113 from __static__ import {type}, box 1114 def testfunc(tst): 1115 x: {type} = {x} 1116 if tst: 1117 x = x + 1 1118 x = {op}x 1119 return box(x) 1120 """ 1121 with self.subTest(type=type, op=op, x=x, res=res): 1122 code = self.compile(codestr, StaticCodeGenerator) 1123 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"] 1124 self.assertEqual(f(False), res, f"{type} {op} {x} {res}") 1125 1126 def test_int_compare(self): 1127 tests = [ 1128 ("int8", 1, 2, "==", False), 1129 ("int8", 1, 2, "!=", True), 1130 ("int8", 1, 2, "<", True), 1131 ("int8", 1, 2, "<=", True), 1132 ("int8", 2, 1, "<", False), 1133 ("int8", 2, 1, "<=", False), 1134 ("int8", -1, 2, "==", False), 1135 ("int8", -1, 2, "!=", True), 1136 ("int8", -1, 2, "<", True), 1137 ("int8", -1, 2, "<=", True), 1138 ("int8", 2, -1, "<", False), 1139 ("int8", 2, -1, "<=", False), 1140 ("uint8", 1, 2, "==", False), 1141 ("uint8", 1, 2, "!=", True), 1142 ("uint8", 1, 2, "<", True), 1143 ("uint8", 1, 2, "<=", True), 1144 ("uint8", 2, 1, "<", False), 1145 ("uint8", 2, 1, "<=", False), 1146 ("uint8", 255, 2, "==", False), 1147 ("uint8", 255, 2, "!=", True), 1148 ("uint8", 255, 2, "<", False), 1149 ("uint8", 255, 2, "<=", False), 1150 ("uint8", 2, 255, "<", True), 1151 ("uint8", 2, 255, "<=", True), 1152 ("int16", 1, 2, "==", False), 1153 ("int16", 1, 2, "!=", True), 1154 ("int16", 1, 2, "<", True), 1155 ("int16", 1, 2, "<=", True), 1156 ("int16", 2, 1, "<", False), 1157 ("int16", 2, 1, "<=", False), 1158 ("int16", -1, 2, "==", False), 1159 ("int16", -1, 2, "!=", True), 1160 ("int16", -1, 2, "<", True), 1161 ("int16", -1, 2, "<=", True), 1162 ("int16", 2, -1, "<", False), 1163 ("int16", 2, -1, "<=", False), 1164 ("uint16", 1, 2, "==", False), 1165 ("uint16", 1, 2, "!=", True), 1166 ("uint16", 1, 2, "<", True), 1167 ("uint16", 1, 2, "<=", True), 1168 ("uint16", 2, 1, "<", False), 1169 ("uint16", 2, 1, "<=", False), 1170 ("uint16", 65535, 2, "==", False), 1171 ("uint16", 65535, 2, "!=", True), 1172 ("uint16", 65535, 2, "<", False), 1173 ("uint16", 65535, 2, "<=", False), 1174 ("uint16", 2, 65535, "<", True), 1175 ("uint16", 2, 65535, "<=", True), 1176 ("int32", 1, 2, "==", False), 1177 ("int32", 1, 2, "!=", True), 1178 ("int32", 1, 2, "<", True), 1179 ("int32", 1, 2, "<=", True), 1180 ("int32", 2, 1, "<", False), 1181 ("int32", 2, 1, "<=", False), 1182 ("int32", -1, 2, "==", False), 1183 ("int32", -1, 2, "!=", True), 1184 ("int32", -1, 2, "<", True), 1185 ("int32", -1, 2, "<=", True), 1186 ("int32", 2, -1, "<", False), 1187 ("int32", 2, -1, "<=", False), 1188 ("uint32", 1, 2, "==", False), 1189 ("uint32", 1, 2, "!=", True), 1190 ("uint32", 1, 2, "<", True), 1191 ("uint32", 1, 2, "<=", True), 1192 ("uint32", 2, 1, "<", False), 1193 ("uint32", 2, 1, "<=", False), 1194 ("uint32", 4294967295, 2, "!=", True), 1195 ("uint32", 4294967295, 2, "<", False), 1196 ("uint32", 4294967295, 2, "<=", False), 1197 ("uint32", 2, 4294967295, "<", True), 1198 ("uint32", 2, 4294967295, "<=", True), 1199 ("int64", 1, 2, "==", False), 1200 ("int64", 1, 2, "!=", True), 1201 ("int64", 1, 2, "<", True), 1202 ("int64", 1, 2, "<=", True), 1203 ("int64", 2, 1, "<", False), 1204 ("int64", 2, 1, "<=", False), 1205 ("int64", -1, 2, "==", False), 1206 ("int64", -1, 2, "!=", True), 1207 ("int64", -1, 2, "<", True), 1208 ("int64", -1, 2, "<=", True), 1209 ("int64", 2, -1, "<", False), 1210 ("int64", 2, -1, "<=", False), 1211 ("uint64", 1, 2, "==", False), 1212 ("uint64", 1, 2, "!=", True), 1213 ("uint64", 1, 2, "<", True), 1214 ("uint64", 1, 2, "<=", True), 1215 ("uint64", 2, 1, "<", False), 1216 ("uint64", 2, 1, "<=", False), 1217 ("int64", 2, -1, ">", True), 1218 ("uint64", 2, 18446744073709551615, ">", False), 1219 ("int64", 2, -1, "<", False), 1220 ("uint64", 2, 18446744073709551615, "<", True), 1221 ("int64", 2, -1, ">=", True), 1222 ("uint64", 2, 18446744073709551615, ">=", False), 1223 ("int64", 2, -1, "<=", False), 1224 ("uint64", 2, 18446744073709551615, "<=", True), 1225 ] 1226 for type, x, y, op, res in tests: 1227 codestr = f""" 1228 from __static__ import {type}, box 1229 def testfunc(tst): 1230 x: {type} = {x} 1231 y: {type} = {y} 1232 if tst: 1233 x = x + 1 1234 y = y + 2 1235 1236 if x {op} y: 1237 return True 1238 return False 1239 """ 1240 with self.subTest(type=type, x=x, y=y, op=op, res=res): 1241 code = self.compile(codestr, StaticCodeGenerator) 1242 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"] 1243 self.assertEqual(f(False), res, f"{type} {x} {op} {y} {res}") 1244 1245 def test_int_compare_unboxed(self): 1246 codestr = f""" 1247 from __static__ import ssize_t, unbox 1248 def testfunc(x, y): 1249 x1: ssize_t = unbox(x) 1250 y1: ssize_t = unbox(y) 1251 1252 if x1 > y1: 1253 return True 1254 return False 1255 """ 1256 code = self.compile(codestr, StaticCodeGenerator) 1257 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"] 1258 self.assertInBytecode(f, "POP_JUMP_IF_ZERO") 1259 self.assertEqual(f(1, 2), False) 1260 1261 def test_int_compare_mixed(self): 1262 codestr = """ 1263 from __static__ import box, ssize_t 1264 x = 1 1265 1266 def testfunc(): 1267 i: ssize_t = 0 1268 j = 0 1269 while box(i < 100) and x: 1270 i = i + 1 1271 j = j + 1 1272 return j 1273 """ 1274 1275 code = self.compile(codestr, StaticCodeGenerator) 1276 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"] 1277 self.assertEqual(f(), 100) 1278 self.assert_jitted(f) 1279 1280 def test_int_compare_or(self): 1281 codestr = """ 1282 from __static__ import box, ssize_t 1283 1284 def testfunc(): 1285 i: ssize_t = 0 1286 j = i > 2 or i < -2 1287 return box(j) 1288 """ 1289 1290 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 1291 f = mod["testfunc"] 1292 self.assertInBytecode(f, "JUMP_IF_NONZERO_OR_POP") 1293 self.assertIs(f(), False) 1294 1295 def test_int_compare_and(self): 1296 codestr = """ 1297 from __static__ import box, ssize_t 1298 1299 def testfunc(): 1300 i: ssize_t = 0 1301 j = i > 2 and i > 3 1302 return box(j) 1303 """ 1304 1305 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 1306 f = mod["testfunc"] 1307 self.assertInBytecode(f, "JUMP_IF_ZERO_OR_POP") 1308 self.assertIs(f(), False) 1309 1310 def test_disallow_prim_nonprim_union(self): 1311 codestr = """ 1312 from __static__ import int32 1313 1314 def f(y: int): 1315 x: int32 = 2 1316 z = x or y 1317 return z 1318 """ 1319 with self.assertRaisesRegex( 1320 TypedSyntaxError, 1321 r"invalid union type Union\[int32, int\]; unions cannot include primitive types", 1322 ): 1323 self.compile(codestr) 1324 1325 def test_int_binop(self): 1326 tests = [ 1327 ("int8", 1, 2, "/", 0), 1328 ("int8", 4, 2, "/", 2), 1329 ("int8", 4, -2, "/", -2), 1330 ("uint8", 0xFF, 0x7F, "/", 2), 1331 ("int16", 4, -2, "/", -2), 1332 ("uint16", 0xFF, 0x7F, "/", 2), 1333 ("uint32", 0xFFFF, 0x7FFF, "/", 2), 1334 ("int32", 4, -2, "/", -2), 1335 ("uint32", 0xFF, 0x7F, "/", 2), 1336 ("uint32", 0xFFFFFFFF, 0x7FFFFFFF, "/", 2), 1337 ("int64", 4, -2, "/", -2), 1338 ("uint64", 0xFF, 0x7F, "/", 2), 1339 ("uint64", 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, "/", 2), 1340 ("int8", 1, -2, "-", 3), 1341 ("int8", 1, 2, "-", -1), 1342 ("int16", 1, -2, "-", 3), 1343 ("int16", 1, 2, "-", -1), 1344 ("int32", 1, -2, "-", 3), 1345 ("int32", 1, 2, "-", -1), 1346 ("int64", 1, -2, "-", 3), 1347 ("int64", 1, 2, "-", -1), 1348 ("int8", 1, -2, "*", -2), 1349 ("int8", 1, 2, "*", 2), 1350 ("int16", 1, -2, "*", -2), 1351 ("int16", 1, 2, "*", 2), 1352 ("int32", 1, -2, "*", -2), 1353 ("int32", 1, 2, "*", 2), 1354 ("int64", 1, -2, "*", -2), 1355 ("int64", 1, 2, "*", 2), 1356 ("int8", 1, -2, "&", 0), 1357 ("int8", 1, 3, "&", 1), 1358 ("int16", 1, 3, "&", 1), 1359 ("int16", 1, 3, "&", 1), 1360 ("int32", 1, 3, "&", 1), 1361 ("int32", 1, 3, "&", 1), 1362 ("int64", 1, 3, "&", 1), 1363 ("int64", 1, 3, "&", 1), 1364 ("int8", 1, 2, "|", 3), 1365 ("uint8", 1, 2, "|", 3), 1366 ("int16", 1, 2, "|", 3), 1367 ("uint16", 1, 2, "|", 3), 1368 ("int32", 1, 2, "|", 3), 1369 ("uint32", 1, 2, "|", 3), 1370 ("int64", 1, 2, "|", 3), 1371 ("uint64", 1, 2, "|", 3), 1372 ("int8", 1, 3, "^", 2), 1373 ("uint8", 1, 3, "^", 2), 1374 ("int16", 1, 3, "^", 2), 1375 ("uint16", 1, 3, "^", 2), 1376 ("int32", 1, 3, "^", 2), 1377 ("uint32", 1, 3, "^", 2), 1378 ("int64", 1, 3, "^", 2), 1379 ("uint64", 1, 3, "^", 2), 1380 ("int8", 1, 3, "%", 1), 1381 ("uint8", 1, 3, "%", 1), 1382 ("int16", 1, 3, "%", 1), 1383 ("uint16", 1, 3, "%", 1), 1384 ("int32", 1, 3, "%", 1), 1385 ("uint32", 1, 3, "%", 1), 1386 ("int64", 1, 3, "%", 1), 1387 ("uint64", 1, 3, "%", 1), 1388 ("int8", 1, -3, "%", 1), 1389 ("uint8", 1, 0xFF, "%", 1), 1390 ("int16", 1, -3, "%", 1), 1391 ("uint16", 1, 0xFFFF, "%", 1), 1392 ("int32", 1, -3, "%", 1), 1393 ("uint32", 1, 0xFFFFFFFF, "%", 1), 1394 ("int64", 1, -3, "%", 1), 1395 ("uint64", 1, 0xFFFFFFFFFFFFFFFF, "%", 1), 1396 ("int8", 1, 2, "<<", 4), 1397 ("uint8", 1, 2, "<<", 4), 1398 ("int16", 1, 2, "<<", 4), 1399 ("uint16", 1, 2, "<<", 4), 1400 ("int32", 1, 2, "<<", 4), 1401 ("uint32", 1, 2, "<<", 4), 1402 ("int64", 1, 2, "<<", 4), 1403 ("uint64", 1, 2, "<<", 4), 1404 ("int8", 4, 1, ">>", 2), 1405 ("int8", -1, 1, ">>", -1), 1406 ("uint8", 0xFF, 1, ">>", 127), 1407 ("int16", 4, 1, ">>", 2), 1408 ("int16", -1, 1, ">>", -1), 1409 ("uint16", 0xFFFF, 1, ">>", 32767), 1410 ("int32", 4, 1, ">>", 2), 1411 ("int32", -1, 1, ">>", -1), 1412 ("uint32", 0xFFFFFFFF, 1, ">>", 2147483647), 1413 ("int64", 4, 1, ">>", 2), 1414 ("int64", -1, 1, ">>", -1), 1415 ("uint64", 0xFFFFFFFFFFFFFFFF, 1, ">>", 9223372036854775807), 1416 ] 1417 for type, x, y, op, res in tests: 1418 codestr = f""" 1419 from __static__ import {type}, box 1420 def testfunc(tst): 1421 x: {type} = {x} 1422 y: {type} = {y} 1423 if tst: 1424 x = x + 1 1425 y = y + 2 1426 1427 z: {type} = x {op} y 1428 return box(z), box(x {op} y) 1429 """ 1430 with self.subTest(type=type, x=x, y=y, op=op, res=res): 1431 with self.in_module(codestr) as mod: 1432 f = mod["testfunc"] 1433 self.assertEqual(f(False), (res, res), f"{type} {x} {op} {y} {res}") 1434 1435 def test_primitive_arithmetic(self): 1436 cases = [ 1437 ("int8", 127, "*", 1, 127), 1438 ("int8", -64, "*", 2, -128), 1439 ("int8", 0, "*", 4, 0), 1440 ("uint8", 51, "*", 5, 255), 1441 ("uint8", 5, "*", 0, 0), 1442 ("int16", 3123, "*", -10, -31230), 1443 ("int16", -32767, "*", -1, 32767), 1444 ("int16", -32768, "*", 1, -32768), 1445 ("int16", 3, "*", 0, 0), 1446 ("uint16", 65535, "*", 1, 65535), 1447 ("uint16", 0, "*", 4, 0), 1448 ("int32", (1 << 31) - 1, "*", 1, (1 << 31) - 1), 1449 ("int32", -(1 << 30), "*", 2, -(1 << 31)), 1450 ("int32", 0, "*", 1, 0), 1451 ("uint32", (1 << 32) - 1, "*", 1, (1 << 32) - 1), 1452 ("uint32", 0, "*", 4, 0), 1453 ("int64", (1 << 63) - 1, "*", 1, (1 << 63) - 1), 1454 ("int64", -(1 << 62), "*", 2, -(1 << 63)), 1455 ("int64", 0, "*", 1, 0), 1456 ("uint64", (1 << 64) - 1, "*", 1, (1 << 64) - 1), 1457 ("uint64", 0, "*", 4, 0), 1458 ("int8", 127, "//", 4, 31), 1459 ("int8", -128, "//", 4, -32), 1460 ("int8", 0, "//", 4, 0), 1461 ("uint8", 255, "//", 5, 51), 1462 ("uint8", 0, "//", 5, 0), 1463 ("int16", 32767, "//", -1000, -32), 1464 ("int16", -32768, "//", -1000, 32), 1465 ("int16", 0, "//", 4, 0), 1466 ("uint16", 65535, "//", 5, 13107), 1467 ("uint16", 0, "//", 4, 0), 1468 ("int32", (1 << 31) - 1, "//", (1 << 31) - 1, 1), 1469 ("int32", -(1 << 31), "//", 1, -(1 << 31)), 1470 ("int32", 0, "//", 1, 0), 1471 ("uint32", (1 << 32) - 1, "//", 500, 8589934), 1472 ("uint32", 0, "//", 4, 0), 1473 ("int64", (1 << 63) - 1, "//", 2, (1 << 62) - 1), 1474 ("int64", -(1 << 63), "//", 2, -(1 << 62)), 1475 ("int64", 0, "//", 1, 0), 1476 ("uint64", (1 << 64) - 1, "//", (1 << 64) - 1, 1), 1477 ("uint64", 0, "//", 4, 0), 1478 ("int8", 127, "%", 4, 3), 1479 ("int8", -128, "%", 4, 0), 1480 ("int8", 0, "%", 4, 0), 1481 ("uint8", 255, "%", 6, 3), 1482 ("uint8", 0, "%", 5, 0), 1483 ("int16", 32767, "%", -1000, 767), 1484 ("int16", -32768, "%", -1000, -768), 1485 ("int16", 0, "%", 4, 0), 1486 ("uint16", 65535, "%", 7, 1), 1487 ("uint16", 0, "%", 4, 0), 1488 ("int32", (1 << 31) - 1, "%", (1 << 31) - 1, 0), 1489 ("int32", -(1 << 31), "%", 1, 0), 1490 ("int32", 0, "%", 1, 0), 1491 ("uint32", (1 << 32) - 1, "%", 500, 295), 1492 ("uint32", 0, "%", 4, 0), 1493 ("int64", (1 << 63) - 1, "%", 2, 1), 1494 ("int64", -(1 << 63), "%", 2, 0), 1495 ("int64", 0, "%", 1, 0), 1496 ("uint64", (1 << 64) - 1, "%", (1 << 64) - 1, 0), 1497 ("uint64", 0, "%", 4, 0), 1498 ] 1499 for typ, a, op, b, res in cases: 1500 for const in ["noconst", "constfirst", "constsecond"]: 1501 if const == "noconst": 1502 codestr = f""" 1503 from __static__ import {typ} 1504 1505 def f(a: {typ}, b: {typ}) -> {typ}: 1506 return a {op} b 1507 """ 1508 elif const == "constfirst": 1509 codestr = f""" 1510 from __static__ import {typ} 1511 1512 def f(b: {typ}) -> {typ}: 1513 return {a} {op} b 1514 """ 1515 elif const == "constsecond": 1516 codestr = f""" 1517 from __static__ import {typ} 1518 1519 def f(a: {typ}) -> {typ}: 1520 return a {op} {b} 1521 """ 1522 1523 with self.subTest(typ=typ, a=a, op=op, b=b, res=res, const=const): 1524 with self.in_module(codestr) as mod: 1525 f = mod["f"] 1526 act = None 1527 if const == "noconst": 1528 act = f(a, b) 1529 elif const == "constfirst": 1530 act = f(b) 1531 elif const == "constsecond": 1532 act = f(a) 1533 self.assertEqual(act, res) 1534 1535 def test_int_binop_type_context(self): 1536 codestr = f""" 1537 from __static__ import box, int8, int16 1538 1539 def f(x: int8, y: int8) -> int: 1540 z: int16 = x * y 1541 return box(z) 1542 """ 1543 with self.in_module(codestr) as mod: 1544 f = mod["f"] 1545 self.assertInBytecode( 1546 f, "CONVERT_PRIMITIVE", TYPED_INT8 | (TYPED_INT16 << 4) 1547 ) 1548 self.assertEqual(f(120, 120), 14400) 1549 1550 def test_int_compare_mixed_sign(self): 1551 tests = [ 1552 ("uint16", 10000, "int16", -1, "<", False), 1553 ("uint16", 10000, "int16", -1, "<=", False), 1554 ("int16", -1, "uint16", 10000, ">", False), 1555 ("int16", -1, "uint16", 10000, ">=", False), 1556 ("uint32", 10000, "int16", -1, "<", False), 1557 ] 1558 for type1, x, type2, y, op, res in tests: 1559 codestr = f""" 1560 from __static__ import {type1}, {type2}, box 1561 def testfunc(tst): 1562 x: {type1} = {x} 1563 y: {type2} = {y} 1564 if tst: 1565 x = x + 1 1566 y = y + 2 1567 1568 if x {op} y: 1569 return True 1570 return False 1571 """ 1572 with self.subTest(type1=type1, x=x, type2=type2, y=y, op=op, res=res): 1573 code = self.compile(codestr, StaticCodeGenerator) 1574 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"] 1575 self.assertEqual(f(False), res, f"{type} {x} {op} {y} {res}") 1576 1577 def test_int_compare64_mixed_sign(self): 1578 codestr = """ 1579 from __static__ import uint64, int64 1580 def testfunc(tst): 1581 x: uint64 = 0 1582 y: int64 = 1 1583 if tst: 1584 x = x + 1 1585 y = y + 2 1586 1587 if x < y: 1588 return True 1589 return False 1590 """ 1591 with self.assertRaises(TypedSyntaxError): 1592 self.compile(codestr, StaticCodeGenerator) 1593 1594 def test_compile_method(self): 1595 code = self.compile( 1596 """ 1597 from __static__ import ssize_t 1598 def f(): 1599 x: ssize_t = 42 1600 """, 1601 StaticCodeGenerator, 1602 ) 1603 1604 f = self.find_code(code) 1605 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (42, TYPED_INT64)) 1606 1607 def test_mixed_compare(self): 1608 codestr = """ 1609 from __static__ import ssize_t, box, unbox 1610 def f(a): 1611 x: ssize_t = 0 1612 while x != a: 1613 pass 1614 """ 1615 with self.assertRaisesRegex(TypedSyntaxError, "can't compare int64 to dynamic"): 1616 self.compile(codestr, StaticCodeGenerator) 1617 1618 def test_unbox(self): 1619 for size, val in [ 1620 ("int8", 126), 1621 ("int8", -128), 1622 ("int16", 32766), 1623 ("int16", -32768), 1624 ("int32", 2147483646), 1625 ("int32", -2147483648), 1626 ("int64", 9223372036854775806), 1627 ("int64", -9223372036854775808), 1628 ("uint8", 254), 1629 ("uint16", 65534), 1630 ("uint32", 4294967294), 1631 ("uint64", 18446744073709551614), 1632 ]: 1633 codestr = f""" 1634 from __static__ import {size}, box, unbox 1635 def f(x): 1636 y: {size} = unbox(x) 1637 y = y + 1 1638 return box(y) 1639 """ 1640 1641 code = self.compile(codestr, StaticCodeGenerator) 1642 f = self.find_code(code) 1643 f = self.run_code(codestr, StaticCodeGenerator)["f"] 1644 self.assertEqual(f(val), val + 1) 1645 1646 def test_int_loop_inplace(self): 1647 codestr = """ 1648 from __static__ import ssize_t, box 1649 def f(): 1650 i: ssize_t = 0 1651 while i < 100: 1652 i += 1 1653 return box(i) 1654 """ 1655 1656 code = self.compile(codestr, StaticCodeGenerator) 1657 f = self.find_code(code) 1658 f = self.run_code(codestr, StaticCodeGenerator)["f"] 1659 self.assertEqual(f(), 100) 1660 1661 def test_int_loop(self): 1662 codestr = """ 1663 from __static__ import ssize_t, box 1664 def testfunc(): 1665 i: ssize_t = 0 1666 while i < 100: 1667 i = i + 1 1668 return box(i) 1669 """ 1670 1671 code = self.compile(codestr, StaticCodeGenerator) 1672 f = self.find_code(code) 1673 1674 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"] 1675 self.assertEqual(f(), 100) 1676 1677 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (0, TYPED_INT64)) 1678 self.assertInBytecode(f, "LOAD_LOCAL", (0, ("__static__", "int64"))) 1679 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT) 1680 self.assertInBytecode(f, "INT_COMPARE_OP", PRIM_OP_LT_INT) 1681 self.assertInBytecode(f, "POP_JUMP_IF_ZERO") 1682 1683 def test_int_assert(self): 1684 codestr = """ 1685 from __static__ import ssize_t, box 1686 def testfunc(): 1687 i: ssize_t = 0 1688 assert i == 0, "hello there" 1689 """ 1690 1691 code = self.compile(codestr, StaticCodeGenerator) 1692 f = self.find_code(code) 1693 self.assertInBytecode(f, "POP_JUMP_IF_NONZERO") 1694 1695 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"] 1696 self.assertEqual(f(), None) 1697 1698 def test_int_assert_raises(self): 1699 codestr = """ 1700 from __static__ import ssize_t, box 1701 def testfunc(): 1702 i: ssize_t = 0 1703 assert i != 0, "hello there" 1704 """ 1705 1706 code = self.compile(codestr, StaticCodeGenerator) 1707 f = self.find_code(code) 1708 self.assertInBytecode(f, "POP_JUMP_IF_NONZERO") 1709 1710 with self.assertRaises(AssertionError): 1711 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"] 1712 self.assertEqual(f(), None) 1713 1714 def test_int_loop_reversed(self): 1715 codestr = """ 1716 from __static__ import ssize_t, box 1717 def testfunc(): 1718 i: ssize_t = 0 1719 while 100 > i: 1720 i = i + 1 1721 return box(i) 1722 """ 1723 1724 code = self.compile(codestr, StaticCodeGenerator) 1725 f = self.find_code(code) 1726 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"] 1727 self.assertEqual(f(), 100) 1728 1729 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (0, TYPED_INT64)) 1730 self.assertInBytecode(f, "LOAD_LOCAL", (0, ("__static__", "int64"))) 1731 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT) 1732 self.assertInBytecode(f, "INT_COMPARE_OP", PRIM_OP_GT_INT) 1733 self.assertInBytecode(f, "POP_JUMP_IF_ZERO") 1734 1735 def test_int_loop_chained(self): 1736 codestr = """ 1737 from __static__ import ssize_t, box 1738 def testfunc(): 1739 i: ssize_t = 0 1740 while -1 < i < 100: 1741 i = i + 1 1742 return box(i) 1743 """ 1744 1745 code = self.compile(codestr, StaticCodeGenerator) 1746 f = self.find_code(code) 1747 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"] 1748 self.assertEqual(f(), 100) 1749 1750 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (0, TYPED_INT64)) 1751 self.assertInBytecode(f, "LOAD_LOCAL", (0, ("__static__", "int64"))) 1752 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT) 1753 self.assertInBytecode(f, "INT_COMPARE_OP", PRIM_OP_LT_INT) 1754 self.assertInBytecode(f, "POP_JUMP_IF_ZERO") 1755 1756 def test_compare_subclass(self): 1757 codestr = """ 1758 class C: pass 1759 class D(C): pass 1760 1761 x = C() > D() 1762 """ 1763 code = self.compile(codestr, StaticCodeGenerator) 1764 self.assertInBytecode(code, "COMPARE_OP") 1765 1766 def test_compat_int_math(self): 1767 codestr = """ 1768 from __static__ import ssize_t, box 1769 def f(): 1770 x: ssize_t = 42 1771 z: ssize_t = 1 + x 1772 return box(z) 1773 """ 1774 1775 code = self.compile(codestr, StaticCodeGenerator) 1776 f = self.find_code(code) 1777 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (42, TYPED_INT64)) 1778 self.assertInBytecode(f, "LOAD_LOCAL", (0, ("__static__", "int64"))) 1779 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT) 1780 f = self.run_code(codestr, StaticCodeGenerator)["f"] 1781 self.assertEqual(f(), 43) 1782 1783 def test_unbox_long(self): 1784 codestr = """ 1785 from __static__ import unbox, int64 1786 def f(): 1787 x:int64 = unbox(1) 1788 """ 1789 1790 self.compile(codestr, StaticCodeGenerator) 1791 1792 def test_unbox_str(self): 1793 codestr = """ 1794 from __static__ import unbox, int64 1795 def f(): 1796 x:int64 = unbox('abc') 1797 """ 1798 1799 with self.in_module(codestr) as mod: 1800 f = mod["f"] 1801 with self.assertRaisesRegex( 1802 TypeError, "(expected int, got str)|(an integer is required)" 1803 ): 1804 f() 1805 1806 def test_unbox_typed(self): 1807 codestr = """ 1808 from __static__ import int64, box 1809 def f(i: object): 1810 x = int64(i) 1811 return box(x) 1812 """ 1813 1814 with self.in_module(codestr) as mod: 1815 f = mod["f"] 1816 self.assertEqual(f(42), 42) 1817 self.assertInBytecode(f, "PRIMITIVE_UNBOX") 1818 with self.assertRaisesRegex( 1819 TypeError, "(expected int, got str)|(an integer is required)" 1820 ): 1821 self.assertEqual(f("abc"), 42) 1822 1823 def test_unbox_typed_bool(self): 1824 codestr = """ 1825 from __static__ import int64, box 1826 def f(i: object): 1827 x = int64(i) 1828 return box(x) 1829 """ 1830 1831 with self.in_module(codestr) as mod: 1832 f = mod["f"] 1833 self.assertEqual(f(42), 42) 1834 self.assertInBytecode(f, "PRIMITIVE_UNBOX") 1835 self.assertEqual(f(True), 1) 1836 self.assertEqual(f(False), 0) 1837 1838 def test_unbox_incompat_type(self): 1839 codestr = """ 1840 from __static__ import int64, box 1841 def f(i: str): 1842 x:int64 = int64(i) 1843 return box(x) 1844 """ 1845 1846 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("str", "int64")): 1847 self.compile(codestr) 1848 1849 def test_uninit_value(self): 1850 codestr = """ 1851 from __static__ import box, int64 1852 def f(): 1853 x:int64 1854 return box(x) 1855 x = 0 1856 """ 1857 f = self.run_code(codestr, StaticCodeGenerator)["f"] 1858 self.assertEqual(f(), 0) 1859 1860 def test_uninit_value_2(self): 1861 codestr = """ 1862 from __static__ import box, int64 1863 def testfunc(x): 1864 if x: 1865 y:int64 = 42 1866 return box(y) 1867 """ 1868 f = self.run_code(codestr, StaticCodeGenerator)["testfunc"] 1869 self.assertEqual(f(False), 0) 1870 1871 def test_bad_box(self): 1872 codestr = """ 1873 from __static__ import box 1874 box('abc') 1875 """ 1876 1877 with self.assertRaisesRegex( 1878 TypedSyntaxError, "can't box non-primitive: Exact\\[str\\]" 1879 ): 1880 self.compile(codestr, StaticCodeGenerator) 1881 1882 def test_bad_unbox(self): 1883 codestr = """ 1884 from __static__ import unbox, int64 1885 def f(): 1886 x:int64 = 42 1887 unbox(x) 1888 """ 1889 1890 with self.assertRaisesRegex( 1891 TypedSyntaxError, type_mismatch("int64", "dynamic") 1892 ): 1893 self.compile(codestr, StaticCodeGenerator) 1894 1895 def test_bad_box_2(self): 1896 codestr = """ 1897 from __static__ import box 1898 box('abc', 'foo') 1899 """ 1900 1901 with self.assertRaisesRegex( 1902 TypedSyntaxError, "box only accepts a single argument" 1903 ): 1904 self.compile(codestr, StaticCodeGenerator) 1905 1906 def test_bad_unbox_2(self): 1907 codestr = """ 1908 from __static__ import unbox, int64 1909 def f(): 1910 x:int64 = 42 1911 unbox(x, y) 1912 """ 1913 1914 with self.assertRaisesRegex( 1915 TypedSyntaxError, "unbox only accepts a single argument" 1916 ): 1917 self.compile(codestr, StaticCodeGenerator) 1918 1919 def test_int_reassign(self): 1920 codestr = """ 1921 from __static__ import ssize_t, box 1922 def f(): 1923 x: ssize_t = 42 1924 z: ssize_t = 1 + x 1925 x = 100 1926 x = x + x 1927 return box(z) 1928 """ 1929 1930 code = self.compile(codestr, StaticCodeGenerator) 1931 f = self.find_code(code) 1932 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (42, TYPED_INT64)) 1933 self.assertInBytecode(f, "LOAD_LOCAL", (0, ("__static__", "int64"))) 1934 self.assertInBytecode(f, "PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT) 1935 f = self.run_code(codestr, StaticCodeGenerator)["f"] 1936 self.assertEqual(f(), 43) 1937 1938 def test_assign_to_object(self): 1939 codestr = """ 1940 def f(): 1941 x: object 1942 x = None 1943 x = 1 1944 x = 'abc' 1945 x = [] 1946 x = {} 1947 x = {1, 2} 1948 x = () 1949 x = 1.0 1950 x = 1j 1951 x = b'foo' 1952 x = int 1953 x = True 1954 x = NotImplemented 1955 x = ... 1956 """ 1957 1958 self.compile(codestr, StaticCodeGenerator) 1959 1960 def test_global_call_add(self) -> None: 1961 codestr = """ 1962 X = ord(42) 1963 def f(): 1964 y = X + 1 1965 """ 1966 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 1967 1968 def test_type_binder(self) -> None: 1969 self.assertEqual(self.bind_expr("42"), INT_EXACT_TYPE.instance) 1970 self.assertEqual(self.bind_expr("42.0"), FLOAT_EXACT_TYPE.instance) 1971 self.assertEqual(self.bind_expr("'abc'"), STR_EXACT_TYPE.instance) 1972 self.assertEqual(self.bind_expr("b'abc'"), BYTES_TYPE.instance) 1973 self.assertEqual(self.bind_expr("3j"), COMPLEX_EXACT_TYPE.instance) 1974 self.assertEqual(self.bind_expr("None"), NONE_TYPE.instance) 1975 self.assertEqual(self.bind_expr("True"), BOOL_TYPE.instance) 1976 self.assertEqual(self.bind_expr("False"), BOOL_TYPE.instance) 1977 self.assertEqual(self.bind_expr("..."), ELLIPSIS_TYPE.instance) 1978 self.assertEqual(self.bind_expr("f''"), STR_EXACT_TYPE.instance) 1979 self.assertEqual(self.bind_expr("f'{x}'"), STR_EXACT_TYPE.instance) 1980 1981 self.assertEqual(self.bind_expr("a"), DYNAMIC) 1982 self.assertEqual(self.bind_expr("a.b"), DYNAMIC) 1983 self.assertEqual(self.bind_expr("a + b"), DYNAMIC) 1984 1985 self.assertEqual(self.bind_expr("1 + 2"), INT_EXACT_TYPE.instance) 1986 self.assertEqual(self.bind_expr("1 - 2"), INT_EXACT_TYPE.instance) 1987 self.assertEqual(self.bind_expr("1 // 2"), INT_EXACT_TYPE.instance) 1988 self.assertEqual(self.bind_expr("1 * 2"), INT_EXACT_TYPE.instance) 1989 self.assertEqual(self.bind_expr("1 / 2"), FLOAT_EXACT_TYPE.instance) 1990 self.assertEqual(self.bind_expr("1 % 2"), INT_EXACT_TYPE.instance) 1991 self.assertEqual(self.bind_expr("1 & 2"), INT_EXACT_TYPE.instance) 1992 self.assertEqual(self.bind_expr("1 | 2"), INT_EXACT_TYPE.instance) 1993 self.assertEqual(self.bind_expr("1 ^ 2"), INT_EXACT_TYPE.instance) 1994 self.assertEqual(self.bind_expr("1 << 2"), INT_EXACT_TYPE.instance) 1995 self.assertEqual(self.bind_expr("100 >> 2"), INT_EXACT_TYPE.instance) 1996 1997 self.assertEqual(self.bind_stmt("x = 1"), INT_EXACT_TYPE.instance) 1998 # self.assertEqual(self.bind_stmt("x: foo = 1").target.comp_type, DYNAMIC) 1999 self.assertEqual(self.bind_stmt("x += 1"), DYNAMIC) 2000 self.assertEqual(self.bind_expr("a or b"), DYNAMIC) 2001 self.assertEqual(self.bind_expr("+a"), DYNAMIC) 2002 self.assertEqual(self.bind_expr("not a"), BOOL_TYPE.instance) 2003 self.assertEqual(self.bind_expr("lambda: 42"), DYNAMIC) 2004 self.assertEqual(self.bind_expr("a if b else c"), DYNAMIC) 2005 self.assertEqual(self.bind_expr("x > y"), DYNAMIC) 2006 self.assertEqual(self.bind_expr("x()"), DYNAMIC) 2007 self.assertEqual(self.bind_expr("x(y)"), DYNAMIC) 2008 self.assertEqual(self.bind_expr("x[y]"), DYNAMIC) 2009 self.assertEqual(self.bind_expr("x[1:2]"), DYNAMIC) 2010 self.assertEqual(self.bind_expr("x[1:2:3]"), DYNAMIC) 2011 self.assertEqual(self.bind_expr("x[:]"), DYNAMIC) 2012 self.assertEqual(self.bind_expr("{}"), DICT_EXACT_TYPE.instance) 2013 self.assertEqual(self.bind_expr("{2:3}"), DICT_EXACT_TYPE.instance) 2014 self.assertEqual(self.bind_expr("{1,2}"), SET_EXACT_TYPE.instance) 2015 self.assertEqual(self.bind_expr("[]"), LIST_EXACT_TYPE.instance) 2016 self.assertEqual(self.bind_expr("[1,2]"), LIST_EXACT_TYPE.instance) 2017 self.assertEqual(self.bind_expr("(1,2)"), TUPLE_EXACT_TYPE.instance) 2018 2019 self.assertEqual(self.bind_expr("[x for x in y]"), LIST_EXACT_TYPE.instance) 2020 self.assertEqual(self.bind_expr("{x for x in y}"), SET_EXACT_TYPE.instance) 2021 self.assertEqual(self.bind_expr("{x:y for x in y}"), DICT_EXACT_TYPE.instance) 2022 self.assertEqual(self.bind_expr("(x for x in y)"), DYNAMIC) 2023 2024 def body_get(stmt): 2025 return stmt.body[0].value 2026 2027 self.assertEqual( 2028 self.bind_stmt("def f(): return 42", getter=body_get), 2029 INT_EXACT_TYPE.instance, 2030 ) 2031 self.assertEqual(self.bind_stmt("def f(): yield 42", getter=body_get), DYNAMIC) 2032 self.assertEqual( 2033 self.bind_stmt("def f(): yield from x", getter=body_get), DYNAMIC 2034 ) 2035 self.assertEqual( 2036 self.bind_stmt("async def f(): await x", getter=body_get), DYNAMIC 2037 ) 2038 2039 self.assertEqual(self.bind_expr("object"), OBJECT_TYPE) 2040 2041 self.assertEqual( 2042 self.bind_expr("1 + 2", optimize=True), INT_EXACT_TYPE.instance 2043 ) 2044 2045 def test_if_exp(self) -> None: 2046 mod, syms = self.bind_module( 2047 """ 2048 class C: pass 2049 class D: pass 2050 2051 x = C() if a else D() 2052 """ 2053 ) 2054 node = mod.body[-1] 2055 types = syms.modules["foo"].types 2056 self.assertEqual(types[node], DYNAMIC) 2057 2058 mod, syms = self.bind_module( 2059 """ 2060 class C: pass 2061 2062 x = C() if a else C() 2063 """ 2064 ) 2065 node = mod.body[-1] 2066 types = syms.modules["foo"].types 2067 self.assertEqual(types[node.value].name, "foo.C") 2068 2069 def test_cmpop(self): 2070 codestr = """ 2071 from __static__ import int32 2072 def f(): 2073 i: int32 = 0 2074 j: int = 0 2075 2076 if i == 0: 2077 return 0 2078 if j == 0: 2079 return 1 2080 """ 2081 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2082 x = self.find_code(code, "f") 2083 self.assertInBytecode(x, "INT_COMPARE_OP", 0) 2084 self.assertInBytecode(x, "COMPARE_OP", "==") 2085 2086 def test_bind_instance(self) -> None: 2087 mod, syms = self.bind_module("class C: pass\na: C = C()") 2088 assign = mod.body[1] 2089 types = syms.modules["foo"].types 2090 self.assertEqual(types[assign.target].name, "foo.C") 2091 self.assertEqual(repr(types[assign.target]), "<foo.C>") 2092 2093 def test_bind_func_def(self) -> None: 2094 mod, syms = self.bind_module( 2095 """ 2096 def f(x: object = None, y: object = None): 2097 pass 2098 """ 2099 ) 2100 modtable = syms.modules["foo"] 2101 self.assertTrue(isinstance(modtable.children["f"], Function)) 2102 2103 def assertReturns(self, code: str, typename: str) -> None: 2104 actual = self.bind_final_return(code).name 2105 self.assertEqual(actual, typename) 2106 2107 def bind_final_return(self, code: str) -> Value: 2108 mod, syms = self.bind_module(code) 2109 types = syms.modules["foo"].types 2110 node = mod.body[-1].body[-1].value 2111 return types[node] 2112 2113 def bind_stmt( 2114 self, code: str, optimize: bool = False, getter=lambda stmt: stmt 2115 ) -> ast.stmt: 2116 mod, syms = self.bind_module(code, optimize) 2117 assert len(mod.body) == 1 2118 types = syms.modules["foo"].types 2119 return types[getter(mod.body[0])] 2120 2121 def bind_expr(self, code: str, optimize: bool = False) -> Value: 2122 mod, syms = self.bind_module(code, optimize) 2123 assert len(mod.body) == 1 2124 types = syms.modules["foo"].types 2125 return types[mod.body[0].value] 2126 2127 def bind_module( 2128 self, code: str, optimize: bool = False 2129 ) -> Tuple[ast.Module, SymbolTable]: 2130 tree = ast.parse(dedent(code)) 2131 if optimize: 2132 tree = AstOptimizer().visit(tree) 2133 2134 symtable = SymbolTable() 2135 decl_visit = DeclarationVisitor("foo", "foo.py", symtable) 2136 decl_visit.visit(tree) 2137 decl_visit.module.finish_bind() 2138 2139 s = SymbolVisitor() 2140 walk(tree, s) 2141 2142 type_binder = TypeBinder(s, "foo.py", symtable, "foo") 2143 type_binder.visit(tree) 2144 2145 # Make sure we can compile the code, just verifying all nodes are 2146 # visited. 2147 graph = StaticCodeGenerator.flow_graph("foo", "foo.py", s.scopes[tree]) 2148 code_gen = StaticCodeGenerator(None, tree, s, graph, symtable, "foo", optimize) 2149 code_gen.visit(tree) 2150 2151 return tree, symtable 2152 2153 def test_cross_module(self) -> None: 2154 acode = """ 2155 class C: 2156 def f(self): 2157 return 42 2158 """ 2159 bcode = """ 2160 from a import C 2161 2162 def f(): 2163 x = C() 2164 return x.f() 2165 """ 2166 symtable = SymbolTable() 2167 acomp = symtable.compile("a", "a.py", ast.parse(dedent(acode))) 2168 bcomp = symtable.compile("b", "b.py", ast.parse(dedent(bcode))) 2169 x = self.find_code(bcomp, "f") 2170 self.assertInBytecode(x, "INVOKE_METHOD", (("a", "C", "f"), 0)) 2171 2172 def test_cross_module_import_time_resolution(self) -> None: 2173 class TestSymbolTable(SymbolTable): 2174 def import_module(self, name): 2175 if name == "a": 2176 symtable.add_module("a", "a.py", ast.parse(dedent(acode))) 2177 2178 acode = """ 2179 class C: 2180 def f(self): 2181 return 42 2182 """ 2183 bcode = """ 2184 from a import C 2185 2186 def f(): 2187 x = C() 2188 return x.f() 2189 """ 2190 symtable = TestSymbolTable() 2191 bcomp = symtable.compile("b", "b.py", ast.parse(dedent(bcode))) 2192 x = self.find_code(bcomp, "f") 2193 self.assertInBytecode(x, "INVOKE_METHOD", (("a", "C", "f"), 0)) 2194 2195 def test_cross_module_type_checking(self) -> None: 2196 acode = """ 2197 class C: 2198 def f(self): 2199 return 42 2200 """ 2201 bcode = """ 2202 from typing import TYPE_CHECKING 2203 2204 if TYPE_CHECKING: 2205 from a import C 2206 2207 def f(x: C): 2208 return x.f() 2209 """ 2210 symtable = SymbolTable() 2211 symtable.add_module("a", "a.py", ast.parse(dedent(acode))) 2212 symtable.add_module("b", "b.py", ast.parse(dedent(bcode))) 2213 acomp = symtable.compile("a", "a.py", ast.parse(dedent(acode))) 2214 bcomp = symtable.compile("b", "b.py", ast.parse(dedent(bcode))) 2215 x = self.find_code(bcomp, "f") 2216 self.assertInBytecode(x, "INVOKE_METHOD", (("a", "C", "f"), 0)) 2217 2218 def test_primitive_invoke(self) -> None: 2219 codestr = """ 2220 from __static__ import int8 2221 def f(): 2222 x: int8 = 42 2223 print(x.__str__()) 2224 """ 2225 with self.assertRaisesRegex( 2226 TypedSyntaxError, "cannot load attribute from int8" 2227 ): 2228 self.compile(codestr, StaticCodeGenerator) 2229 2230 def test_primitive_call(self) -> None: 2231 codestr = """ 2232 from __static__ import int8 2233 def f(): 2234 x: int8 = 42 2235 print(x()) 2236 """ 2237 with self.assertRaisesRegex(TypedSyntaxError, "cannot call int8"): 2238 self.compile(codestr, StaticCodeGenerator) 2239 2240 def test_primitive_subscr(self) -> None: 2241 codestr = """ 2242 from __static__ import int8 2243 def f(): 2244 x: int8 = 42 2245 print(x[42]) 2246 """ 2247 with self.assertRaisesRegex(TypedSyntaxError, "cannot index int8"): 2248 self.compile(codestr, StaticCodeGenerator) 2249 2250 def test_primitive_iter(self) -> None: 2251 codestr = """ 2252 from __static__ import int8 2253 def f(): 2254 x: int8 = 42 2255 for a in x: 2256 pass 2257 """ 2258 with self.assertRaisesRegex(TypedSyntaxError, "cannot iterate over int8"): 2259 self.compile(codestr, StaticCodeGenerator) 2260 2261 def test_pseudo_strict_module(self) -> None: 2262 # simulate strict modules where the builtins come from <builtins> 2263 code = """ 2264 def f(a): 2265 x: bool = a 2266 """ 2267 builtins = ast.Assign( 2268 [ast.Name("bool", ast.Store())], 2269 ast.Subscript( 2270 ast.Name("<builtins>", ast.Load()), 2271 ast.Index(ast.Str("bool")), 2272 ast.Load(), 2273 ), 2274 None, 2275 ) 2276 tree = ast.parse(dedent(code)) 2277 tree.body.insert(0, builtins) 2278 2279 symtable = SymbolTable() 2280 symtable.add_module("a", "a.py", tree) 2281 acomp = symtable.compile("a", "a.py", tree) 2282 x = self.find_code(acomp, "f") 2283 self.assertInBytecode(x, "CAST", ("builtins", "bool")) 2284 2285 def test_aug_assign(self) -> None: 2286 codestr = """ 2287 def f(l): 2288 l[0] += 1 2289 """ 2290 with self.in_module(codestr) as mod: 2291 f = mod["f"] 2292 l = [1] 2293 f(l) 2294 self.assertEqual(l[0], 2) 2295 2296 def test_pseudo_strict_module_constant(self) -> None: 2297 # simulate strict modules where the builtins come from <builtins> 2298 code = """ 2299 def f(a): 2300 x: bool = a 2301 """ 2302 builtins = ast.Assign( 2303 [ast.Name("bool", ast.Store())], 2304 ast.Subscript( 2305 ast.Name("<builtins>", ast.Load()), 2306 ast.Index(ast.Constant("bool")), 2307 ast.Load(), 2308 ), 2309 None, 2310 ) 2311 tree = ast.parse(dedent(code)) 2312 tree.body.insert(0, builtins) 2313 2314 symtable = SymbolTable() 2315 symtable.add_module("a", "a.py", tree) 2316 acomp = symtable.compile("a", "a.py", tree) 2317 x = self.find_code(acomp, "f") 2318 self.assertInBytecode(x, "CAST", ("builtins", "bool")) 2319 2320 def test_cross_module_inheritance(self) -> None: 2321 acode = """ 2322 class C: 2323 def f(self): 2324 return 42 2325 """ 2326 bcode = """ 2327 from a import C 2328 2329 class D(C): 2330 def f(self): 2331 return 'abc' 2332 2333 def f(y): 2334 x: C 2335 if y: 2336 x = D() 2337 else: 2338 x = C() 2339 return x.f() 2340 """ 2341 symtable = SymbolTable() 2342 symtable.add_module("a", "a.py", ast.parse(dedent(acode))) 2343 symtable.add_module("b", "b.py", ast.parse(dedent(bcode))) 2344 acomp = symtable.compile("a", "a.py", ast.parse(dedent(acode))) 2345 bcomp = symtable.compile("b", "b.py", ast.parse(dedent(bcode))) 2346 x = self.find_code(bcomp, "f") 2347 self.assertInBytecode(x, "INVOKE_METHOD", (("a", "C", "f"), 0)) 2348 2349 def test_annotated_function(self): 2350 codestr = """ 2351 class C: 2352 def f(self) -> int: 2353 return 1 2354 2355 def x(c: C): 2356 x = c.f() 2357 x += c.f() 2358 return x 2359 """ 2360 2361 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2362 x = self.find_code(code, "x") 2363 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 2364 2365 with self.in_module(codestr) as mod: 2366 x, C = mod["x"], mod["C"] 2367 c = C() 2368 self.assertEqual(x(c), 2) 2369 2370 def test_invoke_new_derived(self): 2371 codestr = """ 2372 class C: 2373 def f(self): 2374 return 1 2375 2376 def x(c: C): 2377 x = c.f() 2378 x += c.f() 2379 return x 2380 2381 a = x(C()) 2382 2383 class D(C): 2384 def f(self): 2385 return 2 2386 2387 b = x(D()) 2388 """ 2389 2390 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2391 x = self.find_code(code, "x") 2392 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 2393 2394 with self.in_module(codestr) as mod: 2395 a, b = mod["a"], mod["b"] 2396 self.assertEqual(a, 2) 2397 self.assertEqual(b, 4) 2398 2399 def test_invoke_explicit_slots(self): 2400 codestr = """ 2401 class C: 2402 __slots__ = () 2403 def f(self): 2404 return 1 2405 2406 def x(c: C): 2407 x = c.f() 2408 x += c.f() 2409 return x 2410 2411 a = x(C()) 2412 """ 2413 2414 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2415 x = self.find_code(code, "x") 2416 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 2417 2418 with self.in_module(codestr) as mod: 2419 a = mod["a"] 2420 self.assertEqual(a, 2) 2421 2422 def test_invoke_new_derived_nonfunc(self): 2423 codestr = """ 2424 class C: 2425 def f(self): 2426 return 1 2427 2428 def x(c: C): 2429 x = c.f() 2430 x += c.f() 2431 return x 2432 """ 2433 2434 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2435 x = self.find_code(code, "x") 2436 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 2437 2438 with self.in_module(codestr) as mod: 2439 x, C = mod["x"], mod["C"] 2440 self.assertEqual(x(C()), 2) 2441 2442 class Callable: 2443 def __call__(self_, obj): 2444 self.assertTrue(isinstance(obj, D)) 2445 return 42 2446 2447 class D(C): 2448 f = Callable() 2449 2450 d = D() 2451 self.assertEqual(x(d), 84) 2452 2453 def test_invoke_new_derived_nonfunc_slots(self): 2454 codestr = """ 2455 class C: 2456 def f(self): 2457 return 1 2458 2459 def x(c: C): 2460 x = c.f() 2461 x += c.f() 2462 return x 2463 """ 2464 2465 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2466 x = self.find_code(code, "x") 2467 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 2468 2469 with self.in_module(codestr) as mod: 2470 x, C = mod["x"], mod["C"] 2471 self.assertEqual(x(C()), 2) 2472 2473 class Callable: 2474 def __call__(self_, obj): 2475 self.assertTrue(isinstance(obj, D)) 2476 return 42 2477 2478 class D(C): 2479 __slots__ = () 2480 f = Callable() 2481 2482 d = D() 2483 self.assertEqual(x(d), 84) 2484 2485 def test_invoke_new_derived_nonfunc_descriptor(self): 2486 codestr = """ 2487 class C: 2488 def f(self): 2489 return 1 2490 2491 def x(c: C): 2492 x = c.f() 2493 x += c.f() 2494 return x 2495 """ 2496 2497 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2498 x = self.find_code(code, "x") 2499 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 2500 2501 with self.in_module(codestr) as mod: 2502 x, C = mod["x"], mod["C"] 2503 self.assertEqual(x(C()), 2) 2504 2505 class Callable: 2506 def __call__(self): 2507 return 42 2508 2509 class Descr: 2510 def __get__(self, inst, ctx): 2511 return Callable() 2512 2513 class D(C): 2514 f = Descr() 2515 2516 d = D() 2517 self.assertEqual(x(d), 84) 2518 2519 def test_invoke_new_derived_nonfunc_data_descriptor(self): 2520 codestr = """ 2521 class C: 2522 def f(self): 2523 return 1 2524 2525 def x(c: C): 2526 x = c.f() 2527 x += c.f() 2528 return x 2529 """ 2530 2531 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2532 x = self.find_code(code, "x") 2533 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 2534 2535 with self.in_module(codestr) as mod: 2536 x, C = mod["x"], mod["C"] 2537 self.assertEqual(x(C()), 2) 2538 2539 class Callable: 2540 def __call__(self): 2541 return 42 2542 2543 class Descr: 2544 def __get__(self, inst, ctx): 2545 return Callable() 2546 2547 def __set__(self, inst, value): 2548 raise ValueError("no way") 2549 2550 class D(C): 2551 f = Descr() 2552 2553 d = D() 2554 self.assertEqual(x(d), 84) 2555 2556 def test_invoke_new_derived_nonfunc_descriptor_inst_override(self): 2557 codestr = """ 2558 class C: 2559 def f(self): 2560 return 1 2561 2562 def x(c: C): 2563 x = c.f() 2564 x += c.f() 2565 return x 2566 """ 2567 2568 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2569 x = self.find_code(code, "x") 2570 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 2571 2572 with self.in_module(codestr) as mod: 2573 x, C = mod["x"], mod["C"] 2574 self.assertEqual(x(C()), 2) 2575 2576 class Callable: 2577 def __call__(self): 2578 return 42 2579 2580 class Descr: 2581 def __get__(self, inst, ctx): 2582 return Callable() 2583 2584 class D(C): 2585 f = Descr() 2586 2587 d = D() 2588 self.assertEqual(x(d), 84) 2589 d.__dict__["f"] = lambda x: 100 2590 self.assertEqual(x(d), 200) 2591 2592 def test_invoke_new_derived_nonfunc_descriptor_modified(self): 2593 codestr = """ 2594 class C: 2595 def f(self): 2596 return 1 2597 2598 def x(c: C): 2599 x = c.f() 2600 x += c.f() 2601 return x 2602 """ 2603 2604 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2605 x = self.find_code(code, "x") 2606 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 2607 2608 with self.in_module(codestr) as mod: 2609 x, C = mod["x"], mod["C"] 2610 self.assertEqual(x(C()), 2) 2611 2612 class Callable: 2613 def __call__(self): 2614 return 42 2615 2616 class Descr: 2617 def __get__(self, inst, ctx): 2618 return Callable() 2619 2620 def __call__(self, arg): 2621 return 23 2622 2623 class D(C): 2624 f = Descr() 2625 2626 d = D() 2627 self.assertEqual(x(d), 84) 2628 del Descr.__get__ 2629 self.assertEqual(x(d), 46) 2630 2631 def test_invoke_dict_override(self): 2632 codestr = """ 2633 class C: 2634 def f(self): 2635 return 1 2636 2637 def x(c: C): 2638 x = c.f() 2639 x += c.f() 2640 return x 2641 """ 2642 2643 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2644 x = self.find_code(code, "x") 2645 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 2646 2647 with self.in_module(codestr) as mod: 2648 x, C = mod["x"], mod["C"] 2649 self.assertEqual(x(C()), 2) 2650 2651 class D(C): 2652 def __init__(self): 2653 self.f = lambda: 42 2654 2655 d = D() 2656 self.assertEqual(x(d), 84) 2657 2658 def test_invoke_type_modified(self): 2659 codestr = """ 2660 class C: 2661 def f(self): 2662 return 1 2663 2664 def x(c: C): 2665 x = c.f() 2666 x += c.f() 2667 return x 2668 """ 2669 2670 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2671 x = self.find_code(code, "x") 2672 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 2673 2674 with self.in_module(codestr) as mod: 2675 x, C = mod["x"], mod["C"] 2676 self.assertEqual(x(C()), 2) 2677 C.f = lambda self: 42 2678 self.assertEqual(x(C()), 84) 2679 2680 def test_annotated_function_derived(self): 2681 codestr = """ 2682 class C: 2683 def f(self) -> int: 2684 return 1 2685 2686 class D(C): 2687 def f(self) -> int: 2688 return 2 2689 2690 class E(C): 2691 pass 2692 2693 def x(c: C,): 2694 x = c.f() 2695 x += c.f() 2696 return x 2697 """ 2698 2699 code = self.compile( 2700 codestr, StaticCodeGenerator, modname="test_annotated_function_derived" 2701 ) 2702 x = self.find_code(code, "x") 2703 self.assertInBytecode( 2704 x, "INVOKE_METHOD", (("test_annotated_function_derived", "C", "f"), 0) 2705 ) 2706 2707 with self.in_module(codestr) as mod: 2708 x = mod["x"] 2709 self.assertEqual(x(mod["C"]()), 2) 2710 self.assertEqual(x(mod["D"]()), 4) 2711 self.assertEqual(x(mod["E"]()), 2) 2712 2713 def test_conditional_init(self): 2714 codestr = f""" 2715 from __static__ import box, int64 2716 2717 class C: 2718 def __init__(self, init=True): 2719 if init: 2720 self.value: int64 = 1 2721 2722 def f(self) -> int: 2723 return box(self.value) 2724 """ 2725 2726 with self.in_module(codestr) as mod: 2727 C = mod["C"] 2728 x = C() 2729 self.assertEqual(x.f(), 1) 2730 x = C(False) 2731 self.assertEqual(x.f(), 0) 2732 self.assertInBytecode(C.f, "LOAD_FIELD", (mod["__name__"], "C", "value")) 2733 2734 def test_error_incompat_assign_local(self): 2735 codestr = """ 2736 class C: 2737 def __init__(self): 2738 self.x = None 2739 2740 def f(self): 2741 x: "C" = self.x 2742 """ 2743 with self.in_module(codestr) as mod: 2744 C = mod["C"] 2745 with self.assertRaisesRegex(TypeError, "expected 'C', got 'NoneType'"): 2746 C().f() 2747 2748 def test_error_incompat_field_non_dynamic(self): 2749 codestr = """ 2750 class C: 2751 def __init__(self): 2752 self.x: int = 'abc' 2753 """ 2754 with self.assertRaises(TypedSyntaxError): 2755 self.compile(codestr, StaticCodeGenerator) 2756 2757 def test_error_incompat_field(self): 2758 codestr = """ 2759 class C: 2760 def __init__(self): 2761 self.x: int = 100 2762 2763 def f(self, x): 2764 self.x = x 2765 """ 2766 with self.in_module(codestr) as mod: 2767 C = mod["C"] 2768 C().f(42) 2769 with self.assertRaises(TypeError): 2770 C().f("abc") 2771 2772 def test_error_incompat_assign_dynamic(self): 2773 with self.assertRaises(TypedSyntaxError): 2774 code = self.compile( 2775 """ 2776 class C: 2777 x: "C" 2778 def __init__(self): 2779 self.x = None 2780 """, 2781 StaticCodeGenerator, 2782 ) 2783 2784 def test_annotated_class_var(self): 2785 codestr = """ 2786 class C: 2787 x: int 2788 """ 2789 code = self.compile( 2790 codestr, StaticCodeGenerator, modname="test_annotated_class_var" 2791 ) 2792 2793 def test_annotated_instance_var(self): 2794 codestr = """ 2795 class C: 2796 def __init__(self): 2797 self.x: str = 'abc' 2798 """ 2799 code = self.compile( 2800 codestr, StaticCodeGenerator, modname="test_annotated_instance_var" 2801 ) 2802 # get C from module, and then get __init__ from C 2803 code = self.find_code(self.find_code(code)) 2804 self.assertInBytecode(code, "STORE_FIELD") 2805 2806 def test_load_store_attr_value(self): 2807 codestr = """ 2808 class C: 2809 x: int 2810 2811 def __init__(self, value: int): 2812 self.x = value 2813 2814 def f(self): 2815 return self.x 2816 """ 2817 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2818 init = self.find_code(self.find_code(code), "__init__") 2819 self.assertInBytecode(init, "STORE_FIELD") 2820 f = self.find_code(self.find_code(code), "f") 2821 self.assertInBytecode(f, "LOAD_FIELD") 2822 with self.in_module(codestr) as mod: 2823 C = mod["C"] 2824 a = C(42) 2825 self.assertEqual(a.f(), 42) 2826 2827 def test_load_store_attr(self): 2828 codestr = """ 2829 class C: 2830 x: "C" 2831 2832 def __init__(self): 2833 self.x = self 2834 2835 def g(self): 2836 return 42 2837 2838 def f(self): 2839 return self.x.g() 2840 """ 2841 with self.in_module(codestr) as mod: 2842 C = mod["C"] 2843 a = C() 2844 self.assertEqual(a.f(), 42) 2845 2846 def test_load_store_attr_init(self): 2847 codestr = """ 2848 class C: 2849 def __init__(self): 2850 self.x: C = self 2851 2852 def g(self): 2853 return 42 2854 2855 def f(self): 2856 return self.x.g() 2857 """ 2858 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2859 2860 with self.in_module(codestr) as mod: 2861 C = mod["C"] 2862 a = C() 2863 self.assertEqual(a.f(), 42) 2864 2865 def test_load_store_attr_init_no_ann(self): 2866 codestr = """ 2867 class C: 2868 def __init__(self): 2869 self.x = self 2870 2871 def g(self): 2872 return 42 2873 2874 def f(self): 2875 return self.x.g() 2876 """ 2877 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2878 2879 with self.in_module(codestr) as mod: 2880 C = mod["C"] 2881 a = C() 2882 self.assertEqual(a.f(), 42) 2883 2884 def test_unknown_annotation(self): 2885 codestr = """ 2886 def f(a): 2887 x: foo = a 2888 return x.bar 2889 """ 2890 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2891 2892 class C: 2893 bar = 42 2894 2895 f = self.run_code(codestr, StaticCodeGenerator)["f"] 2896 self.assertEqual(f(C()), 42) 2897 2898 def test_class_method_invoke(self): 2899 codestr = """ 2900 class B: 2901 def __init__(self, value): 2902 self.value = value 2903 2904 class D(B): 2905 def __init__(self, value): 2906 B.__init__(self, value) 2907 2908 def f(self): 2909 return self.value 2910 """ 2911 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2912 2913 b_init = self.find_code(self.find_code(code, "B"), "__init__") 2914 self.assertInBytecode(b_init, "STORE_FIELD", ("foo", "B", "value")) 2915 2916 f = self.find_code(self.find_code(code, "D"), "f") 2917 self.assertInBytecode(f, "LOAD_FIELD", ("foo", "B", "value")) 2918 2919 with self.in_module(codestr) as mod: 2920 D = mod["D"] 2921 d = D(42) 2922 self.assertEqual(d.f(), 42) 2923 2924 def test_slotification(self): 2925 codestr = """ 2926 class C: 2927 x: "unknown_type" 2928 """ 2929 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2930 C = self.run_code(codestr, StaticCodeGenerator)["C"] 2931 self.assertEqual(type(C.x), MemberDescriptorType) 2932 2933 def test_slotification_init(self): 2934 codestr = """ 2935 class C: 2936 x: "unknown_type" 2937 def __init__(self): 2938 self.x = 42 2939 """ 2940 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2941 C = self.run_code(codestr, StaticCodeGenerator)["C"] 2942 self.assertEqual(type(C.x), MemberDescriptorType) 2943 2944 def test_slotification_ann_init(self): 2945 codestr = """ 2946 class C: 2947 x: "unknown_type" 2948 def __init__(self): 2949 self.x: "unknown_type" = 42 2950 """ 2951 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 2952 C = self.run_code(codestr, StaticCodeGenerator)["C"] 2953 self.assertEqual(type(C.x), MemberDescriptorType) 2954 2955 def test_slotification_typed(self): 2956 codestr = """ 2957 class C: 2958 x: int 2959 """ 2960 C = self.run_code(codestr, StaticCodeGenerator)["C"] 2961 self.assertNotEqual(type(C.x), MemberDescriptorType) 2962 2963 def test_slotification_init_typed(self): 2964 codestr = """ 2965 class C: 2966 x: int 2967 def __init__(self): 2968 self.x = 42 2969 """ 2970 with self.in_module(codestr) as mod: 2971 C = mod["C"] 2972 self.assertNotEqual(type(C.x), MemberDescriptorType) 2973 x = C() 2974 self.assertEqual(x.x, 42) 2975 with self.assertRaisesRegex( 2976 TypeError, "expected 'int', got 'str' for attribute 'x'" 2977 ) as e: 2978 x.x = "abc" 2979 2980 def test_slotification_ann_init_typed(self): 2981 codestr = """ 2982 class C: 2983 x: int 2984 def __init__(self): 2985 self.x: int = 42 2986 """ 2987 C = self.run_code(codestr, StaticCodeGenerator)["C"] 2988 self.assertNotEqual(type(C.x), MemberDescriptorType) 2989 2990 def test_slotification_conflicting_types(self): 2991 codestr = """ 2992 class C: 2993 x: object 2994 def __init__(self): 2995 self.x: int = 42 2996 """ 2997 with self.assertRaisesRegex( 2998 TypedSyntaxError, 2999 r"conflicting type definitions for slot x in Type\[foo.C\]", 3000 ): 3001 self.compile(codestr, StaticCodeGenerator, modname="foo") 3002 3003 def test_slotification_conflicting_types_imported(self): 3004 self.type_error( 3005 """ 3006 from typing import Optional 3007 3008 class C: 3009 x: Optional[int] 3010 def __init__(self): 3011 self.x: Optional[str] = "foo" 3012 """, 3013 r"conflicting type definitions for slot x in Type\[<module>.C\]", 3014 ) 3015 3016 def test_slotification_conflicting_members(self): 3017 codestr = """ 3018 class C: 3019 def x(self): pass 3020 x: object 3021 """ 3022 with self.assertRaisesRegex( 3023 TypedSyntaxError, r"slot conflicts with other member x in Type\[foo.C\]" 3024 ): 3025 self.compile(codestr, StaticCodeGenerator, modname="foo") 3026 3027 def test_slotification_conflicting_function(self): 3028 codestr = """ 3029 class C: 3030 x: object 3031 def x(self): pass 3032 """ 3033 with self.assertRaisesRegex( 3034 TypedSyntaxError, r"function conflicts with other member x in Type\[foo.C\]" 3035 ): 3036 self.compile(codestr, StaticCodeGenerator, modname="foo") 3037 3038 def test_slot_inheritance(self): 3039 codestr = """ 3040 class B: 3041 def __init__(self): 3042 self.x = 42 3043 3044 def f(self): 3045 return self.x 3046 3047 class D(B): 3048 def __init__(self): 3049 self.x = 100 3050 """ 3051 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 3052 with self.in_module(codestr) as mod: 3053 D = mod["D"] 3054 inst = D() 3055 self.assertEqual(inst.f(), 100) 3056 3057 def test_del_slot(self): 3058 codestr = """ 3059 class C: 3060 x: object 3061 3062 def f(a: C): 3063 del a.x 3064 """ 3065 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 3066 code = self.find_code(code, name="f") 3067 self.assertInBytecode(code, "DELETE_ATTR", "x") 3068 3069 def test_uninit_slot(self): 3070 codestr = """ 3071 class C: 3072 x: object 3073 3074 def f(a: C): 3075 return a.x 3076 """ 3077 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 3078 code = self.find_code(code, name="f") 3079 with self.in_module(codestr) as mod: 3080 with self.assertRaises(AttributeError) as e: 3081 f, C = mod["f"], mod["C"] 3082 f(C()) 3083 3084 self.assertEqual(e.exception.args[0], "x") 3085 3086 def test_aug_add(self): 3087 codestr = """ 3088 class C: 3089 def __init__(self): 3090 self.x = 1 3091 3092 def f(a: C): 3093 a.x += 1 3094 """ 3095 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 3096 code = self.find_code(code, name="f") 3097 self.assertInBytecode(code, "LOAD_FIELD", ("foo", "C", "x")) 3098 self.assertInBytecode(code, "STORE_FIELD", ("foo", "C", "x")) 3099 3100 def test_untyped_attr(self): 3101 codestr = """ 3102 y = x.load 3103 x.store = 42 3104 del x.delete 3105 """ 3106 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 3107 self.assertInBytecode(code, "LOAD_ATTR", "load") 3108 self.assertInBytecode(code, "STORE_ATTR", "store") 3109 self.assertInBytecode(code, "DELETE_ATTR", "delete") 3110 3111 def test_incompat_override(self): 3112 codestr = """ 3113 class C: 3114 x: int 3115 3116 class D(C): 3117 def x(self): pass 3118 """ 3119 with self.assertRaises(TypedSyntaxError): 3120 self.compile(codestr, StaticCodeGenerator, modname="foo") 3121 3122 def test_redefine_type(self): 3123 codestr = """ 3124 class C: pass 3125 class D: pass 3126 3127 def f(a): 3128 x: C = C() 3129 x: D = D() 3130 """ 3131 with self.assertRaises(TypedSyntaxError): 3132 self.compile(codestr, StaticCodeGenerator, modname="foo") 3133 3134 def test_optional_error(self): 3135 codestr = """ 3136 from typing import Optional 3137 class C: 3138 x: Optional["C"] 3139 def __init__(self, set): 3140 if set: 3141 self.x = self 3142 else: 3143 self.x = None 3144 3145 def f(self) -> Optional["C"]: 3146 return self.x.x 3147 """ 3148 with self.assertRaisesRegex( 3149 TypedSyntaxError, 3150 re.escape("Optional[foo.C]: 'NoneType' object has no attribute 'x'"), 3151 ): 3152 self.compile(codestr, StaticCodeGenerator, modname="foo") 3153 3154 def test_optional_subscript_error(self) -> None: 3155 codestr = """ 3156 from typing import Optional 3157 3158 def f(a: Optional[int]): 3159 a[1] 3160 """ 3161 with self.assertRaisesRegex( 3162 TypedSyntaxError, 3163 re.escape("Optional[int]: 'NoneType' object is not subscriptable"), 3164 ): 3165 self.compile(codestr, StaticCodeGenerator) 3166 3167 def test_optional_unary_error(self) -> None: 3168 codestr = """ 3169 from typing import Optional 3170 3171 def f(a: Optional[int]): 3172 -a 3173 """ 3174 with self.assertRaisesRegex( 3175 TypedSyntaxError, 3176 re.escape("Optional[int]: bad operand type for unary -: 'NoneType'"), 3177 ): 3178 self.compile(codestr, StaticCodeGenerator) 3179 3180 def test_optional_assign(self): 3181 codestr = """ 3182 from typing import Optional 3183 class C: 3184 def f(self, x: Optional["C"]): 3185 if x is None: 3186 return self 3187 else: 3188 p: Optional["C"] = x 3189 """ 3190 self.compile(codestr, StaticCodeGenerator, modname="foo") 3191 3192 def test_nonoptional_load(self): 3193 codestr = """ 3194 class C: 3195 def __init__(self, y: int): 3196 self.y = y 3197 3198 def f(c: C) -> int: 3199 return c.y 3200 """ 3201 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 3202 f = self.find_code(code, "f") 3203 self.assertInBytecode(f, "LOAD_FIELD", ("foo", "C", "y")) 3204 3205 def test_optional_assign_subclass(self): 3206 codestr = """ 3207 from typing import Optional 3208 class B: pass 3209 class D(B): pass 3210 3211 def f(x: D): 3212 a: Optional[B] = x 3213 """ 3214 self.compile(codestr, StaticCodeGenerator, modname="foo") 3215 3216 def test_optional_assign_subclass_opt(self): 3217 codestr = """ 3218 from typing import Optional 3219 class B: pass 3220 class D(B): pass 3221 3222 def f(x: Optional[D]): 3223 a: Optional[B] = x 3224 """ 3225 self.compile(codestr, StaticCodeGenerator, modname="foo") 3226 3227 def test_optional_assign_none(self): 3228 codestr = """ 3229 from typing import Optional 3230 class B: pass 3231 3232 def f(x: Optional[B]): 3233 a: Optional[B] = None 3234 """ 3235 self.compile(codestr, StaticCodeGenerator, modname="foo") 3236 3237 def test_optional_union_syntax(self): 3238 self.assertReturns( 3239 """ 3240 from typing import Union 3241 class B: pass 3242 class C(B): pass 3243 3244 def f(x: Union[int, None]) -> int: 3245 # can assign None 3246 y: Optional[int] = None 3247 # can assign subclass 3248 z: Optional[B] = C() 3249 # can narrow 3250 if x is None: 3251 return 1 3252 return x 3253 """, 3254 "int", 3255 ) 3256 3257 def test_optional_union_syntax_error(self): 3258 self.type_error( 3259 """ 3260 from typing import Union 3261 3262 def f(x: Union[int, None]) -> int: 3263 return x 3264 """, 3265 type_mismatch("Optional[int]", "int"), 3266 ) 3267 3268 def test_union_can_assign_to_broader_union(self): 3269 self.assertReturns( 3270 """ 3271 from typing import Union 3272 class B: 3273 pass 3274 3275 def f(x: int, y: str) -> Union[int, str, B]: 3276 return x or y 3277 """, 3278 "Union[int, str]", 3279 ) 3280 3281 def test_union_can_assign_to_same_union(self): 3282 self.assertReturns( 3283 """ 3284 from typing import Union 3285 3286 def f(x: int, y: str) -> Union[int, str]: 3287 return x or y 3288 """, 3289 "Union[int, str]", 3290 ) 3291 3292 def test_union_can_assign_from_individual_element(self): 3293 self.assertReturns( 3294 """ 3295 from typing import Union 3296 3297 def f(x: int) -> Union[int, str]: 3298 return x 3299 """, 3300 "int", 3301 ) 3302 3303 def test_union_cannot_assign_from_broader_union(self): 3304 # TODO this should be a type error, but can't be safely 3305 # until we have runtime checking for unions 3306 self.assertReturns( 3307 """ 3308 from typing import Union 3309 class B: pass 3310 3311 def f(x: int, y: str, z: B) -> Union[int, str]: 3312 return x or y or z 3313 """, 3314 "Union[int, str, foo.B]", 3315 ) 3316 3317 def test_union_simplify_to_single_type(self): 3318 self.assertReturns( 3319 """ 3320 from typing import Union 3321 3322 def f(x: int, y: int) -> int: 3323 return x or y 3324 """, 3325 "int", 3326 ) 3327 3328 def test_union_simplify_related(self): 3329 self.assertReturns( 3330 """ 3331 from typing import Union 3332 class B: pass 3333 class C(B): pass 3334 3335 def f(x: B, y: C) -> B: 3336 return x or y 3337 """, 3338 "foo.B", 3339 ) 3340 3341 def test_union_flatten_nested(self): 3342 self.assertReturns( 3343 """ 3344 from typing import Union 3345 class B: pass 3346 3347 def f(x: int, y: str, z: B): 3348 return x or (y or z) 3349 """, 3350 "Union[int, str, foo.B]", 3351 ) 3352 3353 def test_union_deep_simplify(self): 3354 self.assertReturns( 3355 """ 3356 from typing import Union 3357 3358 def f(x: int, y: None) -> int: 3359 z = (x or x) or (y or y) or (x or x) 3360 if z is None: 3361 return 1 3362 return z 3363 """, 3364 "int", 3365 ) 3366 3367 def test_union_dynamic_element(self): 3368 self.assertReturns( 3369 """ 3370 from somewhere import unknown 3371 3372 def f(x: int, y: unknown): 3373 return x or y 3374 """, 3375 "dynamic", 3376 ) 3377 3378 def test_union_or_syntax(self): 3379 self.type_error( 3380 """ 3381 def f(x) -> int: 3382 if isinstance(x, int|str): 3383 return x 3384 return 1 3385 """, 3386 type_mismatch("Union[int, str]", "int"), 3387 ) 3388 3389 def test_union_or_syntax_none(self): 3390 self.type_error( 3391 """ 3392 def f(x) -> int: 3393 if isinstance(x, int|None): 3394 return x 3395 return 1 3396 """, 3397 type_mismatch("Optional[int]", "int"), 3398 ) 3399 3400 def test_union_or_syntax_builtin_type(self): 3401 self.compile( 3402 """ 3403 from typing import Iterator 3404 def f(x) -> int: 3405 if isinstance(x, bytes | Iterator[bytes]): 3406 return 1 3407 return 2 3408 """, 3409 StaticCodeGenerator, 3410 modname="foo.py", 3411 ) 3412 3413 def test_union_or_syntax_none_first(self): 3414 self.type_error( 3415 """ 3416 def f(x) -> int: 3417 if isinstance(x, None|int): 3418 return x 3419 return 1 3420 """, 3421 type_mismatch("Optional[int]", "int"), 3422 ) 3423 3424 def test_union_or_syntax_annotation(self): 3425 self.type_error( 3426 """ 3427 def f(y: int, z: str) -> int: 3428 x: int|str = y or z 3429 return x 3430 """, 3431 type_mismatch("Union[int, str]", "int"), 3432 ) 3433 3434 def test_union_or_syntax_error(self): 3435 self.type_error( 3436 """ 3437 def f(): 3438 x = int | "foo" 3439 """, 3440 r"unsupported operand type(s) for |: Type\[Exact\[int\]\] and Exact\[str\]", 3441 ) 3442 3443 def test_union_or_syntax_annotation_bad_type(self): 3444 # TODO given that len is not unknown/dynamic, but is a known object 3445 # with type that is invalid in this position, this should really be an 3446 # error. But the current form of `resolve_annotations` doesn't let us 3447 # distinguish between unknown/dynamic and bad type. So for now we just 3448 # let this go as dynamic. 3449 self.assertReturns( 3450 """ 3451 def f(x: len | int) -> int: 3452 return x 3453 """, 3454 "dynamic", 3455 ) 3456 3457 def test_union_attr(self): 3458 self.assertReturns( 3459 """ 3460 class A: 3461 attr: int 3462 3463 class B: 3464 attr: str 3465 3466 def f(x: A, y: B): 3467 z = x or y 3468 return z.attr 3469 """, 3470 "Union[int, str]", 3471 ) 3472 3473 def test_union_attr_error(self): 3474 self.type_error( 3475 """ 3476 class A: 3477 attr: int 3478 3479 def f(x: A | None): 3480 return x.attr 3481 """, 3482 re.escape( 3483 "Optional[<module>.A]: 'NoneType' object has no attribute 'attr'" 3484 ), 3485 ) 3486 3487 # TODO add test_union_call when we have Type[] or Callable[] or 3488 # __call__ support. Right now we have no way to construct a Union of 3489 # callables that return different types. 3490 3491 def test_union_call_error(self): 3492 self.type_error( 3493 """ 3494 def f(x: int | None): 3495 return x() 3496 """, 3497 re.escape("Optional[int]: 'NoneType' object is not callable"), 3498 ) 3499 3500 def test_union_subscr(self): 3501 self.assertReturns( 3502 """ 3503 from __static__ import CheckedDict 3504 3505 def f(x: CheckedDict[int, int], y: CheckedDict[int, str]): 3506 return (x or y)[0] 3507 """, 3508 "Union[int, str]", 3509 ) 3510 3511 def test_union_unaryop(self): 3512 self.assertReturns( 3513 """ 3514 def f(x: int, y: complex): 3515 return -(x or y) 3516 """, 3517 "Union[int, complex]", 3518 ) 3519 3520 def test_union_isinstance_reverse_narrow(self): 3521 self.assertReturns( 3522 """ 3523 def f(x: int, y: str): 3524 z = x or y 3525 if isinstance(z, str): 3526 return 1 3527 return z 3528 """, 3529 "int", 3530 ) 3531 3532 def test_union_isinstance_reverse_narrow_supertype(self): 3533 self.assertReturns( 3534 """ 3535 class A: pass 3536 class B(A): pass 3537 3538 def f(x: int, y: B): 3539 o = x or y 3540 if isinstance(o, A): 3541 return 1 3542 return o 3543 """, 3544 "int", 3545 ) 3546 3547 def test_union_isinstance_reverse_narrow_other_union(self): 3548 self.assertReturns( 3549 """ 3550 class A: pass 3551 class B: pass 3552 class C: pass 3553 3554 def f(x: A, y: B, z: C): 3555 o = x or y or z 3556 if isinstance(o, A | B): 3557 return 1 3558 return o 3559 """, 3560 "foo.C", 3561 ) 3562 3563 def test_union_not_isinstance_narrow(self): 3564 self.assertReturns( 3565 """ 3566 def f(x: int, y: str): 3567 o = x or y 3568 if not isinstance(o, int): 3569 return 1 3570 return o 3571 """, 3572 "int", 3573 ) 3574 3575 def test_union_isinstance_tuple(self): 3576 self.assertReturns( 3577 """ 3578 class A: pass 3579 class B: pass 3580 class C: pass 3581 3582 def f(x: A, y: B, z: C): 3583 o = x or y or z 3584 if isinstance(o, (A, B)): 3585 return 1 3586 return o 3587 """, 3588 "foo.C", 3589 ) 3590 3591 def test_union_no_arg_check(self): 3592 codestr = """ 3593 def f(x: int | str) -> int: 3594 return x 3595 """ 3596 with self.in_module(codestr) as mod: 3597 f = mod["f"] 3598 # no arg check for the union, it's just dynamic 3599 self.assertInBytecode(f, "CHECK_ARGS", ()) 3600 # so we do have to check the return value 3601 self.assertInBytecode(f, "CAST", ("builtins", "int")) 3602 # runtime type error comes from return, not argument 3603 with self.assertRaisesRegex(TypeError, "expected 'int', got 'list'"): 3604 f([]) 3605 3606 def test_error_return_int(self): 3607 with self.assertRaisesRegex( 3608 TypedSyntaxError, "type mismatch: int64 cannot be assigned to dynamic" 3609 ): 3610 code = self.compile( 3611 """ 3612 from __static__ import ssize_t 3613 def f(): 3614 y: ssize_t = 1 3615 return y 3616 """, 3617 StaticCodeGenerator, 3618 ) 3619 3620 def test_error_mixed_math(self): 3621 with self.assertRaises(TypedSyntaxError): 3622 code = self.compile( 3623 """ 3624 from __static__ import ssize_t 3625 def f(): 3626 y = 1 3627 x: ssize_t = 42 + y 3628 """, 3629 StaticCodeGenerator, 3630 ) 3631 3632 def test_error_incompat_return(self): 3633 with self.assertRaises(TypedSyntaxError): 3634 code = self.compile( 3635 """ 3636 class D: pass 3637 class C: 3638 def __init__(self): 3639 self.x = None 3640 3641 def f(self) -> "C": 3642 return D() 3643 """, 3644 StaticCodeGenerator, 3645 ) 3646 3647 def test_cast(self): 3648 for code_gen in (StaticCodeGenerator, PythonCodeGenerator): 3649 codestr = """ 3650 from __static__ import cast 3651 class C: 3652 pass 3653 3654 a = C() 3655 3656 def f() -> C: 3657 return cast(C, a) 3658 """ 3659 code = self.compile(codestr, code_gen) 3660 f = self.find_code(code, "f") 3661 if code_gen is StaticCodeGenerator: 3662 self.assertInBytecode(f, "CAST", ("<module>", "C")) 3663 with self.in_module(codestr, code_gen=code_gen) as mod: 3664 C = mod["C"] 3665 f = mod["f"] 3666 self.assertTrue(isinstance(f(), C)) 3667 self.assert_jitted(f) 3668 3669 def test_cast_fail(self): 3670 for code_gen in (StaticCodeGenerator, PythonCodeGenerator): 3671 codestr = """ 3672 from __static__ import cast 3673 class C: 3674 pass 3675 3676 def f() -> C: 3677 return cast(C, 42) 3678 """ 3679 code = self.compile(codestr, code_gen) 3680 f = self.find_code(code, "f") 3681 if code_gen is StaticCodeGenerator: 3682 self.assertInBytecode(f, "CAST", ("<module>", "C")) 3683 with self.in_module(codestr) as mod: 3684 with self.assertRaises(TypeError): 3685 f = mod["f"] 3686 f() 3687 self.assert_jitted(f) 3688 3689 def test_cast_wrong_args(self): 3690 codestr = """ 3691 from __static__ import cast 3692 def f(): 3693 cast(42) 3694 """ 3695 with self.assertRaises(TypedSyntaxError): 3696 self.compile(codestr, StaticCodeGenerator) 3697 3698 def test_cast_unknown_type(self): 3699 codestr = """ 3700 from __static__ import cast 3701 def f(): 3702 cast(abc, 42) 3703 """ 3704 with self.assertRaises(TypedSyntaxError): 3705 self.compile(codestr, StaticCodeGenerator) 3706 3707 def test_cast_optional(self): 3708 for code_gen in (StaticCodeGenerator, PythonCodeGenerator): 3709 codestr = """ 3710 from __static__ import cast 3711 from typing import Optional 3712 3713 class C: 3714 pass 3715 3716 def f(x) -> Optional[C]: 3717 return cast(Optional[C], x) 3718 """ 3719 code = self.compile(codestr, code_gen) 3720 f = self.find_code(code, "f") 3721 if code_gen is StaticCodeGenerator: 3722 self.assertInBytecode(f, "CAST", ("<module>", "C", "?")) 3723 with self.in_module(codestr, code_gen=code_gen) as mod: 3724 C = mod["C"] 3725 f = mod["f"] 3726 self.assertTrue(isinstance(f(C()), C)) 3727 self.assertEqual(f(None), None) 3728 self.assert_jitted(f) 3729 3730 def test_code_flags(self): 3731 codestr = """ 3732 def func(): 3733 print("hi") 3734 3735 func() 3736 """ 3737 module = self.compile(codestr, StaticCodeGenerator) 3738 self.assertTrue(module.co_flags & CO_STATICALLY_COMPILED) 3739 self.assertTrue( 3740 self.find_code(module, name="func").co_flags & CO_STATICALLY_COMPILED 3741 ) 3742 3743 def test_invoke_kws(self): 3744 codestr = """ 3745 class C: 3746 def f(self, a): 3747 return a 3748 3749 def func(): 3750 a = C() 3751 return a.f(a=2) 3752 3753 """ 3754 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3755 f = mod["func"] 3756 self.assertEqual(f(), 2) 3757 3758 def test_invoke_str_method(self): 3759 codestr = """ 3760 def func(): 3761 a = 'a b c' 3762 return a.split() 3763 3764 """ 3765 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3766 f = mod["func"] 3767 self.assertInBytecode( 3768 f, "INVOKE_FUNCTION", (("builtins", "str", "split"), 1) 3769 ) 3770 self.assertEqual(f(), ["a", "b", "c"]) 3771 3772 def test_invoke_str_method_arg(self): 3773 codestr = """ 3774 def func(): 3775 a = 'a b c' 3776 return a.split('a') 3777 3778 """ 3779 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3780 f = mod["func"] 3781 self.assertInBytecode( 3782 f, "INVOKE_FUNCTION", (("builtins", "str", "split"), 2) 3783 ) 3784 self.assertEqual(f(), ["", " b c"]) 3785 3786 def test_invoke_str_method_kwarg(self): 3787 codestr = """ 3788 def func(): 3789 a = 'a b c' 3790 return a.split(sep='a') 3791 3792 """ 3793 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3794 f = mod["func"] 3795 self.assertNotInBytecode(f, "INVOKE_FUNCTION") 3796 self.assertNotInBytecode(f, "INVOKE_METHOD") 3797 self.assertEqual(f(), ["", " b c"]) 3798 3799 def test_invoke_int_method(self): 3800 codestr = """ 3801 def func(): 3802 a = 42 3803 return a.bit_length() 3804 3805 """ 3806 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3807 f = mod["func"] 3808 self.assertInBytecode( 3809 f, "INVOKE_FUNCTION", (("builtins", "int", "bit_length"), 1) 3810 ) 3811 self.assertEqual(f(), 6) 3812 3813 def test_invoke_chkdict_method(self): 3814 codestr = """ 3815 from __static__ import CheckedDict 3816 def dict_maker() -> CheckedDict[int, int]: 3817 return CheckedDict[int, int]({2:2}) 3818 def func(): 3819 a = dict_maker() 3820 return a.keys() 3821 3822 """ 3823 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3824 f = mod["func"] 3825 3826 self.assertInBytecode( 3827 f, 3828 "INVOKE_METHOD", 3829 ( 3830 ( 3831 "__static__", 3832 "chkdict", 3833 (("builtins", "int"), ("builtins", "int")), 3834 "keys", 3835 ), 3836 0, 3837 ), 3838 ) 3839 self.assertEqual(list(f()), [2]) 3840 self.assert_jitted(f) 3841 3842 def test_invoke_method_non_static_base(self): 3843 codestr = """ 3844 class C(Exception): 3845 def f(self): 3846 return 42 3847 3848 def g(self): 3849 return self.f() 3850 """ 3851 3852 with self.in_module(codestr) as mod: 3853 C = mod["C"] 3854 self.assertEqual(C().g(), 42) 3855 3856 def test_invoke_builtin_func(self): 3857 codestr = """ 3858 from xxclassloader import foo 3859 from __static__ import int64, box 3860 3861 def func(): 3862 a: int64 = foo() 3863 return box(a) 3864 3865 """ 3866 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3867 f = mod["func"] 3868 self.assertInBytecode(f, "INVOKE_FUNCTION", ((("xxclassloader", "foo"), 0))) 3869 self.assertEqual(f(), 42) 3870 self.assert_jitted(f) 3871 3872 def test_invoke_builtin_func_ret_neg(self): 3873 # setup xxclassloader as a built-in function for this test, so we can 3874 # do a direct invoke 3875 xxclassloader = sys.modules["xxclassloader"] 3876 try: 3877 sys.modules["xxclassloader"] = StrictModule(xxclassloader.__dict__, False) 3878 codestr = """ 3879 from xxclassloader import neg 3880 from __static__ import int64, box 3881 3882 def test(): 3883 x: int64 = neg() 3884 return box(x) 3885 """ 3886 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3887 test = mod["test"] 3888 self.assertEqual(test(), -1) 3889 finally: 3890 sys.modules["xxclassloader"] = xxclassloader 3891 3892 def test_invoke_builtin_func_arg(self): 3893 codestr = """ 3894 from xxclassloader import bar 3895 from __static__ import int64, box 3896 3897 def func(): 3898 a: int64 = bar(42) 3899 return box(a) 3900 3901 """ 3902 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3903 f = mod["func"] 3904 self.assertInBytecode(f, "INVOKE_FUNCTION", ((("xxclassloader", "bar"), 1))) 3905 self.assertEqual(f(), 42) 3906 self.assert_jitted(f) 3907 3908 def test_invoke_meth_o(self): 3909 codestr = """ 3910 from xxclassloader import spamobj 3911 3912 def func(): 3913 a = spamobj[int]() 3914 a.setstate_untyped(42) 3915 return a.getstate() 3916 3917 """ 3918 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3919 f = mod["func"] 3920 3921 self.assertInBytecode( 3922 f, 3923 "INVOKE_METHOD", 3924 ( 3925 ( 3926 "xxclassloader", 3927 "spamobj", 3928 (("builtins", "int"),), 3929 "setstate_untyped", 3930 ), 3931 1, 3932 ), 3933 ) 3934 self.assertEqual(f(), 42) 3935 self.assert_jitted(f) 3936 3937 def test_multi_generic(self): 3938 codestr = """ 3939 from xxclassloader import XXGeneric 3940 3941 def func(): 3942 a = XXGeneric[int, str]() 3943 return a.foo(42, 'abc') 3944 """ 3945 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3946 f = mod["func"] 3947 self.assertEqual(f(), "42abc") 3948 3949 def test_verify_positional_args(self): 3950 codestr = """ 3951 def x(a: int, b: str) -> None: 3952 pass 3953 x("a", 2) 3954 """ 3955 with self.assertRaisesRegex(TypedSyntaxError, "argument type mismatch"): 3956 self.compile(codestr, StaticCodeGenerator) 3957 3958 def test_verify_positional_args_unordered(self): 3959 codestr = """ 3960 def x(a: int, b: str) -> None: 3961 return y(a, b) 3962 def y(a: int, b: str) -> None: 3963 pass 3964 """ 3965 self.compile(codestr, StaticCodeGenerator) 3966 3967 def test_verify_kwargs(self): 3968 codestr = """ 3969 def x(a: int=1, b: str="hunter2") -> None: 3970 return 3971 x(b="lol", a=23) 3972 """ 3973 self.compile(codestr, StaticCodeGenerator) 3974 3975 def test_verify_kwdefaults(self): 3976 codestr = """ 3977 def x(*, b: str="hunter2"): 3978 return b 3979 z = x(b="lol") 3980 """ 3981 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3982 self.assertEqual(mod["z"], "lol") 3983 3984 def test_verify_kwdefaults_no_value(self): 3985 codestr = """ 3986 def x(*, b: str="hunter2"): 3987 return b 3988 a = x() 3989 """ 3990 module = self.compile(codestr, StaticCodeGenerator) 3991 # we don't yet support optimized dispatch to kw-only functions 3992 self.assertInBytecode(module, "CALL_FUNCTION") 3993 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 3994 self.assertEqual(mod["a"], "hunter2") 3995 3996 def test_verify_arg_dynamic_type(self): 3997 codestr = """ 3998 def x(v:str): 3999 return 'abc' 4000 def y(v): 4001 return x(v) 4002 """ 4003 module = self.compile(codestr, StaticCodeGenerator) 4004 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 4005 y = mod["y"] 4006 with self.assertRaises(TypeError): 4007 y(42) 4008 self.assertEqual(y("foo"), "abc") 4009 4010 def test_verify_arg_unknown_type(self): 4011 codestr = """ 4012 def x(x:foo): 4013 return b 4014 x('abc') 4015 """ 4016 module = self.compile(codestr, StaticCodeGenerator) 4017 self.assertInBytecode(module, "INVOKE_FUNCTION") 4018 x = self.find_code(module) 4019 self.assertInBytecode(x, "CHECK_ARGS", ()) 4020 4021 def test_dict_invoke(self): 4022 codestr = """ 4023 from __static__ import pydict 4024 def f(x): 4025 y: pydict = x 4026 return y.get('foo') 4027 """ 4028 with self.in_module(codestr) as mod: 4029 f = mod["f"] 4030 self.assertInBytecode(f, "INVOKE_METHOD", (("builtins", "dict", "get"), 1)) 4031 self.assertEqual(f({}), None) 4032 4033 def test_dict_invoke_ret(self): 4034 codestr = """ 4035 from __static__ import pydict 4036 def g(): return None 4037 def f(x): 4038 y: pydict = x 4039 z = y.get('foo') 4040 z = None # should be typed to dynamic 4041 return z 4042 """ 4043 with self.in_module(codestr) as mod: 4044 f = mod["f"] 4045 self.assertInBytecode(f, "INVOKE_METHOD", (("builtins", "dict", "get"), 1)) 4046 self.assertEqual(f({}), None) 4047 4048 def test_verify_kwarg_unknown_type(self): 4049 codestr = """ 4050 def x(x:foo): 4051 return b 4052 x(x='abc') 4053 """ 4054 module = self.compile(codestr, StaticCodeGenerator) 4055 self.assertInBytecode(module, "INVOKE_FUNCTION") 4056 x = self.find_code(module) 4057 self.assertInBytecode(x, "CHECK_ARGS", ()) 4058 4059 def test_verify_kwdefaults_too_many(self): 4060 codestr = """ 4061 def x(*, b: str="hunter2") -> None: 4062 return 4063 x('abc') 4064 """ 4065 with self.assertRaisesRegex( 4066 TypedSyntaxError, "x takes 0 positional args but 1 was given" 4067 ): 4068 self.compile(codestr, StaticCodeGenerator) 4069 4070 def test_verify_kwdefaults_too_many_class(self): 4071 codestr = """ 4072 class C: 4073 def x(self, *, b: str="hunter2") -> None: 4074 return 4075 C().x('abc') 4076 """ 4077 with self.assertRaisesRegex( 4078 TypedSyntaxError, "x takes 1 positional args but 2 were given" 4079 ): 4080 self.compile(codestr, StaticCodeGenerator) 4081 4082 def test_verify_kwonly_failure(self): 4083 codestr = """ 4084 def x(*, a: int=1, b: str="hunter2") -> None: 4085 return 4086 x(a="hi", b="lol") 4087 """ 4088 with self.assertRaisesRegex(TypedSyntaxError, "keyword argument type mismatch"): 4089 self.compile(codestr, StaticCodeGenerator) 4090 4091 def test_verify_kwonly_self_loaded_once(self): 4092 codestr = """ 4093 class C: 4094 def x(self, *, a: int=1) -> int: 4095 return 43 4096 4097 def f(): 4098 return C().x(a=1) 4099 """ 4100 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 4101 f = mod["f"] 4102 io = StringIO() 4103 dis.dis(f, file=io) 4104 self.assertEqual(1, io.getvalue().count("LOAD_GLOBAL")) 4105 4106 def test_call_function_unknown_ret_type(self): 4107 codestr = """ 4108 from __future__ import annotations 4109 def g() -> foo: 4110 return 42 4111 4112 def testfunc(): 4113 return g() 4114 """ 4115 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 4116 f = mod["testfunc"] 4117 self.assertEqual(f(), 42) 4118 4119 def test_verify_kwargs_failure(self): 4120 codestr = """ 4121 def x(a: int=1, b: str="hunter2") -> None: 4122 return 4123 x(a="hi", b="lol") 4124 """ 4125 with self.assertRaisesRegex(TypedSyntaxError, "keyword argument type mismatch"): 4126 self.compile(codestr, StaticCodeGenerator) 4127 4128 def test_verify_mixed_args(self): 4129 codestr = """ 4130 def x(a: int=1, b: str="hunter2", c: int=14) -> None: 4131 return 4132 x(12, c=56, b="lol") 4133 """ 4134 self.compile(codestr, StaticCodeGenerator) 4135 4136 def test_kwarg_cast(self): 4137 codestr = """ 4138 def x(a: int=1, b: str="hunter2", c: int=14) -> None: 4139 return 4140 4141 def g(a): 4142 x(b=a) 4143 """ 4144 code = self.find_code(self.compile(codestr, StaticCodeGenerator), "g") 4145 self.assertInBytecode(code, "CAST", ("builtins", "str")) 4146 4147 def test_kwarg_nocast(self): 4148 codestr = """ 4149 def x(a: int=1, b: str="hunter2", c: int=14) -> None: 4150 return 4151 4152 def g(): 4153 x(b='abc') 4154 """ 4155 code = self.find_code(self.compile(codestr, StaticCodeGenerator), "g") 4156 self.assertNotInBytecode(code, "CAST", ("builtins", "str")) 4157 4158 def test_verify_mixed_args_kw_failure(self): 4159 codestr = """ 4160 def x(a: int=1, b: str="hunter2", c: int=14) -> None: 4161 return 4162 x(12, c="hi", b="lol") 4163 """ 4164 with self.assertRaisesRegex(TypedSyntaxError, "keyword argument type mismatch"): 4165 self.compile(codestr, StaticCodeGenerator) 4166 4167 def test_verify_mixed_args_positional_failure(self): 4168 codestr = """ 4169 def x(a: int=1, b: str="hunter2", c: int=14) -> None: 4170 return 4171 x("hi", b="lol") 4172 """ 4173 with self.assertRaisesRegex( 4174 TypedSyntaxError, "positional argument type mismatch" 4175 ): 4176 self.compile(codestr, StaticCodeGenerator) 4177 4178 # Same tests as above, but for methods. 4179 def test_verify_positional_args_method(self): 4180 codestr = """ 4181 class C: 4182 def x(self, a: int, b: str) -> None: 4183 pass 4184 C().x(2, "hi") 4185 """ 4186 self.compile(codestr, StaticCodeGenerator) 4187 4188 def test_verify_positional_args_failure_method(self): 4189 codestr = """ 4190 class C: 4191 def x(self, a: int, b: str) -> None: 4192 pass 4193 C().x("a", 2) 4194 """ 4195 with self.assertRaisesRegex( 4196 TypedSyntaxError, "positional argument type mismatch" 4197 ): 4198 self.compile(codestr, StaticCodeGenerator) 4199 4200 def test_verify_mixed_args_method(self): 4201 codestr = """ 4202 class C: 4203 def x(self, a: int=1, b: str="hunter2", c: int=14) -> None: 4204 return 4205 C().x(12, c=56, b="lol") 4206 """ 4207 self.compile(codestr, StaticCodeGenerator) 4208 4209 def test_starargs_invoked_once(self): 4210 codestr = """ 4211 X = 0 4212 4213 def f(): 4214 global X 4215 X += 1 4216 return {"a": 1, "b": "foo", "c": 42} 4217 4218 class C: 4219 def x(self, a: int=1, b: str="hunter2", c: int=14) -> None: 4220 return 4221 C().x(12, **f()) 4222 """ 4223 with self.in_module(codestr) as mod: 4224 x = mod["X"] 4225 self.assertEqual(x, 1) 4226 4227 def test_starargs_invoked_in_order(self): 4228 codestr = """ 4229 X = 1 4230 4231 def f(): 4232 global X 4233 X += 1 4234 return {"a": 1, "b": "foo"} 4235 4236 def make_c(): 4237 global X 4238 X *= 2 4239 return 42 4240 4241 class C: 4242 def x(self, a: int=1, b: str="hunter2", c: int=14) -> None: 4243 return 4244 4245 def test(): 4246 C().x(12, c=make_c(), **f()) 4247 """ 4248 with self.in_module(codestr) as mod: 4249 test = mod["test"] 4250 test() 4251 x = mod["X"] 4252 self.assertEqual(x, 3) 4253 4254 def test_verify_mixed_args_kw_failure_method(self): 4255 codestr = """ 4256 class C: 4257 def x(self, a: int=1, b: str="hunter2", c: int=14) -> None: 4258 return 4259 C().x(12, c=b'lol', b="lol") 4260 """ 4261 with self.assertRaisesRegex(TypedSyntaxError, "keyword argument type mismatch"): 4262 self.compile(codestr, StaticCodeGenerator) 4263 4264 def test_verify_mixed_args_positional_failure_method(self): 4265 codestr = """ 4266 class C: 4267 def x(self, a: int=1, b: str="hunter2", c: int=14) -> None: 4268 return 4269 C().x("hi", b="lol") 4270 """ 4271 with self.assertRaisesRegex( 4272 TypedSyntaxError, "positional argument type mismatch" 4273 ): 4274 self.compile(codestr, StaticCodeGenerator) 4275 4276 def test_generic_kwargs_unsupported(self): 4277 # definition is allowed, we just don't do an optimal invoke 4278 codestr = """ 4279 def f(a: int, b: str, **my_stuff) -> None: 4280 pass 4281 4282 def g(): 4283 return f(1, 'abc', x="y") 4284 """ 4285 with self.in_module(codestr) as mod: 4286 g = mod["g"] 4287 self.assertInBytecode(g, "CALL_FUNCTION_KW", 3) 4288 4289 def test_generic_kwargs_method_unsupported(self): 4290 # definition is allowed, we just don't do an optimal invoke 4291 codestr = """ 4292 class C: 4293 def f(self, a: int, b: str, **my_stuff) -> None: 4294 pass 4295 4296 def g(): 4297 return C().f(1, 'abc', x="y") 4298 """ 4299 with self.in_module(codestr) as mod: 4300 g = mod["g"] 4301 self.assertInBytecode(g, "CALL_FUNCTION_KW", 3) 4302 4303 def test_generic_varargs_unsupported(self): 4304 # definition is allowed, we just don't do an optimal invoke 4305 codestr = """ 4306 def f(a: int, b: str, *my_stuff) -> None: 4307 pass 4308 4309 def g(): 4310 return f(1, 'abc', "foo") 4311 """ 4312 with self.in_module(codestr) as mod: 4313 g = mod["g"] 4314 self.assertInBytecode(g, "CALL_FUNCTION", 3) 4315 4316 def test_generic_varargs_method_unsupported(self): 4317 # definition is allowed, we just don't do an optimal invoke 4318 codestr = """ 4319 class C: 4320 def f(self, a: int, b: str, *my_stuff) -> None: 4321 pass 4322 4323 def g(): 4324 return C().f(1, 'abc', "foo") 4325 """ 4326 with self.in_module(codestr) as mod: 4327 g = mod["g"] 4328 self.assertInBytecode(g, "CALL_METHOD", 3) 4329 4330 def test_kwargs_get(self): 4331 codestr = """ 4332 def test(**foo): 4333 print(foo.get('bar')) 4334 """ 4335 4336 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 4337 test = mod["test"] 4338 self.assertInBytecode( 4339 test, "INVOKE_FUNCTION", (("builtins", "dict", "get"), 2) 4340 ) 4341 4342 def test_varargs_count(self): 4343 codestr = """ 4344 def test(*foo): 4345 print(foo.count('bar')) 4346 """ 4347 4348 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 4349 test = mod["test"] 4350 self.assertInBytecode( 4351 test, "INVOKE_FUNCTION", (("builtins", "tuple", "count"), 2) 4352 ) 4353 4354 def test_varargs_call(self): 4355 codestr = """ 4356 def g(*foo): 4357 return foo 4358 4359 def testfunc(): 4360 return g(2) 4361 """ 4362 4363 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 4364 test = mod["testfunc"] 4365 self.assertEqual(test(), (2,)) 4366 4367 def test_kwargs_call(self): 4368 codestr = """ 4369 def g(**foo): 4370 return foo 4371 4372 def testfunc(): 4373 return g(x=2) 4374 """ 4375 4376 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 4377 test = mod["testfunc"] 4378 self.assertEqual(test(), {"x": 2}) 4379 4380 def test_index_by_int(self): 4381 codestr = """ 4382 from __static__ import int32 4383 def f(x): 4384 i: int32 = 0 4385 return x[i] 4386 """ 4387 with self.assertRaises(TypedSyntaxError): 4388 self.compile(codestr, StaticCodeGenerator) 4389 4390 def test_pydict_arg_annotation(self): 4391 codestr = """ 4392 from __static__ import PyDict 4393 def f(d: PyDict[str, int]) -> str: 4394 # static python ignores the untrusted generic types 4395 return d[3] 4396 """ 4397 with self.in_module(codestr) as mod: 4398 self.assertEqual(mod["f"]({3: "foo"}), "foo") 4399 4400 def test_list_get_primitive_int(self): 4401 codestr = """ 4402 from __static__ import int8 4403 def f(): 4404 l = [1, 2, 3] 4405 x: int8 = 1 4406 return l[x] 4407 """ 4408 f = self.find_code(self.compile(codestr)) 4409 self.assertInBytecode(f, "SEQUENCE_GET", SEQ_LIST) 4410 with self.in_module(codestr) as mod: 4411 self.assertEqual(mod["f"](), 2) 4412 4413 def test_list_set_primitive_int(self): 4414 codestr = """ 4415 from __static__ import int8 4416 def f(): 4417 l = [1, 2, 3] 4418 x: int8 = 1 4419 l[x] = 5 4420 return l 4421 """ 4422 f = self.find_code(self.compile(codestr)) 4423 self.assertInBytecode(f, "SEQUENCE_SET") 4424 with self.in_module(codestr) as mod: 4425 self.assertEqual(mod["f"](), [1, 5, 3]) 4426 4427 def test_list_set_primitive_int_2(self): 4428 codestr = """ 4429 from __static__ import int64 4430 def f(l1): 4431 l2 = [None] * len(l1) 4432 i: int64 = 0 4433 for item in l1: 4434 l2[i] = item + 1 4435 i += 1 4436 return l2 4437 """ 4438 f = self.find_code(self.compile(codestr)) 4439 self.assertInBytecode(f, "SEQUENCE_SET") 4440 with self.in_module(codestr) as mod: 4441 self.assertEqual(mod["f"]([1, 2000]), [2, 2001]) 4442 4443 def test_list_del_primitive_int(self): 4444 codestr = """ 4445 from __static__ import int8 4446 def f(): 4447 l = [1, 2, 3] 4448 x: int8 = 1 4449 del l[x] 4450 return l 4451 """ 4452 f = self.find_code(self.compile(codestr)) 4453 self.assertInBytecode(f, "LIST_DEL") 4454 with self.in_module(codestr) as mod: 4455 self.assertEqual(mod["f"](), [1, 3]) 4456 4457 def test_list_append(self): 4458 codestr = """ 4459 from __static__ import int8 4460 def f(): 4461 l = [1, 2, 3] 4462 l.append(4) 4463 return l 4464 """ 4465 f = self.find_code(self.compile(codestr)) 4466 self.assertInBytecode(f, "LIST_APPEND", 1) 4467 with self.in_module(codestr) as mod: 4468 self.assertEqual(mod["f"](), [1, 2, 3, 4]) 4469 4470 def test_unknown_type_unary(self): 4471 codestr = """ 4472 def x(y): 4473 z = -y 4474 """ 4475 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo")) 4476 self.assertInBytecode(f, "UNARY_NEGATIVE") 4477 4478 def test_unknown_type_binary(self): 4479 codestr = """ 4480 def x(a, b): 4481 z = a + b 4482 """ 4483 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo")) 4484 self.assertInBytecode(f, "BINARY_ADD") 4485 4486 def test_unknown_type_compare(self): 4487 codestr = """ 4488 def x(a, b): 4489 z = a > b 4490 """ 4491 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo")) 4492 self.assertInBytecode(f, "COMPARE_OP") 4493 4494 def test_async_func_ret_type(self): 4495 codestr = """ 4496 async def x(a) -> int: 4497 return a 4498 """ 4499 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo")) 4500 self.assertInBytecode(f, "CAST") 4501 4502 def test_async_func_arg_types(self): 4503 codestr = """ 4504 async def f(x: int): 4505 pass 4506 """ 4507 f = self.find_code(self.compile(codestr)) 4508 self.assertInBytecode(f, "CHECK_ARGS", (0, ("builtins", "int"))) 4509 4510 def test_assign_prim_to_class(self): 4511 codestr = """ 4512 from __static__ import int64 4513 class C: pass 4514 4515 def f(): 4516 x: C = C() 4517 y: int64 = 42 4518 x = y 4519 """ 4520 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("int64", "foo.C")): 4521 self.compile(codestr, StaticCodeGenerator, modname="foo") 4522 4523 def test_field_refcount(self): 4524 codestr = """ 4525 class C: 4526 def __init__(self): 4527 self.x = None 4528 4529 def set_x(self, x): 4530 self.x = x 4531 """ 4532 count = 0 4533 with self.in_module(codestr) as mod: 4534 C = mod["C"] 4535 4536 class X: 4537 def __init__(self): 4538 nonlocal count 4539 count += 1 4540 4541 def __del__(self): 4542 nonlocal count 4543 count -= 1 4544 4545 c = C() 4546 c.set_x(X()) 4547 c.set_x(X()) 4548 self.assertEqual(count, 1) 4549 del c 4550 self.assertEqual(count, 0) 4551 4552 def test_typed_field_del(self): 4553 codestr = """ 4554 class D: 4555 def __init__(self, counter): 4556 self.counter = counter 4557 self.counter[0] += 1 4558 4559 def __del__(self): 4560 self.counter[0] -= 1 4561 4562 class C: 4563 def __init__(self, value: D): 4564 self.x: D = value 4565 4566 def __del__(self): 4567 del self.x 4568 """ 4569 count = 0 4570 with self.in_module(codestr) as mod: 4571 D = mod["D"] 4572 counter = [0] 4573 d = D(counter) 4574 4575 C = mod["C"] 4576 a = C(d) 4577 del d 4578 self.assertEqual(counter[0], 1) 4579 del a 4580 self.assertEqual(counter[0], 0) 4581 4582 def test_typed_field_deleted_attr(self): 4583 codestr = """ 4584 class C: 4585 def __init__(self, value: str): 4586 self.x: str = value 4587 """ 4588 count = 0 4589 with self.in_module(codestr) as mod: 4590 C = mod["C"] 4591 a = C("abc") 4592 del a.x 4593 with self.assertRaises(AttributeError): 4594 a.x 4595 4596 def test_generic_method_ret_type(self): 4597 codestr = """ 4598 from __static__ import CheckedDict 4599 4600 from typing import Optional 4601 MAP: CheckedDict[str, Optional[str]] = CheckedDict[str, Optional[str]]({'abc': 'foo', 'bar': None}) 4602 def f(x: str) -> Optional[str]: 4603 return MAP.get(x) 4604 """ 4605 4606 with self.in_module(codestr) as mod: 4607 f = mod["f"] 4608 self.assertInBytecode( 4609 f, 4610 "INVOKE_METHOD", 4611 ( 4612 ( 4613 "__static__", 4614 "chkdict", 4615 (("builtins", "str"), ("builtins", "str", "?")), 4616 "get", 4617 ), 4618 2, 4619 ), 4620 ) 4621 self.assertEqual(f("abc"), "foo") 4622 self.assertEqual(f("bar"), None) 4623 4624 4625 def test_attr_generic_optional(self): 4626 codestr = """ 4627 from typing import Optional 4628 def f(x: Optional): 4629 return x.foo 4630 """ 4631 4632 with self.assertRaisesRegex( 4633 TypedSyntaxError, "cannot access attribute from unbound Union" 4634 ): 4635 self.compile(codestr, StaticCodeGenerator, modname="foo") 4636 4637 def test_assign_generic_optional(self): 4638 codestr = """ 4639 from typing import Optional 4640 def f(): 4641 x: Optional = 42 4642 """ 4643 4644 with self.assertRaisesRegex( 4645 TypedSyntaxError, type_mismatch("Exact[int]", "Optional[T]") 4646 ): 4647 self.compile(codestr, StaticCodeGenerator, modname="foo") 4648 4649 def test_assign_generic_optional_2(self): 4650 codestr = """ 4651 from typing import Optional 4652 def f(): 4653 x: Optional = 42 + 1 4654 """ 4655 4656 with self.assertRaises(TypedSyntaxError): 4657 self.compile(codestr, StaticCodeGenerator, modname="foo") 4658 4659 def test_assign_from_generic_optional(self): 4660 codestr = """ 4661 from typing import Optional 4662 class C: pass 4663 4664 def f(x: Optional): 4665 y: Optional[C] = x 4666 """ 4667 4668 with self.assertRaisesRegex( 4669 TypedSyntaxError, type_mismatch("Optional[T]", optional("foo.C")) 4670 ): 4671 self.compile(codestr, StaticCodeGenerator, modname="foo") 4672 4673 def test_list_of_dynamic(self): 4674 codestr = """ 4675 from threading import Thread 4676 from typing import List 4677 4678 def f(threads: List[Thread]) -> int: 4679 return len(threads) 4680 """ 4681 f = self.find_code(self.compile(codestr), "f") 4682 self.assertInBytecode(f, "FAST_LEN") 4683 4684 def test_int_swap(self): 4685 codestr = """ 4686 from __static__ import int64, box 4687 4688 def test(): 4689 x: int64 = 42 4690 y: int64 = 100 4691 x, y = y, x 4692 return box(x), box(y) 4693 """ 4694 4695 with self.assertRaisesRegex( 4696 TypedSyntaxError, type_mismatch("int64", "dynamic") 4697 ): 4698 self.compile(codestr, StaticCodeGenerator, modname="foo") 4699 4700 def test_typed_swap(self): 4701 codestr = """ 4702 def test(a): 4703 x: int 4704 y: str 4705 x, y = 1, a 4706 """ 4707 4708 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo")) 4709 self.assertInBytecode(f, "CAST", ("builtins", "str")) 4710 self.assertNotInBytecode(f, "CAST", ("builtins", "int")) 4711 4712 def test_typed_swap_2(self): 4713 codestr = """ 4714 def test(a): 4715 x: int 4716 y: str 4717 x, y = a, 'abc' 4718 4719 """ 4720 4721 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo")) 4722 self.assertInBytecode(f, "CAST", ("builtins", "int")) 4723 self.assertNotInBytecode(f, "CAST", ("builtins", "str")) 4724 4725 def test_typed_swap_member(self): 4726 codestr = """ 4727 class C: 4728 def __init__(self): 4729 self.x: int = 42 4730 4731 def test(a): 4732 x: int 4733 y: str 4734 C().x, y = a, 'abc' 4735 4736 """ 4737 4738 f = self.find_code( 4739 self.compile(codestr, StaticCodeGenerator, modname="foo"), "test" 4740 ) 4741 self.assertInBytecode(f, "CAST", ("builtins", "int")) 4742 self.assertNotInBytecode(f, "CAST", ("builtins", "str")) 4743 4744 def test_typed_swap_list(self): 4745 codestr = """ 4746 def test(a): 4747 x: int 4748 y: str 4749 [x, y] = a, 'abc' 4750 """ 4751 4752 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo")) 4753 self.assertInBytecode(f, "CAST", ("builtins", "int")) 4754 self.assertNotInBytecode(f, "CAST", ("builtins", "str")) 4755 4756 def test_typed_swap_nested(self): 4757 codestr = """ 4758 def test(a): 4759 x: int 4760 y: str 4761 z: str 4762 ((x, y), z) = (a, 'abc'), 'foo' 4763 """ 4764 4765 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo")) 4766 self.assertInBytecode(f, "CAST", ("builtins", "int")) 4767 self.assertNotInBytecode(f, "CAST", ("builtins", "str")) 4768 4769 def test_typed_swap_nested_2(self): 4770 codestr = """ 4771 def test(a): 4772 x: int 4773 y: str 4774 z: str 4775 ((x, y), z) = (1, a), 'foo' 4776 4777 """ 4778 4779 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo")) 4780 self.assertInBytecode(f, "CAST", ("builtins", "str")) 4781 self.assertNotInBytecode(f, "CAST", ("builtins", "int")) 4782 4783 def test_typed_swap_nested_3(self): 4784 codestr = """ 4785 def test(a): 4786 x: int 4787 y: int 4788 z: str 4789 ((x, y), z) = (1, 2), a 4790 4791 """ 4792 4793 f = self.find_code(self.compile(codestr, StaticCodeGenerator, modname="foo")) 4794 self.assertInBytecode(f, "CAST", ("builtins", "str")) 4795 # Currently because the tuple gets turned into a constant this is less than 4796 # ideal: 4797 self.assertInBytecode(f, "CAST", ("builtins", "int")) 4798 4799 def test_if_optional(self): 4800 codestr = """ 4801 from typing import Optional 4802 class C: 4803 def __init__(self): 4804 self.field = 42 4805 4806 def f(x: Optional[C]): 4807 if x is not None: 4808 return x.field 4809 4810 return None 4811 """ 4812 4813 self.compile(codestr, StaticCodeGenerator, modname="foo") 4814 4815 def test_return_outside_func(self): 4816 codestr = """ 4817 return 42 4818 """ 4819 with self.assertRaisesRegex(SyntaxError, "'return' outside function"): 4820 self.compile(codestr, StaticCodeGenerator, modname="foo") 4821 4822 def test_double_decl(self): 4823 codestr = """ 4824 def f(): 4825 x: int 4826 x: str 4827 """ 4828 with self.assertRaisesRegex( 4829 TypedSyntaxError, "Cannot redefine local variable x" 4830 ): 4831 self.compile(codestr, StaticCodeGenerator, modname="foo") 4832 4833 codestr = """ 4834 def f(): 4835 x = 42 4836 x: str 4837 """ 4838 with self.assertRaisesRegex( 4839 TypedSyntaxError, "Cannot redefine local variable x" 4840 ): 4841 self.compile(codestr, StaticCodeGenerator, modname="foo") 4842 4843 codestr = """ 4844 def f(): 4845 x = 42 4846 x: int 4847 """ 4848 with self.assertRaisesRegex( 4849 TypedSyntaxError, "Cannot redefine local variable x" 4850 ): 4851 self.compile(codestr, StaticCodeGenerator, modname="foo") 4852 4853 def test_if_else_optional(self): 4854 codestr = """ 4855 from typing import Optional 4856 class C: 4857 def __init__(self): 4858 self.field = self 4859 4860 def g(x: C): 4861 pass 4862 4863 def f(x: Optional[C], y: Optional[C]): 4864 if x is None: 4865 x = y 4866 if x is None: 4867 return None 4868 else: 4869 return g(x) 4870 else: 4871 return g(x) 4872 4873 4874 return None 4875 """ 4876 4877 self.compile(codestr, StaticCodeGenerator, modname="foo") 4878 4879 def test_if_else_optional_return(self): 4880 codestr = """ 4881 from typing import Optional 4882 class C: 4883 def __init__(self): 4884 self.field = self 4885 4886 def f(x: Optional[C]): 4887 if x is None: 4888 return 0 4889 return x.field 4890 """ 4891 4892 self.compile(codestr, StaticCodeGenerator, modname="foo") 4893 4894 def test_if_else_optional_return_two_branches(self): 4895 codestr = """ 4896 from typing import Optional 4897 class C: 4898 def __init__(self): 4899 self.field = self 4900 4901 def f(x: Optional[C]): 4902 if x is None: 4903 if a: 4904 return 0 4905 else: 4906 return 2 4907 return x.field 4908 """ 4909 4910 self.compile(codestr, StaticCodeGenerator, modname="foo") 4911 4912 def test_if_else_optional_return_in_else(self): 4913 codestr = """ 4914 from typing import Optional 4915 4916 def f(x: Optional[int]) -> int: 4917 if x is not None: 4918 pass 4919 else: 4920 return 2 4921 return x 4922 """ 4923 4924 self.compile(codestr, StaticCodeGenerator, modname="foo") 4925 4926 def test_if_else_optional_return_in_else_assignment_in_if(self): 4927 codestr = """ 4928 from typing import Optional 4929 4930 def f(x: Optional[int]) -> int: 4931 if x is None: 4932 x = 1 4933 else: 4934 return 2 4935 return x 4936 """ 4937 4938 self.compile(codestr, StaticCodeGenerator, modname="foo") 4939 4940 def test_if_else_optional_return_in_if_assignment_in_else(self): 4941 codestr = """ 4942 from typing import Optional 4943 4944 def f(x: Optional[int]) -> int: 4945 if x is not None: 4946 return 2 4947 else: 4948 x = 1 4949 return x 4950 """ 4951 4952 self.compile(codestr, StaticCodeGenerator, modname="foo") 4953 4954 def test_narrow_conditional(self): 4955 codestr = """ 4956 class B: 4957 def f(self): 4958 return 42 4959 class D(B): 4960 def f(self): 4961 return 'abc' 4962 4963 def testfunc(abc): 4964 x = B() 4965 if abc: 4966 x = D() 4967 return x.f() 4968 """ 4969 4970 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 4971 f = self.find_code(code, "testfunc") 4972 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "D", "f"), 0)) 4973 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 4974 test = mod["testfunc"] 4975 self.assertEqual(test(True), "abc") 4976 self.assertEqual(test(False), None) 4977 4978 def test_no_narrow_to_dynamic(self): 4979 codestr = """ 4980 def f(): 4981 return 42 4982 4983 def g(): 4984 x: int = 100 4985 x = f() 4986 return x.bit_length() 4987 """ 4988 4989 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 4990 g = mod["g"] 4991 self.assertInBytecode(g, "CAST", ("builtins", "int")) 4992 self.assertInBytecode( 4993 g, "INVOKE_METHOD", (("builtins", "int", "bit_length"), 0) 4994 ) 4995 self.assertEqual(g(), 6) 4996 4997 def test_global_uses_decl_type(self): 4998 codestr = """ 4999 # even though we can locally infer G must be None, 5000 # it's not Final so nested scopes can't assume it 5001 # remains None 5002 G: int | None = None 5003 5004 def f() -> int: 5005 global G 5006 # if we use the local_type for G's type, 5007 # x would have a local type of None 5008 x: int | None = G 5009 if x is None: 5010 x = G = 1 5011 return x 5012 """ 5013 with self.in_strict_module(codestr) as mod: 5014 self.assertEqual(mod.f(), 1) 5015 5016 def test_module_level_type_narrow(self): 5017 codestr = """ 5018 def a() -> int | None: 5019 return 1 5020 5021 G = a() 5022 if G is not None: 5023 G += 1 5024 5025 def f() -> int: 5026 if G is None: 5027 return 0 5028 reveal_type(G) 5029 """ 5030 with self.assertRaisesRegex(TypedSyntaxError, r"Optional\[int\]"): 5031 self.compile(codestr) 5032 5033 def test_narrow_conditional_widened(self): 5034 codestr = """ 5035 class B: 5036 def f(self): 5037 return 42 5038 class D(B): 5039 def f(self): 5040 return 'abc' 5041 5042 def testfunc(abc): 5043 x = B() 5044 if abc: 5045 x = D() 5046 return x.f() 5047 """ 5048 5049 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5050 f = self.find_code(code, "testfunc") 5051 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5052 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5053 test = mod["testfunc"] 5054 self.assertEqual(test(True), "abc") 5055 self.assertEqual(test(False), 42) 5056 5057 def test_widen_to_dynamic(self): 5058 self.assertReturns( 5059 """ 5060 def f(x, flag): 5061 if flag: 5062 x = 3 5063 return x 5064 """, 5065 "dynamic", 5066 ) 5067 5068 def test_assign_conditional_both_sides(self): 5069 codestr = """ 5070 class B: 5071 def f(self): 5072 return 42 5073 class D(B): 5074 def f(self): 5075 return 'abc' 5076 5077 def testfunc(abc): 5078 x = B() 5079 if abc: 5080 x = D() 5081 else: 5082 x = D() 5083 return x.f() 5084 """ 5085 5086 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5087 f = self.find_code(code, "testfunc") 5088 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "D", "f"), 0)) 5089 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5090 test = mod["testfunc"] 5091 self.assertEqual(test(True), "abc") 5092 5093 def test_assign_conditional_invoke_in_else(self): 5094 codestr = """ 5095 class B: 5096 def f(self): 5097 return 42 5098 class D(B): 5099 def f(self): 5100 return 'abc' 5101 5102 def testfunc(abc): 5103 x = B() 5104 if abc: 5105 x = D() 5106 else: 5107 return x.f() 5108 """ 5109 5110 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5111 f = self.find_code(code, "testfunc") 5112 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5113 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5114 test = mod["testfunc"] 5115 self.assertEqual(test(True), None) 5116 5117 def test_assign_else_only(self): 5118 codestr = """ 5119 class B: 5120 def f(self): 5121 return 42 5122 class D(B): 5123 def f(self): 5124 return 'abc' 5125 5126 def testfunc(abc): 5127 if abc: 5128 pass 5129 else: 5130 x = B() 5131 return x.f() 5132 """ 5133 5134 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5135 f = self.find_code(code, "testfunc") 5136 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5137 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5138 test = mod["testfunc"] 5139 self.assertEqual(test(False), 42) 5140 5141 def test_assign_test_var(self): 5142 codestr = """ 5143 from typing import Optional 5144 5145 def f(x: Optional[int]) -> int: 5146 if x is None: 5147 x = 1 5148 return x 5149 """ 5150 5151 self.compile(codestr, StaticCodeGenerator, modname="foo") 5152 5153 def test_assign_while(self): 5154 codestr = """ 5155 class B: 5156 def f(self): 5157 return 42 5158 class D(B): 5159 def f(self): 5160 return 'abc' 5161 5162 def testfunc(abc): 5163 x = B() 5164 while abc: 5165 x = D() 5166 return x.f() 5167 """ 5168 5169 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5170 f = self.find_code(code, "testfunc") 5171 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5172 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5173 test = mod["testfunc"] 5174 self.assertEqual(test(False), 42) 5175 5176 def test_assign_while_test_var(self): 5177 codestr = """ 5178 from typing import Optional 5179 5180 def f(x: Optional[int]) -> int: 5181 while x is None: 5182 x = 1 5183 return x 5184 """ 5185 self.compile(codestr, StaticCodeGenerator, modname="foo") 5186 5187 def test_assign_while_returns(self): 5188 codestr = """ 5189 from typing import Optional 5190 5191 def f(x: Optional[int]) -> int: 5192 while x is None: 5193 return 1 5194 return x 5195 """ 5196 self.compile(codestr, StaticCodeGenerator, modname="foo") 5197 5198 def test_assign_while_returns_but_assigns_first(self): 5199 codestr = """ 5200 from typing import Optional 5201 5202 def f(x: Optional[int]) -> int: 5203 y: Optional[int] = 1 5204 while x is None: 5205 y = None 5206 return 1 5207 return y 5208 """ 5209 self.compile(codestr, StaticCodeGenerator, modname="foo") 5210 5211 def test_while_else_reverses_condition(self): 5212 codestr = """ 5213 from typing import Optional 5214 5215 def f(x: Optional[int]) -> int: 5216 while x is None: 5217 pass 5218 else: 5219 return x 5220 return 1 5221 """ 5222 self.compile(codestr, StaticCodeGenerator, modname="foo") 5223 5224 def test_continue_condition(self): 5225 codestr = """ 5226 from typing import Optional 5227 5228 def f(x: Optional[str]) -> str: 5229 while True: 5230 if x is None: 5231 continue 5232 return x 5233 """ 5234 self.compile(codestr, StaticCodeGenerator, modname="foo") 5235 5236 def test_break_condition(self): 5237 codestr = """ 5238 from typing import Optional 5239 5240 def f(x: Optional[str]) -> str: 5241 while True: 5242 if x is None: 5243 break 5244 return x 5245 """ 5246 self.compile(codestr, StaticCodeGenerator, modname="foo") 5247 5248 def test_assign_but_annotated(self): 5249 codestr = """ 5250 class B: 5251 def f(self): 5252 return 42 5253 class D(B): 5254 def f(self): 5255 return 'abc' 5256 5257 def testfunc(abc): 5258 x: B = D() 5259 return x.f() 5260 """ 5261 5262 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5263 f = self.find_code(code, "testfunc") 5264 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "D", "f"), 0)) 5265 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5266 test = mod["testfunc"] 5267 self.assertEqual(test(False), "abc") 5268 5269 def test_assign_while_2(self): 5270 codestr = """ 5271 class B: 5272 def f(self): 5273 return 42 5274 class D(B): 5275 def f(self): 5276 return 'abc' 5277 5278 def testfunc(abc): 5279 x: B = D() 5280 while abc: 5281 x = B() 5282 return x.f() 5283 """ 5284 5285 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5286 f = self.find_code(code, "testfunc") 5287 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5288 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5289 test = mod["testfunc"] 5290 self.assertEqual(test(False), "abc") 5291 5292 def test_assign_while_else(self): 5293 codestr = """ 5294 class B: 5295 def f(self): 5296 return 42 5297 class D(B): 5298 def f(self): 5299 return 'abc' 5300 5301 def testfunc(abc): 5302 x = B() 5303 while abc: 5304 pass 5305 else: 5306 x = D() 5307 return x.f() 5308 """ 5309 5310 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5311 f = self.find_code(code, "testfunc") 5312 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5313 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5314 test = mod["testfunc"] 5315 self.assertEqual(test(False), "abc") 5316 5317 def test_assign_while_else_2(self): 5318 codestr = """ 5319 class B: 5320 def f(self): 5321 return 42 5322 class D(B): 5323 def f(self): 5324 return 'abc' 5325 5326 def testfunc(abc): 5327 x: B = D() 5328 while abc: 5329 pass 5330 else: 5331 x = B() 5332 return x.f() 5333 """ 5334 5335 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5336 f = self.find_code(code, "testfunc") 5337 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5338 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5339 test = mod["testfunc"] 5340 self.assertEqual(test(False), 42) 5341 5342 def test_assign_try_except_no_initial(self): 5343 codestr = """ 5344 class B: 5345 def f(self): 5346 return 42 5347 class D(B): 5348 def f(self): 5349 return 'abc' 5350 5351 def testfunc(): 5352 try: 5353 x: B = D() 5354 except: 5355 x = B() 5356 return x.f() 5357 """ 5358 5359 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5360 f = self.find_code(code, "testfunc") 5361 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5362 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5363 test = mod["testfunc"] 5364 self.assertEqual(test(), "abc") 5365 5366 def test_narrow_or(self): 5367 codestr = """ 5368 def f(x: int | None) -> int: 5369 if x is None or x > 1: 5370 x = 1 5371 return x 5372 """ 5373 self.compile(codestr) 5374 5375 def test_type_of_or(self): 5376 codestr = """ 5377 def f(x: int, y: str) -> int | str: 5378 return x or y 5379 """ 5380 self.compile(codestr) 5381 5382 def test_none_annotation(self): 5383 codestr = """ 5384 from typing import Optional 5385 5386 def f(x: Optional[int]) -> None: 5387 return x 5388 """ 5389 with self.assertRaisesRegex( 5390 TypedSyntaxError, 5391 type_mismatch("Optional[int]", "None"), 5392 ): 5393 self.compile(codestr, StaticCodeGenerator, modname="foo") 5394 5395 def test_none_compare(self): 5396 codestr = """ 5397 def f(x: int | None): 5398 if x > 1: 5399 x = 1 5400 return x 5401 """ 5402 with self.assertRaisesRegex( 5403 TypedSyntaxError, 5404 r"'>' not supported between 'Optional\[int\]' and 'Exact\[int\]'", 5405 ): 5406 self.compile(codestr) 5407 5408 def test_none_compare_reverse(self): 5409 codestr = """ 5410 def f(x: int | None): 5411 if 1 > x: 5412 x = 1 5413 return x 5414 """ 5415 with self.assertRaisesRegex( 5416 TypedSyntaxError, 5417 r"'>' not supported between 'Exact\[int\]' and 'Optional\[int\]'", 5418 ): 5419 self.compile(codestr) 5420 5421 def test_union_compare(self): 5422 codestr = """ 5423 def f(x: int | float) -> bool: 5424 return x > 0 5425 """ 5426 with self.in_strict_module(codestr) as mod: 5427 self.assertEqual(mod.f(3), True) 5428 self.assertEqual(mod.f(3.1), True) 5429 self.assertEqual(mod.f(-3), False) 5430 self.assertEqual(mod.f(-3.1), False) 5431 5432 def test_global_int(self): 5433 codestr = """ 5434 X: int = 60 * 60 * 24 5435 """ 5436 5437 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5438 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5439 X = mod["X"] 5440 self.assertEqual(X, 60 * 60 * 24) 5441 5442 def test_with_traceback(self): 5443 codestr = """ 5444 def f(): 5445 x = Exception() 5446 return x.with_traceback(None) 5447 """ 5448 5449 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5450 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5451 f = mod["f"] 5452 self.assertEqual(type(f()), Exception) 5453 self.assertInBytecode( 5454 f, "INVOKE_METHOD", (("builtins", "BaseException", "with_traceback"), 1) 5455 ) 5456 5457 def test_assign_num_to_object(self): 5458 codestr = """ 5459 def f(): 5460 x: object = 42 5461 """ 5462 5463 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5464 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5465 f = mod["f"] 5466 self.assertNotInBytecode(f, "CAST", ("builtins", "object")) 5467 5468 def test_assign_num_to_dynamic(self): 5469 codestr = """ 5470 def f(): 5471 x: foo = 42 5472 """ 5473 5474 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5475 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5476 f = mod["f"] 5477 self.assertNotInBytecode(f, "CAST", ("builtins", "object")) 5478 5479 def test_assign_dynamic_to_object(self): 5480 codestr = """ 5481 def f(C): 5482 x: object = C() 5483 """ 5484 5485 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5486 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5487 f = mod["f"] 5488 self.assertNotInBytecode(f, "CAST", ("builtins", "object")) 5489 5490 def test_assign_dynamic_to_dynamic(self): 5491 codestr = """ 5492 def f(C): 5493 x: unknown = C() 5494 """ 5495 5496 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5497 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5498 f = mod["f"] 5499 self.assertNotInBytecode(f, "CAST", ("builtins", "object")) 5500 5501 def test_assign_constant_to_object(self): 5502 codestr = """ 5503 def f(): 5504 x: object = 42 + 1 5505 """ 5506 5507 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5508 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5509 f = mod["f"] 5510 self.assertNotInBytecode(f, "CAST", ("builtins", "object")) 5511 5512 def test_assign_try_except_typing(self): 5513 codestr = """ 5514 def testfunc(): 5515 try: 5516 pass 5517 except Exception as e: 5518 pass 5519 return 42 5520 """ 5521 5522 # We don't do anything special w/ Exception type yet, but it should compile 5523 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5524 5525 def test_assign_try_except_typing_predeclared(self): 5526 codestr = """ 5527 def testfunc(): 5528 e: Exception 5529 try: 5530 pass 5531 except Exception as e: 5532 pass 5533 return 42 5534 """ 5535 # We don't do anything special w/ Exception type yet, but it should compile 5536 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5537 5538 def test_assign_try_except_typing_narrowed(self): 5539 codestr = """ 5540 class E(Exception): 5541 pass 5542 5543 def testfunc(): 5544 e: Exception 5545 try: 5546 pass 5547 except E as e: 5548 pass 5549 return 42 5550 """ 5551 # We don't do anything special w/ Exception type yet, but it should compile 5552 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5553 5554 def test_assign_try_except_typing_redeclared_after(self): 5555 codestr = """ 5556 def testfunc(): 5557 try: 5558 pass 5559 except Exception as e: 5560 pass 5561 e: int = 42 5562 return 42 5563 """ 5564 # We don't do anything special w/ Exception type yet, but it should compile 5565 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5566 5567 def test_assign_try_except_redeclare(self): 5568 codestr = """ 5569 def testfunc(): 5570 e: int 5571 try: 5572 pass 5573 except Exception as e: 5574 pass 5575 return 42 5576 """ 5577 5578 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5579 5580 def test_assign_try_except_redeclare_unknown_type(self): 5581 codestr = """ 5582 def testfunc(): 5583 e: int 5584 try: 5585 pass 5586 except UnknownException as e: 5587 pass 5588 return 42 5589 """ 5590 5591 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5592 5593 def test_assign_try_assign_in_except(self): 5594 codestr = """ 5595 class B: 5596 def f(self): 5597 return 42 5598 class D(B): 5599 def f(self): 5600 return 'abc' 5601 5602 def testfunc(): 5603 x: B = D() 5604 try: 5605 pass 5606 except: 5607 x = B() 5608 return x.f() 5609 """ 5610 5611 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5612 f = self.find_code(code, "testfunc") 5613 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5614 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5615 test = mod["testfunc"] 5616 self.assertEqual(test(), "abc") 5617 5618 def test_assign_try_assign_in_second_except(self): 5619 codestr = """ 5620 class B: 5621 def f(self): 5622 return 42 5623 class D(B): 5624 def f(self): 5625 return 'abc' 5626 5627 def testfunc(): 5628 x: B = D() 5629 try: 5630 pass 5631 except TypeError: 5632 pass 5633 except: 5634 x = B() 5635 return x.f() 5636 """ 5637 5638 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5639 f = self.find_code(code, "testfunc") 5640 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5641 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5642 test = mod["testfunc"] 5643 self.assertEqual(test(), "abc") 5644 5645 def test_assign_try_assign_in_except_with_var(self): 5646 codestr = """ 5647 class B: 5648 def f(self): 5649 return 42 5650 class D(B): 5651 def f(self): 5652 return 'abc' 5653 5654 def testfunc(): 5655 x: B = D() 5656 try: 5657 pass 5658 except TypeError as e: 5659 x = B() 5660 return x.f() 5661 """ 5662 5663 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5664 f = self.find_code(code, "testfunc") 5665 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5666 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5667 test = mod["testfunc"] 5668 self.assertEqual(test(), "abc") 5669 5670 def test_try_except_finally(self): 5671 codestr = """ 5672 class B: 5673 def f(self): 5674 return 42 5675 class D(B): 5676 def f(self): 5677 return 'abc' 5678 5679 def testfunc(): 5680 x: B = D() 5681 try: 5682 pass 5683 except TypeError: 5684 pass 5685 finally: 5686 x = B() 5687 return x.f() 5688 """ 5689 5690 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5691 f = self.find_code(code, "testfunc") 5692 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5693 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5694 test = mod["testfunc"] 5695 self.assertEqual(test(), 42) 5696 5697 def test_assign_try_assign_in_try(self): 5698 codestr = """ 5699 class B: 5700 def f(self): 5701 return 42 5702 class D(B): 5703 def f(self): 5704 return 'abc' 5705 5706 def testfunc(): 5707 x: B = D() 5708 try: 5709 x = B() 5710 except: 5711 pass 5712 return x.f() 5713 """ 5714 5715 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5716 f = self.find_code(code, "testfunc") 5717 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5718 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5719 test = mod["testfunc"] 5720 self.assertEqual(test(), 42) 5721 5722 def test_assign_try_assign_in_finally(self): 5723 codestr = """ 5724 class B: 5725 def f(self): 5726 return 42 5727 class D(B): 5728 def f(self): 5729 return 'abc' 5730 5731 def testfunc(): 5732 x: B = D() 5733 try: 5734 pass 5735 finally: 5736 x = B() 5737 return x.f() 5738 """ 5739 5740 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5741 f = self.find_code(code, "testfunc") 5742 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5743 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5744 test = mod["testfunc"] 5745 self.assertEqual(test(), 42) 5746 5747 def test_assign_try_assign_in_else(self): 5748 codestr = """ 5749 class B: 5750 def f(self): 5751 return 42 5752 class D(B): 5753 def f(self): 5754 return 'abc' 5755 5756 def testfunc(): 5757 x: B = D() 5758 try: 5759 pass 5760 except: 5761 pass 5762 else: 5763 x = B() 5764 return x.f() 5765 """ 5766 5767 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5768 f = self.find_code(code, "testfunc") 5769 self.assertInBytecode(f, "INVOKE_METHOD", (("foo", "B", "f"), 0)) 5770 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5771 test = mod["testfunc"] 5772 self.assertEqual(test(), 42) 5773 5774 def test_if_optional_reassign(self): 5775 codestr = """ 5776 class C: pass 5777 5778 def testfunc(abc: Optional[C]): 5779 if abc is not None: 5780 abc = None 5781 """ 5782 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5783 5784 def test_widening_assign(self): 5785 codestr = """ 5786 from __static__ import int8, int16, box 5787 5788 def testfunc(): 5789 x: int16 5790 y: int8 5791 x = y = 42 5792 return box(x), box(y) 5793 """ 5794 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5795 test = mod["testfunc"] 5796 self.assertEqual(test(), (42, 42)) 5797 5798 def test_unknown_imported_annotation(self): 5799 codestr = """ 5800 from unknown_mod import foo 5801 5802 def testfunc(): 5803 x: foo = 42 5804 return x 5805 """ 5806 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 5807 5808 def test_widening_assign_reassign(self): 5809 codestr = """ 5810 from __static__ import int8, int16, box 5811 5812 def testfunc(): 5813 x: int16 5814 y: int8 5815 x = y = 42 5816 x = 257 5817 return box(x), box(y) 5818 """ 5819 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5820 test = mod["testfunc"] 5821 self.assertEqual(test(), (257, 42)) 5822 5823 def test_widening_assign_reassign_error(self): 5824 codestr = """ 5825 from __static__ import int8, int16, box 5826 5827 def testfunc(): 5828 x: int16 5829 y: int8 5830 x = y = 42 5831 y = 128 5832 return box(x), box(y) 5833 """ 5834 with self.assertRaisesRegex( 5835 TypedSyntaxError, 5836 "constant 128 is outside of the range -128 to 127 for int8", 5837 ): 5838 self.compile(codestr, StaticCodeGenerator, modname="foo") 5839 5840 def test_narrowing_assign_literal(self): 5841 codestr = """ 5842 from __static__ import int8, int16, box 5843 5844 def testfunc(): 5845 x: int8 5846 y: int16 5847 x = y = 42 5848 return box(x), box(y) 5849 """ 5850 self.compile(codestr, StaticCodeGenerator, modname="foo") 5851 5852 def test_narrowing_assign_out_of_range(self): 5853 codestr = """ 5854 from __static__ import int8, int16, box 5855 5856 def testfunc(): 5857 x: int8 5858 y: int16 5859 x = y = 300 5860 return box(x), box(y) 5861 """ 5862 with self.assertRaisesRegex( 5863 TypedSyntaxError, 5864 "constant 300 is outside of the range -128 to 127 for int8", 5865 ): 5866 self.compile(codestr, StaticCodeGenerator, modname="foo") 5867 5868 def test_module_primitive(self): 5869 codestr = """ 5870 from __static__ import int8 5871 x: int8 5872 """ 5873 with self.assertRaisesRegex( 5874 TypedSyntaxError, "cannot use primitives in global or closure scope" 5875 ): 5876 self.compile(codestr, StaticCodeGenerator, modname="foo") 5877 5878 def test_implicit_module_primitive(self): 5879 codestr = """ 5880 from __static__ import int8 5881 x = y = int8(0) 5882 """ 5883 with self.assertRaisesRegex( 5884 TypedSyntaxError, "cannot use primitives in global or closure scope" 5885 ): 5886 self.compile(codestr, StaticCodeGenerator, modname="foo") 5887 5888 def test_chained_primitive_to_non_primitive(self): 5889 codestr = """ 5890 from __static__ import int8 5891 def f(): 5892 x: object 5893 y: int8 = 42 5894 x = y = 42 5895 """ 5896 with self.assertRaisesRegex( 5897 TypedSyntaxError, "int8 cannot be assigned to object" 5898 ): 5899 self.compile(codestr, StaticCodeGenerator, modname="foo") 5900 5901 def test_closure_primitive(self): 5902 codestr = """ 5903 from __static__ import int8 5904 def f(): 5905 x: int8 = 0 5906 def g(): 5907 return x 5908 """ 5909 with self.assertRaisesRegex( 5910 TypedSyntaxError, "cannot use primitives in global or closure scope" 5911 ): 5912 self.compile(codestr, StaticCodeGenerator, modname="foo") 5913 5914 def test_nonlocal_primitive(self): 5915 codestr = """ 5916 from __static__ import int8 5917 def f(): 5918 x: int8 = 0 5919 def g(): 5920 nonlocal x 5921 x = 1 5922 """ 5923 with self.assertRaisesRegex( 5924 TypedSyntaxError, "cannot use primitives in global or closure scope" 5925 ): 5926 self.compile(codestr, StaticCodeGenerator, modname="foo") 5927 5928 def test_dynamic_chained_assign_param(self): 5929 codestr = """ 5930 from __static__ import int16 5931 def testfunc(y): 5932 x: int16 5933 x = y = 42 5934 return box(x) 5935 """ 5936 with self.assertRaisesRegex( 5937 TypedSyntaxError, type_mismatch("Exact[int]", "int16") 5938 ): 5939 self.compile(codestr, StaticCodeGenerator, modname="foo") 5940 5941 def test_dynamic_chained_assign_param_2(self): 5942 codestr = """ 5943 from __static__ import int16 5944 def testfunc(y): 5945 x: int16 5946 y = x = 42 5947 """ 5948 with self.assertRaisesRegex( 5949 TypedSyntaxError, type_mismatch("int16", "dynamic") 5950 ): 5951 self.compile(codestr, StaticCodeGenerator, modname="foo") 5952 5953 def test_dynamic_chained_assign_1(self): 5954 codestr = """ 5955 from __static__ import int16, box 5956 def testfunc(): 5957 x: int16 5958 x = y = 42 5959 return box(x) 5960 """ 5961 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5962 test = mod["testfunc"] 5963 self.assertEqual(test(), 42) 5964 5965 def test_dynamic_chained_assign_2(self): 5966 codestr = """ 5967 from __static__ import int16, box 5968 def testfunc(): 5969 x: int16 5970 y = x = 42 5971 return box(y) 5972 """ 5973 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 5974 test = mod["testfunc"] 5975 self.assertEqual(test(), 42) 5976 5977 def test_tuple_assign_list(self): 5978 codestr = """ 5979 from __static__ import int16, box 5980 def testfunc(a: int, b: int): 5981 x: int 5982 y: str 5983 x, y = [a, b] 5984 """ 5985 with self.assertRaisesRegex(TypedSyntaxError, "int cannot be assigned to str"): 5986 self.compile(codestr, StaticCodeGenerator, modname="foo") 5987 5988 def test_tuple_assign_tuple(self): 5989 codestr = """ 5990 from __static__ import int16, box 5991 def testfunc(a: int, b: int): 5992 x: int 5993 y: str 5994 x, y = a, b 5995 """ 5996 with self.assertRaisesRegex(TypedSyntaxError, "int cannot be assigned to str"): 5997 self.compile(codestr, StaticCodeGenerator, modname="foo") 5998 5999 def test_tuple_assign_constant(self): 6000 codestr = """ 6001 from __static__ import int16, box 6002 def testfunc(): 6003 x: int 6004 y: str 6005 x, y = 1, 1 6006 """ 6007 with self.assertRaisesRegex( 6008 TypedSyntaxError, 6009 r"type mismatch: Exact\[int\] cannot be assigned to str", 6010 ): 6011 self.compile(codestr, StaticCodeGenerator, modname="foo") 6012 6013 def test_if_optional_cond(self): 6014 codestr = """ 6015 from typing import Optional 6016 class C: 6017 def __init__(self): 6018 self.field = 42 6019 6020 def f(x: Optional[C]): 6021 return x.field if x is not None else None 6022 """ 6023 6024 self.compile(codestr, StaticCodeGenerator, modname="foo") 6025 6026 def test_while_optional_cond(self): 6027 codestr = """ 6028 from typing import Optional 6029 class C: 6030 def __init__(self): 6031 self.field: Optional["C"] = self 6032 6033 def f(x: Optional[C]): 6034 while x is not None: 6035 val: Optional[C] = x.field 6036 if val is not None: 6037 x = val 6038 """ 6039 6040 self.compile(codestr, StaticCodeGenerator, modname="foo") 6041 6042 def test_if_optional_dependent_conditions(self): 6043 codestr = """ 6044 from typing import Optional 6045 class C: 6046 def __init__(self): 6047 self.field: Optional[C] = None 6048 6049 def f(x: Optional[C]) -> C: 6050 if x is not None and x.field is not None: 6051 return x 6052 6053 if x is None: 6054 return C() 6055 6056 return x 6057 """ 6058 6059 self.compile(codestr, StaticCodeGenerator, modname="foo") 6060 6061 def test_none_attribute_error(self): 6062 codestr = """ 6063 def f(): 6064 x = None 6065 return x.foo 6066 """ 6067 6068 with self.assertRaisesRegex( 6069 TypedSyntaxError, "'NoneType' object has no attribute 'foo'" 6070 ): 6071 self.compile(codestr, StaticCodeGenerator, modname="foo") 6072 6073 def test_none_call(self): 6074 codestr = """ 6075 def f(): 6076 x = None 6077 return x() 6078 """ 6079 6080 with self.assertRaisesRegex( 6081 TypedSyntaxError, "'NoneType' object is not callable" 6082 ): 6083 self.compile(codestr, StaticCodeGenerator, modname="foo") 6084 6085 def test_none_subscript(self): 6086 codestr = """ 6087 def f(): 6088 x = None 6089 return x[0] 6090 """ 6091 6092 with self.assertRaisesRegex( 6093 TypedSyntaxError, "'NoneType' object is not subscriptable" 6094 ): 6095 self.compile(codestr, StaticCodeGenerator, modname="foo") 6096 6097 def test_none_unaryop(self): 6098 codestr = """ 6099 def f(): 6100 x = None 6101 return -x 6102 """ 6103 6104 with self.assertRaisesRegex( 6105 TypedSyntaxError, "bad operand type for unary -: 'NoneType'" 6106 ): 6107 self.compile(codestr, StaticCodeGenerator, modname="foo") 6108 6109 def test_vector_import(self): 6110 codestr = """ 6111 from __static__ import int64, Vector 6112 6113 def test() -> Vector[int64]: 6114 x: Vector[int64] = Vector[int64]() 6115 x.append(1) 6116 return x 6117 """ 6118 6119 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 6120 test = mod["test"] 6121 self.assertEqual(test(), array("L", [1])) 6122 6123 def test_vector_assign_non_primitive(self): 6124 codestr = """ 6125 from __static__ import int64, Vector 6126 6127 def test(abc) -> Vector[int64]: 6128 x: Vector[int64] = Vector[int64](2) 6129 i: int64 = 0 6130 x[i] = abc 6131 """ 6132 6133 with self.assertRaisesRegex( 6134 TypedSyntaxError, "Cannot assign a dynamic to int64" 6135 ): 6136 self.compile(codestr) 6137 6138 def test_vector_sizes(self): 6139 for signed in ["int", "uint"]: 6140 for size in ["8", "16", "32", "64"]: 6141 with self.subTest(size=size, signed=signed): 6142 int_type = f"{signed}{size}" 6143 codestr = f""" 6144 from __static__ import {int_type}, Vector 6145 6146 def test() -> Vector[{int_type}]: 6147 x: Vector[{int_type}] = Vector[{int_type}]() 6148 y: {int_type} = 1 6149 x.append(y) 6150 return x 6151 """ 6152 6153 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 6154 test = mod["test"] 6155 res = test() 6156 self.assertEqual(list(res), [1]) 6157 6158 def test_vector_invalid_literal(self): 6159 codestr = f""" 6160 from __static__ import int8, Vector 6161 6162 def test() -> Vector[int8]: 6163 x: Vector[int8] = Vector[int8]() 6164 x.append(128) 6165 return x 6166 """ 6167 with self.assertRaisesRegex( 6168 TypedSyntaxError, 6169 "type mismatch: int positional argument type mismatch int8", 6170 ): 6171 self.compile(codestr) 6172 6173 def test_vector_wrong_size(self): 6174 codestr = f""" 6175 from __static__ import int8, int16, Vector 6176 6177 def test() -> Vector[int8]: 6178 y: int16 = 1 6179 x: Vector[int8] = Vector[int8]() 6180 x.append(y) 6181 return x 6182 """ 6183 6184 with self.assertRaisesRegex( 6185 TypedSyntaxError, 6186 "type mismatch: int16 positional argument type mismatch int8", 6187 ): 6188 self.compile(codestr) 6189 6190 def test_vector_presized(self): 6191 codestr = f""" 6192 from __static__ import int8, Vector 6193 6194 def test() -> Vector[int8]: 6195 x: Vector[int8] = Vector[int8](4) 6196 x[1] = 1 6197 return x 6198 """ 6199 6200 with self.in_module(codestr) as mod: 6201 f = mod["test"] 6202 self.assertEqual(f(), array("b", [0, 1, 0, 0])) 6203 6204 def test_array_import(self): 6205 codestr = """ 6206 from __static__ import int64, Array 6207 6208 def test() -> Array[int64]: 6209 x: Array[int64] = Array[int64](1) 6210 x[0] = 1 6211 return x 6212 """ 6213 6214 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 6215 test = mod["test"] 6216 self.assertEqual(test(), array("L", [1])) 6217 6218 def test_array_create(self): 6219 codestr = """ 6220 from __static__ import int64, Array 6221 6222 def test() -> Array[int64]: 6223 x: Array[int64] = Array[int64]([1, 3, 5]) 6224 return x 6225 """ 6226 6227 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 6228 test = mod["test"] 6229 self.assertEqual(test(), array("l", [1, 3, 5])) 6230 6231 def test_array_create_failure(self): 6232 # todo - in the future we're going to support this, but for now fail it. 6233 codestr = """ 6234 from __static__ import int64, Array 6235 6236 class C: pass 6237 6238 def test() -> Array[C]: 6239 return Array[C]([1, 3, 5]) 6240 """ 6241 with self.assertRaisesRegex( 6242 TypedSyntaxError, "Invalid Array element type: foo.C" 6243 ): 6244 self.compile(codestr, StaticCodeGenerator, modname="foo") 6245 6246 def test_array_call_unbound(self): 6247 codestr = """ 6248 from __static__ import Array 6249 6250 def f() -> Array: 6251 return Array([1, 2, 3]) 6252 """ 6253 with self.assertRaisesRegex( 6254 TypedSyntaxError, 6255 r"create instances of a generic Type\[Exact\[Array\[T\]\]\]", 6256 ): 6257 self.compile(codestr, StaticCodeGenerator, modname="foo") 6258 6259 def test_array_assign_wrong_type(self): 6260 codestr = """ 6261 from __static__ import int64, char, Array 6262 6263 def test() -> None: 6264 x: Array[int64] = Array[char]([48]) 6265 """ 6266 with self.assertRaisesRegex( 6267 TypedSyntaxError, 6268 type_mismatch( 6269 "Exact[Array[char]]", 6270 "Array[int64]", 6271 ), 6272 ): 6273 self.compile(codestr, StaticCodeGenerator, modname="foo") 6274 6275 def test_array_subclass_assign(self): 6276 codestr = """ 6277 from __static__ import int64, Array 6278 6279 class MyArray(Array): 6280 pass 6281 6282 def y(inexact: Array[int64]): 6283 exact = Array[int64]([1]) 6284 exact = inexact 6285 """ 6286 with self.assertRaisesRegex( 6287 TypedSyntaxError, 6288 type_mismatch( 6289 "Array[int64]", 6290 "Exact[Array[int64]]", 6291 ), 6292 ): 6293 self.compile(codestr, StaticCodeGenerator, modname="foo") 6294 6295 def test_array_types(self): 6296 codestr = """ 6297 from __static__ import ( 6298 int8, 6299 int16, 6300 int32, 6301 int64, 6302 uint8, 6303 uint16, 6304 uint32, 6305 uint64, 6306 char, 6307 double, 6308 Array 6309 ) 6310 from typing import Tuple 6311 6312 def test() -> Tuple[Array[int64], Array[char], Array[double]]: 6313 x1: Array[int8] = Array[int8]([1, 3, -5]) 6314 x2: Array[uint8] = Array[uint8]([1, 3, 5]) 6315 x3: Array[int16] = Array[int16]([1, -3, 5]) 6316 x4: Array[uint16] = Array[uint16]([1, 3, 5]) 6317 x5: Array[int32] = Array[int32]([1, 3, 5]) 6318 x6: Array[uint32] = Array[uint32]([1, 3, 5]) 6319 x7: Array[int64] = Array[int64]([1, 3, 5]) 6320 x8: Array[uint64] = Array[uint64]([1, 3, 5]) 6321 x9: Array[char] = Array[char]([ord('a')]) 6322 x10: Array[double] = Array[double]([1.1, 3.3, 5.5]) 6323 x11: Array[float] = Array[float]([1.1, 3.3, 5.5]) 6324 arrays = [ 6325 x1, 6326 x2, 6327 x3, 6328 x4, 6329 x5, 6330 x6, 6331 x7, 6332 x8, 6333 x9, 6334 x10, 6335 x11, 6336 ] 6337 first_elements = [] 6338 for ar in arrays: 6339 first_elements.append(ar[0]) 6340 return (arrays, first_elements) 6341 """ 6342 6343 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 6344 test = mod["test"] 6345 arrays, first_elements = test() 6346 exp_arrays = [ 6347 array(*args) 6348 for args in [ 6349 ("h", [1, 3, -5]), 6350 ("H", [1, 3, 5]), 6351 ("i", [1, -3, 5]), 6352 ("I", [1, 3, 5]), 6353 ("l", [1, 3, 5]), 6354 ("L", [1, 3, 5]), 6355 ("q", [1, 3, 5]), 6356 ("Q", [1, 3, 5]), 6357 ("b", [ord("a")]), 6358 ("d", [1.1, 3.3, 5.5]), 6359 ("f", [1.1, 3.3, 5.5]), 6360 ] 6361 ] 6362 exp_first_elements = [ar[0] for ar in exp_arrays] 6363 for result, expectation in zip(arrays, exp_arrays): 6364 self.assertEqual(result, expectation) 6365 for result, expectation in zip(first_elements, exp_first_elements): 6366 self.assertEqual(result, expectation) 6367 6368 def test_assign_type_propagation(self): 6369 codestr = """ 6370 def test() -> int: 6371 x = 5 6372 return x 6373 """ 6374 self.compile(codestr, StaticCodeGenerator, modname="foo") 6375 6376 def test_assign_subtype_handling(self): 6377 codestr = """ 6378 class B: pass 6379 class D(B): pass 6380 6381 def f(): 6382 b = B() 6383 b = D() 6384 b = B() 6385 """ 6386 self.compile(codestr, StaticCodeGenerator, modname="foo") 6387 6388 def test_assign_subtype_handling_fail(self): 6389 codestr = """ 6390 class B: pass 6391 class D(B): pass 6392 6393 def f(): 6394 d = D() 6395 d = B() 6396 """ 6397 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("foo.B", "foo.D")): 6398 self.compile(codestr, StaticCodeGenerator, modname="foo") 6399 6400 def test_assign_chained(self): 6401 codestr = """ 6402 def test() -> str: 6403 x: str = "hi" 6404 y = x = "hello" 6405 return y 6406 """ 6407 self.compile(codestr, StaticCodeGenerator, modname="foo") 6408 6409 def test_assign_chained_failure_wrong_target_type(self): 6410 codestr = """ 6411 def test() -> str: 6412 x = 1 6413 y = x = "hello" 6414 return y 6415 """ 6416 with self.assertRaisesRegex( 6417 TypedSyntaxError, type_mismatch("Exact[str]", "int") 6418 ): 6419 self.compile(codestr, StaticCodeGenerator, modname="foo") 6420 6421 def test_chained_assign_type_propagation(self): 6422 codestr = """ 6423 from __static__ import int64, char, Array 6424 6425 def test2() -> Array[char]: 6426 x = y = Array[char]([48]) 6427 return y 6428 """ 6429 self.compile(codestr, StaticCodeGenerator, modname="foo") 6430 6431 def test_chained_assign_type_propagation_failure_redefine(self): 6432 codestr = """ 6433 from __static__ import int64, char, Array 6434 6435 def test2() -> Array[char]: 6436 x = Array[int64]([54]) 6437 x = y = Array[char]([48]) 6438 return y 6439 """ 6440 with self.assertRaisesRegex( 6441 TypedSyntaxError, 6442 type_mismatch( 6443 "Exact[Array[char]]", 6444 "Exact[Array[int64]]", 6445 ), 6446 ): 6447 self.compile(codestr, StaticCodeGenerator, modname="foo") 6448 6449 def test_chained_assign_type_propagation_failure_redefine_2(self): 6450 codestr = """ 6451 from __static__ import int64, char, Array 6452 6453 def test2() -> Array[char]: 6454 x = Array[int64]([54]) 6455 y = x = Array[char]([48]) 6456 return y 6457 """ 6458 with self.assertRaisesRegex( 6459 TypedSyntaxError, 6460 type_mismatch( 6461 "Exact[Array[char]]", 6462 "Exact[Array[int64]]", 6463 ), 6464 ): 6465 self.compile(codestr, StaticCodeGenerator, modname="foo") 6466 6467 def test_chained_assign_type_inference(self): 6468 codestr = """ 6469 from __static__ import int64, char, Array 6470 6471 def test2(): 6472 y = x = 4 6473 x = "hello" 6474 return y 6475 """ 6476 with self.assertRaisesRegex( 6477 TypedSyntaxError, type_mismatch("Exact[str]", "int") 6478 ): 6479 self.compile(codestr, StaticCodeGenerator, modname="foo") 6480 6481 def test_chained_assign_type_inference_2(self): 6482 codestr = """ 6483 from __static__ import int64, char, Array 6484 6485 def test2(): 6486 y = x = 4 6487 y = "hello" 6488 return x 6489 """ 6490 with self.assertRaisesRegex( 6491 TypedSyntaxError, type_mismatch("Exact[str]", "int") 6492 ): 6493 self.compile(codestr, StaticCodeGenerator, modname="foo") 6494 6495 def test_array_inplace_assign(self): 6496 codestr = """ 6497 from __static__ import Array, int8 6498 6499 def m() -> Array[int8]: 6500 a = Array[int8]([1, 3, -5, -1, 7, 22]) 6501 a[0] += 1 6502 return a 6503 """ 6504 with self.in_module(codestr) as mod: 6505 m = mod["m"] 6506 self.assertEqual(m()[0], 2) 6507 6508 def test_array_subscripting_slice(self): 6509 codestr = """ 6510 from __static__ import Array, int8 6511 6512 def m() -> Array[int8]: 6513 a = Array[int8]([1, 3, -5, -1, 7, 22]) 6514 return a[1:3] 6515 """ 6516 self.compile(codestr, StaticCodeGenerator, modname="foo") 6517 6518 @skipIf(cinderjit is not None, "can't report error from JIT") 6519 def test_load_uninit_module(self): 6520 """verify we don't crash if we receive a module w/o a dictionary""" 6521 codestr = """ 6522 class C: 6523 def __init__(self): 6524 self.x: Optional[C] = None 6525 6526 """ 6527 with self.in_module(codestr) as mod: 6528 C = mod["C"] 6529 6530 class UninitModule(ModuleType): 6531 def __init__(self): 6532 # don't call super init 6533 pass 6534 6535 sys.modules[mod["__name__"]] = UninitModule() 6536 with self.assertRaisesRegex( 6537 TypeError, 6538 r"bad name provided for class loader: \('" 6539 + mod["__name__"] 6540 + r"', 'C'\), not a class", 6541 ): 6542 C() 6543 6544 def test_module_subclass(self): 6545 codestr = """ 6546 class C: 6547 def __init__(self): 6548 self.x: Optional[C] = None 6549 6550 """ 6551 with self.in_module(codestr) as mod: 6552 C = mod["C"] 6553 6554 class CustomModule(ModuleType): 6555 def __getattr__(self, name): 6556 if name == "C": 6557 return C 6558 6559 sys.modules[mod["__name__"]] = CustomModule(mod["__name__"]) 6560 c = C() 6561 self.assertEqual(c.x, None) 6562 6563 def test_invoke_and_raise_noframe_strictmod(self): 6564 codestr = """ 6565 from __static__.compiler_flags import noframe 6566 6567 def x(): 6568 raise TypeError() 6569 6570 def y(): 6571 return x() 6572 """ 6573 with self.in_strict_module(codestr) as mod: 6574 y = mod.y 6575 x = mod.x 6576 with self.assertRaises(TypeError): 6577 y() 6578 self.assert_jitted(x) 6579 self.assertInBytecode( 6580 y, 6581 "INVOKE_FUNCTION", 6582 ((mod.__name__, "x"), 0), 6583 ) 6584 6585 def test_override_okay(self): 6586 codestr = """ 6587 class B: 6588 def f(self) -> "B": 6589 return self 6590 6591 def f(x: B): 6592 return x.f() 6593 """ 6594 with self.in_module(codestr) as mod: 6595 B = mod["B"] 6596 f = mod["f"] 6597 6598 class D(B): 6599 def f(self): 6600 return self 6601 6602 x = f(D()) 6603 6604 def test_override_override_inherited(self): 6605 codestr = """ 6606 from typing import Optional 6607 class B: 6608 def f(self) -> "Optional[B]": 6609 return self 6610 6611 class D(B): 6612 pass 6613 6614 def f(x: B): 6615 return x.f() 6616 """ 6617 with self.in_module(codestr) as mod: 6618 B = mod["B"] 6619 D = mod["D"] 6620 f = mod["f"] 6621 6622 b = B() 6623 d = D() 6624 self.assertEqual(f(b), b) 6625 self.assertEqual(f(d), d) 6626 6627 D.f = lambda self: None 6628 self.assertEqual(f(b), b) 6629 self.assertEqual(f(d), None) 6630 6631 def test_override_bad_ret(self): 6632 codestr = """ 6633 class B: 6634 def f(self) -> "B": 6635 return self 6636 6637 def f(x: B): 6638 return x.f() 6639 """ 6640 with self.in_module(codestr) as mod: 6641 B = mod["B"] 6642 f = mod["f"] 6643 6644 class D(B): 6645 def f(self): 6646 return 42 6647 6648 with self.assertRaisesRegex( 6649 TypeError, "unexpected return type from D.f, expected B, got int" 6650 ): 6651 f(D()) 6652 6653 def test_dynamic_base(self): 6654 nonstatic_code = """ 6655 class Foo: 6656 pass 6657 """ 6658 6659 with self.in_module( 6660 nonstatic_code, code_gen=PythonCodeGenerator, name="nonstatic" 6661 ): 6662 codestr = """ 6663 from nonstatic import Foo 6664 6665 class A(Foo): 6666 def __init__(self): 6667 self.x = 1 6668 6669 def f(self) -> int: 6670 return self.x 6671 6672 def f(x: A) -> int: 6673 return x.f() 6674 """ 6675 with self.in_module(codestr) as mod: 6676 f = mod["f"] 6677 self.assertInBytecode(f, "INVOKE_METHOD") 6678 a = mod["A"]() 6679 self.assertEqual(f(a), 1) 6680 # x is a data descriptor, it takes precedence 6681 a.__dict__["x"] = 100 6682 self.assertEqual(f(a), 1) 6683 # but methods are normal descriptors, instance 6684 # attributes should take precedence 6685 a.__dict__["f"] = lambda: 42 6686 self.assertEqual(f(a), 42) 6687 6688 def test_invoke_type_modified(self): 6689 codestr = """ 6690 class C: 6691 def f(self): 6692 return 1 6693 6694 def x(c: C): 6695 x = c.f() 6696 x += c.f() 6697 return x 6698 """ 6699 6700 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 6701 x = self.find_code(code, "x") 6702 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 6703 6704 with self.in_module(codestr) as mod: 6705 x, C = mod["x"], mod["C"] 6706 self.assertEqual(x(C()), 2) 6707 C.f = lambda self: 42 6708 self.assertEqual(x(C()), 84) 6709 6710 def test_invoke_type_modified_pre_invoke(self): 6711 codestr = """ 6712 class C: 6713 def f(self): 6714 return 1 6715 6716 def x(c: C): 6717 x = c.f() 6718 x += c.f() 6719 return x 6720 """ 6721 6722 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 6723 x = self.find_code(code, "x") 6724 self.assertInBytecode(x, "INVOKE_METHOD", (("foo", "C", "f"), 0)) 6725 6726 with self.in_module(codestr) as mod: 6727 x, C = mod["x"], mod["C"] 6728 C.f = lambda self: 42 6729 self.assertEqual(x(C()), 84) 6730 6731 def test_override_modified_base_class(self): 6732 codestr = """ 6733 class B: 6734 def f(self): 6735 return 1 6736 6737 def f(x: B): 6738 return x.f() 6739 """ 6740 with self.in_module(codestr) as mod: 6741 B = mod["B"] 6742 f = mod["f"] 6743 B.f = lambda self: 2 6744 6745 class D(B): 6746 def f(self): 6747 return 3 6748 6749 d = D() 6750 self.assertEqual(f(d), 3) 6751 6752 def test_override_remove_base_method(self): 6753 codestr = """ 6754 from typing import Optional 6755 class B: 6756 def f(self) -> "B": 6757 return self 6758 6759 class D(B): pass 6760 6761 def f(x: B): 6762 return x.f() 6763 """ 6764 with self.in_module(codestr) as mod: 6765 B = mod["B"] 6766 D = mod["D"] 6767 f = mod["f"] 6768 b = B() 6769 d = D() 6770 self.assertEqual(f(b), b) 6771 self.assertEqual(f(d), d) 6772 del B.f 6773 6774 with self.assertRaises(AttributeError): 6775 f(b) 6776 with self.assertRaises(AttributeError): 6777 f(d) 6778 6779 def test_override_remove_derived_method(self): 6780 codestr = """ 6781 from typing import Optional 6782 class B: 6783 def f(self) -> "Optional[B]": 6784 return self 6785 6786 class D(B): 6787 def f(self) -> Optional["B"]: 6788 return None 6789 6790 def f(x: B): 6791 return x.f() 6792 """ 6793 with self.in_module(codestr) as mod: 6794 B = mod["B"] 6795 D = mod["D"] 6796 f = mod["f"] 6797 b = B() 6798 d = D() 6799 self.assertEqual(f(b), b) 6800 self.assertEqual(f(d), None) 6801 del D.f 6802 6803 self.assertEqual(f(b), b) 6804 self.assertEqual(f(d), d) 6805 6806 def test_override_remove_method(self): 6807 codestr = """ 6808 from typing import Optional 6809 class B: 6810 def f(self) -> "Optional[B]": 6811 return self 6812 6813 def f(x: B): 6814 return x.f() 6815 """ 6816 with self.in_module(codestr) as mod: 6817 B = mod["B"] 6818 f = mod["f"] 6819 b = B() 6820 self.assertEqual(f(b), b) 6821 del B.f 6822 6823 with self.assertRaises(AttributeError): 6824 f(b) 6825 6826 def test_override_remove_method_add_type_check(self): 6827 codestr = """ 6828 from typing import Optional 6829 class B: 6830 def f(self) -> "B": 6831 return self 6832 6833 def f(x: B): 6834 return x.f() 6835 """ 6836 with self.in_module(codestr) as mod: 6837 B = mod["B"] 6838 f = mod["f"] 6839 b = B() 6840 self.assertEqual(f(b), b) 6841 del B.f 6842 6843 with self.assertRaises(AttributeError): 6844 f(b) 6845 6846 B.f = lambda self: None 6847 with self.assertRaises(TypeError): 6848 f(b) 6849 6850 def test_override_update_derived(self): 6851 codestr = """ 6852 from typing import Optional 6853 class B: 6854 def f(self) -> "Optional[B]": 6855 return self 6856 6857 class D(B): 6858 pass 6859 6860 def f(x: B): 6861 return x.f() 6862 """ 6863 with self.in_module(codestr) as mod: 6864 B = mod["B"] 6865 D = mod["D"] 6866 f = mod["f"] 6867 6868 b = B() 6869 d = D() 6870 self.assertEqual(f(b), b) 6871 self.assertEqual(f(d), d) 6872 6873 B.f = lambda self: None 6874 self.assertEqual(f(b), None) 6875 self.assertEqual(f(d), None) 6876 6877 def test_override_update_derived_2(self): 6878 codestr = """ 6879 from typing import Optional 6880 class B: 6881 def f(self) -> "Optional[B]": 6882 return self 6883 6884 class D1(B): pass 6885 6886 class D(D1): 6887 pass 6888 6889 def f(x: B): 6890 return x.f() 6891 """ 6892 with self.in_module(codestr) as mod: 6893 B = mod["B"] 6894 D = mod["D"] 6895 f = mod["f"] 6896 6897 b = B() 6898 d = D() 6899 self.assertEqual(f(b), b) 6900 self.assertEqual(f(d), d) 6901 6902 B.f = lambda self: None 6903 self.assertEqual(f(b), None) 6904 self.assertEqual(f(d), None) 6905 6906 def test_method_prologue(self): 6907 codestr = """ 6908 def f(x: str): 6909 return 42 6910 """ 6911 with self.in_module(codestr) as mod: 6912 f = mod["f"] 6913 self.assertInBytecode(f, "CHECK_ARGS", (0, ("builtins", "str"))) 6914 with self.assertRaisesRegex( 6915 TypeError, ".*expected 'str' for argument x, got 'int'" 6916 ): 6917 f(42) 6918 6919 def test_method_prologue_2(self): 6920 codestr = """ 6921 def f(x, y: str): 6922 return 42 6923 """ 6924 with self.in_module(codestr) as mod: 6925 f = mod["f"] 6926 self.assertInBytecode(f, "CHECK_ARGS", (1, ("builtins", "str"))) 6927 with self.assertRaisesRegex( 6928 TypeError, ".*expected 'str' for argument y, got 'int'" 6929 ): 6930 f("abc", 42) 6931 6932 def test_method_prologue_3(self): 6933 codestr = """ 6934 def f(x: int, y: str): 6935 return 42 6936 """ 6937 with self.in_module(codestr) as mod: 6938 f = mod["f"] 6939 self.assertInBytecode( 6940 f, "CHECK_ARGS", (0, ("builtins", "int"), 1, ("builtins", "str")) 6941 ) 6942 with self.assertRaisesRegex( 6943 TypeError, ".*expected 'str' for argument y, got 'int'" 6944 ): 6945 f(42, 42) 6946 6947 def test_method_prologue_posonly(self): 6948 codestr = """ 6949 def f(x: int, /, y: str): 6950 return 42 6951 """ 6952 with self.in_module(codestr) as mod: 6953 f = mod["f"] 6954 self.assertInBytecode( 6955 f, "CHECK_ARGS", (0, ("builtins", "int"), 1, ("builtins", "str")) 6956 ) 6957 with self.assertRaisesRegex( 6958 TypeError, ".*expected 'str' for argument y, got 'int'" 6959 ): 6960 f(42, 42) 6961 6962 def test_method_prologue_shadowcode(self): 6963 codestr = """ 6964 def f(x, y: str): 6965 return 42 6966 """ 6967 with self.in_module(codestr) as mod: 6968 f = mod["f"] 6969 self.assertInBytecode(f, "CHECK_ARGS", (1, ("builtins", "str"))) 6970 for i in range(100): 6971 self.assertEqual(f("abc", "abc"), 42) 6972 with self.assertRaisesRegex( 6973 TypeError, ".*expected 'str' for argument y, got 'int'" 6974 ): 6975 f("abc", 42) 6976 6977 def test_method_prologue_shadowcode_2(self): 6978 codestr = """ 6979 def f(x: str): 6980 return 42 6981 """ 6982 with self.in_module(codestr) as mod: 6983 f = mod["f"] 6984 self.assertInBytecode(f, "CHECK_ARGS", (0, ("builtins", "str"))) 6985 for i in range(100): 6986 self.assertEqual(f("abc"), 42) 6987 with self.assertRaisesRegex( 6988 TypeError, ".*expected 'str' for argument x, got 'int'" 6989 ): 6990 f(42) 6991 6992 def test_method_prologue_no_annotation(self): 6993 codestr = """ 6994 def f(x): 6995 return 42 6996 """ 6997 with self.in_module(codestr) as mod: 6998 f = mod["f"] 6999 self.assertInBytecode(f, "CHECK_ARGS", ()) 7000 self.assertEqual(f("abc"), 42) 7001 7002 def test_method_prologue_kwonly(self): 7003 codestr = """ 7004 def f(*, x: str): 7005 return 42 7006 """ 7007 with self.in_module(codestr) as mod: 7008 f = mod["f"] 7009 self.assertInBytecode(f, "CHECK_ARGS", (0, ("builtins", "str"))) 7010 with self.assertRaisesRegex( 7011 TypeError, "f expected 'str' for argument x, got 'int'" 7012 ): 7013 f(x=42) 7014 7015 def test_method_prologue_kwonly_2(self): 7016 codestr = """ 7017 def f(x, *, y: str): 7018 return 42 7019 """ 7020 with self.in_module(codestr) as mod: 7021 f = mod["f"] 7022 self.assertInBytecode(f, "CHECK_ARGS", (1, ("builtins", "str"))) 7023 with self.assertRaisesRegex( 7024 TypeError, "f expected 'str' for argument y, got 'object'" 7025 ): 7026 f(1, y=object()) 7027 7028 def test_method_prologue_kwonly_3(self): 7029 codestr = """ 7030 def f(x, *, y: str, z=1): 7031 return 42 7032 """ 7033 with self.in_module(codestr) as mod: 7034 f = mod["f"] 7035 self.assertInBytecode(f, "CHECK_ARGS", (1, ("builtins", "str"))) 7036 with self.assertRaisesRegex( 7037 TypeError, "f expected 'str' for argument y, got 'object'" 7038 ): 7039 f(1, y=object()) 7040 7041 def test_method_prologue_kwonly_4(self): 7042 codestr = """ 7043 def f(x, *, y: str, **rest): 7044 return 42 7045 """ 7046 with self.in_module(codestr) as mod: 7047 f = mod["f"] 7048 self.assertInBytecode(f, "CHECK_ARGS", (1, ("builtins", "str"))) 7049 with self.assertRaisesRegex( 7050 TypeError, "f expected 'str' for argument y, got 'object'" 7051 ): 7052 f(1, y=object(), z=2) 7053 7054 def test_method_prologue_kwonly_no_annotation(self): 7055 codestr = """ 7056 def f(*, x): 7057 return 42 7058 """ 7059 with self.in_module(codestr) as mod: 7060 f = mod["f"] 7061 self.assertInBytecode(f, "CHECK_ARGS", ()) 7062 f(x=42) 7063 7064 def test_package_no_parent(self): 7065 codestr = """ 7066 class C: 7067 def f(self): 7068 return 42 7069 """ 7070 with self.in_module( 7071 codestr, code_gen=StaticCodeGenerator, name="package_no_parent.child" 7072 ) as mod: 7073 C = mod["C"] 7074 self.assertInBytecode( 7075 C.f, "CHECK_ARGS", (0, ("package_no_parent.child", "C")) 7076 ) 7077 self.assertEqual(C().f(), 42) 7078 7079 def test_direct_super_init(self): 7080 value = 42 7081 expected = value 7082 codestr = f""" 7083 class Obj: 7084 pass 7085 7086 class C: 7087 def __init__(self, x: Obj): 7088 pass 7089 7090 class D: 7091 def __init__(self): 7092 C.__init__(None) 7093 """ 7094 with self.assertRaisesRegex( 7095 TypedSyntaxError, 7096 "type mismatch: None positional argument type mismatch foo.C", 7097 ): 7098 self.compile(codestr, StaticCodeGenerator, modname="foo") 7099 7100 def test_class_unknown_attr(self): 7101 value = 42 7102 expected = value 7103 codestr = f""" 7104 class C: 7105 pass 7106 7107 def f(): 7108 return C.foo 7109 """ 7110 with self.in_module(codestr) as mod: 7111 f = mod["f"] 7112 self.assertInBytecode(f, "LOAD_ATTR", "foo") 7113 7114 def test_descriptor_access(self): 7115 value = 42 7116 expected = value 7117 codestr = f""" 7118 class Obj: 7119 abc: int 7120 7121 class C: 7122 x: Obj 7123 7124 def f(): 7125 return C.x.abc 7126 """ 7127 with self.in_module(codestr) as mod: 7128 f = mod["f"] 7129 self.assertInBytecode(f, "LOAD_ATTR", "abc") 7130 self.assertNotInBytecode(f, "LOAD_FIELD") 7131 7132 @skipIf(not path.exists(RICHARDS_PATH), "richards not found") 7133 def test_richards(self): 7134 with open(RICHARDS_PATH) as f: 7135 codestr = f.read() 7136 7137 with self.in_module(codestr) as mod: 7138 Richards = mod["Richards"] 7139 self.assertTrue(Richards().run(1)) 7140 7141 def test_unknown_isinstance_bool_ret(self): 7142 codestr = """ 7143 from typing import Any 7144 7145 class C: 7146 def __init__(self, x: str): 7147 self.x: str = x 7148 7149 def __eq__(self, other: Any) -> bool: 7150 return isinstance(other, C) 7151 7152 """ 7153 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7154 C = mod["C"] 7155 x = C("abc") 7156 y = C("foo") 7157 self.assertTrue(x == y) 7158 7159 def test_unknown_issubclass_bool_ret(self): 7160 codestr = """ 7161 from typing import Any 7162 7163 class C: 7164 def __init__(self, x: str): 7165 self.x: str = x 7166 7167 def __eq__(self, other: Any) -> bool: 7168 return issubclass(type(other), C) 7169 7170 """ 7171 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7172 C = mod["C"] 7173 x = C("abc") 7174 y = C("foo") 7175 self.assertTrue(x == y) 7176 7177 def test_unknown_isinstance_narrows(self): 7178 codestr = """ 7179 from typing import Any 7180 7181 class C: 7182 def __init__(self, x: str): 7183 self.x: str = x 7184 7185 def testfunc(x): 7186 if isinstance(x, C): 7187 return x.x 7188 """ 7189 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7190 testfunc = mod["testfunc"] 7191 self.assertInBytecode(testfunc, "LOAD_FIELD", (mod["__name__"], "C", "x")) 7192 7193 def test_unknown_isinstance_narrows_class_attr(self): 7194 codestr = """ 7195 from typing import Any 7196 7197 class C: 7198 def __init__(self, x: str): 7199 self.x: str = x 7200 7201 def f(self, other) -> str: 7202 if isinstance(other, self.__class__): 7203 return other.x 7204 return '' 7205 """ 7206 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7207 C = mod["C"] 7208 self.assertInBytecode( 7209 C.f, 7210 "LOAD_FIELD", 7211 (mod["__name__"], "C", "x"), 7212 ) 7213 7214 def test_unknown_isinstance_narrows_class_attr_dynamic(self): 7215 codestr = """ 7216 from typing import Any 7217 7218 class C: 7219 def __init__(self, x: str): 7220 self.x: str = x 7221 7222 def f(self, other, unknown): 7223 if isinstance(other, unknown.__class__): 7224 return other.x 7225 return '' 7226 """ 7227 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7228 C = mod["C"] 7229 self.assertInBytecode(C.f, "LOAD_ATTR", "x") 7230 7231 def test_unknown_isinstance_narrows_else_correct(self): 7232 codestr = """ 7233 from typing import Any 7234 7235 class C: 7236 def __init__(self, x: str): 7237 self.x: str = x 7238 7239 def testfunc(x): 7240 if isinstance(x, C): 7241 pass 7242 else: 7243 return x.x 7244 """ 7245 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7246 testfunc = mod["testfunc"] 7247 self.assertNotInBytecode( 7248 testfunc, "LOAD_FIELD", (mod["__name__"], "C", "x") 7249 ) 7250 7251 def test_narrow_while_break(self): 7252 codestr = """ 7253 from typing import Optional 7254 def f(x: Optional[int]) -> int: 7255 while x is None: 7256 break 7257 return x 7258 """ 7259 with self.assertRaisesRegex( 7260 TypedSyntaxError, 7261 type_mismatch("Optional[int]", "int"), 7262 ): 7263 self.compile(codestr) 7264 7265 def test_narrow_while_if_break_else_return(self): 7266 codestr = """ 7267 from typing import Optional 7268 def f(x: Optional[int], y: int) -> int: 7269 while x is None: 7270 if y > 0: 7271 break 7272 else: 7273 return 42 7274 return x 7275 """ 7276 with self.assertRaisesRegex( 7277 TypedSyntaxError, 7278 type_mismatch("Optional[int]", "int"), 7279 ): 7280 self.compile(codestr) 7281 7282 def test_narrow_while_break_if(self): 7283 codestr = """ 7284 from typing import Optional 7285 def f(x: Optional[int]) -> int: 7286 while True: 7287 if x is None: 7288 break 7289 return x 7290 """ 7291 self.compile(codestr) 7292 7293 def test_narrow_while_continue_if(self): 7294 codestr = """ 7295 from typing import Optional 7296 def f(x: Optional[int]) -> int: 7297 while True: 7298 if x is None: 7299 continue 7300 return x 7301 """ 7302 self.compile(codestr) 7303 7304 def test_unknown_param_ann(self): 7305 codestr = """ 7306 from typing import Any 7307 7308 class C: 7309 def __init__(self, x: str): 7310 self.x: str = x 7311 7312 def __eq__(self, other: Any) -> bool: 7313 return False 7314 7315 """ 7316 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7317 C = mod["C"] 7318 x = C("abc") 7319 self.assertInBytecode(C.__eq__, "CHECK_ARGS", (0, (mod["__name__"], "C"))) 7320 self.assertNotEqual(x, x) 7321 7322 def test_class_init_kw(self): 7323 codestr = """ 7324 class C: 7325 def __init__(self, x: str): 7326 self.x: str = x 7327 7328 def f(): 7329 x = C(x='abc') 7330 7331 """ 7332 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7333 f = mod["f"] 7334 self.assertInBytecode(f, "CALL_FUNCTION_KW", 1) 7335 7336 def test_ret_type_cast(self): 7337 codestr = """ 7338 from typing import Any 7339 7340 def testfunc(x: str, y: str) -> bool: 7341 return x == y 7342 """ 7343 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7344 f = mod["testfunc"] 7345 self.assertEqual(f("abc", "abc"), True) 7346 self.assertInBytecode(f, "CAST", ("builtins", "bool")) 7347 7348 def test_bind_boolop_type(self): 7349 codestr = """ 7350 from typing import Any 7351 7352 class C: 7353 def f(self) -> bool: 7354 return True 7355 7356 def g(self) -> bool: 7357 return False 7358 7359 def x(self) -> bool: 7360 return self.f() and self.g() 7361 7362 def y(self) -> bool: 7363 return self.f() or self.g() 7364 """ 7365 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7366 C = mod["C"] 7367 c = C() 7368 self.assertEqual(c.x(), False) 7369 self.assertEqual(c.y(), True) 7370 7371 def test_decorated_function_ignored_class(self): 7372 codestr = """ 7373 class C: 7374 @property 7375 def x(self): 7376 return lambda: 42 7377 7378 def y(self): 7379 return self.x() 7380 7381 """ 7382 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7383 C = mod["C"] 7384 self.assertNotInBytecode(C.y, "INVOKE_METHOD") 7385 self.assertEqual(C().y(), 42) 7386 7387 def test_decorated_function_ignored(self): 7388 codestr = """ 7389 class C: pass 7390 7391 def mydecorator(x): 7392 return C 7393 7394 @mydecorator 7395 def f(): 7396 return 42 7397 7398 def g(): 7399 return f() 7400 7401 """ 7402 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7403 C = mod["C"] 7404 g = mod["g"] 7405 self.assertNotInBytecode(g, "INVOKE_FUNCTION") 7406 self.assertEqual(type(g()), C) 7407 7408 def test_static_function_invoke(self): 7409 codestr = """ 7410 class C: 7411 @staticmethod 7412 def f(): 7413 return 42 7414 7415 def f(): 7416 return C.f() 7417 """ 7418 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7419 f = mod["f"] 7420 self.assertInBytecode( 7421 f, "INVOKE_FUNCTION", ((mod["__name__"], "C", "f"), 0) 7422 ) 7423 self.assertEqual(f(), 42) 7424 7425 def test_static_function_invoke_on_instance(self): 7426 codestr = """ 7427 class C: 7428 @staticmethod 7429 def f(): 7430 return 42 7431 7432 def f(): 7433 return C().f() 7434 """ 7435 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7436 f = mod["f"] 7437 self.assertInBytecode( 7438 f, 7439 "INVOKE_FUNCTION", 7440 ((mod["__name__"], "C", "f"), 0), 7441 ) 7442 self.assertEqual(f(), 42) 7443 7444 def test_spamobj_no_params(self): 7445 codestr = """ 7446 from xxclassloader import spamobj 7447 7448 def f(): 7449 x = spamobj() 7450 """ 7451 with self.assertRaisesRegex( 7452 TypedSyntaxError, 7453 r"cannot create instances of a generic Type\[xxclassloader.spamobj\[T\]\]", 7454 ): 7455 self.compile(codestr, StaticCodeGenerator, modname="foo") 7456 7457 def test_spamobj_error(self): 7458 codestr = """ 7459 from xxclassloader import spamobj 7460 7461 def f(): 7462 x = spamobj[int]() 7463 return x.error(1) 7464 """ 7465 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7466 f = mod["f"] 7467 with self.assertRaisesRegex(TypeError, "no way!"): 7468 f() 7469 7470 def test_spamobj_no_error(self): 7471 codestr = """ 7472 from xxclassloader import spamobj 7473 7474 def testfunc(): 7475 x = spamobj[int]() 7476 return x.error(0) 7477 """ 7478 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7479 f = mod["testfunc"] 7480 self.assertEqual(f(), None) 7481 7482 def test_generic_type_box_box(self): 7483 codestr = """ 7484 from xxclassloader import spamobj 7485 7486 def testfunc(): 7487 x = spamobj[str]() 7488 return (x.getint(), ) 7489 """ 7490 7491 with self.assertRaisesRegex( 7492 TypedSyntaxError, 7493 "type mismatch: int64 is an invalid return type, expected dynamic", 7494 ): 7495 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 7496 7497 def test_generic_type(self): 7498 codestr = """ 7499 from xxclassloader import spamobj 7500 from __static__ import box 7501 7502 def testfunc(): 7503 x = spamobj[str]() 7504 x.setstate('abc') 7505 x.setint(42) 7506 return (x.getstate(), box(x.getint())) 7507 """ 7508 7509 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 7510 f = self.find_code(code, "testfunc") 7511 self.assertInBytecode( 7512 f, 7513 "INVOKE_METHOD", 7514 ((("xxclassloader", "spamobj", (("builtins", "str"),), "setstate"), 1)), 7515 ) 7516 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7517 test = mod["testfunc"] 7518 self.assertEqual(test(), ("abc", 42)) 7519 7520 def test_ret_void(self): 7521 codestr = """ 7522 from xxclassloader import spamobj 7523 from __static__ import box 7524 7525 def testfunc(): 7526 x = spamobj[str]() 7527 y = x.setstate('abc') 7528 return y 7529 """ 7530 7531 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 7532 f = self.find_code(code, "testfunc") 7533 self.assertInBytecode( 7534 f, 7535 "INVOKE_METHOD", 7536 ((("xxclassloader", "spamobj", (("builtins", "str"),), "setstate"), 1)), 7537 ) 7538 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7539 test = mod["testfunc"] 7540 self.assertEqual(test(), None) 7541 7542 def test_user_enumerate_list(self): 7543 codestr = """ 7544 from __static__ import int64, box, clen 7545 7546 def f(x: list): 7547 i: int64 = 0 7548 res = [] 7549 while i < clen(x): 7550 elem = x[i] 7551 res.append((box(i), elem)) 7552 i += 1 7553 return res 7554 """ 7555 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7556 f = mod["f"] 7557 self.assertInBytecode(f, "SEQUENCE_GET", SEQ_LIST_INEXACT) 7558 res = f([1, 2, 3]) 7559 self.assertEqual(res, [(0, 1), (1, 2), (2, 3)]) 7560 7561 def test_user_enumerate_list_nooverride(self): 7562 class mylist(list): 7563 pass 7564 7565 codestr = """ 7566 from __static__ import int64, box, clen 7567 7568 def f(x: list): 7569 i: int64 = 0 7570 res = [] 7571 while i < clen(x): 7572 elem = x[i] 7573 res.append((box(i), elem)) 7574 i += 1 7575 return res 7576 """ 7577 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7578 f = mod["f"] 7579 self.assertInBytecode(f, "SEQUENCE_GET", SEQ_LIST_INEXACT) 7580 res = f(mylist([1, 2, 3])) 7581 self.assertEqual(res, [(0, 1), (1, 2), (2, 3)]) 7582 7583 def test_user_enumerate_list_subclass(self): 7584 class mylist(list): 7585 def __getitem__(self, idx): 7586 return list.__getitem__(self, idx) + 1 7587 7588 codestr = """ 7589 from __static__ import int64, box, clen 7590 7591 def f(x: list): 7592 i: int64 = 0 7593 res = [] 7594 while i < clen(x): 7595 elem = x[i] 7596 res.append((box(i), elem)) 7597 i += 1 7598 return res 7599 """ 7600 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7601 f = mod["f"] 7602 self.assertInBytecode(f, "SEQUENCE_GET", SEQ_LIST_INEXACT) 7603 res = f(mylist([1, 2, 3])) 7604 self.assertEqual(res, [(0, 2), (1, 3), (2, 4)]) 7605 7606 def test_list_assign_subclass(self): 7607 class mylist(list): 7608 def __setitem__(self, idx, value): 7609 return list.__setitem__(self, idx, value + 1) 7610 7611 codestr = """ 7612 from __static__ import int64, box, clen 7613 7614 def f(x: list): 7615 i: int64 = 0 7616 x[i] = 42 7617 """ 7618 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7619 f = mod["f"] 7620 self.assertInBytecode(f, "SEQUENCE_SET", SEQ_LIST_INEXACT) 7621 l = mylist([0]) 7622 f(l) 7623 self.assertEqual(l[0], 43) 7624 7625 def test_inexact_list_negative(self): 7626 codestr = """ 7627 from __static__ import int64, box, clen 7628 7629 def f(x: list): 7630 i: int64 = 1 7631 return x[-i] 7632 """ 7633 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7634 f = mod["f"] 7635 self.assertInBytecode(f, "SEQUENCE_GET", SEQ_LIST_INEXACT) 7636 res = f([1, 2, 3]) 7637 self.assertEqual(res, 3) 7638 7639 def test_inexact_list_negative_small_int(self): 7640 codestr = """ 7641 from __static__ import int64, box, clen 7642 7643 def f(x: list): 7644 i: int8 = 1 7645 return x[-i] 7646 """ 7647 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7648 f = mod["f"] 7649 res = f([1, 2, 3]) 7650 self.assertEqual(res, 3) 7651 7652 def test_inexact_list_large_unsigned(self): 7653 codestr = """ 7654 from __static__ import uint64 7655 def f(x: list): 7656 i: uint64 = 0xffffffffffffffff 7657 return x[i] 7658 """ 7659 with self.assertRaisesRegex( 7660 TypedSyntaxError, "type mismatch: uint64 cannot be assigned to dynamic" 7661 ): 7662 self.compile(codestr) 7663 7664 def test_named_tuple(self): 7665 codestr = """ 7666 from typing import NamedTuple 7667 7668 class C(NamedTuple): 7669 x: int 7670 y: str 7671 7672 def myfunc(x: C): 7673 return x.x 7674 """ 7675 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7676 f = mod["myfunc"] 7677 self.assertNotInBytecode(f, "LOAD_FIELD") 7678 7679 def test_generic_type_error(self): 7680 codestr = """ 7681 from xxclassloader import spamobj 7682 7683 def testfunc(): 7684 x = spamobj[str]() 7685 x.setstate(42) 7686 """ 7687 7688 with self.assertRaisesRegex( 7689 TypedSyntaxError, 7690 "type mismatch: Exact\\[int\\] positional argument type mismatch str", 7691 ): 7692 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 7693 7694 def test_generic_optional_type_param(self): 7695 codestr = """ 7696 from xxclassloader import spamobj 7697 7698 def testfunc(): 7699 x = spamobj[str]() 7700 x.setstateoptional(None) 7701 """ 7702 7703 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 7704 7705 def test_generic_optional_type_param_2(self): 7706 codestr = """ 7707 from xxclassloader import spamobj 7708 7709 def testfunc(): 7710 x = spamobj[str]() 7711 x.setstateoptional('abc') 7712 """ 7713 7714 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 7715 7716 def test_generic_optional_type_param_error(self): 7717 codestr = """ 7718 from xxclassloader import spamobj 7719 7720 def testfunc(): 7721 x = spamobj[str]() 7722 x.setstateoptional(42) 7723 """ 7724 7725 with self.assertRaisesRegex( 7726 TypedSyntaxError, 7727 "type mismatch: Exact\\[int\\] positional argument type mismatch Optional\\[str\\]", 7728 ): 7729 code = self.compile(codestr, StaticCodeGenerator, modname="foo") 7730 7731 def test_compile_nested_dict(self): 7732 codestr = """ 7733 from __static__ import CheckedDict 7734 7735 class B: pass 7736 class D(B): pass 7737 7738 def testfunc(): 7739 x = CheckedDict[B, int]({B():42, D():42}) 7740 y = CheckedDict[int, CheckedDict[B, int]]({42: x}) 7741 return y 7742 """ 7743 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7744 test = mod["testfunc"] 7745 B = mod["B"] 7746 self.assertEqual(type(test()), chkdict[int, chkdict[B, int]]) 7747 7748 def test_compile_dict_setdefault(self): 7749 codestr = """ 7750 from __static__ import CheckedDict 7751 def testfunc(): 7752 x = CheckedDict[int, str]({42: 'abc', }) 7753 x.setdefault(100, 43) 7754 """ 7755 with self.assertRaisesRegex( 7756 TypedSyntaxError, 7757 "type mismatch: Exact\\[int\\] positional argument type mismatch Optional\\[str\\]", 7758 ): 7759 self.compile(codestr, StaticCodeGenerator, modname="foo") 7760 7761 def test_compile_dict_get(self): 7762 codestr = """ 7763 from __static__ import CheckedDict 7764 def testfunc(): 7765 x = CheckedDict[int, str]({42: 'abc', }) 7766 x.get(42, 42) 7767 """ 7768 with self.assertRaisesRegex( 7769 TypedSyntaxError, 7770 "type mismatch: Exact\\[int\\] positional argument type mismatch Optional\\[str\\]", 7771 ): 7772 self.compile(codestr, StaticCodeGenerator, modname="foo") 7773 7774 codestr = """ 7775 from __static__ import CheckedDict 7776 7777 class B: pass 7778 class D(B): pass 7779 7780 def testfunc(): 7781 x = CheckedDict[B, int]({B():42, D():42}) 7782 return x 7783 """ 7784 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7785 test = mod["testfunc"] 7786 B = mod["B"] 7787 self.assertEqual(type(test()), chkdict[B, int]) 7788 7789 def test_compile_dict_setitem(self): 7790 codestr = """ 7791 from __static__ import CheckedDict 7792 7793 def testfunc(): 7794 x = CheckedDict[int, str]({1:'abc'}) 7795 x.__setitem__(2, 'def') 7796 return x 7797 """ 7798 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7799 test = mod["testfunc"] 7800 x = test() 7801 self.assertInBytecode( 7802 test, 7803 "INVOKE_FUNCTION", 7804 ( 7805 ( 7806 "__static__", 7807 "chkdict", 7808 (("builtins", "int"), ("builtins", "str")), 7809 "__setitem__", 7810 ), 7811 3, 7812 ), 7813 ) 7814 self.assertEqual(x, {1: "abc", 2: "def"}) 7815 7816 def test_compile_dict_setitem_subscr(self): 7817 codestr = """ 7818 from __static__ import CheckedDict 7819 7820 def testfunc(): 7821 x = CheckedDict[int, str]({1:'abc'}) 7822 x[2] = 'def' 7823 return x 7824 """ 7825 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7826 test = mod["testfunc"] 7827 x = test() 7828 self.assertInBytecode( 7829 test, 7830 "INVOKE_METHOD", 7831 ( 7832 ( 7833 "__static__", 7834 "chkdict", 7835 (("builtins", "int"), ("builtins", "str")), 7836 "__setitem__", 7837 ), 7838 2, 7839 ), 7840 ) 7841 self.assertEqual(x, {1: "abc", 2: "def"}) 7842 7843 def test_compile_generic_dict_getitem_bad_type(self): 7844 codestr = """ 7845 from __static__ import CheckedDict 7846 7847 def testfunc(): 7848 x = CheckedDict[str, int]({"abc": 42}) 7849 return x[42] 7850 """ 7851 with self.assertRaisesRegex( 7852 TypedSyntaxError, 7853 type_mismatch("Exact[int]", "str"), 7854 ): 7855 self.compile(codestr, StaticCodeGenerator, modname="foo") 7856 7857 def test_compile_generic_dict_setitem_bad_type(self): 7858 codestr = """ 7859 from __static__ import CheckedDict 7860 7861 def testfunc(): 7862 x = CheckedDict[str, int]({"abc": 42}) 7863 x[42] = 42 7864 """ 7865 with self.assertRaisesRegex( 7866 TypedSyntaxError, 7867 type_mismatch("Exact[int]", "str"), 7868 ): 7869 self.compile(codestr, StaticCodeGenerator, modname="foo") 7870 7871 def test_compile_generic_dict_setitem_bad_type_2(self): 7872 codestr = """ 7873 from __static__ import CheckedDict 7874 7875 def testfunc(): 7876 x = CheckedDict[str, int]({"abc": 42}) 7877 x["foo"] = "abc" 7878 """ 7879 with self.assertRaisesRegex( 7880 TypedSyntaxError, 7881 type_mismatch("Exact[str]", "int"), 7882 ): 7883 self.compile(codestr, StaticCodeGenerator, modname="foo") 7884 7885 def test_compile_checked_dict_shadowcode(self): 7886 codestr = """ 7887 from __static__ import CheckedDict 7888 7889 class B: pass 7890 class D(B): pass 7891 7892 def testfunc(): 7893 x = CheckedDict[B, int]({B():42, D():42}) 7894 return x 7895 """ 7896 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7897 test = mod["testfunc"] 7898 B = mod["B"] 7899 for i in range(200): 7900 self.assertEqual(type(test()), chkdict[B, int]) 7901 7902 def test_compile_checked_dict_optional(self): 7903 codestr = """ 7904 from __static__ import CheckedDict 7905 from typing import Optional 7906 7907 def testfunc(): 7908 x = CheckedDict[str, str | None]({ 7909 'x': None, 7910 'y': 'z' 7911 }) 7912 return x 7913 """ 7914 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7915 f = mod["testfunc"] 7916 x = f() 7917 x["z"] = None 7918 self.assertEqual(type(x), chkdict[str, str | None]) 7919 7920 def test_compile_checked_dict_bad_annotation(self): 7921 codestr = """ 7922 from __static__ import CheckedDict 7923 7924 def testfunc(): 7925 x: 42 = CheckedDict[str, str]({'abc':'abc'}) 7926 return x 7927 """ 7928 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7929 test = mod["testfunc"] 7930 self.assertEqual(type(test()), chkdict[str, str]) 7931 7932 def test_compile_checked_dict_ann_differs(self): 7933 codestr = """ 7934 from __static__ import CheckedDict 7935 7936 def testfunc(): 7937 x: CheckedDict[int, int] = CheckedDict[str, str]({'abc':'abc'}) 7938 return x 7939 """ 7940 with self.assertRaisesRegex( 7941 TypedSyntaxError, 7942 type_mismatch( 7943 "Exact[chkdict[str, str]]", 7944 "chkdict[int, int]", 7945 ), 7946 ): 7947 self.compile(codestr, StaticCodeGenerator, modname="foo") 7948 7949 def test_compile_checked_dict_ann_differs_2(self): 7950 codestr = """ 7951 from __static__ import CheckedDict 7952 7953 def testfunc(): 7954 x: int = CheckedDict[str, str]({'abc':'abc'}) 7955 return x 7956 """ 7957 with self.assertRaisesRegex( 7958 TypedSyntaxError, 7959 type_mismatch("Exact[chkdict[str, str]]", "int"), 7960 ): 7961 self.compile(codestr, StaticCodeGenerator, modname="foo") 7962 7963 def test_compile_checked_dict_opt_out(self): 7964 codestr = """ 7965 from __static__.compiler_flags import nonchecked_dicts 7966 class B: pass 7967 class D(B): pass 7968 7969 def testfunc(): 7970 x = {B():42, D():42} 7971 return x 7972 """ 7973 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7974 test = mod["testfunc"] 7975 B = mod["B"] 7976 self.assertEqual(type(test()), dict) 7977 7978 def test_compile_checked_dict_explicit_dict(self): 7979 codestr = """ 7980 from __static__ import pydict 7981 class B: pass 7982 class D(B): pass 7983 7984 def testfunc(): 7985 x: pydict = {B():42, D():42} 7986 return x 7987 """ 7988 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 7989 test = mod["testfunc"] 7990 self.assertEqual(type(test()), dict) 7991 7992 def test_compile_checked_dict_reversed(self): 7993 codestr = """ 7994 from __static__ import CheckedDict 7995 7996 class B: pass 7997 class D(B): pass 7998 7999 def testfunc(): 8000 x = CheckedDict[B, int]({D():42, B():42}) 8001 return x 8002 """ 8003 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8004 test = mod["testfunc"] 8005 B = mod["B"] 8006 self.assertEqual(type(test()), chkdict[B, int]) 8007 8008 def test_compile_checked_dict_type_specified(self): 8009 codestr = """ 8010 from __static__ import CheckedDict 8011 8012 class B: pass 8013 class D(B): pass 8014 8015 def testfunc(): 8016 x: CheckedDict[B, int] = CheckedDict[B, int]({D():42}) 8017 return x 8018 """ 8019 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8020 test = mod["testfunc"] 8021 B = mod["B"] 8022 self.assertEqual(type(test()), chkdict[B, int]) 8023 8024 def test_compile_checked_dict_with_annotation_comprehension(self): 8025 codestr = """ 8026 from __static__ import CheckedDict 8027 8028 def testfunc(): 8029 x: CheckedDict[int, object] = {int(i): object() for i in range(1, 5)} 8030 return x 8031 """ 8032 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8033 test = mod["testfunc"] 8034 self.assertEqual(type(test()), chkdict[int, object]) 8035 8036 def test_compile_checked_dict_with_annotation(self): 8037 codestr = """ 8038 from __static__ import CheckedDict 8039 8040 class B: pass 8041 8042 def testfunc(): 8043 x: CheckedDict[B, int] = {B():42} 8044 return x 8045 """ 8046 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8047 test = mod["testfunc"] 8048 B = mod["B"] 8049 self.assertEqual(type(test()), chkdict[B, int]) 8050 8051 def test_compile_checked_dict_with_annotation_wrong_value_type(self): 8052 codestr = """ 8053 from __static__ import CheckedDict 8054 8055 class B: pass 8056 8057 def testfunc(): 8058 x: CheckedDict[B, int] = {B():'hi'} 8059 return x 8060 """ 8061 with self.assertRaisesRegex( 8062 TypedSyntaxError, 8063 type_mismatch( 8064 "Exact[chkdict[foo.B, Exact[str]]]", 8065 "chkdict[foo.B, int]", 8066 ), 8067 ): 8068 self.compile(codestr, modname="foo") 8069 8070 def test_compile_checked_dict_with_annotation_wrong_key_type(self): 8071 codestr = """ 8072 from __static__ import CheckedDict 8073 8074 class B: pass 8075 8076 def testfunc(): 8077 x: CheckedDict[B, int] = {object():42} 8078 return x 8079 """ 8080 with self.assertRaisesRegex( 8081 TypedSyntaxError, 8082 type_mismatch( 8083 "Exact[chkdict[object, Exact[int]]]", 8084 "chkdict[foo.B, int]", 8085 ), 8086 ): 8087 self.compile(codestr, modname="foo") 8088 8089 def test_compile_checked_dict_wrong_unknown_type(self): 8090 codestr = """ 8091 def f(x: int): 8092 return x 8093 8094 def testfunc(iter): 8095 return f({x:42 for x in iter}) 8096 8097 """ 8098 with self.assertRaisesRegex( 8099 TypedSyntaxError, "positional argument type mismatch" 8100 ): 8101 self.compile(codestr, StaticCodeGenerator, modname="foo") 8102 8103 def test_compile_checked_dict_opt_out_dict_call(self): 8104 codestr = """ 8105 from __static__.compiler_flags import nonchecked_dicts 8106 8107 def testfunc(): 8108 x = dict(x=42) 8109 return x 8110 """ 8111 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8112 test = mod["testfunc"] 8113 self.assertEqual(type(test()), dict) 8114 8115 def test_compile_checked_dict_explicit_dict_as_dict(self): 8116 codestr = """ 8117 from __static__ import pydict as dict 8118 class B: pass 8119 class D(B): pass 8120 8121 def testfunc(): 8122 x: dict = {B():42, D():42} 8123 return x 8124 """ 8125 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8126 test = mod["testfunc"] 8127 self.assertEqual(type(test()), dict) 8128 8129 def test_compile_checked_dict_from_dict_call(self): 8130 codestr = """ 8131 def testfunc(): 8132 x = dict(x=42) 8133 return x 8134 """ 8135 with self.assertRaisesRegex( 8136 TypeError, "cannot create 'dict\\[K, V\\]' instances" 8137 ): 8138 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8139 test = mod["testfunc"] 8140 test() 8141 8142 def test_compile_checked_dict_from_dict_call_2(self): 8143 codestr = """ 8144 def testfunc(): 8145 x = dict[str, int](x=42) 8146 return x 8147 """ 8148 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8149 test = mod["testfunc"] 8150 self.assertEqual(type(test()), chkdict[str, int]) 8151 8152 def test_compile_checked_dict_from_dict_call_3(self): 8153 # we emit the chkdict import first before future annotations, but that 8154 # should be fine as we're the compiler. 8155 codestr = """ 8156 from __future__ import annotations 8157 8158 def testfunc(): 8159 x = dict[str, int](x=42) 8160 return x 8161 """ 8162 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8163 test = mod["testfunc"] 8164 self.assertEqual(type(test()), chkdict[str, int]) 8165 8166 def test_patch_function(self): 8167 codestr = """ 8168 def f(): 8169 return 42 8170 8171 def g(): 8172 return f() 8173 """ 8174 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8175 g = mod["g"] 8176 for i in range(100): 8177 g() 8178 with patch(f"{mod['__name__']}.f", autospec=True, return_value=100) as p: 8179 self.assertEqual(g(), 100) 8180 8181 def test_patch_async_function(self): 8182 codestr = """ 8183 class C: 8184 async def f(self) -> int: 8185 return 42 8186 8187 def g(self): 8188 return self.f() 8189 """ 8190 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8191 C = mod["C"] 8192 c = C() 8193 for i in range(100): 8194 try: 8195 c.g().send(None) 8196 except StopIteration as e: 8197 self.assertEqual(e.args[0], 42) 8198 8199 with patch(f"{mod['__name__']}.C.f", autospec=True, return_value=100) as p: 8200 try: 8201 c.g().send(None) 8202 except StopIteration as e: 8203 self.assertEqual(e.args[0], 100) 8204 8205 def test_patch_parentclass_slot(self): 8206 codestr = """ 8207 class A: 8208 def f(self) -> int: 8209 return 3 8210 8211 class B(A): 8212 pass 8213 8214 def a_f_invoker() -> int: 8215 return A().f() 8216 8217 def b_f_invoker() -> int: 8218 return B().f() 8219 """ 8220 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8221 A = mod["A"] 8222 a_f_invoker = mod["a_f_invoker"] 8223 b_f_invoker = mod["b_f_invoker"] 8224 setattr(A, "f", lambda _: 7) 8225 8226 self.assertEqual(a_f_invoker(), 7) 8227 self.assertEqual(b_f_invoker(), 7) 8228 8229 def test_self_patching_function(self): 8230 codestr = """ 8231 def x(d, d2=1): pass 8232 def removeit(d): 8233 global f 8234 f = x 8235 8236 def f(d): 8237 if d: 8238 removeit(d) 8239 return 42 8240 8241 def g(d): 8242 return f(d) 8243 """ 8244 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8245 g = mod["g"] 8246 f = mod["f"] 8247 import weakref 8248 8249 wr = weakref.ref(f, lambda *args: self.assertEqual(i, -1)) 8250 del f 8251 for i in range(100): 8252 g(False) 8253 i = -1 8254 self.assertEqual(g(True), 42) 8255 i = 0 8256 self.assertEqual(g(True), None) 8257 8258 def test_patch_function_unwatchable_dict(self): 8259 codestr = """ 8260 def f(): 8261 return 42 8262 8263 def g(): 8264 return f() 8265 """ 8266 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8267 g = mod["g"] 8268 for i in range(100): 8269 g() 8270 with patch( 8271 f"{mod['__name__']}.f", 8272 autospec=True, 8273 return_value=100, 8274 ) as p: 8275 mod[42] = 1 8276 self.assertEqual(g(), 100) 8277 8278 def test_patch_function_deleted_func(self): 8279 codestr = """ 8280 def f(): 8281 return 42 8282 8283 def g(): 8284 return f() 8285 """ 8286 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8287 g = mod["g"] 8288 for i in range(100): 8289 g() 8290 del mod["f"] 8291 with self.assertRaisesRegex( 8292 TypeError, 8293 re.escape(f"unknown function ('{mod['__name__']}', 'f')"), 8294 ): 8295 g() 8296 8297 def test_patch_static_function(self): 8298 codestr = """ 8299 class C: 8300 @staticmethod 8301 def f(): 8302 return 42 8303 8304 def g(): 8305 return C.f() 8306 """ 8307 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8308 g = mod["g"] 8309 for i in range(100): 8310 self.assertEqual(g(), 42) 8311 with patch(f"{mod['__name__']}.C.f", autospec=True, return_value=100) as p: 8312 self.assertEqual(g(), 100) 8313 8314 def test_patch_static_function_non_autospec(self): 8315 codestr = """ 8316 class C: 8317 @staticmethod 8318 def f(): 8319 return 42 8320 8321 def g(): 8322 return C.f() 8323 """ 8324 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8325 g = mod["g"] 8326 for i in range(100): 8327 g() 8328 with patch(f"{mod['__name__']}.C.f", return_value=100) as p: 8329 self.assertEqual(g(), 100) 8330 8331 def test_patch_primitive_ret_type(self): 8332 for type_name, value, patched in [ 8333 ("cbool", True, False), 8334 ("cbool", False, True), 8335 ("int8", 0, 1), 8336 ("int16", 0, 1), 8337 ("int32", 0, 1), 8338 ("int64", 0, 1), 8339 ("uint8", 0, 1), 8340 ("uint16", 0, 1), 8341 ("uint32", 0, 1), 8342 ("uint64", 0, 1), 8343 ]: 8344 with self.subTest(type_name=type, value=value, patched=patched): 8345 codestr = f""" 8346 from __static__ import {type_name}, box 8347 class C: 8348 def f(self) -> {type_name}: 8349 return {value!r} 8350 8351 def g(): 8352 return box(C().f()) 8353 """ 8354 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8355 g = mod["g"] 8356 for i in range(100): 8357 self.assertEqual(g(), value) 8358 with patch(f"{mod['__name__']}.C.f", return_value=patched) as p: 8359 self.assertEqual(g(), patched) 8360 8361 def test_patch_primitive_ret_type_overflow(self): 8362 codestr = f""" 8363 from __static__ import int8, box 8364 class C: 8365 def f(self) -> int8: 8366 return 1 8367 8368 def g(): 8369 return box(C().f()) 8370 """ 8371 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8372 g = mod["g"] 8373 for i in range(100): 8374 self.assertEqual(g(), 1) 8375 with patch(f"{mod['__name__']}.C.f", return_value=256) as p: 8376 with self.assertRaisesRegex( 8377 OverflowError, 8378 "unexpected return type from C.f, expected " 8379 "int8, got out-of-range int \\(256\\)", 8380 ): 8381 g() 8382 8383 def test_invoke_frozen_type(self): 8384 codestr = """ 8385 from cinder import freeze_type 8386 8387 @freeze_type 8388 class C: 8389 @staticmethod 8390 def f(): 8391 return 42 8392 8393 def g(): 8394 return C.f() 8395 """ 8396 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8397 g = mod["g"] 8398 for i in range(100): 8399 self.assertEqual(g(), 42) 8400 8401 def test_invoke_strict_module(self): 8402 codestr = """ 8403 def f(): 8404 return 42 8405 8406 def g(): 8407 return f() 8408 """ 8409 with self.in_strict_module(codestr) as mod: 8410 g = mod.g 8411 for i in range(100): 8412 self.assertEqual(g(), 42) 8413 self.assertInBytecode(g, "INVOKE_FUNCTION", ((mod.__name__, "f"), 0)) 8414 8415 def test_invoke_with_cell(self): 8416 codestr = """ 8417 def f(l: list): 8418 x = 2 8419 return [x + y for y in l] 8420 8421 def g(): 8422 return f([1,2,3]) 8423 """ 8424 with self.in_strict_module(codestr) as mod: 8425 g = mod.g 8426 self.assertEqual(g(), [3, 4, 5]) 8427 self.assertInBytecode(g, "INVOKE_FUNCTION", ((mod.__name__, "f"), 1)) 8428 8429 def test_invoke_with_cell_arg(self): 8430 codestr = """ 8431 def f(l: list, x: int): 8432 return [x + y for y in l] 8433 8434 def g(): 8435 return f([1,2,3], 2) 8436 """ 8437 with self.in_strict_module(codestr) as mod: 8438 g = mod.g 8439 self.assertEqual(g(), [3, 4, 5]) 8440 self.assertInBytecode(g, "INVOKE_FUNCTION", ((mod.__name__, "f"), 2)) 8441 8442 def test_invoke_all_reg_args(self): 8443 codestr = """ 8444 def target(a, b, c, d, e, f): 8445 return a * 2 + b * 3 + c * 4 + d * 5 + e * 6 + f * 7 8446 8447 def testfunc(): 8448 return target(1,2,3,4,5,6) 8449 """ 8450 with self.in_strict_module(codestr) as mod: 8451 f = mod.testfunc 8452 self.assertInBytecode( 8453 f, 8454 "INVOKE_FUNCTION", 8455 ((mod.__name__, "target"), 6), 8456 ) 8457 self.assertEqual(f(), 112) 8458 8459 def test_invoke_all_extra_args(self): 8460 codestr = """ 8461 def target(a, b, c, d, e, f, g): 8462 return a * 2 + b * 3 + c * 4 + d * 5 + e * 6 + f * 7 + g 8463 8464 def testfunc(): 8465 return target(1,2,3,4,5,6,7) 8466 """ 8467 with self.in_strict_module(codestr) as mod: 8468 f = mod.testfunc 8469 self.assertInBytecode( 8470 f, 8471 "INVOKE_FUNCTION", 8472 ((mod.__name__, "target"), 7), 8473 ) 8474 self.assertEqual(f(), 119) 8475 8476 def test_invoke_strict_module_deep(self): 8477 codestr = """ 8478 def f0(): return 42 8479 def f1(): return f0() 8480 def f2(): return f1() 8481 def f3(): return f2() 8482 def f4(): return f3() 8483 def f5(): return f4() 8484 def f6(): return f5() 8485 def f7(): return f6() 8486 def f8(): return f7() 8487 def f9(): return f8() 8488 def f10(): return f9() 8489 def f11(): return f10() 8490 8491 def g(): 8492 return f11() 8493 """ 8494 with self.in_strict_module(codestr) as mod: 8495 g = mod.g 8496 self.assertEqual(g(), 42) 8497 self.assertEqual(g(), 42) 8498 self.assertInBytecode(g, "INVOKE_FUNCTION", ((mod.__name__, "f11"), 0)) 8499 8500 def test_invoke_strict_module_deep_unjitable(self): 8501 codestr = """ 8502 def f0(): return 42 8503 def f1(): 8504 from sys import * 8505 return f0() 8506 def f2(): return f1() 8507 def f3(): return f2() 8508 def f4(): return f3() 8509 def f5(): return f4() 8510 def f6(): return f5() 8511 def f7(): return f6() 8512 def f8(): return f7() 8513 def f9(): return f8() 8514 def f10(): return f9() 8515 def f11(): return f10() 8516 8517 def g(x): 8518 if x: return 0 8519 8520 return f11() 8521 """ 8522 with self.in_strict_module(codestr) as mod: 8523 g = mod.g 8524 self.assertEqual(g(True), 0) 8525 # we should have done some level of pre-jitting 8526 self.assert_not_jitted(mod.f2) 8527 self.assert_not_jitted(mod.f1) 8528 self.assert_not_jitted(mod.f0) 8529 [self.assert_jitted(getattr(mod, f"f{i}")) for i in range(3, 12)] 8530 self.assertEqual(g(False), 42) 8531 self.assertInBytecode( 8532 g, 8533 "INVOKE_FUNCTION", 8534 ((mod.__name__, "f11"), 0), 8535 ) 8536 8537 def test_invoke_strict_module_deep_unjitable_many_args(self): 8538 codestr = """ 8539 def f0(): return 42 8540 def f1(a, b, c, d, e, f, g, h): 8541 from sys import * 8542 return f0() - a + b - c + d - e + f - g + h - 4 8543 8544 def f2(): return f1(1,2,3,4,5,6,7,8) 8545 def f3(): return f2() 8546 def f4(): return f3() 8547 def f5(): return f4() 8548 def f6(): return f5() 8549 def f7(): return f6() 8550 def f8(): return f7() 8551 def f9(): return f8() 8552 def f10(): return f9() 8553 def f11(): return f10() 8554 8555 def g(): 8556 return f11() 8557 """ 8558 with self.in_strict_module(codestr) as mod: 8559 g = mod.g 8560 f1 = mod.f1 8561 self.assertEqual(g(), 42) 8562 self.assertEqual(g(), 42) 8563 self.assertInBytecode( 8564 g, 8565 "INVOKE_FUNCTION", 8566 ((mod.__name__, "f11"), 0), 8567 ) 8568 self.assert_not_jitted(f1) 8569 8570 def test_invoke_strict_module_recursive(self): 8571 codestr = """ 8572 def fib(number): 8573 if number <= 1: 8574 return number 8575 return(fib(number-1) + fib(number-2)) 8576 """ 8577 with self.in_strict_module(codestr) as mod: 8578 fib = mod.fib 8579 self.assertInBytecode( 8580 fib, 8581 "INVOKE_FUNCTION", 8582 ((mod.__name__, "fib"), 1), 8583 ) 8584 self.assertEqual(fib(4), 3) 8585 8586 def test_invoke_strict_module_mutual_recursive(self): 8587 codestr = """ 8588 def fib1(number): 8589 if number <= 1: 8590 return number 8591 return(fib(number-1) + fib(number-2)) 8592 8593 def fib(number): 8594 if number <= 1: 8595 return number 8596 return(fib1(number-1) + fib1(number-2)) 8597 """ 8598 with self.in_strict_module(codestr) as mod: 8599 fib = mod.fib 8600 fib1 = mod.fib1 8601 self.assertInBytecode( 8602 fib, 8603 "INVOKE_FUNCTION", 8604 ((mod.__name__, "fib1"), 1), 8605 ) 8606 self.assertInBytecode( 8607 fib1, 8608 "INVOKE_FUNCTION", 8609 ((mod.__name__, "fib"), 1), 8610 ) 8611 self.assertEqual(fib(0), 0) 8612 self.assert_jitted(fib1) 8613 self.assertEqual(fib(4), 3) 8614 8615 def test_invoke_strict_module_pre_invoked(self): 8616 codestr = """ 8617 def f(): 8618 return 42 8619 8620 def g(): 8621 return f() 8622 """ 8623 with self.in_strict_module(codestr) as mod: 8624 self.assertEqual(mod.f(), 42) 8625 self.assert_jitted(mod.f) 8626 g = mod.g 8627 self.assertEqual(g(), 42) 8628 self.assertInBytecode( 8629 g, 8630 "INVOKE_FUNCTION", 8631 ((mod.__name__, "f"), 0), 8632 ) 8633 8634 def test_invoke_strict_module_patching(self): 8635 codestr = """ 8636 def f(): 8637 return 42 8638 8639 def g(): 8640 return f() 8641 """ 8642 with self.in_strict_module(codestr, enable_patching=True) as mod: 8643 g = mod.g 8644 for i in range(100): 8645 self.assertEqual(g(), 42) 8646 self.assertInBytecode(g, "INVOKE_FUNCTION", ((mod.__name__, "f"), 0)) 8647 mod.patch("f", lambda: 100) 8648 self.assertEqual(g(), 100) 8649 8650 def test_invoke_patch_non_vectorcall(self): 8651 codestr = """ 8652 def f(): 8653 return 42 8654 8655 def g(): 8656 return f() 8657 """ 8658 with self.in_strict_module(codestr, enable_patching=True) as mod: 8659 g = mod.g 8660 self.assertInBytecode(g, "INVOKE_FUNCTION", ((mod.__name__, "f"), 0)) 8661 self.assertEqual(g(), 42) 8662 mod.patch("f", Mock(return_value=100)) 8663 self.assertEqual(g(), 100) 8664 8665 def test_patch_method(self): 8666 codestr = """ 8667 class C: 8668 def f(self): 8669 pass 8670 8671 def g(): 8672 return C().f() 8673 """ 8674 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8675 g = mod["g"] 8676 C = mod["C"] 8677 orig = C.f 8678 C.f = lambda *args: args 8679 for i in range(100): 8680 v = g() 8681 self.assertEqual(type(v), tuple) 8682 self.assertEqual(type(v[0]), C) 8683 C.f = orig 8684 self.assertEqual(g(), None) 8685 8686 def test_patch_method_ret_none_error(self): 8687 codestr = """ 8688 class C: 8689 def f(self) -> None: 8690 pass 8691 8692 def g(): 8693 return C().f() 8694 """ 8695 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8696 g = mod["g"] 8697 C = mod["C"] 8698 C.f = lambda *args: args 8699 with self.assertRaisesRegex( 8700 TypeError, 8701 "unexpected return type from C.f, expected NoneType, got tuple", 8702 ): 8703 v = g() 8704 8705 def test_patch_method_ret_none(self): 8706 codestr = """ 8707 class C: 8708 def f(self) -> None: 8709 pass 8710 8711 def g(): 8712 return C().f() 8713 """ 8714 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8715 g = mod["g"] 8716 C = mod["C"] 8717 C.f = lambda *args: None 8718 self.assertEqual(g(), None) 8719 8720 def test_patch_method_bad_ret(self): 8721 codestr = """ 8722 class C: 8723 def f(self) -> int: 8724 return 42 8725 8726 def g(): 8727 return C().f() 8728 """ 8729 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 8730 g = mod["g"] 8731 C = mod["C"] 8732 C.f = lambda *args: "abc" 8733 with self.assertRaisesRegex( 8734 TypeError, "unexpected return type from C.f, expected int, got str" 8735 ): 8736 v = g() 8737 8738 def test_primitive_args_funcdef(self): 8739 codestr = """ 8740 from __static__ import int8, box 8741 8742 def n(val: int8): 8743 return box(val) 8744 8745 def x(): 8746 y: int8 = 42 8747 return n(y) 8748 """ 8749 with self.in_strict_module(codestr) as mod: 8750 n = mod.n 8751 x = mod.x 8752 self.assertEqual(x(), 42) 8753 self.assertEqual(mod.n(-128), -128) 8754 self.assertEqual(mod.n(127), 127) 8755 with self.assertRaises(OverflowError): 8756 print(mod.n(-129)) 8757 with self.assertRaises(OverflowError): 8758 print(mod.n(128)) 8759 8760 def test_primitive_args_funcdef_unjitable(self): 8761 codestr = """ 8762 from __static__ import int8, box 8763 8764 def n(val: int8): 8765 try: 8766 from sys import * 8767 except: 8768 pass 8769 return box(val) 8770 8771 def x(): 8772 y: int8 = 42 8773 return n(y) 8774 """ 8775 with self.in_strict_module(codestr) as mod: 8776 n = mod.n 8777 x = mod.x 8778 self.assertEqual(x(), 42) 8779 self.assertEqual(mod.n(-128), -128) 8780 self.assertEqual(mod.n(127), 127) 8781 with self.assertRaises(OverflowError): 8782 print(mod.n(-129)) 8783 with self.assertRaises(OverflowError): 8784 print(mod.n(128)) 8785 8786 def test_primitive_args_funcdef_too_many_args(self): 8787 codestr = """ 8788 from __static__ import int8, box 8789 8790 def n(x: int8): 8791 return box(x) 8792 """ 8793 with self.in_strict_module(codestr) as mod: 8794 n = mod.n 8795 with self.assertRaises(TypeError): 8796 print(mod.n(-128, x=2)) 8797 with self.assertRaises(TypeError): 8798 print(mod.n(-128, 2)) 8799 8800 def test_primitive_args_funcdef_missing_starargs(self): 8801 codestr = """ 8802 from __static__ import int8, box 8803 8804 def x(val: int8, *foo): 8805 return box(val), foo 8806 def y(val: int8, **foo): 8807 return box(val), foo 8808 """ 8809 with self.in_strict_module(codestr) as mod: 8810 self.assertEqual(mod.x(-128), (-128, ())) 8811 self.assertEqual(mod.y(-128), (-128, {})) 8812 8813 def test_primitive_args_many_args(self): 8814 codestr = """ 8815 from __static__ import int8, int16, int32, int64, uint8, uint16, uint32, uint64, box 8816 8817 def x(i8: int8, i16: int16, i32: int32, i64: int64, u8: uint8, u16: uint16, u32: uint32, u64: uint64): 8818 return box(i8), box(i16), box(i32), box(i64), box(u8), box(u16), box(u32), box(u64) 8819 8820 def y(): 8821 return x(1,2,3,4,5,6,7,8) 8822 """ 8823 with self.in_strict_module(codestr) as mod: 8824 self.assertInBytecode(mod.y, "INVOKE_FUNCTION", ((mod.__name__, "x"), 8)) 8825 self.assertEqual(mod.y(), (1, 2, 3, 4, 5, 6, 7, 8)) 8826 self.assertEqual(mod.x(1, 2, 3, 4, 5, 6, 7, 8), (1, 2, 3, 4, 5, 6, 7, 8)) 8827 8828 def test_primitive_args_sizes(self): 8829 cases = [ 8830 ("cbool", True, False), 8831 ("cbool", False, False), 8832 ("int8", (1 << 7), True), 8833 ("int8", (-1 << 7) - 1, True), 8834 ("int8", -1 << 7, False), 8835 ("int8", (1 << 7) - 1, False), 8836 ("int16", (1 << 15), True), 8837 ("int16", (-1 << 15) - 1, True), 8838 ("int16", -1 << 15, False), 8839 ("int16", (1 << 15) - 1, False), 8840 ("int32", (1 << 31), True), 8841 ("int32", (-1 << 31) - 1, True), 8842 ("int32", -1 << 31, False), 8843 ("int32", (1 << 31) - 1, False), 8844 ("int64", (1 << 63), True), 8845 ("int64", (-1 << 63) - 1, True), 8846 ("int64", -1 << 63, False), 8847 ("int64", (1 << 63) - 1, False), 8848 ("uint8", (1 << 8), True), 8849 ("uint8", -1, True), 8850 ("uint8", (1 << 8) - 1, False), 8851 ("uint8", 0, False), 8852 ("uint16", (1 << 16), True), 8853 ("uint16", -1, True), 8854 ("uint16", (1 << 16) - 1, False), 8855 ("uint16", 0, False), 8856 ("uint32", (1 << 32), True), 8857 ("uint32", -1, True), 8858 ("uint32", (1 << 32) - 1, False), 8859 ("uint32", 0, False), 8860 ("uint64", (1 << 64), True), 8861 ("uint64", -1, True), 8862 ("uint64", (1 << 64) - 1, False), 8863 ("uint64", 0, False), 8864 ] 8865 for type, val, overflows in cases: 8866 codestr = f""" 8867 from __static__ import {type}, box 8868 8869 def x(val: {type}): 8870 return box(val) 8871 """ 8872 with self.subTest(type=type, val=val, overflows=overflows): 8873 with self.in_strict_module(codestr) as mod: 8874 if overflows: 8875 with self.assertRaises(OverflowError): 8876 mod.x(val) 8877 else: 8878 self.assertEqual(mod.x(val), val) 8879 8880 def test_primitive_args_funcdef_missing_kw_call(self): 8881 codestr = """ 8882 from __static__ import int8, box 8883 8884 def testfunc(x: int8, foo): 8885 return box(x), foo 8886 """ 8887 with self.in_strict_module(codestr) as mod: 8888 self.assertEqual(mod.testfunc(-128, foo=42), (-128, 42)) 8889 8890 def test_primitive_args_funccall(self): 8891 codestr = """ 8892 from __static__ import int8 8893 8894 def f(foo): 8895 pass 8896 8897 def n() -> int: 8898 x: int8 = 3 8899 return f(x) 8900 """ 8901 with self.assertRaisesRegex( 8902 TypedSyntaxError, 8903 "type mismatch: int8 positional argument type mismatch dynamic", 8904 ): 8905 self.compile(codestr, StaticCodeGenerator, modname="foo.py") 8906 8907 def test_primitive_args_funccall_int(self): 8908 codestr = """ 8909 from __static__ import int8 8910 8911 def f(foo: int): 8912 pass 8913 8914 def n() -> int: 8915 x: int8 = 3 8916 return f(x) 8917 """ 8918 with self.assertRaisesRegex( 8919 TypedSyntaxError, 8920 "type mismatch: int8 positional argument type mismatch int", 8921 ): 8922 self.compile(codestr, StaticCodeGenerator, modname="foo.py") 8923 8924 def test_primitive_args_typecall(self): 8925 codestr = """ 8926 from __static__ import int8 8927 8928 def n() -> int: 8929 x: int8 = 3 8930 return int(x) 8931 """ 8932 with self.assertRaisesRegex( 8933 TypedSyntaxError, "Call argument cannot be a primitive" 8934 ): 8935 self.compile(codestr, StaticCodeGenerator, modname="foo.py") 8936 8937 def test_primitive_args_typecall_kwarg(self): 8938 codestr = """ 8939 from __static__ import int8 8940 8941 def n() -> int: 8942 x: int8 = 3 8943 return dict(a=x) 8944 """ 8945 with self.assertRaisesRegex( 8946 TypedSyntaxError, "Call argument cannot be a primitive" 8947 ): 8948 self.compile(codestr, StaticCodeGenerator, modname="foo.py") 8949 8950 def test_primitive_args_nonstrict(self): 8951 codestr = """ 8952 from __static__ import int8, int16, box 8953 8954 def f(x: int8, y: int16) -> int16: 8955 return x + y 8956 8957 def g() -> int: 8958 return box(f(1, 300)) 8959 """ 8960 with self.in_module(codestr) as mod: 8961 self.assertEqual(mod["g"](), 301) 8962 8963 def test_primitive_args_and_return(self): 8964 cases = [ 8965 ("cbool", 1), 8966 ("cbool", 0), 8967 ("int8", -1 << 7), 8968 ("int8", (1 << 7) - 1), 8969 ("int16", -1 << 15), 8970 ("int16", (1 << 15) - 1), 8971 ("int32", -1 << 31), 8972 ("int32", (1 << 31) - 1), 8973 ("int64", -1 << 63), 8974 ("int64", (1 << 63) - 1), 8975 ("uint8", (1 << 8) - 1), 8976 ("uint8", 0), 8977 ("uint16", (1 << 16) - 1), 8978 ("uint16", 0), 8979 ("uint32", (1 << 32) - 1), 8980 ("uint32", 0), 8981 ("uint64", (1 << 64) - 1), 8982 ("uint64", 0), 8983 ] 8984 for typ, val in cases: 8985 if typ == "cbool": 8986 op = "or" 8987 expected = True 8988 other = "cbool(True)" 8989 boxed = "bool" 8990 else: 8991 op = "+" if val <= 0 else "-" 8992 expected = val + (1 if op == "+" else -1) 8993 other = "1" 8994 boxed = "int" 8995 with self.subTest(typ=typ, val=val, op=op, expected=expected): 8996 codestr = f""" 8997 from __static__ import {typ}, box 8998 8999 def f(x: {typ}, y: {typ}) -> {typ}: 9000 return x {op} y 9001 9002 def g() -> {boxed}: 9003 return box(f({val}, {other})) 9004 """ 9005 with self.in_strict_module(codestr) as mod: 9006 self.assertEqual(mod.g(), expected) 9007 9008 def test_primitive_return(self): 9009 cases = [ 9010 ("cbool", True), 9011 ("cbool", False), 9012 ("int8", -1 << 7), 9013 ("int8", (1 << 7) - 1), 9014 ("int16", -1 << 15), 9015 ("int16", (1 << 15) - 1), 9016 ("int32", -1 << 31), 9017 ("int32", (1 << 31) - 1), 9018 ("int64", -1 << 63), 9019 ("int64", (1 << 63) - 1), 9020 ("uint8", (1 << 8) - 1), 9021 ("uint8", 0), 9022 ("uint16", (1 << 16) - 1), 9023 ("uint16", 0), 9024 ("uint32", (1 << 32) - 1), 9025 ("uint32", 0), 9026 ("uint64", (1 << 64) - 1), 9027 ("uint64", 0), 9028 ] 9029 tf = [True, False] 9030 for (type, val), box, strict, error, unjitable in itertools.product( 9031 cases, [False], [True], [False], [True] 9032 ): 9033 if type == "cbool": 9034 op = "or" 9035 other = "False" 9036 boxed = "bool" 9037 else: 9038 op = "*" 9039 other = "1" 9040 boxed = "int" 9041 unjitable_code = "from sys import *" if unjitable else "" 9042 codestr = f""" 9043 from __static__ import {type}, box 9044 9045 def f(error: bool) -> {type}: 9046 {unjitable_code} 9047 if error: 9048 raise RuntimeError("boom") 9049 return {val} 9050 """ 9051 if box: 9052 codestr += f""" 9053 9054 def g() -> {boxed}: 9055 return box(f({error}) {op} {type}({other})) 9056 """ 9057 else: 9058 codestr += f""" 9059 9060 def g() -> {type}: 9061 return f({error}) {op} {type}({other}) 9062 """ 9063 ctx = self.in_strict_module if strict else self.in_module 9064 oparg = PRIM_NAME_TO_TYPE[type] 9065 with self.subTest( 9066 type=type, 9067 val=val, 9068 strict=strict, 9069 box=box, 9070 error=error, 9071 unjitable=unjitable, 9072 ): 9073 with ctx(codestr) as mod: 9074 f = mod.f if strict else mod["f"] 9075 g = mod.g if strict else mod["g"] 9076 self.assertInBytecode(f, "RETURN_INT", oparg) 9077 if box: 9078 self.assertNotInBytecode(g, "RETURN_INT") 9079 else: 9080 self.assertInBytecode(g, "RETURN_INT", oparg) 9081 if error: 9082 with self.assertRaisesRegex(RuntimeError, "boom"): 9083 g() 9084 else: 9085 self.assertEqual(g(), val) 9086 self.assert_jitted(g) 9087 if unjitable: 9088 self.assert_not_jitted(f) 9089 else: 9090 self.assert_jitted(f) 9091 9092 def test_primitive_return_recursive(self): 9093 codestr = """ 9094 from __static__ import int32 9095 9096 def fib(n: int32) -> int32: 9097 if n <= 1: 9098 return n 9099 return fib(n-1) + fib(n-2) 9100 """ 9101 with self.in_strict_module(codestr) as mod: 9102 self.assertInBytecode( 9103 mod.fib, 9104 "INVOKE_FUNCTION", 9105 ((mod.__name__, "fib"), 1), 9106 ) 9107 self.assertEqual(mod.fib(2), 1) 9108 self.assert_jitted(mod.fib) 9109 9110 def test_primitive_return_unannotated(self): 9111 codestr = """ 9112 from __static__ import int32 9113 9114 def f(): 9115 x: int32 = 1 9116 return x 9117 """ 9118 with self.assertRaisesRegex( 9119 TypedSyntaxError, "type mismatch: int32 cannot be assigned to dynamic" 9120 ): 9121 self.compile(codestr) 9122 9123 def test_module_level_final_decl(self): 9124 codestr = """ 9125 from typing import Final 9126 9127 x: Final 9128 """ 9129 with self.assertRaisesRegex( 9130 TypedSyntaxError, "Must assign a value when declaring a Final" 9131 ): 9132 self.compile(codestr, StaticCodeGenerator, modname="foo") 9133 9134 def test_int_compare_to_cbool(self): 9135 codestr = """ 9136 from __static__ import int64, cbool 9137 def foo(i: int64) -> cbool: 9138 return i == 0 9139 """ 9140 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 9141 foo = mod["foo"] 9142 self.assertEqual(foo(0), True) 9143 self.assertEqual(foo(1), False) 9144 9145 def test_int_compare_to_cbool_reversed(self): 9146 codestr = """ 9147 from __static__ import int64, cbool 9148 def foo(i: int64) -> cbool: 9149 return 0 == i 9150 """ 9151 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 9152 foo = mod["foo"] 9153 self.assertEqual(foo(0), True) 9154 self.assertEqual(foo(1), False) 9155 9156 def test_inline_primitive(self): 9157 codestr = """ 9158 from __static__ import int64, cbool, inline 9159 9160 @inline 9161 def x(i: int64) -> cbool: 9162 return i == 1 9163 9164 def foo(i: int64) -> cbool: 9165 return i >0 and x(i) 9166 """ 9167 with self.in_module(codestr, optimize=2) as mod: 9168 foo = mod["foo"] 9169 self.assertEqual(foo(0), False) 9170 self.assertEqual(foo(1), True) 9171 self.assertEqual(foo(2), False) 9172 self.assertNotInBytecode(foo, "STORE_FAST") 9173 self.assertInBytecode(foo, "STORE_LOCAL") 9174 9175 def test_final_multiple_typeargs(self): 9176 codestr = """ 9177 from typing import Final 9178 from something import hello 9179 9180 x: Final[int, str] = hello() 9181 """ 9182 with self.assertRaisesRegex( 9183 TypedSyntaxError, "Final types can only have a single type arg" 9184 ): 9185 self.compile(codestr, StaticCodeGenerator, modname="foo") 9186 9187 def test_final_annotation_nesting(self): 9188 with self.assertRaisesRegex( 9189 TypedSyntaxError, "Final annotation is only valid in initial declaration" 9190 ): 9191 self.compile( 9192 """ 9193 from typing import Final, List 9194 9195 x: List[Final[str]] = [] 9196 """, 9197 StaticCodeGenerator, 9198 modname="foo", 9199 ) 9200 9201 with self.assertRaisesRegex( 9202 TypedSyntaxError, "Final annotation is only valid in initial declaration" 9203 ): 9204 self.compile( 9205 """ 9206 from typing import Final, List 9207 x: List[int | Final] = [] 9208 """, 9209 StaticCodeGenerator, 9210 modname="foo", 9211 ) 9212 9213 def test_final(self): 9214 codestr = """ 9215 from typing import Final 9216 9217 x: Final = 0xdeadbeef 9218 """ 9219 self.compile(codestr, StaticCodeGenerator, modname="foo") 9220 9221 def test_final_generic(self): 9222 codestr = """ 9223 from typing import Final 9224 9225 x: Final[int] = 0xdeadbeef 9226 """ 9227 self.compile(codestr, StaticCodeGenerator, modname="foo") 9228 9229 def test_final_generic_types(self): 9230 codestr = """ 9231 from typing import Final 9232 9233 def g(i: int) -> int: 9234 return i 9235 9236 def f() -> int: 9237 x: Final[int] = 0xdeadbeef 9238 return g(x) 9239 """ 9240 self.compile(codestr, StaticCodeGenerator, modname="foo") 9241 9242 def test_final_uninitialized(self): 9243 codestr = """ 9244 from typing import Final 9245 9246 x: Final 9247 """ 9248 with self.assertRaisesRegex( 9249 TypedSyntaxError, "Must assign a value when declaring a Final" 9250 ): 9251 self.compile(codestr, StaticCodeGenerator, modname="foo") 9252 9253 def test_final_reassign(self): 9254 codestr = """ 9255 from typing import Final 9256 9257 x: Final = 0xdeadbeef 9258 x = "something" 9259 """ 9260 with self.assertRaisesRegex( 9261 TypedSyntaxError, "Cannot assign to a Final variable" 9262 ): 9263 self.compile(codestr, StaticCodeGenerator, modname="foo") 9264 9265 def test_final_reassign_explicit_global(self): 9266 codestr = """ 9267 from typing import Final 9268 9269 a: Final = 1337 9270 9271 def fn(): 9272 def fn2(): 9273 global a 9274 a = 0 9275 """ 9276 with self.assertRaisesRegex( 9277 TypedSyntaxError, "Cannot assign to a Final variable" 9278 ): 9279 self.compile(codestr, StaticCodeGenerator, modname="foo") 9280 9281 def test_final_reassign_explicit_global_shadowed(self): 9282 codestr = """ 9283 from typing import Final 9284 9285 a: Final = 1337 9286 9287 def fn(): 9288 a = 2 9289 def fn2(): 9290 global a 9291 a = 0 9292 """ 9293 with self.assertRaisesRegex( 9294 TypedSyntaxError, "Cannot assign to a Final variable" 9295 ): 9296 self.compile(codestr, StaticCodeGenerator, modname="foo") 9297 9298 def test_final_reassign_nonlocal(self): 9299 codestr = """ 9300 from typing import Final 9301 9302 a: Final = 1337 9303 9304 def fn(): 9305 def fn2(): 9306 nonlocal a 9307 a = 0 9308 """ 9309 with self.assertRaisesRegex( 9310 TypedSyntaxError, "Cannot assign to a Final variable" 9311 ): 9312 self.compile(codestr, StaticCodeGenerator, modname="foo") 9313 9314 def test_final_reassign_nonlocal_shadowed(self): 9315 codestr = """ 9316 from typing import Final 9317 9318 a: Final = 1337 9319 9320 def fn(): 9321 a = 3 9322 def fn2(): 9323 nonlocal a 9324 # should be allowed, we're assigning to the shadowed 9325 # value 9326 a = 0 9327 """ 9328 self.compile(codestr, StaticCodeGenerator, modname="foo") 9329 9330 def test_final_reassigned_in_tuple(self): 9331 codestr = """ 9332 from typing import Final 9333 9334 x: Final = 0xdeadbeef 9335 y = 3 9336 x, y = 4, 5 9337 """ 9338 with self.assertRaisesRegex( 9339 TypedSyntaxError, "Cannot assign to a Final variable" 9340 ): 9341 self.compile(codestr, StaticCodeGenerator, modname="foo") 9342 9343 def test_final_reassigned_in_loop(self): 9344 codestr = """ 9345 from typing import Final 9346 9347 x: Final = 0xdeadbeef 9348 9349 for x in [1, 3, 5]: 9350 pass 9351 """ 9352 with self.assertRaisesRegex( 9353 TypedSyntaxError, "Cannot assign to a Final variable" 9354 ): 9355 self.compile(codestr, StaticCodeGenerator, modname="foo") 9356 9357 def test_final_reassigned_in_except(self): 9358 codestr = """ 9359 from typing import Final 9360 9361 def f(): 9362 e: Final = 3 9363 try: 9364 x = 1 + "2" 9365 except Exception as e: 9366 pass 9367 """ 9368 with self.assertRaisesRegex( 9369 TypedSyntaxError, "Cannot assign to a Final variable" 9370 ): 9371 self.compile(codestr, StaticCodeGenerator, modname="foo") 9372 9373 def test_final_reassigned_in_loop_target_tuple(self): 9374 codestr = """ 9375 from typing import Final 9376 9377 x: Final = 0xdeadbeef 9378 9379 for x, y in [(1, 2)]: 9380 pass 9381 """ 9382 with self.assertRaisesRegex( 9383 TypedSyntaxError, "Cannot assign to a Final variable" 9384 ): 9385 self.compile(codestr, StaticCodeGenerator, modname="foo") 9386 9387 def test_final_reassigned_in_ctxmgr(self): 9388 codestr = """ 9389 from typing import Final 9390 9391 x: Final = 0xdeadbeef 9392 9393 with open("lol") as x: 9394 pass 9395 """ 9396 with self.assertRaisesRegex( 9397 TypedSyntaxError, "Cannot assign to a Final variable" 9398 ): 9399 self.compile(codestr, StaticCodeGenerator, modname="foo") 9400 9401 def test_final_generic_reassign(self): 9402 codestr = """ 9403 from typing import Final 9404 9405 x: Final[int] = 0xdeadbeef 9406 x = 0x5ca1ab1e 9407 """ 9408 with self.assertRaisesRegex( 9409 TypedSyntaxError, "Cannot assign to a Final variable" 9410 ): 9411 self.compile(codestr, StaticCodeGenerator, modname="foo") 9412 9413 def test_class_level_final_decl(self): 9414 codestr = """ 9415 from typing import Final 9416 9417 class C: 9418 x: Final[int] 9419 9420 def __init__(self): 9421 self.x = 3 9422 """ 9423 self.compile(codestr, StaticCodeGenerator, modname="foo") 9424 9425 def test_class_level_final_decl_uninitialized(self): 9426 codestr = """ 9427 from typing import Final 9428 9429 class C: 9430 x: Final 9431 """ 9432 with self.assertRaisesRegex( 9433 TypedSyntaxError, 9434 re.escape("Final attribute not initialized: foo.C:x"), 9435 ): 9436 self.compile(codestr, StaticCodeGenerator, modname="foo") 9437 9438 def test_class_level_final_reinitialized(self): 9439 codestr = """ 9440 from typing import Final 9441 9442 class C: 9443 x: Final = 3 9444 x = 4 9445 """ 9446 with self.assertRaisesRegex( 9447 TypedSyntaxError, "Cannot assign to a Final variable" 9448 ): 9449 self.compile(codestr, StaticCodeGenerator, modname="foo") 9450 9451 def test_class_level_final_reinitialized_directly(self): 9452 codestr = """ 9453 from typing import Final 9454 9455 class C: 9456 x: Final = 3 9457 9458 C.x = 4 9459 """ 9460 # Note - this will raise even without the Final, we don't allow assignments to slots 9461 with self.assertRaisesRegex( 9462 TypedSyntaxError, 9463 type_mismatch("Exact[int]", "types.MemberDescriptorType"), 9464 ): 9465 self.compile(codestr, StaticCodeGenerator, modname="foo") 9466 9467 def test_class_level_final_reinitialized_in_instance(self): 9468 codestr = """ 9469 from typing import Final 9470 9471 class C: 9472 x: Final = 3 9473 9474 C().x = 4 9475 """ 9476 with self.assertRaisesRegex( 9477 TypedSyntaxError, 9478 "Cannot assign to a Final attribute of foo.C:x", 9479 ): 9480 self.compile(codestr, StaticCodeGenerator, modname="foo") 9481 9482 def test_class_level_final_reinitialized_in_method(self): 9483 codestr = """ 9484 from typing import Final 9485 9486 class C: 9487 x: Final = 3 9488 9489 def something(self) -> None: 9490 self.x = 4 9491 """ 9492 with self.assertRaisesRegex( 9493 TypedSyntaxError, "Cannot assign to a Final attribute of foo.C:x" 9494 ): 9495 self.compile(codestr, StaticCodeGenerator, modname="foo") 9496 9497 def test_class_level_final_reinitialized_in_subclass_without_annotation(self): 9498 codestr = """ 9499 from typing import Final 9500 9501 class C: 9502 x: Final = 3 9503 9504 class D(C): 9505 x = 4 9506 """ 9507 with self.assertRaisesRegex( 9508 TypedSyntaxError, 9509 "Cannot assign to a Final attribute of foo.D:x", 9510 ): 9511 self.compile(codestr, StaticCodeGenerator, modname="foo") 9512 9513 def test_class_level_final_reinitialized_in_subclass_with_annotation(self): 9514 codestr = """ 9515 from typing import Final 9516 9517 class C: 9518 x: Final = 3 9519 9520 class D(C): 9521 x: Final[int] = 4 9522 """ 9523 with self.assertRaisesRegex( 9524 TypedSyntaxError, 9525 "Cannot assign to a Final attribute of foo.D:x", 9526 ): 9527 self.compile(codestr, StaticCodeGenerator, modname="foo") 9528 9529 def test_class_level_final_reinitialized_in_subclass_init(self): 9530 codestr = """ 9531 from typing import Final 9532 9533 class C: 9534 x: Final = 3 9535 9536 class D(C): 9537 def __init__(self): 9538 self.x = 4 9539 """ 9540 with self.assertRaisesRegex( 9541 TypedSyntaxError, 9542 "Cannot assign to a Final attribute of foo.D:x", 9543 ): 9544 self.compile(codestr, StaticCodeGenerator, modname="foo") 9545 9546 def test_class_level_final_reinitialized_in_subclass_init_with_annotation(self): 9547 codestr = """ 9548 from typing import Final 9549 9550 class C: 9551 x: Final[int] = 3 9552 9553 class D(C): 9554 def __init__(self): 9555 self.x: Final[int] = 4 9556 """ 9557 with self.assertRaisesRegex( 9558 TypedSyntaxError, 9559 "Cannot assign to a Final attribute of foo.D:x", 9560 ): 9561 self.compile(codestr, StaticCodeGenerator, modname="foo") 9562 9563 def test_class_level_final_decl_in_init_reinitialized_in_subclass_init(self): 9564 codestr = """ 9565 from typing import Final 9566 9567 class C: 9568 x: Final[int] 9569 9570 def __init__(self) -> None: 9571 self.x = 3 9572 9573 class D(C): 9574 def __init__(self) -> None: 9575 self.x = 4 9576 """ 9577 with self.assertRaisesRegex( 9578 TypedSyntaxError, 9579 "Cannot assign to a Final attribute of foo.D:x", 9580 ): 9581 self.compile(codestr, StaticCodeGenerator, modname="foo") 9582 9583 def test_final_in_args(self): 9584 codestr = """ 9585 from typing import Final 9586 9587 def f(a: Final) -> None: 9588 pass 9589 """ 9590 with self.assertRaisesRegex( 9591 TypedSyntaxError, 9592 "Final annotation is only valid in initial declaration", 9593 ): 9594 self.compile(codestr, StaticCodeGenerator, modname="foo") 9595 9596 def test_final_returns(self): 9597 codestr = """ 9598 from typing import Final 9599 9600 def f() -> Final[int]: 9601 return 1 9602 """ 9603 with self.assertRaisesRegex( 9604 TypedSyntaxError, 9605 "Final annotation is only valid in initial declaration", 9606 ): 9607 self.compile(codestr, StaticCodeGenerator, modname="foo") 9608 9609 def test_final_decorator(self): 9610 codestr = """ 9611 from typing import final 9612 9613 class C: 9614 @final 9615 def f(): 9616 pass 9617 """ 9618 self.compile(codestr, StaticCodeGenerator, modname="foo") 9619 9620 def test_final_decorator_override(self): 9621 codestr = """ 9622 from typing import final 9623 9624 class C: 9625 @final 9626 def f(): 9627 pass 9628 9629 class D(C): 9630 def f(): 9631 pass 9632 """ 9633 with self.assertRaisesRegex( 9634 TypedSyntaxError, "Cannot assign to a Final attribute of foo.D:f" 9635 ): 9636 self.compile(codestr, StaticCodeGenerator, modname="foo") 9637 9638 def test_final_decorator_override_with_assignment(self): 9639 codestr = """ 9640 from typing import final 9641 9642 class C: 9643 @final 9644 def f(): 9645 pass 9646 9647 class D(C): 9648 f = print 9649 """ 9650 with self.assertRaisesRegex( 9651 TypedSyntaxError, "Cannot assign to a Final attribute of foo.D:f" 9652 ): 9653 self.compile(codestr, StaticCodeGenerator, modname="foo") 9654 9655 def test_final_decorator_override_transitivity(self): 9656 codestr = """ 9657 from typing import final 9658 9659 class C: 9660 @final 9661 def f(): 9662 pass 9663 9664 class D(C): 9665 pass 9666 9667 class E(D): 9668 def f(): 9669 pass 9670 """ 9671 with self.assertRaisesRegex( 9672 TypedSyntaxError, "Cannot assign to a Final attribute of foo.E:f" 9673 ): 9674 self.compile(codestr, StaticCodeGenerator, modname="foo") 9675 9676 def test_final_decorator_class(self): 9677 codestr = """ 9678 from typing import final 9679 9680 @final 9681 class C: 9682 def f(self): 9683 pass 9684 9685 def f(): 9686 return C().f() 9687 """ 9688 c = self.compile(codestr, StaticCodeGenerator, modname="foo") 9689 f = self.find_code(c, "f") 9690 self.assertInBytecode(f, "INVOKE_FUNCTION") 9691 9692 def test_final_decorator_class_inheritance(self): 9693 codestr = """ 9694 from typing import final 9695 9696 @final 9697 class C: 9698 pass 9699 9700 class D(C): 9701 pass 9702 """ 9703 with self.assertRaisesRegex( 9704 TypedSyntaxError, "Class `foo.D` cannot subclass a Final class: `foo.C`" 9705 ): 9706 self.compile(codestr, StaticCodeGenerator, modname="foo") 9707 9708 def test_final_decorator_class_dynamic(self): 9709 """We should never mark DYNAMIC_TYPE as final.""" 9710 codestr = """ 9711 from typing import final, Generic, NamedTuple 9712 9713 @final 9714 class NT(NamedTuple): 9715 x: int 9716 9717 class C(Generic): 9718 pass 9719 """ 9720 # No TypedSyntaxError "cannot inherit from Final class 'dynamic'" 9721 self.compile(codestr) 9722 9723 def test_slotification_decorated(self): 9724 codestr = """ 9725 class _Inner(): 9726 pass 9727 9728 def something(klass): 9729 return _Inner 9730 9731 @something 9732 class C: 9733 def f(self): 9734 pass 9735 9736 def f(): 9737 return C().f() 9738 """ 9739 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 9740 f = mod["f"] 9741 self.assertNotInBytecode(f, "INVOKE_FUNCTION") 9742 self.assertNotInBytecode(f, "INVOKE_METHOD") 9743 9744 def test_inline_func(self): 9745 codestr = """ 9746 from __static__ import inline 9747 9748 @inline 9749 def f(x, y): 9750 return x + y 9751 9752 def g(): 9753 return f(1,2) 9754 """ 9755 # we only inline at opt level 2 to avoid test patching problems 9756 # TODO longer term we might need something better here (e.g. emit both 9757 # inlined code and call and a guard to choose); assuming 9758 # non-patchability at opt 2 works for IG but isn't generally valid 9759 for opt in [0, 1, 2]: 9760 with self.subTest(opt=opt): 9761 with self.in_module(codestr, optimize=opt) as mod: 9762 g = mod["g"] 9763 if opt == 2: 9764 self.assertInBytecode(g, "LOAD_CONST", 3) 9765 else: 9766 self.assertInBytecode( 9767 g, "INVOKE_FUNCTION", ((mod["__name__"], "f"), 2) 9768 ) 9769 self.assertEqual(g(), 3) 9770 9771 def test_inline_kwarg(self): 9772 codestr = """ 9773 from __static__ import inline 9774 9775 @inline 9776 def f(x, y): 9777 return x + y 9778 9779 def g(): 9780 return f(x=1,y=2) 9781 """ 9782 with self.in_module(codestr, optimize=2) as mod: 9783 g = mod["g"] 9784 self.assertInBytecode(g, "LOAD_CONST", 3) 9785 self.assertEqual(g(), 3) 9786 9787 def test_inline_bare_return(self): 9788 codestr = """ 9789 from __static__ import inline 9790 9791 @inline 9792 def f(x, y): 9793 return 9794 9795 def g(): 9796 return f(x=1,y=2) 9797 """ 9798 with self.in_module(codestr, optimize=2) as mod: 9799 g = mod["g"] 9800 self.assertInBytecode(g, "LOAD_CONST", None) 9801 self.assertEqual(g(), None) 9802 9803 def test_inline_final(self): 9804 codestr = """ 9805 from __static__ import inline 9806 from typing import Final 9807 9808 Y: Final[int] = 42 9809 @inline 9810 def f(x): 9811 return x + Y 9812 9813 def g(): 9814 return f(1) 9815 """ 9816 with self.in_module(codestr, optimize=2) as mod: 9817 g = mod["g"] 9818 # We don't currently inline math with finals 9819 self.assertInBytecode(g, "LOAD_CONST", 42) 9820 self.assertEqual(g(), 43) 9821 9822 def test_inline_nested(self): 9823 codestr = """ 9824 from __static__ import inline 9825 9826 @inline 9827 def e(x, y): 9828 return x + y 9829 9830 @inline 9831 def f(x, y): 9832 return e(x, 3) 9833 9834 def g(): 9835 return f(1,2) 9836 """ 9837 with self.in_module(codestr, optimize=2) as mod: 9838 g = mod["g"] 9839 self.assertInBytecode(g, "LOAD_CONST", 4) 9840 self.assertEqual(g(), 4) 9841 9842 def test_inline_nested_arg(self): 9843 codestr = """ 9844 from __static__ import inline 9845 9846 @inline 9847 def e(x, y): 9848 return x + y 9849 9850 @inline 9851 def f(x, y): 9852 return e(x, 3) 9853 9854 def g(a,b): 9855 return f(a,b) 9856 """ 9857 with self.in_module(codestr, optimize=2) as mod: 9858 g = mod["g"] 9859 self.assertInBytecode(g, "LOAD_CONST", 3) 9860 self.assertInBytecode(g, "BINARY_ADD") 9861 self.assertEqual(g(1, 2), 4) 9862 9863 def test_inline_recursive(self): 9864 codestr = """ 9865 from __static__ import inline 9866 9867 @inline 9868 def f(x, y): 9869 return f(x, y) 9870 9871 def g(): 9872 return f(1,2) 9873 """ 9874 with self.in_module(codestr, optimize=2) as mod: 9875 g = mod["g"] 9876 self.assertInBytecode(g, "INVOKE_FUNCTION", (((mod["__name__"], "f"), 2))) 9877 9878 def test_inline_func_default(self): 9879 codestr = """ 9880 from __static__ import inline 9881 9882 @inline 9883 def f(x, y = 2): 9884 return x + y 9885 9886 def g(): 9887 return f(1) 9888 """ 9889 with self.in_module(codestr, optimize=2) as mod: 9890 g = mod["g"] 9891 self.assertInBytecode(g, "LOAD_CONST", 3) 9892 9893 self.assertEqual(g(), 3) 9894 9895 def test_inline_arg_type(self): 9896 codestr = """ 9897 from __static__ import box, inline, int64, int32 9898 9899 @inline 9900 def f(x: int64) -> int: 9901 return box(x) 9902 9903 def g(arg: int) -> int: 9904 return f(int64(arg)) 9905 """ 9906 with self.in_module(codestr, optimize=2) as mod: 9907 g = mod["g"] 9908 self.assertInBytecode(g, "PRIMITIVE_BOX") 9909 self.assertEqual(g(3), 3) 9910 9911 def test_augassign_primitive_int(self): 9912 codestr = """ 9913 from __static__ import int8, box, unbox 9914 9915 def a(i: int) -> int: 9916 j: int8 = unbox(i) 9917 j += 2 9918 return box(j) 9919 """ 9920 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 9921 a = mod["a"] 9922 self.assertInBytecode(a, "PRIMITIVE_BINARY_OP", 0) 9923 self.assertEqual(a(3), 5) 9924 9925 def test_primitive_compare_immediate_no_branch_on_result(self): 9926 for rev in [True, False]: 9927 compare = "0 == xp" if rev else "xp == 0" 9928 codestr = f""" 9929 from __static__ import box, int64, int32 9930 9931 def f(x: int) -> bool: 9932 xp = int64(x) 9933 y = {compare} 9934 return box(y) 9935 """ 9936 with self.subTest(rev=rev): 9937 with self.in_module(codestr) as mod: 9938 f = mod["f"] 9939 self.assertEqual(f(3), 0) 9940 self.assertEqual(f(0), 1) 9941 self.assertIs(f(0), True) 9942 9943 def test_type_type_final(self): 9944 codestr = """ 9945 class A(type): 9946 pass 9947 """ 9948 self.compile(codestr) 9949 9950 9951class StaticRuntimeTests(StaticTestBase): 9952 def test_bad_slots_qualname_conflict(self): 9953 with self.assertRaises(ValueError): 9954 9955 class C: 9956 __slots__ = ("x",) 9957 __slot_types__ = {"x": ("__static__", "int32")} 9958 x = 42 9959 9960 def test_typed_slots_bad_inst(self): 9961 class C: 9962 __slots__ = ("a",) 9963 __slot_types__ = {"a": ("__static__", "int32")} 9964 9965 class D: 9966 pass 9967 9968 with self.assertRaises(TypeError): 9969 C.a.__get__(D(), D) 9970 9971 def test_typed_slots_bad_slots(self): 9972 with self.assertRaises(TypeError): 9973 9974 class C: 9975 __slots__ = ("a",) 9976 __slot_types__ = None 9977 9978 def test_typed_slots_bad_slot_dict(self): 9979 with self.assertRaises(TypeError): 9980 9981 class C: 9982 __slots__ = ("__dict__",) 9983 __slot_types__ = {"__dict__": "object"} 9984 9985 def test_typed_slots_bad_slot_weakerf(self): 9986 with self.assertRaises(TypeError): 9987 9988 class C: 9989 __slots__ = ("__weakref__",) 9990 __slot_types__ = {"__weakref__": "object"} 9991 9992 def test_typed_slots_object(self): 9993 codestr = """ 9994 class C: 9995 __slots__ = ('a', ) 9996 __slot_types__ = {'a': (__name__, 'C')} 9997 9998 inst = C() 9999 """ 10000 10001 with self.in_module(codestr, code_gen=PythonCodeGenerator) as mod: 10002 inst, C = mod["inst"], mod["C"] 10003 self.assertEqual(C.a.__class__.__name__, "typed_descriptor") 10004 with self.assertRaises(TypeError): 10005 # type is checked 10006 inst.a = 42 10007 with self.assertRaises(TypeError): 10008 inst.a = None 10009 with self.assertRaises(AttributeError): 10010 # is initially unassigned 10011 inst.a 10012 10013 # can assign correct type 10014 inst.a = inst 10015 10016 # __sizeof__ doesn't include GC header size 10017 self.assertEqual(inst.__sizeof__(), self.base_size + self.ptr_size) 10018 # size is +2 words for GC header, one word for reference 10019 self.assertEqual(sys.getsizeof(inst), self.base_size + (self.ptr_size * 3)) 10020 10021 # subclasses are okay 10022 class D(C): 10023 pass 10024 10025 inst.a = D() 10026 10027 def test_allow_weakrefs(self): 10028 codestr = """ 10029 from __static__ import allow_weakrefs 10030 import weakref 10031 10032 @allow_weakrefs 10033 class C: 10034 pass 10035 10036 def f(c: C): 10037 return weakref.ref(c) 10038 """ 10039 with self.in_module(codestr) as mod: 10040 C = mod["C"] 10041 c = C() 10042 ref = mod["f"](c) 10043 self.assertIs(ref(), c) 10044 del c 10045 self.assertIs(ref(), None) 10046 self.assertEqual(C.__slots__, ("__weakref__",)) 10047 10048 def test_dynamic_return(self): 10049 codestr = """ 10050 from __future__ import annotations 10051 from __static__ import allow_weakrefs, dynamic_return 10052 import weakref 10053 10054 singletons = [] 10055 10056 @allow_weakrefs 10057 class C: 10058 @dynamic_return 10059 @staticmethod 10060 def make() -> C: 10061 return weakref.proxy(singletons[0]) 10062 10063 def g(self) -> int: 10064 return 1 10065 10066 singletons.append(C()) 10067 10068 def f() -> int: 10069 c = C.make() 10070 return c.g() 10071 """ 10072 with self.in_strict_module(codestr) as mod: 10073 # We don't try to cast the return type of make 10074 self.assertNotInBytecode(mod.C.make, "CAST") 10075 # We can statically invoke make 10076 self.assertInBytecode( 10077 mod.f, "INVOKE_FUNCTION", ((mod.__name__, "C", "make"), 0) 10078 ) 10079 # But we can't statically invoke a method on the returned instance 10080 self.assertNotInBytecode(mod.f, "INVOKE_METHOD") 10081 self.assertEqual(mod.f(), 1) 10082 # We don't mess with __annotations__ 10083 self.assertEqual(mod.C.make.__annotations__, {"return": "C"}) 10084 10085 def test_generic_type_def_no_create(self): 10086 from xxclassloader import spamobj 10087 10088 with self.assertRaises(TypeError): 10089 spamobj() 10090 10091 def test_generic_type_def_bad_args(self): 10092 from xxclassloader import spamobj 10093 10094 with self.assertRaises(TypeError): 10095 spamobj[str, int] 10096 10097 def test_generic_type_def_non_type(self): 10098 from xxclassloader import spamobj 10099 10100 with self.assertRaises(TypeError): 10101 spamobj[42] 10102 10103 def test_generic_type_inst_okay(self): 10104 from xxclassloader import spamobj 10105 10106 o = spamobj[str]() 10107 o.setstate("abc") 10108 10109 def test_generic_type_inst_optional_okay(self): 10110 from xxclassloader import spamobj 10111 10112 o = spamobj[Optional[str]]() 10113 o.setstate("abc") 10114 o.setstate(None) 10115 10116 def test_generic_type_inst_non_optional_error(self): 10117 from xxclassloader import spamobj 10118 10119 o = spamobj[str]() 10120 with self.assertRaises(TypeError): 10121 o.setstate(None) 10122 10123 def test_generic_type_inst_bad_type(self): 10124 from xxclassloader import spamobj 10125 10126 o = spamobj[str]() 10127 with self.assertRaises(TypeError): 10128 o.setstate(42) 10129 10130 def test_generic_type_inst_name(self): 10131 from xxclassloader import spamobj 10132 10133 self.assertEqual(spamobj[str].__name__, "spamobj[str]") 10134 10135 def test_generic_type_inst_name_optional(self): 10136 from xxclassloader import spamobj 10137 10138 self.assertEqual(spamobj[Optional[str]].__name__, "spamobj[Optional[str]]") 10139 10140 def test_generic_type_inst_okay_func(self): 10141 from xxclassloader import spamobj 10142 10143 o = spamobj[str]() 10144 f = o.setstate 10145 f("abc") 10146 10147 def test_generic_type_inst_optional_okay_func(self): 10148 from xxclassloader import spamobj 10149 10150 o = spamobj[Optional[str]]() 10151 f = o.setstate 10152 f("abc") 10153 f(None) 10154 10155 def test_generic_type_inst_non_optional_error_func(self): 10156 from xxclassloader import spamobj 10157 10158 o = spamobj[str]() 10159 f = o.setstate 10160 with self.assertRaises(TypeError): 10161 f(None) 10162 10163 def test_generic_type_inst_bad_type_func(self): 10164 from xxclassloader import spamobj 10165 10166 o = spamobj[str]() 10167 f = o.setstate 10168 with self.assertRaises(TypeError): 10169 f(42) 10170 10171 def test_generic_int_funcs(self): 10172 from xxclassloader import spamobj 10173 10174 o = spamobj[str]() 10175 o.setint(42) 10176 self.assertEqual(o.getint8(), 42) 10177 self.assertEqual(o.getint16(), 42) 10178 self.assertEqual(o.getint32(), 42) 10179 10180 def test_generic_uint_funcs(self): 10181 from xxclassloader import spamobj 10182 10183 o = spamobj[str]() 10184 o.setuint64(42) 10185 self.assertEqual(o.getuint8(), 42) 10186 self.assertEqual(o.getuint16(), 42) 10187 self.assertEqual(o.getuint32(), 42) 10188 self.assertEqual(o.getuint64(), 42) 10189 10190 def test_generic_int_funcs_overflow(self): 10191 from xxclassloader import spamobj 10192 10193 o = spamobj[str]() 10194 o.setuint64(42) 10195 for i, f in enumerate([o.setint8, o.setint16, o.setint32, o.setint]): 10196 with self.assertRaises(OverflowError): 10197 x = -(1 << ((8 << i) - 1)) - 1 10198 f(x) 10199 with self.assertRaises(OverflowError): 10200 x = 1 << ((8 << i) - 1) 10201 f(x) 10202 10203 def test_generic_uint_funcs_overflow(self): 10204 from xxclassloader import spamobj 10205 10206 o = spamobj[str]() 10207 o.setuint64(42) 10208 for f in [o.setuint8, o.setuint16, o.setuint32, o.setuint64]: 10209 with self.assertRaises(OverflowError): 10210 f(-1) 10211 for i, f in enumerate([o.setuint8, o.setuint16, o.setuint32, o.setuint64]): 10212 with self.assertRaises(OverflowError): 10213 x = (1 << (8 << i)) + 1 10214 f(x) 10215 10216 def test_generic_type_int_func(self): 10217 from xxclassloader import spamobj 10218 10219 o = spamobj[str]() 10220 o.setint(42) 10221 self.assertEqual(o.getint(), 42) 10222 with self.assertRaises(TypeError): 10223 o.setint("abc") 10224 10225 def test_generic_type_str_func(self): 10226 from xxclassloader import spamobj 10227 10228 o = spamobj[str]() 10229 o.setstr("abc") 10230 self.assertEqual(o.getstr(), "abc") 10231 with self.assertRaises(TypeError): 10232 o.setstr(42) 10233 10234 def test_generic_type_bad_arg_cnt(self): 10235 from xxclassloader import spamobj 10236 10237 o = spamobj[str]() 10238 with self.assertRaises(TypeError): 10239 o.setstr() 10240 with self.assertRaises(TypeError): 10241 o.setstr("abc", "abc") 10242 10243 def test_generic_type_bad_arg_cnt(self): 10244 from xxclassloader import spamobj 10245 10246 o = spamobj[str]() 10247 self.assertEqual(o.twoargs(1, 2), 3) 10248 10249 def test_typed_slots_one_missing(self): 10250 codestr = """ 10251 class C: 10252 __slots__ = ('a', 'b') 10253 __slot_types__ = {'a': (__name__, 'C')} 10254 10255 inst = C() 10256 """ 10257 10258 with self.in_module(codestr, code_gen=PythonCodeGenerator) as mod: 10259 inst, C = mod["inst"], mod["C"] 10260 self.assertEqual(C.a.__class__.__name__, "typed_descriptor") 10261 with self.assertRaises(TypeError): 10262 # type is checked 10263 inst.a = 42 10264 10265 def test_typed_slots_optional_object(self): 10266 codestr = """ 10267 class C: 10268 __slots__ = ('a', ) 10269 __slot_types__ = {'a': (__name__, 'C', '?')} 10270 10271 inst = C() 10272 """ 10273 10274 with self.in_module(codestr, code_gen=PythonCodeGenerator) as mod: 10275 inst, C = mod["inst"], mod["C"] 10276 inst.a = None 10277 self.assertEqual(inst.a, None) 10278 10279 def test_typed_slots_private(self): 10280 codestr = """ 10281 class C: 10282 __slots__ = ('__a', ) 10283 __slot_types__ = {'__a': (__name__, 'C', '?')} 10284 def __init__(self): 10285 self.__a = None 10286 10287 inst = C() 10288 """ 10289 10290 with self.in_module(codestr, code_gen=PythonCodeGenerator) as mod: 10291 inst, C = mod["inst"], mod["C"] 10292 self.assertEqual(inst._C__a, None) 10293 inst._C__a = inst 10294 self.assertEqual(inst._C__a, inst) 10295 inst._C__a = None 10296 self.assertEqual(inst._C__a, None) 10297 10298 def test_typed_slots_optional_not_defined(self): 10299 codestr = """ 10300 class C: 10301 __slots__ = ('a', ) 10302 __slot_types__ = {'a': (__name__, 'D', '?')} 10303 10304 def __init__(self): 10305 self.a = None 10306 10307 inst = C() 10308 10309 class D: 10310 pass 10311 """ 10312 10313 with self.in_module(codestr, code_gen=PythonCodeGenerator) as mod: 10314 inst, C = mod["inst"], mod["C"] 10315 inst.a = None 10316 self.assertEqual(inst.a, None) 10317 10318 def test_typed_slots_alignment(self): 10319 return 10320 codestr = """ 10321 class C: 10322 __slots__ = ('a', 'b') 10323 __slot_types__ {'a': ('__static__', 'int16')} 10324 10325 inst = C() 10326 """ 10327 10328 with self.in_module(codestr, code_gen=PythonCodeGenerator) as mod: 10329 inst, C = mod["inst"], mod["C"] 10330 inst.a = None 10331 self.assertEqual(inst.a, None) 10332 10333 def test_typed_slots_primitives(self): 10334 slot_types = [ 10335 # signed 10336 ( 10337 ("__static__", "byte"), 10338 0, 10339 1, 10340 [(1 << 7) - 1, -(1 << 7)], 10341 [1 << 8], 10342 ["abc"], 10343 ), 10344 ( 10345 ("__static__", "int8"), 10346 0, 10347 1, 10348 [(1 << 7) - 1, -(1 << 7)], 10349 [1 << 8], 10350 ["abc"], 10351 ), 10352 ( 10353 ("__static__", "int16"), 10354 0, 10355 2, 10356 [(1 << 15) - 1, -(1 << 15)], 10357 [1 << 15, -(1 << 15) - 1], 10358 ["abc"], 10359 ), 10360 ( 10361 ("__static__", "int32"), 10362 0, 10363 4, 10364 [(1 << 31) - 1, -(1 << 31)], 10365 [1 << 31, -(1 << 31) - 1], 10366 ["abc"], 10367 ), 10368 (("__static__", "int64"), 0, 8, [(1 << 63) - 1, -(1 << 63)], [], [1 << 63]), 10369 # unsigned 10370 (("__static__", "uint8"), 0, 1, [(1 << 8) - 1, 0], [1 << 8, -1], ["abc"]), 10371 ( 10372 ("__static__", "uint16"), 10373 0, 10374 2, 10375 [(1 << 16) - 1, 0], 10376 [1 << 16, -1], 10377 ["abc"], 10378 ), 10379 ( 10380 ("__static__", "uint32"), 10381 0, 10382 4, 10383 [(1 << 32) - 1, 0], 10384 [1 << 32, -1], 10385 ["abc"], 10386 ), 10387 (("__static__", "uint64"), 0, 8, [(1 << 64) - 1, 0], [], [1 << 64]), 10388 # pointer 10389 ( 10390 ("__static__", "ssize_t"), 10391 0, 10392 self.ptr_size, 10393 [1, sys.maxsize, -sys.maxsize - 1], 10394 [], 10395 [sys.maxsize + 1, -sys.maxsize - 2], 10396 ), 10397 # floating point 10398 (("__static__", "single"), 0.0, 4, [1.0], [], ["abc"]), 10399 (("__static__", "double"), 0.0, 8, [1.0], [], ["abc"]), 10400 # misc 10401 (("__static__", "char"), "\x00", 1, ["a"], [], ["abc"]), 10402 (("__static__", "cbool"), False, 1, [True], [], ["abc", 1]), 10403 ] 10404 10405 base_size = self.base_size 10406 for type_spec, default, size, test_vals, warn_vals, err_vals in slot_types: 10407 with self.subTest( 10408 type_spec=type_spec, 10409 default=default, 10410 size=size, 10411 test_vals=test_vals, 10412 warn_vals=warn_vals, 10413 err_vals=err_vals, 10414 ): 10415 10416 class C: 10417 __slots__ = ("a",) 10418 __slot_types__ = {"a": type_spec} 10419 10420 a = C() 10421 self.assertEqual(sys.getsizeof(a), base_size + size, type) 10422 self.assertEqual(a.a, default) 10423 self.assertEqual(type(a.a), type(default)) 10424 for val in test_vals: 10425 a.a = val 10426 self.assertEqual(a.a, val) 10427 10428 with warnings.catch_warnings(): 10429 warnings.simplefilter("error", category=RuntimeWarning) 10430 for val in warn_vals: 10431 with self.assertRaises(RuntimeWarning): 10432 a.a = val 10433 10434 for val in err_vals: 10435 with self.assertRaises((TypeError, OverflowError)): 10436 a.a = val 10437 10438 def test_invoke_function(self): 10439 my_int = "12345" 10440 codestr = f""" 10441 def x(a: str, b: int) -> str: 10442 return a + str(b) 10443 10444 def test() -> str: 10445 return x("hello", {my_int}) 10446 """ 10447 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 10448 test = self.find_code(c, "test") 10449 self.assertInBytecode(test, "INVOKE_FUNCTION", (("foo.py", "x"), 2)) 10450 with self.in_module(codestr) as mod: 10451 test_callable = mod["test"] 10452 self.assertEqual(test_callable(), "hello" + my_int) 10453 10454 def test_awaited_invoke_function(self): 10455 codestr = """ 10456 async def f() -> int: 10457 return 1 10458 10459 async def g() -> int: 10460 return await f() 10461 """ 10462 with self.in_strict_module(codestr) as mod: 10463 self.assertInBytecode(mod.g, "INVOKE_FUNCTION", ((mod.__name__, "f"), 0)) 10464 self.assertEqual(asyncio.run(mod.g()), 1) 10465 10466 def test_awaited_invoke_function_unjitable(self): 10467 codestr = """ 10468 async def f() -> int: 10469 from os.path import * 10470 return 1 10471 10472 async def g() -> int: 10473 return await f() 10474 """ 10475 with self.in_strict_module(codestr) as mod: 10476 self.assertInBytecode( 10477 mod.g, 10478 "INVOKE_FUNCTION", 10479 ((mod.__name__, "f"), 0), 10480 ) 10481 self.assertEqual(asyncio.run(mod.g()), 1) 10482 10483 def test_awaited_invoke_function_with_args(self): 10484 codestr = """ 10485 async def f(a: int, b: int) -> int: 10486 return a + b 10487 10488 async def g() -> int: 10489 return await f(1, 2) 10490 """ 10491 with self.in_strict_module(codestr) as mod: 10492 self.assertInBytecode( 10493 mod.g, 10494 "INVOKE_FUNCTION", 10495 ((mod.__name__, "f"), 2), 10496 ) 10497 self.assertEqual(asyncio.run(mod.g()), 3) 10498 10499 # exercise shadowcode, INVOKE_FUNCTION_CACHED 10500 self.make_async_func_hot(mod.g) 10501 self.assertEqual(asyncio.run(mod.g()), 3) 10502 10503 def test_awaited_invoke_function_indirect_with_args(self): 10504 codestr = """ 10505 async def f(a: int, b: int) -> int: 10506 return a + b 10507 10508 async def g() -> int: 10509 return await f(1, 2) 10510 """ 10511 with self.in_module(codestr) as mod: 10512 g = mod["g"] 10513 self.assertInBytecode( 10514 g, 10515 "INVOKE_FUNCTION", 10516 ((mod["__name__"], "f"), 2), 10517 ) 10518 self.assertEqual(asyncio.run(g()), 3) 10519 10520 # exercise shadowcode, INVOKE_FUNCTION_INDIRECT_CACHED 10521 self.make_async_func_hot(g) 10522 self.assertEqual(asyncio.run(g()), 3) 10523 10524 def test_awaited_invoke_function_future(self): 10525 codestr = """ 10526 from asyncio import ensure_future 10527 10528 async def h() -> int: 10529 return 1 10530 10531 async def g() -> None: 10532 await ensure_future(h()) 10533 10534 async def f(): 10535 await g() 10536 """ 10537 with self.in_strict_module(codestr) as mod: 10538 self.assertInBytecode( 10539 mod.f, 10540 "INVOKE_FUNCTION", 10541 ((mod.__name__, "g"), 0), 10542 ) 10543 asyncio.run(mod.f()) 10544 10545 # exercise shadowcode 10546 self.make_async_func_hot(mod.f) 10547 asyncio.run(mod.f()) 10548 10549 def test_awaited_invoke_method(self): 10550 codestr = """ 10551 class C: 10552 async def f(self) -> int: 10553 return 1 10554 10555 async def g(self) -> int: 10556 return await self.f() 10557 """ 10558 with self.in_strict_module(codestr) as mod: 10559 self.assertInBytecode( 10560 mod.C.g, "INVOKE_METHOD", ((mod.__name__, "C", "f"), 0) 10561 ) 10562 self.assertEqual(asyncio.run(mod.C().g()), 1) 10563 10564 def test_awaited_invoke_method_with_args(self): 10565 codestr = """ 10566 class C: 10567 async def f(self, a: int, b: int) -> int: 10568 return a + b 10569 10570 async def g(self) -> int: 10571 return await self.f(1, 2) 10572 """ 10573 with self.in_strict_module(codestr) as mod: 10574 self.assertInBytecode( 10575 mod.C.g, 10576 "INVOKE_METHOD", 10577 ((mod.__name__, "C", "f"), 2), 10578 ) 10579 self.assertEqual(asyncio.run(mod.C().g()), 3) 10580 10581 # exercise shadowcode, INVOKE_METHOD_CACHED 10582 async def make_hot(): 10583 c = mod.C() 10584 for i in range(50): 10585 await c.g() 10586 10587 asyncio.run(make_hot()) 10588 self.assertEqual(asyncio.run(mod.C().g()), 3) 10589 10590 def test_awaited_invoke_method_future(self): 10591 codestr = """ 10592 from asyncio import ensure_future 10593 10594 async def h() -> int: 10595 return 1 10596 10597 class C: 10598 async def g(self) -> None: 10599 await ensure_future(h()) 10600 10601 async def f(): 10602 c = C() 10603 await c.g() 10604 """ 10605 with self.in_strict_module(codestr) as mod: 10606 self.assertInBytecode( 10607 mod.f, 10608 "INVOKE_METHOD", 10609 ((mod.__name__, "C", "g"), 0), 10610 ) 10611 asyncio.run(mod.f()) 10612 10613 # exercise shadowcode, INVOKE_METHOD_CACHED 10614 self.make_async_func_hot(mod.f) 10615 asyncio.run(mod.f()) 10616 10617 def test_load_iterable_arg(self): 10618 codestr = """ 10619 def x(a: int, b: int, c: str, d: float, e: float) -> int: 10620 return 7 10621 10622 def y() -> int: 10623 p = ("hi", 0.1, 0.2) 10624 return x(1, 3, *p) 10625 """ 10626 y = self.find_code( 10627 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 10628 ) 10629 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 0) 10630 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 1) 10631 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 2) 10632 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG", 3) 10633 with self.in_module(codestr) as mod: 10634 y_callable = mod["y"] 10635 self.assertEqual(y_callable(), 7) 10636 10637 def test_load_iterable_arg_default_overridden(self): 10638 codestr = """ 10639 def x(a: int, b: int, c: str, d: float = 10.1, e: float = 20.1) -> bool: 10640 return bool( 10641 a == 1 10642 and b == 3 10643 and c == "hi" 10644 and d == 0.1 10645 and e == 0.2 10646 ) 10647 10648 def y() -> bool: 10649 p = ("hi", 0.1, 0.2) 10650 return x(1, 3, *p) 10651 """ 10652 y = self.find_code( 10653 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 10654 ) 10655 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG", 3) 10656 self.assertNotInBytecode(y, "LOAD_MAPPING_ARG", 3) 10657 with self.in_module(codestr) as mod: 10658 y_callable = mod["y"] 10659 self.assertTrue(y_callable()) 10660 10661 def test_load_iterable_arg_multi_star(self): 10662 codestr = """ 10663 def x(a: int, b: int, c: str, d: float, e: float) -> int: 10664 return 7 10665 10666 def y() -> int: 10667 p = (1, 3) 10668 q = ("hi", 0.1, 0.2) 10669 return x(*p, *q) 10670 """ 10671 y = self.find_code( 10672 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 10673 ) 10674 # we should fallback to the normal Python compiler for this 10675 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG") 10676 with self.in_module(codestr) as mod: 10677 y_callable = mod["y"] 10678 self.assertEqual(y_callable(), 7) 10679 10680 def test_load_iterable_arg_star_not_last(self): 10681 codestr = """ 10682 def x(a: int, b: int, c: str, d: float, e: float) -> int: 10683 return 7 10684 10685 def y() -> int: 10686 p = (1, 3, 'abc', 0.1) 10687 return x(*p, 1.0) 10688 """ 10689 y = self.find_code( 10690 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 10691 ) 10692 # we should fallback to the normal Python compiler for this 10693 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG") 10694 with self.in_module(codestr) as mod: 10695 y_callable = mod["y"] 10696 self.assertEqual(y_callable(), 7) 10697 10698 def test_load_iterable_arg_failure(self): 10699 codestr = """ 10700 def x(a: int, b: int, c: str, d: float, e: float) -> int: 10701 return 7 10702 10703 def y() -> int: 10704 p = ("hi", 0.1) 10705 return x(1, 3, *p) 10706 """ 10707 y = self.find_code( 10708 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 10709 ) 10710 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 0) 10711 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 1) 10712 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 2) 10713 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG", 3) 10714 with self.in_module(codestr) as mod: 10715 y_callable = mod["y"] 10716 with self.assertRaises(IndexError): 10717 y_callable() 10718 10719 def test_load_iterable_arg_sequence(self): 10720 codestr = """ 10721 def x(a: int, b: int, c: str, d: float, e: float) -> int: 10722 return 7 10723 10724 def y() -> int: 10725 p = ["hi", 0.1, 0.2] 10726 return x(1, 3, *p) 10727 """ 10728 y = self.find_code( 10729 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 10730 ) 10731 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 0) 10732 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 1) 10733 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 2) 10734 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG", 3) 10735 with self.in_module(codestr) as mod: 10736 y_callable = mod["y"] 10737 self.assertEqual(y_callable(), 7) 10738 10739 def test_load_iterable_arg_sequence_1(self): 10740 codestr = """ 10741 def x(a: int, b: int, c: str, d: float, e: float) -> int: 10742 return 7 10743 10744 def gen(): 10745 for i in ["hi", 0.05, 0.2]: 10746 yield i 10747 10748 def y() -> int: 10749 g = gen() 10750 return x(1, 3, *g) 10751 """ 10752 y = self.find_code( 10753 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 10754 ) 10755 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 0) 10756 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 1) 10757 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 2) 10758 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG", 3) 10759 with self.in_module(codestr) as mod: 10760 y_callable = mod["y"] 10761 self.assertEqual(y_callable(), 7) 10762 10763 def test_load_iterable_arg_sequence_failure(self): 10764 codestr = """ 10765 def x(a: int, b: int, c: str, d: float, e: float) -> int: 10766 return 7 10767 10768 def y() -> int: 10769 p = ["hi", 0.1] 10770 return x(1, 3, *p) 10771 """ 10772 y = self.find_code( 10773 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 10774 ) 10775 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 0) 10776 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 1) 10777 self.assertInBytecode(y, "LOAD_ITERABLE_ARG", 2) 10778 self.assertNotInBytecode(y, "LOAD_ITERABLE_ARG", 3) 10779 with self.in_module(codestr) as mod: 10780 y_callable = mod["y"] 10781 with self.assertRaises(IndexError): 10782 y_callable() 10783 10784 def test_load_mapping_arg(self): 10785 codestr = """ 10786 def x(a: int, b: int, c: str, d: float=-0.1, e: float=1.1, f: str="something") -> bool: 10787 return bool(f == "yo" and d == 1.0 and e == 1.1) 10788 10789 def y() -> bool: 10790 d = {"d": 1.0} 10791 return x(1, 3, "hi", f="yo", **d) 10792 """ 10793 y = self.find_code( 10794 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 10795 ) 10796 self.assertInBytecode(y, "LOAD_MAPPING_ARG", 3) 10797 with self.in_module(codestr) as mod: 10798 y_callable = mod["y"] 10799 self.assertTrue(y_callable()) 10800 10801 def test_load_mapping_and_iterable_args_failure_1(self): 10802 """ 10803 Fails because we don't supply enough positional args 10804 """ 10805 10806 codestr = """ 10807 def x(a: int, b: int, c: str, d: float=2.2, e: float=1.1, f: str="something") -> bool: 10808 return bool(a == 1 and b == 3 and f == "yo" and d == 2.2 and e == 1.1) 10809 10810 def y() -> bool: 10811 return x(1, 3, f="yo") 10812 """ 10813 with self.assertRaisesRegex( 10814 SyntaxError, "Function foo.x expects a value for argument c" 10815 ): 10816 self.compile(codestr, StaticCodeGenerator, modname="foo") 10817 10818 def test_load_mapping_arg_failure(self): 10819 """ 10820 Fails because we supply an extra kwarg 10821 """ 10822 codestr = """ 10823 def x(a: int, b: int, c: str, d: float=2.2, e: float=1.1, f: str="something") -> bool: 10824 return bool(a == 1 and b == 3 and f == "yo" and d == 2.2 and e == 1.1) 10825 10826 def y() -> bool: 10827 return x(1, 3, "hi", f="yo", g="lol") 10828 """ 10829 with self.assertRaisesRegex( 10830 TypedSyntaxError, 10831 "Given argument g does not exist in the definition of foo.x", 10832 ): 10833 self.compile(codestr, StaticCodeGenerator, modname="foo") 10834 10835 def test_load_mapping_arg_custom_class(self): 10836 """ 10837 Fails because we supply a custom class for the mapped args, instead of a dict 10838 """ 10839 codestr = """ 10840 def x(a: int, b: int, c: str="hello") -> bool: 10841 return bool(a == 1 and b == 3 and c == "hello") 10842 10843 class C: 10844 def __getitem__(self, key: str) -> str: 10845 if key == "c": 10846 return "hi" 10847 10848 def keys(self): 10849 return ["c"] 10850 10851 def y() -> bool: 10852 return x(1, 3, **C()) 10853 """ 10854 with self.in_module(codestr) as mod: 10855 y_callable = mod["y"] 10856 with self.assertRaisesRegex( 10857 TypeError, r"argument after \*\* must be a dict, not C" 10858 ): 10859 self.assertTrue(y_callable()) 10860 10861 def test_load_mapping_arg_use_defaults(self): 10862 codestr = """ 10863 def x(a: int, b: int, c: str, d: float=-0.1, e: float=1.1, f: str="something") -> bool: 10864 return bool(f == "yo" and d == -0.1 and e == 1.1) 10865 10866 def y() -> bool: 10867 d = {"d": 1.0} 10868 return x(1, 3, "hi", f="yo") 10869 """ 10870 y = self.find_code( 10871 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 10872 ) 10873 self.assertInBytecode(y, "LOAD_CONST", 1.1) 10874 with self.in_module(codestr) as mod: 10875 y_callable = mod["y"] 10876 self.assertTrue(y_callable()) 10877 10878 def test_default_arg_non_const(self): 10879 codestr = """ 10880 class C: pass 10881 def x(val=C()) -> C: 10882 return val 10883 10884 def f() -> C: 10885 return x() 10886 """ 10887 with self.in_module(codestr) as mod: 10888 f = mod["f"] 10889 self.assertInBytecode(f, "CALL_FUNCTION") 10890 10891 def test_default_arg_non_const_kw_provided(self): 10892 codestr = """ 10893 class C: pass 10894 def x(val:object=C()): 10895 return val 10896 10897 def f(): 10898 return x(val=42) 10899 """ 10900 10901 with self.in_module(codestr) as mod: 10902 f = mod["f"] 10903 self.assertEqual(f(), 42) 10904 10905 def test_load_mapping_arg_order(self): 10906 codestr = """ 10907 def x(a: int, b: int, c: str, d: float=-0.1, e: float=1.1, f: str="something") -> bool: 10908 return bool( 10909 a == 1 10910 and b == 3 10911 and c == "hi" 10912 and d == 1.1 10913 and e == 3.3 10914 and f == "hmm" 10915 ) 10916 10917 stuff = [] 10918 def q() -> float: 10919 stuff.append("q") 10920 return 1.1 10921 10922 def r() -> float: 10923 stuff.append("r") 10924 return 3.3 10925 10926 def s() -> str: 10927 stuff.append("s") 10928 return "hmm" 10929 10930 def y() -> bool: 10931 return x(1, 3, "hi", f=s(), d=q(), e=r()) 10932 """ 10933 y = self.find_code( 10934 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 10935 ) 10936 self.assertInBytecode(y, "STORE_FAST", "_pystatic_.0._tmp__d") 10937 self.assertInBytecode(y, "LOAD_FAST", "_pystatic_.0._tmp__d") 10938 with self.in_module(codestr) as mod: 10939 y_callable = mod["y"] 10940 self.assertTrue(y_callable()) 10941 self.assertEqual(["s", "q", "r"], mod["stuff"]) 10942 10943 def test_load_mapping_arg_order_with_variadic_kw_args(self): 10944 codestr = """ 10945 def x(a: int, b: int, c: str, d: float=-0.1, e: float=1.1, f: str="something", g: str="look-here") -> bool: 10946 return bool( 10947 a == 1 10948 and b == 3 10949 and c == "hi" 10950 and d == 1.1 10951 and e == 3.3 10952 and f == "hmm" 10953 and g == "overridden" 10954 ) 10955 10956 stuff = [] 10957 def q() -> float: 10958 stuff.append("q") 10959 return 1.1 10960 10961 def r() -> float: 10962 stuff.append("r") 10963 return 3.3 10964 10965 def s() -> str: 10966 stuff.append("s") 10967 return "hmm" 10968 10969 def y() -> bool: 10970 kw = {"g": "overridden"} 10971 return x(1, 3, "hi", f=s(), **kw, d=q(), e=r()) 10972 """ 10973 y = self.find_code( 10974 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 10975 ) 10976 self.assertInBytecode(y, "STORE_FAST", "_pystatic_.0._tmp__d") 10977 self.assertInBytecode(y, "LOAD_FAST", "_pystatic_.0._tmp__d") 10978 with self.in_module(codestr) as mod: 10979 y_callable = mod["y"] 10980 self.assertTrue(y_callable()) 10981 self.assertEqual(["s", "q", "r"], mod["stuff"]) 10982 10983 def test_load_mapping_arg_order_with_variadic_kw_args_one_positional(self): 10984 codestr = """ 10985 def x(a: int, b: int, c: str, d: float=-0.1, e: float=1.1, f: str="something", g: str="look-here") -> bool: 10986 return bool( 10987 a == 1 10988 and b == 3 10989 and c == "hi" 10990 and d == 1.1 10991 and e == 3.3 10992 and f == "hmm" 10993 and g == "overridden" 10994 ) 10995 10996 stuff = [] 10997 def q() -> float: 10998 stuff.append("q") 10999 return 1.1 11000 11001 def r() -> float: 11002 stuff.append("r") 11003 return 3.3 11004 11005 def s() -> str: 11006 stuff.append("s") 11007 return "hmm" 11008 11009 11010 def y() -> bool: 11011 kw = {"g": "overridden"} 11012 return x(1, 3, "hi", 1.1, f=s(), **kw, e=r()) 11013 """ 11014 y = self.find_code( 11015 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 11016 ) 11017 self.assertNotInBytecode(y, "STORE_FAST", "_pystatic_.0._tmp__d") 11018 self.assertNotInBytecode(y, "LOAD_FAST", "_pystatic_.0._tmp__d") 11019 with self.in_module(codestr) as mod: 11020 y_callable = mod["y"] 11021 self.assertTrue(y_callable()) 11022 self.assertEqual(["s", "r"], mod["stuff"]) 11023 11024 def test_vector_generics(self): 11025 T = TypeVar("T") 11026 VT = Vector[T] 11027 VT2 = VT[int64] 11028 a = VT2() 11029 a.append(42) 11030 with self.assertRaisesRegex(TypeError, "Cannot create plain Vector"): 11031 VT() 11032 11033 def test_vector_invalid_type(self): 11034 class C: 11035 pass 11036 11037 with self.assertRaisesRegex( 11038 TypeError, "Invalid type for ArrayElement: C when instantiating Vector" 11039 ): 11040 Vector[C] 11041 11042 def test_vector_wrong_arg_count(self): 11043 class C: 11044 pass 11045 11046 with self.assertRaisesRegex( 11047 TypeError, "Incorrect number of type arguments for Vector" 11048 ): 11049 Vector[int64, int64] 11050 11051 def test_generic_type_args(self): 11052 T = TypeVar("T") 11053 U = TypeVar("U") 11054 11055 class C(StaticGeneric[T, U]): 11056 pass 11057 11058 c_t = make_generic_type(C, (T, int)) 11059 self.assertEqual(c_t.__parameters__, (T,)) 11060 c_t_s = make_generic_type(c_t, (str,)) 11061 self.assertEqual(c_t_s.__name__, "C[str, int]") 11062 c_u = make_generic_type(C, (int, U)) 11063 self.assertEqual(c_u.__parameters__, (U,)) 11064 c_u_t = make_generic_type(c_u, (str,)) 11065 self.assertEqual(c_u_t.__name__, "C[int, str]") 11066 self.assertFalse(hasattr(c_u_t, "__parameters__")) 11067 11068 c_u_t_1 = make_generic_type(c_u, (int,)) 11069 c_u_t_2 = make_generic_type(c_t, (int,)) 11070 self.assertEqual(c_u_t_1.__name__, "C[int, int]") 11071 self.assertIs(c_u_t_1, c_u_t_2) 11072 11073 def test_array_slice(self): 11074 v = Array[int64]([1, 2, 3, 4]) 11075 self.assertEqual(v[1:3], Array[int64]([2, 3])) 11076 self.assertEqual(type(v[1:2]), Array[int64]) 11077 11078 def test_vector_slice(self): 11079 v = Vector[int64]([1, 2, 3, 4]) 11080 self.assertEqual(v[1:3], Vector[int64]([2, 3])) 11081 self.assertEqual(type(v[1:2]), Vector[int64]) 11082 11083 def test_array_deepcopy(self): 11084 v = Array[int64]([1, 2, 3, 4]) 11085 self.assertEqual(v, deepcopy(v)) 11086 self.assertIsNot(v, deepcopy(v)) 11087 self.assertEqual(type(v), type(deepcopy(v))) 11088 11089 def test_vector_deepcopy(self): 11090 v = Vector[int64]([1, 2, 3, 4]) 11091 self.assertEqual(v, deepcopy(v)) 11092 self.assertIsNot(v, deepcopy(v)) 11093 self.assertEqual(type(v), type(deepcopy(v))) 11094 11095 def test_nested_generic(self): 11096 S = TypeVar("S") 11097 T = TypeVar("T") 11098 U = TypeVar("U") 11099 11100 class F(StaticGeneric[U]): 11101 pass 11102 11103 class C(StaticGeneric[T]): 11104 pass 11105 11106 A = F[S] 11107 self.assertEqual(A.__parameters__, (S,)) 11108 X = C[F[T]] 11109 self.assertEqual(X.__parameters__, (T,)) 11110 11111 def test_array_len(self): 11112 codestr = """ 11113 from __static__ import int64, char, double, Array 11114 from array import array 11115 11116 def y(): 11117 return len(Array[int64]([1, 3, 5])) 11118 """ 11119 y = self.find_code( 11120 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 11121 ) 11122 self.assertInBytecode(y, "FAST_LEN", FAST_LEN_ARRAY) 11123 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11124 y = mod["y"] 11125 self.assertEqual(y(), 3) 11126 11127 def test_array_isinstance(self): 11128 x = Array[int64](0) 11129 self.assertTrue(isinstance(x, Array[int64])) 11130 self.assertFalse(isinstance(x, Array[int32])) 11131 self.assertTrue(issubclass(Array[int64], Array[int64])) 11132 self.assertFalse(issubclass(Array[int64], Array[int32])) 11133 11134 def test_array_weird_type_constrution(self): 11135 self.assertIs( 11136 Array[int64], 11137 Array[ 11138 int64, 11139 ], 11140 ) 11141 11142 def test_array_not_subclassable(self): 11143 11144 with self.assertRaises(TypeError): 11145 11146 class C(Array[int64]): 11147 pass 11148 11149 with self.assertRaises(TypeError): 11150 11151 class C(Array): 11152 11153 pass 11154 11155 def test_array_enum(self): 11156 codestr = """ 11157 from __static__ import Array, clen, int64, box 11158 11159 def f(x: Array[int64]): 11160 i: int64 = 0 11161 j: int64 = 0 11162 while i < clen(x): 11163 j += x[i] 11164 i+=1 11165 return box(j) 11166 """ 11167 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11168 f = mod["f"] 11169 a = Array[int64]([1, 2, 3, 4]) 11170 self.assertEqual(f(a), 10) 11171 with self.assertRaises(TypeError): 11172 f(None) 11173 11174 def test_optional_array_enum(self): 11175 codestr = """ 11176 from __static__ import Array, clen, int64, box 11177 from typing import Optional 11178 11179 def f(x: Optional[Array[int64]]): 11180 if x is None: 11181 return 42 11182 11183 i: int64 = 0 11184 j: int64 = 0 11185 while i < clen(x): 11186 j += x[i] 11187 i+=1 11188 return box(j) 11189 """ 11190 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11191 f = mod["f"] 11192 a = Array[int64]([1, 2, 3, 4]) 11193 self.assertEqual(f(a), 10) 11194 self.assertEqual(f(None), 42) 11195 11196 def test_array_len_subclass(self): 11197 codestr = """ 11198 from __static__ import int64, Array 11199 11200 def y(ar: Array[int64]): 11201 return len(ar) 11202 """ 11203 y = self.find_code( 11204 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 11205 ) 11206 self.assertInBytecode(y, "FAST_LEN", FAST_LEN_ARRAY | FAST_LEN_INEXACT) 11207 11208 # TODO the below requires Array to be a generic type in C, or else 11209 # support for generic annotations for not-generic-in-C types. For now 11210 # it's sufficient to validate we emitted FAST_LEN_INEXACT flag. 11211 11212 # class MyArray(Array): 11213 # def __len__(self): 11214 # return 123 11215 11216 # with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11217 # y = mod["y"] 11218 # self.assertEqual(y(MyArray[int64]([1])), 123) 11219 11220 def test_nonarray_len(self): 11221 codestr = """ 11222 class Lol: 11223 def __len__(self): 11224 return 421 11225 11226 def y(): 11227 return len(Lol()) 11228 """ 11229 y = self.find_code( 11230 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 11231 ) 11232 self.assertNotInBytecode(y, "FAST_LEN") 11233 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11234 y = mod["y"] 11235 self.assertEqual(y(), 421) 11236 11237 def test_clen(self): 11238 codestr = """ 11239 from __static__ import box, clen, int64 11240 from typing import List 11241 11242 def f(l: List[int]): 11243 x: int64 = clen(l) 11244 return box(x) 11245 """ 11246 with self.in_module(codestr) as mod: 11247 f = mod["f"] 11248 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST | FAST_LEN_INEXACT) 11249 self.assertEqual(f([1, 2, 3]), 3) 11250 11251 class MyList(list): 11252 def __len__(self): 11253 return 99 11254 11255 self.assertEqual(f(MyList([1, 2])), 99) 11256 11257 def test_clen_bad_arg(self): 11258 codestr = """ 11259 from __static__ import clen 11260 11261 def f(l): 11262 clen(l) 11263 """ 11264 with self.assertRaisesRegex( 11265 TypedSyntaxError, "bad argument type 'dynamic' for clen()" 11266 ): 11267 self.compile(codestr) 11268 11269 def test_seq_repeat_list(self): 11270 codestr = """ 11271 def f(): 11272 l = [1, 2] 11273 return l * 2 11274 """ 11275 f = self.find_code(self.compile(codestr)) 11276 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_LIST) 11277 with self.in_module(codestr) as mod: 11278 self.assertEqual(mod["f"](), [1, 2, 1, 2]) 11279 11280 def test_seq_repeat_list_reversed(self): 11281 codestr = """ 11282 def f(): 11283 l = [1, 2] 11284 return 2 * l 11285 """ 11286 f = self.find_code(self.compile(codestr)) 11287 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_LIST | SEQ_REPEAT_REVERSED) 11288 with self.in_module(codestr) as mod: 11289 self.assertEqual(mod["f"](), [1, 2, 1, 2]) 11290 11291 def test_seq_repeat_primitive(self): 11292 codestr = """ 11293 from __static__ import int64 11294 11295 def f(): 11296 x: int64 = 2 11297 l = [1, 2] 11298 return l * x 11299 """ 11300 f = self.find_code(self.compile(codestr)) 11301 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_LIST | SEQ_REPEAT_PRIMITIVE_NUM) 11302 with self.in_module(codestr) as mod: 11303 self.assertEqual(mod["f"](), [1, 2, 1, 2]) 11304 11305 def test_seq_repeat_primitive_reversed(self): 11306 codestr = """ 11307 from __static__ import int64 11308 11309 def f(): 11310 x: int64 = 2 11311 l = [1, 2] 11312 return x * l 11313 """ 11314 f = self.find_code(self.compile(codestr)) 11315 self.assertInBytecode( 11316 f, 11317 "SEQUENCE_REPEAT", 11318 SEQ_LIST | SEQ_REPEAT_REVERSED | SEQ_REPEAT_PRIMITIVE_NUM, 11319 ) 11320 with self.in_module(codestr) as mod: 11321 self.assertEqual(mod["f"](), [1, 2, 1, 2]) 11322 11323 def test_seq_repeat_tuple(self): 11324 codestr = """ 11325 def f(): 11326 t = (1, 2) 11327 return t * 2 11328 """ 11329 f = self.find_code(self.compile(codestr)) 11330 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_TUPLE) 11331 with self.in_module(codestr) as mod: 11332 self.assertEqual(mod["f"](), (1, 2, 1, 2)) 11333 11334 def test_seq_repeat_tuple_reversed(self): 11335 codestr = """ 11336 def f(): 11337 t = (1, 2) 11338 return 2 * t 11339 """ 11340 f = self.find_code(self.compile(codestr)) 11341 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_TUPLE | SEQ_REPEAT_REVERSED) 11342 with self.in_module(codestr) as mod: 11343 self.assertEqual(mod["f"](), (1, 2, 1, 2)) 11344 11345 def test_seq_repeat_inexact_list(self): 11346 codestr = """ 11347 from typing import List 11348 11349 def f(l: List[int]): 11350 return l * 2 11351 """ 11352 f = self.find_code(self.compile(codestr)) 11353 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_LIST | SEQ_REPEAT_INEXACT_SEQ) 11354 with self.in_module(codestr) as mod: 11355 self.assertEqual(mod["f"]([1, 2]), [1, 2, 1, 2]) 11356 11357 class MyList(list): 11358 def __mul__(self, other): 11359 return "RESULT" 11360 11361 self.assertEqual(mod["f"](MyList([1, 2])), "RESULT") 11362 11363 def test_seq_repeat_inexact_tuple(self): 11364 11365 codestr = """ 11366 from typing import Tuple 11367 11368 def f(t: Tuple[int]): 11369 return t * 2 11370 """ 11371 f = self.find_code(self.compile(codestr)) 11372 self.assertInBytecode(f, "SEQUENCE_REPEAT", SEQ_TUPLE | SEQ_REPEAT_INEXACT_SEQ) 11373 with self.in_module(codestr) as mod: 11374 self.assertEqual(mod["f"]((1, 2)), (1, 2, 1, 2)) 11375 11376 class MyTuple(tuple): 11377 def __mul__(self, other): 11378 return "RESULT" 11379 11380 self.assertEqual(mod["f"](MyTuple((1, 2))), "RESULT") 11381 11382 def test_seq_repeat_inexact_num(self): 11383 codestr = """ 11384 def f(num: int): 11385 11386 return num * [1, 2] 11387 """ 11388 f = self.find_code(self.compile(codestr)) 11389 self.assertInBytecode( 11390 f, 11391 "SEQUENCE_REPEAT", 11392 SEQ_LIST | SEQ_REPEAT_INEXACT_NUM | SEQ_REPEAT_REVERSED, 11393 ) 11394 with self.in_module(codestr) as mod: 11395 self.assertEqual(mod["f"](2), [1, 2, 1, 2]) 11396 11397 class MyInt(int): 11398 def __mul__(self, other): 11399 return "RESULT" 11400 11401 self.assertEqual(mod["f"](MyInt(2)), "RESULT") 11402 11403 def test_load_int_const_sizes(self): 11404 cases = [ 11405 ("int8", (1 << 7), True), 11406 ("int8", (-1 << 7) - 1, True), 11407 ("int8", -1 << 7, False), 11408 ("int8", (1 << 7) - 1, False), 11409 ("int16", (1 << 15), True), 11410 ("int16", (-1 << 15) - 1, True), 11411 ("int16", -1 << 15, False), 11412 ("int16", (1 << 15) - 1, False), 11413 ("int32", (1 << 31), True), 11414 ("int32", (-1 << 31) - 1, True), 11415 ("int32", -1 << 31, False), 11416 ("int32", (1 << 31) - 1, False), 11417 ("int64", (1 << 63), True), 11418 ("int64", (-1 << 63) - 1, True), 11419 ("int64", -1 << 63, False), 11420 ("int64", (1 << 63) - 1, False), 11421 ("uint8", (1 << 8), True), 11422 ("uint8", -1, True), 11423 ("uint8", (1 << 8) - 1, False), 11424 ("uint8", 0, False), 11425 ("uint16", (1 << 16), True), 11426 ("uint16", -1, True), 11427 ("uint16", (1 << 16) - 1, False), 11428 ("uint16", 0, False), 11429 ("uint32", (1 << 32), True), 11430 ("uint32", -1, True), 11431 ("uint32", (1 << 32) - 1, False), 11432 ("uint32", 0, False), 11433 ("uint64", (1 << 64), True), 11434 ("uint64", -1, True), 11435 ("uint64", (1 << 64) - 1, False), 11436 ("uint64", 0, False), 11437 ] 11438 for type, val, overflows in cases: 11439 codestr = f""" 11440 from __static__ import {type}, box 11441 11442 def f() -> int: 11443 x: {type} = {val} 11444 return box(x) 11445 """ 11446 with self.subTest(type=type, val=val, overflows=overflows): 11447 if overflows: 11448 with self.assertRaisesRegex( 11449 TypedSyntaxError, "outside of the range" 11450 ): 11451 self.compile(codestr) 11452 else: 11453 with self.in_strict_module(codestr) as mod: 11454 self.assertEqual(mod.f(), val) 11455 11456 def test_load_int_const_signed(self): 11457 int_types = [ 11458 "int8", 11459 "int16", 11460 "int32", 11461 "int64", 11462 ] 11463 signs = ["-", ""] 11464 values = [12] 11465 11466 for type, sign, value in itertools.product(int_types, signs, values): 11467 expected = value if sign == "" else -value 11468 11469 codestr = f""" 11470 from __static__ import {type}, box 11471 11472 def y() -> int: 11473 x: {type} = {sign}{value} 11474 return box(x) 11475 """ 11476 with self.subTest(type=type, sign=sign, value=value): 11477 y = self.find_code( 11478 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 11479 ) 11480 self.assertInBytecode(y, "PRIMITIVE_LOAD_CONST") 11481 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11482 y = mod["y"] 11483 self.assertEqual(y(), expected) 11484 11485 def test_load_int_const_unsigned(self): 11486 int_types = [ 11487 "uint8", 11488 "uint16", 11489 "uint32", 11490 "uint64", 11491 ] 11492 values = [12] 11493 11494 for type, value in itertools.product(int_types, values): 11495 expected = value 11496 codestr = f""" 11497 from __static__ import {type}, box 11498 11499 def y() -> int: 11500 return box({type}({value})) 11501 """ 11502 with self.subTest(type=type, value=value): 11503 y = self.find_code( 11504 self.compile(codestr, StaticCodeGenerator, modname="foo"), name="y" 11505 ) 11506 self.assertInBytecode(y, "PRIMITIVE_LOAD_CONST") 11507 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11508 y = mod["y"] 11509 self.assertEqual(y(), expected) 11510 11511 def test_primitive_out_of_range(self): 11512 codestr = f""" 11513 from __static__ import int8, box 11514 11515 def f() -> int: 11516 x = int8(255) 11517 return box(x) 11518 """ 11519 with self.assertRaisesRegex( 11520 TypedSyntaxError, 11521 "constant 255 is outside of the range -128 to 127 for int8", 11522 ): 11523 self.compile(codestr) 11524 11525 def test_primitive_conversions(self): 11526 cases = [ 11527 ("int8", "int8", 5, 5), 11528 ("int8", "int16", 5, 5), 11529 ("int8", "int32", 5, 5), 11530 ("int8", "int64", 5, 5), 11531 ("int8", "uint8", -1, 255), 11532 ("int8", "uint8", 12, 12), 11533 ("int8", "uint16", -1, 65535), 11534 ("int8", "uint16", 12, 12), 11535 ("int8", "uint32", -1, 4294967295), 11536 ("int8", "uint32", 12, 12), 11537 ("int8", "uint64", -1, 18446744073709551615), 11538 ("int8", "uint64", 12, 12), 11539 ("int16", "int8", 5, 5), 11540 ("int16", "int8", -1, -1), 11541 ("int16", "int8", 32767, -1), 11542 ("int16", "int16", 5, 5), 11543 ("int16", "int32", -5, -5), 11544 ("int16", "int64", -6, -6), 11545 ("int16", "uint8", 32767, 255), 11546 ("int16", "uint8", -1, 255), 11547 ("int16", "uint16", 32767, 32767), 11548 ("int16", "uint16", -1, 65535), 11549 ("int16", "uint32", 1000, 1000), 11550 ("int16", "uint32", -1, 4294967295), 11551 ("int16", "uint64", 1414, 1414), 11552 ("int16", "uint64", -1, 18446744073709551615), 11553 ("int32", "int8", 5, 5), 11554 ("int32", "int8", -1, -1), 11555 ("int32", "int8", 2147483647, -1), 11556 ("int32", "int16", 5, 5), 11557 ("int32", "int16", -1, -1), 11558 ("int32", "int16", 2147483647, -1), 11559 ("int32", "int32", 5, 5), 11560 ("int32", "int64", 5, 5), 11561 ("int32", "uint8", 5, 5), 11562 ("int32", "uint8", 65535, 255), 11563 ("int32", "uint8", -1, 255), 11564 ("int32", "uint16", 5, 5), 11565 ("int32", "uint16", 2147483647, 65535), 11566 ("int32", "uint16", -1, 65535), 11567 ("int32", "uint32", 5, 5), 11568 ("int32", "uint32", -1, 4294967295), 11569 ("int32", "uint64", 5, 5), 11570 ("int32", "uint64", -1, 18446744073709551615), 11571 ("int64", "int8", 5, 5), 11572 ("int64", "int8", -1, -1), 11573 ("int64", "int8", 65535, -1), 11574 ("int64", "int16", 5, 5), 11575 ("int64", "int16", -1, -1), 11576 ("int64", "int16", 4294967295, -1), 11577 ("int64", "int32", 5, 5), 11578 ("int64", "int32", -1, -1), 11579 ("int64", "int32", 9223372036854775807, -1), 11580 ("int64", "int64", 5, 5), 11581 ("int64", "uint8", 5, 5), 11582 ("int64", "uint8", 65535, 255), 11583 ("int64", "uint8", -1, 255), 11584 ("int64", "uint16", 5, 5), 11585 ("int64", "uint16", 4294967295, 65535), 11586 ("int64", "uint16", -1, 65535), 11587 ("int64", "uint32", 5, 5), 11588 ("int64", "uint32", 9223372036854775807, 4294967295), 11589 ("int64", "uint32", -1, 4294967295), 11590 ("int64", "uint64", 5, 5), 11591 ("int64", "uint64", -1, 18446744073709551615), 11592 ("uint8", "int8", 5, 5), 11593 ("uint8", "int8", 255, -1), 11594 ("uint8", "int16", 255, 255), 11595 ("uint8", "int32", 255, 255), 11596 ("uint8", "int64", 255, 255), 11597 ("uint8", "uint8", 5, 5), 11598 ("uint8", "uint16", 255, 255), 11599 ("uint8", "uint32", 255, 255), 11600 ("uint8", "uint64", 255, 255), 11601 ("uint16", "int8", 5, 5), 11602 ("uint16", "int8", 65535, -1), 11603 ("uint16", "int16", 5, 5), 11604 ("uint16", "int16", 65535, -1), 11605 ("uint16", "int32", 65535, 65535), 11606 ("uint16", "int64", 65535, 65535), 11607 ("uint16", "uint8", 65535, 255), 11608 ("uint16", "uint16", 65535, 65535), 11609 ("uint16", "uint32", 65535, 65535), 11610 ("uint16", "uint64", 65535, 65535), 11611 ("uint32", "int8", 4, 4), 11612 ("uint32", "int8", 4294967295, -1), 11613 ("uint32", "int16", 5, 5), 11614 ("uint32", "int16", 4294967295, -1), 11615 ("uint32", "int32", 65535, 65535), 11616 ("uint32", "int32", 4294967295, -1), 11617 ("uint32", "int64", 4294967295, 4294967295), 11618 ("uint32", "uint8", 4, 4), 11619 ("uint32", "uint8", 65535, 255), 11620 ("uint32", "uint16", 4294967295, 65535), 11621 ("uint32", "uint32", 5, 5), 11622 ("uint32", "uint64", 4294967295, 4294967295), 11623 ("uint64", "int8", 4, 4), 11624 ("uint64", "int8", 18446744073709551615, -1), 11625 ("uint64", "int16", 4, 4), 11626 ("uint64", "int16", 18446744073709551615, -1), 11627 ("uint64", "int32", 4, 4), 11628 ("uint64", "int32", 18446744073709551615, -1), 11629 ("uint64", "int64", 4, 4), 11630 ("uint64", "int64", 18446744073709551615, -1), 11631 ("uint64", "uint8", 5, 5), 11632 ("uint64", "uint8", 65535, 255), 11633 ("uint64", "uint16", 4294967295, 65535), 11634 ("uint64", "uint32", 18446744073709551615, 4294967295), 11635 ("uint64", "uint64", 5, 5), 11636 ] 11637 11638 for src, dest, val, expected in cases: 11639 codestr = f""" 11640 from __static__ import {src}, {dest}, box 11641 11642 def y() -> int: 11643 x = {dest}({src}({val})) 11644 return box(x) 11645 """ 11646 with self.subTest(src=src, dest=dest, val=val, expected=expected): 11647 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11648 y = mod["y"] 11649 actual = y() 11650 self.assertEqual( 11651 actual, 11652 expected, 11653 f"failing case: {[src, dest, val, actual, expected]}", 11654 ) 11655 11656 def test_no_cast_after_box(self): 11657 codestr = """ 11658 from __static__ import int64, box 11659 11660 def f(x: int) -> int: 11661 y = int64(x) + 1 11662 return box(y) 11663 """ 11664 with self.in_module(codestr) as mod: 11665 f = mod["f"] 11666 self.assertNotInBytecode(f, "CAST") 11667 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST", (1, TYPED_INT64)) 11668 self.assertEqual(f(3), 4) 11669 11670 def test_rand(self): 11671 codestr = """ 11672 from __static__ import rand, RAND_MAX, box, int64 11673 11674 def test(): 11675 x: int64 = rand() 11676 return box(x) 11677 """ 11678 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11679 test = mod["test"] 11680 self.assertEqual(type(test()), int) 11681 11682 def test_rand_max_inlined(self): 11683 codestr = """ 11684 from __static__ import rand, RAND_MAX, box, int64 11685 11686 def f() -> int: 11687 x: int64 = rand() // int64(RAND_MAX) 11688 return box(x) 11689 """ 11690 with self.in_module(codestr) as mod: 11691 f = mod["f"] 11692 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST") 11693 self.assertIsInstance(f(), int) 11694 11695 def test_array_get_primitive_idx(self): 11696 codestr = """ 11697 from __static__ import Array, int8, box 11698 11699 def m() -> int: 11700 content = list(range(121)) 11701 a = Array[int8](content) 11702 return box(a[int8(111)]) 11703 """ 11704 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11705 m = self.find_code(c, "m") 11706 self.assertInBytecode(m, "PRIMITIVE_LOAD_CONST", (111, TYPED_INT8)) 11707 self.assertInBytecode(m, "SEQUENCE_GET", SEQ_ARRAY_INT8) 11708 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11709 m = mod["m"] 11710 actual = m() 11711 self.assertEqual(actual, 111) 11712 11713 def test_array_get_nonprimitive_idx(self): 11714 codestr = """ 11715 from __static__ import Array, int8, box 11716 11717 def m() -> int: 11718 content = list(range(121)) 11719 a = Array[int8](content) 11720 return box(a[111]) 11721 """ 11722 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11723 m = self.find_code(c, "m") 11724 self.assertInBytecode(m, "LOAD_CONST", 111) 11725 self.assertNotInBytecode(m, "PRIMITIVE_LOAD_CONST") 11726 self.assertInBytecode(m, "PRIMITIVE_UNBOX") 11727 self.assertInBytecode(m, "SEQUENCE_GET", SEQ_ARRAY_INT8) 11728 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11729 m = mod["m"] 11730 actual = m() 11731 self.assertEqual(actual, 111) 11732 11733 def test_array_get_failure(self): 11734 codestr = """ 11735 from __static__ import Array, int8, box 11736 11737 def m() -> int: 11738 a = Array[int8]([1, 3, -5]) 11739 return box(a[20]) 11740 """ 11741 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11742 m = mod["m"] 11743 with self.assertRaisesRegex(IndexError, "index out of range"): 11744 m() 11745 11746 def test_array_get_negative_idx(self): 11747 codestr = """ 11748 from __static__ import Array, int8, box 11749 11750 def m() -> int: 11751 a = Array[int8]([1, 3, -5]) 11752 return box(a[-1]) 11753 """ 11754 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11755 m = mod["m"] 11756 self.assertEqual(m(), -5) 11757 11758 def test_array_set_signed(self): 11759 int_types = [ 11760 "int8", 11761 "int16", 11762 "int32", 11763 "int64", 11764 ] 11765 seq_types = { 11766 "int8": SEQ_ARRAY_INT8, 11767 "int16": SEQ_ARRAY_INT16, 11768 "int32": SEQ_ARRAY_INT32, 11769 "int64": SEQ_ARRAY_INT64, 11770 } 11771 signs = ["-", ""] 11772 value = 77 11773 11774 for type, sign in itertools.product(int_types, signs): 11775 codestr = f""" 11776 from __static__ import Array, {type} 11777 11778 def m() -> Array[{type}]: 11779 a = Array[{type}]([1, 3, -5]) 11780 a[1] = {sign}{value} 11781 return a 11782 """ 11783 with self.subTest(type=type, sign=sign): 11784 val = -value if sign else value 11785 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11786 m = self.find_code(c, "m") 11787 self.assertInBytecode( 11788 m, "PRIMITIVE_LOAD_CONST", (val, prim_name_to_type[type]) 11789 ) 11790 self.assertInBytecode(m, "LOAD_CONST", 1) 11791 self.assertInBytecode(m, "SEQUENCE_SET", seq_types[type]) 11792 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11793 m = mod["m"] 11794 if sign: 11795 expected = -value 11796 else: 11797 expected = value 11798 result = m() 11799 self.assertEqual( 11800 result, 11801 array("q", [1, expected, -5]), 11802 f"Failing case: {type}, {sign}", 11803 ) 11804 11805 def test_array_set_unsigned(self): 11806 uint_types = [ 11807 "uint8", 11808 "uint16", 11809 "uint32", 11810 "uint64", 11811 ] 11812 value = 77 11813 seq_types = { 11814 "uint8": SEQ_ARRAY_UINT8, 11815 "uint16": SEQ_ARRAY_UINT16, 11816 "uint32": SEQ_ARRAY_UINT32, 11817 "uint64": SEQ_ARRAY_UINT64, 11818 } 11819 for type in uint_types: 11820 codestr = f""" 11821 from __static__ import Array, {type} 11822 11823 def m() -> Array[{type}]: 11824 a = Array[{type}]([1, 3, 5]) 11825 a[1] = {value} 11826 return a 11827 """ 11828 with self.subTest(type=type): 11829 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11830 m = self.find_code(c, "m") 11831 self.assertInBytecode( 11832 m, "PRIMITIVE_LOAD_CONST", (value, prim_name_to_type[type]) 11833 ) 11834 self.assertInBytecode(m, "LOAD_CONST", 1) 11835 self.assertInBytecode(m, "SEQUENCE_SET", seq_types[type]) 11836 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11837 m = mod["m"] 11838 expected = value 11839 result = m() 11840 self.assertEqual( 11841 result, array("q", [1, expected, 5]), f"Failing case: {type}" 11842 ) 11843 11844 def test_array_set_negative_idx(self): 11845 codestr = """ 11846 from __static__ import Array, int8 11847 11848 def m() -> Array[int8]: 11849 a = Array[int8]([1, 3, -5]) 11850 a[-2] = 7 11851 return a 11852 """ 11853 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11854 m = self.find_code(c, "m") 11855 self.assertInBytecode(m, "PRIMITIVE_LOAD_CONST", (7, TYPED_INT8)) 11856 self.assertInBytecode(m, "LOAD_CONST", -2) 11857 self.assertInBytecode(m, "SEQUENCE_SET", SEQ_ARRAY_INT8) 11858 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11859 m = mod["m"] 11860 self.assertEqual(m(), array("h", [1, 7, -5])) 11861 11862 def test_array_set_failure(self) -> object: 11863 codestr = """ 11864 from __static__ import Array, int8 11865 11866 def m() -> Array[int8]: 11867 a = Array[int8]([1, 3, -5]) 11868 a[-100] = 7 11869 return a 11870 """ 11871 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11872 m = self.find_code(c, "m") 11873 self.assertInBytecode(m, "PRIMITIVE_LOAD_CONST", (7, TYPED_INT8)) 11874 self.assertInBytecode(m, "SEQUENCE_SET", SEQ_ARRAY_INT8) 11875 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11876 m = mod["m"] 11877 with self.assertRaisesRegex(IndexError, "index out of range"): 11878 m() 11879 11880 def test_array_set_failure_invalid_subscript(self): 11881 codestr = """ 11882 from __static__ import Array, int8 11883 11884 def x(): 11885 return object() 11886 11887 def m() -> Array[int8]: 11888 a = Array[int8]([1, 3, -5]) 11889 a[x()] = 7 11890 return a 11891 """ 11892 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11893 m = self.find_code(c, "m") 11894 self.assertInBytecode(m, "PRIMITIVE_LOAD_CONST", (7, TYPED_INT8)) 11895 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11896 m = mod["m"] 11897 with self.assertRaisesRegex(TypeError, "array indices must be integers"): 11898 m() 11899 11900 def test_fast_len_list(self): 11901 codestr = """ 11902 def f(): 11903 l = [1, 2, 3, 4, 5, 6, 7] 11904 return len(l) 11905 """ 11906 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11907 f = self.find_code(c, "f") 11908 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST) 11909 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11910 f = mod["f"] 11911 self.assertEqual(f(), 7) 11912 11913 def test_fast_len_str(self): 11914 codestr = """ 11915 def f(): 11916 l = "my str!" 11917 return len(l) 11918 """ 11919 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11920 f = self.find_code(c, "f") 11921 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_STR) 11922 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11923 f = mod["f"] 11924 self.assertEqual(f(), 7) 11925 11926 def test_fast_len_str_unicode_chars(self): 11927 codestr = """ 11928 def f(): 11929 l = "\U0001F923" # ROFL emoji 11930 return len(l) 11931 """ 11932 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11933 f = self.find_code(c, "f") 11934 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_STR) 11935 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11936 f = mod["f"] 11937 self.assertEqual(f(), 1) 11938 11939 def test_fast_len_tuple(self): 11940 codestr = """ 11941 def f(a, b): 11942 l = (a, b) 11943 return len(l) 11944 """ 11945 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11946 f = self.find_code(c, "f") 11947 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_TUPLE) 11948 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11949 f = mod["f"] 11950 self.assertEqual(f("a", "b"), 2) 11951 11952 def test_fast_len_set(self): 11953 codestr = """ 11954 def f(a, b): 11955 l = {a, b} 11956 return len(l) 11957 """ 11958 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11959 f = self.find_code(c, "f") 11960 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_SET) 11961 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11962 f = mod["f"] 11963 self.assertEqual(f("a", "b"), 2) 11964 11965 def test_fast_len_dict(self): 11966 codestr = """ 11967 def f(): 11968 l = {1: 'a', 2: 'b', 3: 'c', 4: 'd'} 11969 return len(l) 11970 """ 11971 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11972 f = self.find_code(c, "f") 11973 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_DICT) 11974 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11975 f = mod["f"] 11976 self.assertEqual(f(), 4) 11977 11978 def test_fast_len_conditional_list(self): 11979 codestr = """ 11980 def f(n: int) -> bool: 11981 l = [i for i in range(n)] 11982 if l: 11983 return True 11984 return False 11985 """ 11986 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 11987 f = self.find_code(c, "f") 11988 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST) 11989 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 11990 f = mod["f"] 11991 for length in [0, 7]: 11992 self.assertEqual(f(length), length > 0) 11993 11994 def test_fast_len_conditional_str(self): 11995 codestr = """ 11996 def f(n: int) -> bool: 11997 l = f"{'a' * n}" 11998 if l: 11999 return True 12000 return False 12001 """ 12002 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12003 f = self.find_code(c, "f") 12004 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_STR) 12005 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12006 f = mod["f"] 12007 for length in [0, 7]: 12008 self.assertEqual(f(length), length > 0) 12009 12010 def test_fast_len_loop_conditional_list(self): 12011 codestr = """ 12012 def f(n: int) -> bool: 12013 l = [i for i in range(n)] 12014 while l: 12015 return True 12016 return False 12017 """ 12018 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12019 f = self.find_code(c, "f") 12020 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST) 12021 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12022 f = mod["f"] 12023 for length in [0, 7]: 12024 self.assertEqual(f(length), length > 0) 12025 12026 def test_fast_len_loop_conditional_str(self): 12027 codestr = """ 12028 def f(n: int) -> bool: 12029 l = f"{'a' * n}" 12030 while l: 12031 return True 12032 return False 12033 """ 12034 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12035 f = self.find_code(c, "f") 12036 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_STR) 12037 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12038 f = mod["f"] 12039 for length in [0, 7]: 12040 self.assertEqual(f(length), length > 0) 12041 12042 def test_fast_len_loop_conditional_tuple(self): 12043 codestr = """ 12044 def f(n: int) -> bool: 12045 l = tuple(i for i in range(n)) 12046 while l: 12047 return True 12048 return False 12049 """ 12050 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12051 f = self.find_code(c, "f") 12052 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_TUPLE) 12053 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12054 f = mod["f"] 12055 for length in [0, 7]: 12056 self.assertEqual(f(length), length > 0) 12057 12058 def test_fast_len_loop_conditional_set(self): 12059 codestr = """ 12060 def f(n: int) -> bool: 12061 l = {i for i in range(n)} 12062 while l: 12063 return True 12064 return False 12065 """ 12066 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12067 f = self.find_code(c, "f") 12068 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_SET) 12069 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12070 f = mod["f"] 12071 for length in [0, 7]: 12072 self.assertEqual(f(length), length > 0) 12073 12074 def test_fast_len_conditional_tuple(self): 12075 codestr = """ 12076 def f(n: int) -> bool: 12077 l = tuple(i for i in range(n)) 12078 if l: 12079 return True 12080 return False 12081 """ 12082 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12083 f = self.find_code(c, "f") 12084 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_TUPLE) 12085 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12086 f = mod["f"] 12087 for length in [0, 7]: 12088 self.assertEqual(f(length), length > 0) 12089 12090 def test_fast_len_conditional_set(self): 12091 codestr = """ 12092 def f(n: int) -> bool: 12093 l = {i for i in range(n)} 12094 if l: 12095 return True 12096 return False 12097 """ 12098 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12099 f = self.find_code(c, "f") 12100 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_SET) 12101 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12102 f = mod["f"] 12103 for length in [0, 7]: 12104 self.assertEqual(f(length), length > 0) 12105 12106 def test_fast_len_conditional_dict(self): 12107 codestr = """ 12108 def f(n: int) -> bool: 12109 l = {i: i for i in range(n)} 12110 if l: 12111 return True 12112 return False 12113 """ 12114 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12115 f = self.find_code(c, "f") 12116 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_DICT) 12117 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12118 f = mod["f"] 12119 for length in [0, 7]: 12120 self.assertEqual(f(length), length > 0) 12121 12122 def test_fast_len_conditional_list_subclass(self): 12123 codestr = """ 12124 from typing import List 12125 12126 class MyList(list): 12127 def __len__(self): 12128 return 1729 12129 12130 def f(n: int, flag: bool) -> bool: 12131 x: List[int] = [i for i in range(n)] 12132 if flag: 12133 x = MyList([i for i in range(n)]) 12134 if x: 12135 return True 12136 return False 12137 """ 12138 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12139 f = self.find_code(c, "f") 12140 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST | FAST_LEN_INEXACT) 12141 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12142 f = mod["f"] 12143 for boolean, length in itertools.product((True, False), [0, 7]): 12144 self.assertEqual( 12145 f(length, boolean), 12146 length > 0 or boolean, 12147 f"length={length}, flag={boolean}", 12148 ) 12149 12150 def test_fast_len_conditional_str_subclass(self): 12151 codestr = """ 12152 class MyStr(str): 12153 def __len__(self): 12154 return 1729 12155 12156 def f(n: int, flag: bool) -> bool: 12157 x: str = f"{'a' * n}" 12158 if flag: 12159 x = MyStr(f"{'a' * n}") 12160 if x: 12161 return True 12162 return False 12163 """ 12164 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12165 f = self.find_code(c, "f") 12166 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_STR | FAST_LEN_INEXACT) 12167 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12168 f = mod["f"] 12169 for boolean, length in itertools.product((True, False), [0, 7]): 12170 self.assertEqual( 12171 f(length, boolean), 12172 length > 0 or boolean, 12173 f"length={length}, flag={boolean}", 12174 ) 12175 12176 def test_fast_len_conditional_tuple_subclass(self): 12177 codestr = """ 12178 class Mytuple(tuple): 12179 def __len__(self): 12180 return 1729 12181 12182 def f(n: int, flag: bool) -> bool: 12183 x = tuple(i for i in range(n)) 12184 if flag: 12185 x = Mytuple([i for i in range(n)]) 12186 if x: 12187 return True 12188 return False 12189 """ 12190 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12191 f = self.find_code(c, "f") 12192 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_TUPLE | FAST_LEN_INEXACT) 12193 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12194 f = mod["f"] 12195 for boolean, length in itertools.product((True, False), [0, 7]): 12196 self.assertEqual( 12197 f(length, boolean), 12198 length > 0 or boolean, 12199 f"length={length}, flag={boolean}", 12200 ) 12201 12202 def test_fast_len_conditional_set_subclass(self): 12203 codestr = """ 12204 class Myset(set): 12205 def __len__(self): 12206 return 1729 12207 12208 def f(n: int, flag: bool) -> bool: 12209 x = set(i for i in range(n)) 12210 if flag: 12211 x = Myset([i for i in range(n)]) 12212 if x: 12213 return True 12214 return False 12215 """ 12216 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12217 f = self.find_code(c, "f") 12218 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_SET | FAST_LEN_INEXACT) 12219 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12220 f = mod["f"] 12221 for boolean, length in itertools.product((True, False), [0, 7]): 12222 self.assertEqual( 12223 f(length, boolean), 12224 length > 0 or boolean, 12225 f"length={length}, flag={boolean}", 12226 ) 12227 12228 def test_fast_len_conditional_list_funcarg(self): 12229 codestr = """ 12230 def z(b: object) -> bool: 12231 return bool(b) 12232 12233 def f(n: int) -> bool: 12234 l = [i for i in range(n)] 12235 if z(l): 12236 return True 12237 return False 12238 """ 12239 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12240 f = self.find_code(c, "f") 12241 # Since the list is given to z(), do not optimize the check 12242 # with FAST_LEN 12243 self.assertNotInBytecode(f, "FAST_LEN") 12244 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12245 f = mod["f"] 12246 for length in [0, 7]: 12247 self.assertEqual(f(length), length > 0) 12248 12249 def test_fast_len_conditional_str_funcarg(self): 12250 codestr = """ 12251 def z(b: object) -> bool: 12252 return bool(b) 12253 12254 def f(n: int) -> bool: 12255 l = f"{'a' * n}" 12256 if z(l): 12257 return True 12258 return False 12259 """ 12260 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12261 f = self.find_code(c, "f") 12262 # Since the list is given to z(), do not optimize the check 12263 # with FAST_LEN 12264 self.assertNotInBytecode(f, "FAST_LEN") 12265 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12266 f = mod["f"] 12267 for length in [0, 7]: 12268 self.assertEqual(f(length), length > 0) 12269 12270 def test_fast_len_conditional_tuple_funcarg(self): 12271 codestr = """ 12272 def z(b: object) -> bool: 12273 return bool(b) 12274 12275 def f(n: int) -> bool: 12276 l = tuple(i for i in range(n)) 12277 if z(l): 12278 return True 12279 return False 12280 """ 12281 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12282 f = self.find_code(c, "f") 12283 # Since the tuple is given to z(), do not optimize the check 12284 # with FAST_LEN 12285 self.assertNotInBytecode(f, "FAST_LEN") 12286 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12287 f = mod["f"] 12288 for length in [0, 7]: 12289 self.assertEqual(f(length), length > 0) 12290 12291 def test_fast_len_conditional_set_funcarg(self): 12292 codestr = """ 12293 def z(b: object) -> bool: 12294 return bool(b) 12295 12296 def f(n: int) -> bool: 12297 l = set(i for i in range(n)) 12298 if z(l): 12299 return True 12300 return False 12301 """ 12302 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12303 f = self.find_code(c, "f") 12304 # Since the set is given to z(), do not optimize the check 12305 # with FAST_LEN 12306 self.assertNotInBytecode(f, "FAST_LEN") 12307 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12308 f = mod["f"] 12309 for length in [0, 7]: 12310 self.assertEqual(f(length), length > 0) 12311 12312 def test_fast_len_conditional_dict_funcarg(self): 12313 codestr = """ 12314 def z(b) -> bool: 12315 return bool(b) 12316 12317 def f(n: int) -> bool: 12318 l = {i: i for i in range(n)} 12319 if z(l): 12320 return True 12321 return False 12322 """ 12323 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12324 f = self.find_code(c, "f") 12325 # Since the dict is given to z(), do not optimize the check 12326 # with FAST_LEN 12327 self.assertNotInBytecode(f, "FAST_LEN") 12328 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12329 f = mod["f"] 12330 for length in [0, 7]: 12331 self.assertEqual(f(length), length > 0) 12332 12333 def test_fast_len_list_subclass(self): 12334 codestr = """ 12335 class mylist(list): 12336 def __len__(self): 12337 return 1111 12338 12339 def f(): 12340 l = mylist([1, 2]) 12341 return len(l) 12342 """ 12343 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12344 f = self.find_code(c, "f") 12345 self.assertNotInBytecode(f, "FAST_LEN") 12346 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12347 f = mod["f"] 12348 self.assertEqual(f(), 1111) 12349 12350 def test_fast_len_str_subclass(self): 12351 codestr = """ 12352 class mystr(str): 12353 def __len__(self): 12354 return 1111 12355 12356 def f(): 12357 s = mystr("a") 12358 return len(s) 12359 """ 12360 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12361 f = self.find_code(c, "f") 12362 self.assertNotInBytecode(f, "FAST_LEN") 12363 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12364 f = mod["f"] 12365 self.assertEqual(f(), 1111) 12366 12367 def test_fast_len_tuple_subclass(self): 12368 codestr = """ 12369 class mytuple(tuple): 12370 def __len__(self): 12371 return 1111 12372 12373 def f(): 12374 l = mytuple([1, 2]) 12375 return len(l) 12376 """ 12377 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12378 f = self.find_code(c, "f") 12379 self.assertNotInBytecode(f, "FAST_LEN") 12380 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12381 f = mod["f"] 12382 self.assertEqual(f(), 1111) 12383 12384 def test_fast_len_set_subclass(self): 12385 codestr = """ 12386 class myset(set): 12387 def __len__(self): 12388 return 1111 12389 12390 def f(): 12391 l = myset([1, 2]) 12392 return len(l) 12393 """ 12394 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12395 f = self.find_code(c, "f") 12396 self.assertNotInBytecode(f, "FAST_LEN") 12397 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12398 f = mod["f"] 12399 self.assertEqual(f(), 1111) 12400 12401 def test_fast_len_dict_subclass(self): 12402 codestr = """ 12403 from typing import Dict 12404 12405 class mydict(Dict[str, int]): 12406 def __len__(self): 12407 return 1111 12408 12409 def f(): 12410 l = mydict(a=1, b=2) 12411 return len(l) 12412 """ 12413 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12414 f = self.find_code(c, "f") 12415 self.assertNotInBytecode(f, "FAST_LEN") 12416 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12417 f = mod["f"] 12418 self.assertEqual(f(), 1111) 12419 12420 def test_fast_len_list_subclass_2(self): 12421 codestr = """ 12422 class mylist(list): 12423 def __len__(self): 12424 return 1111 12425 12426 def f(x): 12427 l = [1, 2] 12428 if x: 12429 l = mylist([1, 2]) 12430 return len(l) 12431 """ 12432 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12433 f = self.find_code(c, "f") 12434 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_LIST | FAST_LEN_INEXACT) 12435 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12436 f = mod["f"] 12437 self.assertEqual(f(True), 1111) 12438 12439 def test_fast_len_str_subclass_2(self): 12440 codestr = """ 12441 class mystr(str): 12442 def __len__(self): 12443 return 1111 12444 12445 def f(x): 12446 s = "abc" 12447 if x: 12448 s = mystr("pqr") 12449 return len(s) 12450 """ 12451 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12452 f = self.find_code(c, "f") 12453 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_STR | FAST_LEN_INEXACT) 12454 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12455 f = mod["f"] 12456 self.assertEqual(f(True), 1111) 12457 self.assertEqual(f(False), 3) 12458 12459 def test_fast_len_tuple_subclass_2(self): 12460 codestr = """ 12461 class mytuple(tuple): 12462 def __len__(self): 12463 return 1111 12464 12465 def f(x, a, b): 12466 l = (a, b) 12467 if x: 12468 l = mytuple([a, b]) 12469 return len(l) 12470 """ 12471 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12472 f = self.find_code(c, "f") 12473 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_TUPLE | FAST_LEN_INEXACT) 12474 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12475 f = mod["f"] 12476 self.assertEqual(f(True, 1, 2), 1111) 12477 12478 def test_fast_len_dict_subclass_2(self): 12479 codestr = """ 12480 from typing import Dict 12481 12482 class mydict(Dict[str, int]): 12483 def __len__(self): 12484 return 1111 12485 12486 def f(x, a, b): 12487 l: Dict[str, int] = {'c': 3} 12488 if x: 12489 l = mydict(a=1, b=2) 12490 return len(l) 12491 """ 12492 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12493 f = self.find_code(c, "f") 12494 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_DICT | FAST_LEN_INEXACT) 12495 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12496 f = mod["f"] 12497 self.assertEqual(f(True, 1, 2), 1111) 12498 12499 def test_fast_len_set_subclass_2(self): 12500 codestr = """ 12501 class myset(set): 12502 def __len__(self): 12503 return 1111 12504 12505 def f(x, a, b): 12506 l = {a, b} 12507 if x: 12508 l = myset([a, b]) 12509 return len(l) 12510 """ 12511 c = self.compile(codestr, StaticCodeGenerator, modname="foo.py") 12512 f = self.find_code(c, "f") 12513 self.assertInBytecode(f, "FAST_LEN", FAST_LEN_SET | FAST_LEN_INEXACT) 12514 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12515 f = mod["f"] 12516 self.assertEqual(f(True, 1, 2), 1111) 12517 12518 def test_dynamic_type_param(self): 12519 """DYNAMIC as type param of generic doesn't render the whole type DYNAMIC.""" 12520 codestr = """ 12521 from __static__ import int64, clen 12522 from nonstatic import Foo 12523 from typing import Dict 12524 12525 def f(d: Dict[Foo, int]): 12526 x: int64 = clen(d) 12527 """ 12528 self.compile(codestr) 12529 12530 def test_checked_dict(self): 12531 x = chkdict[str, str]() 12532 x["abc"] = "foo" 12533 self.assertEqual(repr(x), "{'abc': 'foo'}") 12534 x = chkdict[str, int]() 12535 x["abc"] = 42 12536 x = chkdict[int, str]() 12537 x[42] = "abc" 12538 12539 def test_checked_dict_optional(self): 12540 x = chkdict[str, Optional[str]]() 12541 x["abc"] = None 12542 x = chkdict[Optional[str], str]() 12543 x[None] = "abc" 12544 12545 def test_checked_dict_nonoptional(self): 12546 x = chkdict[str, Optional[str]]() 12547 with self.assertRaises(TypeError): 12548 x[None] = "abc" 12549 x = chkdict[Optional[str], str]() 12550 with self.assertRaises(TypeError): 12551 x["abc"] = None 12552 12553 def test_checked_dict_types_enforced(self): 12554 x = chkdict[str, str]() 12555 with self.assertRaises(TypeError): 12556 x[42] = "abc" 12557 self.assertEqual(x, {}) 12558 with self.assertRaises(TypeError): 12559 x["abc"] = 42 12560 self.assertEqual(x, {}) 12561 12562 x = chkdict[str, int]() 12563 with self.assertRaises(TypeError): 12564 x[42] = 42 12565 self.assertEqual(x, {}) 12566 with self.assertRaises(TypeError): 12567 x["abc"] = "abc" 12568 self.assertEqual(x, {}) 12569 12570 def test_checked_dict_ctor(self): 12571 self.assertEqual(chkdict[str, str](x="abc"), {"x": "abc"}) 12572 self.assertEqual(chkdict[str, int](x=42), {"x": 42}) 12573 self.assertEqual(chkdict[str, str]({"x": "abc"}), {"x": "abc"}) 12574 self.assertEqual(chkdict[str, str]([("a", "b")]), {"a": "b"}) 12575 self.assertEqual(chkdict[str, str]([("a", "b")]), {"a": "b"}) 12576 self.assertEqual(chkdict[str, str](chkdict[str, str](x="abc")), {"x": "abc"}) 12577 self.assertEqual(chkdict[str, str](chkdict[str, object](x="abc")), {"x": "abc"}) 12578 self.assertEqual(chkdict[str, str](UserDict(x="abc")), {"x": "abc"}) 12579 self.assertEqual(chkdict[str, str](UserDict(x="abc"), x="foo"), {"x": "foo"}) 12580 12581 def test_checked_dict_bad_ctor(self): 12582 with self.assertRaises(TypeError): 12583 chkdict[str, str](None) 12584 12585 def test_checked_dict_setdefault(self): 12586 x = chkdict[str, str]() 12587 x.setdefault("abc", "foo") 12588 self.assertEqual(x, {"abc": "foo"}) 12589 12590 def test_checked_dict___module__(self): 12591 class Lol: 12592 pass 12593 12594 x = chkdict[int, Lol]() 12595 self.assertEqual(type(x).__module__, "__static__") 12596 12597 def test_checked_dict_setdefault_bad_values(self): 12598 x = chkdict[str, int]() 12599 with self.assertRaises(TypeError): 12600 x.setdefault("abc", "abc") 12601 self.assertEqual(x, {}) 12602 with self.assertRaises(TypeError): 12603 x.setdefault(42, 42) 12604 self.assertEqual(x, {}) 12605 12606 def test_checked_dict_fromkeys(self): 12607 x = chkdict[str, int].fromkeys("abc", 42) 12608 self.assertEqual(x, {"a": 42, "b": 42, "c": 42}) 12609 12610 def test_checked_dict_fromkeys_optional(self): 12611 x = chkdict[Optional[str], int].fromkeys(["a", "b", "c", None], 42) 12612 self.assertEqual(x, {"a": 42, "b": 42, "c": 42, None: 42}) 12613 12614 x = chkdict[str, Optional[int]].fromkeys("abc", None) 12615 self.assertEqual(x, {"a": None, "b": None, "c": None}) 12616 12617 def test_checked_dict_fromkeys_bad_types(self): 12618 with self.assertRaises(TypeError): 12619 chkdict[str, int].fromkeys([2], 42) 12620 12621 with self.assertRaises(TypeError): 12622 chkdict[str, int].fromkeys("abc", object()) 12623 12624 with self.assertRaises(TypeError): 12625 chkdict[str, int].fromkeys("abc") 12626 12627 def test_checked_dict_copy(self): 12628 x = chkdict[str, str](x="abc") 12629 self.assertEqual(type(x), chkdict[str, str]) 12630 self.assertEqual(x, {"x": "abc"}) 12631 12632 def test_checked_dict_clear(self): 12633 x = chkdict[str, str](x="abc") 12634 x.clear() 12635 self.assertEqual(x, {}) 12636 12637 def test_checked_dict_update(self): 12638 x = chkdict[str, str](x="abc") 12639 x.update(y="foo") 12640 self.assertEqual(x, {"x": "abc", "y": "foo"}) 12641 x.update({"z": "bar"}) 12642 self.assertEqual(x, {"x": "abc", "y": "foo", "z": "bar"}) 12643 12644 def test_checked_dict_update_bad_type(self): 12645 x = chkdict[str, int]() 12646 with self.assertRaises(TypeError): 12647 x.update(x="abc") 12648 self.assertEqual(x, {}) 12649 with self.assertRaises(TypeError): 12650 x.update({"x": "abc"}) 12651 with self.assertRaises(TypeError): 12652 x.update({24: 42}) 12653 self.assertEqual(x, {}) 12654 12655 def test_checked_dict_keys(self): 12656 x = chkdict[str, int](x=2) 12657 self.assertEqual(list(x.keys()), ["x"]) 12658 x = chkdict[str, int](x=2, y=3) 12659 self.assertEqual(list(x.keys()), ["x", "y"]) 12660 12661 def test_checked_dict_values(self): 12662 x = chkdict[str, int](x=2, y=3) 12663 self.assertEqual(list(x.values()), [2, 3]) 12664 12665 def test_checked_dict_items(self): 12666 x = chkdict[str, int](x=2) 12667 self.assertEqual( 12668 list(x.items()), 12669 [ 12670 ("x", 2), 12671 ], 12672 ) 12673 x = chkdict[str, int](x=2, y=3) 12674 self.assertEqual(list(x.items()), [("x", 2), ("y", 3)]) 12675 12676 def test_checked_dict_pop(self): 12677 x = chkdict[str, int](x=2) 12678 y = x.pop("x") 12679 self.assertEqual(y, 2) 12680 with self.assertRaises(KeyError): 12681 x.pop("z") 12682 12683 def test_checked_dict_popitem(self): 12684 x = chkdict[str, int](x=2) 12685 y = x.popitem() 12686 self.assertEqual(y, ("x", 2)) 12687 with self.assertRaises(KeyError): 12688 x.popitem() 12689 12690 def test_checked_dict_get(self): 12691 x = chkdict[str, int](x=2) 12692 self.assertEqual(x.get("x"), 2) 12693 self.assertEqual(x.get("y", 100), 100) 12694 12695 def test_checked_dict_errors(self): 12696 x = chkdict[str, int](x=2) 12697 with self.assertRaises(TypeError): 12698 x.get(100) 12699 with self.assertRaises(TypeError): 12700 x.get("x", "abc") 12701 12702 def test_checked_dict_sizeof(self): 12703 x = chkdict[str, int](x=2).__sizeof__() 12704 self.assertEqual(type(x), int) 12705 12706 def test_checked_dict_getitem(self): 12707 x = chkdict[str, int](x=2) 12708 self.assertEqual(x.__getitem__("x"), 2) 12709 12710 def test_checked_dict_free_list(self): 12711 t1 = chkdict[str, int] 12712 t2 = chkdict[str, str] 12713 x = t1() 12714 x_id1 = id(x) 12715 del x 12716 x = t2() 12717 x_id2 = id(x) 12718 self.assertEqual(x_id1, x_id2) 12719 12720 def test_check_args(self): 12721 """ 12722 Tests whether CHECK_ARGS can handle variables which are in a Cell, 12723 and are a positional arg at index 0. 12724 """ 12725 12726 codestr = """ 12727 def use(i: object) -> object: 12728 return i 12729 12730 def outer(x: int) -> object: 12731 12732 def inner() -> None: 12733 use(x) 12734 12735 return use(x) 12736 """ 12737 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12738 outer = mod["outer"] 12739 self.assertEqual(outer(1), 1) 12740 12741 def test_check_args_2(self): 12742 """ 12743 Tests whether CHECK_ARGS can handle multiple variables which are in a Cell, 12744 and are positional args. 12745 """ 12746 12747 codestr = """ 12748 def use(i: object) -> object: 12749 return i 12750 12751 def outer(x: int, y: str) -> object: 12752 12753 def inner() -> None: 12754 use(x) 12755 use(y) 12756 12757 use(x) 12758 return use(y) 12759 """ 12760 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12761 outer = mod["outer"] 12762 self.assertEqual(outer(1, "yo"), "yo") 12763 # Force JIT-compiled code to go through argument checks after 12764 # keyword arg binding 12765 self.assertEqual(outer(1, y="yo"), "yo") 12766 12767 def test_check_args_3(self): 12768 """ 12769 Tests whether CHECK_ARGS can handle variables which are in a Cell, 12770 and are a positional arg at index > 0. 12771 """ 12772 12773 codestr = """ 12774 def use(i: object) -> object: 12775 return i 12776 12777 def outer(x: int, y: str) -> object: 12778 12779 def inner() -> None: 12780 use(y) 12781 12782 use(x) 12783 return use(y) 12784 """ 12785 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12786 outer = mod["outer"] 12787 self.assertEqual(outer(1, "yo"), "yo") 12788 # Force JIT-compiled code to go through argument checks after 12789 # keyword arg binding 12790 self.assertEqual(outer(1, y="yo"), "yo") 12791 12792 def test_check_args_4(self): 12793 """ 12794 Tests whether CHECK_ARGS can handle variables which are in a Cell, 12795 and are a kwarg at index 0. 12796 """ 12797 12798 codestr = """ 12799 def use(i: object) -> object: 12800 return i 12801 12802 def outer(x: int = 0) -> object: 12803 12804 def inner() -> None: 12805 use(x) 12806 12807 return use(x) 12808 """ 12809 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12810 outer = mod["outer"] 12811 self.assertEqual(outer(1), 1) 12812 12813 def test_check_args_5(self): 12814 """ 12815 Tests whether CHECK_ARGS can handle variables which are in a Cell, 12816 and are a kw-only arg. 12817 """ 12818 codestr = """ 12819 def use(i: object) -> object: 12820 return i 12821 12822 def outer(x: int, *, y: str = "lol") -> object: 12823 12824 def inner() -> None: 12825 use(y) 12826 12827 return use(y) 12828 12829 """ 12830 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12831 outer = mod["outer"] 12832 self.assertEqual(outer(1, y="hi"), "hi") 12833 12834 def test_check_args_6(self): 12835 """ 12836 Tests whether CHECK_ARGS can handle variables which are in a Cell, 12837 and are a pos-only arg. 12838 """ 12839 codestr = """ 12840 def use(i: object) -> object: 12841 return i 12842 12843 def outer(x: int, /, y: str) -> object: 12844 12845 def inner() -> None: 12846 use(y) 12847 12848 return use(y) 12849 12850 """ 12851 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12852 outer = mod["outer"] 12853 self.assertEqual(outer(1, "hi"), "hi") 12854 12855 def test_check_args_7(self): 12856 """ 12857 Tests whether CHECK_ARGS can handle multiple variables which are in a Cell, 12858 and are a mix of positional, pos-only and kw-only args. 12859 """ 12860 12861 codestr = """ 12862 def use(i: object) -> object: 12863 return i 12864 12865 def outer(x: int, /, y: int, *, z: str = "lol") -> object: 12866 12867 def inner() -> None: 12868 use(x) 12869 use(y) 12870 use(z) 12871 12872 return use(x), use(y), use(z) 12873 """ 12874 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12875 outer = mod["outer"] 12876 self.assertEqual(outer(3, 2, z="hi"), (3, 2, "hi")) 12877 12878 def test_str_split(self): 12879 codestr = """ 12880 def get_str() -> str: 12881 return "something here" 12882 12883 def test() -> str: 12884 a, b = get_str().split(None, 1) 12885 return b 12886 """ 12887 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12888 test = mod["test"] 12889 self.assertEqual(test(), "here") 12890 12891 def test_vtable_shadow_builtin_subclass_after_init(self): 12892 """Shadowing methods on subclass of list after vtables are inited.""" 12893 12894 class MyList(list): 12895 pass 12896 12897 def myreverse(self): 12898 return 1 12899 12900 codestr = """ 12901 def f(l: list): 12902 l.reverse() 12903 return l 12904 """ 12905 f = self.find_code(self.compile(codestr), "f") 12906 self.assertInBytecode( 12907 f, "INVOKE_METHOD", ((("builtins", "list", "reverse"), 0)) 12908 ) 12909 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12910 # Now cause vtables to be inited 12911 self.assertEqual(mod["f"]([1, 2]), [2, 1]) 12912 12913 # And now patch 12914 MyList.reverse = myreverse 12915 12916 self.assertEqual(MyList().reverse(), 1) 12917 12918 def test_vtable_shadow_builtin_subclass_before_init(self): 12919 """Shadowing methods on subclass of list before vtables are inited.""" 12920 # Create a subclass of list... 12921 class MyList(list): 12922 pass 12923 12924 def myreverse(self): 12925 return 1 12926 12927 # ... and override a slot from list with a non-static func 12928 MyList.reverse = myreverse 12929 12930 codestr = """ 12931 def f(l: list): 12932 l.reverse() 12933 return l 12934 """ 12935 f = self.find_code(self.compile(codestr), "f") 12936 self.assertInBytecode( 12937 f, "INVOKE_METHOD", ((("builtins", "list", "reverse"), 0)) 12938 ) 12939 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12940 # Now cause vtables to be inited 12941 self.assertEqual(mod["f"]([1, 2]), [2, 1]) 12942 12943 # ... and this should not blow up when we remove the override. 12944 del MyList.reverse 12945 12946 self.assertEqual(MyList().reverse(), None) 12947 12948 def test_vtable_shadow_static_subclass(self): 12949 """Shadowing methods of a static type before its inited should not bypass typechecks.""" 12950 # Define a static type and shadow a subtype method before invoking. 12951 codestr = """ 12952 class StaticType: 12953 def foo(self) -> int: 12954 return 1 12955 12956 class SubType(StaticType): 12957 pass 12958 12959 def goodfoo(self): 12960 return 2 12961 12962 SubType.foo = goodfoo 12963 12964 def f(x: StaticType) -> int: 12965 return x.foo() 12966 """ 12967 f = self.find_code(self.compile(codestr), "f") 12968 self.assertInBytecode( 12969 f, "INVOKE_METHOD", ((("<module>", "StaticType", "foo"), 0)) 12970 ) 12971 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 12972 SubType = mod["SubType"] 12973 # Now invoke: 12974 self.assertEqual(mod["f"](SubType()), 2) 12975 12976 # And replace the function again, forcing us to find the right slot type: 12977 def badfoo(self): 12978 return "foo" 12979 12980 SubType.foo = badfoo 12981 12982 with self.assertRaisesRegex(TypeError, "expected int, got str"): 12983 mod["f"](SubType()) 12984 12985 def test_vtable_shadow_static_subclass_nonstatic_patch(self): 12986 """Shadowing methods of a static type before its inited should not bypass typechecks.""" 12987 code1 = """ 12988 def nonstaticfoo(self): 12989 return 2 12990 """ 12991 with self.in_module( 12992 code1, code_gen=PythonCodeGenerator, name="nonstatic" 12993 ) as mod1: 12994 # Define a static type and shadow a subtype method with a non-static func before invoking. 12995 codestr = """ 12996 from nonstatic import nonstaticfoo 12997 12998 class StaticType: 12999 def foo(self) -> int: 13000 return 1 13001 13002 class SubType(StaticType): 13003 pass 13004 13005 SubType.foo = nonstaticfoo 13006 13007 def f(x: StaticType) -> int: 13008 return x.foo() 13009 13010 def badfoo(self): 13011 return "foo" 13012 """ 13013 code = self.compile(codestr) 13014 f = self.find_code(code, "f") 13015 self.assertInBytecode( 13016 f, "INVOKE_METHOD", ((("<module>", "StaticType", "foo"), 0)) 13017 ) 13018 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13019 SubType = mod["SubType"] 13020 badfoo = mod["badfoo"] 13021 13022 # And replace the function again, forcing us to find the right slot type: 13023 SubType.foo = badfoo 13024 13025 with self.assertRaisesRegex(TypeError, "expected int, got str"): 13026 mod["f"](SubType()) 13027 13028 def test_vtable_shadow_grandparent(self): 13029 codestr = """ 13030 class Base: 13031 def foo(self) -> int: 13032 return 1 13033 13034 class Sub(Base): 13035 pass 13036 13037 class Grand(Sub): 13038 pass 13039 13040 def f(x: Base) -> int: 13041 return x.foo() 13042 13043 def grandfoo(self): 13044 return "foo" 13045 """ 13046 f = self.find_code(self.compile(codestr), "f") 13047 self.assertInBytecode(f, "INVOKE_METHOD", ((("<module>", "Base", "foo"), 0))) 13048 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13049 Grand = mod["Grand"] 13050 grandfoo = mod["grandfoo"] 13051 f = mod["f"] 13052 13053 # init vtables 13054 self.assertEqual(f(Grand()), 1) 13055 13056 # patch in an override of the grandparent method 13057 Grand.foo = grandfoo 13058 13059 with self.assertRaisesRegex(TypeError, "expected int, got str"): 13060 f(Grand()) 13061 13062 def test_for_iter_list(self): 13063 codestr = """ 13064 from typing import List 13065 13066 def f(n: int) -> List: 13067 acc = [] 13068 l = [i for i in range(n)] 13069 for i in l: 13070 acc.append(i + 1) 13071 return acc 13072 """ 13073 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13074 f = mod["f"] 13075 self.assertNotInBytecode(f, "FOR_ITER") 13076 self.assertEqual(f(4), [i + 1 for i in range(4)]) 13077 13078 def test_for_iter_tuple(self): 13079 codestr = """ 13080 from typing import List 13081 13082 def f(n: int) -> List: 13083 acc = [] 13084 l = tuple([i for i in range(n)]) 13085 for i in l: 13086 acc.append(i + 1) 13087 return acc 13088 """ 13089 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13090 f = mod["f"] 13091 self.assertNotInBytecode(f, "FOR_ITER") 13092 self.assertEqual(f(4), [i + 1 for i in range(4)]) 13093 13094 def test_for_iter_sequence_orelse(self): 13095 codestr = """ 13096 from typing import List 13097 13098 def f(n: int) -> List: 13099 acc = [] 13100 l = [i for i in range(n)] 13101 for i in l: 13102 acc.append(i + 1) 13103 else: 13104 acc.append(999) 13105 return acc 13106 """ 13107 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13108 f = mod["f"] 13109 self.assertNotInBytecode(f, "FOR_ITER") 13110 self.assertEqual(f(4), [i + 1 for i in range(4)] + [999]) 13111 13112 def test_for_iter_sequence_break(self): 13113 codestr = """ 13114 from typing import List 13115 13116 def f(n: int) -> List: 13117 acc = [] 13118 l = [i for i in range(n)] 13119 for i in l: 13120 if i == 3: 13121 break 13122 acc.append(i + 1) 13123 return acc 13124 """ 13125 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13126 f = mod["f"] 13127 self.assertNotInBytecode(f, "FOR_ITER") 13128 self.assertEqual(f(5), [1, 2, 3]) 13129 13130 def test_for_iter_sequence_orelse_break(self): 13131 codestr = """ 13132 from typing import List 13133 13134 def f(n: int) -> List: 13135 acc = [] 13136 l = [i for i in range(n)] 13137 for i in l: 13138 if i == 2: 13139 break 13140 acc.append(i + 1) 13141 else: 13142 acc.append(999) 13143 return acc 13144 """ 13145 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13146 f = mod["f"] 13147 self.assertNotInBytecode(f, "FOR_ITER") 13148 self.assertEqual(f(4), [1, 2]) 13149 13150 def test_for_iter_sequence_return(self): 13151 codestr = """ 13152 from typing import List 13153 13154 def f(n: int) -> List: 13155 acc = [] 13156 l = [i for i in range(n)] 13157 for i in l: 13158 if i == 3: 13159 return acc 13160 acc.append(i + 1) 13161 return acc 13162 """ 13163 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13164 f = mod["f"] 13165 self.assertNotInBytecode(f, "FOR_ITER") 13166 self.assertEqual(f(6), [1, 2, 3]) 13167 13168 def test_nested_for_iter_sequence(self): 13169 codestr = """ 13170 from typing import List 13171 13172 def f(n: int) -> List: 13173 acc = [] 13174 l = [i for i in range(n)] 13175 for i in l: 13176 for j in l: 13177 acc.append(i + j) 13178 return acc 13179 """ 13180 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13181 f = mod["f"] 13182 self.assertNotInBytecode(f, "FOR_ITER") 13183 self.assertEqual(f(3), [0, 1, 2, 1, 2, 3, 2, 3, 4]) 13184 13185 def test_nested_for_iter_sequence_break(self): 13186 codestr = """ 13187 from typing import List 13188 13189 def f(n: int) -> List: 13190 acc = [] 13191 l = [i for i in range(n)] 13192 for i in l: 13193 for j in l: 13194 if j == 2: 13195 break 13196 acc.append(i + j) 13197 return acc 13198 """ 13199 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13200 f = mod["f"] 13201 self.assertNotInBytecode(f, "FOR_ITER") 13202 self.assertEqual(f(3), [0, 1, 1, 2, 2, 3]) 13203 13204 def test_nested_for_iter_sequence_return(self): 13205 codestr = """ 13206 from typing import List 13207 13208 def f(n: int) -> List: 13209 acc = [] 13210 l = [i for i in range(n)] 13211 for i in l: 13212 for j in l: 13213 if j == 1: 13214 return acc 13215 acc.append(i + j) 13216 return acc 13217 """ 13218 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13219 f = mod["f"] 13220 self.assertNotInBytecode(f, "FOR_ITER") 13221 self.assertEqual(f(3), [0]) 13222 13223 def test_for_iter_unchecked_get(self): 13224 """We don't need to check sequence bounds when we've just compared with the list size.""" 13225 codestr = """ 13226 def f(): 13227 l = [1, 2, 3] 13228 acc = [] 13229 for x in l: 13230 acc.append(x) 13231 return acc 13232 """ 13233 with self.in_module(codestr) as mod: 13234 f = mod["f"] 13235 self.assertInBytecode(f, "SEQUENCE_GET", SEQ_LIST | SEQ_SUBSCR_UNCHECKED) 13236 self.assertEqual(f(), [1, 2, 3]) 13237 13238 def test_for_iter_list_modified(self): 13239 codestr = """ 13240 def f(): 13241 l = [1, 2, 3, 4, 5] 13242 acc = [] 13243 for x in l: 13244 acc.append(x) 13245 l[2:] = [] 13246 return acc 13247 """ 13248 with self.in_module(codestr) as mod: 13249 f = mod["f"] 13250 self.assertNotInBytecode(f, "FOR_ITER") 13251 self.assertEqual(f(), [1, 2]) 13252 13253 def test_sorted(self): 13254 """sorted() builtin returns an Exact[List].""" 13255 codestr = """ 13256 from typing import Iterable 13257 13258 def f(l: Iterable[int]): 13259 for x in sorted(l): 13260 pass 13261 """ 13262 with self.in_module(codestr) as mod: 13263 f = mod["f"] 13264 self.assertNotInBytecode(f, "FOR_ITER") 13265 self.assertInBytecode(f, "REFINE_TYPE", ("builtins", "list")) 13266 13267 def test_min(self): 13268 codestr = """ 13269 def f(a: int, b: int) -> int: 13270 return min(a, b) 13271 """ 13272 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13273 f = mod["f"] 13274 self.assertInBytecode(f, "COMPARE_OP", "<=") 13275 self.assertInBytecode(f, "POP_JUMP_IF_FALSE") 13276 self.assertEqual(f(1, 3), 1) 13277 self.assertEqual(f(3, 1), 1) 13278 13279 def test_min_stability(self): 13280 codestr = """ 13281 def f(a: int, b: int) -> int: 13282 return min(a, b) 13283 """ 13284 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13285 f = mod["f"] 13286 self.assertInBytecode(f, "COMPARE_OP", "<=") 13287 self.assertInBytecode(f, "POP_JUMP_IF_FALSE") 13288 # p & q should be different objects, but with same value 13289 p = int("11334455667") 13290 q = int("11334455667") 13291 self.assertNotEqual(id(p), id(q)) 13292 # Since p and q are equal, the returned value should be the first arg 13293 self.assertEqual(id(f(p, q)), id(p)) 13294 self.assertEqual(id(f(q, p)), id(q)) 13295 13296 def test_max(self): 13297 codestr = """ 13298 def f(a: int, b: int) -> int: 13299 return max(a, b) 13300 """ 13301 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13302 f = mod["f"] 13303 self.assertInBytecode(f, "COMPARE_OP", ">=") 13304 self.assertInBytecode(f, "POP_JUMP_IF_FALSE") 13305 self.assertEqual(f(1, 3), 3) 13306 self.assertEqual(f(3, 1), 3) 13307 13308 def test_max_stability(self): 13309 codestr = """ 13310 def f(a: int, b: int) -> int: 13311 return max(a, b) 13312 """ 13313 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13314 f = mod["f"] 13315 self.assertInBytecode(f, "COMPARE_OP", ">=") 13316 self.assertInBytecode(f, "POP_JUMP_IF_FALSE") 13317 # p & q should be different objects, but with same value 13318 p = int("11334455667") 13319 q = int("11334455667") 13320 self.assertNotEqual(id(p), id(q)) 13321 # Since p and q are equal, the returned value should be the first arg 13322 self.assertEqual(id(f(p, q)), id(p)) 13323 self.assertEqual(id(f(q, p)), id(q)) 13324 13325 def test_extremum_primitive(self): 13326 codestr = """ 13327 from __static__ import int8 13328 13329 def f() -> None: 13330 a: int8 = 4 13331 b: int8 = 5 13332 min(a, b) 13333 """ 13334 with self.assertRaisesRegex( 13335 TypedSyntaxError, "Call argument cannot be a primitive" 13336 ): 13337 self.compile(codestr, StaticCodeGenerator, modname="foo.py") 13338 13339 def test_extremum_non_specialization_kwarg(self): 13340 codestr = """ 13341 def f() -> None: 13342 a = "4" 13343 b = "5" 13344 min(a, b, key=int) 13345 """ 13346 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13347 f = mod["f"] 13348 self.assertNotInBytecode(f, "COMPARE_OP") 13349 self.assertNotInBytecode(f, "POP_JUMP_IF_FALSE") 13350 13351 def test_extremum_non_specialization_stararg(self): 13352 codestr = """ 13353 def f() -> None: 13354 a = [3, 4] 13355 min(*a) 13356 """ 13357 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13358 f = mod["f"] 13359 self.assertNotInBytecode(f, "COMPARE_OP") 13360 self.assertNotInBytecode(f, "POP_JUMP_IF_FALSE") 13361 13362 def test_extremum_non_specialization_dstararg(self): 13363 codestr = """ 13364 def f() -> None: 13365 k = { 13366 "default": 5 13367 } 13368 min(3, 4, **k) 13369 """ 13370 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13371 f = mod["f"] 13372 self.assertNotInBytecode(f, "COMPARE_OP") 13373 self.assertNotInBytecode(f, "POP_JUMP_IF_FALSE") 13374 13375 def test_try_return_finally(self): 13376 codestr = """ 13377 from typing import List 13378 13379 def f1(x: List): 13380 try: 13381 return 13382 finally: 13383 x.append("hi") 13384 """ 13385 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13386 f1 = mod["f1"] 13387 l = [] 13388 f1(l) 13389 self.assertEqual(l, ["hi"]) 13390 13391 def test_cbool(self): 13392 for b in ("True", "False"): 13393 codestr = f""" 13394 from __static__ import cbool 13395 13396 def f() -> int: 13397 x: cbool = {b} 13398 if x: 13399 return 1 13400 else: 13401 return 2 13402 """ 13403 with self.subTest(b=b): 13404 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13405 f = mod["f"] 13406 self.assertInBytecode(f, "PRIMITIVE_LOAD_CONST") 13407 self.assertInBytecode( 13408 f, "STORE_LOCAL", (0, ("__static__", "cbool")) 13409 ) 13410 self.assertInBytecode(f, "POP_JUMP_IF_ZERO") 13411 self.assertEqual(f(), 1 if b == "True" else 2) 13412 13413 def test_cbool_field(self): 13414 codestr = """ 13415 from __static__ import cbool 13416 13417 class C: 13418 def __init__(self, x: cbool) -> None: 13419 self.x: cbool = x 13420 13421 def f(c: C): 13422 if c.x: 13423 return True 13424 return False 13425 """ 13426 with self.in_module(codestr) as mod: 13427 f, C = mod["f"], mod["C"] 13428 self.assertInBytecode(f, "LOAD_FIELD", (mod["__name__"], "C", "x")) 13429 self.assertInBytecode(f, "POP_JUMP_IF_ZERO") 13430 self.assertIs(C(True).x, True) 13431 self.assertIs(C(False).x, False) 13432 self.assertIs(f(C(True)), True) 13433 self.assertIs(f(C(False)), False) 13434 13435 def test_cbool_cast(self): 13436 codestr = """ 13437 from __static__ import cbool 13438 13439 def f(y: bool) -> int: 13440 x: cbool = y 13441 if x: 13442 return 1 13443 else: 13444 return 2 13445 """ 13446 with self.assertRaisesRegex(TypedSyntaxError, type_mismatch("bool", "cbool")): 13447 self.compile(codestr, StaticCodeGenerator, modname="foo") 13448 13449 def test_primitive_compare_returns_cbool(self): 13450 codestr = """ 13451 from __static__ import cbool, int64 13452 13453 def f(x: int64, y: int64) -> cbool: 13454 return x == y 13455 """ 13456 with self.in_module(codestr) as mod: 13457 f = mod["f"] 13458 self.assertIs(f(1, 1), True) 13459 self.assertIs(f(1, 2), False) 13460 13461 def test_no_cbool_math(self): 13462 codestr = """ 13463 from __static__ import cbool 13464 13465 def f(x: cbool, y: cbool) -> cbool: 13466 return x + y 13467 """ 13468 with self.assertRaisesRegex( 13469 TypedSyntaxError, "cbool is not a valid operand type for add" 13470 ): 13471 self.compile(codestr) 13472 13473 def test_chkdict_del(self): 13474 codestr = """ 13475 def f(): 13476 x = {} 13477 x[1] = "a" 13478 x[2] = "b" 13479 del x[1] 13480 return x 13481 """ 13482 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13483 f = mod["f"] 13484 ret = f() 13485 self.assertNotIn(1, ret) 13486 self.assertIn(2, ret) 13487 13488 def test_final_constant_folding_int(self): 13489 codestr = """ 13490 from typing import Final 13491 13492 X: Final[int] = 1337 13493 13494 def plus_1337(i: int) -> int: 13495 return i + X 13496 """ 13497 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13498 plus_1337 = mod["plus_1337"] 13499 self.assertInBytecode(plus_1337, "LOAD_CONST", 1337) 13500 self.assertNotInBytecode(plus_1337, "LOAD_GLOBAL") 13501 self.assertEqual(plus_1337(3), 1340) 13502 13503 def test_final_constant_folding_bool(self): 13504 codestr = """ 13505 from typing import Final 13506 13507 X: Final[bool] = True 13508 13509 def f() -> bool: 13510 return not X 13511 """ 13512 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13513 f = mod["f"] 13514 self.assertInBytecode(f, "LOAD_CONST", True) 13515 self.assertNotInBytecode(f, "LOAD_GLOBAL") 13516 self.assertFalse(f()) 13517 13518 def test_final_constant_folding_str(self): 13519 codestr = """ 13520 from typing import Final 13521 13522 X: Final[str] = "omg" 13523 13524 def f() -> str: 13525 return X[1] 13526 """ 13527 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13528 f = mod["f"] 13529 self.assertInBytecode(f, "LOAD_CONST", "omg") 13530 self.assertNotInBytecode(f, "LOAD_GLOBAL") 13531 self.assertEqual(f(), "m") 13532 13533 def test_final_constant_folding_disabled_on_nonfinals(self): 13534 codestr = """ 13535 from typing import Final 13536 13537 X: str = "omg" 13538 13539 def f() -> str: 13540 return X[1] 13541 """ 13542 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13543 f = mod["f"] 13544 self.assertNotInBytecode(f, "LOAD_CONST", "omg") 13545 self.assertInBytecode(f, "LOAD_GLOBAL", "X") 13546 self.assertEqual(f(), "m") 13547 13548 def test_final_constant_folding_disabled_on_nonconstant_finals(self): 13549 codestr = """ 13550 from typing import Final 13551 13552 def p() -> str: 13553 return "omg" 13554 13555 X: Final[str] = p() 13556 13557 def f() -> str: 13558 return X[1] 13559 """ 13560 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13561 f = mod["f"] 13562 self.assertNotInBytecode(f, "LOAD_CONST", "omg") 13563 self.assertInBytecode(f, "LOAD_GLOBAL", "X") 13564 self.assertEqual(f(), "m") 13565 13566 def test_final_constant_folding_shadowing(self): 13567 codestr = """ 13568 from typing import Final 13569 13570 X: Final[str] = "omg" 13571 13572 def f() -> str: 13573 X = "lol" 13574 return X[1] 13575 """ 13576 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13577 f = mod["f"] 13578 self.assertInBytecode(f, "LOAD_CONST", "lol") 13579 self.assertNotInBytecode(f, "LOAD_GLOBAL", "omg") 13580 self.assertEqual(f(), "o") 13581 13582 def test_final_constant_folding_in_module_scope(self): 13583 codestr = """ 13584 from typing import Final 13585 13586 X: Final[int] = 21 13587 y = X + 3 13588 """ 13589 c = self.compile(codestr, generator=StaticCodeGenerator, modname="foo.py") 13590 self.assertNotInBytecode(c, "LOAD_NAME", "X") 13591 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13592 self.assertEqual(mod["y"], 24) 13593 13594 def test_final_constant_in_module_scope(self): 13595 codestr = """ 13596 from typing import Final 13597 13598 X: Final[int] = 21 13599 """ 13600 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13601 self.assertEqual(mod["__final_constants__"], ("X",)) 13602 13603 def test_final_nonconstant_in_module_scope(self): 13604 codestr = """ 13605 from typing import Final 13606 13607 def p() -> str: 13608 return "omg" 13609 13610 X: Final[str] = p() 13611 """ 13612 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13613 self.assertEqual(mod["__final_constants__"], ()) 13614 13615 def test_double_load_const(self): 13616 codestr = """ 13617 from __static__ import double 13618 13619 def t(): 13620 pi: double = 3.14159 13621 """ 13622 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13623 t = mod["t"] 13624 self.assertInBytecode(t, "PRIMITIVE_LOAD_CONST", (3.14159, TYPED_DOUBLE)) 13625 t() 13626 self.assert_jitted(t) 13627 13628 def test_double_box(self): 13629 codestr = """ 13630 from __static__ import double, box 13631 13632 def t() -> float: 13633 pi: double = 3.14159 13634 return box(pi) 13635 """ 13636 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13637 t = mod["t"] 13638 self.assertInBytecode(t, "PRIMITIVE_LOAD_CONST", (3.14159, TYPED_DOUBLE)) 13639 self.assertNotInBytecode(t, "CAST") 13640 self.assertEqual(t(), 3.14159) 13641 self.assert_jitted(t) 13642 13643 def test_none_not(self): 13644 codestr = """ 13645 def t() -> bool: 13646 x = None 13647 if not x: 13648 return True 13649 else: 13650 return False 13651 """ 13652 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13653 t = mod["t"] 13654 self.assertInBytecode(t, "POP_JUMP_IF_TRUE") 13655 self.assertTrue(t()) 13656 13657 def test_unbox_final_inlined(self): 13658 for func in ["unbox", "int64"]: 13659 with self.subTest(func=func): 13660 codestr = f""" 13661 from typing import Final 13662 from __static__ import int64, unbox 13663 13664 MY_FINAL: Final[int] = 111 13665 13666 def t() -> bool: 13667 i: int64 = 64 13668 if i < {func}(MY_FINAL): 13669 return True 13670 else: 13671 return False 13672 """ 13673 with self.in_module(codestr) as mod: 13674 t = mod["t"] 13675 self.assertInBytecode(t, "PRIMITIVE_LOAD_CONST", (111, TYPED_INT64)) 13676 self.assertEqual(t(), True) 13677 13678 def test_augassign_inexact(self): 13679 codestr = """ 13680 def something(): 13681 return 3 13682 13683 def t(): 13684 a: int = something() 13685 13686 b = 0 13687 b += a 13688 return b 13689 """ 13690 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13691 t = mod["t"] 13692 self.assertInBytecode(t, "INPLACE_ADD") 13693 self.assertEqual(t(), 3) 13694 13695 def test_qualname(self): 13696 codestr = """ 13697 def f(): 13698 pass 13699 13700 13701 class C: 13702 def x(self): 13703 pass 13704 13705 @staticmethod 13706 def sm(): 13707 pass 13708 13709 @classmethod 13710 def cm(): 13711 pass 13712 13713 def f(self): 13714 class G: 13715 def y(self): 13716 pass 13717 return G.y 13718 """ 13719 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13720 f = mod["f"] 13721 C = mod["C"] 13722 self.assertEqual(cinder._get_qualname(f.__code__), "f") 13723 13724 self.assertEqual(cinder._get_qualname(C.x.__code__), "C.x") 13725 self.assertEqual(cinder._get_qualname(C.sm.__code__), "C.sm") 13726 self.assertEqual(cinder._get_qualname(C.cm.__code__), "C.cm") 13727 13728 self.assertEqual(cinder._get_qualname(C().f().__code__), "C.f.<locals>.G.y") 13729 13730 def test_refine_optional_name(self): 13731 codestr = """ 13732 from typing import Optional 13733 13734 def f(s: Optional[str]) -> bytes: 13735 return s.encode("utf-8") if s else b"" 13736 """ 13737 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13738 f = mod["f"] 13739 self.assertEqual(f("A"), b"A") 13740 self.assertEqual(f(None), b"") 13741 13742 def test_donotcompile_fn(self): 13743 codestr = """ 13744 from __static__ import _donotcompile 13745 13746 def a() -> int: 13747 return 1 13748 13749 @_donotcompile 13750 def fn() -> None: 13751 a() + 2 13752 13753 def fn2() -> None: 13754 a() + 2 13755 """ 13756 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13757 fn = mod["fn"] 13758 self.assertInBytecode(fn, "CALL_FUNCTION") 13759 self.assertNotInBytecode(fn, "INVOKE_FUNCTION") 13760 self.assertFalse(fn.__code__.co_flags & CO_STATICALLY_COMPILED) 13761 self.assertEqual(fn(), None) 13762 13763 fn2 = mod["fn2"] 13764 self.assertNotInBytecode(fn2, "CALL_FUNCTION") 13765 self.assertInBytecode(fn2, "INVOKE_FUNCTION") 13766 self.assertTrue(fn2.__code__.co_flags & CO_STATICALLY_COMPILED) 13767 self.assertEqual(fn2(), None) 13768 13769 def test_donotcompile_method(self): 13770 codestr = """ 13771 from __static__ import _donotcompile 13772 13773 def a() -> int: 13774 return 1 13775 13776 class C: 13777 @_donotcompile 13778 def fn() -> None: 13779 a() + 2 13780 13781 def fn2() -> None: 13782 a() + 2 13783 13784 c = C() 13785 """ 13786 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13787 C = mod["C"] 13788 13789 fn2 = C.fn2 13790 self.assertNotInBytecode(fn2, "CALL_FUNCTION") 13791 self.assertInBytecode(fn2, "INVOKE_FUNCTION") 13792 self.assertTrue(fn2.__code__.co_flags & CO_STATICALLY_COMPILED) 13793 self.assertEqual(fn2(), None) 13794 13795 def test_donotcompile_class(self): 13796 codestr = """ 13797 from __static__ import _donotcompile 13798 13799 def a() -> int: 13800 return 1 13801 13802 @_donotcompile 13803 class C: 13804 def fn() -> None: 13805 a() + 2 13806 13807 @_donotcompile 13808 class D: 13809 a() 13810 13811 c = C() 13812 """ 13813 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13814 C = mod["C"] 13815 fn = C.fn 13816 self.assertInBytecode(fn, "CALL_FUNCTION") 13817 self.assertNotInBytecode(fn, "INVOKE_FUNCTION") 13818 self.assertFalse(fn.__code__.co_flags & CO_STATICALLY_COMPILED) 13819 self.assertEqual(fn(), None) 13820 13821 D = mod["D"] 13822 13823 def test_donotcompile_lambda(self): 13824 codestr = """ 13825 from __static__ import _donotcompile 13826 13827 def a() -> int: 13828 return 1 13829 13830 class C: 13831 @_donotcompile 13832 def fn() -> None: 13833 z = lambda: a() + 2 13834 z() 13835 13836 def fn2() -> None: 13837 z = lambda: a() + 2 13838 z() 13839 13840 c = C() 13841 """ 13842 with self.in_module(codestr, code_gen=StaticCodeGenerator) as mod: 13843 C = mod["C"] 13844 fn = C.fn 13845 lambda_code = self.find_code(fn.__code__) 13846 self.assertNotInBytecode(lambda_code, "INVOKE_FUNCTION") 13847 self.assertFalse(lambda_code.co_flags & CO_STATICALLY_COMPILED) 13848 self.assertEqual(fn(), None) 13849 13850 fn2 = C.fn2 13851 lambda_code2 = self.find_code(fn2.__code__) 13852 self.assertInBytecode(lambda_code2, "INVOKE_FUNCTION") 13853 self.assertTrue(lambda_code2.co_flags & CO_STATICALLY_COMPILED) 13854 self.assertEqual(fn2(), None) 13855 13856 def test_double_binop(self): 13857 tests = [ 13858 (1.732, 2.0, "+", 3.732), 13859 (1.732, 2.0, "-", -0.268), 13860 (1.732, 2.0, "/", 0.866), 13861 (1.732, 2.0, "*", 3.464), 13862 (1.732, 2, "+", 3.732), 13863 ] 13864 13865 if cinderjit is not None: 13866 # test for division by zero 13867 tests.append((1.732, 0.0, "/", float("inf"))) 13868 13869 for x, y, op, res in tests: 13870 codestr = f""" 13871 from __static__ import double, box 13872 def testfunc(tst): 13873 x: double = {x} 13874 y: double = {y} 13875 13876 z: double = x {op} y 13877 return box(z) 13878 """ 13879 with self.subTest(type=type, x=x, y=y, op=op, res=res): 13880 with self.in_module(codestr) as mod: 13881 f = mod["testfunc"] 13882 self.assertEqual(f(False), res, f"{type} {x} {op} {y} {res}") 13883 13884 def test_primitive_stack_spill(self): 13885 # Create enough locals that some must get spilled to stack, to test 13886 # shuffling stack-spilled values across basic block transitions, and 13887 # field reads/writes with stack-spilled values. These can create 13888 # mem->mem moves that otherwise wouldn't exist, and trigger issues 13889 # like push/pop not supporting 8 or 32 bits on x64. 13890 varnames = string.ascii_lowercase[:20] 13891 sizes = ["uint8", "int16", "int32", "int64"] 13892 for size in sizes: 13893 indent = " " * 20 13894 attrs = f"\n{indent}".join(f"{var}: {size}" for var in varnames) 13895 inits = f"\n{indent}".join( 13896 f"{var}: {size} = {val}" for val, var in enumerate(varnames) 13897 ) 13898 assigns = f"\n{indent}".join(f"val.{var} = {var}" for var in varnames) 13899 reads = f"\n{indent}".join(f"{var} = val.{var}" for var in varnames) 13900 indent = " " * 24 13901 incrs = f"\n{indent}".join(f"{var} += 1" for var in varnames) 13902 codestr = f""" 13903 from __static__ import {size} 13904 13905 class C: 13906 {attrs} 13907 13908 def f(val: C, flag: {size}) -> {size}: 13909 {inits} 13910 if flag: 13911 {incrs} 13912 {assigns} 13913 {reads} 13914 return {' + '.join(varnames)} 13915 """ 13916 with self.subTest(size=size): 13917 with self.in_module(codestr) as mod: 13918 f, C = mod["f"], mod["C"] 13919 c = C() 13920 self.assertEqual(f(c, 0), sum(range(len(varnames)))) 13921 for val, var in enumerate(varnames): 13922 self.assertEqual(getattr(c, var), val) 13923 13924 c = C() 13925 self.assertEqual(f(c, 1), sum(range(len(varnames) + 1))) 13926 for val, var in enumerate(varnames): 13927 self.assertEqual(getattr(c, var), val + 1) 13928 13929 def test_class_static_tpflag(self): 13930 codestr = """ 13931 class A: 13932 pass 13933 """ 13934 with self.in_module(codestr) as mod: 13935 A = mod["A"] 13936 self.assertTrue(is_type_static(A)) 13937 13938 class B: 13939 pass 13940 13941 self.assertFalse(is_type_static(B)) 13942 13943 13944if __name__ == "__main__": 13945 unittest.main()