this repo has no description
at trunk 492 lines 18 kB view raw
1# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) 2import ast 3import dis 4import math 5import sys 6import unittest 7from compiler.consts import ( 8 CO_NEWLOCALS, 9 CO_NOFREE, 10 CO_OPTIMIZED, 11) 12from compiler.optimizer import AstOptimizer 13from compiler.pyassem import PyFlowGraph 14from compiler.pycodegen import CodeGenerator, Python37CodeGenerator 15from compiler.unparse import to_expr 16from unittest import skipIf 17 18from .common import CompilerTest 19 20 21@unittest.skipIf(sys.version_info < (3, 7), "AST optimizer introduced in 3.7") 22class AstOptimizerTests(CompilerTest): 23 class _Comparer: 24 def __init__(self, code, test): 25 self.code = code 26 self.test = test 27 self.opt = self.test.to_graph(code) 28 self.notopt = self.test.to_graph_no_opt(code) 29 30 def assert_both(self, *args): 31 self.test.assertInGraph( 32 self.notopt, *args 33 ) # should be present w/o peephole 34 self.test.assertInGraph(self.opt, *args) # should be present w/ peephole 35 36 def assert_neither(self, *args): 37 self.test.assertNotInGraph( 38 self.notopt, *args 39 ) # should be absent w/o peephole 40 self.test.assertNotInGraph(self.opt, *args) # should be absent w/ peephole 41 42 def assert_removed(self, *args): 43 self.test.assertInGraph( 44 self.notopt, *args 45 ) # should be present w/o peephole 46 self.test.assertNotInGraph(self.opt, *args) # should be removed w/ peephole 47 48 def assert_added(self, *args): 49 self.test.assertInGraph(self.opt, *args) # should be added w/ peephole 50 self.test.assertNotInGraph( 51 self.notopt, *args 52 ) # should be absent w/o peephole 53 54 def get_instructions(self, graph): 55 return [ 56 instr 57 for block in graph.getBlocks() 58 for instr in block.getInstructions() 59 ] 60 61 def assert_all_removed(self, *args): 62 for instr in self.get_instructions(self.opt): 63 for arg in args: 64 self.test.assertFalse(instr.opname.startswith(arg)) 65 for instr in self.get_instructions(self.notopt): 66 for arg in args: 67 if instr.opname.startswith(arg): 68 return 69 disassembly = self.test.dump_graph(self.notopt) 70 self.test.fail( 71 "no args were present: " + ", ".join(args) + "\n" + disassembly 72 ) 73 74 def assert_in_opt(self, *args): 75 self.test.assertInGraph(self.opt, *args) 76 77 def assert_not_in_opt(self, *args): 78 self.test.assertNotInGraph(self.opt, *args) 79 80 def assert_instr_count(self, opcode, before, after): 81 before_instrs = [ 82 instr 83 for instr in self.get_instructions(self.notopt) 84 if instr.opname == opcode 85 ] 86 self.test.assertEqual(len(before_instrs), before) 87 after_instrs = [ 88 instr 89 for instr in self.get_instructions(self.opt) 90 if instr.opname == opcode 91 ] 92 self.test.assertEqual(len(after_instrs), after) 93 94 def compare_graph(self, code): 95 return AstOptimizerTests._Comparer(code, self) 96 97 def to_graph_no_opt(self, code): 98 return self.to_graph(code, ast_optimizer_enabled=False) 99 100 def test_compile_opt_enabled(self): 101 graph = self.to_graph("x = -1") 102 self.assertNotInGraph(graph, "UNARY_NEGATIVE") 103 104 graph = self.to_graph_no_opt("x = -1") 105 self.assertInGraph(graph, "UNARY_NEGATIVE") 106 107 def test_opt_debug(self): 108 graph = self.to_graph("if not __debug__:\n x = 42") 109 self.assertNotInGraph(graph, "STORE_NAME") 110 111 graph = self.to_graph_no_opt("if not __debug__:\n x = 42") 112 self.assertInGraph(graph, "STORE_NAME") 113 114 def test_opt_debug_del(self): 115 code = "def f(): del __debug__" 116 outer_graph = self.to_graph(code) 117 for outer_instr in self.graph_to_instrs(outer_graph): 118 if outer_instr.opname == "LOAD_CONST" and isinstance( 119 outer_instr.oparg, CodeGenerator 120 ): 121 graph = outer_instr.oparg.graph 122 self.assertInGraph(graph, "LOAD_CONST", True) 123 self.assertNotInGraph(graph, "DELETE_FAST", "__debug__") 124 125 outer_graph = self.to_graph_no_opt(code) 126 for outer_instr in self.graph_to_instrs(outer_graph): 127 if outer_instr.opname == "LOAD_CONST" and isinstance( 128 outer_instr.oparg, CodeGenerator 129 ): 130 graph = outer_instr.oparg.graph 131 self.assertNotInGraph(graph, "LOAD_CONST", True) 132 self.assertInGraph(graph, "DELETE_FAST", "__debug__") 133 134 def test_const_fold(self): 135 code = self.compile("x = 0.0\ny=-0.0") 136 self.assertEqual(code.co_consts, (0.0, -0.0, None)) 137 self.assertEqual(math.copysign(1, code.co_consts[0]), 1) 138 self.assertEqual(math.copysign(1, code.co_consts[1]), -1) 139 140 def test_const_fold_tuple(self): 141 code = self.compile("x = (0.0, )\ny=(-0.0, )") 142 self.assertEqual(code.co_consts, ((0.0,), (-0.0,), None)) 143 self.assertEqual(math.copysign(1, code.co_consts[0][0]), 1) 144 self.assertEqual(math.copysign(1, code.co_consts[1][0]), -1) 145 146 def test_ast_optimizer(self): 147 cases = [ 148 ("+1", "1"), 149 ("--1", "1"), 150 ("~1", "-2"), 151 ("not 1", "False"), 152 ("not x is y", "x is not y"), 153 ("not x is not y", "x is y"), 154 ("not x in y", "x not in y"), 155 ("~1.1", "~1.1"), 156 ("+'str'", "+'str'"), 157 ("1 + 2", "3"), 158 ("1 + 3", "4"), 159 ("'abc' + 'def'", "'abcdef'"), 160 ("b'abc' + b'def'", "b'abcdef'"), 161 ("b'abc' + 'def'", "b'abc' + 'def'"), 162 ("b'abc' + --2", "b'abc' + 2"), 163 ("--2 + 'abc'", "2 + 'abc'"), 164 ("5 - 3", "2"), 165 ("6 - 3", "3"), 166 ("2 * 2", "4"), 167 ("2 * 3", "6"), 168 ("'abc' * 2", "'abcabc'"), 169 ("b'abc' * 2", "b'abcabc'"), 170 ("1 / 2", "0.5"), 171 ("6 / 2", "3.0"), 172 ("6 // 2", "3"), 173 ("5 // 2", "2"), 174 ("2 >> 1", "1"), 175 ("6 >> 1", "3"), 176 ("1 | 2", "3"), 177 ("1 | 1", "1"), 178 ("1 ^ 3", "2"), 179 ("1 ^ 1", "0"), 180 ("1 & 2", "0"), 181 ("1 & 3", "1"), 182 ("'abc' + 1", "'abc' + 1"), 183 ("1 / 0", "1 / 0"), 184 ("1 + None", "1 + None"), 185 ("True + None", "True + None"), 186 ("True + 1", "2"), 187 ("(1, 2)", "(1, 2)"), 188 ("(1, 2) * 2", "(1, 2, 1, 2)"), 189 ("(1, --2, abc)", "(1, 2, abc)"), 190 ("(1, 2)[0]", "1"), 191 ("1[0]", "1[0]"), 192 ("x[+1]", "x[1]"), 193 ("(+1)[x]", "1[x]"), 194 ("[x for x in [1,2,3]]", "[x for x in (1, 2, 3)]"), 195 ("(x for x in [1,2,3])", "(x for x in (1, 2, 3))"), 196 ("{x for x in [1,2,3]}", "{x for x in (1, 2, 3)}"), 197 ("{x for x in [--1,2,3]}", "{x for x in (1, 2, 3)}"), 198 ("{--1 for x in [1,2,3]}", "{1 for x in (1, 2, 3)}"), 199 ("x in [1,2,3]", "x in (1, 2, 3)"), 200 ("x in x in [1,2,3]", "x in x in (1, 2, 3)"), 201 ("x in [1,2,3] in x", "x in [1, 2, 3] in x"), 202 ] 203 for inp, expected in cases: 204 optimizer = AstOptimizer() 205 tree = ast.parse(inp) 206 optimized = to_expr(optimizer.visit(tree).body[0].value) 207 self.assertEqual(expected, optimized, "Input was: " + inp) 208 209 def test_ast_optimizer_for(self): 210 optimizer = AstOptimizer() 211 tree = ast.parse("for x in [1,2,3]: pass") 212 optimized = optimizer.visit(tree).body[0] 213 self.assertEqual(to_expr(optimized.iter), "(1, 2, 3)") 214 215 @skipIf(sys.version_info < (3, 8), "This optimization is only for Python 3.8+") 216 def test_fold_nonconst_list_to_tuple_in_comparisons(self): 217 optimizer = AstOptimizer() 218 tree = ast.parse("[a for a in b if a.c in [e, f]]") 219 optimized = optimizer.visit(tree) 220 self.assertEqual( 221 to_expr(optimized.body[0].value.generators[0].ifs[0].comparators[0]), 222 "(e, f)", 223 ) 224 225 def test_assert_statements(self): 226 optimizer = AstOptimizer(optimize=True) 227 non_optimizer = AstOptimizer(optimize=False) 228 code = """def f(a, b): assert a == b, 'lol'""" 229 tree = ast.parse(code) 230 optimized = optimizer.visit(tree) 231 # Function body should be empty 232 self.assertListEqual(optimized.body[0].body, []) 233 234 unoptimized = non_optimizer.visit(tree) 235 # Function body should contain the assert 236 self.assertIsInstance(unoptimized.body[0].body[0], ast.Assert) 237 238 @unittest.skipIf(sys.version_info < (3, 7), "3.6 does this in peephole") 239 def test_folding_of_tuples_of_constants(self): 240 for line, elem in ( 241 ("a = 1,2,3", (1, 2, 3)), 242 ('a = ("a","b","c")', ("a", "b", "c")), 243 ("a,b,c = 1,2,3", (1, 2, 3)), 244 ("a = (None, 1, None)", (None, 1, None)), 245 ("a = ((1, 2), 3, 4)", ((1, 2), 3, 4)), 246 ): 247 code = self.compare_graph(line) 248 code.assert_added("LOAD_CONST", elem) 249 code.assert_removed("BUILD_TUPLE") 250 251 # Long tuples should be folded too. 252 code = self.compare_graph("x=" + repr(tuple(range(10000)))) 253 code.assert_removed("BUILD_TUPLE") 254 # One LOAD_CONST for the tuple, one for the None return value 255 code.assert_instr_count("LOAD_CONST", 10001, 2) 256 257 # Bug 1053819: Tuple of constants misidentified when presented with: 258 # . . . opcode_with_arg 100 unary_opcode BUILD_TUPLE 1 . . . 259 # The following would segfault upon compilation 260 def crater(): 261 ( 262 ~[ 263 0, 264 1, 265 2, 266 3, 267 4, 268 5, 269 6, 270 7, 271 8, 272 9, 273 0, 274 1, 275 2, 276 3, 277 4, 278 5, 279 6, 280 7, 281 8, 282 9, 283 0, 284 1, 285 2, 286 3, 287 4, 288 5, 289 6, 290 7, 291 8, 292 9, 293 0, 294 1, 295 2, 296 3, 297 4, 298 5, 299 6, 300 7, 301 8, 302 9, 303 0, 304 1, 305 2, 306 3, 307 4, 308 5, 309 6, 310 7, 311 8, 312 9, 313 0, 314 1, 315 2, 316 3, 317 4, 318 5, 319 6, 320 7, 321 8, 322 9, 323 0, 324 1, 325 2, 326 3, 327 4, 328 5, 329 6, 330 7, 331 8, 332 9, 333 0, 334 1, 335 2, 336 3, 337 4, 338 5, 339 6, 340 7, 341 8, 342 9, 343 0, 344 1, 345 2, 346 3, 347 4, 348 5, 349 6, 350 7, 351 8, 352 9, 353 0, 354 1, 355 2, 356 3, 357 4, 358 5, 359 6, 360 7, 361 8, 362 9, 363 ], 364 ) 365 366 @unittest.skipIf(sys.version_info < (3, 7), "3.6 does this in peephole") 367 def test_folding_of_lists_of_constants(self): 368 for line, elem in ( 369 # in/not in constants with BUILD_LIST should be folded to a tuple: 370 ("a in [1,2,3]", (1, 2, 3)), 371 ('a not in ["a","b","c"]', ("a", "b", "c")), 372 ("a in [None, 1, None]", (None, 1, None)), 373 ("a not in [(1, 2), 3, 4]", ((1, 2), 3, 4)), 374 ): 375 code = self.compare_graph(line) 376 code.assert_added("LOAD_CONST", elem) 377 code.assert_removed("BUILD_LIST") 378 379 @unittest.skipIf(sys.version_info < (3, 7), "3.6 does this in peephole") 380 def test_folding_of_sets_of_constants(self): 381 for line, elem in ( 382 # in/not in constants with BUILD_SET should be folded to a frozenset: 383 ("a in {1,2,3}", frozenset({1, 2, 3})), 384 ('a not in {"a","b","c"}', frozenset({"a", "c", "b"})), 385 ("a in {None, 1, None}", frozenset({1, None})), 386 ("a not in {(1, 2), 3, 4}", frozenset({(1, 2), 3, 4})), 387 ("a in {1, 2, 3, 3, 2, 1}", frozenset({1, 2, 3})), 388 ): 389 code = self.compare_graph(line) 390 code.assert_removed("BUILD_SET") 391 code.assert_added("LOAD_CONST", elem) 392 393 # Ensure that the resulting code actually works: 394 d = self.run_code( 395 """ 396 def f(a): 397 return a in {1, 2, 3} 398 399 def g(a): 400 return a not in {1, 2, 3}""" 401 ) 402 f, g = d["f"], d["g"] 403 self.assertTrue(f(3)) 404 self.assertTrue(not f(4)) 405 406 self.assertTrue(not g(3)) 407 self.assertTrue(g(4)) 408 409 @unittest.skipIf(sys.version_info < (3, 7), "3.6 does this in peephole") 410 def test_folding_of_binops_on_constants(self): 411 for line, elem in ( 412 ("a = 2+3+4", 9), # chained fold 413 ('a = "@"*4', "@@@@"), # check string ops 414 ('a="abc" + "def"', "abcdef"), # check string ops 415 ("a = 3**4", 81), # binary power 416 ("a = 3*4", 12), # binary multiply 417 ("a = 13//4", 3), # binary floor divide 418 ("a = 14%4", 2), # binary modulo 419 ("a = 2+3", 5), # binary add 420 ("a = 13-4", 9), # binary subtract 421 # ('a = (12,13)[1]', 13), # binary subscr 422 ("a = 13 << 2", 52), # binary lshift 423 ("a = 13 >> 2", 3), # binary rshift 424 ("a = 13 & 7", 5), # binary and 425 ("a = 13 ^ 7", 10), # binary xor 426 ("a = 13 | 7", 15), # binary or 427 ("a = 2 ** -14", 6.103515625e-05), # binary power neg rhs 428 ): 429 code = self.compare_graph(line) 430 code.assert_added("LOAD_CONST", elem) 431 code.assert_all_removed("BINARY_") 432 433 # Verify that unfoldables are skipped 434 code = self.compare_graph('a=2+"b"') 435 code.assert_both("LOAD_CONST", 2) 436 code.assert_both("LOAD_CONST", "b") 437 438 # Verify that large sequences do not result from folding 439 code = self.compare_graph('a="x"*10000') 440 code.assert_both("LOAD_CONST", 10000) 441 consts = code.opt.getConsts() 442 self.assertNotIn("x" * 10000, consts) 443 code = self.compare_graph("a=1<<1000") 444 code.assert_both("LOAD_CONST", 1000) 445 self.assertNotIn(1 << 1000, consts) 446 code = self.compare_graph("a=2**1000") 447 code.assert_both("LOAD_CONST", 1000) 448 self.assertNotIn(2 ** 1000, consts) 449 450 @unittest.skipIf(sys.version_info < (3, 7), "3.6 does this in peephole") 451 def test_binary_subscr_on_unicode(self): 452 # valid code get optimized 453 code = self.compare_graph('x = "foo"[0]') 454 code.assert_added("LOAD_CONST", "f") 455 code.assert_removed("BINARY_SUBSCR") 456 code = self.compare_graph('x = "\u0061\uffff"[1]') 457 code.assert_added("LOAD_CONST", "\uffff") 458 code.assert_removed("BINARY_SUBSCR") 459 460 # With PEP 393, non-BMP char get optimized 461 code = self.compare_graph('x = "\U00012345"[0]') 462 code.assert_both("LOAD_CONST", "\U00012345") 463 code.assert_removed("BINARY_SUBSCR") 464 465 # invalid code doesn't get optimized 466 # out of range 467 code = self.compare_graph('x = "fuu"[10]') 468 code.assert_both("BINARY_SUBSCR") 469 470 @unittest.skipIf(sys.version_info < (3, 7), "3.6 does this in peephole") 471 def test_folding_of_unaryops_on_constants(self): 472 for line, elem in ( 473 ("x = -0.5", -0.5), # unary negative 474 ("x = -0.0", -0.0), # -0.0 475 ("x = -(1.0-1.0)", -0.0), # -0.0 after folding 476 ("x = -0", 0), # -0 477 ("x = ~-2", 1), # unary invert 478 ("x = +1", 1), # unary positive 479 ): 480 code = self.compare_graph(line) 481 # can't assert added here because -0/0 compares equal 482 code.assert_in_opt("LOAD_CONST", elem) 483 code.assert_all_removed("UNARY_") 484 485 # Verify that unfoldables are skipped 486 for line, elem, opname in ( 487 ('-"abc"', "abc", "UNARY_NEGATIVE"), 488 ('~"abc"', "abc", "UNARY_INVERT"), 489 ): 490 code = self.compare_graph(line) 491 code.assert_both("LOAD_CONST", elem) 492 code.assert_both(opname)