this repo has no description
1#!/usr/bin/env python3
2import dataclasses
3import io
4import itertools
5import json
6import os
7import typing
8
9from typing import Dict, Optional, Tuple
10
11from scrapscript import (
12 Access,
13 Apply,
14 Assign,
15 Binop,
16 BinopKind,
17 Function,
18 Hole,
19 Int,
20 List,
21 MatchFunction,
22 Object,
23 Record,
24 Spread,
25 String,
26 Var,
27 Variant,
28 Where,
29 free_in,
30 type_of,
31 IntType,
32 StringType,
33 parse, # needed for /compilerepl
34 tokenize, # needed for /compilerepl
35)
36
37Env = Dict[str, str]
38
39
40@dataclasses.dataclass
41class CompiledFunction:
42 name: str
43 params: typing.List[str]
44 fields: typing.List[str] = dataclasses.field(default_factory=list)
45 code: typing.List[str] = dataclasses.field(default_factory=list)
46
47 def __post_init__(self) -> None:
48 self.code.append("HANDLES();")
49 for param in self.params:
50 # The parameters are raw pointers and must be updated on GC
51 self.code.append(f"GC_PROTECT({param});")
52
53 def decl(self) -> str:
54 args = ", ".join(f"struct object* {arg}" for arg in self.params)
55 return f"struct object* {self.name}({args})"
56
57
58class Compiler:
59 def __init__(self, main_fn: CompiledFunction) -> None:
60 self.gensym_counter: int = 0
61 self.functions: typing.List[CompiledFunction] = [main_fn]
62 self.function: CompiledFunction = main_fn
63 self.record_keys: Dict[str, int] = {}
64 self.record_builders: Dict[Tuple[str, ...], CompiledFunction] = {}
65 self.variant_tags: Dict[str, int] = {}
66 self.debug: bool = False
67 self.const_heap: typing.List[str] = []
68
69 def record_key(self, key: str) -> str:
70 if key not in self.record_keys:
71 self.record_keys[key] = len(self.record_keys)
72 return f"Record_{key}"
73
74 def record_builder(self, keys: Tuple[str, ...]) -> CompiledFunction:
75 builder = self.record_builders.get(keys)
76 if builder is not None:
77 return builder
78
79 builder = CompiledFunction(f"Record_builder_{'_'.join(keys)}", list(keys))
80 self.functions.append(builder)
81 cur = self.function
82 self.function = builder
83
84 result = self._mktemp(f"mkrecord(heap, {len(keys)})")
85 for i, key in enumerate(keys):
86 key_idx = self.record_key(key)
87 self._emit(f"record_set({result}, /*index=*/{i}, (struct record_field){{.key={key_idx}, .value={key}}});")
88 self._debug("collect(heap);")
89 self._emit(f"return {result};")
90
91 self.function = cur
92 self.record_builders[keys] = builder
93 return builder
94
95 def variant_tag(self, key: str) -> int:
96 result = self.variant_tags.get(key)
97 if result is not None:
98 return result
99 result = self.variant_tags[key] = len(self.variant_tags)
100 return result
101
102 def gensym(self, stem: str = "tmp") -> str:
103 self.gensym_counter += 1
104 return f"{stem}_{self.gensym_counter-1}"
105
106 def _emit(self, line: str) -> None:
107 self.function.code.append(line)
108
109 def _debug(self, line: str) -> None:
110 if not self.debug:
111 return
112 self._emit("#ifndef NDEBUG")
113 self._emit(line)
114 self._emit("#endif")
115
116 def _handle(self, name: str, exp: str) -> str:
117 # TODO(max): Liveness analysis to avoid unnecessary handles
118 self._emit(f"OBJECT_HANDLE({name}, {exp});")
119 return name
120
121 def _guard(self, cond: str, msg: Optional[str] = None) -> None:
122 if msg is None:
123 msg = f"assertion {cond!s} failed"
124 self._emit(f"if (!({cond})) {{")
125 self._emit(f'fprintf(stderr, "{msg}\\n");')
126 self._emit("abort();")
127 self._emit("}")
128
129 def _guard_int(self, exp: Object, c_name: str) -> None:
130 if type_of(exp) != IntType:
131 self._guard(f"is_num({c_name})")
132
133 def _guard_str(self, exp: Object, c_name: str) -> None:
134 if type_of(exp) != StringType:
135 self._guard(f"is_string({c_name})")
136
137 def _mktemp(self, exp: str) -> str:
138 temp = self.gensym()
139 return self._handle(temp, exp)
140
141 def compile_assign(self, env: Env, exp: Assign) -> Env:
142 assert isinstance(exp.name, Var)
143 name = exp.name.name
144 if isinstance(exp.value, Function):
145 # Named function
146 value = self.compile_function(env, exp.value, name)
147 return {**env, name: value}
148 if isinstance(exp.value, MatchFunction):
149 # Named match function
150 value = self.compile_match_function(env, exp.value, name)
151 return {**env, name: value}
152 value = self.compile(env, exp.value)
153 return {**env, name: value}
154
155 def make_compiled_function(self, arg: str, exp: Object, name: Optional[str]) -> CompiledFunction:
156 assert isinstance(exp, (Function, MatchFunction))
157 free = free_in(exp)
158 if name is not None and name in free:
159 free.remove(name)
160 fields = sorted(free)
161 fn_name = self.gensym(name if name else "fn") # must be globally unique
162 return CompiledFunction(fn_name, params=["this", arg], fields=fields)
163
164 def compile_function_env(self, fn: CompiledFunction, name: Optional[str]) -> Env:
165 result = {param: param for param in fn.params}
166 if name is not None:
167 result[name] = "this"
168 for i, field in enumerate(fn.fields):
169 result[field] = self._mktemp(f"closure_get(this, /*{field}=*/{i})")
170 return result
171
172 def compile_function(self, env: Env, exp: Function, name: Optional[str]) -> str:
173 assert isinstance(exp.arg, Var)
174 fn = self.make_compiled_function(exp.arg.name, exp, name)
175 self.functions.append(fn)
176 cur = self.function
177 self.function = fn
178 funcenv = self.compile_function_env(fn, name)
179 val = self.compile(funcenv, exp.body)
180 fn.code.append(f"return {val};")
181 self.function = cur
182 if not fn.fields:
183 # TODO(max): Closure over freevars but only consts
184 return self._const_closure(fn)
185 return self.make_closure(env, fn)
186
187 def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> Env:
188 # TODO(max): Give `arg` an AST node so we can track its inferred type
189 # and make use of that in pattern matching
190 if isinstance(pattern, Int):
191 self._emit(f"if (!is_num_equal_word({arg}, {pattern.value})) {{ goto {fallthrough}; }}")
192 return {}
193 if isinstance(pattern, Hole):
194 self._emit(f"if (!is_hole({arg})) {{ goto {fallthrough}; }}")
195 return {}
196 if isinstance(pattern, Variant):
197 self.variant_tag(pattern.tag) # register it for the big enum
198 if isinstance(pattern.value, Hole):
199 # This is an optimization for immediate variants but it's not
200 # necessary; the non-Hole case would work just fine.
201 self._emit(f"if ({arg} != mk_immediate_variant(Tag_{pattern.tag})) {{ goto {fallthrough}; }}")
202 return {}
203 self._emit(f"if (!is_variant({arg})) {{ goto {fallthrough}; }}")
204 self._emit(f"if (variant_tag({arg}) != Tag_{pattern.tag}) {{ goto {fallthrough}; }}")
205 return self.try_match(env, self._mktemp(f"variant_value({arg})"), pattern.value, fallthrough)
206
207 if isinstance(pattern, String):
208 value = pattern.value
209 if len(value) < 8:
210 self._emit(f"if ({arg} != mksmallstring({json.dumps(value)}, {len(value)})) {{ goto {fallthrough}; }}")
211 return {}
212 self._emit(f"if (!is_string({arg})) {{ goto {fallthrough}; }}")
213 self._emit(
214 f"if (!string_equal_cstr_len({arg}, {json.dumps(value)}, {len(value)})) {{ goto {fallthrough}; }}"
215 )
216 return {}
217 if isinstance(pattern, Var):
218 return {pattern.name: arg}
219 if isinstance(pattern, List):
220 self._emit(f"if (!is_list({arg})) {{ goto {fallthrough}; }}")
221 updates = {}
222 the_list = arg
223 for i, pattern_item in enumerate(pattern.items):
224 if isinstance(pattern_item, Spread):
225 if pattern_item.name:
226 updates[pattern_item.name] = the_list
227 return updates
228 # Not enough elements
229 self._emit(f"if (is_empty_list({the_list})) {{ goto {fallthrough}; }}")
230 list_item = self._mktemp(f"list_first({the_list})")
231 updates.update(self.try_match(env, list_item, pattern_item, fallthrough))
232 the_list = self._mktemp(f"list_rest({the_list})")
233 # Too many elements
234 self._emit(f"if (!is_empty_list({the_list})) {{ goto {fallthrough}; }}")
235 return updates
236 if isinstance(pattern, Record):
237 self._emit(f"if (!is_record({arg})) {{ goto {fallthrough}; }}")
238 updates = {}
239 for key, pattern_value in pattern.data.items():
240 if isinstance(pattern_value, Spread):
241 if pattern_value.name:
242 raise NotImplementedError("named record spread not yet supported")
243 return updates
244 key_idx = self.record_key(key)
245 record_value = self._mktemp(f"record_get({arg}, {key_idx})")
246 # TODO(max): If the key is present in the type, don't emit this
247 # check
248 self._emit(f"if ({record_value} == NULL) {{ goto {fallthrough}; }}")
249 updates.update(self.try_match(env, record_value, pattern_value, fallthrough))
250 self._emit(f"if (record_num_fields({arg}) != {len(pattern.data)}) {{ goto {fallthrough}; }}")
251 return updates
252 raise NotImplementedError("try_match", pattern)
253
254 def compile_match_function(self, env: Env, exp: MatchFunction, name: Optional[str]) -> str:
255 arg = self.gensym()
256 fn = self.make_compiled_function(arg, exp, name)
257 self.functions.append(fn)
258 cur = self.function
259 self.function = fn
260 funcenv = self.compile_function_env(fn, name)
261 for i, case in enumerate(exp.cases):
262 fallthrough = f"case_{i+1}" if i < len(exp.cases) - 1 else "no_match"
263 env_updates = self.try_match(funcenv, arg, case.pattern, fallthrough)
264 case_result = self.compile({**funcenv, **env_updates}, case.body)
265 self._emit(f"return {case_result};")
266 self._emit(f"{fallthrough}:;")
267 self._emit(r'fprintf(stderr, "no matching cases\n");')
268 self._emit("abort();")
269 # Pacify the C compiler
270 self._emit("return NULL;")
271 self.function = cur
272 if not fn.fields:
273 # TODO(max): Closure over freevars but only consts
274 return self._const_closure(fn)
275 return self.make_closure(env, fn)
276
277 def make_closure(self, env: Env, fn: CompiledFunction) -> str:
278 name = self._mktemp(f"mkclosure(heap, {fn.name}, {len(fn.fields)})")
279 for i, field in enumerate(fn.fields):
280 self._emit(f"closure_set({name}, /*{field}=*/{i}, {env[field]});")
281 self._debug("collect(heap);")
282 return name
283
284 def _is_const(self, exp: Object) -> bool:
285 if isinstance(exp, Int):
286 return True
287 if isinstance(exp, String):
288 return True
289 if isinstance(exp, Variant):
290 return self._is_const(exp.value)
291 if isinstance(exp, Record):
292 return all(self._is_const(value) for value in exp.data.values())
293 if isinstance(exp, List):
294 return all(self._is_const(item) for item in exp.items)
295 if isinstance(exp, Hole):
296 return True
297 if isinstance(exp, Function) and len(free_in(exp)) == 0:
298 return True
299 return False
300
301 def _const_obj(self, type: str, tag: str, contents: str) -> str:
302 result = self.gensym(f"const_{type}")
303 self.const_heap.append(f"CONST_HEAP struct {type} {result} = {{.HEAD.tag={tag}, {contents} }};")
304 return f"ptrto({result})"
305
306 def _const_cons(self, first: str, rest: str) -> str:
307 return self._const_obj("list", "TAG_LIST", f".first={first}, .rest={rest}")
308
309 def _const_closure(self, fn: CompiledFunction) -> str:
310 assert len(fn.fields) == 0
311 return self._const_obj("closure", "TAG_CLOSURE", f".fn={fn.name}, .size=0")
312
313 def _emit_small_string(self, value_str: str) -> str:
314 value = value_str.encode("utf-8")
315 length = len(value)
316 assert length < 8, "small string must be less than 8 bytes"
317 value_int = int.from_bytes(value, "little")
318 return f"(struct object*)(({hex(value_int)}ULL << kBitsPerByte) | ({length}ULL << kImmediateTagBits) | (uword)kSmallStringTag /* {value_str!r} */)"
319
320 def _emit_const(self, exp: Object) -> str:
321 assert self._is_const(exp), f"not a constant {exp}"
322 if isinstance(exp, Hole):
323 return "hole()"
324 if isinstance(exp, Int):
325 # TODO(max): Bignum
326 return f"_mksmallint({exp.value})"
327 if isinstance(exp, List):
328 items = [self._emit_const(item) for item in exp.items]
329 result = "empty_list()"
330 for item in reversed(items):
331 result = self._const_cons(item, result)
332 return result
333 if isinstance(exp, String):
334 if len(exp.value) < 8:
335 return self._emit_small_string(exp.value)
336 return self._const_obj(
337 "heap_string", "TAG_STRING", f".size={len(exp.value)}, .data={json.dumps(exp.value)}"
338 )
339 if isinstance(exp, Variant):
340 self.variant_tag(exp.tag)
341 if isinstance(exp.value, Hole):
342 return f"mk_immediate_variant(Tag_{exp.tag})"
343 value = self._emit_const(exp.value)
344 return self._const_obj("variant", "TAG_VARIANT", f".tag=Tag_{exp.tag}, .value={value}")
345 if isinstance(exp, Record):
346 values = {self.record_key(key): self._emit_const(value) for key, value in exp.data.items()}
347 fields = ",\n".join(f"{{.key={key}, .value={value} }}" for key, value in values.items())
348 return self._const_obj("record", "TAG_RECORD", f".size={len(values)}, .fields={{ {fields} }}")
349 if isinstance(exp, Function):
350 assert len(free_in(exp)) == 0, "only constant functions can be constified"
351 return self.compile_function({}, exp, name=None)
352 raise NotImplementedError(f"const {exp}")
353
354 def compile(self, env: Env, exp: Object) -> str:
355 if self._is_const(exp):
356 return self._emit_const(exp)
357 if isinstance(exp, Variant):
358 assert not isinstance(exp.value, Hole), "immediate variant should be handled in _emit_const"
359 assert not self._is_const(exp.value), "const heap variant should be handled in _emit_const"
360 self._debug("collect(heap);")
361 self.variant_tag(exp.tag)
362 value = self.compile(env, exp.value)
363 result = self._mktemp(f"mkvariant(heap, Tag_{exp.tag})")
364 self._emit(f"variant_set({result}, {value});")
365 return result
366 if isinstance(exp, String):
367 assert len(exp.value.encode("utf-8")) >= 8, "small string should be handled in _emit_const"
368 self._debug("collect(heap);")
369 string_repr = json.dumps(exp.value)
370 return self._mktemp(f"mkstring(heap, {string_repr}, {len(exp.value)});")
371 if isinstance(exp, Binop):
372 left = self.compile(env, exp.left)
373 right = self.compile(env, exp.right)
374 if exp.op == BinopKind.ADD:
375 self._debug("collect(heap);")
376 self._guard_int(exp.left, left)
377 self._guard_int(exp.right, right)
378 return self._mktemp(f"num_add({left}, {right})")
379 if exp.op == BinopKind.MUL:
380 self._debug("collect(heap);")
381 self._guard_int(exp.left, left)
382 self._guard_int(exp.right, right)
383 return self._mktemp(f"num_mul({left}, {right})")
384 if exp.op == BinopKind.SUB:
385 self._debug("collect(heap);")
386 self._guard_int(exp.left, left)
387 self._guard_int(exp.right, right)
388 return self._mktemp(f"num_sub({left}, {right})")
389 if exp.op == BinopKind.LIST_CONS:
390 self._debug("collect(heap);")
391 return self._mktemp(f"list_cons({left}, {right})")
392 if exp.op == BinopKind.STRING_CONCAT:
393 self._debug("collect(heap);")
394 self._guard_str(exp.left, left)
395 self._guard_str(exp.right, right)
396 return self._mktemp(f"string_concat({left}, {right})")
397 raise NotImplementedError(f"binop {exp.op}")
398 if isinstance(exp, Where):
399 assert isinstance(exp.binding, Assign)
400 res_env = self.compile_assign(env, exp.binding)
401 new_env = {**env, **res_env}
402 return self.compile(new_env, exp.body)
403 if isinstance(exp, Var):
404 var_value = env.get(exp.name)
405 if var_value is None:
406 raise NameError(f"name '{exp.name}' is not defined")
407 return var_value
408 if isinstance(exp, Apply):
409 callee = self.compile(env, exp.func)
410 arg = self.compile(env, exp.arg)
411 return self._mktemp(f"closure_call({callee}, {arg})")
412 if isinstance(exp, List):
413 items = [self.compile(env, item) for item in exp.items]
414 result = self._mktemp("empty_list()")
415 for item in reversed(items):
416 result = self._mktemp(f"list_cons({item}, {result})")
417 self._debug("collect(heap);")
418 return result
419 if isinstance(exp, Record):
420 values: Dict[str, str] = {}
421 for key, value_exp in exp.data.items():
422 values[key] = self.compile(env, value_exp)
423 keys = tuple(sorted(exp.data.keys()))
424 builder = self.record_builder(keys)
425 return self._mktemp(f"{builder.name}({', '.join(values[key] for key in keys)})")
426 if isinstance(exp, Access):
427 assert isinstance(exp.at, Var), f"only Var access is supported, got {type(exp.at)}"
428 record = self.compile(env, exp.obj)
429 key_idx = self.record_key(exp.at.name)
430 # Check if the record is a record
431 self._guard(f"is_record({record})", "not a record")
432 value = self._mktemp(f"record_get({record}, {key_idx})")
433 self._guard(f"{value} != NULL", f"missing key {exp.at.name!s}")
434 return value
435 if isinstance(exp, Function):
436 # Anonymous function
437 return self.compile_function(env, exp, name=None)
438 if isinstance(exp, MatchFunction):
439 # Anonymous match function
440 return self.compile_match_function(env, exp, name=None)
441 raise NotImplementedError(f"exp {type(exp)} {exp}")
442
443
444def compile_to_string(program: Object, debug: bool) -> str:
445 main_fn = CompiledFunction("scrap_main", params=[])
446 compiler = Compiler(main_fn)
447 compiler.debug = debug
448 result = compiler.compile({}, program)
449 main_fn.code.append(f"return {result};")
450
451 f = io.StringIO()
452 constants = [
453 ("uword", "kKiB", 1024),
454 ("uword", "kMiB", "kKiB * kKiB"),
455 ("uword", "kGiB", "kKiB * kKiB * kKiB"),
456 ("uword", "kPageSize", "4 * kKiB"),
457 ("uword", "kSmallIntTagBits", 1),
458 ("uword", "kPrimaryTagBits", 3),
459 ("uword", "kObjectAlignmentLog2", 3), # bits
460 ("uword", "kObjectAlignment", "1ULL << kObjectAlignmentLog2"),
461 ("uword", "kImmediateTagBits", 5),
462 ("uword", "kSmallIntTagMask", "(1ULL << kSmallIntTagBits) - 1"),
463 ("uword", "kPrimaryTagMask", "(1ULL << kPrimaryTagBits) - 1"),
464 ("uword", "kImmediateTagMask", "(1ULL << kImmediateTagBits) - 1"),
465 ("uword", "kWordSize", "sizeof(word)"),
466 ("uword", "kMaxSmallStringLength", "kWordSize - 1"),
467 ("uword", "kBitsPerByte", 8),
468 # Up to the five least significant bits are used to tag the object's layout.
469 # The three low bits make up a primary tag, used to differentiate gc_obj
470 # from immediate objects. All even tags map to SmallInt, which is
471 # optimized by checking only the lowest bit for parity.
472 ("uword", "kSmallIntTag", 0), # 0b****0
473 ("uword", "kHeapObjectTag", 1), # 0b**001
474 ("uword", "kEmptyListTag", 5), # 0b00101
475 ("uword", "kHoleTag", 7), # 0b00111
476 ("uword", "kSmallStringTag", 13), # 0b01101
477 ("uword", "kVariantTag", 15), # 0b01111
478 # TODO(max): Fill in 21
479 # TODO(max): Fill in 23
480 # TODO(max): Fill in 29
481 # TODO(max): Fill in 31
482 ("uword", "kBitsPerPointer", "kBitsPerByte * kWordSize"),
483 ("word", "kSmallIntBits", "kBitsPerPointer - kSmallIntTagBits"),
484 ("word", "kSmallIntMinValue", "-(((word)1) << (kSmallIntBits - 1))"),
485 ("word", "kSmallIntMaxValue", "(((word)1) << (kSmallIntBits - 1)) - 1"),
486 ]
487 for type_, name, value in constants:
488 print(f"#define {name} ({type_})({value})", file=f)
489 # The runtime is in the same directory as this file
490 dirname = os.path.dirname(__file__)
491 with open(os.path.join(dirname, "runtime.c"), "r") as runtime:
492 print(runtime.read(), file=f)
493 print("#define OBJECT_HANDLE(name, exp) GC_HANDLE(struct object*, name, exp)", file=f)
494 if compiler.record_keys:
495 print("const char* record_keys[] = {", file=f)
496 for key in compiler.record_keys:
497 print(f'"{key}",', file=f)
498 print("};", file=f)
499 print("enum {", file=f)
500 for key, idx in compiler.record_keys.items():
501 print(f"Record_{key} = {idx},", file=f)
502 print("};", file=f)
503 else:
504 # Pacify the C compiler
505 print("const char* record_keys[] = { NULL };", file=f)
506 if compiler.variant_tags:
507 print("const char* variant_names[] = {", file=f)
508 for key in compiler.variant_tags:
509 print(f'"{key}",', file=f)
510 print("};", file=f)
511 print("enum {", file=f)
512 for key, idx in compiler.variant_tags.items():
513 print(f"Tag_{key} = {idx},", file=f)
514 print("};", file=f)
515 else:
516 # Pacify the C compiler
517 print("const char* variant_names[] = { NULL };", file=f)
518 # Declare all functions
519 for function in compiler.functions:
520 print(function.decl() + ";", file=f)
521 # Emit the const heap
522 print("#define ptrto(obj) ((struct object*)((uword)&(obj) + 1))", file=f)
523 for line in compiler.const_heap:
524 print(line, file=f)
525 for function in compiler.functions:
526 print(f"{function.decl()} {{", file=f)
527 for line in function.code:
528 print(line, file=f)
529 print("}", file=f)
530 return f.getvalue()