this repo has no description
1# Portions copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
2# pyre-unsafe
3import ast
4import operator
5import sys
6from ast import Bytes, Constant, Ellipsis, NameConstant, Num, Str, cmpop, copy_location
7from typing import Dict, Iterable, Optional, Type
8
9from .peephole import safe_lshift, safe_mod, safe_multiply, safe_power
10from .visitor import ASTRewriter
11
12
13def is_const(node):
14 return isinstance(node, (Constant, Num, Str, Bytes, Ellipsis, NameConstant))
15
16
17def get_const_value(node):
18 if isinstance(node, (Constant, NameConstant)):
19 return node.value
20 elif isinstance(node, Num):
21 return node.n
22 elif isinstance(node, (Str, Bytes)):
23 return node.s
24 elif isinstance(node, Ellipsis):
25 return ...
26
27 raise TypeError("Bad constant value")
28
29
30class Py37Limits:
31 MAX_INT_SIZE = 128
32 MAX_COLLECTION_SIZE = 256
33 MAX_STR_SIZE = 4096
34 MAX_TOTAL_ITEMS = 1024
35
36
37UNARY_OPS = {
38 ast.Invert: operator.invert,
39 ast.Not: operator.not_,
40 ast.UAdd: operator.pos,
41 ast.USub: operator.neg,
42}
43INVERSE_OPS: Dict[Type[cmpop], Type[cmpop]] = {
44 ast.Is: ast.IsNot,
45 ast.IsNot: ast.Is,
46 ast.In: ast.NotIn,
47 ast.NotIn: ast.In,
48}
49
50BIN_OPS = {
51 ast.Add: operator.add,
52 ast.Sub: operator.sub,
53 ast.Mult: lambda l, r: safe_multiply(l, r, Py37Limits),
54 ast.Div: operator.truediv,
55 ast.FloorDiv: operator.floordiv,
56 ast.Mod: lambda l, r: safe_mod(l, r, Py37Limits),
57 ast.Pow: lambda l, r: safe_power(l, r, Py37Limits),
58 ast.LShift: lambda l, r: safe_lshift(l, r, Py37Limits),
59 ast.RShift: operator.rshift,
60 ast.BitOr: operator.or_,
61 ast.BitXor: operator.xor,
62 ast.BitAnd: operator.and_,
63}
64
65IS_PY38_ABOVE = sys.version_info >= (3, 8)
66
67
68class AstOptimizer(ASTRewriter):
69 def __init__(self, optimize: bool = False):
70 super().__init__()
71 self.optimize = optimize
72
73 def visitUnaryOp(self, node: ast.UnaryOp) -> ast.expr:
74 op = self.visit(node.operand)
75 if is_const(op):
76 conv = UNARY_OPS[type(node.op)]
77 val = get_const_value(op)
78 try:
79 return copy_location(Constant(conv(val)), node)
80 except Exception:
81 pass
82 elif (
83 isinstance(node.op, ast.Not)
84 and isinstance(op, ast.Compare)
85 and len(op.ops) == 1
86 ):
87 cmp_op = op.ops[0]
88 new_op = INVERSE_OPS.get(type(cmp_op))
89 if new_op is not None:
90 return self.update_node(op, ops=[new_op()])
91
92 return self.update_node(node, operand=op)
93
94 def visitBinOp(self, node: ast.BinOp) -> ast.expr:
95 left = self.visit(node.left)
96 right = self.visit(node.right)
97
98 if is_const(left) and is_const(right):
99 handler = BIN_OPS.get(type(node.op))
100 if handler is not None:
101 lval = get_const_value(left)
102 rval = get_const_value(right)
103 try:
104 return copy_location(Constant(handler(lval, rval)), node)
105 except Exception:
106 pass
107
108 return self.update_node(node, left=left, right=right)
109
110 def makeConstTuple(self, elts: Iterable[ast.expr]) -> Optional[Constant]:
111 if all(is_const(elt) for elt in elts):
112 return Constant(tuple(get_const_value(elt) for elt in elts))
113
114 return None
115
116 def visitTuple(self, node: ast.Tuple) -> ast.expr:
117 elts = self.walk_list(node.elts)
118
119 if isinstance(node.ctx, ast.Load):
120 res = self.makeConstTuple(elts)
121 if res is not None:
122 return copy_location(res, node)
123
124 return self.update_node(node, elts=elts)
125
126 def visitSubscript(self, node: ast.Subscript) -> ast.expr:
127 value = self.visit(node.value)
128 slice = self.visit(node.slice)
129
130 if (
131 isinstance(node.ctx, ast.Load)
132 and is_const(value)
133 and isinstance(slice, ast.Index)
134 and is_const(slice.value)
135 ):
136 try:
137 return copy_location(
138 Constant(get_const_value(value)[get_const_value(slice.value)]), node
139 )
140 except Exception:
141 pass
142
143 return self.update_node(node, value=value, slice=slice)
144
145 def _visitIter(self, node: ast.expr) -> ast.expr:
146 if isinstance(node, ast.List):
147 elts = self.walk_list(node.elts)
148 res = self.makeConstTuple(elts)
149 if res is not None:
150 return copy_location(res, node)
151 if IS_PY38_ABOVE and not any(isinstance(e, ast.Starred) for e in elts):
152 return self.update_node(ast.Tuple(elts=elts, ctx=node.ctx))
153 return self.update_node(node, elts=elts)
154 elif isinstance(node, ast.Set):
155 elts = self.walk_list(node.elts)
156 res = self.makeConstTuple(elts)
157 if res is not None:
158 return copy_location(Constant(frozenset(res.value)), node)
159
160 return self.update_node(node, elts=elts)
161
162 return self.generic_visit(node)
163
164 def visitcomprehension(self, node: ast.comprehension) -> ast.comprehension:
165 target = self.visit(node.target)
166 iter = self.visit(node.iter)
167 ifs = self.walk_list(node.ifs)
168 iter = self._visitIter(iter)
169
170 return self.update_node(node, target=target, iter=iter, ifs=ifs)
171
172 def visitFor(self, node: ast.For) -> ast.For:
173 target = self.visit(node.target)
174 iter = self.visit(node.iter)
175 body = self.walk_list(node.body)
176 orelse = self.walk_list(node.orelse)
177
178 iter = self._visitIter(iter)
179 return self.update_node(
180 node, target=target, iter=iter, body=body, orelse=orelse
181 )
182
183 def visitCompare(self, node: ast.Compare) -> ast.expr:
184 left = self.visit(node.left)
185 comparators = self.walk_list(node.comparators)
186
187 if isinstance(node.ops[-1], (ast.In, ast.NotIn)):
188 new_iter = self._visitIter(comparators[-1])
189 if new_iter is not None and new_iter is not comparators[-1]:
190 comparators = list(comparators)
191 comparators[-1] = new_iter
192
193 return self.update_node(node, left=left, comparators=comparators)
194
195 def visitName(self, node: ast.Name):
196 if node.id == "__debug__":
197 return copy_location(Constant(not self.optimize), node)
198
199 return self.generic_visit(node)
200
201 def visitAssert(self, node: ast.Assert):
202 if self.optimize:
203 # Skip asserts if we're optimizing
204 return None
205 return self.generic_visit(node)