this repo has no description
at trunk 156 lines 4.8 kB view raw
1# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) 2import ast 3import dis 4import sys 5import unittest 6from compiler.pycodegen import CodeGenerator 7from dis import opmap, opname 8from unittest import TestCase 9 10from .common import CompilerTest 11 12 13class Block: 14 def __init__(self, label, next=None): 15 self.label = label 16 self.next = next 17 18 19class GraphTests(CompilerTest): 20 """Performs various unit tests on the flow control graph that gets produced 21 to make sure that we're linking all of our basic blocks together properly.""" 22 23 def format_graph(self, graph): 24 if graph.next: 25 return f"Block({repr(graph.label)}, {self.format_graph(graph.next)})" 26 return f"Block({repr(graph.label)})" 27 28 def assert_graph_equal(self, graph, expected): 29 first_block = graph.ordered_blocks[0] 30 try: 31 self.assert_graph_equal_worker(first_block, expected) 32 except AssertionError as e: 33 raise AssertionError( 34 e.args[0] + "\nGraph was: " + self.format_graph(first_block) 35 ) from None 36 37 def assert_graph_equal_worker(self, compiled, expected): 38 self.assertEqual(compiled.label, expected.label) 39 if expected.next: 40 self.assertIsNotNone(compiled.next) 41 self.assert_graph_equal_worker(compiled.next, expected.next) 42 else: 43 self.assertEqual(compiled.next, None) 44 45 def get_child_graph(self, graph, name): 46 for block in graph.ordered_blocks: 47 for instr in block.getInstructions(): 48 if instr.opname == "LOAD_CONST": 49 if ( 50 isinstance(instr.oparg, CodeGenerator) 51 and instr.oparg.name == name 52 ): 53 return instr.oparg.graph 54 55 def test_future_no_longer_relevant(self): 56 graph = self.to_graph( 57 """ 58 while x: 59 pass""" 60 ) 61 expected = Block( 62 "entry", 63 Block( 64 "while_loop", 65 Block("while_body", Block("while_else", Block("while_after"))), 66 ), 67 ) 68 self.assert_graph_equal(graph, expected) 69 70 def test_if(self): 71 graph = self.to_graph( 72 """ 73 if foo: 74 pass 75 else: 76 pass""" 77 ) 78 expected = Block("entry", Block("", Block("if_else", Block("if_end")))) 79 self.assert_graph_equal(graph, expected) 80 81 def test_try_except(self): 82 graph = self.to_graph( 83 """ 84 try: 85 pass 86 except: 87 pass""" 88 ) 89 90 if sys.version_info >= (3, 8): 91 expected = Block( 92 "entry", 93 Block( 94 "try_body", 95 Block( 96 "try_handlers", 97 Block( 98 "try_cleanup_body0", 99 Block("try_except_0", Block("try_else", Block("try_end"))), 100 ), 101 ), 102 ), 103 ) 104 else: 105 expected = Block( 106 "entry", 107 Block( 108 "try_body", 109 Block("try_handlers", Block("handler_end", Block("try_end"))), 110 ), 111 ) 112 self.assert_graph_equal(graph, expected) 113 114 def test_chained_comparison(self): 115 graph = self.to_graph("a < b < c") 116 expected = Block( 117 "entry", Block("compare_or_cleanup", Block("cleanup", Block("end"))) 118 ) 119 self.assert_graph_equal(graph, expected) 120 121 def test_async_for(self): 122 graph = self.to_graph( 123 """ 124 async def f(): 125 async for x in foo: 126 pass""" 127 ) 128 # graph the graph for f so we can check the async for 129 graph = self.get_child_graph(graph, "f") 130 if sys.version_info >= (3, 8): 131 expected = Block( 132 "entry", 133 Block("async_for_try", Block("except", Block("end", Block("exit")))), 134 ) 135 else: 136 expected = Block( 137 "entry", 138 Block( 139 "async_for_try", 140 Block( 141 "except", 142 Block( 143 "after_try", 144 Block( 145 "try_cleanup", 146 Block("after_loop_else", Block("end", Block("exit"))), 147 ), 148 ), 149 ), 150 ), 151 ) 152 self.assert_graph_equal(graph, expected) 153 154 155if __name__ == "__main__": 156 unittest.main()