this repo has no description
1# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
2import ast
3from ast import AST
4from compiler import compile as compiler_compile
5from compiler.consts import CO_VARARGS, CO_VARKEYWORDS
6from compiler.optimizer import BIN_OPS, is_const, get_const_value
7from compiler.py38.optimizer import AstOptimizer38
8from compiler.pyassem import PyFlowGraph38, Instruction
9from compiler.pycodegen import Python38CodeGenerator
10from compiler.symbols import SymbolVisitor
11from compiler.visitor import ASTVisitor, walk
12
13import _compiler_opcode as opcodepyro
14
15
16def should_rewrite_printf(node):
17 return isinstance(node.left, ast.Str) and isinstance(node.op, ast.Mod)
18
19
20def create_conversion_call(name, value):
21 method = ast.Attribute(ast.Str(""), name, ast.Load())
22 return ast.Call(method, args=[value], keywords=[])
23
24
25def try_constant_fold_mod(format_string, right):
26 r = get_const_value(right)
27 return ast.Str(format_string.__mod__(r))
28
29
30class AstOptimizerPyro(AstOptimizer38):
31 def rewrite_str_mod(self, left, right): # noqa: C901
32 format_string = left.s
33 try:
34 if is_const(right):
35 return try_constant_fold_mod(format_string, right)
36 # Try and collapse the whole expression into a string
37 const_tuple = self.makeConstTuple(right.elts)
38 if const_tuple:
39 return ast.Str(format_string.__mod__(const_tuple.value))
40 except Exception:
41 pass
42 n_specifiers = 0
43 i = 0
44 length = len(format_string)
45 while i < length:
46 i = format_string.find("%", i)
47 if i == -1:
48 break
49 ch = format_string[i]
50 i += 1
51
52 if i >= length:
53 # Invalid format string ending in a single percent
54 return None
55 ch = format_string[i]
56 i += 1
57 if ch == "%":
58 # Break the string apart at '%'
59 continue
60 elif ch == "(":
61 # We don't support dict lookups and may get confused from
62 # inner '%' chars
63 return None
64 n_specifiers += 1
65
66 rhs = right
67 if isinstance(right, ast.Tuple):
68 rhs_values = rhs.elts
69 num_values = len(rhs_values)
70 else:
71 # If RHS is not a tuple constructor, then we only support the
72 # situation with a single format specifier in the string, by
73 # normalizing `rhs` to a one-element tuple:
74 # `_mod_check_single_arg(rhs)[0]`
75 rhs_values = None
76 if n_specifiers != 1:
77 return None
78 num_values = 1
79 i = 0
80 value_idx = 0
81 segment_begin = 0
82 strings = []
83 while i < length:
84 i = format_string.find("%", i)
85 if i == -1:
86 break
87 ch = format_string[i]
88 i += 1
89
90 segment_end = i - 1
91 if segment_end - segment_begin > 0:
92 substr = format_string[segment_begin:segment_end]
93 strings.append(ast.Str(substr))
94
95 if i >= length:
96 return None
97 ch = format_string[i]
98 i += 1
99
100 # Parse flags and width
101 spec_begin = i - 1
102 have_width = False
103 while True:
104 if ch == "0":
105 # TODO(matthiasb): Support ' ', '+', '#', etc
106 # They mostly have the same meaning. However they can
107 # appear in any order here but must follow stricter
108 # conventions in f-strings.
109 if i >= length:
110 return None
111 ch = format_string[i]
112 i += 1
113 continue
114 break
115 if "1" <= ch <= "9":
116 have_width = True
117 if i >= length:
118 return None
119 ch = format_string[i]
120 i += 1
121 while "0" <= ch <= "9":
122 if i >= length:
123 return None
124 ch = format_string[i]
125 i += 1
126 spec_str = ""
127 if i - 1 - spec_begin > 0:
128 spec_str = format_string[spec_begin : i - 1]
129
130 if ch == "%":
131 # Handle '%%'
132 segment_begin = i - 1
133 continue
134
135 # Handle remaining supported cases that use a value from RHS
136 if rhs_values is not None:
137 if value_idx >= num_values:
138 return None
139 value = rhs_values[value_idx]
140 else:
141 # We have a situation like `"%s" % x` without tuple on RHS.
142 # Transform to: f"{''._mod_check_single_arg(x)[0]}"
143 converted = create_conversion_call("_mod_check_single_arg", rhs)
144 value = ast.Subscript(converted, ast.Index(ast.Num(0)), ast.Load())
145 value_idx += 1
146
147 if ch in "sra":
148 # Rewrite "%s" % (x,) to f"{x!s}"
149 if have_width:
150 # Need to explicitly specify alignment because `%5s`
151 # aligns right, while `f"{x:5}"` aligns left.
152 spec_str = ">" + spec_str
153 format_spec = ast.Str(spec_str) if spec_str else None
154 formatted = ast.FormattedValue(value, ord(ch), format_spec)
155 strings.append(formatted)
156 elif ch in "diu":
157 # Rewrite "%d" % (x,) to f"{''._mod_convert_number_int(x)}".
158 # Calling a method on the empty string is a hack to access a
159 # well-known function regardless of the surrounding
160 # environment.
161 converted = create_conversion_call("_mod_convert_number_int", value)
162 format_spec = ast.Str(spec_str) if spec_str else None
163 formatted = ast.FormattedValue(converted, -1, format_spec)
164 strings.append(formatted)
165 elif ch in "xXo":
166 # Rewrite "%x" % (v,) to f"{''._mod_convert_number_index(v):x}".
167 # Calling a method on the empty string is a hack to access a
168 # well-known function regardless of the surrounding
169 # environment.
170 converted = create_conversion_call("_mod_convert_number_index", value)
171 format_spec = ast.Str(spec_str + ch)
172 formatted = ast.FormattedValue(converted, -1, format_spec)
173 strings.append(formatted)
174 else:
175 return None
176 # Begin next segment after specifier
177 segment_begin = i
178
179 if value_idx != num_values:
180 return None
181
182 segment_end = length
183 if segment_end - segment_begin > 0:
184 substr = format_string[segment_begin:segment_end]
185 strings.append(ast.Str(substr))
186
187 return ast.JoinedStr(strings)
188
189 def visitBinOp(self, node: ast.BinOp) -> ast.expr:
190 left = self.visit(node.left)
191 right = self.visit(node.right)
192
193 if is_const(left) and is_const(right):
194 handler = BIN_OPS.get(type(node.op))
195 if handler is not None:
196 lval = get_const_value(left)
197 rval = get_const_value(right)
198 try:
199 return ast.copy_location(ast.Constant(handler(lval, rval)), node)
200 except Exception:
201 pass
202
203 if should_rewrite_printf(node):
204 result = self.rewrite_str_mod(left, right)
205 if result:
206 return self.visit(result)
207
208 return self.update_node(node, left=left, right=right)
209
210
211class PyroFlowGraph(PyFlowGraph38):
212 opcode = opcodepyro.opcode
213
214 def optimizeLoadFast(self):
215 blocks = self.getBlocksInOrder()
216 preds = tuple(set() for i in range(self.block_count))
217 for block in blocks:
218 for child in block.get_children():
219 if child is not None:
220 # TODO(emacs): Tail-duplicate finally blocks or upgrade to
221 # 3.10, which does this already. This avoids except blocks
222 # falling through into else blocks and mucking up
223 # performance.
224 preds[child.bid].add(block.bid)
225
226 num_locals = len(self.varnames)
227 Top = 2**num_locals - 1
228 # map of block id -> assignment state in lattice
229 assigned_out = [Top] * self.block_count
230 conditionally_assigned = set()
231 argcount = (
232 len(self.args)
233 + len(self.kwonlyargs)
234 + bool(self.flags & CO_VARARGS)
235 + bool(self.flags & CO_VARKEYWORDS)
236 )
237 total_locals = num_locals + len(self.cellvars) + len(self.freevars)
238 ArgsAssigned = 2**argcount - 1
239
240 def reverse_local_idx(idx):
241 return total_locals - idx - 1
242
243 def meet(args):
244 result = Top
245 for arg in args:
246 result &= arg
247 return result
248
249 def process_one_block(block, modify=False):
250 bid = block.bid
251 if len(preds[bid]) == 0:
252 # No preds; all parameters are assigned
253 assigned = ArgsAssigned
254 else:
255 # Meet the assigned sets of all predecessors
256 assigned = meet(assigned_out[pred] for pred in preds[bid])
257 for instr in block.getInstructions():
258 if modify and instr.opname == "LOAD_FAST":
259 if assigned & (1 << instr.ioparg):
260 instr.opname = "LOAD_FAST_REVERSE_UNCHECKED"
261 instr.ioparg = reverse_local_idx(instr.ioparg)
262 elif instr.ioparg >= argcount:
263 # Exclude arguments because they come into the function
264 # body assigned. The only thing that can undefine them
265 # is DELETE_FAST.
266 conditionally_assigned.add(instr.oparg)
267 elif instr.opname == "STORE_FAST":
268 assigned |= 1 << instr.ioparg
269 if modify:
270 instr.opname = "STORE_FAST_REVERSE"
271 instr.ioparg = reverse_local_idx(instr.ioparg)
272 elif instr.opname == "DELETE_FAST":
273 assigned &= ~(1 << instr.ioparg)
274 if assigned == assigned_out[bid]:
275 return False
276 assigned_out[bid] = assigned
277 return True
278
279 changed = True
280 while changed:
281 changed = False
282 for block in blocks:
283 changed |= process_one_block(block)
284
285 for block in blocks:
286 process_one_block(block, modify=True)
287
288 if conditionally_assigned:
289 deletes = [
290 Instruction(
291 "DELETE_FAST_REVERSE_UNCHECKED",
292 name,
293 reverse_local_idx(self.varnames.index(name)),
294 )
295 for name in sorted(conditionally_assigned)
296 ]
297 self.entry.insts = deletes + self.entry.insts
298
299 def getCode(self):
300 self.optimizeLoadFast()
301 return super().getCode()
302
303
304class ComprehensionRenamer(ASTVisitor):
305 def __init__(self, scope):
306 super().__init__()
307 # We need a prefix that is unique per-scope for each renaming round.
308 index = getattr(scope, "last_comprehension_rename_index", -1) + 1
309 scope.last_comprehension_rename_index = index
310 self.prefix = f"_gen{str(index) if index > 0 else ''}$"
311 self.new_names = {}
312 self.is_target = False
313
314 def visitName(self, node):
315 if self.is_target and isinstance(node.ctx, (ast.Store, ast.Del)):
316 name = node.id
317 new_name = self.prefix + name
318 self.new_names[name] = new_name
319 node.id = new_name
320 else:
321 new_name = self.new_names.get(node.id)
322 if new_name is not None:
323 node.id = new_name
324
325 def visitarg(self, node):
326 new_name = self.new_names.get(node.arg)
327 if new_name is not None:
328 node.arg = new_name
329
330
331class CollectNames(ASTVisitor):
332 def __init__(self):
333 super().__init__()
334 self.names = set()
335
336 def visitName(self, node):
337 self.names.add(node.id)
338
339 def visitarg(self, node):
340 self.names.add(node.arg)
341
342
343def _can_inline_comprehension(node):
344 can_inline = getattr(node, "can_inline", None)
345 # Bad heuristic: Stop inlining comprehensions when "locals" is used.
346 if can_inline is None:
347 # Do not rename if "locals" is used.
348 visitor = CollectNames()
349 visitor.visit(node)
350 can_inline = "locals" not in visitor.names
351 node.can_inline = can_inline
352 return can_inline
353
354
355class PyroSymbolVisitor(SymbolVisitor):
356 def visitDictCompListCompSetComp(self, node, scope):
357 if not _can_inline_comprehension(node):
358 return super().visitGeneratorExp(node, scope)
359
360 # Check for unexpected assignments.
361 scope.comp_iter_expr += 1
362 self.visit(node.generators[0].iter, scope)
363 scope.comp_iter_expr -= 1
364
365 renamer = ComprehensionRenamer(scope)
366 is_outer = True
367 for gen in node.generators:
368 renamer.visit(gen.iter)
369 renamer.is_target = True
370 renamer.visit(gen.target)
371 renamer.is_target = False
372 for if_node in gen.ifs:
373 renamer.visit(if_node)
374
375 self.visitcomprehension(gen, scope, is_outer)
376 is_outer = False
377
378 if isinstance(node, ast.DictComp):
379 renamer.visit(node.value)
380 renamer.visit(node.key)
381 self.visit(node.value, scope)
382 self.visit(node.key, scope)
383 else:
384 renamer.visit(node.elt)
385 self.visit(node.elt, scope)
386
387 visitDictComp = visitDictCompListCompSetComp
388 visitListComp = visitDictCompListCompSetComp
389 visitSetComp = visitDictCompListCompSetComp
390
391
392class PyroCodeGenerator(Python38CodeGenerator):
393 flow_graph = PyroFlowGraph
394
395 @classmethod
396 def make_code_gen(
397 cls,
398 name: str,
399 tree: AST,
400 filename: str,
401 flags: int,
402 optimize: int,
403 peephole_enabled: bool = True,
404 ast_optimizer_enabled: bool = True,
405 ):
406 if ast_optimizer_enabled:
407 tree = cls.optimize_tree(optimize, tree)
408 s = PyroSymbolVisitor()
409 walk(tree, s)
410
411 graph = cls.flow_graph(
412 name, filename, s.scopes[tree], peephole_enabled=peephole_enabled
413 )
414 code_gen = cls(None, tree, s, graph, flags, optimize)
415 walk(tree, code_gen)
416 return code_gen
417
418 @classmethod
419 def optimize_tree(cls, optimize: int, tree: ast.AST):
420 return AstOptimizerPyro(optimize=optimize > 0).visit(tree)
421
422 def defaultEmitCompare(self, op):
423 if isinstance(op, ast.Is):
424 self.emit("COMPARE_IS")
425 elif isinstance(op, ast.IsNot):
426 self.emit("COMPARE_IS_NOT")
427 else:
428 self.emit("COMPARE_OP", self._cmp_opcode[type(op)])
429
430 def visitListComp(self, node):
431 if not _can_inline_comprehension(node):
432 return super().visitListComp(node)
433 self.emit("BUILD_LIST")
434 self.compile_comprehension_body(node.generators, 0, node.elt, None, type(node))
435
436 def visitSetComp(self, node):
437 if not _can_inline_comprehension(node):
438 return super().visitSetComp(node)
439 self.emit("BUILD_SET")
440 self.compile_comprehension_body(node.generators, 0, node.elt, None, type(node))
441
442 def visitDictComp(self, node):
443 if not _can_inline_comprehension(node):
444 return super().visitDictComp(node)
445 self.emit("BUILD_MAP")
446 self.compile_comprehension_body(
447 node.generators, 0, node.key, node.value, type(node)
448 )
449
450
451def compile(source, filename, mode, flags, dont_inherit, optimize):
452 return compiler_compile(
453 source, filename, mode, flags, None, optimize, PyroCodeGenerator
454 )