this repo has no description
at trunk 530 lines 23 kB view raw
1#!/usr/bin/env python3 2import dataclasses 3import io 4import itertools 5import json 6import os 7import typing 8 9from typing import Dict, Optional, Tuple 10 11from scrapscript import ( 12 Access, 13 Apply, 14 Assign, 15 Binop, 16 BinopKind, 17 Function, 18 Hole, 19 Int, 20 List, 21 MatchFunction, 22 Object, 23 Record, 24 Spread, 25 String, 26 Var, 27 Variant, 28 Where, 29 free_in, 30 type_of, 31 IntType, 32 StringType, 33 parse, # needed for /compilerepl 34 tokenize, # needed for /compilerepl 35) 36 37Env = Dict[str, str] 38 39 40@dataclasses.dataclass 41class CompiledFunction: 42 name: str 43 params: typing.List[str] 44 fields: typing.List[str] = dataclasses.field(default_factory=list) 45 code: typing.List[str] = dataclasses.field(default_factory=list) 46 47 def __post_init__(self) -> None: 48 self.code.append("HANDLES();") 49 for param in self.params: 50 # The parameters are raw pointers and must be updated on GC 51 self.code.append(f"GC_PROTECT({param});") 52 53 def decl(self) -> str: 54 args = ", ".join(f"struct object* {arg}" for arg in self.params) 55 return f"struct object* {self.name}({args})" 56 57 58class Compiler: 59 def __init__(self, main_fn: CompiledFunction) -> None: 60 self.gensym_counter: int = 0 61 self.functions: typing.List[CompiledFunction] = [main_fn] 62 self.function: CompiledFunction = main_fn 63 self.record_keys: Dict[str, int] = {} 64 self.record_builders: Dict[Tuple[str, ...], CompiledFunction] = {} 65 self.variant_tags: Dict[str, int] = {} 66 self.debug: bool = False 67 self.const_heap: typing.List[str] = [] 68 69 def record_key(self, key: str) -> str: 70 if key not in self.record_keys: 71 self.record_keys[key] = len(self.record_keys) 72 return f"Record_{key}" 73 74 def record_builder(self, keys: Tuple[str, ...]) -> CompiledFunction: 75 builder = self.record_builders.get(keys) 76 if builder is not None: 77 return builder 78 79 builder = CompiledFunction(f"Record_builder_{'_'.join(keys)}", list(keys)) 80 self.functions.append(builder) 81 cur = self.function 82 self.function = builder 83 84 result = self._mktemp(f"mkrecord(heap, {len(keys)})") 85 for i, key in enumerate(keys): 86 key_idx = self.record_key(key) 87 self._emit(f"record_set({result}, /*index=*/{i}, (struct record_field){{.key={key_idx}, .value={key}}});") 88 self._debug("collect(heap);") 89 self._emit(f"return {result};") 90 91 self.function = cur 92 self.record_builders[keys] = builder 93 return builder 94 95 def variant_tag(self, key: str) -> int: 96 result = self.variant_tags.get(key) 97 if result is not None: 98 return result 99 result = self.variant_tags[key] = len(self.variant_tags) 100 return result 101 102 def gensym(self, stem: str = "tmp") -> str: 103 self.gensym_counter += 1 104 return f"{stem}_{self.gensym_counter-1}" 105 106 def _emit(self, line: str) -> None: 107 self.function.code.append(line) 108 109 def _debug(self, line: str) -> None: 110 if not self.debug: 111 return 112 self._emit("#ifndef NDEBUG") 113 self._emit(line) 114 self._emit("#endif") 115 116 def _handle(self, name: str, exp: str) -> str: 117 # TODO(max): Liveness analysis to avoid unnecessary handles 118 self._emit(f"OBJECT_HANDLE({name}, {exp});") 119 return name 120 121 def _guard(self, cond: str, msg: Optional[str] = None) -> None: 122 if msg is None: 123 msg = f"assertion {cond!s} failed" 124 self._emit(f"if (!({cond})) {{") 125 self._emit(f'fprintf(stderr, "{msg}\\n");') 126 self._emit("abort();") 127 self._emit("}") 128 129 def _guard_int(self, exp: Object, c_name: str) -> None: 130 if type_of(exp) != IntType: 131 self._guard(f"is_num({c_name})") 132 133 def _guard_str(self, exp: Object, c_name: str) -> None: 134 if type_of(exp) != StringType: 135 self._guard(f"is_string({c_name})") 136 137 def _mktemp(self, exp: str) -> str: 138 temp = self.gensym() 139 return self._handle(temp, exp) 140 141 def compile_assign(self, env: Env, exp: Assign) -> Env: 142 assert isinstance(exp.name, Var) 143 name = exp.name.name 144 if isinstance(exp.value, Function): 145 # Named function 146 value = self.compile_function(env, exp.value, name) 147 return {**env, name: value} 148 if isinstance(exp.value, MatchFunction): 149 # Named match function 150 value = self.compile_match_function(env, exp.value, name) 151 return {**env, name: value} 152 value = self.compile(env, exp.value) 153 return {**env, name: value} 154 155 def make_compiled_function(self, arg: str, exp: Object, name: Optional[str]) -> CompiledFunction: 156 assert isinstance(exp, (Function, MatchFunction)) 157 free = free_in(exp) 158 if name is not None and name in free: 159 free.remove(name) 160 fields = sorted(free) 161 fn_name = self.gensym(name if name else "fn") # must be globally unique 162 return CompiledFunction(fn_name, params=["this", arg], fields=fields) 163 164 def compile_function_env(self, fn: CompiledFunction, name: Optional[str]) -> Env: 165 result = {param: param for param in fn.params} 166 if name is not None: 167 result[name] = "this" 168 for i, field in enumerate(fn.fields): 169 result[field] = self._mktemp(f"closure_get(this, /*{field}=*/{i})") 170 return result 171 172 def compile_function(self, env: Env, exp: Function, name: Optional[str]) -> str: 173 assert isinstance(exp.arg, Var) 174 fn = self.make_compiled_function(exp.arg.name, exp, name) 175 self.functions.append(fn) 176 cur = self.function 177 self.function = fn 178 funcenv = self.compile_function_env(fn, name) 179 val = self.compile(funcenv, exp.body) 180 fn.code.append(f"return {val};") 181 self.function = cur 182 if not fn.fields: 183 # TODO(max): Closure over freevars but only consts 184 return self._const_closure(fn) 185 return self.make_closure(env, fn) 186 187 def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> Env: 188 # TODO(max): Give `arg` an AST node so we can track its inferred type 189 # and make use of that in pattern matching 190 if isinstance(pattern, Int): 191 self._emit(f"if (!is_num_equal_word({arg}, {pattern.value})) {{ goto {fallthrough}; }}") 192 return {} 193 if isinstance(pattern, Hole): 194 self._emit(f"if (!is_hole({arg})) {{ goto {fallthrough}; }}") 195 return {} 196 if isinstance(pattern, Variant): 197 self.variant_tag(pattern.tag) # register it for the big enum 198 if isinstance(pattern.value, Hole): 199 # This is an optimization for immediate variants but it's not 200 # necessary; the non-Hole case would work just fine. 201 self._emit(f"if ({arg} != mk_immediate_variant(Tag_{pattern.tag})) {{ goto {fallthrough}; }}") 202 return {} 203 self._emit(f"if (!is_variant({arg})) {{ goto {fallthrough}; }}") 204 self._emit(f"if (variant_tag({arg}) != Tag_{pattern.tag}) {{ goto {fallthrough}; }}") 205 return self.try_match(env, self._mktemp(f"variant_value({arg})"), pattern.value, fallthrough) 206 207 if isinstance(pattern, String): 208 value = pattern.value 209 if len(value) < 8: 210 self._emit(f"if ({arg} != mksmallstring({json.dumps(value)}, {len(value)})) {{ goto {fallthrough}; }}") 211 return {} 212 self._emit(f"if (!is_string({arg})) {{ goto {fallthrough}; }}") 213 self._emit( 214 f"if (!string_equal_cstr_len({arg}, {json.dumps(value)}, {len(value)})) {{ goto {fallthrough}; }}" 215 ) 216 return {} 217 if isinstance(pattern, Var): 218 return {pattern.name: arg} 219 if isinstance(pattern, List): 220 self._emit(f"if (!is_list({arg})) {{ goto {fallthrough}; }}") 221 updates = {} 222 the_list = arg 223 for i, pattern_item in enumerate(pattern.items): 224 if isinstance(pattern_item, Spread): 225 if pattern_item.name: 226 updates[pattern_item.name] = the_list 227 return updates 228 # Not enough elements 229 self._emit(f"if (is_empty_list({the_list})) {{ goto {fallthrough}; }}") 230 list_item = self._mktemp(f"list_first({the_list})") 231 updates.update(self.try_match(env, list_item, pattern_item, fallthrough)) 232 the_list = self._mktemp(f"list_rest({the_list})") 233 # Too many elements 234 self._emit(f"if (!is_empty_list({the_list})) {{ goto {fallthrough}; }}") 235 return updates 236 if isinstance(pattern, Record): 237 self._emit(f"if (!is_record({arg})) {{ goto {fallthrough}; }}") 238 updates = {} 239 for key, pattern_value in pattern.data.items(): 240 if isinstance(pattern_value, Spread): 241 if pattern_value.name: 242 raise NotImplementedError("named record spread not yet supported") 243 return updates 244 key_idx = self.record_key(key) 245 record_value = self._mktemp(f"record_get({arg}, {key_idx})") 246 # TODO(max): If the key is present in the type, don't emit this 247 # check 248 self._emit(f"if ({record_value} == NULL) {{ goto {fallthrough}; }}") 249 updates.update(self.try_match(env, record_value, pattern_value, fallthrough)) 250 self._emit(f"if (record_num_fields({arg}) != {len(pattern.data)}) {{ goto {fallthrough}; }}") 251 return updates 252 raise NotImplementedError("try_match", pattern) 253 254 def compile_match_function(self, env: Env, exp: MatchFunction, name: Optional[str]) -> str: 255 arg = self.gensym() 256 fn = self.make_compiled_function(arg, exp, name) 257 self.functions.append(fn) 258 cur = self.function 259 self.function = fn 260 funcenv = self.compile_function_env(fn, name) 261 for i, case in enumerate(exp.cases): 262 fallthrough = f"case_{i+1}" if i < len(exp.cases) - 1 else "no_match" 263 env_updates = self.try_match(funcenv, arg, case.pattern, fallthrough) 264 case_result = self.compile({**funcenv, **env_updates}, case.body) 265 self._emit(f"return {case_result};") 266 self._emit(f"{fallthrough}:;") 267 self._emit(r'fprintf(stderr, "no matching cases\n");') 268 self._emit("abort();") 269 # Pacify the C compiler 270 self._emit("return NULL;") 271 self.function = cur 272 if not fn.fields: 273 # TODO(max): Closure over freevars but only consts 274 return self._const_closure(fn) 275 return self.make_closure(env, fn) 276 277 def make_closure(self, env: Env, fn: CompiledFunction) -> str: 278 name = self._mktemp(f"mkclosure(heap, {fn.name}, {len(fn.fields)})") 279 for i, field in enumerate(fn.fields): 280 self._emit(f"closure_set({name}, /*{field}=*/{i}, {env[field]});") 281 self._debug("collect(heap);") 282 return name 283 284 def _is_const(self, exp: Object) -> bool: 285 if isinstance(exp, Int): 286 return True 287 if isinstance(exp, String): 288 return True 289 if isinstance(exp, Variant): 290 return self._is_const(exp.value) 291 if isinstance(exp, Record): 292 return all(self._is_const(value) for value in exp.data.values()) 293 if isinstance(exp, List): 294 return all(self._is_const(item) for item in exp.items) 295 if isinstance(exp, Hole): 296 return True 297 if isinstance(exp, Function) and len(free_in(exp)) == 0: 298 return True 299 return False 300 301 def _const_obj(self, type: str, tag: str, contents: str) -> str: 302 result = self.gensym(f"const_{type}") 303 self.const_heap.append(f"CONST_HEAP struct {type} {result} = {{.HEAD.tag={tag}, {contents} }};") 304 return f"ptrto({result})" 305 306 def _const_cons(self, first: str, rest: str) -> str: 307 return self._const_obj("list", "TAG_LIST", f".first={first}, .rest={rest}") 308 309 def _const_closure(self, fn: CompiledFunction) -> str: 310 assert len(fn.fields) == 0 311 return self._const_obj("closure", "TAG_CLOSURE", f".fn={fn.name}, .size=0") 312 313 def _emit_small_string(self, value_str: str) -> str: 314 value = value_str.encode("utf-8") 315 length = len(value) 316 assert length < 8, "small string must be less than 8 bytes" 317 value_int = int.from_bytes(value, "little") 318 return f"(struct object*)(({hex(value_int)}ULL << kBitsPerByte) | ({length}ULL << kImmediateTagBits) | (uword)kSmallStringTag /* {value_str!r} */)" 319 320 def _emit_const(self, exp: Object) -> str: 321 assert self._is_const(exp), f"not a constant {exp}" 322 if isinstance(exp, Hole): 323 return "hole()" 324 if isinstance(exp, Int): 325 # TODO(max): Bignum 326 return f"_mksmallint({exp.value})" 327 if isinstance(exp, List): 328 items = [self._emit_const(item) for item in exp.items] 329 result = "empty_list()" 330 for item in reversed(items): 331 result = self._const_cons(item, result) 332 return result 333 if isinstance(exp, String): 334 if len(exp.value) < 8: 335 return self._emit_small_string(exp.value) 336 return self._const_obj( 337 "heap_string", "TAG_STRING", f".size={len(exp.value)}, .data={json.dumps(exp.value)}" 338 ) 339 if isinstance(exp, Variant): 340 self.variant_tag(exp.tag) 341 if isinstance(exp.value, Hole): 342 return f"mk_immediate_variant(Tag_{exp.tag})" 343 value = self._emit_const(exp.value) 344 return self._const_obj("variant", "TAG_VARIANT", f".tag=Tag_{exp.tag}, .value={value}") 345 if isinstance(exp, Record): 346 values = {self.record_key(key): self._emit_const(value) for key, value in exp.data.items()} 347 fields = ",\n".join(f"{{.key={key}, .value={value} }}" for key, value in values.items()) 348 return self._const_obj("record", "TAG_RECORD", f".size={len(values)}, .fields={{ {fields} }}") 349 if isinstance(exp, Function): 350 assert len(free_in(exp)) == 0, "only constant functions can be constified" 351 return self.compile_function({}, exp, name=None) 352 raise NotImplementedError(f"const {exp}") 353 354 def compile(self, env: Env, exp: Object) -> str: 355 if self._is_const(exp): 356 return self._emit_const(exp) 357 if isinstance(exp, Variant): 358 assert not isinstance(exp.value, Hole), "immediate variant should be handled in _emit_const" 359 assert not self._is_const(exp.value), "const heap variant should be handled in _emit_const" 360 self._debug("collect(heap);") 361 self.variant_tag(exp.tag) 362 value = self.compile(env, exp.value) 363 result = self._mktemp(f"mkvariant(heap, Tag_{exp.tag})") 364 self._emit(f"variant_set({result}, {value});") 365 return result 366 if isinstance(exp, String): 367 assert len(exp.value.encode("utf-8")) >= 8, "small string should be handled in _emit_const" 368 self._debug("collect(heap);") 369 string_repr = json.dumps(exp.value) 370 return self._mktemp(f"mkstring(heap, {string_repr}, {len(exp.value)});") 371 if isinstance(exp, Binop): 372 left = self.compile(env, exp.left) 373 right = self.compile(env, exp.right) 374 if exp.op == BinopKind.ADD: 375 self._debug("collect(heap);") 376 self._guard_int(exp.left, left) 377 self._guard_int(exp.right, right) 378 return self._mktemp(f"num_add({left}, {right})") 379 if exp.op == BinopKind.MUL: 380 self._debug("collect(heap);") 381 self._guard_int(exp.left, left) 382 self._guard_int(exp.right, right) 383 return self._mktemp(f"num_mul({left}, {right})") 384 if exp.op == BinopKind.SUB: 385 self._debug("collect(heap);") 386 self._guard_int(exp.left, left) 387 self._guard_int(exp.right, right) 388 return self._mktemp(f"num_sub({left}, {right})") 389 if exp.op == BinopKind.LIST_CONS: 390 self._debug("collect(heap);") 391 return self._mktemp(f"list_cons({left}, {right})") 392 if exp.op == BinopKind.STRING_CONCAT: 393 self._debug("collect(heap);") 394 self._guard_str(exp.left, left) 395 self._guard_str(exp.right, right) 396 return self._mktemp(f"string_concat({left}, {right})") 397 raise NotImplementedError(f"binop {exp.op}") 398 if isinstance(exp, Where): 399 assert isinstance(exp.binding, Assign) 400 res_env = self.compile_assign(env, exp.binding) 401 new_env = {**env, **res_env} 402 return self.compile(new_env, exp.body) 403 if isinstance(exp, Var): 404 var_value = env.get(exp.name) 405 if var_value is None: 406 raise NameError(f"name '{exp.name}' is not defined") 407 return var_value 408 if isinstance(exp, Apply): 409 callee = self.compile(env, exp.func) 410 arg = self.compile(env, exp.arg) 411 return self._mktemp(f"closure_call({callee}, {arg})") 412 if isinstance(exp, List): 413 items = [self.compile(env, item) for item in exp.items] 414 result = self._mktemp("empty_list()") 415 for item in reversed(items): 416 result = self._mktemp(f"list_cons({item}, {result})") 417 self._debug("collect(heap);") 418 return result 419 if isinstance(exp, Record): 420 values: Dict[str, str] = {} 421 for key, value_exp in exp.data.items(): 422 values[key] = self.compile(env, value_exp) 423 keys = tuple(sorted(exp.data.keys())) 424 builder = self.record_builder(keys) 425 return self._mktemp(f"{builder.name}({', '.join(values[key] for key in keys)})") 426 if isinstance(exp, Access): 427 assert isinstance(exp.at, Var), f"only Var access is supported, got {type(exp.at)}" 428 record = self.compile(env, exp.obj) 429 key_idx = self.record_key(exp.at.name) 430 # Check if the record is a record 431 self._guard(f"is_record({record})", "not a record") 432 value = self._mktemp(f"record_get({record}, {key_idx})") 433 self._guard(f"{value} != NULL", f"missing key {exp.at.name!s}") 434 return value 435 if isinstance(exp, Function): 436 # Anonymous function 437 return self.compile_function(env, exp, name=None) 438 if isinstance(exp, MatchFunction): 439 # Anonymous match function 440 return self.compile_match_function(env, exp, name=None) 441 raise NotImplementedError(f"exp {type(exp)} {exp}") 442 443 444def compile_to_string(program: Object, debug: bool) -> str: 445 main_fn = CompiledFunction("scrap_main", params=[]) 446 compiler = Compiler(main_fn) 447 compiler.debug = debug 448 result = compiler.compile({}, program) 449 main_fn.code.append(f"return {result};") 450 451 f = io.StringIO() 452 constants = [ 453 ("uword", "kKiB", 1024), 454 ("uword", "kMiB", "kKiB * kKiB"), 455 ("uword", "kGiB", "kKiB * kKiB * kKiB"), 456 ("uword", "kPageSize", "4 * kKiB"), 457 ("uword", "kSmallIntTagBits", 1), 458 ("uword", "kPrimaryTagBits", 3), 459 ("uword", "kObjectAlignmentLog2", 3), # bits 460 ("uword", "kObjectAlignment", "1ULL << kObjectAlignmentLog2"), 461 ("uword", "kImmediateTagBits", 5), 462 ("uword", "kSmallIntTagMask", "(1ULL << kSmallIntTagBits) - 1"), 463 ("uword", "kPrimaryTagMask", "(1ULL << kPrimaryTagBits) - 1"), 464 ("uword", "kImmediateTagMask", "(1ULL << kImmediateTagBits) - 1"), 465 ("uword", "kWordSize", "sizeof(word)"), 466 ("uword", "kMaxSmallStringLength", "kWordSize - 1"), 467 ("uword", "kBitsPerByte", 8), 468 # Up to the five least significant bits are used to tag the object's layout. 469 # The three low bits make up a primary tag, used to differentiate gc_obj 470 # from immediate objects. All even tags map to SmallInt, which is 471 # optimized by checking only the lowest bit for parity. 472 ("uword", "kSmallIntTag", 0), # 0b****0 473 ("uword", "kHeapObjectTag", 1), # 0b**001 474 ("uword", "kEmptyListTag", 5), # 0b00101 475 ("uword", "kHoleTag", 7), # 0b00111 476 ("uword", "kSmallStringTag", 13), # 0b01101 477 ("uword", "kVariantTag", 15), # 0b01111 478 # TODO(max): Fill in 21 479 # TODO(max): Fill in 23 480 # TODO(max): Fill in 29 481 # TODO(max): Fill in 31 482 ("uword", "kBitsPerPointer", "kBitsPerByte * kWordSize"), 483 ("word", "kSmallIntBits", "kBitsPerPointer - kSmallIntTagBits"), 484 ("word", "kSmallIntMinValue", "-(((word)1) << (kSmallIntBits - 1))"), 485 ("word", "kSmallIntMaxValue", "(((word)1) << (kSmallIntBits - 1)) - 1"), 486 ] 487 for type_, name, value in constants: 488 print(f"#define {name} ({type_})({value})", file=f) 489 # The runtime is in the same directory as this file 490 dirname = os.path.dirname(__file__) 491 with open(os.path.join(dirname, "runtime.c"), "r") as runtime: 492 print(runtime.read(), file=f) 493 print("#define OBJECT_HANDLE(name, exp) GC_HANDLE(struct object*, name, exp)", file=f) 494 if compiler.record_keys: 495 print("const char* record_keys[] = {", file=f) 496 for key in compiler.record_keys: 497 print(f'"{key}",', file=f) 498 print("};", file=f) 499 print("enum {", file=f) 500 for key, idx in compiler.record_keys.items(): 501 print(f"Record_{key} = {idx},", file=f) 502 print("};", file=f) 503 else: 504 # Pacify the C compiler 505 print("const char* record_keys[] = { NULL };", file=f) 506 if compiler.variant_tags: 507 print("const char* variant_names[] = {", file=f) 508 for key in compiler.variant_tags: 509 print(f'"{key}",', file=f) 510 print("};", file=f) 511 print("enum {", file=f) 512 for key, idx in compiler.variant_tags.items(): 513 print(f"Tag_{key} = {idx},", file=f) 514 print("};", file=f) 515 else: 516 # Pacify the C compiler 517 print("const char* variant_names[] = { NULL };", file=f) 518 # Declare all functions 519 for function in compiler.functions: 520 print(function.decl() + ";", file=f) 521 # Emit the const heap 522 print("#define ptrto(obj) ((struct object*)((uword)&(obj) + 1))", file=f) 523 for line in compiler.const_heap: 524 print(line, file=f) 525 for function in compiler.functions: 526 print(f"{function.decl()} {{", file=f) 527 for line in function.code: 528 print(line, file=f) 529 print("}", file=f) 530 return f.getvalue()