this repo has no description
1#!/usr/bin/env python3.10
2from __future__ import annotations
3import argparse
4import base64
5import code
6import copy
7import dataclasses
8import enum
9import functools
10import json
11import logging
12import os
13import struct
14import sys
15import typing
16import urllib.request
17from dataclasses import dataclass
18from enum import auto
19from types import ModuleType
20from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, Set, Tuple, Union
21
22readline: Optional[ModuleType]
23try:
24 import readline
25except ImportError:
26 readline = None
27
28
29logger = logging.getLogger(__name__)
30
31
32def is_identifier_char(c: str) -> bool:
33 return c.isalnum() or c in ("$", "'", "_")
34
35
36@dataclass(eq=True, order=True, unsafe_hash=True)
37class SourceLocation:
38 lineno: int = dataclasses.field(default=-1)
39 colno: int = dataclasses.field(default=-1)
40 byteno: int = dataclasses.field(default=-1)
41
42
43@dataclass(eq=True, unsafe_hash=True)
44class SourceExtent:
45 start: SourceLocation = dataclasses.field(default_factory=SourceLocation)
46 end: SourceLocation = dataclasses.field(default_factory=SourceLocation)
47
48 def coalesce(self, other: Optional[SourceExtent]) -> Optional[SourceExtent]:
49 return SourceExtent(min(self.start, other.start), max(self.end, other.end)) if other else None
50
51
52def join_source_extents(
53 source_extent_one: Optional[SourceExtent], source_extent_two: Optional[SourceExtent]
54) -> Optional[SourceExtent]:
55 return source_extent_one.coalesce(source_extent_two) if source_extent_one else None
56
57
58@dataclass(eq=True)
59class Token:
60 source_extent: SourceExtent = dataclasses.field(default_factory=SourceExtent, init=False, compare=False)
61
62 def with_source(self, source_extent: SourceExtent) -> Token:
63 self.source_extent = source_extent
64 return self
65
66
67@dataclass(eq=True)
68class IntLit(Token):
69 value: int
70
71
72@dataclass(eq=True)
73class FloatLit(Token):
74 value: float
75
76
77@dataclass(eq=True)
78class StringLit(Token):
79 value: str
80
81
82@dataclass(eq=True)
83class BytesLit(Token):
84 value: str
85 base: int
86
87
88@dataclass(eq=True)
89class Operator(Token):
90 value: str
91
92
93@dataclass(eq=True)
94class Name(Token):
95 value: str
96
97
98@dataclass(eq=True)
99class LeftParen(Token):
100 # (
101 pass
102
103
104@dataclass(eq=True)
105class RightParen(Token):
106 # )
107 pass
108
109
110@dataclass(eq=True)
111class LeftBrace(Token):
112 # {
113 pass
114
115
116@dataclass(eq=True)
117class RightBrace(Token):
118 # }
119 pass
120
121
122@dataclass(eq=True)
123class LeftBracket(Token):
124 # [
125 pass
126
127
128@dataclass(eq=True)
129class RightBracket(Token):
130 # ]
131 pass
132
133
134@dataclass(eq=True)
135class Hash(Token):
136 # #
137 pass
138
139
140@dataclass(eq=True)
141class EOF(Token):
142 pass
143
144
145def num_bytes_as_utf8(s: str) -> int:
146 return len(s.encode(encoding="UTF-8"))
147
148
149class Lexer:
150 def __init__(self, text: str):
151 self.text: str = text
152 self.idx: int = 0
153 self._lineno: int = 1
154 self._colno: int = 1
155 self.line: str = ""
156 self._byteno: int = 0
157 self.current_token_source_extent: SourceExtent = SourceExtent(
158 start=SourceLocation(
159 lineno=self._lineno,
160 colno=self._colno,
161 byteno=self._byteno,
162 ),
163 end=SourceLocation(
164 lineno=self._lineno,
165 colno=self._colno,
166 byteno=self._byteno,
167 ),
168 )
169 self.token_start_idx: int = self.idx
170 self.token_end_idx: int = self.token_start_idx
171
172 @property
173 def lineno(self) -> int:
174 return self._lineno
175
176 @property
177 def colno(self) -> int:
178 return self._colno
179
180 @property
181 def byteno(self) -> int:
182 return self._byteno
183
184 def mark_token_start(self) -> None:
185 self.current_token_source_extent.start.lineno = self._lineno
186 self.current_token_source_extent.start.colno = self._colno
187 self.current_token_source_extent.start.byteno = self._byteno
188 self.token_start_idx = self.idx
189
190 def mark_token_end(self) -> None:
191 self.current_token_source_extent.end.lineno = self._lineno
192 self.current_token_source_extent.end.colno = self._colno
193 self.current_token_source_extent.end.byteno = self._byteno
194 self.token_end_idx = self.idx
195
196 def has_input(self) -> bool:
197 return self.idx < len(self.text)
198
199 def read_char(self) -> str:
200 self.mark_token_end()
201 c = self.peek_char()
202 if c == "\n":
203 self._lineno += 1
204 self._colno = 1
205 self.line = ""
206 else:
207 self.line += c
208 self._colno += 1
209 self.idx += 1
210 self._byteno += num_bytes_as_utf8(c)
211 return c
212
213 def peek_char(self) -> str:
214 if not self.has_input():
215 raise UnexpectedEOFError("while reading token")
216 return self.text[self.idx]
217
218 def make_token(self, cls: type, *args: Any) -> Token:
219 result: Token = cls(*args)
220 return result.with_source(copy.deepcopy(self.current_token_source_extent))
221
222 def read_tokens(self) -> Generator[Token, None, None]:
223 while (token := self.read_token()) and not isinstance(token, EOF):
224 yield token
225
226 def read_token(self) -> Token:
227 # Consume all whitespace
228 while self.has_input():
229 # Keep updating the token start location until we exhaust all whitespace
230 self.mark_token_start()
231 c = self.read_char()
232 if not c.isspace():
233 break
234 else:
235 return self.make_token(EOF)
236 if c == '"':
237 return self.read_string()
238 if c == "-":
239 if self.has_input() and self.peek_char() == "-":
240 self.read_comment()
241 # Need to start reading a new token
242 return self.read_token()
243 return self.read_op(c)
244 if c == "#":
245 return self.make_token(Hash)
246 if c == "~":
247 if self.has_input() and self.peek_char() == "~":
248 self.read_char()
249 return self.read_bytes()
250 raise ParseError(f"unexpected token {c!r}")
251 if c.isdigit():
252 return self.read_number(c)
253 if c in "()[]{}":
254 custom = {
255 "(": LeftParen,
256 ")": RightParen,
257 "{": LeftBrace,
258 "}": RightBrace,
259 "[": LeftBracket,
260 "]": RightBracket,
261 }
262 return self.make_token(custom[c])
263 if c in OPER_CHARS:
264 return self.read_op(c)
265 if is_identifier_char(c):
266 return self.read_var(c)
267 raise InvalidTokenError(
268 SourceExtent(
269 start=SourceLocation(
270 lineno=self.current_token_source_extent.start.lineno,
271 colno=self.current_token_source_extent.start.colno,
272 byteno=self.current_token_source_extent.start.byteno,
273 ),
274 end=SourceLocation(
275 lineno=self.current_token_source_extent.end.lineno,
276 colno=self.current_token_source_extent.end.colno,
277 byteno=self.current_token_source_extent.end.byteno,
278 ),
279 )
280 )
281
282 def read_string(self) -> Token:
283 buf = ""
284 while self.has_input():
285 if (c := self.read_char()) == '"':
286 break
287 buf += c
288 else:
289 raise UnexpectedEOFError("while reading string")
290 return self.make_token(StringLit, buf)
291
292 def read_comment(self) -> None:
293 while self.has_input() and self.read_char() != "\n":
294 pass
295
296 def read_number(self, first_digit: str) -> Token:
297 # TODO: Support floating point numbers with no integer part
298 buf = first_digit
299 has_decimal = False
300 while self.has_input():
301 c = self.peek_char()
302 if c == ".":
303 if has_decimal:
304 raise ParseError(f"unexpected token {c!r}")
305 has_decimal = True
306 elif not c.isdigit():
307 break
308 self.read_char()
309 buf += c
310
311 if has_decimal:
312 return self.make_token(FloatLit, float(buf))
313 return self.make_token(IntLit, int(buf))
314
315 def _starts_operator(self, buf: str) -> bool:
316 # TODO(max): Rewrite using trie
317 return any(op.startswith(buf) for op in PS.keys())
318
319 def read_op(self, first_char: str) -> Token:
320 buf = first_char
321 while self.has_input():
322 c = self.peek_char()
323 if not self._starts_operator(buf + c):
324 break
325 self.read_char()
326 buf += c
327 if buf in PS.keys():
328 return self.make_token(Operator, buf)
329 raise ParseError(f"unexpected token {buf!r}")
330
331 def read_var(self, first_char: str) -> Token:
332 buf = first_char
333 while self.has_input() and is_identifier_char(c := self.peek_char()):
334 self.read_char()
335 buf += c
336 return self.make_token(Name, buf)
337
338 def read_bytes(self) -> Token:
339 buf = ""
340 while self.has_input():
341 if self.peek_char().isspace():
342 break
343 buf += self.read_char()
344 base, _, value = buf.rpartition("'")
345 return self.make_token(BytesLit, value, int(base) if base else 64)
346
347
348PEEK_EMPTY = object()
349
350
351class Peekable:
352 def __init__(self, iterator: Iterator[Any]) -> None:
353 self.iterator = iterator
354 self.cache = PEEK_EMPTY
355
356 def __iter__(self) -> Iterator[Any]:
357 return self
358
359 def __next__(self) -> Any:
360 if self.cache is not PEEK_EMPTY:
361 result = self.cache
362 self.cache = PEEK_EMPTY
363 return result
364 return next(self.iterator)
365
366 def peek(self) -> Any:
367 result = self.cache = next(self)
368 return result
369
370
371def tokenize(x: str) -> Peekable:
372 lexer = Lexer(x)
373 return Peekable(lexer.read_tokens())
374
375
376@dataclass(frozen=True)
377class Prec:
378 pl: float
379 pr: float
380
381
382def lp(n: float) -> Prec:
383 # TODO(max): Rewrite
384 return Prec(n, n - 0.1)
385
386
387def rp(n: float) -> Prec:
388 # TODO(max): Rewrite
389 return Prec(n, n + 0.1)
390
391
392def np(n: float) -> Prec:
393 # TODO(max): Rewrite
394 return Prec(n, n)
395
396
397def xp(n: float) -> Prec:
398 # TODO(max): Rewrite
399 return Prec(n, 0)
400
401
402PS = {
403 "::": lp(2000),
404 "@": rp(1001),
405 "": rp(1000),
406 ">>": lp(14),
407 "<<": lp(14),
408 "^": rp(13),
409 "*": rp(12),
410 "/": rp(12),
411 "//": lp(12),
412 "%": lp(12),
413 "+": lp(11),
414 "-": lp(11),
415 ">*": rp(10),
416 "++": rp(10),
417 ">+": lp(10),
418 "+<": rp(10),
419 "==": np(9),
420 "/=": np(9),
421 "<": np(9),
422 ">": np(9),
423 "<=": np(9),
424 ">=": np(9),
425 "&&": rp(8),
426 "||": rp(7),
427 "|>": rp(6),
428 "<|": lp(6),
429 "#": lp(5.5),
430 "->": lp(5),
431 "|": rp(4.5),
432 ":": lp(4.5),
433 "=": rp(4),
434 "!": lp(3),
435 ".": rp(3),
436 "?": rp(3),
437 ",": xp(1),
438 # TODO: Fix precedence for spread
439 "...": xp(0),
440}
441
442
443HIGHEST_PREC: float = max(max(p.pl, p.pr) for p in PS.values())
444
445
446OPER_CHARS = set("".join(PS.keys()))
447assert " " not in OPER_CHARS
448
449
450class SyntacticError(Exception):
451 pass
452
453
454class ParseError(SyntacticError):
455 pass
456
457
458@dataclass(eq=True, frozen=True, unsafe_hash=True)
459class UnexpectedTokenError(ParseError):
460 unexpected_token: Token
461
462
463@dataclass(eq=True, frozen=True, unsafe_hash=True)
464class InvalidTokenError(ParseError):
465 unexpected_token: SourceExtent = dataclasses.field(default_factory=SourceExtent, compare=False)
466
467
468# TODO(max): Replace with EOFError?
469class UnexpectedEOFError(ParseError):
470 pass
471
472
473def parse_assign(tokens: Peekable, p: float = 0) -> "Assign":
474 assign = parse_binary(tokens, p)
475 if isinstance(assign, Spread):
476 return Assign(RECORD_SPREAD_KEY_PLACEHOLDER, assign)
477 if not isinstance(assign, Assign):
478 raise ParseError("failed to parse variable assignment in record constructor")
479 return assign
480
481
482def gensym() -> str:
483 gensym.counter += 1 # type: ignore
484 return f"$v{gensym.counter}" # type: ignore
485
486
487def gensym_reset() -> None:
488 gensym.counter = -1 # type: ignore
489
490
491gensym_reset()
492
493
494def make_source_annotated_object(cls: type, source_extent: Optional[SourceExtent], *args: Any) -> Object:
495 result: Object = cls(*args)
496 object.__setattr__(result, "source_extent", source_extent)
497 return result
498
499
500def parse_unary(tokens: Peekable, p: float) -> "Object":
501 token = next(tokens)
502 l: Object
503 if isinstance(token, IntLit):
504 return make_source_annotated_object(Int, token.source_extent, token.value)
505 elif isinstance(token, FloatLit):
506 return make_source_annotated_object(Float, token.source_extent, token.value)
507 elif isinstance(token, Name):
508 # TODO: Handle kebab case vars
509 return make_source_annotated_object(Var, token.source_extent, token.value)
510 elif isinstance(token, Hash):
511 hash_source_extent = token.source_extent
512 if isinstance(variant_tag := next(tokens), Name):
513 # It needs to be higher than the precedence of the -> operator so that
514 # we can match variants in MatchFunction
515 # It needs to be higher than the precedence of the && operator so that
516 # we can use #true() and #false() in boolean expressions
517 # It needs to be higher than the precedence of juxtaposition so that
518 # f #true() #false() is parsed as f(TRUE)(FALSE)
519 variant_payload = parse_binary(tokens, PS[""].pr + 1)
520 return make_source_annotated_object(
521 Variant,
522 hash_source_extent.coalesce(variant_payload.source_extent),
523 variant_tag.value,
524 variant_payload,
525 )
526 else:
527 raise UnexpectedTokenError(variant_tag)
528 elif isinstance(token, BytesLit):
529 base = token.base
530 if base == 85:
531 l = Bytes(base64.b85decode(token.value))
532 elif base == 64:
533 l = Bytes(base64.b64decode(token.value))
534 elif base == 32:
535 l = Bytes(base64.b32decode(token.value))
536 elif base == 16:
537 l = Bytes(base64.b16decode(token.value))
538 else:
539 raise ParseError(f"unexpected base {base!r} in {token!r}")
540 object.__setattr__(l, "source_extent", token.source_extent)
541 return l
542 elif isinstance(token, StringLit):
543 return make_source_annotated_object(String, token.source_extent, token.value)
544 elif token == Operator("..."):
545 try:
546 if isinstance(tokens.peek(), Name):
547 spread_variable = next(tokens)
548 return make_source_annotated_object(
549 Spread, token.source_extent.coalesce(spread_variable.source_extent), spread_variable.value
550 )
551 else:
552 return make_source_annotated_object(Spread, token.source_extent)
553 except StopIteration:
554 return Spread()
555 elif token == Operator("|"):
556 pipe_source_extent = token.source_extent
557 expr = parse_binary(tokens, PS["|"].pr) # TODO: make this work for larger arities
558 if not isinstance(expr, Function):
559 raise ParseError(f"expected function in match expression {expr!r}")
560 match_case = make_source_annotated_object(
561 MatchCase, pipe_source_extent.coalesce(expr.source_extent), expr.arg, expr.body
562 )
563 cases = [match_case]
564 match_function_source_extent = match_case.source_extent
565 while True:
566 try:
567 if tokens.peek() != Operator("|"):
568 break
569 except StopIteration:
570 break
571 pipe_source_extent = next(tokens).source_extent
572 expr = parse_binary(tokens, PS["|"].pr) # TODO: make this work for larger arities
573 if not isinstance(expr, Function):
574 raise ParseError(f"expected function in match expression {expr!r}")
575 match_case = make_source_annotated_object(
576 MatchCase, pipe_source_extent.coalesce(expr.source_extent), expr.arg, expr.body
577 )
578 cases.append(match_case)
579 match_function_source_extent = join_source_extents(match_function_source_extent, match_case.source_extent)
580 return make_source_annotated_object(
581 MatchFunction,
582 match_function_source_extent,
583 cases,
584 )
585 elif isinstance(token, LeftParen):
586 left_paren_source_extent = token.source_extent
587 if isinstance(tokens.peek(), RightParen):
588 l = make_source_annotated_object(Hole, left_paren_source_extent.coalesce(next(tokens).source_extent))
589 else:
590 l = parse(tokens)
591 object.__setattr__(l, "source_extent", left_paren_source_extent.coalesce(next(tokens).source_extent))
592 return l
593 elif isinstance(token, LeftBracket):
594 list_start_source_extent = token.source_extent
595 l = List([])
596 token = tokens.peek()
597 if isinstance(token, RightBracket):
598 list_end_source_extent = next(tokens).source_extent
599 else:
600 l.items.append(parse_binary(tokens, 2))
601 while not isinstance(token := next(tokens), RightBracket):
602 if isinstance(l.items[-1], Spread):
603 raise ParseError("spread must come at end of list match")
604 # TODO: Implement .. operator
605 l.items.append(parse_binary(tokens, 2))
606 list_end_source_extent = token.source_extent
607 object.__setattr__(l, "source_extent", list_start_source_extent.coalesce(list_end_source_extent))
608 return l
609 elif isinstance(token, LeftBrace):
610 record_start_source_extent = token.source_extent
611 l = Record({})
612 token = tokens.peek()
613 if isinstance(token, RightBrace):
614 record_end_source_extent = next(tokens).source_extent
615 else:
616 assign = parse_assign(tokens, 2)
617 l.data[assign.name.name] = assign.value
618 while not isinstance(token := next(tokens), RightBrace):
619 if isinstance(assign.value, Spread):
620 raise ParseError("spread must come at end of record match")
621 # TODO: Implement .. operator
622 assign = parse_assign(tokens, 2)
623 l.data[assign.name.name] = assign.value
624 record_end_source_extent = token.source_extent
625 object.__setattr__(l, "source_extent", record_start_source_extent.coalesce(record_end_source_extent))
626 return l
627 elif token == Operator("-"):
628 # Unary minus
629 # Precedence was chosen to be higher than binary ops so that -a op
630 # b is (-a) op b and not -(a op b).
631 # Precedence was chosen to be higher than function application so that
632 # -a b is (-a) b and not -(a b).
633 r = parse_binary(tokens, HIGHEST_PREC + 1)
634 source_extent = token.source_extent.coalesce(r.source_extent)
635 if isinstance(r, Int):
636 assert r.value >= 0, "Tokens should never have negative values"
637 return make_source_annotated_object(Int, source_extent, -r.value)
638 if isinstance(r, Float):
639 assert r.value >= 0, "Tokens should never have negative values"
640 return make_source_annotated_object(Float, source_extent, -r.value)
641 return make_source_annotated_object(Binop, source_extent, BinopKind.SUB, Int(0), r)
642 else:
643 raise UnexpectedTokenError(token)
644
645
646def parse_binary(tokens: Peekable, p: float) -> "Object":
647 l: Object = parse_unary(tokens, p)
648 while True:
649 op: Token
650 try:
651 op = tokens.peek()
652 except StopIteration:
653 break
654 if isinstance(op, (RightParen, RightBracket, RightBrace)):
655 break
656 if not isinstance(op, Operator):
657 prec = PS[""]
658 pl, pr = prec.pl, prec.pr
659 if pl < p:
660 break
661 arg = parse_binary(tokens, pr)
662 l = make_source_annotated_object(Apply, join_source_extents(l.source_extent, arg.source_extent), l, arg)
663 continue
664 prec = PS[op.value]
665 pl, pr = prec.pl, prec.pr
666 if pl < p:
667 break
668 next(tokens)
669 if op == Operator("="):
670 if not isinstance(l, Var):
671 raise ParseError(f"expected variable in assignment {l!r}")
672 value = parse_binary(tokens, pr)
673 l = make_source_annotated_object(
674 Assign, join_source_extents(l.source_extent, value.source_extent), l, value
675 )
676 elif op == Operator("->"):
677 body = parse_binary(tokens, pr)
678 l = make_source_annotated_object(
679 Function, join_source_extents(l.source_extent, body.source_extent), l, body
680 )
681 elif op == Operator("|>"):
682 func = parse_binary(tokens, pr)
683 l = make_source_annotated_object(Apply, join_source_extents(func.source_extent, l.source_extent), func, l)
684 elif op == Operator("<|"):
685 arg = parse_binary(tokens, pr)
686 l = make_source_annotated_object(Apply, join_source_extents(l.source_extent, arg.source_extent), l, arg)
687 elif op == Operator(">>"):
688 r = parse_binary(tokens, pr)
689 varname = gensym()
690 l = make_source_annotated_object(
691 Function,
692 join_source_extents(l.source_extent, r.source_extent),
693 Var(varname),
694 Apply(r, Apply(l, Var(varname))),
695 )
696 elif op == Operator("<<"):
697 r = parse_binary(tokens, pr)
698 varname = gensym()
699 l = make_source_annotated_object(
700 Function,
701 join_source_extents(l.source_extent, r.source_extent),
702 Var(varname),
703 Apply(l, Apply(r, Var(varname))),
704 )
705 elif op == Operator("."):
706 binding = parse_binary(tokens, pr)
707 l = make_source_annotated_object(
708 Where, join_source_extents(l.source_extent, binding.source_extent), l, binding
709 )
710 elif op == Operator("?"):
711 cond = parse_binary(tokens, pr)
712 l = make_source_annotated_object(Assert, join_source_extents(l.source_extent, cond.source_extent), l, cond)
713 elif op == Operator("@"):
714 # TODO: revisit whether to use @ or . for field access
715 at = parse_binary(tokens, pr)
716 l = make_source_annotated_object(Access, join_source_extents(l.source_extent, at.source_extent), l, at)
717 else:
718 assert isinstance(op, Operator)
719 right = parse_binary(tokens, pr)
720 l = make_source_annotated_object(
721 Binop,
722 join_source_extents(l.source_extent, right.source_extent),
723 BinopKind.from_str(op.value),
724 l,
725 right,
726 )
727 return l
728
729
730def parse(tokens: Peekable) -> "Object":
731 try:
732 return parse_binary(tokens, 0)
733 except StopIteration:
734 raise UnexpectedEOFError("unexpected end of input")
735
736
737@dataclass(eq=True, frozen=True, unsafe_hash=True)
738class Object:
739 source_extent: Optional[SourceExtent] = dataclasses.field(default=None, compare=False, init=False, repr=False)
740
741 def __str__(self) -> str:
742 return pretty(self)
743
744
745@dataclass(eq=True, frozen=True, unsafe_hash=True)
746class Int(Object):
747 value: int
748
749
750@dataclass(eq=True, frozen=True, unsafe_hash=True)
751class Float(Object):
752 value: float
753
754
755@dataclass(eq=True, frozen=True, unsafe_hash=True)
756class String(Object):
757 value: str
758
759
760@dataclass(eq=True, frozen=True, unsafe_hash=True)
761class Bytes(Object):
762 value: bytes
763
764
765@dataclass(eq=True, frozen=True, unsafe_hash=True)
766class Var(Object):
767 name: str
768
769
770@dataclass(eq=True, frozen=True, unsafe_hash=True)
771class Hole(Object):
772 pass
773
774
775@dataclass(eq=True, frozen=True, unsafe_hash=True)
776class Spread(Object):
777 name: Optional[str] = None
778
779
780RECORD_SPREAD_KEY_PLACEHOLDER = Var("...")
781
782
783Env = Mapping[str, Object]
784
785
786# TODO(max): Add source extents for BinopKind?
787class BinopKind(enum.Enum):
788 ADD = auto()
789 SUB = auto()
790 MUL = auto()
791 DIV = auto()
792 FLOOR_DIV = auto()
793 EXP = auto()
794 MOD = auto()
795 EQUAL = auto()
796 NOT_EQUAL = auto()
797 LESS = auto()
798 GREATER = auto()
799 LESS_EQUAL = auto()
800 GREATER_EQUAL = auto()
801 BOOL_AND = auto()
802 BOOL_OR = auto()
803 STRING_CONCAT = auto()
804 LIST_CONS = auto()
805 LIST_APPEND = auto()
806 RIGHT_EVAL = auto()
807 HASTYPE = auto()
808 PIPE = auto()
809 REVERSE_PIPE = auto()
810
811 @classmethod
812 def from_str(cls, x: str) -> "BinopKind":
813 return {
814 "+": cls.ADD,
815 "-": cls.SUB,
816 "*": cls.MUL,
817 "/": cls.DIV,
818 "//": cls.FLOOR_DIV,
819 "^": cls.EXP,
820 "%": cls.MOD,
821 "==": cls.EQUAL,
822 "/=": cls.NOT_EQUAL,
823 "<": cls.LESS,
824 ">": cls.GREATER,
825 "<=": cls.LESS_EQUAL,
826 ">=": cls.GREATER_EQUAL,
827 "&&": cls.BOOL_AND,
828 "||": cls.BOOL_OR,
829 "++": cls.STRING_CONCAT,
830 ">+": cls.LIST_CONS,
831 "+<": cls.LIST_APPEND,
832 "!": cls.RIGHT_EVAL,
833 ":": cls.HASTYPE,
834 "|>": cls.PIPE,
835 "<|": cls.REVERSE_PIPE,
836 }[x]
837
838 @classmethod
839 def to_str(cls, binop_kind: "BinopKind") -> str:
840 return {
841 cls.ADD: "+",
842 cls.SUB: "-",
843 cls.MUL: "*",
844 cls.DIV: "/",
845 cls.EXP: "^",
846 cls.MOD: "%",
847 cls.EQUAL: "==",
848 cls.NOT_EQUAL: "/=",
849 cls.LESS: "<",
850 cls.GREATER: ">",
851 cls.LESS_EQUAL: "<=",
852 cls.GREATER_EQUAL: ">=",
853 cls.BOOL_AND: "&&",
854 cls.BOOL_OR: "||",
855 cls.STRING_CONCAT: "++",
856 cls.LIST_CONS: ">+",
857 cls.LIST_APPEND: "+<",
858 cls.RIGHT_EVAL: "!",
859 cls.HASTYPE: ":",
860 cls.PIPE: "|>",
861 cls.REVERSE_PIPE: "<|",
862 }[binop_kind]
863
864
865@dataclass(eq=True, frozen=True, unsafe_hash=True)
866class Binop(Object):
867 op: BinopKind
868 left: Object
869 right: Object
870
871
872@dataclass(eq=True, frozen=True, unsafe_hash=True)
873class List(Object):
874 items: typing.List[Object]
875
876
877@dataclass(eq=True, frozen=True, unsafe_hash=True)
878class Assign(Object):
879 name: Var
880 value: Object
881
882
883@dataclass(eq=True, frozen=True, unsafe_hash=True)
884class Function(Object):
885 arg: Object
886 body: Object
887
888
889@dataclass(eq=True, frozen=True, unsafe_hash=True)
890class Apply(Object):
891 func: Object
892 arg: Object
893
894
895@dataclass(eq=True, frozen=True, unsafe_hash=True)
896class Where(Object):
897 body: Object
898 binding: Object
899
900
901@dataclass(eq=True, frozen=True, unsafe_hash=True)
902class Assert(Object):
903 value: Object
904 cond: Object
905
906
907@dataclass(eq=True, frozen=True, unsafe_hash=True)
908class EnvObject(Object):
909 env: Env
910
911 def __str__(self) -> str:
912 return f"EnvObject(keys={self.env.keys()})"
913
914
915@dataclass(eq=True, frozen=True, unsafe_hash=True)
916class MatchCase(Object):
917 pattern: Object
918 body: Object
919
920
921@dataclass(eq=True, frozen=True, unsafe_hash=True)
922class MatchFunction(Object):
923 cases: typing.List[MatchCase]
924
925
926@dataclass(eq=True, frozen=True, unsafe_hash=True)
927class Relocation(Object):
928 name: str
929
930
931@dataclass(eq=True, frozen=True, unsafe_hash=True)
932class NativeFunctionRelocation(Relocation):
933 pass
934
935
936@dataclass(eq=True, frozen=True, unsafe_hash=True)
937class NativeFunction(Object):
938 name: str
939 func: Callable[[Object], Object]
940
941
942@dataclass(eq=True, frozen=True, unsafe_hash=True)
943class Closure(Object):
944 env: Env
945 func: Union[Function, MatchFunction]
946
947
948@dataclass(eq=True, frozen=True, unsafe_hash=True)
949class Record(Object):
950 data: Dict[str, Object]
951
952
953@dataclass(eq=True, frozen=True, unsafe_hash=True)
954class Access(Object):
955 obj: Object
956 at: Object
957
958
959@dataclass(eq=True, frozen=True, unsafe_hash=True)
960class Variant(Object):
961 tag: str
962 value: Object
963
964
965tags = [
966 TYPE_SHORT := b"i", # fits in 64 bits
967 TYPE_LONG := b"l", # bignum
968 TYPE_FLOAT := b"d",
969 TYPE_STRING := b"s",
970 TYPE_REF := b"r",
971 TYPE_LIST := b"[",
972 TYPE_RECORD := b"{",
973 TYPE_VARIANT := b"#",
974 TYPE_VAR := b"v",
975 TYPE_FUNCTION := b"f",
976 TYPE_MATCH_FUNCTION := b"m",
977 TYPE_CLOSURE := b"c",
978 TYPE_BYTES := b"b",
979 TYPE_HOLE := b"(",
980 TYPE_ASSIGN := b"=",
981 TYPE_BINOP := b"+",
982 TYPE_APPLY := b" ",
983 TYPE_WHERE := b".",
984 TYPE_ACCESS := b"@",
985 TYPE_SPREAD := b"S",
986 TYPE_NAMED_SPREAD := b"R",
987 TYPE_TRUE := b"T",
988 TYPE_FALSE := b"F",
989]
990FLAG_REF = 0x80
991
992
993BITS_PER_BYTE = 8
994BYTES_PER_DIGIT = 8
995BITS_PER_DIGIT = BYTES_PER_DIGIT * BITS_PER_BYTE
996DIGIT_MASK = (1 << BITS_PER_DIGIT) - 1
997
998
999def ref(tag: bytes) -> bytes:
1000 return (tag[0] | FLAG_REF).to_bytes(1, "little")
1001
1002
1003tags = tags + [ref(v) for v in tags]
1004assert len(tags) == len(set(tags)), "Duplicate tags"
1005assert all(len(v) == 1 for v in tags), "Tags must be 1 byte"
1006assert all(isinstance(v, bytes) for v in tags)
1007
1008
1009def zigzag_encode(val: int) -> int:
1010 if val < 0:
1011 return -2 * val - 1
1012 return 2 * val
1013
1014
1015def zigzag_decode(val: int) -> int:
1016 if val & 1 == 1:
1017 return -val // 2
1018 return val // 2
1019
1020
1021@dataclass
1022class Serializer:
1023 refs: typing.List[Object] = dataclasses.field(default_factory=list)
1024 output: bytearray = dataclasses.field(default_factory=bytearray)
1025
1026 def ref(self, obj: Object) -> Optional[int]:
1027 for idx, ref in enumerate(self.refs):
1028 if ref is obj:
1029 return idx
1030 return None
1031
1032 def add_ref(self, ty: bytes, obj: Object) -> int:
1033 assert len(ty) == 1
1034 assert self.ref(obj) is None
1035 self.emit(ref(ty))
1036 result = len(self.refs)
1037 self.refs.append(obj)
1038 return result
1039
1040 def emit(self, obj: bytes) -> None:
1041 self.output.extend(obj)
1042
1043 def _fits_in_nbits(self, obj: int, nbits: int) -> bool:
1044 return -(1 << (nbits - 1)) <= obj < (1 << (nbits - 1))
1045
1046 def _short(self, number: int) -> bytes:
1047 # From Peter Ruibal, https://github.com/fmoo/python-varint
1048 number = zigzag_encode(number)
1049 buf = bytearray()
1050 while True:
1051 towrite = number & 0x7F
1052 number >>= 7
1053 if number:
1054 buf.append(towrite | 0x80)
1055 else:
1056 buf.append(towrite)
1057 break
1058 return bytes(buf)
1059
1060 def _long(self, number: int) -> bytes:
1061 digits = []
1062 number = zigzag_encode(number)
1063 while number:
1064 digits.append(number & DIGIT_MASK)
1065 number >>= BITS_PER_DIGIT
1066 buf = bytearray(self._short(len(digits)))
1067 for digit in digits:
1068 buf.extend(digit.to_bytes(BYTES_PER_DIGIT, "little"))
1069 return bytes(buf)
1070
1071 def _string(self, obj: str) -> bytes:
1072 encoded = obj.encode("utf-8")
1073 return self._short(len(encoded)) + encoded
1074
1075 def serialize(self, obj: Object) -> None:
1076 assert isinstance(obj, Object), type(obj)
1077 if (ref := self.ref(obj)) is not None:
1078 return self.emit(TYPE_REF + self._short(ref))
1079 if isinstance(obj, Int):
1080 if self._fits_in_nbits(obj.value, 64):
1081 self.emit(TYPE_SHORT)
1082 self.emit(self._short(obj.value))
1083 return
1084 self.emit(TYPE_LONG)
1085 self.emit(self._long(obj.value))
1086 return
1087 if isinstance(obj, String):
1088 return self.emit(TYPE_STRING + self._string(obj.value))
1089 if isinstance(obj, List):
1090 self.add_ref(TYPE_LIST, obj)
1091 self.emit(self._short(len(obj.items)))
1092 for item in obj.items:
1093 self.serialize(item)
1094 return
1095 if isinstance(obj, Variant):
1096 if obj.tag == "true" and isinstance(obj.value, Hole):
1097 return self.emit(TYPE_TRUE)
1098 if obj.tag == "false" and isinstance(obj.value, Hole):
1099 return self.emit(TYPE_FALSE)
1100 # TODO(max): Determine if this should be a ref
1101 self.emit(TYPE_VARIANT)
1102 # TODO(max): String pool (via refs) for strings longer than some length?
1103 self.emit(self._string(obj.tag))
1104 return self.serialize(obj.value)
1105 if isinstance(obj, Record):
1106 # TODO(max): Determine if this should be a ref
1107 self.emit(TYPE_RECORD)
1108 self.emit(self._short(len(obj.data)))
1109 for key, value in obj.data.items():
1110 self.emit(self._string(key))
1111 self.serialize(value)
1112 return
1113 if isinstance(obj, Var):
1114 return self.emit(TYPE_VAR + self._string(obj.name))
1115 if isinstance(obj, Function):
1116 self.emit(TYPE_FUNCTION)
1117 self.serialize(obj.arg)
1118 return self.serialize(obj.body)
1119 if isinstance(obj, MatchFunction):
1120 self.emit(TYPE_MATCH_FUNCTION)
1121 self.emit(self._short(len(obj.cases)))
1122 for case in obj.cases:
1123 self.serialize(case.pattern)
1124 self.serialize(case.body)
1125 return
1126 if isinstance(obj, Closure):
1127 self.add_ref(TYPE_CLOSURE, obj)
1128 self.serialize(obj.func)
1129 self.emit(self._short(len(obj.env)))
1130 for key, value in obj.env.items():
1131 self.emit(self._string(key))
1132 self.serialize(value)
1133 return
1134 if isinstance(obj, Bytes):
1135 self.emit(TYPE_BYTES)
1136 self.emit(self._short(len(obj.value)))
1137 self.emit(obj.value)
1138 return
1139 if isinstance(obj, Float):
1140 self.emit(TYPE_FLOAT)
1141 self.emit(struct.pack("<d", obj.value))
1142 return
1143 if isinstance(obj, Hole):
1144 self.emit(TYPE_HOLE)
1145 return
1146 if isinstance(obj, Assign):
1147 self.emit(TYPE_ASSIGN)
1148 self.serialize(obj.name)
1149 self.serialize(obj.value)
1150 return
1151 if isinstance(obj, Binop):
1152 self.emit(TYPE_BINOP)
1153 self.emit(self._string(BinopKind.to_str(obj.op)))
1154 self.serialize(obj.left)
1155 self.serialize(obj.right)
1156 return
1157 if isinstance(obj, Apply):
1158 self.emit(TYPE_APPLY)
1159 self.serialize(obj.func)
1160 self.serialize(obj.arg)
1161 return
1162 if isinstance(obj, Where):
1163 self.emit(TYPE_WHERE)
1164 self.serialize(obj.body)
1165 self.serialize(obj.binding)
1166 return
1167 if isinstance(obj, Access):
1168 self.emit(TYPE_ACCESS)
1169 self.serialize(obj.obj)
1170 self.serialize(obj.at)
1171 return
1172 if isinstance(obj, Spread):
1173 if obj.name is not None:
1174 self.emit(TYPE_NAMED_SPREAD)
1175 self.emit(self._string(obj.name))
1176 return
1177 self.emit(TYPE_SPREAD)
1178 return
1179 raise NotImplementedError(type(obj))
1180
1181
1182@dataclass
1183class Deserializer:
1184 flat: Union[bytes, memoryview]
1185 idx: int = 0
1186 refs: typing.List[Object] = dataclasses.field(default_factory=list)
1187
1188 def __post_init__(self) -> None:
1189 if isinstance(self.flat, bytes):
1190 self.flat = memoryview(self.flat)
1191
1192 def read(self, size: int) -> memoryview:
1193 result = memoryview(self.flat[self.idx : self.idx + size])
1194 self.idx += size
1195 return result
1196
1197 def read_tag(self) -> Tuple[bytes, bool]:
1198 tag = self.read(1)[0]
1199 is_ref = bool(tag & FLAG_REF)
1200 return (tag & ~FLAG_REF).to_bytes(1, "little"), is_ref
1201
1202 def _string(self) -> str:
1203 length = self._short()
1204 encoded = self.read(length)
1205 return str(encoded, "utf-8")
1206
1207 def _short(self) -> int:
1208 # From Peter Ruibal, https://github.com/fmoo/python-varint
1209 shift = 0
1210 result = 0
1211 while True:
1212 i = self.read(1)[0]
1213 result |= (i & 0x7F) << shift
1214 shift += 7
1215 if not (i & 0x80):
1216 break
1217 return zigzag_decode(result)
1218
1219 def _long(self) -> int:
1220 num_digits = self._short()
1221 digits = []
1222 for _ in range(num_digits):
1223 digit = int.from_bytes(self.read(BYTES_PER_DIGIT), "little")
1224 digits.append(digit)
1225 result = 0
1226 for digit in reversed(digits):
1227 result <<= BITS_PER_DIGIT
1228 result |= digit
1229 return zigzag_decode(result)
1230
1231 def parse(self) -> Object:
1232 ty, is_ref = self.read_tag()
1233 if ty == TYPE_REF:
1234 idx = self._short()
1235 return self.refs[idx]
1236 if ty == TYPE_SHORT:
1237 assert not is_ref
1238 return Int(self._short())
1239 if ty == TYPE_LONG:
1240 assert not is_ref
1241 return Int(self._long())
1242 if ty == TYPE_STRING:
1243 assert not is_ref
1244 return String(self._string())
1245 if ty == TYPE_LIST:
1246 length = self._short()
1247 result_list = List([])
1248 assert is_ref
1249 self.refs.append(result_list)
1250 for i in range(length):
1251 result_list.items.append(self.parse())
1252 return result_list
1253 if ty == TYPE_RECORD:
1254 assert not is_ref
1255 length = self._short()
1256 result_rec = Record({})
1257 for i in range(length):
1258 key = self._string()
1259 value = self.parse()
1260 result_rec.data[key] = value
1261 return result_rec
1262 if ty == TYPE_VARIANT:
1263 assert not is_ref
1264 tag = self._string()
1265 value = self.parse()
1266 return Variant(tag, value)
1267 if ty == TYPE_VAR:
1268 assert not is_ref
1269 return Var(self._string())
1270 if ty == TYPE_FUNCTION:
1271 assert not is_ref
1272 arg = self.parse()
1273 body = self.parse()
1274 return Function(arg, body)
1275 if ty == TYPE_MATCH_FUNCTION:
1276 assert not is_ref
1277 length = self._short()
1278 result_matchfun = MatchFunction([])
1279 for i in range(length):
1280 pattern = self.parse()
1281 body = self.parse()
1282 result_matchfun.cases.append(MatchCase(pattern, body))
1283 return result_matchfun
1284 if ty == TYPE_CLOSURE:
1285 func = self.parse()
1286 length = self._short()
1287 assert isinstance(func, (Function, MatchFunction))
1288 result_closure = Closure({}, func)
1289 assert is_ref
1290 self.refs.append(result_closure)
1291 for i in range(length):
1292 key = self._string()
1293 value = self.parse()
1294 assert isinstance(result_closure.env, dict) # For mypy
1295 result_closure.env[key] = value
1296 return result_closure
1297 if ty == TYPE_BYTES:
1298 assert not is_ref
1299 length = self._short()
1300 return Bytes(self.read(length))
1301 if ty == TYPE_FLOAT:
1302 assert not is_ref
1303 return Float(struct.unpack("<d", self.read(8))[0])
1304 if ty == TYPE_HOLE:
1305 assert not is_ref
1306 return Hole()
1307 if ty == TYPE_ASSIGN:
1308 assert not is_ref
1309 name = self.parse()
1310 value = self.parse()
1311 assert isinstance(name, Var)
1312 return Assign(name, value)
1313 if ty == TYPE_BINOP:
1314 assert not is_ref
1315 op = BinopKind.from_str(self._string())
1316 left = self.parse()
1317 right = self.parse()
1318 return Binop(op, left, right)
1319 if ty == TYPE_APPLY:
1320 assert not is_ref
1321 func = self.parse()
1322 arg = self.parse()
1323 return Apply(func, arg)
1324 if ty == TYPE_WHERE:
1325 assert not is_ref
1326 body = self.parse()
1327 binding = self.parse()
1328 return Where(body, binding)
1329 if ty == TYPE_ACCESS:
1330 assert not is_ref
1331 obj = self.parse()
1332 at = self.parse()
1333 return Access(obj, at)
1334 if ty == TYPE_SPREAD:
1335 return Spread()
1336 if ty == TYPE_NAMED_SPREAD:
1337 return Spread(self._string())
1338 if ty == TYPE_TRUE:
1339 assert not is_ref
1340 return Variant("true", Hole())
1341 if ty == TYPE_FALSE:
1342 assert not is_ref
1343 return Variant("false", Hole())
1344 raise NotImplementedError(bytes(ty))
1345
1346
1347TRUE = Variant("true", Hole())
1348
1349
1350FALSE = Variant("false", Hole())
1351
1352
1353def unpack_number(obj: Object) -> Union[int, float]:
1354 if not isinstance(obj, (Int, Float)):
1355 raise TypeError(f"expected Int or Float, got {type(obj).__name__}")
1356 return obj.value
1357
1358
1359def eval_number(env: Env, exp: Object) -> Union[int, float]:
1360 result = eval_exp(env, exp)
1361 return unpack_number(result)
1362
1363
1364def eval_str(env: Env, exp: Object) -> str:
1365 result = eval_exp(env, exp)
1366 if not isinstance(result, String):
1367 raise TypeError(f"expected String, got {type(result).__name__}")
1368 return result.value
1369
1370
1371def eval_bool(env: Env, exp: Object) -> bool:
1372 result = eval_exp(env, exp)
1373 if not isinstance(result, Variant):
1374 raise TypeError(f"expected #true or #false, got {type(result).__name__}")
1375 if result.tag not in ("true", "false"):
1376 raise TypeError(f"expected #true or #false, got {type(result).__name__}")
1377 return result.tag == "true"
1378
1379
1380def eval_list(env: Env, exp: Object) -> typing.List[Object]:
1381 result = eval_exp(env, exp)
1382 if not isinstance(result, List):
1383 raise TypeError(f"expected List, got {type(result).__name__}")
1384 return result.items
1385
1386
1387def make_bool(x: bool) -> Object:
1388 return TRUE if x else FALSE
1389
1390
1391def wrap_inferred_number_type(x: Union[int, float]) -> Object:
1392 # TODO: Since this is intended to be a reference implementation
1393 # we should avoid relying heavily on Python's implementation of
1394 # arithmetic operations, type inference, and multiple dispatch.
1395 # Update this to make the interpreter more language agnostic.
1396 if isinstance(x, int):
1397 return Int(x)
1398 return Float(x)
1399
1400
1401BINOP_HANDLERS: Dict[BinopKind, Callable[[Env, Object, Object], Object]] = {
1402 BinopKind.ADD: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) + eval_number(env, y)),
1403 BinopKind.SUB: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) - eval_number(env, y)),
1404 BinopKind.MUL: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) * eval_number(env, y)),
1405 BinopKind.DIV: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) / eval_number(env, y)),
1406 BinopKind.FLOOR_DIV: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) // eval_number(env, y)),
1407 BinopKind.EXP: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) ** eval_number(env, y)),
1408 BinopKind.MOD: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) % eval_number(env, y)),
1409 BinopKind.EQUAL: lambda env, x, y: make_bool(eval_exp(env, x) == eval_exp(env, y)),
1410 BinopKind.NOT_EQUAL: lambda env, x, y: make_bool(eval_exp(env, x) != eval_exp(env, y)),
1411 BinopKind.LESS: lambda env, x, y: make_bool(eval_number(env, x) < eval_number(env, y)),
1412 BinopKind.GREATER: lambda env, x, y: make_bool(eval_number(env, x) > eval_number(env, y)),
1413 BinopKind.LESS_EQUAL: lambda env, x, y: make_bool(eval_number(env, x) <= eval_number(env, y)),
1414 BinopKind.GREATER_EQUAL: lambda env, x, y: make_bool(eval_number(env, x) >= eval_number(env, y)),
1415 BinopKind.BOOL_AND: lambda env, x, y: make_bool(eval_bool(env, x) and eval_bool(env, y)),
1416 BinopKind.BOOL_OR: lambda env, x, y: make_bool(eval_bool(env, x) or eval_bool(env, y)),
1417 BinopKind.STRING_CONCAT: lambda env, x, y: String(eval_str(env, x) + eval_str(env, y)),
1418 BinopKind.LIST_CONS: lambda env, x, y: List([eval_exp(env, x)] + eval_list(env, y)),
1419 BinopKind.LIST_APPEND: lambda env, x, y: List(eval_list(env, x) + [eval_exp(env, y)]),
1420 BinopKind.RIGHT_EVAL: lambda env, x, y: eval_exp(env, y),
1421}
1422
1423
1424class MatchError(Exception):
1425 pass
1426
1427
1428def match(obj: Object, pattern: Object) -> Optional[Env]:
1429 if isinstance(pattern, Hole):
1430 return {} if isinstance(obj, Hole) else None
1431 if isinstance(pattern, Int):
1432 return {} if isinstance(obj, Int) and obj.value == pattern.value else None
1433 if isinstance(pattern, Float):
1434 raise MatchError("pattern matching is not supported for Floats")
1435 if isinstance(pattern, String):
1436 return {} if isinstance(obj, String) and obj.value == pattern.value else None
1437 if isinstance(pattern, Var):
1438 return {pattern.name: obj}
1439 if isinstance(pattern, Variant):
1440 if not isinstance(obj, Variant):
1441 return None
1442 if obj.tag != pattern.tag:
1443 return None
1444 return match(obj.value, pattern.value)
1445 if isinstance(pattern, Record):
1446 if not isinstance(obj, Record):
1447 return None
1448 result: Env = {}
1449 seen_keys: set[str] = set()
1450 for key, pattern_item in pattern.data.items():
1451 if isinstance(pattern_item, Spread):
1452 if pattern_item.name is not None:
1453 assert isinstance(result, dict) # for .update()
1454 rest_keys = set(obj.data.keys()) - seen_keys
1455 result.update({pattern_item.name: Record({key: obj.data[key] for key in rest_keys})})
1456 return result
1457 seen_keys.add(key)
1458 obj_item = obj.data.get(key)
1459 if obj_item is None:
1460 return None
1461 part = match(obj_item, pattern_item)
1462 if part is None:
1463 return None
1464 assert isinstance(result, dict) # for .update()
1465 result.update(part)
1466 if len(pattern.data) != len(obj.data):
1467 return None
1468 return result
1469 if isinstance(pattern, List):
1470 if not isinstance(obj, List):
1471 return None
1472 result: Env = {} # type: ignore
1473 for i, pattern_item in enumerate(pattern.items):
1474 if isinstance(pattern_item, Spread):
1475 if pattern_item.name is not None:
1476 assert isinstance(result, dict) # for .update()
1477 result.update({pattern_item.name: List(obj.items[i:])})
1478 return result
1479 if i >= len(obj.items):
1480 return None
1481 obj_item = obj.items[i]
1482 part = match(obj_item, pattern_item)
1483 if part is None:
1484 return None
1485 assert isinstance(result, dict) # for .update()
1486 result.update(part)
1487 if len(pattern.items) != len(obj.items):
1488 return None
1489 return result
1490 raise NotImplementedError(f"match not implemented for {type(pattern).__name__}")
1491
1492
1493def free_in(exp: Object) -> Set[str]:
1494 if isinstance(exp, (Int, Float, String, Bytes, Hole, NativeFunction)):
1495 return set()
1496 if isinstance(exp, Variant):
1497 return free_in(exp.value)
1498 if isinstance(exp, Var):
1499 return {exp.name}
1500 if isinstance(exp, Spread):
1501 if exp.name is not None:
1502 return {exp.name}
1503 return set()
1504 if isinstance(exp, Binop):
1505 return free_in(exp.left) | free_in(exp.right)
1506 if isinstance(exp, List):
1507 if not exp.items:
1508 return set()
1509 return set.union(*(free_in(item) for item in exp.items))
1510 if isinstance(exp, Record):
1511 if not exp.data:
1512 return set()
1513 return set.union(*(free_in(value) for key, value in exp.data.items()))
1514 if isinstance(exp, Function):
1515 assert isinstance(exp.arg, Var)
1516 return free_in(exp.body) - {exp.arg.name}
1517 if isinstance(exp, MatchFunction):
1518 if not exp.cases:
1519 return set()
1520 return set.union(*(free_in(case) for case in exp.cases))
1521 if isinstance(exp, MatchCase):
1522 return free_in(exp.body) - free_in(exp.pattern)
1523 if isinstance(exp, Apply):
1524 return free_in(exp.func) | free_in(exp.arg)
1525 if isinstance(exp, Access):
1526 # For records, y is not free in x@y; it is a field name.
1527 # For lists, y *is* free in x@y; it is an index expression (could be a
1528 # var).
1529 # For now, we'll assume it might be an expression and mark it as a
1530 # (possibly extra) freevar.
1531 return free_in(exp.obj) | free_in(exp.at)
1532 if isinstance(exp, Where):
1533 assert isinstance(exp.binding, Assign)
1534 return (free_in(exp.body) - {exp.binding.name.name}) | free_in(exp.binding)
1535 if isinstance(exp, Assign):
1536 return free_in(exp.value)
1537 if isinstance(exp, Closure):
1538 # TODO(max): Should this remove the set of keys in the closure env?
1539 return free_in(exp.func)
1540 raise NotImplementedError(("free_in", type(exp)))
1541
1542
1543def improve_closure(closure: Closure) -> Closure:
1544 freevars = free_in(closure.func)
1545 env = {boundvar: value for boundvar, value in closure.env.items() if boundvar in freevars}
1546 return Closure(env, closure.func)
1547
1548
1549def eval_exp(env: Env, exp: Object) -> Object:
1550 logger.debug(exp)
1551 if isinstance(exp, (Int, Float, String, Bytes, Hole, Closure, NativeFunction)):
1552 return exp
1553 if isinstance(exp, Variant):
1554 return Variant(exp.tag, eval_exp(env, exp.value))
1555 if isinstance(exp, Var):
1556 value = env.get(exp.name)
1557 if value is None:
1558 raise NameError(f"name '{exp.name}' is not defined")
1559 return value
1560 if isinstance(exp, Binop):
1561 handler = BINOP_HANDLERS.get(exp.op)
1562 if handler is None:
1563 raise NotImplementedError(f"no handler for {exp.op}")
1564 return handler(env, exp.left, exp.right)
1565 if isinstance(exp, List):
1566 return List([eval_exp(env, item) for item in exp.items])
1567 if isinstance(exp, Record):
1568 return Record({k: eval_exp(env, exp.data[k]) for k in exp.data})
1569 if isinstance(exp, Assign):
1570 # TODO(max): Rework this. There's something about matching that we need
1571 # to figure out and implement.
1572 assert isinstance(exp.name, Var)
1573 value = eval_exp(env, exp.value)
1574 if isinstance(value, Closure):
1575 # We want functions to be able to call themselves without using the
1576 # Y combinator or similar, so we bind functions (and only
1577 # functions) using a letrec-like strategy. We augment their
1578 # captured environment with a binding to themselves.
1579 assert isinstance(value.env, dict)
1580 value.env[exp.name.name] = value
1581 # We still improve_closure here even though we also did it on
1582 # Closure creation because the Closure might not need a binding for
1583 # itself (it might not be recursive).
1584 value = improve_closure(value)
1585 return EnvObject({**env, exp.name.name: value})
1586 if isinstance(exp, Where):
1587 assert isinstance(exp.binding, Assign)
1588 res_env = eval_exp(env, exp.binding)
1589 assert isinstance(res_env, EnvObject)
1590 new_env = {**env, **res_env.env}
1591 return eval_exp(new_env, exp.body)
1592 if isinstance(exp, Assert):
1593 cond = eval_exp(env, exp.cond)
1594 if cond != TRUE:
1595 raise AssertionError(f"condition {exp.cond} failed")
1596 return eval_exp(env, exp.value)
1597 if isinstance(exp, Function):
1598 if not isinstance(exp.arg, Var):
1599 raise RuntimeError(f"expected variable in function definition {exp.arg}")
1600 value = Closure(env, exp)
1601 value = improve_closure(value)
1602 return value
1603 if isinstance(exp, MatchFunction):
1604 value = Closure(env, exp)
1605 value = improve_closure(value)
1606 return value
1607 if isinstance(exp, Apply):
1608 if isinstance(exp.func, Var) and exp.func.name == "$$quote":
1609 return exp.arg
1610 callee = eval_exp(env, exp.func)
1611 arg = eval_exp(env, exp.arg)
1612 if isinstance(callee, NativeFunction):
1613 return callee.func(arg)
1614 if not isinstance(callee, Closure):
1615 raise TypeError(f"attempted to apply a non-closure of type {type(callee).__name__}")
1616 if isinstance(callee.func, Function):
1617 assert isinstance(callee.func.arg, Var)
1618 new_env = {**callee.env, callee.func.arg.name: arg}
1619 return eval_exp(new_env, callee.func.body)
1620 elif isinstance(callee.func, MatchFunction):
1621 for case in callee.func.cases:
1622 m = match(arg, case.pattern)
1623 if m is None:
1624 continue
1625 return eval_exp({**callee.env, **m}, case.body)
1626 raise MatchError("no matching cases")
1627 else:
1628 raise TypeError(f"attempted to apply a non-function of type {type(callee.func).__name__}")
1629 if isinstance(exp, Access):
1630 obj = eval_exp(env, exp.obj)
1631 if isinstance(obj, Record):
1632 if not isinstance(exp.at, Var):
1633 raise TypeError(f"cannot access record field using {type(exp.at).__name__}, expected a field name")
1634 if exp.at.name not in obj.data:
1635 raise NameError(f"no assignment to {exp.at.name} found in record")
1636 return obj.data[exp.at.name]
1637 elif isinstance(obj, List):
1638 access_at = eval_exp(env, exp.at)
1639 if not isinstance(access_at, Int):
1640 raise TypeError(f"cannot index into list using type {type(access_at).__name__}, expected integer")
1641 if access_at.value < 0 or access_at.value >= len(obj.items):
1642 raise ValueError(f"index {access_at.value} out of bounds for list")
1643 return obj.items[access_at.value]
1644 raise TypeError(f"attempted to access from type {type(obj).__name__}")
1645 elif isinstance(exp, Spread):
1646 raise RuntimeError("cannot evaluate a spread")
1647 raise NotImplementedError(f"eval_exp not implemented for {exp}")
1648
1649
1650class ScrapMonad:
1651 def __init__(self, env: Env) -> None:
1652 assert isinstance(env, dict) # for .copy()
1653 self.env: Env = env.copy()
1654
1655 def bind(self, exp: Object) -> Tuple[Object, "ScrapMonad"]:
1656 env = self.env
1657 result = eval_exp(env, exp)
1658 if isinstance(result, EnvObject):
1659 return result, ScrapMonad({**env, **result.env})
1660 return result, ScrapMonad({**env, "_": result})
1661
1662
1663class InferenceError(Exception):
1664 pass
1665
1666
1667@dataclasses.dataclass
1668class MonoType:
1669 def find(self) -> MonoType:
1670 return self
1671
1672
1673@dataclasses.dataclass
1674class TyVar(MonoType):
1675 forwarded: MonoType | None = dataclasses.field(init=False, default=None)
1676 name: str
1677
1678 def find(self) -> MonoType:
1679 result: MonoType = self
1680 while isinstance(result, TyVar):
1681 it = result.forwarded
1682 if it is None:
1683 return result
1684 result = it
1685 return result
1686
1687 def __str__(self) -> str:
1688 return f"'{self.name}"
1689
1690 def make_equal_to(self, other: MonoType) -> None:
1691 chain_end = self.find()
1692 if not isinstance(chain_end, TyVar):
1693 raise InferenceError(f"{self} is already resolved to {chain_end}")
1694 chain_end.forwarded = other
1695
1696 def is_unbound(self) -> bool:
1697 return self.forwarded is None
1698
1699
1700@dataclasses.dataclass
1701class TyCon(MonoType):
1702 name: str
1703 args: list[MonoType]
1704
1705 def __str__(self) -> str:
1706 # TODO(max): Precedence pretty-print type constructors
1707 if not self.args:
1708 return self.name
1709 if len(self.args) == 1:
1710 return f"({self.args[0]} {self.name})"
1711 return f"({self.name.join(map(str, self.args))})"
1712
1713
1714@dataclasses.dataclass
1715class TyEmptyRow(MonoType):
1716 def __str__(self) -> str:
1717 return "{}"
1718
1719
1720@dataclasses.dataclass
1721class TyRow(MonoType):
1722 fields: dict[str, MonoType]
1723 rest: TyVar | TyEmptyRow = dataclasses.field(default_factory=TyEmptyRow)
1724
1725 def __post_init__(self) -> None:
1726 if not self.fields and isinstance(self.rest, TyEmptyRow):
1727 raise InferenceError("Empty row must have a rest type")
1728
1729 def __str__(self) -> str:
1730 flat, rest = row_flatten(self)
1731 # sort to make tests deterministic
1732 result = [f"{key}={val}" for key, val in sorted(flat.items())]
1733 if isinstance(rest, TyVar):
1734 result.append(f"...{rest}")
1735 else:
1736 assert isinstance(rest, TyEmptyRow)
1737 return "{" + ", ".join(result) + "}"
1738
1739
1740def row_flatten(rec: MonoType) -> tuple[dict[str, MonoType], TyVar | TyEmptyRow]:
1741 if isinstance(rec, TyVar):
1742 rec = rec.find()
1743 if isinstance(rec, TyVar):
1744 return {}, rec
1745 if isinstance(rec, TyRow):
1746 flat, rest = row_flatten(rec.rest)
1747 flat.update(rec.fields)
1748 return flat, rest
1749 if isinstance(rec, TyEmptyRow):
1750 return {}, rec
1751 raise InferenceError(f"Expected record type, got {type(rec)}")
1752
1753
1754@dataclasses.dataclass
1755class Forall:
1756 tyvars: list[TyVar]
1757 ty: MonoType
1758
1759 def __str__(self) -> str:
1760 return f"(forall {', '.join(map(str, self.tyvars))}. {self.ty})"
1761
1762
1763def func_type(*args: MonoType) -> TyCon:
1764 assert len(args) >= 2
1765 if len(args) == 2:
1766 return TyCon("->", list(args))
1767 return TyCon("->", [args[0], func_type(*args[1:])])
1768
1769
1770def list_type(arg: MonoType) -> TyCon:
1771 return TyCon("list", [arg])
1772
1773
1774def unify_fail(ty1: MonoType, ty2: MonoType) -> None:
1775 raise InferenceError(f"Unification failed for {ty1} and {ty2}")
1776
1777
1778def occurs_in(tyvar: TyVar, ty: MonoType) -> bool:
1779 if isinstance(ty, TyVar):
1780 return tyvar == ty
1781 if isinstance(ty, TyCon):
1782 return any(occurs_in(tyvar, arg) for arg in ty.args)
1783 if isinstance(ty, TyEmptyRow):
1784 return False
1785 if isinstance(ty, TyRow):
1786 return any(occurs_in(tyvar, val) for val in ty.fields.values()) or occurs_in(tyvar, ty.rest)
1787 raise InferenceError(f"Unknown type: {ty}")
1788
1789
1790def unify_type(ty1: MonoType, ty2: MonoType) -> None:
1791 ty1 = ty1.find()
1792 ty2 = ty2.find()
1793 if isinstance(ty1, TyVar):
1794 if occurs_in(ty1, ty2):
1795 raise InferenceError(f"Occurs check failed for {ty1} and {ty2}")
1796 ty1.make_equal_to(ty2)
1797 return
1798 if isinstance(ty2, TyVar): # Mirror
1799 return unify_type(ty2, ty1)
1800 if isinstance(ty1, TyCon) and isinstance(ty2, TyCon):
1801 if ty1.name != ty2.name:
1802 unify_fail(ty1, ty2)
1803 return
1804 if len(ty1.args) != len(ty2.args):
1805 unify_fail(ty1, ty2)
1806 return
1807 for l, r in zip(ty1.args, ty2.args):
1808 unify_type(l, r)
1809 return
1810 if isinstance(ty1, TyEmptyRow) and isinstance(ty2, TyEmptyRow):
1811 return
1812 if isinstance(ty1, TyRow) and isinstance(ty2, TyRow):
1813 ty1_fields, ty1_rest = row_flatten(ty1)
1814 ty2_fields, ty2_rest = row_flatten(ty2)
1815 ty1_missing = {}
1816 ty2_missing = {}
1817 all_field_names = set(ty1_fields.keys()) | set(ty2_fields.keys())
1818 for key in sorted(all_field_names): # Sort for deterministic error messages
1819 ty1_val = ty1_fields.get(key)
1820 ty2_val = ty2_fields.get(key)
1821 if ty1_val is not None and ty2_val is not None:
1822 unify_type(ty1_val, ty2_val)
1823 elif ty1_val is None:
1824 assert ty2_val is not None
1825 ty1_missing[key] = ty2_val
1826 elif ty2_val is None:
1827 assert ty1_val is not None
1828 ty2_missing[key] = ty1_val
1829 # In general, we want to:
1830 # 1) Add missing fields from one row to the other row
1831 # 2) "Keep the rows unified" by linking each row's rest to the other
1832 # row's rest
1833 if not ty1_missing and not ty2_missing:
1834 # The rests are either both empty (rows were closed) or both
1835 # unbound type variables (rows were open); unify the rest variables
1836 unify_type(ty1_rest, ty2_rest)
1837 return
1838 if not ty1_missing:
1839 # The first row has fields that the second row doesn't have; add
1840 # them to the second row
1841 unify_type(ty2_rest, TyRow(ty2_missing, ty1_rest))
1842 return
1843 if not ty2_missing:
1844 # The second row has fields that the first row doesn't have; add
1845 # them to the first row
1846 unify_type(ty1_rest, TyRow(ty1_missing, ty2_rest))
1847 return
1848 # They each have fields the other lacks; create new rows sharing a rest
1849 # and add the missing fields to each row
1850 rest = fresh_tyvar()
1851 unify_type(ty1_rest, TyRow(ty1_missing, rest))
1852 unify_type(ty2_rest, TyRow(ty2_missing, rest))
1853 return
1854 if isinstance(ty1, TyRow) and isinstance(ty2, TyEmptyRow):
1855 raise InferenceError(f"Unifying row {ty1} with empty row")
1856 if isinstance(ty1, TyEmptyRow) and isinstance(ty2, TyRow):
1857 raise InferenceError(f"Unifying empty row with row {ty2}")
1858 raise InferenceError(f"Cannot unify {ty1} and {ty2}")
1859
1860
1861Context = typing.Mapping[str, Forall]
1862
1863
1864fresh_var_counter = 0
1865
1866
1867def fresh_tyvar(prefix: str = "t") -> TyVar:
1868 global fresh_var_counter
1869 result = f"{prefix}{fresh_var_counter}"
1870 fresh_var_counter += 1
1871 return TyVar(result)
1872
1873
1874def reset_tyvar_counter() -> None:
1875 global fresh_var_counter
1876 fresh_var_counter = 0
1877
1878
1879IntType = TyCon("int", [])
1880StringType = TyCon("string", [])
1881FloatType = TyCon("float", [])
1882BytesType = TyCon("bytes", [])
1883HoleType = TyCon("hole", [])
1884
1885
1886Subst = typing.Mapping[str, MonoType]
1887
1888
1889def apply_ty(ty: MonoType, subst: Subst) -> MonoType:
1890 ty = ty.find()
1891 if isinstance(ty, TyVar):
1892 return subst.get(ty.name, ty)
1893 if isinstance(ty, TyCon):
1894 return TyCon(ty.name, [apply_ty(arg, subst) for arg in ty.args])
1895 if isinstance(ty, TyEmptyRow):
1896 return ty
1897 if isinstance(ty, TyRow):
1898 rest = apply_ty(ty.rest, subst)
1899 assert isinstance(rest, (TyVar, TyEmptyRow))
1900 return TyRow({key: apply_ty(val, subst) for key, val in ty.fields.items()}, rest)
1901 raise InferenceError(f"Unknown type: {ty}")
1902
1903
1904def instantiate(scheme: Forall) -> MonoType:
1905 fresh = {tyvar.name: fresh_tyvar() for tyvar in scheme.tyvars}
1906 return apply_ty(scheme.ty, fresh)
1907
1908
1909def ftv_ty(ty: MonoType) -> set[str]:
1910 ty = ty.find()
1911 if isinstance(ty, TyVar):
1912 return {ty.name}
1913 if isinstance(ty, TyCon):
1914 return set().union(*map(ftv_ty, ty.args))
1915 if isinstance(ty, TyEmptyRow):
1916 return set()
1917 if isinstance(ty, TyRow):
1918 return set().union(*map(ftv_ty, ty.fields.values()), ftv_ty(ty.rest))
1919 raise InferenceError(f"Unknown type: {ty}")
1920
1921
1922def generalize(ty: MonoType, ctx: Context) -> Forall:
1923 def ftv_scheme(ty: Forall) -> set[str]:
1924 return ftv_ty(ty.ty) - set(tyvar.name for tyvar in ty.tyvars)
1925
1926 def ftv_ctx(ctx: Context) -> set[str]:
1927 return set().union(*(ftv_scheme(scheme) for scheme in ctx.values()))
1928
1929 # TODO(max): Freshen?
1930 tyvars = ftv_ty(ty) - ftv_ctx(ctx)
1931 return Forall([TyVar(name) for name in sorted(tyvars)], ty)
1932
1933
1934def type_of(expr: Object) -> MonoType:
1935 ty = getattr(expr, "inferred_type", None)
1936 if ty is not None:
1937 assert isinstance(ty, MonoType)
1938 return ty.find()
1939 return set_type(expr, fresh_tyvar())
1940
1941
1942def set_type(expr: Object, ty: MonoType) -> MonoType:
1943 object.__setattr__(expr, "inferred_type", ty)
1944 return ty
1945
1946
1947def infer_common(expr: Object) -> MonoType:
1948 if isinstance(expr, Int):
1949 return set_type(expr, IntType)
1950 if isinstance(expr, Float):
1951 return set_type(expr, FloatType)
1952 if isinstance(expr, Bytes):
1953 return set_type(expr, BytesType)
1954 if isinstance(expr, Hole):
1955 return set_type(expr, HoleType)
1956 if isinstance(expr, String):
1957 return set_type(expr, StringType)
1958 raise InferenceError(f"{type(expr)} can't be simply inferred")
1959
1960
1961def infer_pattern_type(pattern: Object, ctx: Context) -> MonoType:
1962 assert isinstance(ctx, dict)
1963 if isinstance(pattern, (Int, Float, Bytes, Hole, String)):
1964 return infer_common(pattern)
1965 if isinstance(pattern, Var):
1966 result = fresh_tyvar()
1967 ctx[pattern.name] = Forall([], result)
1968 return set_type(pattern, result)
1969 if isinstance(pattern, List):
1970 list_item_ty = fresh_tyvar()
1971 result_ty = list_type(list_item_ty)
1972 for item in pattern.items:
1973 if isinstance(item, Spread):
1974 if item.name is not None:
1975 ctx[item.name] = Forall([], result_ty)
1976 break
1977 item_ty = infer_pattern_type(item, ctx)
1978 unify_type(list_item_ty, item_ty)
1979 return set_type(pattern, result_ty)
1980 if isinstance(pattern, Record):
1981 fields = {}
1982 rest: TyVar | TyEmptyRow = TyEmptyRow() # Default closed row
1983 for key, value in pattern.data.items():
1984 if isinstance(value, Spread):
1985 # Open row
1986 rest = fresh_tyvar()
1987 if value.name is not None:
1988 ctx[value.name] = Forall([], rest)
1989 break
1990 fields[key] = infer_pattern_type(value, ctx)
1991 return set_type(pattern, TyRow(fields, rest))
1992 raise InferenceError(f"{type(pattern)} isn't allowed in a pattern")
1993
1994
1995def infer_type(expr: Object, ctx: Context) -> MonoType:
1996 if isinstance(expr, (Int, Float, Bytes, Hole, String)):
1997 return infer_common(expr)
1998 if isinstance(expr, Var):
1999 scheme = ctx.get(expr.name)
2000 if scheme is None:
2001 raise InferenceError(f"Unbound variable {expr.name}")
2002 return set_type(expr, instantiate(scheme))
2003 if isinstance(expr, Function):
2004 arg_tyvar = fresh_tyvar()
2005 assert isinstance(expr.arg, Var)
2006 body_ctx = {**ctx, expr.arg.name: Forall([], arg_tyvar)}
2007 body_ty = infer_type(expr.body, body_ctx)
2008 return set_type(expr, func_type(arg_tyvar, body_ty))
2009 if isinstance(expr, Binop):
2010 left, right = expr.left, expr.right
2011 op = Var(BinopKind.to_str(expr.op))
2012 return set_type(expr, infer_type(Apply(Apply(op, left), right), ctx))
2013 if isinstance(expr, Where):
2014 assert isinstance(expr.binding, Assign)
2015 name, value, body = expr.binding.name.name, expr.binding.value, expr.body
2016 if isinstance(value, (Function, MatchFunction)):
2017 # Letrec
2018 func_ty: MonoType = fresh_tyvar()
2019 value_ty = infer_type(value, {**ctx, name: Forall([], func_ty)})
2020 else:
2021 # Let
2022 value_ty = infer_type(value, ctx)
2023 value_scheme = generalize(value_ty, ctx)
2024 body_ty = infer_type(body, {**ctx, name: value_scheme})
2025 return set_type(expr, body_ty)
2026 if isinstance(expr, List):
2027 list_item_ty = fresh_tyvar()
2028 for item in expr.items:
2029 assert not isinstance(item, Spread), "Spread can only occur in list match (for now)"
2030 item_ty = infer_type(item, ctx)
2031 unify_type(list_item_ty, item_ty)
2032 return set_type(expr, list_type(list_item_ty))
2033 if isinstance(expr, MatchCase):
2034 pattern_ctx: Context = {}
2035 pattern_ty = infer_pattern_type(expr.pattern, pattern_ctx)
2036 body_ty = infer_type(expr.body, {**ctx, **pattern_ctx})
2037 return set_type(expr, func_type(pattern_ty, body_ty))
2038 if isinstance(expr, Apply):
2039 func_ty = infer_type(expr.func, ctx)
2040 arg_ty = infer_type(expr.arg, ctx)
2041 result = fresh_tyvar()
2042 unify_type(func_ty, func_type(arg_ty, result))
2043 return set_type(expr, result)
2044 if isinstance(expr, MatchFunction):
2045 result = fresh_tyvar()
2046 for case in expr.cases:
2047 case_ty = infer_type(case, ctx)
2048 unify_type(result, case_ty)
2049 return set_type(expr, result)
2050 if isinstance(expr, Record):
2051 fields = {}
2052 rest: TyVar | TyEmptyRow = TyEmptyRow()
2053 for key, value in expr.data.items():
2054 assert not isinstance(value, Spread), "Spread can only occur in record match (for now)"
2055 fields[key] = infer_type(value, ctx)
2056 return set_type(expr, TyRow(fields, rest))
2057 if isinstance(expr, Access):
2058 obj_ty = infer_type(expr.obj, ctx)
2059 value_ty = fresh_tyvar()
2060 assert isinstance(expr.at, Var)
2061 # "has field" constraint in the form of an open row
2062 unify_type(obj_ty, TyRow({expr.at.name: value_ty}, fresh_tyvar()))
2063 return value_ty
2064 raise InferenceError(f"Unexpected type {type(expr)}")
2065
2066
2067def minimize(ty: MonoType) -> MonoType:
2068 letters = iter("abcdefghijklmnopqrstuvwxyz")
2069 free = ftv_ty(ty)
2070 subst = {ftv: TyVar(next(letters)) for ftv in sorted(free)}
2071 return apply_ty(ty, subst)
2072
2073
2074Number = typing.Union[int, float]
2075
2076
2077class Repr(typing.Protocol):
2078 def __call__(self, obj: Object, prec: Number = 0) -> str: ...
2079
2080
2081# Can't use reprlib.recursive_repr because it doesn't work if the print
2082# function has more than one argument (for example, prec)
2083def handle_recursion(func: Repr) -> Repr:
2084 cache: typing.List[Object] = []
2085
2086 @functools.wraps(func)
2087 def wrapper(obj: Object, prec: Number = 0) -> str:
2088 for cached in cache:
2089 if obj is cached:
2090 return "..."
2091 cache.append(obj)
2092 result = func(obj, prec)
2093 cache.remove(obj)
2094 return result
2095
2096 return wrapper
2097
2098
2099@handle_recursion
2100def pretty(obj: Object, prec: Number = 0) -> str:
2101 if isinstance(obj, Int):
2102 return str(obj.value)
2103 if isinstance(obj, Float):
2104 return str(obj.value)
2105 if isinstance(obj, String):
2106 return json.dumps(obj.value)
2107 if isinstance(obj, Bytes):
2108 return f"~~{base64.b64encode(obj.value).decode()}"
2109 if isinstance(obj, Var):
2110 return obj.name
2111 if isinstance(obj, Hole):
2112 return "()"
2113 if isinstance(obj, Spread):
2114 return f"...{obj.name}" if obj.name else "..."
2115 if isinstance(obj, List):
2116 return f"[{', '.join(pretty(item) for item in obj.items)}]"
2117 if isinstance(obj, Record):
2118 return f"{{{', '.join(f'{key} = {pretty(value)}' for key, value in obj.data.items())}}}"
2119 if isinstance(obj, Closure):
2120 keys = list(obj.env.keys())
2121 return f"Closure({keys}, {pretty(obj.func)})"
2122 if isinstance(obj, EnvObject):
2123 return f"EnvObject({repr(obj.env)})"
2124 if isinstance(obj, NativeFunction):
2125 return f"NativeFunction(name={obj.name})"
2126 if isinstance(obj, Relocation):
2127 return f"Relocation(name={repr(obj.name)})"
2128 if isinstance(obj, Variant):
2129 op_prec = PS["#"]
2130 left_prec, right_prec = op_prec.pl, op_prec.pr
2131 result = f"#{obj.tag} {pretty(obj.value, right_prec)}"
2132 if isinstance(obj, Assign):
2133 op_prec = PS["="]
2134 left_prec, right_prec = op_prec.pl, op_prec.pr
2135 result = f"{pretty(obj.name, left_prec)} = {pretty(obj.value, right_prec)}"
2136 if isinstance(obj, Binop):
2137 op_prec = PS[BinopKind.to_str(obj.op)]
2138 left_prec, right_prec = op_prec.pl, op_prec.pr
2139 result = f"{pretty(obj.left, left_prec)} {BinopKind.to_str(obj.op)} {pretty(obj.right, right_prec)}"
2140 if isinstance(obj, Function):
2141 op_prec = PS["->"]
2142 left_prec, right_prec = op_prec.pl, op_prec.pr
2143 assert isinstance(obj.arg, Var)
2144 result = f"{obj.arg.name} -> {pretty(obj.body, right_prec)}"
2145 if isinstance(obj, MatchFunction):
2146 op_prec = PS["|"]
2147 left_prec, right_prec = op_prec.pl, op_prec.pr
2148 result = "\n".join(
2149 f"| {pretty(case.pattern, left_prec)} -> {pretty(case.body, right_prec)}" for case in obj.cases
2150 )
2151 if isinstance(obj, Where):
2152 op_prec = PS["."]
2153 left_prec, right_prec = op_prec.pl, op_prec.pr
2154 result = f"{pretty(obj.body, left_prec)} . {pretty(obj.binding, right_prec)}"
2155 if isinstance(obj, Assert):
2156 op_prec = PS["!"]
2157 left_prec, right_prec = op_prec.pl, op_prec.pr
2158 result = f"{pretty(obj.value, left_prec)} ! {pretty(obj.cond, right_prec)}"
2159 if isinstance(obj, Apply):
2160 op_prec = PS[""]
2161 left_prec, right_prec = op_prec.pl, op_prec.pr
2162 result = f"{pretty(obj.func, left_prec)} {pretty(obj.arg, right_prec)}"
2163 if isinstance(obj, Access):
2164 op_prec = PS["@"]
2165 left_prec, right_prec = op_prec.pl, op_prec.pr
2166 result = f"{pretty(obj.obj, left_prec)} @ {pretty(obj.at, right_prec)}"
2167 if prec >= op_prec.pl:
2168 return f"({result})"
2169 return result
2170
2171
2172def fetch(url: Object) -> Object:
2173 if not isinstance(url, String):
2174 raise TypeError(f"fetch expected String, but got {type(url).__name__}")
2175 with urllib.request.urlopen(url.value) as f:
2176 return String(f.read().decode("utf-8"))
2177
2178
2179def make_object(pyobj: object) -> Object:
2180 assert not isinstance(pyobj, Object)
2181 if isinstance(pyobj, int):
2182 return Int(pyobj)
2183 if isinstance(pyobj, str):
2184 return String(pyobj)
2185 if isinstance(pyobj, list):
2186 return List([make_object(o) for o in pyobj])
2187 if isinstance(pyobj, dict):
2188 # Assumed to only be called with JSON, so string keys.
2189 return Record({key: make_object(value) for key, value in pyobj.items()})
2190 raise NotImplementedError(type(pyobj))
2191
2192
2193def jsondecode(obj: Object) -> Object:
2194 if not isinstance(obj, String):
2195 raise TypeError(f"jsondecode expected String, but got {type(obj).__name__}")
2196 data = json.loads(obj.value)
2197 return make_object(data)
2198
2199
2200def listlength(obj: Object) -> Object:
2201 # TODO(max): Implement in scrapscript once list pattern matching is
2202 # implemented.
2203 if not isinstance(obj, List):
2204 raise TypeError(f"listlength expected List, but got {type(obj).__name__}")
2205 return Int(len(obj.items))
2206
2207
2208def serialize(obj: Object) -> bytes:
2209 serializer = Serializer()
2210 serializer.serialize(obj)
2211 return bytes(serializer.output)
2212
2213
2214def deserialize(data: bytes) -> Object:
2215 deserializer = Deserializer(data)
2216 return deserializer.parse()
2217
2218
2219def deserialize_object(obj: Object) -> Object:
2220 assert isinstance(obj, Bytes)
2221 return deserialize(obj.value)
2222
2223
2224STDLIB = {
2225 "$$add": Closure({}, Function(Var("x"), Function(Var("y"), Binop(BinopKind.ADD, Var("x"), Var("y"))))),
2226 "$$fetch": NativeFunction("$$fetch", fetch),
2227 "$$jsondecode": NativeFunction("$$jsondecode", jsondecode),
2228 "$$serialize": NativeFunction("$$serialize", lambda obj: Bytes(serialize(obj))),
2229 "$$deserialize": NativeFunction("$$deserialize", deserialize_object),
2230 "$$listlength": NativeFunction("$$listlength", listlength),
2231}
2232
2233
2234PRELUDE = """
2235id = x -> x
2236
2237. quicksort =
2238 | [] -> []
2239 | [p, ...xs] -> (concat ((quicksort (ltp xs p)) +< p) (quicksort (gtp xs p))
2240 . gtp = xs -> p -> filter (x -> x >= p) xs
2241 . ltp = xs -> p -> filter (x -> x < p) xs)
2242
2243. filter = f ->
2244 | [] -> []
2245 | [x, ...xs] -> f x |> | #true () -> x >+ filter f xs
2246 | #false () -> filter f xs
2247
2248. concat = xs ->
2249 | [] -> xs
2250 | [y, ...ys] -> concat (xs +< y) ys
2251
2252. map = f ->
2253 | [] -> []
2254 | [x, ...xs] -> f x >+ map f xs
2255
2256. range =
2257 | 0 -> []
2258 | i -> range (i - 1) +< (i - 1)
2259
2260. foldr = f -> a ->
2261 | [] -> a
2262 | [x, ...xs] -> f x (foldr f a xs)
2263
2264. take =
2265 | 0 -> xs -> []
2266 | n ->
2267 | [] -> []
2268 | [x, ...xs] -> x >+ take (n - 1) xs
2269
2270. all = f ->
2271 | [] -> #true ()
2272 | [x, ...xs] -> f x && all f xs
2273
2274. any = f ->
2275 | [] -> #false ()
2276 | [x, ...xs] -> f x || any f xs
2277"""
2278
2279
2280def boot_env() -> Env:
2281 env_object = eval_exp(STDLIB, parse(tokenize(PRELUDE)))
2282 assert isinstance(env_object, EnvObject)
2283 return env_object.env
2284
2285
2286class Completer:
2287 def __init__(self, env: Env) -> None:
2288 self.env: Env = env
2289 self.matches: typing.List[str] = []
2290
2291 def complete(self, text: str, state: int) -> Optional[str]:
2292 assert "@" not in text, "TODO: handle attr/index access"
2293 if state == 0:
2294 options = sorted(self.env.keys())
2295 if not text:
2296 self.matches = options[:]
2297 else:
2298 self.matches = [key for key in options if key.startswith(text)]
2299 try:
2300 return self.matches[state]
2301 except IndexError:
2302 return None
2303
2304
2305REPL_HISTFILE = os.path.expanduser(".scrap-history")
2306
2307
2308class ScrapRepl(code.InteractiveConsole):
2309 def __init__(self, *args: Any, **kwargs: Any) -> None:
2310 super().__init__(*args, **kwargs)
2311 self.env: Env = boot_env()
2312
2313 def enable_readline(self) -> None:
2314 assert readline, "Can't enable readline without readline module"
2315 if os.path.exists(REPL_HISTFILE):
2316 readline.read_history_file(REPL_HISTFILE)
2317 # what determines the end of a word; need to set so $ can be part of a
2318 # variable name
2319 readline.set_completer_delims(" \t\n;")
2320 # TODO(max): Add completion per scope, not just for global environment.
2321 readline.set_completer(Completer(self.env).complete)
2322 readline.parse_and_bind("set show-all-if-ambiguous on")
2323 readline.parse_and_bind("tab: menu-complete")
2324
2325 def finish_readline(self) -> None:
2326 assert readline, "Can't finish readline without readline module"
2327 histfile_size = 1000
2328 readline.set_history_length(histfile_size)
2329 readline.write_history_file(REPL_HISTFILE)
2330
2331 def runsource(self, source: str, filename: str = "<input>", symbol: str = "single") -> bool:
2332 try:
2333 tokens = tokenize(source)
2334 logger.debug("Tokens: %s", tokens)
2335 ast = parse(tokens)
2336 if isinstance(ast, MatchFunction) and not source.endswith("\n"):
2337 # User might be in the middle of typing a multi-line match...
2338 # wait for them to hit Enter once after the last case
2339 return True
2340 logger.debug("AST: %s", ast)
2341 result = eval_exp(self.env, ast)
2342 assert isinstance(self.env, dict) # for .update()/__setitem__
2343 if isinstance(result, EnvObject):
2344 self.env.update(result.env)
2345 else:
2346 self.env["_"] = result
2347 print(pretty(result))
2348 except UnexpectedEOFError:
2349 # Need to read more text
2350 return True
2351 except ParseError as e:
2352 print(f"Parse error: {e}", file=sys.stderr)
2353 except Exception as e:
2354 print(f"Error: {e}", file=sys.stderr)
2355 return False
2356
2357
2358def eval_command(args: argparse.Namespace) -> None:
2359 if args.debug:
2360 logging.basicConfig(level=logging.DEBUG)
2361
2362 program = args.program_file.read()
2363 tokens = tokenize(program)
2364 logger.debug("Tokens: %s", tokens)
2365 ast = parse(tokens)
2366 logger.debug("AST: %s", ast)
2367 result = eval_exp(boot_env(), ast)
2368 print(pretty(result))
2369
2370
2371def check_command(args: argparse.Namespace) -> None:
2372 if args.debug:
2373 logging.basicConfig(level=logging.DEBUG)
2374
2375 program = args.program_file.read()
2376 tokens = tokenize(program)
2377 logger.debug("Tokens: %s", tokens)
2378 ast = parse(tokens)
2379 logger.debug("AST: %s", ast)
2380 result = infer_type(ast, OP_ENV)
2381 result = minimize(result)
2382 print(result)
2383
2384
2385def apply_command(args: argparse.Namespace) -> None:
2386 if args.debug:
2387 logging.basicConfig(level=logging.DEBUG)
2388
2389 tokens = tokenize(args.program)
2390 logger.debug("Tokens: %s", tokens)
2391 ast = parse(tokens)
2392 logger.debug("AST: %s", ast)
2393 result = eval_exp(boot_env(), ast)
2394 print(pretty(result))
2395
2396
2397def repl_command(args: argparse.Namespace) -> None:
2398 if args.debug:
2399 logging.basicConfig(level=logging.DEBUG)
2400
2401 repl = ScrapRepl()
2402 if readline:
2403 repl.enable_readline()
2404 repl.interact(banner="")
2405 if readline:
2406 repl.finish_readline()
2407
2408
2409def env_get_split(key: str, default: Optional[typing.List[str]] = None) -> typing.List[str]:
2410 import shlex
2411
2412 cflags = os.environ.get(key)
2413 if cflags:
2414 return shlex.split(cflags)
2415 if default:
2416 return default
2417 return []
2418
2419
2420def discover_cflags(cc: typing.List[str], debug: bool = True) -> typing.List[str]:
2421 default_cflags = ["-Wall", "-Wextra", "-fno-strict-aliasing", "-Wno-unused-function"]
2422 # -fno-strict-aliasing is needed because we do pointer casting a bunch
2423 # -Wno-unused-function is needed because we have a bunch of unused
2424 # functions depending on what code is compiled
2425 if debug:
2426 default_cflags += ["-O0", "-ggdb"]
2427 else:
2428 default_cflags += ["-O2", "-DNDEBUG"]
2429 if "cosmo" not in cc[0]:
2430 # cosmocc does not support LTO
2431 default_cflags.append("-flto")
2432 if "mingw" in cc[0]:
2433 # Windows does not support mmap
2434 default_cflags.append("-DSTATIC_HEAP")
2435 return env_get_split("CFLAGS", default_cflags)
2436
2437
2438OP_ENV = {
2439 "+": Forall([], func_type(IntType, IntType, IntType)),
2440 "-": Forall([], func_type(IntType, IntType, IntType)),
2441 "*": Forall([], func_type(IntType, IntType, IntType)),
2442 "/": Forall([], func_type(IntType, IntType, FloatType)),
2443 "++": Forall([], func_type(StringType, StringType, StringType)),
2444 ">+": Forall([TyVar("a")], func_type(TyVar("a"), list_type(TyVar("a")), list_type(TyVar("a")))),
2445 "+<": Forall([TyVar("a")], func_type(list_type(TyVar("a")), TyVar("a"), list_type(TyVar("a")))),
2446}
2447
2448
2449def compile_command(args: argparse.Namespace) -> None:
2450 if args.run:
2451 args.compile = True
2452 from compiler import compile_to_string
2453
2454 with open(args.file, "r") as f:
2455 source = f.read()
2456
2457 program = parse(tokenize(source))
2458 if args.check:
2459 infer_type(program, OP_ENV)
2460 c_program = compile_to_string(program, args.debug)
2461
2462 with open(args.platform, "r") as f:
2463 platform = f.read()
2464
2465 with open(args.output, "w") as f:
2466 f.write(c_program)
2467 f.write(platform)
2468
2469 if args.format:
2470 import subprocess
2471
2472 subprocess.run(["clang-format-15", "-i", args.output], check=True)
2473
2474 if args.compile:
2475 import subprocess
2476
2477 cc = env_get_split("CC", ["clang"])
2478 cflags = discover_cflags(cc, args.debug)
2479 if args.memory:
2480 cflags += [f"-DMEMORY_SIZE={args.memory}"]
2481 if args.handle_stack_size:
2482 cflags += [f"-DHANDLE_STACK_SIZE={args.handle_stack_size}"]
2483 ldflags = env_get_split("LDFLAGS")
2484 subprocess.run([*cc, "-o", "a.out", *cflags, args.output, *ldflags], check=True)
2485
2486 if args.run:
2487 import subprocess
2488
2489 subprocess.run(["sh", "-c", "./a.out"], check=True)
2490
2491
2492def flat_command(args: argparse.Namespace) -> None:
2493 prog = parse(tokenize(sys.stdin.read()))
2494 serializer = Serializer()
2495 serializer.serialize(prog)
2496 sys.stdout.buffer.write(serializer.output)
2497
2498
2499def server_command(args: argparse.Namespace) -> None:
2500 import http.server
2501 import socketserver
2502 import hashlib
2503
2504 dir = os.path.abspath(args.directory)
2505 if not os.path.isdir(dir):
2506 print(f"Error: {dir} is not a valid directory")
2507 sys.exit(1)
2508
2509 scraps = {}
2510 for root, _, files in os.walk(dir):
2511 for file in files:
2512 file_path = os.path.join(root, file)
2513 rel_path = os.path.relpath(file_path, dir)
2514 if file.startswith("$"):
2515 logger.debug(f"Skipping {rel_path}")
2516 continue
2517 rel_path_without_ext = os.path.splitext(rel_path)[0]
2518 with open(file_path, "r") as f:
2519 try:
2520 program = parse(tokenize(f.read()))
2521 serializer = Serializer()
2522 serializer.serialize(program)
2523 serialized = bytes(serializer.output)
2524 scraps[rel_path_without_ext] = serialized
2525 logger.debug(f"Loaded {rel_path_without_ext}")
2526 file_hash = hashlib.sha256(serialized).hexdigest()
2527 scraps[f"${file_hash}"] = serialized
2528 logger.debug(f"Loaded {rel_path_without_ext} as ${file_hash}")
2529 except Exception as e:
2530 logger.error(f"Error processing {file_path}: {e}")
2531
2532 keep_serving = True
2533
2534 class ScrapHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
2535 def do_QUIT(self) -> None:
2536 self.send_response(200)
2537 self.end_headers()
2538 self.wfile.write(b"Quitting")
2539 nonlocal keep_serving
2540 keep_serving = False
2541
2542 def do_GET(self) -> None:
2543 path = self.path.lstrip("/")
2544 scrap = scraps.get(path)
2545 if scrap is not None:
2546 self.send_response(200)
2547 self.send_header("Content-Type", "application/scrap; charset=binary")
2548 self.send_header("Content-Disposition", f"attachment; filename={json.dumps(f'{path}.scrap')}")
2549 self.send_header("Content-Length", str(len(scrap)))
2550 self.end_headers()
2551 self.wfile.write(scrap)
2552 else:
2553 self.send_response(404)
2554 self.send_header("Content-Type", "text/plain")
2555 self.end_headers()
2556 self.wfile.write(b"File not found")
2557
2558 handler = ScrapHTTPRequestHandler
2559 with socketserver.TCPServer((args.host, args.port), handler) as httpd:
2560 logger.info(f"Serving {dir} at http://{args.host}:{args.port}")
2561 while keep_serving:
2562 httpd.handle_request()
2563
2564
2565def main() -> None:
2566 parser = argparse.ArgumentParser(prog="scrapscript")
2567 subparsers = parser.add_subparsers(dest="command")
2568
2569 repl = subparsers.add_parser("repl")
2570 repl.set_defaults(func=repl_command)
2571 repl.add_argument("--debug", action="store_true")
2572
2573 eval_ = subparsers.add_parser("eval")
2574 eval_.set_defaults(func=eval_command)
2575 eval_.add_argument("program_file", type=argparse.FileType("r"))
2576 eval_.add_argument("--debug", action="store_true")
2577
2578 check = subparsers.add_parser("check")
2579 check.set_defaults(func=check_command)
2580 check.add_argument("program_file", type=argparse.FileType("r"))
2581 check.add_argument("--debug", action="store_true")
2582
2583 apply = subparsers.add_parser("apply")
2584 apply.set_defaults(func=apply_command)
2585 apply.add_argument("program")
2586 apply.add_argument("--debug", action="store_true")
2587
2588 comp = subparsers.add_parser("compile")
2589 comp.set_defaults(func=compile_command)
2590 comp.add_argument("file")
2591 comp.add_argument("-o", "--output", default="output.c")
2592 comp.add_argument("--format", action="store_true")
2593 comp.add_argument("--compile", action="store_true")
2594 comp.add_argument("--memory", type=int)
2595 comp.add_argument("--handle-stack-size", type=int)
2596 comp.add_argument("--run", action="store_true")
2597 comp.add_argument("--debug", action="store_true", default=False)
2598 comp.add_argument("--check", action="store_true", default=False)
2599 # The platform is in the same directory as this file
2600 comp.add_argument("--platform", default=os.path.join(os.path.dirname(__file__), "cli.c"))
2601
2602 flat = subparsers.add_parser("flat")
2603 flat.set_defaults(func=flat_command)
2604
2605 yard = subparsers.add_parser("yard")
2606 yard.set_defaults(func=lambda _: yard.print_help())
2607 yard_subparsers = yard.add_subparsers(dest="yard_command")
2608
2609 yard_server = yard_subparsers.add_parser("server")
2610 yard_server.set_defaults(func=server_command)
2611 yard_server.add_argument("directory", type=str, nargs="?", default=".", help="Directory to serve")
2612 yard_server.add_argument("--host", type=str, default="127.0.0.1", help="Host to bind to")
2613 yard_server.add_argument("--port", type=int, default=8080, help="Port to listen on")
2614
2615 args = parser.parse_args()
2616 if not args.command:
2617 args.debug = False
2618 repl_command(args)
2619 else:
2620 args.func(args)
2621
2622
2623if __name__ == "__main__":
2624 # This is so that we can use scrapscript.py as a main but also import
2625 # things from `scrapscript` and not have that be a separate module.
2626 sys.modules["scrapscript"] = sys.modules[__name__]
2627 main()