this repo has no description
1# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
2import ast
3import dis
4import sys
5import unittest
6from compiler.pycodegen import CodeGenerator
7from dis import opmap, opname
8from unittest import TestCase
9
10from .common import CompilerTest
11
12
13class Block:
14 def __init__(self, label, next=None):
15 self.label = label
16 self.next = next
17
18
19class GraphTests(CompilerTest):
20 """Performs various unit tests on the flow control graph that gets produced
21 to make sure that we're linking all of our basic blocks together properly."""
22
23 def format_graph(self, graph):
24 if graph.next:
25 return f"Block({repr(graph.label)}, {self.format_graph(graph.next)})"
26 return f"Block({repr(graph.label)})"
27
28 def assert_graph_equal(self, graph, expected):
29 first_block = graph.ordered_blocks[0]
30 try:
31 self.assert_graph_equal_worker(first_block, expected)
32 except AssertionError as e:
33 raise AssertionError(
34 e.args[0] + "\nGraph was: " + self.format_graph(first_block)
35 ) from None
36
37 def assert_graph_equal_worker(self, compiled, expected):
38 self.assertEqual(compiled.label, expected.label)
39 if expected.next:
40 self.assertIsNotNone(compiled.next)
41 self.assert_graph_equal_worker(compiled.next, expected.next)
42 else:
43 self.assertEqual(compiled.next, None)
44
45 def get_child_graph(self, graph, name):
46 for block in graph.ordered_blocks:
47 for instr in block.getInstructions():
48 if instr.opname == "LOAD_CONST":
49 if (
50 isinstance(instr.oparg, CodeGenerator)
51 and instr.oparg.name == name
52 ):
53 return instr.oparg.graph
54
55 def test_future_no_longer_relevant(self):
56 graph = self.to_graph(
57 """
58 while x:
59 pass"""
60 )
61 expected = Block(
62 "entry",
63 Block(
64 "while_loop",
65 Block("while_body", Block("while_else", Block("while_after"))),
66 ),
67 )
68 self.assert_graph_equal(graph, expected)
69
70 def test_if(self):
71 graph = self.to_graph(
72 """
73 if foo:
74 pass
75 else:
76 pass"""
77 )
78 expected = Block("entry", Block("", Block("if_else", Block("if_end"))))
79 self.assert_graph_equal(graph, expected)
80
81 def test_try_except(self):
82 graph = self.to_graph(
83 """
84 try:
85 pass
86 except:
87 pass"""
88 )
89
90 if sys.version_info >= (3, 8):
91 expected = Block(
92 "entry",
93 Block(
94 "try_body",
95 Block(
96 "try_handlers",
97 Block(
98 "try_cleanup_body0",
99 Block("try_except_0", Block("try_else", Block("try_end"))),
100 ),
101 ),
102 ),
103 )
104 else:
105 expected = Block(
106 "entry",
107 Block(
108 "try_body",
109 Block("try_handlers", Block("handler_end", Block("try_end"))),
110 ),
111 )
112 self.assert_graph_equal(graph, expected)
113
114 def test_chained_comparison(self):
115 graph = self.to_graph("a < b < c")
116 expected = Block(
117 "entry", Block("compare_or_cleanup", Block("cleanup", Block("end")))
118 )
119 self.assert_graph_equal(graph, expected)
120
121 def test_async_for(self):
122 graph = self.to_graph(
123 """
124 async def f():
125 async for x in foo:
126 pass"""
127 )
128 # graph the graph for f so we can check the async for
129 graph = self.get_child_graph(graph, "f")
130 if sys.version_info >= (3, 8):
131 expected = Block(
132 "entry",
133 Block("async_for_try", Block("except", Block("end", Block("exit")))),
134 )
135 else:
136 expected = Block(
137 "entry",
138 Block(
139 "async_for_try",
140 Block(
141 "except",
142 Block(
143 "after_try",
144 Block(
145 "try_cleanup",
146 Block("after_loop_else", Block("end", Block("exit"))),
147 ),
148 ),
149 ),
150 ),
151 )
152 self.assert_graph_equal(graph, expected)
153
154
155if __name__ == "__main__":
156 unittest.main()