this repo has no description
at trunk 205 lines 6.7 kB view raw
1# Portions copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) 2# pyre-unsafe 3import ast 4import operator 5import sys 6from ast import Bytes, Constant, Ellipsis, NameConstant, Num, Str, cmpop, copy_location 7from typing import Dict, Iterable, Optional, Type 8 9from .peephole import safe_lshift, safe_mod, safe_multiply, safe_power 10from .visitor import ASTRewriter 11 12 13def is_const(node): 14 return isinstance(node, (Constant, Num, Str, Bytes, Ellipsis, NameConstant)) 15 16 17def get_const_value(node): 18 if isinstance(node, (Constant, NameConstant)): 19 return node.value 20 elif isinstance(node, Num): 21 return node.n 22 elif isinstance(node, (Str, Bytes)): 23 return node.s 24 elif isinstance(node, Ellipsis): 25 return ... 26 27 raise TypeError("Bad constant value") 28 29 30class Py37Limits: 31 MAX_INT_SIZE = 128 32 MAX_COLLECTION_SIZE = 256 33 MAX_STR_SIZE = 4096 34 MAX_TOTAL_ITEMS = 1024 35 36 37UNARY_OPS = { 38 ast.Invert: operator.invert, 39 ast.Not: operator.not_, 40 ast.UAdd: operator.pos, 41 ast.USub: operator.neg, 42} 43INVERSE_OPS: Dict[Type[cmpop], Type[cmpop]] = { 44 ast.Is: ast.IsNot, 45 ast.IsNot: ast.Is, 46 ast.In: ast.NotIn, 47 ast.NotIn: ast.In, 48} 49 50BIN_OPS = { 51 ast.Add: operator.add, 52 ast.Sub: operator.sub, 53 ast.Mult: lambda l, r: safe_multiply(l, r, Py37Limits), 54 ast.Div: operator.truediv, 55 ast.FloorDiv: operator.floordiv, 56 ast.Mod: lambda l, r: safe_mod(l, r, Py37Limits), 57 ast.Pow: lambda l, r: safe_power(l, r, Py37Limits), 58 ast.LShift: lambda l, r: safe_lshift(l, r, Py37Limits), 59 ast.RShift: operator.rshift, 60 ast.BitOr: operator.or_, 61 ast.BitXor: operator.xor, 62 ast.BitAnd: operator.and_, 63} 64 65IS_PY38_ABOVE = sys.version_info >= (3, 8) 66 67 68class AstOptimizer(ASTRewriter): 69 def __init__(self, optimize: bool = False): 70 super().__init__() 71 self.optimize = optimize 72 73 def visitUnaryOp(self, node: ast.UnaryOp) -> ast.expr: 74 op = self.visit(node.operand) 75 if is_const(op): 76 conv = UNARY_OPS[type(node.op)] 77 val = get_const_value(op) 78 try: 79 return copy_location(Constant(conv(val)), node) 80 except Exception: 81 pass 82 elif ( 83 isinstance(node.op, ast.Not) 84 and isinstance(op, ast.Compare) 85 and len(op.ops) == 1 86 ): 87 cmp_op = op.ops[0] 88 new_op = INVERSE_OPS.get(type(cmp_op)) 89 if new_op is not None: 90 return self.update_node(op, ops=[new_op()]) 91 92 return self.update_node(node, operand=op) 93 94 def visitBinOp(self, node: ast.BinOp) -> ast.expr: 95 left = self.visit(node.left) 96 right = self.visit(node.right) 97 98 if is_const(left) and is_const(right): 99 handler = BIN_OPS.get(type(node.op)) 100 if handler is not None: 101 lval = get_const_value(left) 102 rval = get_const_value(right) 103 try: 104 return copy_location(Constant(handler(lval, rval)), node) 105 except Exception: 106 pass 107 108 return self.update_node(node, left=left, right=right) 109 110 def makeConstTuple(self, elts: Iterable[ast.expr]) -> Optional[Constant]: 111 if all(is_const(elt) for elt in elts): 112 return Constant(tuple(get_const_value(elt) for elt in elts)) 113 114 return None 115 116 def visitTuple(self, node: ast.Tuple) -> ast.expr: 117 elts = self.walk_list(node.elts) 118 119 if isinstance(node.ctx, ast.Load): 120 res = self.makeConstTuple(elts) 121 if res is not None: 122 return copy_location(res, node) 123 124 return self.update_node(node, elts=elts) 125 126 def visitSubscript(self, node: ast.Subscript) -> ast.expr: 127 value = self.visit(node.value) 128 slice = self.visit(node.slice) 129 130 if ( 131 isinstance(node.ctx, ast.Load) 132 and is_const(value) 133 and isinstance(slice, ast.Index) 134 and is_const(slice.value) 135 ): 136 try: 137 return copy_location( 138 Constant(get_const_value(value)[get_const_value(slice.value)]), node 139 ) 140 except Exception: 141 pass 142 143 return self.update_node(node, value=value, slice=slice) 144 145 def _visitIter(self, node: ast.expr) -> ast.expr: 146 if isinstance(node, ast.List): 147 elts = self.walk_list(node.elts) 148 res = self.makeConstTuple(elts) 149 if res is not None: 150 return copy_location(res, node) 151 if IS_PY38_ABOVE and not any(isinstance(e, ast.Starred) for e in elts): 152 return self.update_node(ast.Tuple(elts=elts, ctx=node.ctx)) 153 return self.update_node(node, elts=elts) 154 elif isinstance(node, ast.Set): 155 elts = self.walk_list(node.elts) 156 res = self.makeConstTuple(elts) 157 if res is not None: 158 return copy_location(Constant(frozenset(res.value)), node) 159 160 return self.update_node(node, elts=elts) 161 162 return self.generic_visit(node) 163 164 def visitcomprehension(self, node: ast.comprehension) -> ast.comprehension: 165 target = self.visit(node.target) 166 iter = self.visit(node.iter) 167 ifs = self.walk_list(node.ifs) 168 iter = self._visitIter(iter) 169 170 return self.update_node(node, target=target, iter=iter, ifs=ifs) 171 172 def visitFor(self, node: ast.For) -> ast.For: 173 target = self.visit(node.target) 174 iter = self.visit(node.iter) 175 body = self.walk_list(node.body) 176 orelse = self.walk_list(node.orelse) 177 178 iter = self._visitIter(iter) 179 return self.update_node( 180 node, target=target, iter=iter, body=body, orelse=orelse 181 ) 182 183 def visitCompare(self, node: ast.Compare) -> ast.expr: 184 left = self.visit(node.left) 185 comparators = self.walk_list(node.comparators) 186 187 if isinstance(node.ops[-1], (ast.In, ast.NotIn)): 188 new_iter = self._visitIter(comparators[-1]) 189 if new_iter is not None and new_iter is not comparators[-1]: 190 comparators = list(comparators) 191 comparators[-1] = new_iter 192 193 return self.update_node(node, left=left, comparators=comparators) 194 195 def visitName(self, node: ast.Name): 196 if node.id == "__debug__": 197 return copy_location(Constant(not self.optimize), node) 198 199 return self.generic_visit(node) 200 201 def visitAssert(self, node: ast.Assert): 202 if self.optimize: 203 # Skip asserts if we're optimizing 204 return None 205 return self.generic_visit(node)