this repo has no description
1# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
2import ast
3import dis
4import math
5import sys
6import unittest
7from compiler.consts import (
8 CO_NEWLOCALS,
9 CO_NOFREE,
10 CO_OPTIMIZED,
11)
12from compiler.optimizer import AstOptimizer
13from compiler.pyassem import PyFlowGraph
14from compiler.pycodegen import CodeGenerator, Python37CodeGenerator
15from compiler.unparse import to_expr
16from unittest import skipIf
17
18from .common import CompilerTest
19
20
21@unittest.skipIf(sys.version_info < (3, 7), "AST optimizer introduced in 3.7")
22class AstOptimizerTests(CompilerTest):
23 class _Comparer:
24 def __init__(self, code, test):
25 self.code = code
26 self.test = test
27 self.opt = self.test.to_graph(code)
28 self.notopt = self.test.to_graph_no_opt(code)
29
30 def assert_both(self, *args):
31 self.test.assertInGraph(
32 self.notopt, *args
33 ) # should be present w/o peephole
34 self.test.assertInGraph(self.opt, *args) # should be present w/ peephole
35
36 def assert_neither(self, *args):
37 self.test.assertNotInGraph(
38 self.notopt, *args
39 ) # should be absent w/o peephole
40 self.test.assertNotInGraph(self.opt, *args) # should be absent w/ peephole
41
42 def assert_removed(self, *args):
43 self.test.assertInGraph(
44 self.notopt, *args
45 ) # should be present w/o peephole
46 self.test.assertNotInGraph(self.opt, *args) # should be removed w/ peephole
47
48 def assert_added(self, *args):
49 self.test.assertInGraph(self.opt, *args) # should be added w/ peephole
50 self.test.assertNotInGraph(
51 self.notopt, *args
52 ) # should be absent w/o peephole
53
54 def get_instructions(self, graph):
55 return [
56 instr
57 for block in graph.getBlocks()
58 for instr in block.getInstructions()
59 ]
60
61 def assert_all_removed(self, *args):
62 for instr in self.get_instructions(self.opt):
63 for arg in args:
64 self.test.assertFalse(instr.opname.startswith(arg))
65 for instr in self.get_instructions(self.notopt):
66 for arg in args:
67 if instr.opname.startswith(arg):
68 return
69 disassembly = self.test.dump_graph(self.notopt)
70 self.test.fail(
71 "no args were present: " + ", ".join(args) + "\n" + disassembly
72 )
73
74 def assert_in_opt(self, *args):
75 self.test.assertInGraph(self.opt, *args)
76
77 def assert_not_in_opt(self, *args):
78 self.test.assertNotInGraph(self.opt, *args)
79
80 def assert_instr_count(self, opcode, before, after):
81 before_instrs = [
82 instr
83 for instr in self.get_instructions(self.notopt)
84 if instr.opname == opcode
85 ]
86 self.test.assertEqual(len(before_instrs), before)
87 after_instrs = [
88 instr
89 for instr in self.get_instructions(self.opt)
90 if instr.opname == opcode
91 ]
92 self.test.assertEqual(len(after_instrs), after)
93
94 def compare_graph(self, code):
95 return AstOptimizerTests._Comparer(code, self)
96
97 def to_graph_no_opt(self, code):
98 return self.to_graph(code, ast_optimizer_enabled=False)
99
100 def test_compile_opt_enabled(self):
101 graph = self.to_graph("x = -1")
102 self.assertNotInGraph(graph, "UNARY_NEGATIVE")
103
104 graph = self.to_graph_no_opt("x = -1")
105 self.assertInGraph(graph, "UNARY_NEGATIVE")
106
107 def test_opt_debug(self):
108 graph = self.to_graph("if not __debug__:\n x = 42")
109 self.assertNotInGraph(graph, "STORE_NAME")
110
111 graph = self.to_graph_no_opt("if not __debug__:\n x = 42")
112 self.assertInGraph(graph, "STORE_NAME")
113
114 def test_opt_debug_del(self):
115 code = "def f(): del __debug__"
116 outer_graph = self.to_graph(code)
117 for outer_instr in self.graph_to_instrs(outer_graph):
118 if outer_instr.opname == "LOAD_CONST" and isinstance(
119 outer_instr.oparg, CodeGenerator
120 ):
121 graph = outer_instr.oparg.graph
122 self.assertInGraph(graph, "LOAD_CONST", True)
123 self.assertNotInGraph(graph, "DELETE_FAST", "__debug__")
124
125 outer_graph = self.to_graph_no_opt(code)
126 for outer_instr in self.graph_to_instrs(outer_graph):
127 if outer_instr.opname == "LOAD_CONST" and isinstance(
128 outer_instr.oparg, CodeGenerator
129 ):
130 graph = outer_instr.oparg.graph
131 self.assertNotInGraph(graph, "LOAD_CONST", True)
132 self.assertInGraph(graph, "DELETE_FAST", "__debug__")
133
134 def test_const_fold(self):
135 code = self.compile("x = 0.0\ny=-0.0")
136 self.assertEqual(code.co_consts, (0.0, -0.0, None))
137 self.assertEqual(math.copysign(1, code.co_consts[0]), 1)
138 self.assertEqual(math.copysign(1, code.co_consts[1]), -1)
139
140 def test_const_fold_tuple(self):
141 code = self.compile("x = (0.0, )\ny=(-0.0, )")
142 self.assertEqual(code.co_consts, ((0.0,), (-0.0,), None))
143 self.assertEqual(math.copysign(1, code.co_consts[0][0]), 1)
144 self.assertEqual(math.copysign(1, code.co_consts[1][0]), -1)
145
146 def test_ast_optimizer(self):
147 cases = [
148 ("+1", "1"),
149 ("--1", "1"),
150 ("~1", "-2"),
151 ("not 1", "False"),
152 ("not x is y", "x is not y"),
153 ("not x is not y", "x is y"),
154 ("not x in y", "x not in y"),
155 ("~1.1", "~1.1"),
156 ("+'str'", "+'str'"),
157 ("1 + 2", "3"),
158 ("1 + 3", "4"),
159 ("'abc' + 'def'", "'abcdef'"),
160 ("b'abc' + b'def'", "b'abcdef'"),
161 ("b'abc' + 'def'", "b'abc' + 'def'"),
162 ("b'abc' + --2", "b'abc' + 2"),
163 ("--2 + 'abc'", "2 + 'abc'"),
164 ("5 - 3", "2"),
165 ("6 - 3", "3"),
166 ("2 * 2", "4"),
167 ("2 * 3", "6"),
168 ("'abc' * 2", "'abcabc'"),
169 ("b'abc' * 2", "b'abcabc'"),
170 ("1 / 2", "0.5"),
171 ("6 / 2", "3.0"),
172 ("6 // 2", "3"),
173 ("5 // 2", "2"),
174 ("2 >> 1", "1"),
175 ("6 >> 1", "3"),
176 ("1 | 2", "3"),
177 ("1 | 1", "1"),
178 ("1 ^ 3", "2"),
179 ("1 ^ 1", "0"),
180 ("1 & 2", "0"),
181 ("1 & 3", "1"),
182 ("'abc' + 1", "'abc' + 1"),
183 ("1 / 0", "1 / 0"),
184 ("1 + None", "1 + None"),
185 ("True + None", "True + None"),
186 ("True + 1", "2"),
187 ("(1, 2)", "(1, 2)"),
188 ("(1, 2) * 2", "(1, 2, 1, 2)"),
189 ("(1, --2, abc)", "(1, 2, abc)"),
190 ("(1, 2)[0]", "1"),
191 ("1[0]", "1[0]"),
192 ("x[+1]", "x[1]"),
193 ("(+1)[x]", "1[x]"),
194 ("[x for x in [1,2,3]]", "[x for x in (1, 2, 3)]"),
195 ("(x for x in [1,2,3])", "(x for x in (1, 2, 3))"),
196 ("{x for x in [1,2,3]}", "{x for x in (1, 2, 3)}"),
197 ("{x for x in [--1,2,3]}", "{x for x in (1, 2, 3)}"),
198 ("{--1 for x in [1,2,3]}", "{1 for x in (1, 2, 3)}"),
199 ("x in [1,2,3]", "x in (1, 2, 3)"),
200 ("x in x in [1,2,3]", "x in x in (1, 2, 3)"),
201 ("x in [1,2,3] in x", "x in [1, 2, 3] in x"),
202 ]
203 for inp, expected in cases:
204 optimizer = AstOptimizer()
205 tree = ast.parse(inp)
206 optimized = to_expr(optimizer.visit(tree).body[0].value)
207 self.assertEqual(expected, optimized, "Input was: " + inp)
208
209 def test_ast_optimizer_for(self):
210 optimizer = AstOptimizer()
211 tree = ast.parse("for x in [1,2,3]: pass")
212 optimized = optimizer.visit(tree).body[0]
213 self.assertEqual(to_expr(optimized.iter), "(1, 2, 3)")
214
215 @skipIf(sys.version_info < (3, 8), "This optimization is only for Python 3.8+")
216 def test_fold_nonconst_list_to_tuple_in_comparisons(self):
217 optimizer = AstOptimizer()
218 tree = ast.parse("[a for a in b if a.c in [e, f]]")
219 optimized = optimizer.visit(tree)
220 self.assertEqual(
221 to_expr(optimized.body[0].value.generators[0].ifs[0].comparators[0]),
222 "(e, f)",
223 )
224
225 def test_assert_statements(self):
226 optimizer = AstOptimizer(optimize=True)
227 non_optimizer = AstOptimizer(optimize=False)
228 code = """def f(a, b): assert a == b, 'lol'"""
229 tree = ast.parse(code)
230 optimized = optimizer.visit(tree)
231 # Function body should be empty
232 self.assertListEqual(optimized.body[0].body, [])
233
234 unoptimized = non_optimizer.visit(tree)
235 # Function body should contain the assert
236 self.assertIsInstance(unoptimized.body[0].body[0], ast.Assert)
237
238 @unittest.skipIf(sys.version_info < (3, 7), "3.6 does this in peephole")
239 def test_folding_of_tuples_of_constants(self):
240 for line, elem in (
241 ("a = 1,2,3", (1, 2, 3)),
242 ('a = ("a","b","c")', ("a", "b", "c")),
243 ("a,b,c = 1,2,3", (1, 2, 3)),
244 ("a = (None, 1, None)", (None, 1, None)),
245 ("a = ((1, 2), 3, 4)", ((1, 2), 3, 4)),
246 ):
247 code = self.compare_graph(line)
248 code.assert_added("LOAD_CONST", elem)
249 code.assert_removed("BUILD_TUPLE")
250
251 # Long tuples should be folded too.
252 code = self.compare_graph("x=" + repr(tuple(range(10000))))
253 code.assert_removed("BUILD_TUPLE")
254 # One LOAD_CONST for the tuple, one for the None return value
255 code.assert_instr_count("LOAD_CONST", 10001, 2)
256
257 # Bug 1053819: Tuple of constants misidentified when presented with:
258 # . . . opcode_with_arg 100 unary_opcode BUILD_TUPLE 1 . . .
259 # The following would segfault upon compilation
260 def crater():
261 (
262 ~[
263 0,
264 1,
265 2,
266 3,
267 4,
268 5,
269 6,
270 7,
271 8,
272 9,
273 0,
274 1,
275 2,
276 3,
277 4,
278 5,
279 6,
280 7,
281 8,
282 9,
283 0,
284 1,
285 2,
286 3,
287 4,
288 5,
289 6,
290 7,
291 8,
292 9,
293 0,
294 1,
295 2,
296 3,
297 4,
298 5,
299 6,
300 7,
301 8,
302 9,
303 0,
304 1,
305 2,
306 3,
307 4,
308 5,
309 6,
310 7,
311 8,
312 9,
313 0,
314 1,
315 2,
316 3,
317 4,
318 5,
319 6,
320 7,
321 8,
322 9,
323 0,
324 1,
325 2,
326 3,
327 4,
328 5,
329 6,
330 7,
331 8,
332 9,
333 0,
334 1,
335 2,
336 3,
337 4,
338 5,
339 6,
340 7,
341 8,
342 9,
343 0,
344 1,
345 2,
346 3,
347 4,
348 5,
349 6,
350 7,
351 8,
352 9,
353 0,
354 1,
355 2,
356 3,
357 4,
358 5,
359 6,
360 7,
361 8,
362 9,
363 ],
364 )
365
366 @unittest.skipIf(sys.version_info < (3, 7), "3.6 does this in peephole")
367 def test_folding_of_lists_of_constants(self):
368 for line, elem in (
369 # in/not in constants with BUILD_LIST should be folded to a tuple:
370 ("a in [1,2,3]", (1, 2, 3)),
371 ('a not in ["a","b","c"]', ("a", "b", "c")),
372 ("a in [None, 1, None]", (None, 1, None)),
373 ("a not in [(1, 2), 3, 4]", ((1, 2), 3, 4)),
374 ):
375 code = self.compare_graph(line)
376 code.assert_added("LOAD_CONST", elem)
377 code.assert_removed("BUILD_LIST")
378
379 @unittest.skipIf(sys.version_info < (3, 7), "3.6 does this in peephole")
380 def test_folding_of_sets_of_constants(self):
381 for line, elem in (
382 # in/not in constants with BUILD_SET should be folded to a frozenset:
383 ("a in {1,2,3}", frozenset({1, 2, 3})),
384 ('a not in {"a","b","c"}', frozenset({"a", "c", "b"})),
385 ("a in {None, 1, None}", frozenset({1, None})),
386 ("a not in {(1, 2), 3, 4}", frozenset({(1, 2), 3, 4})),
387 ("a in {1, 2, 3, 3, 2, 1}", frozenset({1, 2, 3})),
388 ):
389 code = self.compare_graph(line)
390 code.assert_removed("BUILD_SET")
391 code.assert_added("LOAD_CONST", elem)
392
393 # Ensure that the resulting code actually works:
394 d = self.run_code(
395 """
396 def f(a):
397 return a in {1, 2, 3}
398
399 def g(a):
400 return a not in {1, 2, 3}"""
401 )
402 f, g = d["f"], d["g"]
403 self.assertTrue(f(3))
404 self.assertTrue(not f(4))
405
406 self.assertTrue(not g(3))
407 self.assertTrue(g(4))
408
409 @unittest.skipIf(sys.version_info < (3, 7), "3.6 does this in peephole")
410 def test_folding_of_binops_on_constants(self):
411 for line, elem in (
412 ("a = 2+3+4", 9), # chained fold
413 ('a = "@"*4', "@@@@"), # check string ops
414 ('a="abc" + "def"', "abcdef"), # check string ops
415 ("a = 3**4", 81), # binary power
416 ("a = 3*4", 12), # binary multiply
417 ("a = 13//4", 3), # binary floor divide
418 ("a = 14%4", 2), # binary modulo
419 ("a = 2+3", 5), # binary add
420 ("a = 13-4", 9), # binary subtract
421 # ('a = (12,13)[1]', 13), # binary subscr
422 ("a = 13 << 2", 52), # binary lshift
423 ("a = 13 >> 2", 3), # binary rshift
424 ("a = 13 & 7", 5), # binary and
425 ("a = 13 ^ 7", 10), # binary xor
426 ("a = 13 | 7", 15), # binary or
427 ("a = 2 ** -14", 6.103515625e-05), # binary power neg rhs
428 ):
429 code = self.compare_graph(line)
430 code.assert_added("LOAD_CONST", elem)
431 code.assert_all_removed("BINARY_")
432
433 # Verify that unfoldables are skipped
434 code = self.compare_graph('a=2+"b"')
435 code.assert_both("LOAD_CONST", 2)
436 code.assert_both("LOAD_CONST", "b")
437
438 # Verify that large sequences do not result from folding
439 code = self.compare_graph('a="x"*10000')
440 code.assert_both("LOAD_CONST", 10000)
441 consts = code.opt.getConsts()
442 self.assertNotIn("x" * 10000, consts)
443 code = self.compare_graph("a=1<<1000")
444 code.assert_both("LOAD_CONST", 1000)
445 self.assertNotIn(1 << 1000, consts)
446 code = self.compare_graph("a=2**1000")
447 code.assert_both("LOAD_CONST", 1000)
448 self.assertNotIn(2 ** 1000, consts)
449
450 @unittest.skipIf(sys.version_info < (3, 7), "3.6 does this in peephole")
451 def test_binary_subscr_on_unicode(self):
452 # valid code get optimized
453 code = self.compare_graph('x = "foo"[0]')
454 code.assert_added("LOAD_CONST", "f")
455 code.assert_removed("BINARY_SUBSCR")
456 code = self.compare_graph('x = "\u0061\uffff"[1]')
457 code.assert_added("LOAD_CONST", "\uffff")
458 code.assert_removed("BINARY_SUBSCR")
459
460 # With PEP 393, non-BMP char get optimized
461 code = self.compare_graph('x = "\U00012345"[0]')
462 code.assert_both("LOAD_CONST", "\U00012345")
463 code.assert_removed("BINARY_SUBSCR")
464
465 # invalid code doesn't get optimized
466 # out of range
467 code = self.compare_graph('x = "fuu"[10]')
468 code.assert_both("BINARY_SUBSCR")
469
470 @unittest.skipIf(sys.version_info < (3, 7), "3.6 does this in peephole")
471 def test_folding_of_unaryops_on_constants(self):
472 for line, elem in (
473 ("x = -0.5", -0.5), # unary negative
474 ("x = -0.0", -0.0), # -0.0
475 ("x = -(1.0-1.0)", -0.0), # -0.0 after folding
476 ("x = -0", 0), # -0
477 ("x = ~-2", 1), # unary invert
478 ("x = +1", 1), # unary positive
479 ):
480 code = self.compare_graph(line)
481 # can't assert added here because -0/0 compares equal
482 code.assert_in_opt("LOAD_CONST", elem)
483 code.assert_all_removed("UNARY_")
484
485 # Verify that unfoldables are skipped
486 for line, elem, opname in (
487 ('-"abc"', "abc", "UNARY_NEGATIVE"),
488 ('~"abc"', "abc", "UNARY_INVERT"),
489 ):
490 code = self.compare_graph(line)
491 code.assert_both("LOAD_CONST", elem)
492 code.assert_both(opname)