this repo has no description
at trunk 488 lines 12 kB view raw
1# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) 2# pyre-strict 3from __future__ import annotations 4 5import array 6import random 7from types import FunctionType, Union as typesUnion 8from typing import ( 9 _GenericAlias, 10 Dict, 11 Iterable, 12 Mapping, 13 Type, 14 TypeVar, 15 Tuple, 16 Union, 17 _tp_cache, 18) 19from weakref import WeakValueDictionary 20 21try: 22 import _static 23except ImportError: 24 25 def is_type_static(_t): 26 return False 27 28 def set_type_static(_t): 29 return None 30 31 _static = None 32 chkdict = dict 33 34else: 35 chkdict = _static.chkdict 36 set_type_code = _static.set_type_code 37 is_type_static = _static.is_type_static 38 set_type_static = _static.set_type_static 39 40try: 41 from _static import ( 42 TYPED_INT8, 43 TYPED_INT16, 44 TYPED_INT32, 45 TYPED_INT64, 46 TYPED_UINT8, 47 TYPED_UINT16, 48 TYPED_UINT32, 49 TYPED_UINT64, 50 TYPED_DOUBLE, 51 TYPED_SINGLE, 52 TYPED_BOOL, 53 TYPED_CHAR, 54 RAND_MAX, 55 rand, 56 ) 57except ImportError: 58 TYPED_INT8 = 0 59 TYPED_INT16 = 0 60 TYPED_INT32 = 0 61 TYPED_INT64 = 0 62 TYPED_UINT8 = 0 63 TYPED_UINT16 = 0 64 TYPED_UINT32 = 0 65 TYPED_UINT64 = 0 66 TYPED_DOUBLE = 0 67 TYPED_SINGLE = 0 68 TYPED_BOOL = 0 69 TYPED_CHAR = 0 70 RAND_MAX = (1 << 31) - 1 71 72 def rand(): 73 return random.randint(0, RAND_MAX) 74 75 76try: 77 import cinder 78except ImportError: 79 cinder = None 80 81 82def type_code(code: int): 83 def inner(c): 84 if _static is not None: 85 _static.set_type_code(c, code) 86 return c 87 88 return inner 89 90 91pydict = dict 92PyDict = Dict 93 94clen = len 95 96 97@type_code(TYPED_UINT64) 98class size_t(int): 99 pass 100 101 102@type_code(TYPED_INT64) 103class ssize_t(int): 104 pass 105 106 107@type_code(TYPED_INT8) 108class int8(int): 109 pass 110 111 112byte = int8 113 114 115@type_code(TYPED_INT16) 116class int16(int): 117 pass 118 119 120@type_code(TYPED_INT32) 121class int32(int): 122 pass 123 124 125@type_code(TYPED_INT64) 126class int64(int): 127 pass 128 129 130@type_code(TYPED_UINT8) 131class uint8(int): 132 pass 133 134 135@type_code(TYPED_UINT16) 136class uint16(int): 137 pass 138 139 140@type_code(TYPED_UINT32) 141class uint32(int): 142 pass 143 144 145@type_code(TYPED_UINT64) 146class uint64(int): 147 pass 148 149 150@type_code(TYPED_SINGLE) 151class single(float): 152 pass 153 154 155@type_code(TYPED_DOUBLE) 156class double(float): 157 pass 158 159 160@type_code(TYPED_CHAR) 161class char(int): 162 pass 163 164 165@type_code(TYPED_BOOL) 166class cbool(int8): 167 pass 168 169 170ArrayElement = TypeVar( 171 "ArrayElement", 172 int8, 173 int16, 174 int32, 175 int64, 176 uint8, 177 uint16, 178 uint32, 179 uint64, 180 char, 181 float, 182 double, 183) 184 185_TYPE_SIZES = {tc: array.array(tc).itemsize for tc in array.typecodes} 186 187# These should be in sync with the array module 188_TYPE_CODES = { 189 int8: "b", 190 uint8: "B", 191 int16: "h", 192 uint16: "H", 193 # apparently, l is equivalent to q for us, but that may not be true everywhere. 194 int32: "i" if _TYPE_SIZES["i"] == 4 else "l", 195 uint32: "I" if _TYPE_SIZES["I"] == 4 else "L", 196 int64: "q", 197 uint64: "Q", 198 float: "f", 199 double: "d", 200 char: "B", 201} 202 203TVarOrType = Union[TypeVar, Type[object]] 204 205 206def _subs_tvars( 207 tp: Tuple[TVarOrType, ...], 208 tvars: Tuple[TVarOrType, ...], 209 subs: Tuple[TVarOrType, ...], 210) -> Type[object]: 211 """Substitute type variables 'tvars' with substitutions 'subs'. 212 These two must have the same length. 213 """ 214 if not hasattr(tp, "__args__"): 215 return tp 216 217 new_args = list(tp.__args__) 218 for a, arg in enumerate(tp.__args__): 219 if isinstance(arg, TypeVar): 220 for i, tvar in enumerate(tvars): 221 if arg == tvar: 222 if ( 223 tvar.__constraints__ 224 and not isinstance(subs[i], TypeVar) 225 and not issubclass(subs[i], tvar.__constraints__) 226 ): 227 raise TypeError( 228 f"Invalid type for {tvar.__name__}: {subs[i].__name__} when instantiating {tp.__name__}" 229 ) 230 231 new_args[a] = subs[i] 232 else: 233 new_args[a] = _subs_tvars(arg, tvars, subs) 234 235 return _replace_types(tp, tuple(new_args)) 236 237 238def _collect_type_vars(types: Tuple[TVarOrType, ...]) -> Tuple[TypeVar, ...]: 239 """Collect all type variable contained in types in order of 240 first appearance (lexicographic order). For example:: 241 242 _collect_type_vars((T, List[S, T])) == (T, S) 243 """ 244 tvars = [] 245 for t in types: 246 if isinstance(t, TypeVar) and t not in tvars: 247 tvars.append(t) 248 if hasattr(t, "__parameters__"): 249 tvars.extend([t for t in t.__parameters__ if t not in tvars]) 250 return tuple(tvars) 251 252 253def make_generic_type( 254 gen_type: Type[object], params: Tuple[Type[object], ...] 255) -> Type[object]: 256 if len(params) != len(gen_type.__parameters__): 257 raise TypeError(f"Incorrect number of type arguments for {gen_type.__name__}") 258 259 # Substitute params into __args__ replacing instances of __parameters__ 260 return _subs_tvars( 261 gen_type, 262 gen_type.__parameters__, 263 params, 264 ) 265 266 267def _replace_types( 268 gen_type: Type[object], subs: Tuple[Type[object], ...] 269) -> Type[object]: 270 existing_inst = gen_type.__origin__.__insts__.get(subs) 271 272 if existing_inst is not None: 273 return existing_inst 274 275 # Check if we have a full instantation, and verify the constraints 276 new_dict = dict(gen_type.__dict__) 277 has_params = False 278 for sub in subs: 279 if isinstance(sub, TypeVar) or hasattr(sub, "__parameters__"): 280 has_params = True 281 continue 282 283 # Remove the existing StaticGeneric base... 284 bases = tuple( 285 base for base in gen_type.__orig_bases__ if not isinstance(base, StaticGeneric) 286 ) 287 288 new_dict["__args__"] = subs 289 if not has_params: 290 # Instantiated types don't have generic parameters anymore. 291 del new_dict["__parameters__"] 292 else: 293 new_vars = _collect_type_vars(subs) 294 new_gen = StaticGeneric() 295 new_gen.__parameters__ = new_vars 296 new_dict["__orig_bases__"] = bases + (new_gen,) 297 bases += (StaticGeneric,) 298 new_dict["__parameters__"] = new_vars 299 300 # Eventually we'll want to have some processing of the members here to 301 # bind the generics through. That may be an actual process which creates 302 # new objects with the generics bound, or a virtual process. For now 303 # we just propagate the members to the new type. 304 param_names = ", ".join(param.__name__ for param in subs) 305 306 res = type(f"{gen_type.__origin__.__name__}[{param_names}]", bases, new_dict) 307 res.__origin__ = gen_type 308 309 if not has_params: 310 # specialize the type 311 for name, value in new_dict.items(): 312 if isinstance(value, FunctionType): 313 if hasattr(value, "__runtime_impl__"): 314 setattr( 315 res, 316 name, 317 _static.specialize_function(res, value.__qualname__, subs), 318 ) 319 320 if cinder is not None: 321 cinder.freeze_type(res) 322 323 gen_type.__origin__.__insts__[subs] = res 324 return res 325 326 327def _runtime_impl(f): 328 """marks a generic function as being runtime-implemented""" 329 f.__runtime_impl__ = True 330 return f 331 332 333class StaticGeneric: 334 """Base type used to mark static-Generic classes. Instantations of these 335 classes share different generic types and the generic type arguments can 336 be accessed via __args___""" 337 338 @_tp_cache 339 def __class_getitem__( 340 cls, elem_type: Tuple[Union[TypeVar, Type[object]]] 341 ) -> Union[StaticGeneric, Type[object]]: 342 if not isinstance(elem_type, tuple): 343 # we specifically recurse to hit the type cache 344 return cls[ 345 elem_type, 346 ] 347 348 if cls is StaticGeneric: 349 res = StaticGeneric() 350 res.__parameters__ = elem_type 351 return res 352 353 return make_generic_type(cls, elem_type) 354 355 def __init_subclass__(cls) -> None: 356 type_vars = _collect_type_vars(cls.__orig_bases__) 357 cls.__origin__ = cls 358 cls.__parameters__ = type_vars 359 if not hasattr(cls, "__args__"): 360 cls.__args__ = type_vars 361 cls.__insts__ = WeakValueDictionary() 362 363 def __mro_entries__(self, bases) -> Tuple[Type[object, ...]]: 364 return (StaticGeneric,) 365 366 def __repr__(self) -> str: 367 return ( 368 "<StaticGeneric: " 369 + ", ".join([param.__name__ for param in self.__parameters__]) 370 + ">" 371 ) 372 373 374class Array(array.array, StaticGeneric[ArrayElement]): 375 def __new__(cls, initializer: int | Iterable[ArrayElement]): 376 if hasattr(cls, "__parameters__"): 377 raise TypeError("Cannot create plain Array") 378 379 typecode = _TYPE_CODES[cls.__args__[0]] 380 if isinstance(initializer, int): 381 res = array.array.__new__(cls, typecode, [0]) 382 res *= initializer 383 return res 384 else: 385 return array.array.__new__(cls, typecode, initializer) 386 387 def __init_subclass__(cls): 388 raise TypeError("Cannot subclass Array") 389 390 def __getitem__(self, index): 391 if isinstance(index, slice): 392 return type(self)(array.array.__getitem__(self, index)) 393 394 return array.array.__getitem__(self, index) 395 396 def __deepcopy__(self, memo): 397 return type(self)(self) 398 399 400class Vector(array.array, StaticGeneric[ArrayElement]): 401 """Vector is a resizable array of primitive elements""" 402 403 def __new__(cls, initializer: int | Iterable[ArrayElement] | None = None): 404 if hasattr(cls, "__parameters__"): 405 raise TypeError("Cannot create plain Vector") 406 407 typecode = _TYPE_CODES[cls.__args__[0]] 408 if isinstance(initializer, int): 409 # specifing size 410 res = array.array.__new__(cls, typecode, [0]) 411 res *= initializer 412 return res 413 elif initializer is not None: 414 return array.array.__new__(cls, typecode, initializer) 415 else: 416 return array.array.__new__(cls, typecode) 417 418 if _static is not None: 419 420 @_runtime_impl 421 def append(self, value: ArrayElement) -> None: 422 super().append(value) 423 424 def __init_subclass__(cls): 425 raise TypeError("Cannot subclass Vector") 426 427 def __getitem__(self, index): 428 if isinstance(index, slice): 429 return type(self)(array.array.__getitem__(self, index)) 430 431 return array.array.__getitem__(self, index) 432 433 def __deepcopy__(self, memo): 434 return type(self)(self) 435 436 437def box(o): 438 return o 439 440 441def unbox(o): 442 return o 443 444 445def allow_weakrefs(klass): 446 return klass 447 448 449def dynamic_return(func): 450 return func 451 452 453def inline(func): 454 return func 455 456 457def _donotcompile(func): 458 return func 459 460 461def cast(typ, val): 462 union_args = None 463 if type(typ) is _GenericAlias: 464 typ, args = typ.__origin__, typ.__args__ 465 if typ is Union: 466 union_args = args 467 elif type(typ) is typesUnion: 468 union_args = typ.__args__ 469 if union_args: 470 typ = None 471 if len(union_args) == 2: 472 if union_args[0] is type(None): 473 typ = union_args[1] 474 elif union_args[1] is type(None): 475 typ = union_args[0] 476 if typ is None: 477 raise ValueError("cast expects type or Optional[T]") 478 if val is None: 479 return None 480 481 inst_type = type(val) 482 if typ not in inst_type.__mro__: 483 raise TypeError(f"expected {typ.__name__}, got {type(val).__name__}") 484 485 return val 486 487 488CheckedDict = chkdict