this repo has no description
at trunk 2627 lines 88 kB view raw
1#!/usr/bin/env python3.10 2from __future__ import annotations 3import argparse 4import base64 5import code 6import copy 7import dataclasses 8import enum 9import functools 10import json 11import logging 12import os 13import struct 14import sys 15import typing 16import urllib.request 17from dataclasses import dataclass 18from enum import auto 19from types import ModuleType 20from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, Set, Tuple, Union 21 22readline: Optional[ModuleType] 23try: 24 import readline 25except ImportError: 26 readline = None 27 28 29logger = logging.getLogger(__name__) 30 31 32def is_identifier_char(c: str) -> bool: 33 return c.isalnum() or c in ("$", "'", "_") 34 35 36@dataclass(eq=True, order=True, unsafe_hash=True) 37class SourceLocation: 38 lineno: int = dataclasses.field(default=-1) 39 colno: int = dataclasses.field(default=-1) 40 byteno: int = dataclasses.field(default=-1) 41 42 43@dataclass(eq=True, unsafe_hash=True) 44class SourceExtent: 45 start: SourceLocation = dataclasses.field(default_factory=SourceLocation) 46 end: SourceLocation = dataclasses.field(default_factory=SourceLocation) 47 48 def coalesce(self, other: Optional[SourceExtent]) -> Optional[SourceExtent]: 49 return SourceExtent(min(self.start, other.start), max(self.end, other.end)) if other else None 50 51 52def join_source_extents( 53 source_extent_one: Optional[SourceExtent], source_extent_two: Optional[SourceExtent] 54) -> Optional[SourceExtent]: 55 return source_extent_one.coalesce(source_extent_two) if source_extent_one else None 56 57 58@dataclass(eq=True) 59class Token: 60 source_extent: SourceExtent = dataclasses.field(default_factory=SourceExtent, init=False, compare=False) 61 62 def with_source(self, source_extent: SourceExtent) -> Token: 63 self.source_extent = source_extent 64 return self 65 66 67@dataclass(eq=True) 68class IntLit(Token): 69 value: int 70 71 72@dataclass(eq=True) 73class FloatLit(Token): 74 value: float 75 76 77@dataclass(eq=True) 78class StringLit(Token): 79 value: str 80 81 82@dataclass(eq=True) 83class BytesLit(Token): 84 value: str 85 base: int 86 87 88@dataclass(eq=True) 89class Operator(Token): 90 value: str 91 92 93@dataclass(eq=True) 94class Name(Token): 95 value: str 96 97 98@dataclass(eq=True) 99class LeftParen(Token): 100 # ( 101 pass 102 103 104@dataclass(eq=True) 105class RightParen(Token): 106 # ) 107 pass 108 109 110@dataclass(eq=True) 111class LeftBrace(Token): 112 # { 113 pass 114 115 116@dataclass(eq=True) 117class RightBrace(Token): 118 # } 119 pass 120 121 122@dataclass(eq=True) 123class LeftBracket(Token): 124 # [ 125 pass 126 127 128@dataclass(eq=True) 129class RightBracket(Token): 130 # ] 131 pass 132 133 134@dataclass(eq=True) 135class Hash(Token): 136 # # 137 pass 138 139 140@dataclass(eq=True) 141class EOF(Token): 142 pass 143 144 145def num_bytes_as_utf8(s: str) -> int: 146 return len(s.encode(encoding="UTF-8")) 147 148 149class Lexer: 150 def __init__(self, text: str): 151 self.text: str = text 152 self.idx: int = 0 153 self._lineno: int = 1 154 self._colno: int = 1 155 self.line: str = "" 156 self._byteno: int = 0 157 self.current_token_source_extent: SourceExtent = SourceExtent( 158 start=SourceLocation( 159 lineno=self._lineno, 160 colno=self._colno, 161 byteno=self._byteno, 162 ), 163 end=SourceLocation( 164 lineno=self._lineno, 165 colno=self._colno, 166 byteno=self._byteno, 167 ), 168 ) 169 self.token_start_idx: int = self.idx 170 self.token_end_idx: int = self.token_start_idx 171 172 @property 173 def lineno(self) -> int: 174 return self._lineno 175 176 @property 177 def colno(self) -> int: 178 return self._colno 179 180 @property 181 def byteno(self) -> int: 182 return self._byteno 183 184 def mark_token_start(self) -> None: 185 self.current_token_source_extent.start.lineno = self._lineno 186 self.current_token_source_extent.start.colno = self._colno 187 self.current_token_source_extent.start.byteno = self._byteno 188 self.token_start_idx = self.idx 189 190 def mark_token_end(self) -> None: 191 self.current_token_source_extent.end.lineno = self._lineno 192 self.current_token_source_extent.end.colno = self._colno 193 self.current_token_source_extent.end.byteno = self._byteno 194 self.token_end_idx = self.idx 195 196 def has_input(self) -> bool: 197 return self.idx < len(self.text) 198 199 def read_char(self) -> str: 200 self.mark_token_end() 201 c = self.peek_char() 202 if c == "\n": 203 self._lineno += 1 204 self._colno = 1 205 self.line = "" 206 else: 207 self.line += c 208 self._colno += 1 209 self.idx += 1 210 self._byteno += num_bytes_as_utf8(c) 211 return c 212 213 def peek_char(self) -> str: 214 if not self.has_input(): 215 raise UnexpectedEOFError("while reading token") 216 return self.text[self.idx] 217 218 def make_token(self, cls: type, *args: Any) -> Token: 219 result: Token = cls(*args) 220 return result.with_source(copy.deepcopy(self.current_token_source_extent)) 221 222 def read_tokens(self) -> Generator[Token, None, None]: 223 while (token := self.read_token()) and not isinstance(token, EOF): 224 yield token 225 226 def read_token(self) -> Token: 227 # Consume all whitespace 228 while self.has_input(): 229 # Keep updating the token start location until we exhaust all whitespace 230 self.mark_token_start() 231 c = self.read_char() 232 if not c.isspace(): 233 break 234 else: 235 return self.make_token(EOF) 236 if c == '"': 237 return self.read_string() 238 if c == "-": 239 if self.has_input() and self.peek_char() == "-": 240 self.read_comment() 241 # Need to start reading a new token 242 return self.read_token() 243 return self.read_op(c) 244 if c == "#": 245 return self.make_token(Hash) 246 if c == "~": 247 if self.has_input() and self.peek_char() == "~": 248 self.read_char() 249 return self.read_bytes() 250 raise ParseError(f"unexpected token {c!r}") 251 if c.isdigit(): 252 return self.read_number(c) 253 if c in "()[]{}": 254 custom = { 255 "(": LeftParen, 256 ")": RightParen, 257 "{": LeftBrace, 258 "}": RightBrace, 259 "[": LeftBracket, 260 "]": RightBracket, 261 } 262 return self.make_token(custom[c]) 263 if c in OPER_CHARS: 264 return self.read_op(c) 265 if is_identifier_char(c): 266 return self.read_var(c) 267 raise InvalidTokenError( 268 SourceExtent( 269 start=SourceLocation( 270 lineno=self.current_token_source_extent.start.lineno, 271 colno=self.current_token_source_extent.start.colno, 272 byteno=self.current_token_source_extent.start.byteno, 273 ), 274 end=SourceLocation( 275 lineno=self.current_token_source_extent.end.lineno, 276 colno=self.current_token_source_extent.end.colno, 277 byteno=self.current_token_source_extent.end.byteno, 278 ), 279 ) 280 ) 281 282 def read_string(self) -> Token: 283 buf = "" 284 while self.has_input(): 285 if (c := self.read_char()) == '"': 286 break 287 buf += c 288 else: 289 raise UnexpectedEOFError("while reading string") 290 return self.make_token(StringLit, buf) 291 292 def read_comment(self) -> None: 293 while self.has_input() and self.read_char() != "\n": 294 pass 295 296 def read_number(self, first_digit: str) -> Token: 297 # TODO: Support floating point numbers with no integer part 298 buf = first_digit 299 has_decimal = False 300 while self.has_input(): 301 c = self.peek_char() 302 if c == ".": 303 if has_decimal: 304 raise ParseError(f"unexpected token {c!r}") 305 has_decimal = True 306 elif not c.isdigit(): 307 break 308 self.read_char() 309 buf += c 310 311 if has_decimal: 312 return self.make_token(FloatLit, float(buf)) 313 return self.make_token(IntLit, int(buf)) 314 315 def _starts_operator(self, buf: str) -> bool: 316 # TODO(max): Rewrite using trie 317 return any(op.startswith(buf) for op in PS.keys()) 318 319 def read_op(self, first_char: str) -> Token: 320 buf = first_char 321 while self.has_input(): 322 c = self.peek_char() 323 if not self._starts_operator(buf + c): 324 break 325 self.read_char() 326 buf += c 327 if buf in PS.keys(): 328 return self.make_token(Operator, buf) 329 raise ParseError(f"unexpected token {buf!r}") 330 331 def read_var(self, first_char: str) -> Token: 332 buf = first_char 333 while self.has_input() and is_identifier_char(c := self.peek_char()): 334 self.read_char() 335 buf += c 336 return self.make_token(Name, buf) 337 338 def read_bytes(self) -> Token: 339 buf = "" 340 while self.has_input(): 341 if self.peek_char().isspace(): 342 break 343 buf += self.read_char() 344 base, _, value = buf.rpartition("'") 345 return self.make_token(BytesLit, value, int(base) if base else 64) 346 347 348PEEK_EMPTY = object() 349 350 351class Peekable: 352 def __init__(self, iterator: Iterator[Any]) -> None: 353 self.iterator = iterator 354 self.cache = PEEK_EMPTY 355 356 def __iter__(self) -> Iterator[Any]: 357 return self 358 359 def __next__(self) -> Any: 360 if self.cache is not PEEK_EMPTY: 361 result = self.cache 362 self.cache = PEEK_EMPTY 363 return result 364 return next(self.iterator) 365 366 def peek(self) -> Any: 367 result = self.cache = next(self) 368 return result 369 370 371def tokenize(x: str) -> Peekable: 372 lexer = Lexer(x) 373 return Peekable(lexer.read_tokens()) 374 375 376@dataclass(frozen=True) 377class Prec: 378 pl: float 379 pr: float 380 381 382def lp(n: float) -> Prec: 383 # TODO(max): Rewrite 384 return Prec(n, n - 0.1) 385 386 387def rp(n: float) -> Prec: 388 # TODO(max): Rewrite 389 return Prec(n, n + 0.1) 390 391 392def np(n: float) -> Prec: 393 # TODO(max): Rewrite 394 return Prec(n, n) 395 396 397def xp(n: float) -> Prec: 398 # TODO(max): Rewrite 399 return Prec(n, 0) 400 401 402PS = { 403 "::": lp(2000), 404 "@": rp(1001), 405 "": rp(1000), 406 ">>": lp(14), 407 "<<": lp(14), 408 "^": rp(13), 409 "*": rp(12), 410 "/": rp(12), 411 "//": lp(12), 412 "%": lp(12), 413 "+": lp(11), 414 "-": lp(11), 415 ">*": rp(10), 416 "++": rp(10), 417 ">+": lp(10), 418 "+<": rp(10), 419 "==": np(9), 420 "/=": np(9), 421 "<": np(9), 422 ">": np(9), 423 "<=": np(9), 424 ">=": np(9), 425 "&&": rp(8), 426 "||": rp(7), 427 "|>": rp(6), 428 "<|": lp(6), 429 "#": lp(5.5), 430 "->": lp(5), 431 "|": rp(4.5), 432 ":": lp(4.5), 433 "=": rp(4), 434 "!": lp(3), 435 ".": rp(3), 436 "?": rp(3), 437 ",": xp(1), 438 # TODO: Fix precedence for spread 439 "...": xp(0), 440} 441 442 443HIGHEST_PREC: float = max(max(p.pl, p.pr) for p in PS.values()) 444 445 446OPER_CHARS = set("".join(PS.keys())) 447assert " " not in OPER_CHARS 448 449 450class SyntacticError(Exception): 451 pass 452 453 454class ParseError(SyntacticError): 455 pass 456 457 458@dataclass(eq=True, frozen=True, unsafe_hash=True) 459class UnexpectedTokenError(ParseError): 460 unexpected_token: Token 461 462 463@dataclass(eq=True, frozen=True, unsafe_hash=True) 464class InvalidTokenError(ParseError): 465 unexpected_token: SourceExtent = dataclasses.field(default_factory=SourceExtent, compare=False) 466 467 468# TODO(max): Replace with EOFError? 469class UnexpectedEOFError(ParseError): 470 pass 471 472 473def parse_assign(tokens: Peekable, p: float = 0) -> "Assign": 474 assign = parse_binary(tokens, p) 475 if isinstance(assign, Spread): 476 return Assign(RECORD_SPREAD_KEY_PLACEHOLDER, assign) 477 if not isinstance(assign, Assign): 478 raise ParseError("failed to parse variable assignment in record constructor") 479 return assign 480 481 482def gensym() -> str: 483 gensym.counter += 1 # type: ignore 484 return f"$v{gensym.counter}" # type: ignore 485 486 487def gensym_reset() -> None: 488 gensym.counter = -1 # type: ignore 489 490 491gensym_reset() 492 493 494def make_source_annotated_object(cls: type, source_extent: Optional[SourceExtent], *args: Any) -> Object: 495 result: Object = cls(*args) 496 object.__setattr__(result, "source_extent", source_extent) 497 return result 498 499 500def parse_unary(tokens: Peekable, p: float) -> "Object": 501 token = next(tokens) 502 l: Object 503 if isinstance(token, IntLit): 504 return make_source_annotated_object(Int, token.source_extent, token.value) 505 elif isinstance(token, FloatLit): 506 return make_source_annotated_object(Float, token.source_extent, token.value) 507 elif isinstance(token, Name): 508 # TODO: Handle kebab case vars 509 return make_source_annotated_object(Var, token.source_extent, token.value) 510 elif isinstance(token, Hash): 511 hash_source_extent = token.source_extent 512 if isinstance(variant_tag := next(tokens), Name): 513 # It needs to be higher than the precedence of the -> operator so that 514 # we can match variants in MatchFunction 515 # It needs to be higher than the precedence of the && operator so that 516 # we can use #true() and #false() in boolean expressions 517 # It needs to be higher than the precedence of juxtaposition so that 518 # f #true() #false() is parsed as f(TRUE)(FALSE) 519 variant_payload = parse_binary(tokens, PS[""].pr + 1) 520 return make_source_annotated_object( 521 Variant, 522 hash_source_extent.coalesce(variant_payload.source_extent), 523 variant_tag.value, 524 variant_payload, 525 ) 526 else: 527 raise UnexpectedTokenError(variant_tag) 528 elif isinstance(token, BytesLit): 529 base = token.base 530 if base == 85: 531 l = Bytes(base64.b85decode(token.value)) 532 elif base == 64: 533 l = Bytes(base64.b64decode(token.value)) 534 elif base == 32: 535 l = Bytes(base64.b32decode(token.value)) 536 elif base == 16: 537 l = Bytes(base64.b16decode(token.value)) 538 else: 539 raise ParseError(f"unexpected base {base!r} in {token!r}") 540 object.__setattr__(l, "source_extent", token.source_extent) 541 return l 542 elif isinstance(token, StringLit): 543 return make_source_annotated_object(String, token.source_extent, token.value) 544 elif token == Operator("..."): 545 try: 546 if isinstance(tokens.peek(), Name): 547 spread_variable = next(tokens) 548 return make_source_annotated_object( 549 Spread, token.source_extent.coalesce(spread_variable.source_extent), spread_variable.value 550 ) 551 else: 552 return make_source_annotated_object(Spread, token.source_extent) 553 except StopIteration: 554 return Spread() 555 elif token == Operator("|"): 556 pipe_source_extent = token.source_extent 557 expr = parse_binary(tokens, PS["|"].pr) # TODO: make this work for larger arities 558 if not isinstance(expr, Function): 559 raise ParseError(f"expected function in match expression {expr!r}") 560 match_case = make_source_annotated_object( 561 MatchCase, pipe_source_extent.coalesce(expr.source_extent), expr.arg, expr.body 562 ) 563 cases = [match_case] 564 match_function_source_extent = match_case.source_extent 565 while True: 566 try: 567 if tokens.peek() != Operator("|"): 568 break 569 except StopIteration: 570 break 571 pipe_source_extent = next(tokens).source_extent 572 expr = parse_binary(tokens, PS["|"].pr) # TODO: make this work for larger arities 573 if not isinstance(expr, Function): 574 raise ParseError(f"expected function in match expression {expr!r}") 575 match_case = make_source_annotated_object( 576 MatchCase, pipe_source_extent.coalesce(expr.source_extent), expr.arg, expr.body 577 ) 578 cases.append(match_case) 579 match_function_source_extent = join_source_extents(match_function_source_extent, match_case.source_extent) 580 return make_source_annotated_object( 581 MatchFunction, 582 match_function_source_extent, 583 cases, 584 ) 585 elif isinstance(token, LeftParen): 586 left_paren_source_extent = token.source_extent 587 if isinstance(tokens.peek(), RightParen): 588 l = make_source_annotated_object(Hole, left_paren_source_extent.coalesce(next(tokens).source_extent)) 589 else: 590 l = parse(tokens) 591 object.__setattr__(l, "source_extent", left_paren_source_extent.coalesce(next(tokens).source_extent)) 592 return l 593 elif isinstance(token, LeftBracket): 594 list_start_source_extent = token.source_extent 595 l = List([]) 596 token = tokens.peek() 597 if isinstance(token, RightBracket): 598 list_end_source_extent = next(tokens).source_extent 599 else: 600 l.items.append(parse_binary(tokens, 2)) 601 while not isinstance(token := next(tokens), RightBracket): 602 if isinstance(l.items[-1], Spread): 603 raise ParseError("spread must come at end of list match") 604 # TODO: Implement .. operator 605 l.items.append(parse_binary(tokens, 2)) 606 list_end_source_extent = token.source_extent 607 object.__setattr__(l, "source_extent", list_start_source_extent.coalesce(list_end_source_extent)) 608 return l 609 elif isinstance(token, LeftBrace): 610 record_start_source_extent = token.source_extent 611 l = Record({}) 612 token = tokens.peek() 613 if isinstance(token, RightBrace): 614 record_end_source_extent = next(tokens).source_extent 615 else: 616 assign = parse_assign(tokens, 2) 617 l.data[assign.name.name] = assign.value 618 while not isinstance(token := next(tokens), RightBrace): 619 if isinstance(assign.value, Spread): 620 raise ParseError("spread must come at end of record match") 621 # TODO: Implement .. operator 622 assign = parse_assign(tokens, 2) 623 l.data[assign.name.name] = assign.value 624 record_end_source_extent = token.source_extent 625 object.__setattr__(l, "source_extent", record_start_source_extent.coalesce(record_end_source_extent)) 626 return l 627 elif token == Operator("-"): 628 # Unary minus 629 # Precedence was chosen to be higher than binary ops so that -a op 630 # b is (-a) op b and not -(a op b). 631 # Precedence was chosen to be higher than function application so that 632 # -a b is (-a) b and not -(a b). 633 r = parse_binary(tokens, HIGHEST_PREC + 1) 634 source_extent = token.source_extent.coalesce(r.source_extent) 635 if isinstance(r, Int): 636 assert r.value >= 0, "Tokens should never have negative values" 637 return make_source_annotated_object(Int, source_extent, -r.value) 638 if isinstance(r, Float): 639 assert r.value >= 0, "Tokens should never have negative values" 640 return make_source_annotated_object(Float, source_extent, -r.value) 641 return make_source_annotated_object(Binop, source_extent, BinopKind.SUB, Int(0), r) 642 else: 643 raise UnexpectedTokenError(token) 644 645 646def parse_binary(tokens: Peekable, p: float) -> "Object": 647 l: Object = parse_unary(tokens, p) 648 while True: 649 op: Token 650 try: 651 op = tokens.peek() 652 except StopIteration: 653 break 654 if isinstance(op, (RightParen, RightBracket, RightBrace)): 655 break 656 if not isinstance(op, Operator): 657 prec = PS[""] 658 pl, pr = prec.pl, prec.pr 659 if pl < p: 660 break 661 arg = parse_binary(tokens, pr) 662 l = make_source_annotated_object(Apply, join_source_extents(l.source_extent, arg.source_extent), l, arg) 663 continue 664 prec = PS[op.value] 665 pl, pr = prec.pl, prec.pr 666 if pl < p: 667 break 668 next(tokens) 669 if op == Operator("="): 670 if not isinstance(l, Var): 671 raise ParseError(f"expected variable in assignment {l!r}") 672 value = parse_binary(tokens, pr) 673 l = make_source_annotated_object( 674 Assign, join_source_extents(l.source_extent, value.source_extent), l, value 675 ) 676 elif op == Operator("->"): 677 body = parse_binary(tokens, pr) 678 l = make_source_annotated_object( 679 Function, join_source_extents(l.source_extent, body.source_extent), l, body 680 ) 681 elif op == Operator("|>"): 682 func = parse_binary(tokens, pr) 683 l = make_source_annotated_object(Apply, join_source_extents(func.source_extent, l.source_extent), func, l) 684 elif op == Operator("<|"): 685 arg = parse_binary(tokens, pr) 686 l = make_source_annotated_object(Apply, join_source_extents(l.source_extent, arg.source_extent), l, arg) 687 elif op == Operator(">>"): 688 r = parse_binary(tokens, pr) 689 varname = gensym() 690 l = make_source_annotated_object( 691 Function, 692 join_source_extents(l.source_extent, r.source_extent), 693 Var(varname), 694 Apply(r, Apply(l, Var(varname))), 695 ) 696 elif op == Operator("<<"): 697 r = parse_binary(tokens, pr) 698 varname = gensym() 699 l = make_source_annotated_object( 700 Function, 701 join_source_extents(l.source_extent, r.source_extent), 702 Var(varname), 703 Apply(l, Apply(r, Var(varname))), 704 ) 705 elif op == Operator("."): 706 binding = parse_binary(tokens, pr) 707 l = make_source_annotated_object( 708 Where, join_source_extents(l.source_extent, binding.source_extent), l, binding 709 ) 710 elif op == Operator("?"): 711 cond = parse_binary(tokens, pr) 712 l = make_source_annotated_object(Assert, join_source_extents(l.source_extent, cond.source_extent), l, cond) 713 elif op == Operator("@"): 714 # TODO: revisit whether to use @ or . for field access 715 at = parse_binary(tokens, pr) 716 l = make_source_annotated_object(Access, join_source_extents(l.source_extent, at.source_extent), l, at) 717 else: 718 assert isinstance(op, Operator) 719 right = parse_binary(tokens, pr) 720 l = make_source_annotated_object( 721 Binop, 722 join_source_extents(l.source_extent, right.source_extent), 723 BinopKind.from_str(op.value), 724 l, 725 right, 726 ) 727 return l 728 729 730def parse(tokens: Peekable) -> "Object": 731 try: 732 return parse_binary(tokens, 0) 733 except StopIteration: 734 raise UnexpectedEOFError("unexpected end of input") 735 736 737@dataclass(eq=True, frozen=True, unsafe_hash=True) 738class Object: 739 source_extent: Optional[SourceExtent] = dataclasses.field(default=None, compare=False, init=False, repr=False) 740 741 def __str__(self) -> str: 742 return pretty(self) 743 744 745@dataclass(eq=True, frozen=True, unsafe_hash=True) 746class Int(Object): 747 value: int 748 749 750@dataclass(eq=True, frozen=True, unsafe_hash=True) 751class Float(Object): 752 value: float 753 754 755@dataclass(eq=True, frozen=True, unsafe_hash=True) 756class String(Object): 757 value: str 758 759 760@dataclass(eq=True, frozen=True, unsafe_hash=True) 761class Bytes(Object): 762 value: bytes 763 764 765@dataclass(eq=True, frozen=True, unsafe_hash=True) 766class Var(Object): 767 name: str 768 769 770@dataclass(eq=True, frozen=True, unsafe_hash=True) 771class Hole(Object): 772 pass 773 774 775@dataclass(eq=True, frozen=True, unsafe_hash=True) 776class Spread(Object): 777 name: Optional[str] = None 778 779 780RECORD_SPREAD_KEY_PLACEHOLDER = Var("...") 781 782 783Env = Mapping[str, Object] 784 785 786# TODO(max): Add source extents for BinopKind? 787class BinopKind(enum.Enum): 788 ADD = auto() 789 SUB = auto() 790 MUL = auto() 791 DIV = auto() 792 FLOOR_DIV = auto() 793 EXP = auto() 794 MOD = auto() 795 EQUAL = auto() 796 NOT_EQUAL = auto() 797 LESS = auto() 798 GREATER = auto() 799 LESS_EQUAL = auto() 800 GREATER_EQUAL = auto() 801 BOOL_AND = auto() 802 BOOL_OR = auto() 803 STRING_CONCAT = auto() 804 LIST_CONS = auto() 805 LIST_APPEND = auto() 806 RIGHT_EVAL = auto() 807 HASTYPE = auto() 808 PIPE = auto() 809 REVERSE_PIPE = auto() 810 811 @classmethod 812 def from_str(cls, x: str) -> "BinopKind": 813 return { 814 "+": cls.ADD, 815 "-": cls.SUB, 816 "*": cls.MUL, 817 "/": cls.DIV, 818 "//": cls.FLOOR_DIV, 819 "^": cls.EXP, 820 "%": cls.MOD, 821 "==": cls.EQUAL, 822 "/=": cls.NOT_EQUAL, 823 "<": cls.LESS, 824 ">": cls.GREATER, 825 "<=": cls.LESS_EQUAL, 826 ">=": cls.GREATER_EQUAL, 827 "&&": cls.BOOL_AND, 828 "||": cls.BOOL_OR, 829 "++": cls.STRING_CONCAT, 830 ">+": cls.LIST_CONS, 831 "+<": cls.LIST_APPEND, 832 "!": cls.RIGHT_EVAL, 833 ":": cls.HASTYPE, 834 "|>": cls.PIPE, 835 "<|": cls.REVERSE_PIPE, 836 }[x] 837 838 @classmethod 839 def to_str(cls, binop_kind: "BinopKind") -> str: 840 return { 841 cls.ADD: "+", 842 cls.SUB: "-", 843 cls.MUL: "*", 844 cls.DIV: "/", 845 cls.EXP: "^", 846 cls.MOD: "%", 847 cls.EQUAL: "==", 848 cls.NOT_EQUAL: "/=", 849 cls.LESS: "<", 850 cls.GREATER: ">", 851 cls.LESS_EQUAL: "<=", 852 cls.GREATER_EQUAL: ">=", 853 cls.BOOL_AND: "&&", 854 cls.BOOL_OR: "||", 855 cls.STRING_CONCAT: "++", 856 cls.LIST_CONS: ">+", 857 cls.LIST_APPEND: "+<", 858 cls.RIGHT_EVAL: "!", 859 cls.HASTYPE: ":", 860 cls.PIPE: "|>", 861 cls.REVERSE_PIPE: "<|", 862 }[binop_kind] 863 864 865@dataclass(eq=True, frozen=True, unsafe_hash=True) 866class Binop(Object): 867 op: BinopKind 868 left: Object 869 right: Object 870 871 872@dataclass(eq=True, frozen=True, unsafe_hash=True) 873class List(Object): 874 items: typing.List[Object] 875 876 877@dataclass(eq=True, frozen=True, unsafe_hash=True) 878class Assign(Object): 879 name: Var 880 value: Object 881 882 883@dataclass(eq=True, frozen=True, unsafe_hash=True) 884class Function(Object): 885 arg: Object 886 body: Object 887 888 889@dataclass(eq=True, frozen=True, unsafe_hash=True) 890class Apply(Object): 891 func: Object 892 arg: Object 893 894 895@dataclass(eq=True, frozen=True, unsafe_hash=True) 896class Where(Object): 897 body: Object 898 binding: Object 899 900 901@dataclass(eq=True, frozen=True, unsafe_hash=True) 902class Assert(Object): 903 value: Object 904 cond: Object 905 906 907@dataclass(eq=True, frozen=True, unsafe_hash=True) 908class EnvObject(Object): 909 env: Env 910 911 def __str__(self) -> str: 912 return f"EnvObject(keys={self.env.keys()})" 913 914 915@dataclass(eq=True, frozen=True, unsafe_hash=True) 916class MatchCase(Object): 917 pattern: Object 918 body: Object 919 920 921@dataclass(eq=True, frozen=True, unsafe_hash=True) 922class MatchFunction(Object): 923 cases: typing.List[MatchCase] 924 925 926@dataclass(eq=True, frozen=True, unsafe_hash=True) 927class Relocation(Object): 928 name: str 929 930 931@dataclass(eq=True, frozen=True, unsafe_hash=True) 932class NativeFunctionRelocation(Relocation): 933 pass 934 935 936@dataclass(eq=True, frozen=True, unsafe_hash=True) 937class NativeFunction(Object): 938 name: str 939 func: Callable[[Object], Object] 940 941 942@dataclass(eq=True, frozen=True, unsafe_hash=True) 943class Closure(Object): 944 env: Env 945 func: Union[Function, MatchFunction] 946 947 948@dataclass(eq=True, frozen=True, unsafe_hash=True) 949class Record(Object): 950 data: Dict[str, Object] 951 952 953@dataclass(eq=True, frozen=True, unsafe_hash=True) 954class Access(Object): 955 obj: Object 956 at: Object 957 958 959@dataclass(eq=True, frozen=True, unsafe_hash=True) 960class Variant(Object): 961 tag: str 962 value: Object 963 964 965tags = [ 966 TYPE_SHORT := b"i", # fits in 64 bits 967 TYPE_LONG := b"l", # bignum 968 TYPE_FLOAT := b"d", 969 TYPE_STRING := b"s", 970 TYPE_REF := b"r", 971 TYPE_LIST := b"[", 972 TYPE_RECORD := b"{", 973 TYPE_VARIANT := b"#", 974 TYPE_VAR := b"v", 975 TYPE_FUNCTION := b"f", 976 TYPE_MATCH_FUNCTION := b"m", 977 TYPE_CLOSURE := b"c", 978 TYPE_BYTES := b"b", 979 TYPE_HOLE := b"(", 980 TYPE_ASSIGN := b"=", 981 TYPE_BINOP := b"+", 982 TYPE_APPLY := b" ", 983 TYPE_WHERE := b".", 984 TYPE_ACCESS := b"@", 985 TYPE_SPREAD := b"S", 986 TYPE_NAMED_SPREAD := b"R", 987 TYPE_TRUE := b"T", 988 TYPE_FALSE := b"F", 989] 990FLAG_REF = 0x80 991 992 993BITS_PER_BYTE = 8 994BYTES_PER_DIGIT = 8 995BITS_PER_DIGIT = BYTES_PER_DIGIT * BITS_PER_BYTE 996DIGIT_MASK = (1 << BITS_PER_DIGIT) - 1 997 998 999def ref(tag: bytes) -> bytes: 1000 return (tag[0] | FLAG_REF).to_bytes(1, "little") 1001 1002 1003tags = tags + [ref(v) for v in tags] 1004assert len(tags) == len(set(tags)), "Duplicate tags" 1005assert all(len(v) == 1 for v in tags), "Tags must be 1 byte" 1006assert all(isinstance(v, bytes) for v in tags) 1007 1008 1009def zigzag_encode(val: int) -> int: 1010 if val < 0: 1011 return -2 * val - 1 1012 return 2 * val 1013 1014 1015def zigzag_decode(val: int) -> int: 1016 if val & 1 == 1: 1017 return -val // 2 1018 return val // 2 1019 1020 1021@dataclass 1022class Serializer: 1023 refs: typing.List[Object] = dataclasses.field(default_factory=list) 1024 output: bytearray = dataclasses.field(default_factory=bytearray) 1025 1026 def ref(self, obj: Object) -> Optional[int]: 1027 for idx, ref in enumerate(self.refs): 1028 if ref is obj: 1029 return idx 1030 return None 1031 1032 def add_ref(self, ty: bytes, obj: Object) -> int: 1033 assert len(ty) == 1 1034 assert self.ref(obj) is None 1035 self.emit(ref(ty)) 1036 result = len(self.refs) 1037 self.refs.append(obj) 1038 return result 1039 1040 def emit(self, obj: bytes) -> None: 1041 self.output.extend(obj) 1042 1043 def _fits_in_nbits(self, obj: int, nbits: int) -> bool: 1044 return -(1 << (nbits - 1)) <= obj < (1 << (nbits - 1)) 1045 1046 def _short(self, number: int) -> bytes: 1047 # From Peter Ruibal, https://github.com/fmoo/python-varint 1048 number = zigzag_encode(number) 1049 buf = bytearray() 1050 while True: 1051 towrite = number & 0x7F 1052 number >>= 7 1053 if number: 1054 buf.append(towrite | 0x80) 1055 else: 1056 buf.append(towrite) 1057 break 1058 return bytes(buf) 1059 1060 def _long(self, number: int) -> bytes: 1061 digits = [] 1062 number = zigzag_encode(number) 1063 while number: 1064 digits.append(number & DIGIT_MASK) 1065 number >>= BITS_PER_DIGIT 1066 buf = bytearray(self._short(len(digits))) 1067 for digit in digits: 1068 buf.extend(digit.to_bytes(BYTES_PER_DIGIT, "little")) 1069 return bytes(buf) 1070 1071 def _string(self, obj: str) -> bytes: 1072 encoded = obj.encode("utf-8") 1073 return self._short(len(encoded)) + encoded 1074 1075 def serialize(self, obj: Object) -> None: 1076 assert isinstance(obj, Object), type(obj) 1077 if (ref := self.ref(obj)) is not None: 1078 return self.emit(TYPE_REF + self._short(ref)) 1079 if isinstance(obj, Int): 1080 if self._fits_in_nbits(obj.value, 64): 1081 self.emit(TYPE_SHORT) 1082 self.emit(self._short(obj.value)) 1083 return 1084 self.emit(TYPE_LONG) 1085 self.emit(self._long(obj.value)) 1086 return 1087 if isinstance(obj, String): 1088 return self.emit(TYPE_STRING + self._string(obj.value)) 1089 if isinstance(obj, List): 1090 self.add_ref(TYPE_LIST, obj) 1091 self.emit(self._short(len(obj.items))) 1092 for item in obj.items: 1093 self.serialize(item) 1094 return 1095 if isinstance(obj, Variant): 1096 if obj.tag == "true" and isinstance(obj.value, Hole): 1097 return self.emit(TYPE_TRUE) 1098 if obj.tag == "false" and isinstance(obj.value, Hole): 1099 return self.emit(TYPE_FALSE) 1100 # TODO(max): Determine if this should be a ref 1101 self.emit(TYPE_VARIANT) 1102 # TODO(max): String pool (via refs) for strings longer than some length? 1103 self.emit(self._string(obj.tag)) 1104 return self.serialize(obj.value) 1105 if isinstance(obj, Record): 1106 # TODO(max): Determine if this should be a ref 1107 self.emit(TYPE_RECORD) 1108 self.emit(self._short(len(obj.data))) 1109 for key, value in obj.data.items(): 1110 self.emit(self._string(key)) 1111 self.serialize(value) 1112 return 1113 if isinstance(obj, Var): 1114 return self.emit(TYPE_VAR + self._string(obj.name)) 1115 if isinstance(obj, Function): 1116 self.emit(TYPE_FUNCTION) 1117 self.serialize(obj.arg) 1118 return self.serialize(obj.body) 1119 if isinstance(obj, MatchFunction): 1120 self.emit(TYPE_MATCH_FUNCTION) 1121 self.emit(self._short(len(obj.cases))) 1122 for case in obj.cases: 1123 self.serialize(case.pattern) 1124 self.serialize(case.body) 1125 return 1126 if isinstance(obj, Closure): 1127 self.add_ref(TYPE_CLOSURE, obj) 1128 self.serialize(obj.func) 1129 self.emit(self._short(len(obj.env))) 1130 for key, value in obj.env.items(): 1131 self.emit(self._string(key)) 1132 self.serialize(value) 1133 return 1134 if isinstance(obj, Bytes): 1135 self.emit(TYPE_BYTES) 1136 self.emit(self._short(len(obj.value))) 1137 self.emit(obj.value) 1138 return 1139 if isinstance(obj, Float): 1140 self.emit(TYPE_FLOAT) 1141 self.emit(struct.pack("<d", obj.value)) 1142 return 1143 if isinstance(obj, Hole): 1144 self.emit(TYPE_HOLE) 1145 return 1146 if isinstance(obj, Assign): 1147 self.emit(TYPE_ASSIGN) 1148 self.serialize(obj.name) 1149 self.serialize(obj.value) 1150 return 1151 if isinstance(obj, Binop): 1152 self.emit(TYPE_BINOP) 1153 self.emit(self._string(BinopKind.to_str(obj.op))) 1154 self.serialize(obj.left) 1155 self.serialize(obj.right) 1156 return 1157 if isinstance(obj, Apply): 1158 self.emit(TYPE_APPLY) 1159 self.serialize(obj.func) 1160 self.serialize(obj.arg) 1161 return 1162 if isinstance(obj, Where): 1163 self.emit(TYPE_WHERE) 1164 self.serialize(obj.body) 1165 self.serialize(obj.binding) 1166 return 1167 if isinstance(obj, Access): 1168 self.emit(TYPE_ACCESS) 1169 self.serialize(obj.obj) 1170 self.serialize(obj.at) 1171 return 1172 if isinstance(obj, Spread): 1173 if obj.name is not None: 1174 self.emit(TYPE_NAMED_SPREAD) 1175 self.emit(self._string(obj.name)) 1176 return 1177 self.emit(TYPE_SPREAD) 1178 return 1179 raise NotImplementedError(type(obj)) 1180 1181 1182@dataclass 1183class Deserializer: 1184 flat: Union[bytes, memoryview] 1185 idx: int = 0 1186 refs: typing.List[Object] = dataclasses.field(default_factory=list) 1187 1188 def __post_init__(self) -> None: 1189 if isinstance(self.flat, bytes): 1190 self.flat = memoryview(self.flat) 1191 1192 def read(self, size: int) -> memoryview: 1193 result = memoryview(self.flat[self.idx : self.idx + size]) 1194 self.idx += size 1195 return result 1196 1197 def read_tag(self) -> Tuple[bytes, bool]: 1198 tag = self.read(1)[0] 1199 is_ref = bool(tag & FLAG_REF) 1200 return (tag & ~FLAG_REF).to_bytes(1, "little"), is_ref 1201 1202 def _string(self) -> str: 1203 length = self._short() 1204 encoded = self.read(length) 1205 return str(encoded, "utf-8") 1206 1207 def _short(self) -> int: 1208 # From Peter Ruibal, https://github.com/fmoo/python-varint 1209 shift = 0 1210 result = 0 1211 while True: 1212 i = self.read(1)[0] 1213 result |= (i & 0x7F) << shift 1214 shift += 7 1215 if not (i & 0x80): 1216 break 1217 return zigzag_decode(result) 1218 1219 def _long(self) -> int: 1220 num_digits = self._short() 1221 digits = [] 1222 for _ in range(num_digits): 1223 digit = int.from_bytes(self.read(BYTES_PER_DIGIT), "little") 1224 digits.append(digit) 1225 result = 0 1226 for digit in reversed(digits): 1227 result <<= BITS_PER_DIGIT 1228 result |= digit 1229 return zigzag_decode(result) 1230 1231 def parse(self) -> Object: 1232 ty, is_ref = self.read_tag() 1233 if ty == TYPE_REF: 1234 idx = self._short() 1235 return self.refs[idx] 1236 if ty == TYPE_SHORT: 1237 assert not is_ref 1238 return Int(self._short()) 1239 if ty == TYPE_LONG: 1240 assert not is_ref 1241 return Int(self._long()) 1242 if ty == TYPE_STRING: 1243 assert not is_ref 1244 return String(self._string()) 1245 if ty == TYPE_LIST: 1246 length = self._short() 1247 result_list = List([]) 1248 assert is_ref 1249 self.refs.append(result_list) 1250 for i in range(length): 1251 result_list.items.append(self.parse()) 1252 return result_list 1253 if ty == TYPE_RECORD: 1254 assert not is_ref 1255 length = self._short() 1256 result_rec = Record({}) 1257 for i in range(length): 1258 key = self._string() 1259 value = self.parse() 1260 result_rec.data[key] = value 1261 return result_rec 1262 if ty == TYPE_VARIANT: 1263 assert not is_ref 1264 tag = self._string() 1265 value = self.parse() 1266 return Variant(tag, value) 1267 if ty == TYPE_VAR: 1268 assert not is_ref 1269 return Var(self._string()) 1270 if ty == TYPE_FUNCTION: 1271 assert not is_ref 1272 arg = self.parse() 1273 body = self.parse() 1274 return Function(arg, body) 1275 if ty == TYPE_MATCH_FUNCTION: 1276 assert not is_ref 1277 length = self._short() 1278 result_matchfun = MatchFunction([]) 1279 for i in range(length): 1280 pattern = self.parse() 1281 body = self.parse() 1282 result_matchfun.cases.append(MatchCase(pattern, body)) 1283 return result_matchfun 1284 if ty == TYPE_CLOSURE: 1285 func = self.parse() 1286 length = self._short() 1287 assert isinstance(func, (Function, MatchFunction)) 1288 result_closure = Closure({}, func) 1289 assert is_ref 1290 self.refs.append(result_closure) 1291 for i in range(length): 1292 key = self._string() 1293 value = self.parse() 1294 assert isinstance(result_closure.env, dict) # For mypy 1295 result_closure.env[key] = value 1296 return result_closure 1297 if ty == TYPE_BYTES: 1298 assert not is_ref 1299 length = self._short() 1300 return Bytes(self.read(length)) 1301 if ty == TYPE_FLOAT: 1302 assert not is_ref 1303 return Float(struct.unpack("<d", self.read(8))[0]) 1304 if ty == TYPE_HOLE: 1305 assert not is_ref 1306 return Hole() 1307 if ty == TYPE_ASSIGN: 1308 assert not is_ref 1309 name = self.parse() 1310 value = self.parse() 1311 assert isinstance(name, Var) 1312 return Assign(name, value) 1313 if ty == TYPE_BINOP: 1314 assert not is_ref 1315 op = BinopKind.from_str(self._string()) 1316 left = self.parse() 1317 right = self.parse() 1318 return Binop(op, left, right) 1319 if ty == TYPE_APPLY: 1320 assert not is_ref 1321 func = self.parse() 1322 arg = self.parse() 1323 return Apply(func, arg) 1324 if ty == TYPE_WHERE: 1325 assert not is_ref 1326 body = self.parse() 1327 binding = self.parse() 1328 return Where(body, binding) 1329 if ty == TYPE_ACCESS: 1330 assert not is_ref 1331 obj = self.parse() 1332 at = self.parse() 1333 return Access(obj, at) 1334 if ty == TYPE_SPREAD: 1335 return Spread() 1336 if ty == TYPE_NAMED_SPREAD: 1337 return Spread(self._string()) 1338 if ty == TYPE_TRUE: 1339 assert not is_ref 1340 return Variant("true", Hole()) 1341 if ty == TYPE_FALSE: 1342 assert not is_ref 1343 return Variant("false", Hole()) 1344 raise NotImplementedError(bytes(ty)) 1345 1346 1347TRUE = Variant("true", Hole()) 1348 1349 1350FALSE = Variant("false", Hole()) 1351 1352 1353def unpack_number(obj: Object) -> Union[int, float]: 1354 if not isinstance(obj, (Int, Float)): 1355 raise TypeError(f"expected Int or Float, got {type(obj).__name__}") 1356 return obj.value 1357 1358 1359def eval_number(env: Env, exp: Object) -> Union[int, float]: 1360 result = eval_exp(env, exp) 1361 return unpack_number(result) 1362 1363 1364def eval_str(env: Env, exp: Object) -> str: 1365 result = eval_exp(env, exp) 1366 if not isinstance(result, String): 1367 raise TypeError(f"expected String, got {type(result).__name__}") 1368 return result.value 1369 1370 1371def eval_bool(env: Env, exp: Object) -> bool: 1372 result = eval_exp(env, exp) 1373 if not isinstance(result, Variant): 1374 raise TypeError(f"expected #true or #false, got {type(result).__name__}") 1375 if result.tag not in ("true", "false"): 1376 raise TypeError(f"expected #true or #false, got {type(result).__name__}") 1377 return result.tag == "true" 1378 1379 1380def eval_list(env: Env, exp: Object) -> typing.List[Object]: 1381 result = eval_exp(env, exp) 1382 if not isinstance(result, List): 1383 raise TypeError(f"expected List, got {type(result).__name__}") 1384 return result.items 1385 1386 1387def make_bool(x: bool) -> Object: 1388 return TRUE if x else FALSE 1389 1390 1391def wrap_inferred_number_type(x: Union[int, float]) -> Object: 1392 # TODO: Since this is intended to be a reference implementation 1393 # we should avoid relying heavily on Python's implementation of 1394 # arithmetic operations, type inference, and multiple dispatch. 1395 # Update this to make the interpreter more language agnostic. 1396 if isinstance(x, int): 1397 return Int(x) 1398 return Float(x) 1399 1400 1401BINOP_HANDLERS: Dict[BinopKind, Callable[[Env, Object, Object], Object]] = { 1402 BinopKind.ADD: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) + eval_number(env, y)), 1403 BinopKind.SUB: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) - eval_number(env, y)), 1404 BinopKind.MUL: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) * eval_number(env, y)), 1405 BinopKind.DIV: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) / eval_number(env, y)), 1406 BinopKind.FLOOR_DIV: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) // eval_number(env, y)), 1407 BinopKind.EXP: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) ** eval_number(env, y)), 1408 BinopKind.MOD: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) % eval_number(env, y)), 1409 BinopKind.EQUAL: lambda env, x, y: make_bool(eval_exp(env, x) == eval_exp(env, y)), 1410 BinopKind.NOT_EQUAL: lambda env, x, y: make_bool(eval_exp(env, x) != eval_exp(env, y)), 1411 BinopKind.LESS: lambda env, x, y: make_bool(eval_number(env, x) < eval_number(env, y)), 1412 BinopKind.GREATER: lambda env, x, y: make_bool(eval_number(env, x) > eval_number(env, y)), 1413 BinopKind.LESS_EQUAL: lambda env, x, y: make_bool(eval_number(env, x) <= eval_number(env, y)), 1414 BinopKind.GREATER_EQUAL: lambda env, x, y: make_bool(eval_number(env, x) >= eval_number(env, y)), 1415 BinopKind.BOOL_AND: lambda env, x, y: make_bool(eval_bool(env, x) and eval_bool(env, y)), 1416 BinopKind.BOOL_OR: lambda env, x, y: make_bool(eval_bool(env, x) or eval_bool(env, y)), 1417 BinopKind.STRING_CONCAT: lambda env, x, y: String(eval_str(env, x) + eval_str(env, y)), 1418 BinopKind.LIST_CONS: lambda env, x, y: List([eval_exp(env, x)] + eval_list(env, y)), 1419 BinopKind.LIST_APPEND: lambda env, x, y: List(eval_list(env, x) + [eval_exp(env, y)]), 1420 BinopKind.RIGHT_EVAL: lambda env, x, y: eval_exp(env, y), 1421} 1422 1423 1424class MatchError(Exception): 1425 pass 1426 1427 1428def match(obj: Object, pattern: Object) -> Optional[Env]: 1429 if isinstance(pattern, Hole): 1430 return {} if isinstance(obj, Hole) else None 1431 if isinstance(pattern, Int): 1432 return {} if isinstance(obj, Int) and obj.value == pattern.value else None 1433 if isinstance(pattern, Float): 1434 raise MatchError("pattern matching is not supported for Floats") 1435 if isinstance(pattern, String): 1436 return {} if isinstance(obj, String) and obj.value == pattern.value else None 1437 if isinstance(pattern, Var): 1438 return {pattern.name: obj} 1439 if isinstance(pattern, Variant): 1440 if not isinstance(obj, Variant): 1441 return None 1442 if obj.tag != pattern.tag: 1443 return None 1444 return match(obj.value, pattern.value) 1445 if isinstance(pattern, Record): 1446 if not isinstance(obj, Record): 1447 return None 1448 result: Env = {} 1449 seen_keys: set[str] = set() 1450 for key, pattern_item in pattern.data.items(): 1451 if isinstance(pattern_item, Spread): 1452 if pattern_item.name is not None: 1453 assert isinstance(result, dict) # for .update() 1454 rest_keys = set(obj.data.keys()) - seen_keys 1455 result.update({pattern_item.name: Record({key: obj.data[key] for key in rest_keys})}) 1456 return result 1457 seen_keys.add(key) 1458 obj_item = obj.data.get(key) 1459 if obj_item is None: 1460 return None 1461 part = match(obj_item, pattern_item) 1462 if part is None: 1463 return None 1464 assert isinstance(result, dict) # for .update() 1465 result.update(part) 1466 if len(pattern.data) != len(obj.data): 1467 return None 1468 return result 1469 if isinstance(pattern, List): 1470 if not isinstance(obj, List): 1471 return None 1472 result: Env = {} # type: ignore 1473 for i, pattern_item in enumerate(pattern.items): 1474 if isinstance(pattern_item, Spread): 1475 if pattern_item.name is not None: 1476 assert isinstance(result, dict) # for .update() 1477 result.update({pattern_item.name: List(obj.items[i:])}) 1478 return result 1479 if i >= len(obj.items): 1480 return None 1481 obj_item = obj.items[i] 1482 part = match(obj_item, pattern_item) 1483 if part is None: 1484 return None 1485 assert isinstance(result, dict) # for .update() 1486 result.update(part) 1487 if len(pattern.items) != len(obj.items): 1488 return None 1489 return result 1490 raise NotImplementedError(f"match not implemented for {type(pattern).__name__}") 1491 1492 1493def free_in(exp: Object) -> Set[str]: 1494 if isinstance(exp, (Int, Float, String, Bytes, Hole, NativeFunction)): 1495 return set() 1496 if isinstance(exp, Variant): 1497 return free_in(exp.value) 1498 if isinstance(exp, Var): 1499 return {exp.name} 1500 if isinstance(exp, Spread): 1501 if exp.name is not None: 1502 return {exp.name} 1503 return set() 1504 if isinstance(exp, Binop): 1505 return free_in(exp.left) | free_in(exp.right) 1506 if isinstance(exp, List): 1507 if not exp.items: 1508 return set() 1509 return set.union(*(free_in(item) for item in exp.items)) 1510 if isinstance(exp, Record): 1511 if not exp.data: 1512 return set() 1513 return set.union(*(free_in(value) for key, value in exp.data.items())) 1514 if isinstance(exp, Function): 1515 assert isinstance(exp.arg, Var) 1516 return free_in(exp.body) - {exp.arg.name} 1517 if isinstance(exp, MatchFunction): 1518 if not exp.cases: 1519 return set() 1520 return set.union(*(free_in(case) for case in exp.cases)) 1521 if isinstance(exp, MatchCase): 1522 return free_in(exp.body) - free_in(exp.pattern) 1523 if isinstance(exp, Apply): 1524 return free_in(exp.func) | free_in(exp.arg) 1525 if isinstance(exp, Access): 1526 # For records, y is not free in x@y; it is a field name. 1527 # For lists, y *is* free in x@y; it is an index expression (could be a 1528 # var). 1529 # For now, we'll assume it might be an expression and mark it as a 1530 # (possibly extra) freevar. 1531 return free_in(exp.obj) | free_in(exp.at) 1532 if isinstance(exp, Where): 1533 assert isinstance(exp.binding, Assign) 1534 return (free_in(exp.body) - {exp.binding.name.name}) | free_in(exp.binding) 1535 if isinstance(exp, Assign): 1536 return free_in(exp.value) 1537 if isinstance(exp, Closure): 1538 # TODO(max): Should this remove the set of keys in the closure env? 1539 return free_in(exp.func) 1540 raise NotImplementedError(("free_in", type(exp))) 1541 1542 1543def improve_closure(closure: Closure) -> Closure: 1544 freevars = free_in(closure.func) 1545 env = {boundvar: value for boundvar, value in closure.env.items() if boundvar in freevars} 1546 return Closure(env, closure.func) 1547 1548 1549def eval_exp(env: Env, exp: Object) -> Object: 1550 logger.debug(exp) 1551 if isinstance(exp, (Int, Float, String, Bytes, Hole, Closure, NativeFunction)): 1552 return exp 1553 if isinstance(exp, Variant): 1554 return Variant(exp.tag, eval_exp(env, exp.value)) 1555 if isinstance(exp, Var): 1556 value = env.get(exp.name) 1557 if value is None: 1558 raise NameError(f"name '{exp.name}' is not defined") 1559 return value 1560 if isinstance(exp, Binop): 1561 handler = BINOP_HANDLERS.get(exp.op) 1562 if handler is None: 1563 raise NotImplementedError(f"no handler for {exp.op}") 1564 return handler(env, exp.left, exp.right) 1565 if isinstance(exp, List): 1566 return List([eval_exp(env, item) for item in exp.items]) 1567 if isinstance(exp, Record): 1568 return Record({k: eval_exp(env, exp.data[k]) for k in exp.data}) 1569 if isinstance(exp, Assign): 1570 # TODO(max): Rework this. There's something about matching that we need 1571 # to figure out and implement. 1572 assert isinstance(exp.name, Var) 1573 value = eval_exp(env, exp.value) 1574 if isinstance(value, Closure): 1575 # We want functions to be able to call themselves without using the 1576 # Y combinator or similar, so we bind functions (and only 1577 # functions) using a letrec-like strategy. We augment their 1578 # captured environment with a binding to themselves. 1579 assert isinstance(value.env, dict) 1580 value.env[exp.name.name] = value 1581 # We still improve_closure here even though we also did it on 1582 # Closure creation because the Closure might not need a binding for 1583 # itself (it might not be recursive). 1584 value = improve_closure(value) 1585 return EnvObject({**env, exp.name.name: value}) 1586 if isinstance(exp, Where): 1587 assert isinstance(exp.binding, Assign) 1588 res_env = eval_exp(env, exp.binding) 1589 assert isinstance(res_env, EnvObject) 1590 new_env = {**env, **res_env.env} 1591 return eval_exp(new_env, exp.body) 1592 if isinstance(exp, Assert): 1593 cond = eval_exp(env, exp.cond) 1594 if cond != TRUE: 1595 raise AssertionError(f"condition {exp.cond} failed") 1596 return eval_exp(env, exp.value) 1597 if isinstance(exp, Function): 1598 if not isinstance(exp.arg, Var): 1599 raise RuntimeError(f"expected variable in function definition {exp.arg}") 1600 value = Closure(env, exp) 1601 value = improve_closure(value) 1602 return value 1603 if isinstance(exp, MatchFunction): 1604 value = Closure(env, exp) 1605 value = improve_closure(value) 1606 return value 1607 if isinstance(exp, Apply): 1608 if isinstance(exp.func, Var) and exp.func.name == "$$quote": 1609 return exp.arg 1610 callee = eval_exp(env, exp.func) 1611 arg = eval_exp(env, exp.arg) 1612 if isinstance(callee, NativeFunction): 1613 return callee.func(arg) 1614 if not isinstance(callee, Closure): 1615 raise TypeError(f"attempted to apply a non-closure of type {type(callee).__name__}") 1616 if isinstance(callee.func, Function): 1617 assert isinstance(callee.func.arg, Var) 1618 new_env = {**callee.env, callee.func.arg.name: arg} 1619 return eval_exp(new_env, callee.func.body) 1620 elif isinstance(callee.func, MatchFunction): 1621 for case in callee.func.cases: 1622 m = match(arg, case.pattern) 1623 if m is None: 1624 continue 1625 return eval_exp({**callee.env, **m}, case.body) 1626 raise MatchError("no matching cases") 1627 else: 1628 raise TypeError(f"attempted to apply a non-function of type {type(callee.func).__name__}") 1629 if isinstance(exp, Access): 1630 obj = eval_exp(env, exp.obj) 1631 if isinstance(obj, Record): 1632 if not isinstance(exp.at, Var): 1633 raise TypeError(f"cannot access record field using {type(exp.at).__name__}, expected a field name") 1634 if exp.at.name not in obj.data: 1635 raise NameError(f"no assignment to {exp.at.name} found in record") 1636 return obj.data[exp.at.name] 1637 elif isinstance(obj, List): 1638 access_at = eval_exp(env, exp.at) 1639 if not isinstance(access_at, Int): 1640 raise TypeError(f"cannot index into list using type {type(access_at).__name__}, expected integer") 1641 if access_at.value < 0 or access_at.value >= len(obj.items): 1642 raise ValueError(f"index {access_at.value} out of bounds for list") 1643 return obj.items[access_at.value] 1644 raise TypeError(f"attempted to access from type {type(obj).__name__}") 1645 elif isinstance(exp, Spread): 1646 raise RuntimeError("cannot evaluate a spread") 1647 raise NotImplementedError(f"eval_exp not implemented for {exp}") 1648 1649 1650class ScrapMonad: 1651 def __init__(self, env: Env) -> None: 1652 assert isinstance(env, dict) # for .copy() 1653 self.env: Env = env.copy() 1654 1655 def bind(self, exp: Object) -> Tuple[Object, "ScrapMonad"]: 1656 env = self.env 1657 result = eval_exp(env, exp) 1658 if isinstance(result, EnvObject): 1659 return result, ScrapMonad({**env, **result.env}) 1660 return result, ScrapMonad({**env, "_": result}) 1661 1662 1663class InferenceError(Exception): 1664 pass 1665 1666 1667@dataclasses.dataclass 1668class MonoType: 1669 def find(self) -> MonoType: 1670 return self 1671 1672 1673@dataclasses.dataclass 1674class TyVar(MonoType): 1675 forwarded: MonoType | None = dataclasses.field(init=False, default=None) 1676 name: str 1677 1678 def find(self) -> MonoType: 1679 result: MonoType = self 1680 while isinstance(result, TyVar): 1681 it = result.forwarded 1682 if it is None: 1683 return result 1684 result = it 1685 return result 1686 1687 def __str__(self) -> str: 1688 return f"'{self.name}" 1689 1690 def make_equal_to(self, other: MonoType) -> None: 1691 chain_end = self.find() 1692 if not isinstance(chain_end, TyVar): 1693 raise InferenceError(f"{self} is already resolved to {chain_end}") 1694 chain_end.forwarded = other 1695 1696 def is_unbound(self) -> bool: 1697 return self.forwarded is None 1698 1699 1700@dataclasses.dataclass 1701class TyCon(MonoType): 1702 name: str 1703 args: list[MonoType] 1704 1705 def __str__(self) -> str: 1706 # TODO(max): Precedence pretty-print type constructors 1707 if not self.args: 1708 return self.name 1709 if len(self.args) == 1: 1710 return f"({self.args[0]} {self.name})" 1711 return f"({self.name.join(map(str, self.args))})" 1712 1713 1714@dataclasses.dataclass 1715class TyEmptyRow(MonoType): 1716 def __str__(self) -> str: 1717 return "{}" 1718 1719 1720@dataclasses.dataclass 1721class TyRow(MonoType): 1722 fields: dict[str, MonoType] 1723 rest: TyVar | TyEmptyRow = dataclasses.field(default_factory=TyEmptyRow) 1724 1725 def __post_init__(self) -> None: 1726 if not self.fields and isinstance(self.rest, TyEmptyRow): 1727 raise InferenceError("Empty row must have a rest type") 1728 1729 def __str__(self) -> str: 1730 flat, rest = row_flatten(self) 1731 # sort to make tests deterministic 1732 result = [f"{key}={val}" for key, val in sorted(flat.items())] 1733 if isinstance(rest, TyVar): 1734 result.append(f"...{rest}") 1735 else: 1736 assert isinstance(rest, TyEmptyRow) 1737 return "{" + ", ".join(result) + "}" 1738 1739 1740def row_flatten(rec: MonoType) -> tuple[dict[str, MonoType], TyVar | TyEmptyRow]: 1741 if isinstance(rec, TyVar): 1742 rec = rec.find() 1743 if isinstance(rec, TyVar): 1744 return {}, rec 1745 if isinstance(rec, TyRow): 1746 flat, rest = row_flatten(rec.rest) 1747 flat.update(rec.fields) 1748 return flat, rest 1749 if isinstance(rec, TyEmptyRow): 1750 return {}, rec 1751 raise InferenceError(f"Expected record type, got {type(rec)}") 1752 1753 1754@dataclasses.dataclass 1755class Forall: 1756 tyvars: list[TyVar] 1757 ty: MonoType 1758 1759 def __str__(self) -> str: 1760 return f"(forall {', '.join(map(str, self.tyvars))}. {self.ty})" 1761 1762 1763def func_type(*args: MonoType) -> TyCon: 1764 assert len(args) >= 2 1765 if len(args) == 2: 1766 return TyCon("->", list(args)) 1767 return TyCon("->", [args[0], func_type(*args[1:])]) 1768 1769 1770def list_type(arg: MonoType) -> TyCon: 1771 return TyCon("list", [arg]) 1772 1773 1774def unify_fail(ty1: MonoType, ty2: MonoType) -> None: 1775 raise InferenceError(f"Unification failed for {ty1} and {ty2}") 1776 1777 1778def occurs_in(tyvar: TyVar, ty: MonoType) -> bool: 1779 if isinstance(ty, TyVar): 1780 return tyvar == ty 1781 if isinstance(ty, TyCon): 1782 return any(occurs_in(tyvar, arg) for arg in ty.args) 1783 if isinstance(ty, TyEmptyRow): 1784 return False 1785 if isinstance(ty, TyRow): 1786 return any(occurs_in(tyvar, val) for val in ty.fields.values()) or occurs_in(tyvar, ty.rest) 1787 raise InferenceError(f"Unknown type: {ty}") 1788 1789 1790def unify_type(ty1: MonoType, ty2: MonoType) -> None: 1791 ty1 = ty1.find() 1792 ty2 = ty2.find() 1793 if isinstance(ty1, TyVar): 1794 if occurs_in(ty1, ty2): 1795 raise InferenceError(f"Occurs check failed for {ty1} and {ty2}") 1796 ty1.make_equal_to(ty2) 1797 return 1798 if isinstance(ty2, TyVar): # Mirror 1799 return unify_type(ty2, ty1) 1800 if isinstance(ty1, TyCon) and isinstance(ty2, TyCon): 1801 if ty1.name != ty2.name: 1802 unify_fail(ty1, ty2) 1803 return 1804 if len(ty1.args) != len(ty2.args): 1805 unify_fail(ty1, ty2) 1806 return 1807 for l, r in zip(ty1.args, ty2.args): 1808 unify_type(l, r) 1809 return 1810 if isinstance(ty1, TyEmptyRow) and isinstance(ty2, TyEmptyRow): 1811 return 1812 if isinstance(ty1, TyRow) and isinstance(ty2, TyRow): 1813 ty1_fields, ty1_rest = row_flatten(ty1) 1814 ty2_fields, ty2_rest = row_flatten(ty2) 1815 ty1_missing = {} 1816 ty2_missing = {} 1817 all_field_names = set(ty1_fields.keys()) | set(ty2_fields.keys()) 1818 for key in sorted(all_field_names): # Sort for deterministic error messages 1819 ty1_val = ty1_fields.get(key) 1820 ty2_val = ty2_fields.get(key) 1821 if ty1_val is not None and ty2_val is not None: 1822 unify_type(ty1_val, ty2_val) 1823 elif ty1_val is None: 1824 assert ty2_val is not None 1825 ty1_missing[key] = ty2_val 1826 elif ty2_val is None: 1827 assert ty1_val is not None 1828 ty2_missing[key] = ty1_val 1829 # In general, we want to: 1830 # 1) Add missing fields from one row to the other row 1831 # 2) "Keep the rows unified" by linking each row's rest to the other 1832 # row's rest 1833 if not ty1_missing and not ty2_missing: 1834 # The rests are either both empty (rows were closed) or both 1835 # unbound type variables (rows were open); unify the rest variables 1836 unify_type(ty1_rest, ty2_rest) 1837 return 1838 if not ty1_missing: 1839 # The first row has fields that the second row doesn't have; add 1840 # them to the second row 1841 unify_type(ty2_rest, TyRow(ty2_missing, ty1_rest)) 1842 return 1843 if not ty2_missing: 1844 # The second row has fields that the first row doesn't have; add 1845 # them to the first row 1846 unify_type(ty1_rest, TyRow(ty1_missing, ty2_rest)) 1847 return 1848 # They each have fields the other lacks; create new rows sharing a rest 1849 # and add the missing fields to each row 1850 rest = fresh_tyvar() 1851 unify_type(ty1_rest, TyRow(ty1_missing, rest)) 1852 unify_type(ty2_rest, TyRow(ty2_missing, rest)) 1853 return 1854 if isinstance(ty1, TyRow) and isinstance(ty2, TyEmptyRow): 1855 raise InferenceError(f"Unifying row {ty1} with empty row") 1856 if isinstance(ty1, TyEmptyRow) and isinstance(ty2, TyRow): 1857 raise InferenceError(f"Unifying empty row with row {ty2}") 1858 raise InferenceError(f"Cannot unify {ty1} and {ty2}") 1859 1860 1861Context = typing.Mapping[str, Forall] 1862 1863 1864fresh_var_counter = 0 1865 1866 1867def fresh_tyvar(prefix: str = "t") -> TyVar: 1868 global fresh_var_counter 1869 result = f"{prefix}{fresh_var_counter}" 1870 fresh_var_counter += 1 1871 return TyVar(result) 1872 1873 1874def reset_tyvar_counter() -> None: 1875 global fresh_var_counter 1876 fresh_var_counter = 0 1877 1878 1879IntType = TyCon("int", []) 1880StringType = TyCon("string", []) 1881FloatType = TyCon("float", []) 1882BytesType = TyCon("bytes", []) 1883HoleType = TyCon("hole", []) 1884 1885 1886Subst = typing.Mapping[str, MonoType] 1887 1888 1889def apply_ty(ty: MonoType, subst: Subst) -> MonoType: 1890 ty = ty.find() 1891 if isinstance(ty, TyVar): 1892 return subst.get(ty.name, ty) 1893 if isinstance(ty, TyCon): 1894 return TyCon(ty.name, [apply_ty(arg, subst) for arg in ty.args]) 1895 if isinstance(ty, TyEmptyRow): 1896 return ty 1897 if isinstance(ty, TyRow): 1898 rest = apply_ty(ty.rest, subst) 1899 assert isinstance(rest, (TyVar, TyEmptyRow)) 1900 return TyRow({key: apply_ty(val, subst) for key, val in ty.fields.items()}, rest) 1901 raise InferenceError(f"Unknown type: {ty}") 1902 1903 1904def instantiate(scheme: Forall) -> MonoType: 1905 fresh = {tyvar.name: fresh_tyvar() for tyvar in scheme.tyvars} 1906 return apply_ty(scheme.ty, fresh) 1907 1908 1909def ftv_ty(ty: MonoType) -> set[str]: 1910 ty = ty.find() 1911 if isinstance(ty, TyVar): 1912 return {ty.name} 1913 if isinstance(ty, TyCon): 1914 return set().union(*map(ftv_ty, ty.args)) 1915 if isinstance(ty, TyEmptyRow): 1916 return set() 1917 if isinstance(ty, TyRow): 1918 return set().union(*map(ftv_ty, ty.fields.values()), ftv_ty(ty.rest)) 1919 raise InferenceError(f"Unknown type: {ty}") 1920 1921 1922def generalize(ty: MonoType, ctx: Context) -> Forall: 1923 def ftv_scheme(ty: Forall) -> set[str]: 1924 return ftv_ty(ty.ty) - set(tyvar.name for tyvar in ty.tyvars) 1925 1926 def ftv_ctx(ctx: Context) -> set[str]: 1927 return set().union(*(ftv_scheme(scheme) for scheme in ctx.values())) 1928 1929 # TODO(max): Freshen? 1930 tyvars = ftv_ty(ty) - ftv_ctx(ctx) 1931 return Forall([TyVar(name) for name in sorted(tyvars)], ty) 1932 1933 1934def type_of(expr: Object) -> MonoType: 1935 ty = getattr(expr, "inferred_type", None) 1936 if ty is not None: 1937 assert isinstance(ty, MonoType) 1938 return ty.find() 1939 return set_type(expr, fresh_tyvar()) 1940 1941 1942def set_type(expr: Object, ty: MonoType) -> MonoType: 1943 object.__setattr__(expr, "inferred_type", ty) 1944 return ty 1945 1946 1947def infer_common(expr: Object) -> MonoType: 1948 if isinstance(expr, Int): 1949 return set_type(expr, IntType) 1950 if isinstance(expr, Float): 1951 return set_type(expr, FloatType) 1952 if isinstance(expr, Bytes): 1953 return set_type(expr, BytesType) 1954 if isinstance(expr, Hole): 1955 return set_type(expr, HoleType) 1956 if isinstance(expr, String): 1957 return set_type(expr, StringType) 1958 raise InferenceError(f"{type(expr)} can't be simply inferred") 1959 1960 1961def infer_pattern_type(pattern: Object, ctx: Context) -> MonoType: 1962 assert isinstance(ctx, dict) 1963 if isinstance(pattern, (Int, Float, Bytes, Hole, String)): 1964 return infer_common(pattern) 1965 if isinstance(pattern, Var): 1966 result = fresh_tyvar() 1967 ctx[pattern.name] = Forall([], result) 1968 return set_type(pattern, result) 1969 if isinstance(pattern, List): 1970 list_item_ty = fresh_tyvar() 1971 result_ty = list_type(list_item_ty) 1972 for item in pattern.items: 1973 if isinstance(item, Spread): 1974 if item.name is not None: 1975 ctx[item.name] = Forall([], result_ty) 1976 break 1977 item_ty = infer_pattern_type(item, ctx) 1978 unify_type(list_item_ty, item_ty) 1979 return set_type(pattern, result_ty) 1980 if isinstance(pattern, Record): 1981 fields = {} 1982 rest: TyVar | TyEmptyRow = TyEmptyRow() # Default closed row 1983 for key, value in pattern.data.items(): 1984 if isinstance(value, Spread): 1985 # Open row 1986 rest = fresh_tyvar() 1987 if value.name is not None: 1988 ctx[value.name] = Forall([], rest) 1989 break 1990 fields[key] = infer_pattern_type(value, ctx) 1991 return set_type(pattern, TyRow(fields, rest)) 1992 raise InferenceError(f"{type(pattern)} isn't allowed in a pattern") 1993 1994 1995def infer_type(expr: Object, ctx: Context) -> MonoType: 1996 if isinstance(expr, (Int, Float, Bytes, Hole, String)): 1997 return infer_common(expr) 1998 if isinstance(expr, Var): 1999 scheme = ctx.get(expr.name) 2000 if scheme is None: 2001 raise InferenceError(f"Unbound variable {expr.name}") 2002 return set_type(expr, instantiate(scheme)) 2003 if isinstance(expr, Function): 2004 arg_tyvar = fresh_tyvar() 2005 assert isinstance(expr.arg, Var) 2006 body_ctx = {**ctx, expr.arg.name: Forall([], arg_tyvar)} 2007 body_ty = infer_type(expr.body, body_ctx) 2008 return set_type(expr, func_type(arg_tyvar, body_ty)) 2009 if isinstance(expr, Binop): 2010 left, right = expr.left, expr.right 2011 op = Var(BinopKind.to_str(expr.op)) 2012 return set_type(expr, infer_type(Apply(Apply(op, left), right), ctx)) 2013 if isinstance(expr, Where): 2014 assert isinstance(expr.binding, Assign) 2015 name, value, body = expr.binding.name.name, expr.binding.value, expr.body 2016 if isinstance(value, (Function, MatchFunction)): 2017 # Letrec 2018 func_ty: MonoType = fresh_tyvar() 2019 value_ty = infer_type(value, {**ctx, name: Forall([], func_ty)}) 2020 else: 2021 # Let 2022 value_ty = infer_type(value, ctx) 2023 value_scheme = generalize(value_ty, ctx) 2024 body_ty = infer_type(body, {**ctx, name: value_scheme}) 2025 return set_type(expr, body_ty) 2026 if isinstance(expr, List): 2027 list_item_ty = fresh_tyvar() 2028 for item in expr.items: 2029 assert not isinstance(item, Spread), "Spread can only occur in list match (for now)" 2030 item_ty = infer_type(item, ctx) 2031 unify_type(list_item_ty, item_ty) 2032 return set_type(expr, list_type(list_item_ty)) 2033 if isinstance(expr, MatchCase): 2034 pattern_ctx: Context = {} 2035 pattern_ty = infer_pattern_type(expr.pattern, pattern_ctx) 2036 body_ty = infer_type(expr.body, {**ctx, **pattern_ctx}) 2037 return set_type(expr, func_type(pattern_ty, body_ty)) 2038 if isinstance(expr, Apply): 2039 func_ty = infer_type(expr.func, ctx) 2040 arg_ty = infer_type(expr.arg, ctx) 2041 result = fresh_tyvar() 2042 unify_type(func_ty, func_type(arg_ty, result)) 2043 return set_type(expr, result) 2044 if isinstance(expr, MatchFunction): 2045 result = fresh_tyvar() 2046 for case in expr.cases: 2047 case_ty = infer_type(case, ctx) 2048 unify_type(result, case_ty) 2049 return set_type(expr, result) 2050 if isinstance(expr, Record): 2051 fields = {} 2052 rest: TyVar | TyEmptyRow = TyEmptyRow() 2053 for key, value in expr.data.items(): 2054 assert not isinstance(value, Spread), "Spread can only occur in record match (for now)" 2055 fields[key] = infer_type(value, ctx) 2056 return set_type(expr, TyRow(fields, rest)) 2057 if isinstance(expr, Access): 2058 obj_ty = infer_type(expr.obj, ctx) 2059 value_ty = fresh_tyvar() 2060 assert isinstance(expr.at, Var) 2061 # "has field" constraint in the form of an open row 2062 unify_type(obj_ty, TyRow({expr.at.name: value_ty}, fresh_tyvar())) 2063 return value_ty 2064 raise InferenceError(f"Unexpected type {type(expr)}") 2065 2066 2067def minimize(ty: MonoType) -> MonoType: 2068 letters = iter("abcdefghijklmnopqrstuvwxyz") 2069 free = ftv_ty(ty) 2070 subst = {ftv: TyVar(next(letters)) for ftv in sorted(free)} 2071 return apply_ty(ty, subst) 2072 2073 2074Number = typing.Union[int, float] 2075 2076 2077class Repr(typing.Protocol): 2078 def __call__(self, obj: Object, prec: Number = 0) -> str: ... 2079 2080 2081# Can't use reprlib.recursive_repr because it doesn't work if the print 2082# function has more than one argument (for example, prec) 2083def handle_recursion(func: Repr) -> Repr: 2084 cache: typing.List[Object] = [] 2085 2086 @functools.wraps(func) 2087 def wrapper(obj: Object, prec: Number = 0) -> str: 2088 for cached in cache: 2089 if obj is cached: 2090 return "..." 2091 cache.append(obj) 2092 result = func(obj, prec) 2093 cache.remove(obj) 2094 return result 2095 2096 return wrapper 2097 2098 2099@handle_recursion 2100def pretty(obj: Object, prec: Number = 0) -> str: 2101 if isinstance(obj, Int): 2102 return str(obj.value) 2103 if isinstance(obj, Float): 2104 return str(obj.value) 2105 if isinstance(obj, String): 2106 return json.dumps(obj.value) 2107 if isinstance(obj, Bytes): 2108 return f"~~{base64.b64encode(obj.value).decode()}" 2109 if isinstance(obj, Var): 2110 return obj.name 2111 if isinstance(obj, Hole): 2112 return "()" 2113 if isinstance(obj, Spread): 2114 return f"...{obj.name}" if obj.name else "..." 2115 if isinstance(obj, List): 2116 return f"[{', '.join(pretty(item) for item in obj.items)}]" 2117 if isinstance(obj, Record): 2118 return f"{{{', '.join(f'{key} = {pretty(value)}' for key, value in obj.data.items())}}}" 2119 if isinstance(obj, Closure): 2120 keys = list(obj.env.keys()) 2121 return f"Closure({keys}, {pretty(obj.func)})" 2122 if isinstance(obj, EnvObject): 2123 return f"EnvObject({repr(obj.env)})" 2124 if isinstance(obj, NativeFunction): 2125 return f"NativeFunction(name={obj.name})" 2126 if isinstance(obj, Relocation): 2127 return f"Relocation(name={repr(obj.name)})" 2128 if isinstance(obj, Variant): 2129 op_prec = PS["#"] 2130 left_prec, right_prec = op_prec.pl, op_prec.pr 2131 result = f"#{obj.tag} {pretty(obj.value, right_prec)}" 2132 if isinstance(obj, Assign): 2133 op_prec = PS["="] 2134 left_prec, right_prec = op_prec.pl, op_prec.pr 2135 result = f"{pretty(obj.name, left_prec)} = {pretty(obj.value, right_prec)}" 2136 if isinstance(obj, Binop): 2137 op_prec = PS[BinopKind.to_str(obj.op)] 2138 left_prec, right_prec = op_prec.pl, op_prec.pr 2139 result = f"{pretty(obj.left, left_prec)} {BinopKind.to_str(obj.op)} {pretty(obj.right, right_prec)}" 2140 if isinstance(obj, Function): 2141 op_prec = PS["->"] 2142 left_prec, right_prec = op_prec.pl, op_prec.pr 2143 assert isinstance(obj.arg, Var) 2144 result = f"{obj.arg.name} -> {pretty(obj.body, right_prec)}" 2145 if isinstance(obj, MatchFunction): 2146 op_prec = PS["|"] 2147 left_prec, right_prec = op_prec.pl, op_prec.pr 2148 result = "\n".join( 2149 f"| {pretty(case.pattern, left_prec)} -> {pretty(case.body, right_prec)}" for case in obj.cases 2150 ) 2151 if isinstance(obj, Where): 2152 op_prec = PS["."] 2153 left_prec, right_prec = op_prec.pl, op_prec.pr 2154 result = f"{pretty(obj.body, left_prec)} . {pretty(obj.binding, right_prec)}" 2155 if isinstance(obj, Assert): 2156 op_prec = PS["!"] 2157 left_prec, right_prec = op_prec.pl, op_prec.pr 2158 result = f"{pretty(obj.value, left_prec)} ! {pretty(obj.cond, right_prec)}" 2159 if isinstance(obj, Apply): 2160 op_prec = PS[""] 2161 left_prec, right_prec = op_prec.pl, op_prec.pr 2162 result = f"{pretty(obj.func, left_prec)} {pretty(obj.arg, right_prec)}" 2163 if isinstance(obj, Access): 2164 op_prec = PS["@"] 2165 left_prec, right_prec = op_prec.pl, op_prec.pr 2166 result = f"{pretty(obj.obj, left_prec)} @ {pretty(obj.at, right_prec)}" 2167 if prec >= op_prec.pl: 2168 return f"({result})" 2169 return result 2170 2171 2172def fetch(url: Object) -> Object: 2173 if not isinstance(url, String): 2174 raise TypeError(f"fetch expected String, but got {type(url).__name__}") 2175 with urllib.request.urlopen(url.value) as f: 2176 return String(f.read().decode("utf-8")) 2177 2178 2179def make_object(pyobj: object) -> Object: 2180 assert not isinstance(pyobj, Object) 2181 if isinstance(pyobj, int): 2182 return Int(pyobj) 2183 if isinstance(pyobj, str): 2184 return String(pyobj) 2185 if isinstance(pyobj, list): 2186 return List([make_object(o) for o in pyobj]) 2187 if isinstance(pyobj, dict): 2188 # Assumed to only be called with JSON, so string keys. 2189 return Record({key: make_object(value) for key, value in pyobj.items()}) 2190 raise NotImplementedError(type(pyobj)) 2191 2192 2193def jsondecode(obj: Object) -> Object: 2194 if not isinstance(obj, String): 2195 raise TypeError(f"jsondecode expected String, but got {type(obj).__name__}") 2196 data = json.loads(obj.value) 2197 return make_object(data) 2198 2199 2200def listlength(obj: Object) -> Object: 2201 # TODO(max): Implement in scrapscript once list pattern matching is 2202 # implemented. 2203 if not isinstance(obj, List): 2204 raise TypeError(f"listlength expected List, but got {type(obj).__name__}") 2205 return Int(len(obj.items)) 2206 2207 2208def serialize(obj: Object) -> bytes: 2209 serializer = Serializer() 2210 serializer.serialize(obj) 2211 return bytes(serializer.output) 2212 2213 2214def deserialize(data: bytes) -> Object: 2215 deserializer = Deserializer(data) 2216 return deserializer.parse() 2217 2218 2219def deserialize_object(obj: Object) -> Object: 2220 assert isinstance(obj, Bytes) 2221 return deserialize(obj.value) 2222 2223 2224STDLIB = { 2225 "$$add": Closure({}, Function(Var("x"), Function(Var("y"), Binop(BinopKind.ADD, Var("x"), Var("y"))))), 2226 "$$fetch": NativeFunction("$$fetch", fetch), 2227 "$$jsondecode": NativeFunction("$$jsondecode", jsondecode), 2228 "$$serialize": NativeFunction("$$serialize", lambda obj: Bytes(serialize(obj))), 2229 "$$deserialize": NativeFunction("$$deserialize", deserialize_object), 2230 "$$listlength": NativeFunction("$$listlength", listlength), 2231} 2232 2233 2234PRELUDE = """ 2235id = x -> x 2236 2237. quicksort = 2238 | [] -> [] 2239 | [p, ...xs] -> (concat ((quicksort (ltp xs p)) +< p) (quicksort (gtp xs p)) 2240 . gtp = xs -> p -> filter (x -> x >= p) xs 2241 . ltp = xs -> p -> filter (x -> x < p) xs) 2242 2243. filter = f -> 2244 | [] -> [] 2245 | [x, ...xs] -> f x |> | #true () -> x >+ filter f xs 2246 | #false () -> filter f xs 2247 2248. concat = xs -> 2249 | [] -> xs 2250 | [y, ...ys] -> concat (xs +< y) ys 2251 2252. map = f -> 2253 | [] -> [] 2254 | [x, ...xs] -> f x >+ map f xs 2255 2256. range = 2257 | 0 -> [] 2258 | i -> range (i - 1) +< (i - 1) 2259 2260. foldr = f -> a -> 2261 | [] -> a 2262 | [x, ...xs] -> f x (foldr f a xs) 2263 2264. take = 2265 | 0 -> xs -> [] 2266 | n -> 2267 | [] -> [] 2268 | [x, ...xs] -> x >+ take (n - 1) xs 2269 2270. all = f -> 2271 | [] -> #true () 2272 | [x, ...xs] -> f x && all f xs 2273 2274. any = f -> 2275 | [] -> #false () 2276 | [x, ...xs] -> f x || any f xs 2277""" 2278 2279 2280def boot_env() -> Env: 2281 env_object = eval_exp(STDLIB, parse(tokenize(PRELUDE))) 2282 assert isinstance(env_object, EnvObject) 2283 return env_object.env 2284 2285 2286class Completer: 2287 def __init__(self, env: Env) -> None: 2288 self.env: Env = env 2289 self.matches: typing.List[str] = [] 2290 2291 def complete(self, text: str, state: int) -> Optional[str]: 2292 assert "@" not in text, "TODO: handle attr/index access" 2293 if state == 0: 2294 options = sorted(self.env.keys()) 2295 if not text: 2296 self.matches = options[:] 2297 else: 2298 self.matches = [key for key in options if key.startswith(text)] 2299 try: 2300 return self.matches[state] 2301 except IndexError: 2302 return None 2303 2304 2305REPL_HISTFILE = os.path.expanduser(".scrap-history") 2306 2307 2308class ScrapRepl(code.InteractiveConsole): 2309 def __init__(self, *args: Any, **kwargs: Any) -> None: 2310 super().__init__(*args, **kwargs) 2311 self.env: Env = boot_env() 2312 2313 def enable_readline(self) -> None: 2314 assert readline, "Can't enable readline without readline module" 2315 if os.path.exists(REPL_HISTFILE): 2316 readline.read_history_file(REPL_HISTFILE) 2317 # what determines the end of a word; need to set so $ can be part of a 2318 # variable name 2319 readline.set_completer_delims(" \t\n;") 2320 # TODO(max): Add completion per scope, not just for global environment. 2321 readline.set_completer(Completer(self.env).complete) 2322 readline.parse_and_bind("set show-all-if-ambiguous on") 2323 readline.parse_and_bind("tab: menu-complete") 2324 2325 def finish_readline(self) -> None: 2326 assert readline, "Can't finish readline without readline module" 2327 histfile_size = 1000 2328 readline.set_history_length(histfile_size) 2329 readline.write_history_file(REPL_HISTFILE) 2330 2331 def runsource(self, source: str, filename: str = "<input>", symbol: str = "single") -> bool: 2332 try: 2333 tokens = tokenize(source) 2334 logger.debug("Tokens: %s", tokens) 2335 ast = parse(tokens) 2336 if isinstance(ast, MatchFunction) and not source.endswith("\n"): 2337 # User might be in the middle of typing a multi-line match... 2338 # wait for them to hit Enter once after the last case 2339 return True 2340 logger.debug("AST: %s", ast) 2341 result = eval_exp(self.env, ast) 2342 assert isinstance(self.env, dict) # for .update()/__setitem__ 2343 if isinstance(result, EnvObject): 2344 self.env.update(result.env) 2345 else: 2346 self.env["_"] = result 2347 print(pretty(result)) 2348 except UnexpectedEOFError: 2349 # Need to read more text 2350 return True 2351 except ParseError as e: 2352 print(f"Parse error: {e}", file=sys.stderr) 2353 except Exception as e: 2354 print(f"Error: {e}", file=sys.stderr) 2355 return False 2356 2357 2358def eval_command(args: argparse.Namespace) -> None: 2359 if args.debug: 2360 logging.basicConfig(level=logging.DEBUG) 2361 2362 program = args.program_file.read() 2363 tokens = tokenize(program) 2364 logger.debug("Tokens: %s", tokens) 2365 ast = parse(tokens) 2366 logger.debug("AST: %s", ast) 2367 result = eval_exp(boot_env(), ast) 2368 print(pretty(result)) 2369 2370 2371def check_command(args: argparse.Namespace) -> None: 2372 if args.debug: 2373 logging.basicConfig(level=logging.DEBUG) 2374 2375 program = args.program_file.read() 2376 tokens = tokenize(program) 2377 logger.debug("Tokens: %s", tokens) 2378 ast = parse(tokens) 2379 logger.debug("AST: %s", ast) 2380 result = infer_type(ast, OP_ENV) 2381 result = minimize(result) 2382 print(result) 2383 2384 2385def apply_command(args: argparse.Namespace) -> None: 2386 if args.debug: 2387 logging.basicConfig(level=logging.DEBUG) 2388 2389 tokens = tokenize(args.program) 2390 logger.debug("Tokens: %s", tokens) 2391 ast = parse(tokens) 2392 logger.debug("AST: %s", ast) 2393 result = eval_exp(boot_env(), ast) 2394 print(pretty(result)) 2395 2396 2397def repl_command(args: argparse.Namespace) -> None: 2398 if args.debug: 2399 logging.basicConfig(level=logging.DEBUG) 2400 2401 repl = ScrapRepl() 2402 if readline: 2403 repl.enable_readline() 2404 repl.interact(banner="") 2405 if readline: 2406 repl.finish_readline() 2407 2408 2409def env_get_split(key: str, default: Optional[typing.List[str]] = None) -> typing.List[str]: 2410 import shlex 2411 2412 cflags = os.environ.get(key) 2413 if cflags: 2414 return shlex.split(cflags) 2415 if default: 2416 return default 2417 return [] 2418 2419 2420def discover_cflags(cc: typing.List[str], debug: bool = True) -> typing.List[str]: 2421 default_cflags = ["-Wall", "-Wextra", "-fno-strict-aliasing", "-Wno-unused-function"] 2422 # -fno-strict-aliasing is needed because we do pointer casting a bunch 2423 # -Wno-unused-function is needed because we have a bunch of unused 2424 # functions depending on what code is compiled 2425 if debug: 2426 default_cflags += ["-O0", "-ggdb"] 2427 else: 2428 default_cflags += ["-O2", "-DNDEBUG"] 2429 if "cosmo" not in cc[0]: 2430 # cosmocc does not support LTO 2431 default_cflags.append("-flto") 2432 if "mingw" in cc[0]: 2433 # Windows does not support mmap 2434 default_cflags.append("-DSTATIC_HEAP") 2435 return env_get_split("CFLAGS", default_cflags) 2436 2437 2438OP_ENV = { 2439 "+": Forall([], func_type(IntType, IntType, IntType)), 2440 "-": Forall([], func_type(IntType, IntType, IntType)), 2441 "*": Forall([], func_type(IntType, IntType, IntType)), 2442 "/": Forall([], func_type(IntType, IntType, FloatType)), 2443 "++": Forall([], func_type(StringType, StringType, StringType)), 2444 ">+": Forall([TyVar("a")], func_type(TyVar("a"), list_type(TyVar("a")), list_type(TyVar("a")))), 2445 "+<": Forall([TyVar("a")], func_type(list_type(TyVar("a")), TyVar("a"), list_type(TyVar("a")))), 2446} 2447 2448 2449def compile_command(args: argparse.Namespace) -> None: 2450 if args.run: 2451 args.compile = True 2452 from compiler import compile_to_string 2453 2454 with open(args.file, "r") as f: 2455 source = f.read() 2456 2457 program = parse(tokenize(source)) 2458 if args.check: 2459 infer_type(program, OP_ENV) 2460 c_program = compile_to_string(program, args.debug) 2461 2462 with open(args.platform, "r") as f: 2463 platform = f.read() 2464 2465 with open(args.output, "w") as f: 2466 f.write(c_program) 2467 f.write(platform) 2468 2469 if args.format: 2470 import subprocess 2471 2472 subprocess.run(["clang-format-15", "-i", args.output], check=True) 2473 2474 if args.compile: 2475 import subprocess 2476 2477 cc = env_get_split("CC", ["clang"]) 2478 cflags = discover_cflags(cc, args.debug) 2479 if args.memory: 2480 cflags += [f"-DMEMORY_SIZE={args.memory}"] 2481 if args.handle_stack_size: 2482 cflags += [f"-DHANDLE_STACK_SIZE={args.handle_stack_size}"] 2483 ldflags = env_get_split("LDFLAGS") 2484 subprocess.run([*cc, "-o", "a.out", *cflags, args.output, *ldflags], check=True) 2485 2486 if args.run: 2487 import subprocess 2488 2489 subprocess.run(["sh", "-c", "./a.out"], check=True) 2490 2491 2492def flat_command(args: argparse.Namespace) -> None: 2493 prog = parse(tokenize(sys.stdin.read())) 2494 serializer = Serializer() 2495 serializer.serialize(prog) 2496 sys.stdout.buffer.write(serializer.output) 2497 2498 2499def server_command(args: argparse.Namespace) -> None: 2500 import http.server 2501 import socketserver 2502 import hashlib 2503 2504 dir = os.path.abspath(args.directory) 2505 if not os.path.isdir(dir): 2506 print(f"Error: {dir} is not a valid directory") 2507 sys.exit(1) 2508 2509 scraps = {} 2510 for root, _, files in os.walk(dir): 2511 for file in files: 2512 file_path = os.path.join(root, file) 2513 rel_path = os.path.relpath(file_path, dir) 2514 if file.startswith("$"): 2515 logger.debug(f"Skipping {rel_path}") 2516 continue 2517 rel_path_without_ext = os.path.splitext(rel_path)[0] 2518 with open(file_path, "r") as f: 2519 try: 2520 program = parse(tokenize(f.read())) 2521 serializer = Serializer() 2522 serializer.serialize(program) 2523 serialized = bytes(serializer.output) 2524 scraps[rel_path_without_ext] = serialized 2525 logger.debug(f"Loaded {rel_path_without_ext}") 2526 file_hash = hashlib.sha256(serialized).hexdigest() 2527 scraps[f"${file_hash}"] = serialized 2528 logger.debug(f"Loaded {rel_path_without_ext} as ${file_hash}") 2529 except Exception as e: 2530 logger.error(f"Error processing {file_path}: {e}") 2531 2532 keep_serving = True 2533 2534 class ScrapHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): 2535 def do_QUIT(self) -> None: 2536 self.send_response(200) 2537 self.end_headers() 2538 self.wfile.write(b"Quitting") 2539 nonlocal keep_serving 2540 keep_serving = False 2541 2542 def do_GET(self) -> None: 2543 path = self.path.lstrip("/") 2544 scrap = scraps.get(path) 2545 if scrap is not None: 2546 self.send_response(200) 2547 self.send_header("Content-Type", "application/scrap; charset=binary") 2548 self.send_header("Content-Disposition", f"attachment; filename={json.dumps(f'{path}.scrap')}") 2549 self.send_header("Content-Length", str(len(scrap))) 2550 self.end_headers() 2551 self.wfile.write(scrap) 2552 else: 2553 self.send_response(404) 2554 self.send_header("Content-Type", "text/plain") 2555 self.end_headers() 2556 self.wfile.write(b"File not found") 2557 2558 handler = ScrapHTTPRequestHandler 2559 with socketserver.TCPServer((args.host, args.port), handler) as httpd: 2560 logger.info(f"Serving {dir} at http://{args.host}:{args.port}") 2561 while keep_serving: 2562 httpd.handle_request() 2563 2564 2565def main() -> None: 2566 parser = argparse.ArgumentParser(prog="scrapscript") 2567 subparsers = parser.add_subparsers(dest="command") 2568 2569 repl = subparsers.add_parser("repl") 2570 repl.set_defaults(func=repl_command) 2571 repl.add_argument("--debug", action="store_true") 2572 2573 eval_ = subparsers.add_parser("eval") 2574 eval_.set_defaults(func=eval_command) 2575 eval_.add_argument("program_file", type=argparse.FileType("r")) 2576 eval_.add_argument("--debug", action="store_true") 2577 2578 check = subparsers.add_parser("check") 2579 check.set_defaults(func=check_command) 2580 check.add_argument("program_file", type=argparse.FileType("r")) 2581 check.add_argument("--debug", action="store_true") 2582 2583 apply = subparsers.add_parser("apply") 2584 apply.set_defaults(func=apply_command) 2585 apply.add_argument("program") 2586 apply.add_argument("--debug", action="store_true") 2587 2588 comp = subparsers.add_parser("compile") 2589 comp.set_defaults(func=compile_command) 2590 comp.add_argument("file") 2591 comp.add_argument("-o", "--output", default="output.c") 2592 comp.add_argument("--format", action="store_true") 2593 comp.add_argument("--compile", action="store_true") 2594 comp.add_argument("--memory", type=int) 2595 comp.add_argument("--handle-stack-size", type=int) 2596 comp.add_argument("--run", action="store_true") 2597 comp.add_argument("--debug", action="store_true", default=False) 2598 comp.add_argument("--check", action="store_true", default=False) 2599 # The platform is in the same directory as this file 2600 comp.add_argument("--platform", default=os.path.join(os.path.dirname(__file__), "cli.c")) 2601 2602 flat = subparsers.add_parser("flat") 2603 flat.set_defaults(func=flat_command) 2604 2605 yard = subparsers.add_parser("yard") 2606 yard.set_defaults(func=lambda _: yard.print_help()) 2607 yard_subparsers = yard.add_subparsers(dest="yard_command") 2608 2609 yard_server = yard_subparsers.add_parser("server") 2610 yard_server.set_defaults(func=server_command) 2611 yard_server.add_argument("directory", type=str, nargs="?", default=".", help="Directory to serve") 2612 yard_server.add_argument("--host", type=str, default="127.0.0.1", help="Host to bind to") 2613 yard_server.add_argument("--port", type=int, default=8080, help="Port to listen on") 2614 2615 args = parser.parse_args() 2616 if not args.command: 2617 args.debug = False 2618 repl_command(args) 2619 else: 2620 args.func(args) 2621 2622 2623if __name__ == "__main__": 2624 # This is so that we can use scrapscript.py as a main but also import 2625 # things from `scrapscript` and not have that be a separate module. 2626 sys.modules["scrapscript"] = sys.modules[__name__] 2627 main()