this repo has no description
1# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
2# pyre-strict
3from __future__ import annotations
4
5import array
6import random
7from types import FunctionType, Union as typesUnion
8from typing import (
9 _GenericAlias,
10 Dict,
11 Iterable,
12 Mapping,
13 Type,
14 TypeVar,
15 Tuple,
16 Union,
17 _tp_cache,
18)
19from weakref import WeakValueDictionary
20
21try:
22 import _static
23except ImportError:
24
25 def is_type_static(_t):
26 return False
27
28 def set_type_static(_t):
29 return None
30
31 _static = None
32 chkdict = dict
33
34else:
35 chkdict = _static.chkdict
36 set_type_code = _static.set_type_code
37 is_type_static = _static.is_type_static
38 set_type_static = _static.set_type_static
39
40try:
41 from _static import (
42 TYPED_INT8,
43 TYPED_INT16,
44 TYPED_INT32,
45 TYPED_INT64,
46 TYPED_UINT8,
47 TYPED_UINT16,
48 TYPED_UINT32,
49 TYPED_UINT64,
50 TYPED_DOUBLE,
51 TYPED_SINGLE,
52 TYPED_BOOL,
53 TYPED_CHAR,
54 RAND_MAX,
55 rand,
56 )
57except ImportError:
58 TYPED_INT8 = 0
59 TYPED_INT16 = 0
60 TYPED_INT32 = 0
61 TYPED_INT64 = 0
62 TYPED_UINT8 = 0
63 TYPED_UINT16 = 0
64 TYPED_UINT32 = 0
65 TYPED_UINT64 = 0
66 TYPED_DOUBLE = 0
67 TYPED_SINGLE = 0
68 TYPED_BOOL = 0
69 TYPED_CHAR = 0
70 RAND_MAX = (1 << 31) - 1
71
72 def rand():
73 return random.randint(0, RAND_MAX)
74
75
76try:
77 import cinder
78except ImportError:
79 cinder = None
80
81
82def type_code(code: int):
83 def inner(c):
84 if _static is not None:
85 _static.set_type_code(c, code)
86 return c
87
88 return inner
89
90
91pydict = dict
92PyDict = Dict
93
94clen = len
95
96
97@type_code(TYPED_UINT64)
98class size_t(int):
99 pass
100
101
102@type_code(TYPED_INT64)
103class ssize_t(int):
104 pass
105
106
107@type_code(TYPED_INT8)
108class int8(int):
109 pass
110
111
112byte = int8
113
114
115@type_code(TYPED_INT16)
116class int16(int):
117 pass
118
119
120@type_code(TYPED_INT32)
121class int32(int):
122 pass
123
124
125@type_code(TYPED_INT64)
126class int64(int):
127 pass
128
129
130@type_code(TYPED_UINT8)
131class uint8(int):
132 pass
133
134
135@type_code(TYPED_UINT16)
136class uint16(int):
137 pass
138
139
140@type_code(TYPED_UINT32)
141class uint32(int):
142 pass
143
144
145@type_code(TYPED_UINT64)
146class uint64(int):
147 pass
148
149
150@type_code(TYPED_SINGLE)
151class single(float):
152 pass
153
154
155@type_code(TYPED_DOUBLE)
156class double(float):
157 pass
158
159
160@type_code(TYPED_CHAR)
161class char(int):
162 pass
163
164
165@type_code(TYPED_BOOL)
166class cbool(int8):
167 pass
168
169
170ArrayElement = TypeVar(
171 "ArrayElement",
172 int8,
173 int16,
174 int32,
175 int64,
176 uint8,
177 uint16,
178 uint32,
179 uint64,
180 char,
181 float,
182 double,
183)
184
185_TYPE_SIZES = {tc: array.array(tc).itemsize for tc in array.typecodes}
186
187# These should be in sync with the array module
188_TYPE_CODES = {
189 int8: "b",
190 uint8: "B",
191 int16: "h",
192 uint16: "H",
193 # apparently, l is equivalent to q for us, but that may not be true everywhere.
194 int32: "i" if _TYPE_SIZES["i"] == 4 else "l",
195 uint32: "I" if _TYPE_SIZES["I"] == 4 else "L",
196 int64: "q",
197 uint64: "Q",
198 float: "f",
199 double: "d",
200 char: "B",
201}
202
203TVarOrType = Union[TypeVar, Type[object]]
204
205
206def _subs_tvars(
207 tp: Tuple[TVarOrType, ...],
208 tvars: Tuple[TVarOrType, ...],
209 subs: Tuple[TVarOrType, ...],
210) -> Type[object]:
211 """Substitute type variables 'tvars' with substitutions 'subs'.
212 These two must have the same length.
213 """
214 if not hasattr(tp, "__args__"):
215 return tp
216
217 new_args = list(tp.__args__)
218 for a, arg in enumerate(tp.__args__):
219 if isinstance(arg, TypeVar):
220 for i, tvar in enumerate(tvars):
221 if arg == tvar:
222 if (
223 tvar.__constraints__
224 and not isinstance(subs[i], TypeVar)
225 and not issubclass(subs[i], tvar.__constraints__)
226 ):
227 raise TypeError(
228 f"Invalid type for {tvar.__name__}: {subs[i].__name__} when instantiating {tp.__name__}"
229 )
230
231 new_args[a] = subs[i]
232 else:
233 new_args[a] = _subs_tvars(arg, tvars, subs)
234
235 return _replace_types(tp, tuple(new_args))
236
237
238def _collect_type_vars(types: Tuple[TVarOrType, ...]) -> Tuple[TypeVar, ...]:
239 """Collect all type variable contained in types in order of
240 first appearance (lexicographic order). For example::
241
242 _collect_type_vars((T, List[S, T])) == (T, S)
243 """
244 tvars = []
245 for t in types:
246 if isinstance(t, TypeVar) and t not in tvars:
247 tvars.append(t)
248 if hasattr(t, "__parameters__"):
249 tvars.extend([t for t in t.__parameters__ if t not in tvars])
250 return tuple(tvars)
251
252
253def make_generic_type(
254 gen_type: Type[object], params: Tuple[Type[object], ...]
255) -> Type[object]:
256 if len(params) != len(gen_type.__parameters__):
257 raise TypeError(f"Incorrect number of type arguments for {gen_type.__name__}")
258
259 # Substitute params into __args__ replacing instances of __parameters__
260 return _subs_tvars(
261 gen_type,
262 gen_type.__parameters__,
263 params,
264 )
265
266
267def _replace_types(
268 gen_type: Type[object], subs: Tuple[Type[object], ...]
269) -> Type[object]:
270 existing_inst = gen_type.__origin__.__insts__.get(subs)
271
272 if existing_inst is not None:
273 return existing_inst
274
275 # Check if we have a full instantation, and verify the constraints
276 new_dict = dict(gen_type.__dict__)
277 has_params = False
278 for sub in subs:
279 if isinstance(sub, TypeVar) or hasattr(sub, "__parameters__"):
280 has_params = True
281 continue
282
283 # Remove the existing StaticGeneric base...
284 bases = tuple(
285 base for base in gen_type.__orig_bases__ if not isinstance(base, StaticGeneric)
286 )
287
288 new_dict["__args__"] = subs
289 if not has_params:
290 # Instantiated types don't have generic parameters anymore.
291 del new_dict["__parameters__"]
292 else:
293 new_vars = _collect_type_vars(subs)
294 new_gen = StaticGeneric()
295 new_gen.__parameters__ = new_vars
296 new_dict["__orig_bases__"] = bases + (new_gen,)
297 bases += (StaticGeneric,)
298 new_dict["__parameters__"] = new_vars
299
300 # Eventually we'll want to have some processing of the members here to
301 # bind the generics through. That may be an actual process which creates
302 # new objects with the generics bound, or a virtual process. For now
303 # we just propagate the members to the new type.
304 param_names = ", ".join(param.__name__ for param in subs)
305
306 res = type(f"{gen_type.__origin__.__name__}[{param_names}]", bases, new_dict)
307 res.__origin__ = gen_type
308
309 if not has_params:
310 # specialize the type
311 for name, value in new_dict.items():
312 if isinstance(value, FunctionType):
313 if hasattr(value, "__runtime_impl__"):
314 setattr(
315 res,
316 name,
317 _static.specialize_function(res, value.__qualname__, subs),
318 )
319
320 if cinder is not None:
321 cinder.freeze_type(res)
322
323 gen_type.__origin__.__insts__[subs] = res
324 return res
325
326
327def _runtime_impl(f):
328 """marks a generic function as being runtime-implemented"""
329 f.__runtime_impl__ = True
330 return f
331
332
333class StaticGeneric:
334 """Base type used to mark static-Generic classes. Instantations of these
335 classes share different generic types and the generic type arguments can
336 be accessed via __args___"""
337
338 @_tp_cache
339 def __class_getitem__(
340 cls, elem_type: Tuple[Union[TypeVar, Type[object]]]
341 ) -> Union[StaticGeneric, Type[object]]:
342 if not isinstance(elem_type, tuple):
343 # we specifically recurse to hit the type cache
344 return cls[
345 elem_type,
346 ]
347
348 if cls is StaticGeneric:
349 res = StaticGeneric()
350 res.__parameters__ = elem_type
351 return res
352
353 return make_generic_type(cls, elem_type)
354
355 def __init_subclass__(cls) -> None:
356 type_vars = _collect_type_vars(cls.__orig_bases__)
357 cls.__origin__ = cls
358 cls.__parameters__ = type_vars
359 if not hasattr(cls, "__args__"):
360 cls.__args__ = type_vars
361 cls.__insts__ = WeakValueDictionary()
362
363 def __mro_entries__(self, bases) -> Tuple[Type[object, ...]]:
364 return (StaticGeneric,)
365
366 def __repr__(self) -> str:
367 return (
368 "<StaticGeneric: "
369 + ", ".join([param.__name__ for param in self.__parameters__])
370 + ">"
371 )
372
373
374class Array(array.array, StaticGeneric[ArrayElement]):
375 def __new__(cls, initializer: int | Iterable[ArrayElement]):
376 if hasattr(cls, "__parameters__"):
377 raise TypeError("Cannot create plain Array")
378
379 typecode = _TYPE_CODES[cls.__args__[0]]
380 if isinstance(initializer, int):
381 res = array.array.__new__(cls, typecode, [0])
382 res *= initializer
383 return res
384 else:
385 return array.array.__new__(cls, typecode, initializer)
386
387 def __init_subclass__(cls):
388 raise TypeError("Cannot subclass Array")
389
390 def __getitem__(self, index):
391 if isinstance(index, slice):
392 return type(self)(array.array.__getitem__(self, index))
393
394 return array.array.__getitem__(self, index)
395
396 def __deepcopy__(self, memo):
397 return type(self)(self)
398
399
400class Vector(array.array, StaticGeneric[ArrayElement]):
401 """Vector is a resizable array of primitive elements"""
402
403 def __new__(cls, initializer: int | Iterable[ArrayElement] | None = None):
404 if hasattr(cls, "__parameters__"):
405 raise TypeError("Cannot create plain Vector")
406
407 typecode = _TYPE_CODES[cls.__args__[0]]
408 if isinstance(initializer, int):
409 # specifing size
410 res = array.array.__new__(cls, typecode, [0])
411 res *= initializer
412 return res
413 elif initializer is not None:
414 return array.array.__new__(cls, typecode, initializer)
415 else:
416 return array.array.__new__(cls, typecode)
417
418 if _static is not None:
419
420 @_runtime_impl
421 def append(self, value: ArrayElement) -> None:
422 super().append(value)
423
424 def __init_subclass__(cls):
425 raise TypeError("Cannot subclass Vector")
426
427 def __getitem__(self, index):
428 if isinstance(index, slice):
429 return type(self)(array.array.__getitem__(self, index))
430
431 return array.array.__getitem__(self, index)
432
433 def __deepcopy__(self, memo):
434 return type(self)(self)
435
436
437def box(o):
438 return o
439
440
441def unbox(o):
442 return o
443
444
445def allow_weakrefs(klass):
446 return klass
447
448
449def dynamic_return(func):
450 return func
451
452
453def inline(func):
454 return func
455
456
457def _donotcompile(func):
458 return func
459
460
461def cast(typ, val):
462 union_args = None
463 if type(typ) is _GenericAlias:
464 typ, args = typ.__origin__, typ.__args__
465 if typ is Union:
466 union_args = args
467 elif type(typ) is typesUnion:
468 union_args = typ.__args__
469 if union_args:
470 typ = None
471 if len(union_args) == 2:
472 if union_args[0] is type(None):
473 typ = union_args[1]
474 elif union_args[1] is type(None):
475 typ = union_args[0]
476 if typ is None:
477 raise ValueError("cast expects type or Optional[T]")
478 if val is None:
479 return None
480
481 inst_type = type(val)
482 if typ not in inst_type.__mro__:
483 raise TypeError(f"expected {typ.__name__}, got {type(val).__name__}")
484
485 return val
486
487
488CheckedDict = chkdict