this repo has no description
at trunk 454 lines 16 kB view raw
1# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) 2import ast 3from ast import AST 4from compiler import compile as compiler_compile 5from compiler.consts import CO_VARARGS, CO_VARKEYWORDS 6from compiler.optimizer import BIN_OPS, is_const, get_const_value 7from compiler.py38.optimizer import AstOptimizer38 8from compiler.pyassem import PyFlowGraph38, Instruction 9from compiler.pycodegen import Python38CodeGenerator 10from compiler.symbols import SymbolVisitor 11from compiler.visitor import ASTVisitor, walk 12 13import _compiler_opcode as opcodepyro 14 15 16def should_rewrite_printf(node): 17 return isinstance(node.left, ast.Str) and isinstance(node.op, ast.Mod) 18 19 20def create_conversion_call(name, value): 21 method = ast.Attribute(ast.Str(""), name, ast.Load()) 22 return ast.Call(method, args=[value], keywords=[]) 23 24 25def try_constant_fold_mod(format_string, right): 26 r = get_const_value(right) 27 return ast.Str(format_string.__mod__(r)) 28 29 30class AstOptimizerPyro(AstOptimizer38): 31 def rewrite_str_mod(self, left, right): # noqa: C901 32 format_string = left.s 33 try: 34 if is_const(right): 35 return try_constant_fold_mod(format_string, right) 36 # Try and collapse the whole expression into a string 37 const_tuple = self.makeConstTuple(right.elts) 38 if const_tuple: 39 return ast.Str(format_string.__mod__(const_tuple.value)) 40 except Exception: 41 pass 42 n_specifiers = 0 43 i = 0 44 length = len(format_string) 45 while i < length: 46 i = format_string.find("%", i) 47 if i == -1: 48 break 49 ch = format_string[i] 50 i += 1 51 52 if i >= length: 53 # Invalid format string ending in a single percent 54 return None 55 ch = format_string[i] 56 i += 1 57 if ch == "%": 58 # Break the string apart at '%' 59 continue 60 elif ch == "(": 61 # We don't support dict lookups and may get confused from 62 # inner '%' chars 63 return None 64 n_specifiers += 1 65 66 rhs = right 67 if isinstance(right, ast.Tuple): 68 rhs_values = rhs.elts 69 num_values = len(rhs_values) 70 else: 71 # If RHS is not a tuple constructor, then we only support the 72 # situation with a single format specifier in the string, by 73 # normalizing `rhs` to a one-element tuple: 74 # `_mod_check_single_arg(rhs)[0]` 75 rhs_values = None 76 if n_specifiers != 1: 77 return None 78 num_values = 1 79 i = 0 80 value_idx = 0 81 segment_begin = 0 82 strings = [] 83 while i < length: 84 i = format_string.find("%", i) 85 if i == -1: 86 break 87 ch = format_string[i] 88 i += 1 89 90 segment_end = i - 1 91 if segment_end - segment_begin > 0: 92 substr = format_string[segment_begin:segment_end] 93 strings.append(ast.Str(substr)) 94 95 if i >= length: 96 return None 97 ch = format_string[i] 98 i += 1 99 100 # Parse flags and width 101 spec_begin = i - 1 102 have_width = False 103 while True: 104 if ch == "0": 105 # TODO(matthiasb): Support ' ', '+', '#', etc 106 # They mostly have the same meaning. However they can 107 # appear in any order here but must follow stricter 108 # conventions in f-strings. 109 if i >= length: 110 return None 111 ch = format_string[i] 112 i += 1 113 continue 114 break 115 if "1" <= ch <= "9": 116 have_width = True 117 if i >= length: 118 return None 119 ch = format_string[i] 120 i += 1 121 while "0" <= ch <= "9": 122 if i >= length: 123 return None 124 ch = format_string[i] 125 i += 1 126 spec_str = "" 127 if i - 1 - spec_begin > 0: 128 spec_str = format_string[spec_begin : i - 1] 129 130 if ch == "%": 131 # Handle '%%' 132 segment_begin = i - 1 133 continue 134 135 # Handle remaining supported cases that use a value from RHS 136 if rhs_values is not None: 137 if value_idx >= num_values: 138 return None 139 value = rhs_values[value_idx] 140 else: 141 # We have a situation like `"%s" % x` without tuple on RHS. 142 # Transform to: f"{''._mod_check_single_arg(x)[0]}" 143 converted = create_conversion_call("_mod_check_single_arg", rhs) 144 value = ast.Subscript(converted, ast.Index(ast.Num(0)), ast.Load()) 145 value_idx += 1 146 147 if ch in "sra": 148 # Rewrite "%s" % (x,) to f"{x!s}" 149 if have_width: 150 # Need to explicitly specify alignment because `%5s` 151 # aligns right, while `f"{x:5}"` aligns left. 152 spec_str = ">" + spec_str 153 format_spec = ast.Str(spec_str) if spec_str else None 154 formatted = ast.FormattedValue(value, ord(ch), format_spec) 155 strings.append(formatted) 156 elif ch in "diu": 157 # Rewrite "%d" % (x,) to f"{''._mod_convert_number_int(x)}". 158 # Calling a method on the empty string is a hack to access a 159 # well-known function regardless of the surrounding 160 # environment. 161 converted = create_conversion_call("_mod_convert_number_int", value) 162 format_spec = ast.Str(spec_str) if spec_str else None 163 formatted = ast.FormattedValue(converted, -1, format_spec) 164 strings.append(formatted) 165 elif ch in "xXo": 166 # Rewrite "%x" % (v,) to f"{''._mod_convert_number_index(v):x}". 167 # Calling a method on the empty string is a hack to access a 168 # well-known function regardless of the surrounding 169 # environment. 170 converted = create_conversion_call("_mod_convert_number_index", value) 171 format_spec = ast.Str(spec_str + ch) 172 formatted = ast.FormattedValue(converted, -1, format_spec) 173 strings.append(formatted) 174 else: 175 return None 176 # Begin next segment after specifier 177 segment_begin = i 178 179 if value_idx != num_values: 180 return None 181 182 segment_end = length 183 if segment_end - segment_begin > 0: 184 substr = format_string[segment_begin:segment_end] 185 strings.append(ast.Str(substr)) 186 187 return ast.JoinedStr(strings) 188 189 def visitBinOp(self, node: ast.BinOp) -> ast.expr: 190 left = self.visit(node.left) 191 right = self.visit(node.right) 192 193 if is_const(left) and is_const(right): 194 handler = BIN_OPS.get(type(node.op)) 195 if handler is not None: 196 lval = get_const_value(left) 197 rval = get_const_value(right) 198 try: 199 return ast.copy_location(ast.Constant(handler(lval, rval)), node) 200 except Exception: 201 pass 202 203 if should_rewrite_printf(node): 204 result = self.rewrite_str_mod(left, right) 205 if result: 206 return self.visit(result) 207 208 return self.update_node(node, left=left, right=right) 209 210 211class PyroFlowGraph(PyFlowGraph38): 212 opcode = opcodepyro.opcode 213 214 def optimizeLoadFast(self): 215 blocks = self.getBlocksInOrder() 216 preds = tuple(set() for i in range(self.block_count)) 217 for block in blocks: 218 for child in block.get_children(): 219 if child is not None: 220 # TODO(emacs): Tail-duplicate finally blocks or upgrade to 221 # 3.10, which does this already. This avoids except blocks 222 # falling through into else blocks and mucking up 223 # performance. 224 preds[child.bid].add(block.bid) 225 226 num_locals = len(self.varnames) 227 Top = 2**num_locals - 1 228 # map of block id -> assignment state in lattice 229 assigned_out = [Top] * self.block_count 230 conditionally_assigned = set() 231 argcount = ( 232 len(self.args) 233 + len(self.kwonlyargs) 234 + bool(self.flags & CO_VARARGS) 235 + bool(self.flags & CO_VARKEYWORDS) 236 ) 237 total_locals = num_locals + len(self.cellvars) + len(self.freevars) 238 ArgsAssigned = 2**argcount - 1 239 240 def reverse_local_idx(idx): 241 return total_locals - idx - 1 242 243 def meet(args): 244 result = Top 245 for arg in args: 246 result &= arg 247 return result 248 249 def process_one_block(block, modify=False): 250 bid = block.bid 251 if len(preds[bid]) == 0: 252 # No preds; all parameters are assigned 253 assigned = ArgsAssigned 254 else: 255 # Meet the assigned sets of all predecessors 256 assigned = meet(assigned_out[pred] for pred in preds[bid]) 257 for instr in block.getInstructions(): 258 if modify and instr.opname == "LOAD_FAST": 259 if assigned & (1 << instr.ioparg): 260 instr.opname = "LOAD_FAST_REVERSE_UNCHECKED" 261 instr.ioparg = reverse_local_idx(instr.ioparg) 262 elif instr.ioparg >= argcount: 263 # Exclude arguments because they come into the function 264 # body assigned. The only thing that can undefine them 265 # is DELETE_FAST. 266 conditionally_assigned.add(instr.oparg) 267 elif instr.opname == "STORE_FAST": 268 assigned |= 1 << instr.ioparg 269 if modify: 270 instr.opname = "STORE_FAST_REVERSE" 271 instr.ioparg = reverse_local_idx(instr.ioparg) 272 elif instr.opname == "DELETE_FAST": 273 assigned &= ~(1 << instr.ioparg) 274 if assigned == assigned_out[bid]: 275 return False 276 assigned_out[bid] = assigned 277 return True 278 279 changed = True 280 while changed: 281 changed = False 282 for block in blocks: 283 changed |= process_one_block(block) 284 285 for block in blocks: 286 process_one_block(block, modify=True) 287 288 if conditionally_assigned: 289 deletes = [ 290 Instruction( 291 "DELETE_FAST_REVERSE_UNCHECKED", 292 name, 293 reverse_local_idx(self.varnames.index(name)), 294 ) 295 for name in sorted(conditionally_assigned) 296 ] 297 self.entry.insts = deletes + self.entry.insts 298 299 def getCode(self): 300 self.optimizeLoadFast() 301 return super().getCode() 302 303 304class ComprehensionRenamer(ASTVisitor): 305 def __init__(self, scope): 306 super().__init__() 307 # We need a prefix that is unique per-scope for each renaming round. 308 index = getattr(scope, "last_comprehension_rename_index", -1) + 1 309 scope.last_comprehension_rename_index = index 310 self.prefix = f"_gen{str(index) if index > 0 else ''}$" 311 self.new_names = {} 312 self.is_target = False 313 314 def visitName(self, node): 315 if self.is_target and isinstance(node.ctx, (ast.Store, ast.Del)): 316 name = node.id 317 new_name = self.prefix + name 318 self.new_names[name] = new_name 319 node.id = new_name 320 else: 321 new_name = self.new_names.get(node.id) 322 if new_name is not None: 323 node.id = new_name 324 325 def visitarg(self, node): 326 new_name = self.new_names.get(node.arg) 327 if new_name is not None: 328 node.arg = new_name 329 330 331class CollectNames(ASTVisitor): 332 def __init__(self): 333 super().__init__() 334 self.names = set() 335 336 def visitName(self, node): 337 self.names.add(node.id) 338 339 def visitarg(self, node): 340 self.names.add(node.arg) 341 342 343def _can_inline_comprehension(node): 344 can_inline = getattr(node, "can_inline", None) 345 # Bad heuristic: Stop inlining comprehensions when "locals" is used. 346 if can_inline is None: 347 # Do not rename if "locals" is used. 348 visitor = CollectNames() 349 visitor.visit(node) 350 can_inline = "locals" not in visitor.names 351 node.can_inline = can_inline 352 return can_inline 353 354 355class PyroSymbolVisitor(SymbolVisitor): 356 def visitDictCompListCompSetComp(self, node, scope): 357 if not _can_inline_comprehension(node): 358 return super().visitGeneratorExp(node, scope) 359 360 # Check for unexpected assignments. 361 scope.comp_iter_expr += 1 362 self.visit(node.generators[0].iter, scope) 363 scope.comp_iter_expr -= 1 364 365 renamer = ComprehensionRenamer(scope) 366 is_outer = True 367 for gen in node.generators: 368 renamer.visit(gen.iter) 369 renamer.is_target = True 370 renamer.visit(gen.target) 371 renamer.is_target = False 372 for if_node in gen.ifs: 373 renamer.visit(if_node) 374 375 self.visitcomprehension(gen, scope, is_outer) 376 is_outer = False 377 378 if isinstance(node, ast.DictComp): 379 renamer.visit(node.value) 380 renamer.visit(node.key) 381 self.visit(node.value, scope) 382 self.visit(node.key, scope) 383 else: 384 renamer.visit(node.elt) 385 self.visit(node.elt, scope) 386 387 visitDictComp = visitDictCompListCompSetComp 388 visitListComp = visitDictCompListCompSetComp 389 visitSetComp = visitDictCompListCompSetComp 390 391 392class PyroCodeGenerator(Python38CodeGenerator): 393 flow_graph = PyroFlowGraph 394 395 @classmethod 396 def make_code_gen( 397 cls, 398 name: str, 399 tree: AST, 400 filename: str, 401 flags: int, 402 optimize: int, 403 peephole_enabled: bool = True, 404 ast_optimizer_enabled: bool = True, 405 ): 406 if ast_optimizer_enabled: 407 tree = cls.optimize_tree(optimize, tree) 408 s = PyroSymbolVisitor() 409 walk(tree, s) 410 411 graph = cls.flow_graph( 412 name, filename, s.scopes[tree], peephole_enabled=peephole_enabled 413 ) 414 code_gen = cls(None, tree, s, graph, flags, optimize) 415 walk(tree, code_gen) 416 return code_gen 417 418 @classmethod 419 def optimize_tree(cls, optimize: int, tree: ast.AST): 420 return AstOptimizerPyro(optimize=optimize > 0).visit(tree) 421 422 def defaultEmitCompare(self, op): 423 if isinstance(op, ast.Is): 424 self.emit("COMPARE_IS") 425 elif isinstance(op, ast.IsNot): 426 self.emit("COMPARE_IS_NOT") 427 else: 428 self.emit("COMPARE_OP", self._cmp_opcode[type(op)]) 429 430 def visitListComp(self, node): 431 if not _can_inline_comprehension(node): 432 return super().visitListComp(node) 433 self.emit("BUILD_LIST") 434 self.compile_comprehension_body(node.generators, 0, node.elt, None, type(node)) 435 436 def visitSetComp(self, node): 437 if not _can_inline_comprehension(node): 438 return super().visitSetComp(node) 439 self.emit("BUILD_SET") 440 self.compile_comprehension_body(node.generators, 0, node.elt, None, type(node)) 441 442 def visitDictComp(self, node): 443 if not _can_inline_comprehension(node): 444 return super().visitDictComp(node) 445 self.emit("BUILD_MAP") 446 self.compile_comprehension_body( 447 node.generators, 0, node.key, node.value, type(node) 448 ) 449 450 451def compile(source, filename, mode, flags, dont_inherit, optimize): 452 return compiler_compile( 453 source, filename, mode, flags, None, optimize, PyroCodeGenerator 454 )