this repo has no description
at trunk 4122 lines 153 kB view raw
1import unittest 2import re 3from typing import Optional 4import urllib.request 5 6# ruff: noqa: F405 7# ruff: noqa: F403 8from scrapscript import * 9 10 11class PeekableTests(unittest.TestCase): 12 def test_can_create_peekable(self) -> None: 13 Peekable(iter([1, 2, 3])) 14 15 def test_can_iterate_over_peekable(self) -> None: 16 sequence = [1, 2, 3] 17 for idx, e in enumerate(Peekable(iter(sequence))): 18 self.assertEqual(sequence[idx], e) 19 20 def test_peek_next(self) -> None: 21 iterator = Peekable(iter([1, 2, 3])) 22 self.assertEqual(iterator.peek(), 1) 23 self.assertEqual(next(iterator), 1) 24 self.assertEqual(iterator.peek(), 2) 25 self.assertEqual(next(iterator), 2) 26 self.assertEqual(iterator.peek(), 3) 27 self.assertEqual(next(iterator), 3) 28 with self.assertRaises(StopIteration): 29 iterator.peek() 30 with self.assertRaises(StopIteration): 31 next(iterator) 32 33 def test_can_peek_peekable(self) -> None: 34 sequence = [1, 2, 3] 35 p = Peekable(iter(sequence)) 36 self.assertEqual(p.peek(), 1) 37 # Ensure we can peek repeatedly 38 self.assertEqual(p.peek(), 1) 39 for idx, e in enumerate(p): 40 self.assertEqual(sequence[idx], e) 41 42 def test_peek_on_empty_peekable_raises_stop_iteration(self) -> None: 43 empty = Peekable(iter([])) 44 with self.assertRaises(StopIteration): 45 empty.peek() 46 47 def test_next_on_empty_peekable_raises_stop_iteration(self) -> None: 48 empty = Peekable(iter([])) 49 with self.assertRaises(StopIteration): 50 next(empty) 51 52 53class TokenizerTests(unittest.TestCase): 54 def test_tokenize_digit(self) -> None: 55 self.assertEqual(list(tokenize("1")), [IntLit(1)]) 56 57 def test_tokenize_multiple_digits(self) -> None: 58 self.assertEqual(list(tokenize("123")), [IntLit(123)]) 59 60 def test_tokenize_negative_int(self) -> None: 61 self.assertEqual(list(tokenize("-123")), [Operator("-"), IntLit(123)]) 62 63 def test_tokenize_float(self) -> None: 64 self.assertEqual(list(tokenize("3.14")), [FloatLit(3.14)]) 65 66 def test_tokenize_negative_float(self) -> None: 67 self.assertEqual(list(tokenize("-3.14")), [Operator("-"), FloatLit(3.14)]) 68 69 @unittest.skip("TODO: support floats with no integer part") 70 def test_tokenize_float_with_no_integer_part(self) -> None: 71 self.assertEqual(list(tokenize(".14")), [FloatLit(0.14)]) 72 73 def test_tokenize_float_with_no_decimal_part(self) -> None: 74 self.assertEqual(list(tokenize("10.")), [FloatLit(10.0)]) 75 76 def test_tokenize_float_with_multiple_decimal_points_raises_parse_error(self) -> None: 77 with self.assertRaisesRegex(ParseError, re.escape("unexpected token '.'")): 78 list(tokenize("1.0.1")) 79 80 def test_tokenize_binop(self) -> None: 81 self.assertEqual(list(tokenize("1 + 2")), [IntLit(1), Operator("+"), IntLit(2)]) 82 83 def test_tokenize_binop_no_spaces(self) -> None: 84 self.assertEqual(list(tokenize("1+2")), [IntLit(1), Operator("+"), IntLit(2)]) 85 86 def test_tokenize_two_oper_chars_returns_two_ops(self) -> None: 87 self.assertEqual(list(tokenize(",:")), [Operator(","), Operator(":")]) 88 89 def test_tokenize_binary_sub_no_spaces(self) -> None: 90 self.assertEqual(list(tokenize("1-2")), [IntLit(1), Operator("-"), IntLit(2)]) 91 92 def test_tokenize_binop_var(self) -> None: 93 ops = ["+", "-", "*", "/", "^", "%", "==", "/=", "<", ">", "<=", ">=", "&&", "||", "++", ">+", "+<"] 94 for op in ops: 95 with self.subTest(op=op): 96 self.assertEqual(list(tokenize(f"a {op} b")), [Name("a"), Operator(op), Name("b")]) 97 self.assertEqual(list(tokenize(f"a{op}b")), [Name("a"), Operator(op), Name("b")]) 98 99 def test_tokenize_var(self) -> None: 100 self.assertEqual(list(tokenize("abc")), [Name("abc")]) 101 102 @unittest.skip("TODO: make this fail to tokenize") 103 def test_tokenize_var_with_quote(self) -> None: 104 self.assertEqual(list(tokenize("sha1'abc")), [Name("sha1'abc")]) 105 106 def test_tokenize_dollar_sha1_var(self) -> None: 107 self.assertEqual(list(tokenize("$sha1'foo")), [Name("$sha1'foo")]) 108 109 def test_tokenize_dollar_dollar_var(self) -> None: 110 self.assertEqual(list(tokenize("$$bills")), [Name("$$bills")]) 111 112 def test_tokenize_dot_dot_raises_parse_error(self) -> None: 113 with self.assertRaisesRegex(ParseError, re.escape("unexpected token '..'")): 114 list(tokenize("..")) 115 116 def test_tokenize_spread(self) -> None: 117 self.assertEqual(list(tokenize("...")), [Operator("...")]) 118 119 def test_ignore_whitespace(self) -> None: 120 self.assertEqual(list(tokenize("1\n+\t2")), [IntLit(1), Operator("+"), IntLit(2)]) 121 122 def test_ignore_line_comment(self) -> None: 123 self.assertEqual(list(tokenize("-- 1\n2")), [IntLit(2)]) 124 125 def test_tokenize_string(self) -> None: 126 self.assertEqual(list(tokenize('"hello"')), [StringLit("hello")]) 127 128 def test_tokenize_string_with_spaces(self) -> None: 129 self.assertEqual(list(tokenize('"hello world"')), [StringLit("hello world")]) 130 131 def test_tokenize_string_missing_end_quote_raises_parse_error(self) -> None: 132 with self.assertRaisesRegex(UnexpectedEOFError, "while reading string"): 133 list(tokenize('"hello')) 134 135 def test_tokenize_with_trailing_whitespace(self) -> None: 136 self.assertEqual(list(tokenize("- ")), [Operator("-")]) 137 self.assertEqual(list(tokenize("-- ")), []) 138 self.assertEqual(list(tokenize("+ ")), [Operator("+")]) 139 self.assertEqual(list(tokenize("123 ")), [IntLit(123)]) 140 self.assertEqual(list(tokenize("abc ")), [Name("abc")]) 141 self.assertEqual(list(tokenize("[ ")), [LeftBracket()]) 142 self.assertEqual(list(tokenize("] ")), [RightBracket()]) 143 144 def test_tokenize_empty_list(self) -> None: 145 self.assertEqual(list(tokenize("[ ]")), [LeftBracket(), RightBracket()]) 146 147 def test_tokenize_empty_list_with_spaces(self) -> None: 148 self.assertEqual(list(tokenize("[ ]")), [LeftBracket(), RightBracket()]) 149 150 def test_tokenize_list_with_items(self) -> None: 151 self.assertEqual( 152 list(tokenize("[ 1 , 2 ]")), [LeftBracket(), IntLit(1), Operator(","), IntLit(2), RightBracket()] 153 ) 154 155 def test_tokenize_list_with_no_spaces(self) -> None: 156 self.assertEqual(list(tokenize("[1,2]")), [LeftBracket(), IntLit(1), Operator(","), IntLit(2), RightBracket()]) 157 158 def test_tokenize_function(self) -> None: 159 self.assertEqual( 160 list(tokenize("a -> b -> a + b")), 161 [Name("a"), Operator("->"), Name("b"), Operator("->"), Name("a"), Operator("+"), Name("b")], 162 ) 163 164 def test_tokenize_function_with_no_spaces(self) -> None: 165 self.assertEqual( 166 list(tokenize("a->b->a+b")), 167 [Name("a"), Operator("->"), Name("b"), Operator("->"), Name("a"), Operator("+"), Name("b")], 168 ) 169 170 def test_tokenize_where(self) -> None: 171 self.assertEqual(list(tokenize("a . b")), [Name("a"), Operator("."), Name("b")]) 172 173 def test_tokenize_assert(self) -> None: 174 self.assertEqual(list(tokenize("a ? b")), [Name("a"), Operator("?"), Name("b")]) 175 176 def test_tokenize_hastype(self) -> None: 177 self.assertEqual(list(tokenize("a : b")), [Name("a"), Operator(":"), Name("b")]) 178 179 def test_tokenize_minus_returns_minus(self) -> None: 180 self.assertEqual(list(tokenize("-")), [Operator("-")]) 181 182 def test_tokenize_tilde_raises_parse_error(self) -> None: 183 with self.assertRaisesRegex(ParseError, "unexpected token '~'"): 184 list(tokenize("~")) 185 186 def test_tokenize_tilde_equals_raises_parse_error(self) -> None: 187 with self.assertRaisesRegex(ParseError, "unexpected token '~'"): 188 list(tokenize("~=")) 189 190 def test_tokenize_tilde_tilde_returns_empty_bytes(self) -> None: 191 self.assertEqual(list(tokenize("~~")), [BytesLit("", 64)]) 192 193 def test_tokenize_bytes_returns_bytes_base64(self) -> None: 194 self.assertEqual(list(tokenize("~~QUJD")), [BytesLit("QUJD", 64)]) 195 196 def test_tokenize_bytes_base85(self) -> None: 197 self.assertEqual(list(tokenize("~~85'K|(_")), [BytesLit("K|(_", 85)]) 198 199 def test_tokenize_bytes_base64(self) -> None: 200 self.assertEqual(list(tokenize("~~64'QUJD")), [BytesLit("QUJD", 64)]) 201 202 def test_tokenize_bytes_base32(self) -> None: 203 self.assertEqual(list(tokenize("~~32'IFBEG===")), [BytesLit("IFBEG===", 32)]) 204 205 def test_tokenize_bytes_base16(self) -> None: 206 self.assertEqual(list(tokenize("~~16'414243")), [BytesLit("414243", 16)]) 207 208 def test_tokenize_hole(self) -> None: 209 self.assertEqual(list(tokenize("()")), [LeftParen(), RightParen()]) 210 211 def test_tokenize_hole_with_spaces(self) -> None: 212 self.assertEqual(list(tokenize("( )")), [LeftParen(), RightParen()]) 213 214 def test_tokenize_parenthetical_expression(self) -> None: 215 self.assertEqual(list(tokenize("(1+2)")), [LeftParen(), IntLit(1), Operator("+"), IntLit(2), RightParen()]) 216 217 def test_tokenize_pipe(self) -> None: 218 self.assertEqual( 219 list(tokenize("1 |> f . f = a -> a + 1")), 220 [ 221 IntLit(1), 222 Operator("|>"), 223 Name("f"), 224 Operator("."), 225 Name("f"), 226 Operator("="), 227 Name("a"), 228 Operator("->"), 229 Name("a"), 230 Operator("+"), 231 IntLit(1), 232 ], 233 ) 234 235 def test_tokenize_reverse_pipe(self) -> None: 236 self.assertEqual( 237 list(tokenize("f <| 1 . f = a -> a + 1")), 238 [ 239 Name("f"), 240 Operator("<|"), 241 IntLit(1), 242 Operator("."), 243 Name("f"), 244 Operator("="), 245 Name("a"), 246 Operator("->"), 247 Name("a"), 248 Operator("+"), 249 IntLit(1), 250 ], 251 ) 252 253 def test_tokenize_record_no_fields(self) -> None: 254 self.assertEqual( 255 list(tokenize("{ }")), 256 [LeftBrace(), RightBrace()], 257 ) 258 259 def test_tokenize_record_no_fields_no_spaces(self) -> None: 260 self.assertEqual( 261 list(tokenize("{}")), 262 [LeftBrace(), RightBrace()], 263 ) 264 265 def test_tokenize_record_one_field(self) -> None: 266 self.assertEqual( 267 list(tokenize("{ a = 4 }")), 268 [LeftBrace(), Name("a"), Operator("="), IntLit(4), RightBrace()], 269 ) 270 271 def test_tokenize_record_multiple_fields(self) -> None: 272 self.assertEqual( 273 list(tokenize('{ a = 4, b = "z" }')), 274 [ 275 LeftBrace(), 276 Name("a"), 277 Operator("="), 278 IntLit(4), 279 Operator(","), 280 Name("b"), 281 Operator("="), 282 StringLit("z"), 283 RightBrace(), 284 ], 285 ) 286 287 def test_tokenize_record_access(self) -> None: 288 self.assertEqual( 289 list(tokenize("r@a")), 290 [Name("r"), Operator("@"), Name("a")], 291 ) 292 293 def test_tokenize_right_eval(self) -> None: 294 self.assertEqual(list(tokenize("a!b")), [Name("a"), Operator("!"), Name("b")]) 295 296 def test_tokenize_match(self) -> None: 297 self.assertEqual( 298 list(tokenize("g = | 1 -> 2 | 2 -> 3")), 299 [ 300 Name("g"), 301 Operator("="), 302 Operator("|"), 303 IntLit(1), 304 Operator("->"), 305 IntLit(2), 306 Operator("|"), 307 IntLit(2), 308 Operator("->"), 309 IntLit(3), 310 ], 311 ) 312 313 def test_tokenize_compose(self) -> None: 314 self.assertEqual( 315 list(tokenize("f >> g")), 316 [Name("f"), Operator(">>"), Name("g")], 317 ) 318 319 def test_tokenize_compose_reverse(self) -> None: 320 self.assertEqual( 321 list(tokenize("f << g")), 322 [Name("f"), Operator("<<"), Name("g")], 323 ) 324 325 def test_first_lineno_is_one(self) -> None: 326 l = Lexer("abc") 327 self.assertEqual(l.lineno, 1) 328 329 def test_first_colno_is_one(self) -> None: 330 l = Lexer("abc") 331 self.assertEqual(l.colno, 1) 332 333 def test_first_line_is_empty(self) -> None: 334 l = Lexer("abc") 335 self.assertEqual(l.line, "") 336 337 def test_read_char_increments_colno(self) -> None: 338 l = Lexer("abc") 339 l.read_char() 340 self.assertEqual(l.colno, 2) 341 self.assertEqual(l.lineno, 1) 342 343 def test_read_newline_increments_lineno(self) -> None: 344 l = Lexer("ab\nc") 345 l.read_char() 346 l.read_char() 347 l.read_char() 348 self.assertEqual(l.lineno, 2) 349 self.assertEqual(l.colno, 1) 350 351 def test_read_char_increments_byteno(self) -> None: 352 l = Lexer("abc") 353 l.read_char() 354 self.assertEqual(l.byteno, 1) 355 l.read_char() 356 self.assertEqual(l.byteno, 2) 357 l.read_char() 358 self.assertEqual(l.byteno, 3) 359 360 def test_read_char_appends_to_line(self) -> None: 361 l = Lexer("ab\nc") 362 l.read_char() 363 l.read_char() 364 self.assertEqual(l.line, "ab") 365 l.read_char() 366 self.assertEqual(l.line, "") 367 368 def test_read_token_sets_start_and_end_linenos(self) -> None: 369 l = Lexer("a b \n c d") 370 a = l.read_token() 371 b = l.read_token() 372 c = l.read_token() 373 d = l.read_token() 374 375 self.assertEqual(a.source_extent.start.lineno, 1) 376 self.assertEqual(a.source_extent.end.lineno, 1) 377 378 self.assertEqual(b.source_extent.start.lineno, 1) 379 self.assertEqual(b.source_extent.end.lineno, 1) 380 381 self.assertEqual(c.source_extent.start.lineno, 2) 382 self.assertEqual(c.source_extent.end.lineno, 2) 383 384 self.assertEqual(d.source_extent.start.lineno, 2) 385 self.assertEqual(d.source_extent.end.lineno, 2) 386 387 def test_read_token_sets_source_extents_for_variables(self) -> None: 388 l = Lexer("aa bbbb \n ccccc ddddddd") 389 390 a = l.read_token() 391 b = l.read_token() 392 c = l.read_token() 393 d = l.read_token() 394 395 self.assertEqual(a.source_extent.start.lineno, 1) 396 self.assertEqual(a.source_extent.end.lineno, 1) 397 self.assertEqual(a.source_extent.start.colno, 1) 398 self.assertEqual(a.source_extent.end.colno, 2) 399 self.assertEqual(a.source_extent.start.byteno, 0) 400 self.assertEqual(a.source_extent.end.byteno, 1) 401 402 self.assertEqual(b.source_extent.start.lineno, 1) 403 self.assertEqual(b.source_extent.end.lineno, 1) 404 self.assertEqual(b.source_extent.start.colno, 4) 405 self.assertEqual(b.source_extent.end.colno, 7) 406 self.assertEqual(b.source_extent.start.byteno, 3) 407 self.assertEqual(b.source_extent.end.byteno, 6) 408 409 self.assertEqual(c.source_extent.start.lineno, 2) 410 self.assertEqual(c.source_extent.end.lineno, 2) 411 self.assertEqual(c.source_extent.start.colno, 2) 412 self.assertEqual(c.source_extent.end.colno, 6) 413 self.assertEqual(c.source_extent.start.byteno, 10) 414 self.assertEqual(c.source_extent.end.byteno, 14) 415 416 self.assertEqual(d.source_extent.start.lineno, 2) 417 self.assertEqual(d.source_extent.end.lineno, 2) 418 self.assertEqual(d.source_extent.start.colno, 8) 419 self.assertEqual(d.source_extent.end.colno, 14) 420 self.assertEqual(d.source_extent.start.byteno, 16) 421 self.assertEqual(d.source_extent.end.byteno, 22) 422 423 def test_read_token_correctly_sets_source_extents_for_variants(self) -> None: 424 l = Lexer("# \n\r\n\t abc") 425 426 a = l.read_token() 427 b = l.read_token() 428 429 self.assertEqual(a.source_extent.start.lineno, 1) 430 self.assertEqual(a.source_extent.end.lineno, 1) 431 self.assertEqual(a.source_extent.start.colno, 1) 432 # TODO(max): Should tabs count as one column? 433 self.assertEqual(a.source_extent.end.colno, 1) 434 435 self.assertEqual(b.source_extent.start.lineno, 3) 436 self.assertEqual(b.source_extent.end.lineno, 3) 437 self.assertEqual(b.source_extent.start.colno, 3) 438 self.assertEqual(b.source_extent.end.colno, 5) 439 440 def test_read_token_correctly_sets_source_extents_for_strings(self) -> None: 441 l = Lexer('"今日は、Maxさん。"') 442 a = l.read_token() 443 444 self.assertEqual(a.source_extent.start.lineno, 1) 445 self.assertEqual(a.source_extent.end.lineno, 1) 446 447 self.assertEqual(a.source_extent.start.colno, 1) 448 self.assertEqual(a.source_extent.end.colno, 12) 449 450 self.assertEqual(a.source_extent.start.byteno, 0) 451 self.assertEqual(a.source_extent.end.byteno, 25) 452 453 def test_read_token_correctly_sets_source_extents_for_byte_literals(self) -> None: 454 l = Lexer("~~QUJD ~~85'K|(_ ~~64'QUJD\n ~~32'IFBEG=== ~~16'414243") 455 a = l.read_token() 456 b = l.read_token() 457 c = l.read_token() 458 d = l.read_token() 459 e = l.read_token() 460 461 self.assertEqual(a.source_extent.start.lineno, 1) 462 self.assertEqual(a.source_extent.end.lineno, 1) 463 self.assertEqual(a.source_extent.start.colno, 1) 464 self.assertEqual(a.source_extent.end.colno, 6) 465 self.assertEqual(a.source_extent.start.byteno, 0) 466 self.assertEqual(a.source_extent.end.byteno, 5) 467 468 self.assertEqual(b.source_extent.start.lineno, 1) 469 self.assertEqual(b.source_extent.end.lineno, 1) 470 self.assertEqual(b.source_extent.start.colno, 8) 471 self.assertEqual(b.source_extent.end.colno, 16) 472 self.assertEqual(b.source_extent.start.byteno, 7) 473 self.assertEqual(b.source_extent.end.byteno, 15) 474 475 self.assertEqual(c.source_extent.start.lineno, 1) 476 self.assertEqual(c.source_extent.end.lineno, 1) 477 self.assertEqual(c.source_extent.start.colno, 18) 478 self.assertEqual(c.source_extent.end.colno, 26) 479 self.assertEqual(c.source_extent.start.byteno, 17) 480 self.assertEqual(c.source_extent.end.byteno, 25) 481 482 self.assertEqual(d.source_extent.start.lineno, 2) 483 self.assertEqual(d.source_extent.end.lineno, 2) 484 self.assertEqual(d.source_extent.start.colno, 2) 485 self.assertEqual(d.source_extent.end.colno, 14) 486 self.assertEqual(d.source_extent.start.byteno, 28) 487 self.assertEqual(d.source_extent.end.byteno, 40) 488 489 self.assertEqual(e.source_extent.start.lineno, 2) 490 self.assertEqual(e.source_extent.end.lineno, 2) 491 self.assertEqual(e.source_extent.start.colno, 16) 492 self.assertEqual(e.source_extent.end.colno, 26) 493 self.assertEqual(e.source_extent.start.byteno, 42) 494 self.assertEqual(e.source_extent.end.byteno, 52) 495 496 def test_read_token_correctly_sets_source_extents_for_numbers(self) -> None: 497 l = Lexer("123 123.456") 498 a = l.read_token() 499 b = l.read_token() 500 501 self.assertEqual(a.source_extent.start.lineno, 1) 502 self.assertEqual(a.source_extent.end.lineno, 1) 503 self.assertEqual(a.source_extent.start.colno, 1) 504 self.assertEqual(a.source_extent.end.colno, 3) 505 self.assertEqual(a.source_extent.start.byteno, 0) 506 self.assertEqual(a.source_extent.end.byteno, 2) 507 508 self.assertEqual(b.source_extent.start.lineno, 1) 509 self.assertEqual(b.source_extent.end.lineno, 1) 510 self.assertEqual(b.source_extent.start.colno, 5) 511 self.assertEqual(b.source_extent.end.colno, 11) 512 self.assertEqual(b.source_extent.start.byteno, 4) 513 self.assertEqual(b.source_extent.end.byteno, 10) 514 515 def test_read_token_correctly_sets_source_extents_for_operators(self) -> None: 516 l = Lexer("> >>") 517 a = l.read_token() 518 b = l.read_token() 519 520 self.assertEqual(a.source_extent.start.lineno, 1) 521 self.assertEqual(a.source_extent.end.lineno, 1) 522 self.assertEqual(a.source_extent.start.colno, 1) 523 self.assertEqual(a.source_extent.end.colno, 1) 524 self.assertEqual(a.source_extent.start.byteno, 0) 525 self.assertEqual(a.source_extent.end.byteno, 0) 526 527 self.assertEqual(b.source_extent.start.lineno, 1) 528 self.assertEqual(b.source_extent.end.lineno, 1) 529 self.assertEqual(b.source_extent.start.colno, 3) 530 self.assertEqual(b.source_extent.end.colno, 4) 531 self.assertEqual(b.source_extent.start.byteno, 2) 532 self.assertEqual(b.source_extent.end.byteno, 3) 533 534 def test_tokenize_list_with_only_spread(self) -> None: 535 self.assertEqual(list(tokenize("[ ... ]")), [LeftBracket(), Operator("..."), RightBracket()]) 536 537 def test_tokenize_list_with_spread(self) -> None: 538 self.assertEqual( 539 list(tokenize("[ 1 , ... ]")), 540 [ 541 LeftBracket(), 542 IntLit(1), 543 Operator(","), 544 Operator("..."), 545 RightBracket(), 546 ], 547 ) 548 549 def test_tokenize_list_with_spread_no_spaces(self) -> None: 550 self.assertEqual( 551 list(tokenize("[ 1,... ]")), 552 [ 553 LeftBracket(), 554 IntLit(1), 555 Operator(","), 556 Operator("..."), 557 RightBracket(), 558 ], 559 ) 560 561 def test_tokenize_list_with_named_spread(self) -> None: 562 self.assertEqual( 563 list(tokenize("[1,...rest]")), 564 [ 565 LeftBracket(), 566 IntLit(1), 567 Operator(","), 568 Operator("..."), 569 Name("rest"), 570 RightBracket(), 571 ], 572 ) 573 574 def test_tokenize_record_with_only_spread(self) -> None: 575 self.assertEqual( 576 list(tokenize("{ ... }")), 577 [ 578 LeftBrace(), 579 Operator("..."), 580 RightBrace(), 581 ], 582 ) 583 584 def test_tokenize_record_with_spread(self) -> None: 585 self.assertEqual( 586 list(tokenize("{ x = 1, ...}")), 587 [ 588 LeftBrace(), 589 Name("x"), 590 Operator("="), 591 IntLit(1), 592 Operator(","), 593 Operator("..."), 594 RightBrace(), 595 ], 596 ) 597 598 def test_tokenize_record_with_spread_no_spaces(self) -> None: 599 self.assertEqual( 600 list(tokenize("{x=1,...}")), 601 [ 602 LeftBrace(), 603 Name("x"), 604 Operator("="), 605 IntLit(1), 606 Operator(","), 607 Operator("..."), 608 RightBrace(), 609 ], 610 ) 611 612 def test_tokenize_variant_with_whitespace(self) -> None: 613 self.assertEqual(list(tokenize("# \n\r\n\t abc")), [Hash(), Name("abc")]) 614 615 def test_tokenize_variant_with_no_space(self) -> None: 616 self.assertEqual(list(tokenize("#abc")), [Hash(), Name("abc")]) 617 618 619class ParserTests(unittest.TestCase): 620 def test_parse_with_empty_tokens_raises_parse_error(self) -> None: 621 with self.assertRaises(UnexpectedEOFError) as ctx: 622 parse(Peekable(iter([]))) 623 self.assertEqual(ctx.exception.args[0], "unexpected end of input") 624 625 def test_parse_digit_returns_int(self) -> None: 626 self.assertEqual(parse(Peekable(iter([IntLit(1)]))), Int(1)) 627 628 def test_parse_digits_returns_int(self) -> None: 629 self.assertEqual(parse(Peekable(iter([IntLit(123)]))), Int(123)) 630 631 def test_parse_negative_int_returns_negative_int(self) -> None: 632 self.assertEqual(parse(Peekable(iter([Operator("-"), IntLit(123)]))), Int(-123)) 633 634 def test_parse_negative_var_returns_binary_sub_var(self) -> None: 635 self.assertEqual(parse(Peekable(iter([Operator("-"), Name("x")]))), Binop(BinopKind.SUB, Int(0), Var("x"))) 636 637 def test_parse_negative_int_binds_tighter_than_plus(self) -> None: 638 self.assertEqual( 639 parse(Peekable(iter([Operator("-"), Name("l"), Operator("+"), Name("r")]))), 640 Binop(BinopKind.ADD, Binop(BinopKind.SUB, Int(0), Var("l")), Var("r")), 641 ) 642 643 def test_parse_negative_int_binds_tighter_than_mul(self) -> None: 644 self.assertEqual( 645 parse(Peekable(iter([Operator("-"), Name("l"), Operator("*"), Name("r")]))), 646 Binop(BinopKind.MUL, Binop(BinopKind.SUB, Int(0), Var("l")), Var("r")), 647 ) 648 649 def test_parse_negative_int_binds_tighter_than_index(self) -> None: 650 self.assertEqual( 651 parse(Peekable(iter([Operator("-"), Name("l"), Operator("@"), Name("r")]))), 652 Access(Binop(BinopKind.SUB, Int(0), Var("l")), Var("r")), 653 ) 654 655 def test_parse_negative_int_binds_tighter_than_apply(self) -> None: 656 self.assertEqual( 657 parse(Peekable(iter([Operator("-"), Name("l"), Name("r")]))), 658 Apply(Binop(BinopKind.SUB, Int(0), Var("l")), Var("r")), 659 ) 660 661 def test_parse_decimal_returns_float(self) -> None: 662 self.assertEqual(parse(Peekable(iter([FloatLit(3.14)]))), Float(3.14)) 663 664 def test_parse_negative_float_returns_binary_sub_float(self) -> None: 665 self.assertEqual(parse(Peekable(iter([Operator("-"), FloatLit(3.14)]))), Float(-3.14)) 666 667 def test_parse_var_returns_var(self) -> None: 668 self.assertEqual(parse(Peekable(iter([Name("abc_123")]))), Var("abc_123")) 669 670 def test_parse_sha_var_returns_var(self) -> None: 671 self.assertEqual(parse(Peekable(iter([Name("$sha1'abc")]))), Var("$sha1'abc")) 672 673 def test_parse_sha_var_without_quote_returns_var(self) -> None: 674 self.assertEqual(parse(Peekable(iter([Name("$sha1abc")]))), Var("$sha1abc")) 675 676 def test_parse_dollar_returns_var(self) -> None: 677 self.assertEqual(parse(Peekable(iter([Name("$")]))), Var("$")) 678 679 def test_parse_dollar_dollar_returns_var(self) -> None: 680 self.assertEqual(parse(Peekable(iter([Name("$$")]))), Var("$$")) 681 682 @unittest.skip("TODO: make this fail to parse") 683 def test_parse_sha_var_without_dollar_raises_parse_error(self) -> None: 684 with self.assertRaisesRegex(ParseError, "unexpected token"): 685 parse(Peekable(iter([Name("sha1'abc")]))) 686 687 def test_parse_dollar_dollar_var_returns_var(self) -> None: 688 self.assertEqual(parse(Peekable(iter([Name("$$bills")]))), Var("$$bills")) 689 690 def test_parse_bytes_returns_bytes(self) -> None: 691 self.assertEqual(parse(Peekable(iter([BytesLit("QUJD", 64)]))), Bytes(b"ABC")) 692 693 def test_parse_binary_add_returns_binop(self) -> None: 694 self.assertEqual( 695 parse(Peekable(iter([IntLit(1), Operator("+"), IntLit(2)]))), Binop(BinopKind.ADD, Int(1), Int(2)) 696 ) 697 698 def test_parse_binary_sub_returns_binop(self) -> None: 699 self.assertEqual( 700 parse(Peekable(iter([IntLit(1), Operator("-"), IntLit(2)]))), Binop(BinopKind.SUB, Int(1), Int(2)) 701 ) 702 703 def test_parse_binary_add_right_returns_binop(self) -> None: 704 self.assertEqual( 705 parse(Peekable(iter([IntLit(1), Operator("+"), IntLit(2), Operator("+"), IntLit(3)]))), 706 Binop(BinopKind.ADD, Int(1), Binop(BinopKind.ADD, Int(2), Int(3))), 707 ) 708 709 def test_mul_binds_tighter_than_add_right(self) -> None: 710 self.assertEqual( 711 parse(Peekable(iter([IntLit(1), Operator("+"), IntLit(2), Operator("*"), IntLit(3)]))), 712 Binop(BinopKind.ADD, Int(1), Binop(BinopKind.MUL, Int(2), Int(3))), 713 ) 714 715 def test_mul_binds_tighter_than_add_left(self) -> None: 716 self.assertEqual( 717 parse(Peekable(iter([IntLit(1), Operator("*"), IntLit(2), Operator("+"), IntLit(3)]))), 718 Binop(BinopKind.ADD, Binop(BinopKind.MUL, Int(1), Int(2)), Int(3)), 719 ) 720 721 def test_mul_and_div_bind_left_to_right(self) -> None: 722 self.assertEqual( 723 parse(Peekable(iter([IntLit(1), Operator("/"), IntLit(3), Operator("*"), IntLit(3)]))), 724 Binop(BinopKind.MUL, Binop(BinopKind.DIV, Int(1), Int(3)), Int(3)), 725 ) 726 727 def test_exp_binds_tighter_than_mul_right(self) -> None: 728 self.assertEqual( 729 parse(Peekable(iter([IntLit(5), Operator("*"), IntLit(2), Operator("^"), IntLit(3)]))), 730 Binop(BinopKind.MUL, Int(5), Binop(BinopKind.EXP, Int(2), Int(3))), 731 ) 732 733 def test_list_access_binds_tighter_than_append(self) -> None: 734 self.assertEqual( 735 parse(Peekable(iter([Name("a"), Operator("+<"), Name("ls"), Operator("@"), IntLit(0)]))), 736 Binop(BinopKind.LIST_APPEND, Var("a"), Access(Var("ls"), Int(0))), 737 ) 738 739 def test_parse_binary_str_concat_returns_binop(self) -> None: 740 self.assertEqual( 741 parse(Peekable(iter([StringLit("abc"), Operator("++"), StringLit("def")]))), 742 Binop(BinopKind.STRING_CONCAT, String("abc"), String("def")), 743 ) 744 745 def test_parse_binary_list_cons_returns_binop(self) -> None: 746 self.assertEqual( 747 parse(Peekable(iter([Name("a"), Operator(">+"), Name("b")]))), 748 Binop(BinopKind.LIST_CONS, Var("a"), Var("b")), 749 ) 750 751 def test_parse_binary_list_append_returns_binop(self) -> None: 752 self.assertEqual( 753 parse(Peekable(iter([Name("a"), Operator("+<"), Name("b")]))), 754 Binop(BinopKind.LIST_APPEND, Var("a"), Var("b")), 755 ) 756 757 def test_parse_binary_op_returns_binop(self) -> None: 758 ops = ["+", "-", "*", "/", "^", "%", "==", "/=", "<", ">", "<=", ">=", "&&", "||", "++", ">+", "+<"] 759 for op in ops: 760 with self.subTest(op=op): 761 kind = BinopKind.from_str(op) 762 self.assertEqual( 763 parse(Peekable(iter([Name("a"), Operator(op), Name("b")]))), Binop(kind, Var("a"), Var("b")) 764 ) 765 766 def test_parse_empty_list(self) -> None: 767 self.assertEqual( 768 parse(Peekable(iter([LeftBracket(), RightBracket()]))), 769 List([]), 770 ) 771 772 def test_parse_list_of_ints_returns_list(self) -> None: 773 self.assertEqual( 774 parse(Peekable(iter([LeftBracket(), IntLit(1), Operator(","), IntLit(2), RightBracket()]))), 775 List([Int(1), Int(2)]), 776 ) 777 778 def test_parse_list_with_only_comma_raises_parse_error(self) -> None: 779 with self.assertRaises(UnexpectedTokenError) as parse_error: 780 parse(Peekable(iter([LeftBracket(), Operator(","), RightBracket()]))) 781 782 self.assertEqual(parse_error.exception.unexpected_token, Operator(",")) 783 784 def test_parse_list_with_two_commas_raises_parse_error(self) -> None: 785 with self.assertRaises(UnexpectedTokenError) as parse_error: 786 parse(Peekable(iter([LeftBracket(), Operator(","), Operator(","), RightBracket()]))) 787 788 self.assertEqual(parse_error.exception.unexpected_token, Operator(",")) 789 790 def test_parse_list_with_trailing_comma_raises_parse_error(self) -> None: 791 with self.assertRaises(UnexpectedTokenError) as parse_error: 792 parse(Peekable(iter([LeftBracket(), IntLit(1), Operator(","), RightBracket()]))) 793 794 self.assertEqual(parse_error.exception.unexpected_token, RightBracket()) 795 796 def test_parse_assign(self) -> None: 797 self.assertEqual( 798 parse(Peekable(iter([Name("a"), Operator("="), IntLit(1)]))), 799 Assign(Var("a"), Int(1)), 800 ) 801 802 def test_parse_function_one_arg_returns_function(self) -> None: 803 self.assertEqual( 804 parse(Peekable(iter([Name("a"), Operator("->"), Name("a"), Operator("+"), IntLit(1)]))), 805 Function(Var("a"), Binop(BinopKind.ADD, Var("a"), Int(1))), 806 ) 807 808 def test_parse_function_two_args_returns_functions(self) -> None: 809 self.assertEqual( 810 parse( 811 Peekable( 812 iter([Name("a"), Operator("->"), Name("b"), Operator("->"), Name("a"), Operator("+"), Name("b")]) 813 ) 814 ), 815 Function(Var("a"), Function(Var("b"), Binop(BinopKind.ADD, Var("a"), Var("b")))), 816 ) 817 818 def test_parse_assign_function(self) -> None: 819 self.assertEqual( 820 parse(Peekable(iter([Name("id"), Operator("="), Name("x"), Operator("->"), Name("x")]))), 821 Assign(Var("id"), Function(Var("x"), Var("x"))), 822 ) 823 824 def test_parse_function_application_one_arg(self) -> None: 825 self.assertEqual(parse(Peekable(iter([Name("f"), Name("a")]))), Apply(Var("f"), Var("a"))) 826 827 def test_parse_function_application_two_args(self) -> None: 828 self.assertEqual( 829 parse(Peekable(iter([Name("f"), Name("a"), Name("b")]))), Apply(Apply(Var("f"), Var("a")), Var("b")) 830 ) 831 832 def test_parse_where(self) -> None: 833 self.assertEqual(parse(Peekable(iter([Name("a"), Operator("."), Name("b")]))), Where(Var("a"), Var("b"))) 834 835 def test_parse_nested_where(self) -> None: 836 self.assertEqual( 837 parse(Peekable(iter([Name("a"), Operator("."), Name("b"), Operator("."), Name("c")]))), 838 Where(Where(Var("a"), Var("b")), Var("c")), 839 ) 840 841 def test_parse_assert(self) -> None: 842 self.assertEqual(parse(Peekable(iter([Name("a"), Operator("?"), Name("b")]))), Assert(Var("a"), Var("b"))) 843 844 def test_parse_nested_assert(self) -> None: 845 self.assertEqual( 846 parse(Peekable(iter([Name("a"), Operator("?"), Name("b"), Operator("?"), Name("c")]))), 847 Assert(Assert(Var("a"), Var("b")), Var("c")), 848 ) 849 850 def test_parse_mixed_assert_where(self) -> None: 851 self.assertEqual( 852 parse(Peekable(iter([Name("a"), Operator("?"), Name("b"), Operator("."), Name("c")]))), 853 Where(Assert(Var("a"), Var("b")), Var("c")), 854 ) 855 856 def test_parse_hastype(self) -> None: 857 self.assertEqual( 858 parse(Peekable(iter([Name("a"), Operator(":"), Name("b")]))), Binop(BinopKind.HASTYPE, Var("a"), Var("b")) 859 ) 860 861 def test_parse_hole(self) -> None: 862 self.assertEqual(parse(Peekable(iter([LeftParen(), RightParen()]))), Hole()) 863 864 def test_parse_parenthesized_expression(self) -> None: 865 self.assertEqual( 866 parse(Peekable(iter([LeftParen(), IntLit(1), Operator("+"), IntLit(2), RightParen()]))), 867 Binop(BinopKind.ADD, Int(1), Int(2)), 868 ) 869 870 def test_parse_parenthesized_add_mul(self) -> None: 871 self.assertEqual( 872 parse( 873 Peekable( 874 iter([LeftParen(), IntLit(1), Operator("+"), IntLit(2), RightParen(), Operator("*"), IntLit(3)]) 875 ) 876 ), 877 Binop(BinopKind.MUL, Binop(BinopKind.ADD, Int(1), Int(2)), Int(3)), 878 ) 879 880 def test_parse_pipe(self) -> None: 881 self.assertEqual( 882 parse(Peekable(iter([IntLit(1), Operator("|>"), Name("f")]))), 883 Apply(Var("f"), Int(1)), 884 ) 885 886 def test_parse_nested_pipe(self) -> None: 887 self.assertEqual( 888 parse(Peekable(iter([IntLit(1), Operator("|>"), Name("f"), Operator("|>"), Name("g")]))), 889 Apply(Var("g"), Apply(Var("f"), Int(1))), 890 ) 891 892 def test_parse_reverse_pipe(self) -> None: 893 self.assertEqual( 894 parse(Peekable(iter([Name("f"), Operator("<|"), IntLit(1)]))), 895 Apply(Var("f"), Int(1)), 896 ) 897 898 def test_parse_nested_reverse_pipe(self) -> None: 899 self.assertEqual( 900 parse(Peekable(iter([Name("g"), Operator("<|"), Name("f"), Operator("<|"), IntLit(1)]))), 901 Apply(Var("g"), Apply(Var("f"), Int(1))), 902 ) 903 904 def test_parse_empty_record(self) -> None: 905 self.assertEqual(parse(Peekable(iter([LeftBrace(), RightBrace()]))), Record({})) 906 907 def test_parse_record_single_field(self) -> None: 908 self.assertEqual( 909 parse(Peekable(iter([LeftBrace(), Name("a"), Operator("="), IntLit(4), RightBrace()]))), 910 Record({"a": Int(4)}), 911 ) 912 913 def test_parse_record_with_expression(self) -> None: 914 self.assertEqual( 915 parse( 916 Peekable( 917 iter([LeftBrace(), Name("a"), Operator("="), IntLit(1), Operator("+"), IntLit(2), RightBrace()]) 918 ) 919 ), 920 Record({"a": Binop(BinopKind.ADD, Int(1), Int(2))}), 921 ) 922 923 def test_parse_record_multiple_fields(self) -> None: 924 self.assertEqual( 925 parse( 926 Peekable( 927 iter( 928 [ 929 LeftBrace(), 930 Name("a"), 931 Operator("="), 932 IntLit(4), 933 Operator(","), 934 Name("b"), 935 Operator("="), 936 StringLit("z"), 937 RightBrace(), 938 ] 939 ) 940 ) 941 ), 942 Record({"a": Int(4), "b": String("z")}), 943 ) 944 945 def test_non_variable_in_assignment_raises_parse_error(self) -> None: 946 with self.assertRaises(ParseError) as ctx: 947 parse(Peekable(iter([IntLit(3), Operator("="), IntLit(4)]))) 948 self.assertEqual(ctx.exception.args[0], "expected variable in assignment Int(value=3)") 949 950 def test_non_assign_in_record_constructor_raises_parse_error(self) -> None: 951 with self.assertRaises(ParseError) as ctx: 952 parse(Peekable(iter([LeftBrace(), IntLit(1), Operator(","), IntLit(2), RightBrace()]))) 953 self.assertEqual(ctx.exception.args[0], "failed to parse variable assignment in record constructor") 954 955 def test_parse_right_eval_returns_binop(self) -> None: 956 self.assertEqual( 957 parse(Peekable(iter([Name("a"), Operator("!"), Name("b")]))), 958 Binop(BinopKind.RIGHT_EVAL, Var("a"), Var("b")), 959 ) 960 961 def test_parse_right_eval_with_defs_returns_binop(self) -> None: 962 self.assertEqual( 963 parse(Peekable(iter([Name("a"), Operator("!"), Name("b"), Operator("."), Name("c")]))), 964 Binop(BinopKind.RIGHT_EVAL, Var("a"), Where(Var("b"), Var("c"))), 965 ) 966 967 def test_parse_match_no_cases_raises_parse_error(self) -> None: 968 with self.assertRaises(ParseError) as ctx: 969 parse(Peekable(iter([Operator("|")]))) 970 self.assertEqual(ctx.exception.args[0], "unexpected end of input") 971 972 def test_parse_match_one_case(self) -> None: 973 self.assertEqual( 974 parse(Peekable(iter([Operator("|"), IntLit(1), Operator("->"), IntLit(2)]))), 975 MatchFunction([MatchCase(Int(1), Int(2))]), 976 ) 977 978 def test_parse_match_two_cases(self) -> None: 979 self.assertEqual( 980 parse( 981 Peekable( 982 iter( 983 [ 984 Operator("|"), 985 IntLit(1), 986 Operator("->"), 987 IntLit(2), 988 Operator("|"), 989 IntLit(2), 990 Operator("->"), 991 IntLit(3), 992 ] 993 ) 994 ) 995 ), 996 MatchFunction( 997 [ 998 MatchCase(Int(1), Int(2)), 999 MatchCase(Int(2), Int(3)), 1000 ] 1001 ), 1002 ) 1003 1004 def test_parse_compose(self) -> None: 1005 gensym_reset() 1006 self.assertEqual( 1007 parse(Peekable(iter([Name("f"), Operator(">>"), Name("g")]))), 1008 Function(Var("$v0"), Apply(Var("g"), Apply(Var("f"), Var("$v0")))), 1009 ) 1010 1011 def test_parse_compose_reverse(self) -> None: 1012 gensym_reset() 1013 self.assertEqual( 1014 parse(Peekable(iter([Name("f"), Operator("<<"), Name("g")]))), 1015 Function(Var("$v0"), Apply(Var("f"), Apply(Var("g"), Var("$v0")))), 1016 ) 1017 1018 def test_parse_double_compose(self) -> None: 1019 gensym_reset() 1020 self.assertEqual( 1021 parse(Peekable(iter([Name("f"), Operator("<<"), Name("g"), Operator("<<"), Name("h")]))), 1022 Function( 1023 Var("$v1"), 1024 Apply(Var("f"), Apply(Function(Var("$v0"), Apply(Var("g"), Apply(Var("h"), Var("$v0")))), Var("$v1"))), 1025 ), 1026 ) 1027 1028 def test_boolean_and_binds_tighter_than_or(self) -> None: 1029 self.assertEqual( 1030 parse(Peekable(iter([Name("x"), Operator("||"), Name("y"), Operator("&&"), Name("z")]))), 1031 Binop(BinopKind.BOOL_OR, Var("x"), Binop(BinopKind.BOOL_AND, Var("y"), Var("z"))), 1032 ) 1033 1034 def test_parse_list_spread(self) -> None: 1035 self.assertEqual( 1036 parse(Peekable(iter([LeftBracket(), IntLit(1), Operator(","), Operator("..."), RightBracket()]))), 1037 List([Int(1), Spread()]), 1038 ) 1039 1040 @unittest.skip("TODO(max): Raise if ...x is used with non-name") 1041 def test_parse_list_with_non_name_expr_after_spread_raises_parse_error(self) -> None: 1042 with self.assertRaisesRegex(ParseError, re.escape("unexpected token IntLit(lineno=-1, value=1)")): 1043 parse(Peekable(iter([LeftBracket(), IntLit(1), Operator(","), Operator("..."), IntLit(2), RightBracket()]))) 1044 1045 def test_parse_list_with_named_spread(self) -> None: 1046 self.assertEqual( 1047 parse( 1048 Peekable( 1049 iter( 1050 [ 1051 LeftBracket(), 1052 IntLit(1), 1053 Operator(","), 1054 Operator("..."), 1055 Name("rest"), 1056 RightBracket(), 1057 ] 1058 ) 1059 ) 1060 ), 1061 List([Int(1), Spread("rest")]), 1062 ) 1063 1064 def test_parse_list_spread_beginning_raises_parse_error(self) -> None: 1065 with self.assertRaisesRegex(ParseError, re.escape("spread must come at end of list match")): 1066 parse(Peekable(iter([LeftBracket(), Operator("..."), Operator(","), IntLit(1), RightBracket()]))) 1067 1068 def test_parse_list_named_spread_beginning_raises_parse_error(self) -> None: 1069 with self.assertRaisesRegex(ParseError, re.escape("spread must come at end of list match")): 1070 parse( 1071 Peekable(iter([LeftBracket(), Operator("..."), Name("rest"), Operator(","), IntLit(1), RightBracket()])) 1072 ) 1073 1074 def test_parse_list_spread_middle_raises_parse_error(self) -> None: 1075 with self.assertRaisesRegex(ParseError, re.escape("spread must come at end of list match")): 1076 parse( 1077 Peekable( 1078 iter( 1079 [ 1080 LeftBracket(), 1081 IntLit(1), 1082 Operator(","), 1083 Operator("..."), 1084 Operator(","), 1085 IntLit(1), 1086 RightBracket(), 1087 ] 1088 ) 1089 ) 1090 ) 1091 1092 def test_parse_list_named_spread_middle_raises_parse_error(self) -> None: 1093 with self.assertRaisesRegex(ParseError, re.escape("spread must come at end of list match")): 1094 parse( 1095 Peekable( 1096 iter( 1097 [ 1098 LeftBracket(), 1099 IntLit(1), 1100 Operator(","), 1101 Operator("..."), 1102 Name("rest"), 1103 Operator(","), 1104 IntLit(1), 1105 RightBracket(), 1106 ] 1107 ) 1108 ) 1109 ) 1110 1111 def test_parse_record_spread(self) -> None: 1112 self.assertEqual( 1113 parse( 1114 Peekable( 1115 iter( 1116 [LeftBrace(), Name("x"), Operator("="), IntLit(1), Operator(","), Operator("..."), RightBrace()] 1117 ) 1118 ) 1119 ), 1120 Record({"x": Int(1), "...": Spread()}), 1121 ) 1122 1123 def test_parse_record_spread_beginning_raises_parse_error(self) -> None: 1124 with self.assertRaisesRegex(ParseError, re.escape("spread must come at end of record match")): 1125 parse( 1126 Peekable( 1127 iter( 1128 [LeftBrace(), Operator("..."), Operator(","), Name("x"), Operator("="), IntLit(1), RightBrace()] 1129 ) 1130 ) 1131 ) 1132 1133 def test_parse_record_spread_middle_raises_parse_error(self) -> None: 1134 with self.assertRaisesRegex(ParseError, re.escape("spread must come at end of record match")): 1135 parse( 1136 Peekable( 1137 iter( 1138 [ 1139 LeftBrace(), 1140 Name("x"), 1141 Operator("="), 1142 IntLit(1), 1143 Operator(","), 1144 Operator("..."), 1145 Operator(","), 1146 Name("y"), 1147 Operator("="), 1148 IntLit(2), 1149 RightBrace(), 1150 ] 1151 ) 1152 ) 1153 ) 1154 1155 def test_parse_record_with_only_comma_raises_parse_error(self) -> None: 1156 with self.assertRaises(UnexpectedTokenError) as parse_error: 1157 parse(Peekable(iter([LeftBrace(), Operator(","), RightBrace()]))) 1158 1159 self.assertEqual(parse_error.exception.unexpected_token, Operator(",")) 1160 1161 def test_parse_record_with_two_commas_raises_parse_error(self) -> None: 1162 with self.assertRaises(UnexpectedTokenError) as parse_error: 1163 parse(Peekable(iter([LeftBrace(), Operator(","), Operator(","), RightBrace()]))) 1164 1165 self.assertEqual(parse_error.exception.unexpected_token, Operator(",")) 1166 1167 def test_parse_record_with_trailing_comma_raises_parse_error(self) -> None: 1168 with self.assertRaises(UnexpectedTokenError) as parse_error: 1169 parse(Peekable(iter([LeftBrace(), Name("x"), Operator("="), IntLit(1), Operator(","), RightBrace()]))) 1170 1171 self.assertEqual(parse_error.exception.unexpected_token, RightBrace()) 1172 1173 def test_parse_variant_returns_variant(self) -> None: 1174 self.assertEqual(parse(Peekable(iter([Hash(), Name("abc"), IntLit(1)]))), Variant("abc", Int(1))) 1175 1176 def test_parse_hash_raises_unexpected_eof_error(self) -> None: 1177 tokens = Peekable(iter([Hash()])) 1178 with self.assertRaises(UnexpectedEOFError): 1179 parse(tokens) 1180 1181 def test_parse_variant_non_name_raises_parse_error(self) -> None: 1182 with self.assertRaises(UnexpectedTokenError) as parse_error: 1183 parse(Peekable(iter([Hash(), IntLit(1)]))) 1184 1185 self.assertEqual(parse_error.exception.unexpected_token, IntLit(1)) 1186 1187 def test_parse_variant_eof_raises_unexpected_eof_error(self) -> None: 1188 with self.assertRaises(UnexpectedEOFError): 1189 parse(Peekable(iter([Hash()]))) 1190 1191 def test_match_with_variant(self) -> None: 1192 ast = parse(tokenize("| #true () -> 123")) 1193 self.assertEqual(ast, MatchFunction([MatchCase(TRUE, Int(123))])) 1194 1195 def test_binary_and_with_variant_args(self) -> None: 1196 ast = parse(tokenize("#true() && #false()")) 1197 self.assertEqual(ast, Binop(BinopKind.BOOL_AND, TRUE, FALSE)) 1198 1199 def test_apply_with_variant_args(self) -> None: 1200 ast = parse(tokenize("f #true() #false()")) 1201 self.assertEqual(ast, Apply(Apply(Var("f"), TRUE), FALSE)) 1202 1203 def test_parse_int_preserves_source_extent(self) -> None: 1204 source_extent = SourceExtent( 1205 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0) 1206 ) 1207 int_lit = IntLit(1).with_source(source_extent) 1208 self.assertEqual(parse(Peekable(iter([int_lit]))).source_extent, source_extent) 1209 1210 def test_parse_float_preserves_source_extent(self) -> None: 1211 source_extent = SourceExtent( 1212 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=3, byteno=2) 1213 ) 1214 float_lit = FloatLit(3.2).with_source(source_extent) 1215 self.assertEqual(parse(Peekable(iter([float_lit]))).source_extent, source_extent) 1216 1217 def test_parse_string_preserves_source_extent(self) -> None: 1218 source_extent = SourceExtent( 1219 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=7, byteno=6) 1220 ) 1221 string_lit = StringLit("Hello").with_source(source_extent) 1222 self.assertEqual(parse(Peekable(iter([string_lit]))).source_extent, source_extent) 1223 1224 def test_parse_bytes_preserves_source_extent(self) -> None: 1225 source_extent = SourceExtent( 1226 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=9, byteno=8) 1227 ) 1228 bytes_lit = BytesLit("QUJD", 64).with_source(source_extent) 1229 self.assertEqual(parse(Peekable(iter([bytes_lit]))).source_extent, source_extent) 1230 1231 def test_parse_var_preserves_source_extent(self) -> None: 1232 source_extent = SourceExtent( 1233 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0) 1234 ) 1235 var = Name("x").with_source(source_extent) 1236 self.assertEqual(parse(Peekable(iter([var]))).source_extent, source_extent) 1237 1238 def test_parse_hole_preserves_source_extent(self) -> None: 1239 left_paren = LeftParen().with_source( 1240 SourceExtent( 1241 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0) 1242 ) 1243 ) 1244 right_paren = RightParen().with_source( 1245 SourceExtent( 1246 start=SourceLocation(lineno=1, colno=2, byteno=1), end=SourceLocation(lineno=1, colno=2, byteno=1) 1247 ) 1248 ) 1249 hole_source_extent = SourceExtent( 1250 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=2, byteno=1) 1251 ) 1252 self.assertEqual(parse(Peekable(iter([left_paren, right_paren]))).source_extent, hole_source_extent) 1253 1254 def test_parenthesized_expression_preserves_source_extent(self) -> None: 1255 left_paren = LeftParen().with_source( 1256 SourceExtent( 1257 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0) 1258 ) 1259 ) 1260 int_lit = IntLit(1).with_source( 1261 SourceExtent( 1262 start=SourceLocation(lineno=1, colno=2, byteno=1), end=SourceLocation(lineno=1, colno=2, byteno=1) 1263 ) 1264 ) 1265 right_paren = RightParen().with_source( 1266 SourceExtent( 1267 start=SourceLocation(lineno=1, colno=3, byteno=2), end=SourceLocation(lineno=1, colno=3, byteno=2) 1268 ) 1269 ) 1270 parenthesized_int_lit_source_extent = SourceExtent( 1271 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=3, byteno=2) 1272 ) 1273 self.assertEqual( 1274 parse(Peekable(iter([left_paren, int_lit, right_paren]))).source_extent, parenthesized_int_lit_source_extent 1275 ) 1276 1277 def test_parse_spread_preserves_source_extent(self) -> None: 1278 ellipsis = Operator("...").with_source( 1279 SourceExtent( 1280 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=3, byteno=2) 1281 ) 1282 ) 1283 name = Name("x").with_source( 1284 SourceExtent( 1285 start=SourceLocation(lineno=1, colno=4, byteno=3), end=SourceLocation(lineno=1, colno=4, byteno=3) 1286 ) 1287 ) 1288 spread_source_extent = SourceExtent( 1289 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=4, byteno=3) 1290 ) 1291 self.assertEqual(parse(Peekable(iter([ellipsis, name]))).source_extent, spread_source_extent) 1292 1293 def test_parse_binop_preserves_source_extent(self) -> None: 1294 first_addend = IntLit(1).with_source( 1295 SourceExtent( 1296 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0) 1297 ) 1298 ) 1299 operator = Operator("+").with_source( 1300 SourceExtent( 1301 start=SourceLocation(lineno=1, colno=3, byteno=2), end=SourceLocation(lineno=1, colno=3, byteno=2) 1302 ) 1303 ) 1304 second_addend = IntLit(2).with_source( 1305 SourceExtent( 1306 start=SourceLocation(lineno=2, colno=5, byteno=4), end=SourceLocation(lineno=2, colno=5, byteno=4) 1307 ) 1308 ) 1309 binop_source_extent = SourceExtent( 1310 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=2, colno=5, byteno=4) 1311 ) 1312 self.assertEqual( 1313 parse(Peekable(iter([first_addend, operator, second_addend]))).source_extent, binop_source_extent 1314 ) 1315 1316 def test_parse_list_preserves_source_extent(self) -> None: 1317 left_bracket = LeftBracket().with_source( 1318 SourceExtent( 1319 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0) 1320 ) 1321 ) 1322 one = IntLit(1).with_source( 1323 SourceExtent( 1324 start=SourceLocation(lineno=1, colno=2, byteno=1), end=SourceLocation(lineno=1, colno=2, byteno=1) 1325 ) 1326 ) 1327 comma = Operator(",").with_source( 1328 SourceExtent( 1329 start=SourceLocation(lineno=1, colno=3, byteno=2), end=SourceLocation(lineno=1, colno=3, byteno=2) 1330 ) 1331 ) 1332 two = IntLit(2).with_source( 1333 SourceExtent( 1334 start=SourceLocation(lineno=1, colno=5, byteno=4), end=SourceLocation(lineno=1, colno=5, byteno=4) 1335 ) 1336 ) 1337 right_bracket = RightBracket().with_source( 1338 SourceExtent( 1339 start=SourceLocation(lineno=1, colno=6, byteno=5), end=SourceLocation(lineno=1, colno=6, byteno=5) 1340 ) 1341 ) 1342 list_source_extent = SourceExtent( 1343 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=6, byteno=5) 1344 ) 1345 self.assertEqual( 1346 parse(Peekable(iter([left_bracket, one, comma, two, right_bracket]))).source_extent, list_source_extent 1347 ) 1348 1349 1350class MatchTests(unittest.TestCase): 1351 def test_match_hole_with_non_hole_returns_none(self) -> None: 1352 self.assertEqual(match(Int(1), pattern=Hole()), None) 1353 1354 def test_match_hole_with_hole_returns_empty_dict(self) -> None: 1355 self.assertEqual(match(Hole(), pattern=Hole()), {}) 1356 1357 def test_match_with_equal_ints_returns_empty_dict(self) -> None: 1358 self.assertEqual(match(Int(1), pattern=Int(1)), {}) 1359 1360 def test_match_with_inequal_ints_returns_none(self) -> None: 1361 self.assertEqual(match(Int(2), pattern=Int(1)), None) 1362 1363 def test_match_int_with_non_int_returns_none(self) -> None: 1364 self.assertEqual(match(String("abc"), pattern=Int(1)), None) 1365 1366 def test_match_with_equal_floats_raises_match_error(self) -> None: 1367 with self.assertRaisesRegex(MatchError, re.escape("pattern matching is not supported for Floats")): 1368 match(Float(1), pattern=Float(1)) 1369 1370 def test_match_with_inequal_floats_raises_match_error(self) -> None: 1371 with self.assertRaisesRegex(MatchError, re.escape("pattern matching is not supported for Floats")): 1372 match(Float(2), pattern=Float(1)) 1373 1374 def test_match_float_with_non_float_raises_match_error(self) -> None: 1375 with self.assertRaisesRegex(MatchError, re.escape("pattern matching is not supported for Floats")): 1376 match(String("abc"), pattern=Float(1)) 1377 1378 def test_match_with_equal_strings_returns_empty_dict(self) -> None: 1379 self.assertEqual(match(String("a"), pattern=String("a")), {}) 1380 1381 def test_match_with_inequal_strings_returns_none(self) -> None: 1382 self.assertEqual(match(String("b"), pattern=String("a")), None) 1383 1384 def test_match_string_with_non_string_returns_none(self) -> None: 1385 self.assertEqual(match(Int(1), pattern=String("abc")), None) 1386 1387 def test_match_var_returns_dict_with_var_name(self) -> None: 1388 self.assertEqual(match(String("abc"), pattern=Var("a")), {"a": String("abc")}) 1389 1390 def test_match_record_with_non_record_returns_none(self) -> None: 1391 self.assertEqual( 1392 match( 1393 Int(2), 1394 pattern=Record({"x": Var("x"), "y": Var("y")}), 1395 ), 1396 None, 1397 ) 1398 1399 def test_match_record_with_more_fields_in_pattern_returns_none(self) -> None: 1400 self.assertEqual( 1401 match( 1402 Record({"x": Int(1), "y": Int(2)}), 1403 pattern=Record({"x": Var("x"), "y": Var("y"), "z": Var("z")}), 1404 ), 1405 None, 1406 ) 1407 1408 def test_match_record_with_fewer_fields_in_pattern_returns_none(self) -> None: 1409 self.assertEqual( 1410 match( 1411 Record({"x": Int(1), "y": Int(2)}), 1412 pattern=Record({"x": Var("x")}), 1413 ), 1414 None, 1415 ) 1416 1417 def test_match_record_with_vars_returns_dict_with_keys(self) -> None: 1418 self.assertEqual( 1419 match( 1420 Record({"x": Int(1), "y": Int(2)}), 1421 pattern=Record({"x": Var("x"), "y": Var("y")}), 1422 ), 1423 {"x": Int(1), "y": Int(2)}, 1424 ) 1425 1426 def test_match_record_with_matching_const_returns_dict_with_other_keys(self) -> None: 1427 # TODO(max): Should this be the case? I feel like we should return all 1428 # the keys. 1429 self.assertEqual( 1430 match( 1431 Record({"x": Int(1), "y": Int(2)}), 1432 pattern=Record({"x": Int(1), "y": Var("y")}), 1433 ), 1434 {"y": Int(2)}, 1435 ) 1436 1437 def test_match_record_with_non_matching_const_returns_none(self) -> None: 1438 self.assertEqual( 1439 match( 1440 Record({"x": Int(1), "y": Int(2)}), 1441 pattern=Record({"x": Int(3), "y": Var("y")}), 1442 ), 1443 None, 1444 ) 1445 1446 def test_match_list_with_non_list_returns_none(self) -> None: 1447 self.assertEqual( 1448 match( 1449 Int(2), 1450 pattern=List([Var("x"), Var("y")]), 1451 ), 1452 None, 1453 ) 1454 1455 def test_match_list_with_more_fields_in_pattern_returns_none(self) -> None: 1456 self.assertEqual( 1457 match( 1458 List([Int(1), Int(2)]), 1459 pattern=List([Var("x"), Var("y"), Var("z")]), 1460 ), 1461 None, 1462 ) 1463 1464 def test_match_list_with_fewer_fields_in_pattern_returns_none(self) -> None: 1465 self.assertEqual( 1466 match( 1467 List([Int(1), Int(2)]), 1468 pattern=List([Var("x")]), 1469 ), 1470 None, 1471 ) 1472 1473 def test_match_list_with_vars_returns_dict_with_keys(self) -> None: 1474 self.assertEqual( 1475 match( 1476 List([Int(1), Int(2)]), 1477 pattern=List([Var("x"), Var("y")]), 1478 ), 1479 {"x": Int(1), "y": Int(2)}, 1480 ) 1481 1482 def test_match_list_with_matching_const_returns_dict_with_other_keys(self) -> None: 1483 self.assertEqual( 1484 match( 1485 List([Int(1), Int(2)]), 1486 pattern=List([Int(1), Var("y")]), 1487 ), 1488 {"y": Int(2)}, 1489 ) 1490 1491 def test_match_list_with_non_matching_const_returns_none(self) -> None: 1492 self.assertEqual( 1493 match( 1494 List([Int(1), Int(2)]), 1495 pattern=List([Int(3), Var("y")]), 1496 ), 1497 None, 1498 ) 1499 1500 def test_parse_right_pipe(self) -> None: 1501 text = "3 + 4 |> $$quote" 1502 ast = parse(tokenize(text)) 1503 self.assertEqual(ast, Apply(Var("$$quote"), Binop(BinopKind.ADD, Int(3), Int(4)))) 1504 1505 def test_parse_left_pipe(self) -> None: 1506 text = "$$quote <| 3 + 4" 1507 ast = parse(tokenize(text)) 1508 self.assertEqual(ast, Apply(Var("$$quote"), Binop(BinopKind.ADD, Int(3), Int(4)))) 1509 1510 def test_parse_match_with_left_apply(self) -> None: 1511 text = """| a -> b <| c 1512 | d -> e""" 1513 tokens = tokenize(text) 1514 self.assertEqual( 1515 list(tokens), 1516 [ 1517 Operator("|"), 1518 Name("a"), 1519 Operator("->"), 1520 Name("b"), 1521 Operator("<|"), 1522 Name("c"), 1523 Operator("|"), 1524 Name("d"), 1525 Operator("->"), 1526 Name("e"), 1527 ], 1528 ) 1529 tokens = tokenize(text) 1530 ast = parse(tokens) 1531 self.assertEqual( 1532 ast, MatchFunction([MatchCase(Var("a"), Apply(Var("b"), Var("c"))), MatchCase(Var("d"), Var("e"))]) 1533 ) 1534 1535 def test_parse_match_with_right_apply(self) -> None: 1536 text = """ 1537| 1 -> 19 1538| a -> a |> (x -> x + 1) 1539""" 1540 tokens = tokenize(text) 1541 ast = parse(tokens) 1542 self.assertEqual( 1543 ast, 1544 MatchFunction( 1545 [ 1546 MatchCase(Int(1), Int(19)), 1547 MatchCase( 1548 Var("a"), 1549 Apply( 1550 Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Int(1))), 1551 Var("a"), 1552 ), 1553 ), 1554 ] 1555 ), 1556 ) 1557 1558 def test_match_list_with_spread_returns_empty_dict(self) -> None: 1559 self.assertEqual( 1560 match( 1561 List([Int(1), Int(2), Int(3), Int(4), Int(5)]), 1562 pattern=List([Int(1), Spread()]), 1563 ), 1564 {}, 1565 ) 1566 1567 def test_match_list_with_named_spread_returns_name_bound_to_rest(self) -> None: 1568 self.assertEqual( 1569 match( 1570 List([Int(1), Int(2), Int(3), Int(4)]), 1571 pattern=List([Var("a"), Int(2), Spread("rest")]), 1572 ), 1573 {"a": Int(1), "rest": List([Int(3), Int(4)])}, 1574 ) 1575 1576 def test_match_list_with_named_spread_returns_name_bound_to_empty_rest(self) -> None: 1577 self.assertEqual( 1578 match( 1579 List([Int(1), Int(2)]), 1580 pattern=List([Var("a"), Int(2), Spread("rest")]), 1581 ), 1582 {"a": Int(1), "rest": List([])}, 1583 ) 1584 1585 def test_match_list_with_mismatched_spread_returns_none(self) -> None: 1586 self.assertEqual( 1587 match( 1588 List([Int(1), Int(2), Int(3), Int(4), Int(5)]), 1589 pattern=List([Int(1), Int(6), Spread()]), 1590 ), 1591 None, 1592 ) 1593 1594 def test_match_record_with_constant_and_spread_returns_empty_dict(self) -> None: 1595 self.assertEqual( 1596 match( 1597 Record({"a": Int(1), "b": Int(2), "c": Int(3)}), 1598 pattern=Record({"a": Int(1), "...": Spread()}), 1599 ), 1600 {}, 1601 ) 1602 1603 def test_match_record_with_var_and_spread_returns_match(self) -> None: 1604 self.assertEqual( 1605 match( 1606 Record({"a": Int(1), "b": Int(2), "c": Int(3)}), 1607 pattern=Record({"a": Var("x"), "...": Spread()}), 1608 ), 1609 {"x": Int(1)}, 1610 ) 1611 1612 def test_match_record_with_mismatched_spread_returns_none(self) -> None: 1613 self.assertEqual( 1614 match( 1615 Record({"a": Int(1), "b": Int(2), "c": Int(3)}), 1616 pattern=Record({"d": Var("x"), "...": Spread()}), 1617 ), 1618 None, 1619 ) 1620 1621 def test_match_variant_with_equal_tag_returns_empty_dict(self) -> None: 1622 self.assertEqual(match(Variant("abc", Hole()), pattern=Variant("abc", Hole())), {}) 1623 1624 def test_match_variant_with_inequal_tag_returns_none(self) -> None: 1625 self.assertEqual(match(Variant("def", Hole()), pattern=Variant("abc", Hole())), None) 1626 1627 def test_match_variant_matches_value(self) -> None: 1628 self.assertEqual(match(Variant("abc", Int(123)), pattern=Variant("abc", Hole())), None) 1629 self.assertEqual(match(Variant("abc", Int(123)), pattern=Variant("abc", Int(123))), {}) 1630 1631 def test_match_variant_with_different_type_returns_none(self) -> None: 1632 self.assertEqual(match(Int(123), pattern=Variant("abc", Hole())), None) 1633 1634 1635class EvalTests(unittest.TestCase): 1636 def test_eval_int_returns_int(self) -> None: 1637 exp = Int(5) 1638 self.assertEqual(eval_exp({}, exp), Int(5)) 1639 1640 def test_eval_float_returns_float(self) -> None: 1641 exp = Float(3.14) 1642 self.assertEqual(eval_exp({}, exp), Float(3.14)) 1643 1644 def test_eval_str_returns_str(self) -> None: 1645 exp = String("xyz") 1646 self.assertEqual(eval_exp({}, exp), String("xyz")) 1647 1648 def test_eval_bytes_returns_bytes(self) -> None: 1649 exp = Bytes(b"xyz") 1650 self.assertEqual(eval_exp({}, exp), Bytes(b"xyz")) 1651 1652 def test_eval_with_non_existent_var_raises_name_error(self) -> None: 1653 exp = Var("no") 1654 with self.assertRaises(NameError) as ctx: 1655 eval_exp({}, exp) 1656 self.assertEqual(ctx.exception.args[0], "name 'no' is not defined") 1657 1658 def test_eval_with_bound_var_returns_value(self) -> None: 1659 exp = Var("yes") 1660 env = {"yes": Int(123)} 1661 self.assertEqual(eval_exp(env, exp), Int(123)) 1662 1663 def test_eval_with_binop_add_returns_sum(self) -> None: 1664 exp = Binop(BinopKind.ADD, Int(1), Int(2)) 1665 self.assertEqual(eval_exp({}, exp), Int(3)) 1666 1667 def test_eval_with_nested_binop(self) -> None: 1668 exp = Binop(BinopKind.ADD, Binop(BinopKind.ADD, Int(1), Int(2)), Int(3)) 1669 self.assertEqual(eval_exp({}, exp), Int(6)) 1670 1671 def test_eval_with_binop_add_with_int_string_raises_type_error(self) -> None: 1672 exp = Binop(BinopKind.ADD, Int(1), String("hello")) 1673 with self.assertRaises(TypeError) as ctx: 1674 eval_exp({}, exp) 1675 self.assertEqual(ctx.exception.args[0], "expected Int or Float, got String") 1676 1677 def test_eval_with_binop_sub(self) -> None: 1678 exp = Binop(BinopKind.SUB, Int(1), Int(2)) 1679 self.assertEqual(eval_exp({}, exp), Int(-1)) 1680 1681 def test_eval_with_binop_mul(self) -> None: 1682 exp = Binop(BinopKind.MUL, Int(2), Int(3)) 1683 self.assertEqual(eval_exp({}, exp), Int(6)) 1684 1685 def test_eval_with_binop_div(self) -> None: 1686 exp = Binop(BinopKind.DIV, Int(3), Int(10)) 1687 self.assertEqual(eval_exp({}, exp), Float(0.3)) 1688 1689 def test_eval_with_binop_floor_div(self) -> None: 1690 exp = Binop(BinopKind.FLOOR_DIV, Int(2), Int(3)) 1691 self.assertEqual(eval_exp({}, exp), Int(0)) 1692 1693 def test_eval_with_binop_exp(self) -> None: 1694 exp = Binop(BinopKind.EXP, Int(2), Int(3)) 1695 self.assertEqual(eval_exp({}, exp), Int(8)) 1696 1697 def test_eval_with_binop_mod(self) -> None: 1698 exp = Binop(BinopKind.MOD, Int(10), Int(4)) 1699 self.assertEqual(eval_exp({}, exp), Int(2)) 1700 1701 def test_eval_with_binop_equal_with_equal_returns_true(self) -> None: 1702 exp = Binop(BinopKind.EQUAL, Int(1), Int(1)) 1703 self.assertEqual(eval_exp({}, exp), TRUE) 1704 1705 def test_eval_with_binop_equal_with_inequal_returns_false(self) -> None: 1706 exp = Binop(BinopKind.EQUAL, Int(1), Int(2)) 1707 self.assertEqual(eval_exp({}, exp), FALSE) 1708 1709 def test_eval_with_binop_not_equal_with_equal_returns_false(self) -> None: 1710 exp = Binop(BinopKind.NOT_EQUAL, Int(1), Int(1)) 1711 self.assertEqual(eval_exp({}, exp), FALSE) 1712 1713 def test_eval_with_binop_not_equal_with_inequal_returns_true(self) -> None: 1714 exp = Binop(BinopKind.NOT_EQUAL, Int(1), Int(2)) 1715 self.assertEqual(eval_exp({}, exp), TRUE) 1716 1717 def test_eval_with_binop_concat_with_strings_returns_string(self) -> None: 1718 exp = Binop(BinopKind.STRING_CONCAT, String("hello"), String(" world")) 1719 self.assertEqual(eval_exp({}, exp), String("hello world")) 1720 1721 def test_eval_with_binop_concat_with_int_string_raises_type_error(self) -> None: 1722 exp = Binop(BinopKind.STRING_CONCAT, Int(123), String(" world")) 1723 with self.assertRaises(TypeError) as ctx: 1724 eval_exp({}, exp) 1725 self.assertEqual(ctx.exception.args[0], "expected String, got Int") 1726 1727 def test_eval_with_binop_concat_with_string_int_raises_type_error(self) -> None: 1728 exp = Binop(BinopKind.STRING_CONCAT, String(" world"), Int(123)) 1729 with self.assertRaises(TypeError) as ctx: 1730 eval_exp({}, exp) 1731 self.assertEqual(ctx.exception.args[0], "expected String, got Int") 1732 1733 def test_eval_with_binop_cons_with_int_list_returns_list(self) -> None: 1734 exp = Binop(BinopKind.LIST_CONS, Int(1), List([Int(2), Int(3)])) 1735 self.assertEqual(eval_exp({}, exp), List([Int(1), Int(2), Int(3)])) 1736 1737 def test_eval_with_binop_cons_with_list_list_returns_nested_list(self) -> None: 1738 exp = Binop(BinopKind.LIST_CONS, List([]), List([])) 1739 self.assertEqual(eval_exp({}, exp), List([List([])])) 1740 1741 def test_eval_with_binop_cons_with_list_int_raises_type_error(self) -> None: 1742 exp = Binop(BinopKind.LIST_CONS, List([]), Int(123)) 1743 with self.assertRaises(TypeError) as ctx: 1744 eval_exp({}, exp) 1745 self.assertEqual(ctx.exception.args[0], "expected List, got Int") 1746 1747 def test_eval_with_list_append(self) -> None: 1748 exp = Binop(BinopKind.LIST_APPEND, List([Int(1), Int(2)]), Int(3)) 1749 self.assertEqual(eval_exp({}, exp), List([Int(1), Int(2), Int(3)])) 1750 1751 def test_eval_with_list_evaluates_elements(self) -> None: 1752 exp = List( 1753 [ 1754 Binop(BinopKind.ADD, Int(1), Int(2)), 1755 Binop(BinopKind.ADD, Int(3), Int(4)), 1756 ] 1757 ) 1758 self.assertEqual(eval_exp({}, exp), List([Int(3), Int(7)])) 1759 1760 def test_eval_with_function_returns_closure_with_improved_env(self) -> None: 1761 exp = Function(Var("x"), Var("x")) 1762 self.assertEqual(eval_exp({"a": Int(1), "b": Int(2)}, exp), Closure({}, exp)) 1763 1764 def test_eval_with_match_function_returns_closure_with_improved_env(self) -> None: 1765 exp = MatchFunction([]) 1766 self.assertEqual(eval_exp({"a": Int(1), "b": Int(2)}, exp), Closure({}, exp)) 1767 1768 def test_eval_assign_returns_env_object(self) -> None: 1769 exp = Assign(Var("a"), Int(1)) 1770 env: Env = {} 1771 result = eval_exp(env, exp) 1772 self.assertEqual(result, EnvObject({"a": Int(1)})) 1773 1774 def test_eval_assign_function_returns_closure_without_function_in_env(self) -> None: 1775 exp = Assign(Var("a"), Function(Var("x"), Var("x"))) 1776 result = eval_exp({}, exp) 1777 assert isinstance(result, EnvObject) 1778 closure = result.env["a"] 1779 self.assertIsInstance(closure, Closure) 1780 self.assertEqual(closure, Closure({}, Function(Var("x"), Var("x")))) 1781 1782 def test_eval_assign_function_returns_closure_with_function_in_env(self) -> None: 1783 exp = Assign(Var("a"), Function(Var("x"), Var("a"))) 1784 result = eval_exp({}, exp) 1785 assert isinstance(result, EnvObject) 1786 closure = result.env["a"] 1787 self.assertIsInstance(closure, Closure) 1788 self.assertEqual(closure, Closure({"a": closure}, Function(Var("x"), Var("a")))) 1789 1790 def test_eval_assign_does_not_modify_env(self) -> None: 1791 exp = Assign(Var("a"), Int(1)) 1792 env: Env = {} 1793 eval_exp(env, exp) 1794 self.assertEqual(env, {}) 1795 1796 def test_eval_where_evaluates_in_order(self) -> None: 1797 exp = Where(Binop(BinopKind.ADD, Var("a"), Int(2)), Assign(Var("a"), Int(1))) 1798 env: Env = {} 1799 self.assertEqual(eval_exp(env, exp), Int(3)) 1800 self.assertEqual(env, {}) 1801 1802 def test_eval_nested_where(self) -> None: 1803 exp = Where( 1804 Where( 1805 Binop(BinopKind.ADD, Var("a"), Var("b")), 1806 Assign(Var("a"), Int(1)), 1807 ), 1808 Assign(Var("b"), Int(2)), 1809 ) 1810 env: Env = {} 1811 self.assertEqual(eval_exp(env, exp), Int(3)) 1812 self.assertEqual(env, {}) 1813 1814 def test_eval_assert_with_truthy_cond_returns_value(self) -> None: 1815 exp = Assert(Int(123), TRUE) 1816 self.assertEqual(eval_exp({}, exp), Int(123)) 1817 1818 def test_eval_assert_with_falsey_cond_raises_assertion_error(self) -> None: 1819 exp = Assert(Int(123), FALSE) 1820 with self.assertRaisesRegex(AssertionError, re.escape("condition #false () failed")): 1821 eval_exp({}, exp) 1822 1823 def test_eval_nested_assert(self) -> None: 1824 exp = Assert(Assert(Int(123), TRUE), TRUE) 1825 self.assertEqual(eval_exp({}, exp), Int(123)) 1826 1827 def test_eval_hole(self) -> None: 1828 exp = Hole() 1829 self.assertEqual(eval_exp({}, exp), Hole()) 1830 1831 def test_eval_function_application_one_arg(self) -> None: 1832 exp = Apply(Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Int(1))), Int(2)) 1833 self.assertEqual(eval_exp({}, exp), Int(3)) 1834 1835 def test_eval_function_application_two_args(self) -> None: 1836 exp = Apply( 1837 Apply(Function(Var("a"), Function(Var("b"), Binop(BinopKind.ADD, Var("a"), Var("b")))), Int(3)), 1838 Int(2), 1839 ) 1840 self.assertEqual(eval_exp({}, exp), Int(5)) 1841 1842 def test_eval_function_returns_closure_with_captured_env(self) -> None: 1843 exp = Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Var("y"))) 1844 res = eval_exp({"y": Int(5)}, exp) 1845 self.assertIsInstance(res, Closure) 1846 assert isinstance(res, Closure) # for mypy 1847 self.assertEqual(res.env, {"y": Int(5)}) 1848 1849 def test_eval_function_capture_env(self) -> None: 1850 exp = Apply(Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Var("y"))), Int(2)) 1851 self.assertEqual(eval_exp({"y": Int(5)}, exp), Int(7)) 1852 1853 def test_eval_non_function_raises_type_error(self) -> None: 1854 exp = Apply(Int(3), Int(4)) 1855 with self.assertRaisesRegex(TypeError, re.escape("attempted to apply a non-closure of type Int")): 1856 eval_exp({}, exp) 1857 1858 def test_eval_access_from_invalid_object_raises_type_error(self) -> None: 1859 exp = Access(Int(4), String("x")) 1860 with self.assertRaisesRegex(TypeError, re.escape("attempted to access from type Int")): 1861 eval_exp({}, exp) 1862 1863 def test_eval_record_evaluates_value_expressions(self) -> None: 1864 exp = Record({"a": Binop(BinopKind.ADD, Int(1), Int(2))}) 1865 self.assertEqual(eval_exp({}, exp), Record({"a": Int(3)})) 1866 1867 def test_eval_record_access_with_invalid_accessor_raises_type_error(self) -> None: 1868 exp = Access(Record({"a": Int(4)}), Int(0)) 1869 with self.assertRaisesRegex( 1870 TypeError, re.escape("cannot access record field using Int, expected a field name") 1871 ): 1872 eval_exp({}, exp) 1873 1874 def test_eval_record_access_with_unknown_accessor_raises_name_error(self) -> None: 1875 exp = Access(Record({"a": Int(4)}), Var("b")) 1876 with self.assertRaisesRegex(NameError, re.escape("no assignment to b found in record")): 1877 eval_exp({}, exp) 1878 1879 def test_eval_record_access(self) -> None: 1880 exp = Access(Record({"a": Int(4)}), Var("a")) 1881 self.assertEqual(eval_exp({}, exp), Int(4)) 1882 1883 def test_eval_list_access_with_invalid_accessor_raises_type_error(self) -> None: 1884 exp = Access(List([Int(4)]), String("hello")) 1885 with self.assertRaisesRegex(TypeError, re.escape("cannot index into list using type String, expected integer")): 1886 eval_exp({}, exp) 1887 1888 def test_eval_list_access_with_out_of_bounds_accessor_raises_value_error(self) -> None: 1889 exp = Access(List([Int(1), Int(2), Int(3)]), Int(4)) 1890 with self.assertRaisesRegex(ValueError, re.escape("index 4 out of bounds for list")): 1891 eval_exp({}, exp) 1892 1893 def test_eval_list_access(self) -> None: 1894 exp = Access(List([String("a"), String("b"), String("c")]), Int(2)) 1895 self.assertEqual(eval_exp({}, exp), String("c")) 1896 1897 def test_right_eval_evaluates_right_hand_side(self) -> None: 1898 exp = Binop(BinopKind.RIGHT_EVAL, Int(1), Int(2)) 1899 self.assertEqual(eval_exp({}, exp), Int(2)) 1900 1901 def test_match_no_cases_raises_match_error(self) -> None: 1902 exp = Apply(MatchFunction([]), Int(1)) 1903 with self.assertRaisesRegex(MatchError, "no matching cases"): 1904 eval_exp({}, exp) 1905 1906 def test_match_int_with_equal_int_matches(self) -> None: 1907 exp = Apply(MatchFunction([MatchCase(pattern=Int(1), body=Int(2))]), Int(1)) 1908 self.assertEqual(eval_exp({}, exp), Int(2)) 1909 1910 def test_match_int_with_inequal_int_raises_match_error(self) -> None: 1911 exp = Apply(MatchFunction([MatchCase(pattern=Int(1), body=Int(2))]), Int(3)) 1912 with self.assertRaisesRegex(MatchError, "no matching cases"): 1913 eval_exp({}, exp) 1914 1915 def test_match_string_with_equal_string_matches(self) -> None: 1916 exp = Apply(MatchFunction([MatchCase(pattern=String("a"), body=String("b"))]), String("a")) 1917 self.assertEqual(eval_exp({}, exp), String("b")) 1918 1919 def test_match_string_with_inequal_string_raises_match_error(self) -> None: 1920 exp = Apply(MatchFunction([MatchCase(pattern=String("a"), body=String("b"))]), String("c")) 1921 with self.assertRaisesRegex(MatchError, "no matching cases"): 1922 eval_exp({}, exp) 1923 1924 def test_match_falls_through_to_next(self) -> None: 1925 exp = Apply( 1926 MatchFunction([MatchCase(pattern=Int(3), body=Int(4)), MatchCase(pattern=Int(1), body=Int(2))]), Int(1) 1927 ) 1928 self.assertEqual(eval_exp({}, exp), Int(2)) 1929 1930 def test_eval_compose(self) -> None: 1931 gensym_reset() 1932 exp = parse(tokenize("(x -> x + 3) << (x -> x * 2)")) 1933 env = {"a": Int(1)} 1934 expected = Closure( 1935 {}, 1936 Function( 1937 Var("$v0"), 1938 Apply( 1939 Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Int(3))), 1940 Apply(Function(Var("x"), Binop(BinopKind.MUL, Var("x"), Int(2))), Var("$v0")), 1941 ), 1942 ), 1943 ) 1944 self.assertEqual(eval_exp(env, exp), expected) 1945 1946 def test_eval_native_function_returns_function(self) -> None: 1947 exp = NativeFunction("times2", lambda x: Int(x.value * 2)) # type: ignore [attr-defined] 1948 self.assertIs(eval_exp({}, exp), exp) 1949 1950 def test_eval_apply_native_function_calls_function(self) -> None: 1951 exp = Apply(NativeFunction("times2", lambda x: Int(x.value * 2)), Int(3)) # type: ignore [attr-defined] 1952 self.assertEqual(eval_exp({}, exp), Int(6)) 1953 1954 def test_eval_apply_quote_returns_ast(self) -> None: 1955 ast = Binop(BinopKind.ADD, Int(1), Int(2)) 1956 exp = Apply(Var("$$quote"), ast) 1957 self.assertIs(eval_exp({}, exp), ast) 1958 1959 def test_eval_apply_closure_with_match_function_has_access_to_closure_vars(self) -> None: 1960 ast = Apply(Closure({"x": Int(1)}, MatchFunction([MatchCase(Var("y"), Var("x"))])), Int(2)) 1961 self.assertEqual(eval_exp({}, ast), Int(1)) 1962 1963 def test_eval_less_returns_bool(self) -> None: 1964 ast = Binop(BinopKind.LESS, Int(3), Int(4)) 1965 self.assertEqual(eval_exp({}, ast), TRUE) 1966 1967 def test_eval_less_on_non_bool_raises_type_error(self) -> None: 1968 ast = Binop(BinopKind.LESS, String("xyz"), Int(4)) 1969 with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")): 1970 eval_exp({}, ast) 1971 1972 def test_eval_less_equal_returns_bool(self) -> None: 1973 ast = Binop(BinopKind.LESS_EQUAL, Int(3), Int(4)) 1974 self.assertEqual(eval_exp({}, ast), TRUE) 1975 1976 def test_eval_less_equal_on_non_bool_raises_type_error(self) -> None: 1977 ast = Binop(BinopKind.LESS_EQUAL, String("xyz"), Int(4)) 1978 with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")): 1979 eval_exp({}, ast) 1980 1981 def test_eval_greater_returns_bool(self) -> None: 1982 ast = Binop(BinopKind.GREATER, Int(3), Int(4)) 1983 self.assertEqual(eval_exp({}, ast), FALSE) 1984 1985 def test_eval_greater_on_non_bool_raises_type_error(self) -> None: 1986 ast = Binop(BinopKind.GREATER, String("xyz"), Int(4)) 1987 with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")): 1988 eval_exp({}, ast) 1989 1990 def test_eval_greater_equal_returns_bool(self) -> None: 1991 ast = Binop(BinopKind.GREATER_EQUAL, Int(3), Int(4)) 1992 self.assertEqual(eval_exp({}, ast), FALSE) 1993 1994 def test_eval_greater_equal_on_non_bool_raises_type_error(self) -> None: 1995 ast = Binop(BinopKind.GREATER_EQUAL, String("xyz"), Int(4)) 1996 with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")): 1997 eval_exp({}, ast) 1998 1999 def test_boolean_and_evaluates_args(self) -> None: 2000 ast = Binop(BinopKind.BOOL_AND, TRUE, Var("a")) 2001 self.assertEqual(eval_exp({"a": FALSE}, ast), FALSE) 2002 2003 ast = Binop(BinopKind.BOOL_AND, Var("a"), FALSE) 2004 self.assertEqual(eval_exp({"a": TRUE}, ast), FALSE) 2005 2006 def test_boolean_or_evaluates_args(self) -> None: 2007 ast = Binop(BinopKind.BOOL_OR, FALSE, Var("a")) 2008 self.assertEqual(eval_exp({"a": TRUE}, ast), TRUE) 2009 2010 ast = Binop(BinopKind.BOOL_OR, Var("a"), TRUE) 2011 self.assertEqual(eval_exp({"a": FALSE}, ast), TRUE) 2012 2013 def test_boolean_and_short_circuit(self) -> None: 2014 def raise_func(message: Object) -> Object: 2015 if not isinstance(message, String): 2016 raise TypeError(f"raise_func expected String, but got {type(message).__name__}") 2017 raise RuntimeError(message) 2018 2019 error = NativeFunction("error", raise_func) 2020 apply = Apply(Var("error"), String("expected failure")) 2021 ast = Binop(BinopKind.BOOL_AND, FALSE, apply) 2022 self.assertEqual(eval_exp({"error": error}, ast), FALSE) 2023 2024 def test_boolean_or_short_circuit(self) -> None: 2025 def raise_func(message: Object) -> Object: 2026 if not isinstance(message, String): 2027 raise TypeError(f"raise_func expected String, but got {type(message).__name__}") 2028 raise RuntimeError(message) 2029 2030 error = NativeFunction("error", raise_func) 2031 apply = Apply(Var("error"), String("expected failure")) 2032 ast = Binop(BinopKind.BOOL_OR, TRUE, apply) 2033 self.assertEqual(eval_exp({"error": error}, ast), TRUE) 2034 2035 def test_boolean_and_on_int_raises_type_error(self) -> None: 2036 exp = Binop(BinopKind.BOOL_AND, Int(1), Int(2)) 2037 with self.assertRaisesRegex(TypeError, re.escape("expected #true or #false, got Int")): 2038 eval_exp({}, exp) 2039 2040 def test_boolean_or_on_int_raises_type_error(self) -> None: 2041 exp = Binop(BinopKind.BOOL_OR, Int(1), Int(2)) 2042 with self.assertRaisesRegex(TypeError, re.escape("expected #true or #false, got Int")): 2043 eval_exp({}, exp) 2044 2045 def test_eval_record_with_spread_fails(self) -> None: 2046 exp = Record({"x": Spread()}) 2047 with self.assertRaisesRegex(RuntimeError, "cannot evaluate a spread"): 2048 eval_exp({}, exp) 2049 2050 def test_eval_variant_returns_variant(self) -> None: 2051 self.assertEqual( 2052 eval_exp( 2053 {}, 2054 Variant("abc", Binop(BinopKind.ADD, Int(1), Int(2))), 2055 ), 2056 Variant("abc", Int(3)), 2057 ) 2058 2059 def test_eval_float_and_float_addition_returns_float(self) -> None: 2060 self.assertEqual(eval_exp({}, Binop(BinopKind.ADD, Float(1.0), Float(2.0))), Float(3.0)) 2061 2062 def test_eval_int_and_float_addition_returns_float(self) -> None: 2063 self.assertEqual(eval_exp({}, Binop(BinopKind.ADD, Int(1), Float(2.0))), Float(3.0)) 2064 2065 def test_eval_int_and_float_division_returns_float(self) -> None: 2066 self.assertEqual(eval_exp({}, Binop(BinopKind.DIV, Int(1), Float(2.0))), Float(0.5)) 2067 2068 def test_eval_float_and_int_division_returns_float(self) -> None: 2069 self.assertEqual(eval_exp({}, Binop(BinopKind.DIV, Float(1.0), Int(2))), Float(0.5)) 2070 2071 def test_eval_int_and_int_division_returns_float(self) -> None: 2072 self.assertEqual(eval_exp({}, Binop(BinopKind.DIV, Int(1), Int(2))), Float(0.5)) 2073 2074 2075class EndToEndTestsBase(unittest.TestCase): 2076 def _run(self, text: str, env: Optional[Env] = None, check: bool = False) -> Object: 2077 tokens = tokenize(text) 2078 ast = parse(tokens) 2079 if check: 2080 infer_type(ast, OP_ENV) 2081 if env is None: 2082 env = boot_env() 2083 return eval_exp(env, ast) 2084 2085 2086class EndToEndTests(EndToEndTestsBase): 2087 def test_int_returns_int(self) -> None: 2088 self.assertEqual(self._run("1"), Int(1)) 2089 2090 def test_float_returns_float(self) -> None: 2091 self.assertEqual(self._run("3.14"), Float(3.14)) 2092 2093 def test_bytes_returns_bytes(self) -> None: 2094 self.assertEqual(self._run("~~QUJD"), Bytes(b"ABC")) 2095 2096 def test_bytes_base85_returns_bytes(self) -> None: 2097 self.assertEqual(self._run("~~85'K|(_"), Bytes(b"ABC")) 2098 2099 def test_bytes_base64_returns_bytes(self) -> None: 2100 self.assertEqual(self._run("~~64'QUJD"), Bytes(b"ABC")) 2101 2102 def test_bytes_base32_returns_bytes(self) -> None: 2103 self.assertEqual(self._run("~~32'IFBEG==="), Bytes(b"ABC")) 2104 2105 def test_bytes_base16_returns_bytes(self) -> None: 2106 self.assertEqual(self._run("~~16'414243"), Bytes(b"ABC")) 2107 2108 def test_int_add_returns_int(self) -> None: 2109 self.assertEqual(self._run("1 + 2"), Int(3)) 2110 2111 def test_int_sub_returns_int(self) -> None: 2112 self.assertEqual(self._run("1 - 2"), Int(-1)) 2113 2114 def test_string_concat_returns_string(self) -> None: 2115 self.assertEqual(self._run('"abc" ++ "def"'), String("abcdef")) 2116 2117 def test_list_cons_returns_list(self) -> None: 2118 self.assertEqual(self._run("1 >+ [2,3]"), List([Int(1), Int(2), Int(3)])) 2119 2120 def test_list_cons_nested_returns_list(self) -> None: 2121 self.assertEqual(self._run("1 >+ 2 >+ [3,4]"), List([Int(1), Int(2), Int(3), Int(4)])) 2122 2123 def test_list_append_returns_list(self) -> None: 2124 self.assertEqual(self._run("[1,2] +< 3"), List([Int(1), Int(2), Int(3)])) 2125 2126 def test_list_append_nested_returns_list(self) -> None: 2127 self.assertEqual(self._run("[1,2] +< 3 +< 4"), List([Int(1), Int(2), Int(3), Int(4)])) 2128 2129 def test_empty_list(self) -> None: 2130 self.assertEqual(self._run("[ ]"), List([])) 2131 2132 def test_empty_list_with_no_spaces(self) -> None: 2133 self.assertEqual(self._run("[]"), List([])) 2134 2135 def test_list_of_ints(self) -> None: 2136 self.assertEqual(self._run("[ 1 , 2 ]"), List([Int(1), Int(2)])) 2137 2138 def test_list_of_exprs(self) -> None: 2139 self.assertEqual( 2140 self._run("[ 1 + 2 , 3 + 4 ]"), 2141 List([Int(3), Int(7)]), 2142 ) 2143 2144 def test_where(self) -> None: 2145 self.assertEqual(self._run("a + 2 . a = 1"), Int(3)) 2146 2147 def test_nested_where(self) -> None: 2148 self.assertEqual(self._run("a + b . a = 1 . b = 2"), Int(3)) 2149 2150 def test_assert_with_truthy_cond_returns_value(self) -> None: 2151 self.assertEqual(self._run("a + 1 ? a == 1 . a = 1"), Int(2)) 2152 2153 def test_assert_with_falsey_cond_raises_assertion_error(self) -> None: 2154 with self.assertRaisesRegex(AssertionError, "condition a == 2 failed"): 2155 self._run("a + 1 ? a == 2 . a = 1") 2156 2157 def test_nested_assert(self) -> None: 2158 self.assertEqual(self._run("a + b ? a == 1 ? b == 2 . a = 1 . b = 2"), Int(3)) 2159 2160 def test_hole(self) -> None: 2161 self.assertEqual(self._run("()"), Hole()) 2162 2163 def test_bindings_behave_like_letstar(self) -> None: 2164 with self.assertRaises(NameError) as ctx: 2165 self._run("b . a = 1 . b = a") 2166 self.assertEqual(ctx.exception.args[0], "name 'a' is not defined") 2167 2168 def test_function_application_two_args(self) -> None: 2169 self.assertEqual(self._run("(a -> b -> a + b) 3 2"), Int(5)) 2170 2171 def test_function_create_list_correct_order(self) -> None: 2172 self.assertEqual(self._run("(a -> b -> [a, b]) 3 2"), List([Int(3), Int(2)])) 2173 2174 def test_create_record(self) -> None: 2175 self.assertEqual(self._run("{a = 1 + 3}"), Record({"a": Int(4)})) 2176 2177 def test_access_record(self) -> None: 2178 self.assertEqual(self._run('rec@b . rec = { a = 1, b = "x" }'), String("x")) 2179 2180 def test_access_list(self) -> None: 2181 self.assertEqual(self._run("xs@1 . xs = [1, 2, 3]"), Int(2)) 2182 2183 def test_access_list_var(self) -> None: 2184 self.assertEqual(self._run("xs@y . y = 2 . xs = [1, 2, 3]"), Int(3)) 2185 2186 def test_access_list_expr(self) -> None: 2187 self.assertEqual(self._run("xs@(1+1) . xs = [1, 2, 3]"), Int(3)) 2188 2189 def test_access_list_closure_var(self) -> None: 2190 self.assertEqual( 2191 self._run("list_at 1 [1,2,3] . list_at = idx -> ls -> ls@idx"), 2192 Int(2), 2193 ) 2194 2195 def test_functions_eval_arguments(self) -> None: 2196 self.assertEqual(self._run("(x -> x) c . c = 1"), Int(1)) 2197 2198 def test_non_var_function_arg_raises_parse_error(self) -> None: 2199 with self.assertRaises(RuntimeError) as ctx: 2200 self._run("1 -> a") 2201 self.assertEqual(ctx.exception.args[0], "expected variable in function definition 1") 2202 2203 def test_compose(self) -> None: 2204 self.assertEqual(self._run("((a -> a + 1) >> (b -> b * 2)) 3"), Int(8)) 2205 2206 def test_compose_does_not_expose_internal_x(self) -> None: 2207 with self.assertRaisesRegex(NameError, "name 'x' is not defined"): 2208 self._run("f 3 . f = ((y -> x) >> (z -> x))") 2209 2210 def test_double_compose(self) -> None: 2211 self.assertEqual(self._run("((a -> a + 1) >> (x -> x) >> (b -> b * 2)) 3"), Int(8)) 2212 2213 def test_reverse_compose(self) -> None: 2214 self.assertEqual(self._run("((a -> a + 1) << (b -> b * 2)) 3"), Int(7)) 2215 2216 def test_simple_int_match(self) -> None: 2217 self.assertEqual( 2218 self._run( 2219 """ 2220 inc 2 2221 . inc = 2222 | 1 -> 2 2223 | 2 -> 3 2224 """ 2225 ), 2226 Int(3), 2227 ) 2228 2229 def test_match_var_binds_var(self) -> None: 2230 self.assertEqual( 2231 self._run( 2232 """ 2233 id 3 2234 . id = 2235 | x -> x 2236 """ 2237 ), 2238 Int(3), 2239 ) 2240 2241 def test_match_var_binds_first_arm(self) -> None: 2242 self.assertEqual( 2243 self._run( 2244 """ 2245 id 3 2246 . id = 2247 | x -> x 2248 | y -> y * 2 2249 """ 2250 ), 2251 Int(3), 2252 ) 2253 2254 def test_match_function_can_close_over_variables(self) -> None: 2255 self.assertEqual( 2256 self._run( 2257 """ 2258 f 1 2 2259 . f = a -> 2260 | b -> a + b 2261 """ 2262 ), 2263 Int(3), 2264 ) 2265 2266 def test_match_record_binds_var(self) -> None: 2267 self.assertEqual( 2268 self._run( 2269 """ 2270 get_x rec 2271 . rec = { x = 3 } 2272 . get_x = 2273 | { x = x } -> x 2274 """ 2275 ), 2276 Int(3), 2277 ) 2278 2279 def test_match_record_binds_vars(self) -> None: 2280 self.assertEqual( 2281 self._run( 2282 """ 2283 mult rec 2284 . rec = { x = 3, y = 4 } 2285 . mult = 2286 | { x = x, y = y } -> x * y 2287 """ 2288 ), 2289 Int(12), 2290 ) 2291 2292 def test_match_record_with_extra_fields_does_not_match(self) -> None: 2293 with self.assertRaises(MatchError): 2294 self._run( 2295 """ 2296 mult rec 2297 . rec = { x = 3 } 2298 . mult = 2299 | { x = x, y = y } -> x * y 2300 """ 2301 ) 2302 2303 def test_match_record_with_constant(self) -> None: 2304 self.assertEqual( 2305 self._run( 2306 """ 2307 mult rec 2308 . rec = { x = 4, y = 5 } 2309 . mult = 2310 | { x = 3, y = y } -> 1 2311 | { x = 4, y = y } -> 2 2312 """ 2313 ), 2314 Int(2), 2315 ) 2316 2317 def test_match_record_with_non_record_fails(self) -> None: 2318 with self.assertRaises(MatchError): 2319 self._run( 2320 """ 2321 get_x 3 2322 . get_x = 2323 | { x = x } -> x 2324 """ 2325 ) 2326 2327 def test_match_record_doubly_binds_vars(self) -> None: 2328 self.assertEqual( 2329 self._run( 2330 """ 2331 get_x rec 2332 . rec = { a = 3, b = 3 } 2333 . get_x = 2334 | { a = x, b = x } -> x 2335 """ 2336 ), 2337 Int(3), 2338 ) 2339 2340 def test_match_record_spread_binds_spread(self) -> None: 2341 self.assertEqual(self._run("(| { a=1, ...rest } -> rest) {a=1, b=2, c=3}"), Record({"b": Int(2), "c": Int(3)})) 2342 2343 def test_match_list_binds_vars(self) -> None: 2344 self.assertEqual( 2345 self._run( 2346 """ 2347 mult xs 2348 . xs = [1, 2, 3, 4, 5] 2349 . mult = 2350 | [1, x, 3, y, 5] -> x * y 2351 """ 2352 ), 2353 Int(8), 2354 ) 2355 2356 def test_match_list_incorrect_length_does_not_match(self) -> None: 2357 with self.assertRaises(MatchError): 2358 self._run( 2359 """ 2360 mult xs 2361 . xs = [1, 2, 3] 2362 . mult = 2363 | [1, 2] -> 1 2364 | [1, 2, 3, 4] -> 1 2365 | [1, 3] -> 1 2366 """ 2367 ) 2368 2369 def test_match_list_with_constant(self) -> None: 2370 self.assertEqual( 2371 self._run( 2372 """ 2373 middle xs 2374 . xs = [4, 5, 6] 2375 . middle = 2376 | [1, x, 3] -> x 2377 | [4, x, 6] -> x 2378 | [7, x, 9] -> x 2379 """ 2380 ), 2381 Int(5), 2382 ) 2383 2384 def test_match_list_with_non_list_fails(self) -> None: 2385 with self.assertRaises(MatchError): 2386 self._run( 2387 """ 2388 get_x 3 2389 . get_x = 2390 | [2, x] -> x 2391 """ 2392 ) 2393 2394 def test_match_list_doubly_binds_vars(self) -> None: 2395 self.assertEqual( 2396 self._run( 2397 """ 2398 mult xs 2399 . xs = [1, 2, 3, 2, 1] 2400 . mult = 2401 | [1, x, 3, x, 1] -> x 2402 """ 2403 ), 2404 Int(2), 2405 ) 2406 2407 def test_match_list_spread_binds_spread(self) -> None: 2408 self.assertEqual(self._run("(| [x, ...xs] -> xs) [1, 2]"), List([Int(2)])) 2409 2410 def test_pipe(self) -> None: 2411 self.assertEqual(self._run("1 |> (a -> a + 2)"), Int(3)) 2412 2413 def test_pipe_nested(self) -> None: 2414 self.assertEqual(self._run("1 |> (a -> a + 2) |> (b -> b * 2)"), Int(6)) 2415 2416 def test_reverse_pipe(self) -> None: 2417 self.assertEqual(self._run("(a -> a + 2) <| 1"), Int(3)) 2418 2419 def test_reverse_pipe_nested(self) -> None: 2420 self.assertEqual(self._run("(b -> b * 2) <| (a -> a + 2) <| 1"), Int(6)) 2421 2422 def test_function_can_reference_itself(self) -> None: 2423 result = self._run( 2424 """ 2425 f 1 2426 . f = n -> f 2427 """, 2428 {}, 2429 ) 2430 self.assertEqual(result, Closure({"f": result}, Function(Var("n"), Var("f")))) 2431 2432 def test_function_can_call_itself(self) -> None: 2433 with self.assertRaises(RecursionError): 2434 self._run( 2435 """ 2436 f 1 2437 . f = n -> f n 2438 """ 2439 ) 2440 2441 def test_match_function_can_call_itself(self) -> None: 2442 self.assertEqual( 2443 self._run( 2444 """ 2445 fac 5 2446 . fac = 2447 | 0 -> 1 2448 | 1 -> 1 2449 | n -> n * fac (n - 1) 2450 """ 2451 ), 2452 Int(120), 2453 ) 2454 2455 def test_list_access_binds_tighter_than_append(self) -> None: 2456 self.assertEqual(self._run("[1, 2, 3] +< xs@0 . xs = [4]"), List([Int(1), Int(2), Int(3), Int(4)])) 2457 2458 def test_exponentiation(self) -> None: 2459 self.assertEqual(self._run("6 ^ 2"), Int(36)) 2460 2461 def test_modulus(self) -> None: 2462 self.assertEqual(self._run("11 % 3"), Int(2)) 2463 2464 def test_exp_binds_tighter_than_mul(self) -> None: 2465 self.assertEqual(self._run("5 * 2 ^ 3"), Int(40)) 2466 2467 def test_variant_true_returns_true(self) -> None: 2468 self.assertEqual(self._run("# true ()", {}), TRUE) 2469 2470 def test_variant_false_returns_false(self) -> None: 2471 self.assertEqual(self._run("#false ()", {}), FALSE) 2472 2473 def test_boolean_and_binds_tighter_than_or(self) -> None: 2474 self.assertEqual(self._run("#true () || #true () && boom", {}), TRUE) 2475 2476 def test_compare_binds_tighter_than_boolean_and(self) -> None: 2477 self.assertEqual(self._run("1 < 2 && 2 < 1"), FALSE) 2478 2479 def test_match_list_spread(self) -> None: 2480 self.assertEqual( 2481 self._run( 2482 """ 2483 f [2, 4, 6] 2484 . f = 2485 | [] -> 0 2486 | [x, ...] -> x 2487 | c -> 1 2488 """ 2489 ), 2490 Int(2), 2491 ) 2492 2493 def test_match_list_named_spread(self) -> None: 2494 self.assertEqual( 2495 self._run( 2496 """ 2497 tail [1,2,3] 2498 . tail = 2499 | [first, ...rest] -> rest 2500 """ 2501 ), 2502 List([Int(2), Int(3)]), 2503 ) 2504 2505 def test_match_record_spread(self) -> None: 2506 self.assertEqual( 2507 self._run( 2508 """ 2509 f {x = 4, y = 5} 2510 . f = 2511 | {} -> 0 2512 | {x = a, ...} -> a 2513 | c -> 1 2514 """ 2515 ), 2516 Int(4), 2517 ) 2518 2519 def test_match_expr_as_boolean_variants(self) -> None: 2520 self.assertEqual( 2521 self._run( 2522 """ 2523 say (1 < 2) 2524 . say = 2525 | #false () -> "oh no" 2526 | #true () -> "omg" 2527 """ 2528 ), 2529 String("omg"), 2530 ) 2531 2532 def test_match_variant_record(self) -> None: 2533 self.assertEqual( 2534 self._run( 2535 """ 2536 f #add {x = 3, y = 4} 2537 . f = 2538 | # b () -> "foo" 2539 | #add {x = x, y = y} -> x + y 2540 """ 2541 ), 2542 Int(7), 2543 ) 2544 2545 def test_int_div_returns_float(self) -> None: 2546 self.assertEqual(self._run("1 / 2 + 3"), Float(3.5)) 2547 with self.assertRaisesRegex(InferenceError, "int and float"): 2548 self._run("1 / 2 + 3", check=True) 2549 2550 def test_eval_count_bits_function_preserves_source_extents(self) -> None: 2551 env_object = self._run( 2552 """ 2553 count_bits = counts -> 2554 | [1, ...bits] -> count_bits { zeros = counts@zeros, ones = 1 + counts@ones } bits 2555 | [0, ...bits] -> count_bits { zeros = 1 + counts@zeros, ones = counts@ones } bits 2556 | [] -> counts 2557 """ 2558 ) 2559 2560 assert isinstance(env_object, EnvObject) 2561 2562 count_bits_closure = env_object.env["count_bits"] 2563 2564 assert isinstance(count_bits_closure, Closure) 2565 2566 outer_function = count_bits_closure.func 2567 outer_function_source_extent = SourceExtent( 2568 start=SourceLocation(lineno=2, colno=22, byteno=22), 2569 end=SourceLocation(lineno=5, colno=24, byteno=241), 2570 ) 2571 2572 assert isinstance(outer_function, Function) 2573 2574 inner_function = outer_function.body 2575 2576 assert isinstance(inner_function, MatchFunction) 2577 2578 match_function_one = inner_function.cases[0] 2579 match_function_one_source_extent = SourceExtent( 2580 start=SourceLocation(lineno=3, colno=11, byteno=42), end=SourceLocation(lineno=3, colno=92, byteno=123) 2581 ) 2582 2583 match_function_two = inner_function.cases[1] 2584 match_function_two_source_extent = SourceExtent( 2585 start=SourceLocation(lineno=4, colno=11, byteno=135), end=SourceLocation(lineno=4, colno=92, byteno=216) 2586 ) 2587 2588 match_function_three = inner_function.cases[2] 2589 match_function_three_source_extent = SourceExtent( 2590 start=SourceLocation(lineno=5, colno=11, byteno=228), end=SourceLocation(lineno=5, colno=24, byteno=241) 2591 ) 2592 2593 match_functions = [match_function_one, match_function_two, match_function_three] 2594 match_function_source_extents = [ 2595 match_function_one_source_extent, 2596 match_function_two_source_extent, 2597 match_function_three_source_extent, 2598 ] 2599 2600 self.assertEqual(outer_function.source_extent, outer_function_source_extent) 2601 self.assertTrue( 2602 all( 2603 match_function.source_extent == source_extent 2604 for match_function, source_extent in zip(match_functions, match_function_source_extents) 2605 ) 2606 ) 2607 2608 def test_eval_collatz_function_preserves_source_extents(self) -> None: 2609 env_object = self._run( 2610 """ 2611 collatz = count -> 2612 | 1 -> count 2613 | n -> (n % 2 == 0) |> | #true () -> collatz (count + 1) (n // 2) 2614 | #false () -> collatz (count + 1) (3 * n + 1) 2615 """ 2616 ) 2617 2618 assert isinstance(env_object, EnvObject) 2619 2620 collatz_closure = env_object.env["collatz"] 2621 2622 assert isinstance(collatz_closure, Closure) 2623 2624 outer_function = collatz_closure.func 2625 2626 assert isinstance(outer_function, Function) 2627 2628 outer_function_source_extent = SourceExtent( 2629 start=SourceLocation(lineno=2, colno=19, byteno=19), 2630 end=SourceLocation(lineno=5, colno=79, byteno=205), 2631 ) 2632 2633 inner_function = outer_function.body 2634 2635 assert isinstance(inner_function, MatchFunction) 2636 2637 apply_ast = inner_function.cases[1].body 2638 2639 assert isinstance(apply_ast, Apply) 2640 2641 arg = apply_ast.arg 2642 func = apply_ast.func 2643 2644 arg_source_extent = SourceExtent( 2645 start=SourceLocation(lineno=4, colno=18, byteno=68), end=SourceLocation(lineno=4, colno=29, byteno=79) 2646 ) 2647 2648 func_source_extent = SourceExtent( 2649 start=SourceLocation(lineno=4, colno=34, byteno=84), end=SourceLocation(lineno=5, colno=79, byteno=205) 2650 ) 2651 2652 self.assertEqual(outer_function.source_extent, outer_function_source_extent) 2653 self.assertEqual(arg.source_extent, arg_source_extent) 2654 self.assertEqual(func.source_extent, func_source_extent) 2655 2656 2657class ClosureOptimizeTests(unittest.TestCase): 2658 def test_int(self) -> None: 2659 self.assertEqual(free_in(Int(1)), set()) 2660 2661 def test_float(self) -> None: 2662 self.assertEqual(free_in(Float(1.0)), set()) 2663 2664 def test_string(self) -> None: 2665 self.assertEqual(free_in(String("x")), set()) 2666 2667 def test_bytes(self) -> None: 2668 self.assertEqual(free_in(Bytes(b"x")), set()) 2669 2670 def test_hole(self) -> None: 2671 self.assertEqual(free_in(Hole()), set()) 2672 2673 def test_spread(self) -> None: 2674 self.assertEqual(free_in(Spread()), set()) 2675 2676 def test_spread_name(self) -> None: 2677 # TODO(max): Should this be assumed to always be in a place where it 2678 # defines a name, and therefore never have free variables? 2679 self.assertEqual(free_in(Spread("x")), {"x"}) 2680 2681 def test_nativefunction(self) -> None: 2682 self.assertEqual(free_in(NativeFunction("id", lambda x: x)), set()) 2683 2684 def test_variant(self) -> None: 2685 self.assertEqual(free_in(Variant("x", Var("y"))), {"y"}) 2686 2687 def test_var(self) -> None: 2688 self.assertEqual(free_in(Var("x")), {"x"}) 2689 2690 def test_binop(self) -> None: 2691 self.assertEqual(free_in(Binop(BinopKind.ADD, Var("x"), Var("y"))), {"x", "y"}) 2692 2693 def test_empty_list(self) -> None: 2694 self.assertEqual(free_in(List([])), set()) 2695 2696 def test_list(self) -> None: 2697 self.assertEqual(free_in(List([Var("x"), Var("y")])), {"x", "y"}) 2698 2699 def test_empty_record(self) -> None: 2700 self.assertEqual(free_in(Record({})), set()) 2701 2702 def test_record(self) -> None: 2703 self.assertEqual(free_in(Record({"x": Var("x"), "y": Var("y")})), {"x", "y"}) 2704 2705 def test_function(self) -> None: 2706 exp = parse(tokenize("x -> x + y")) 2707 self.assertEqual(free_in(exp), {"y"}) 2708 2709 def test_nested_function(self) -> None: 2710 exp = parse(tokenize("x -> y -> x + y + z")) 2711 self.assertEqual(free_in(exp), {"z"}) 2712 2713 def test_match_function(self) -> None: 2714 exp = parse(tokenize("| 1 -> x | 2 -> y | x -> 3 | z -> 4")) 2715 self.assertEqual(free_in(exp), {"x", "y"}) 2716 2717 def test_match_case_int(self) -> None: 2718 exp = MatchCase(Int(1), Var("x")) 2719 self.assertEqual(free_in(exp), {"x"}) 2720 2721 def test_match_case_var(self) -> None: 2722 exp = MatchCase(Var("x"), Binop(BinopKind.ADD, Var("x"), Var("y"))) 2723 self.assertEqual(free_in(exp), {"y"}) 2724 2725 def test_match_case_list(self) -> None: 2726 exp = MatchCase(List([Var("x")]), Binop(BinopKind.ADD, Var("x"), Var("y"))) 2727 self.assertEqual(free_in(exp), {"y"}) 2728 2729 def test_match_case_list_spread(self) -> None: 2730 exp = MatchCase(List([Spread()]), Binop(BinopKind.ADD, Var("xs"), Var("y"))) 2731 self.assertEqual(free_in(exp), {"xs", "y"}) 2732 2733 def test_match_case_list_spread_name(self) -> None: 2734 exp = MatchCase(List([Spread("xs")]), Binop(BinopKind.ADD, Var("xs"), Var("y"))) 2735 self.assertEqual(free_in(exp), {"y"}) 2736 2737 def test_match_case_record(self) -> None: 2738 exp = MatchCase( 2739 Record({"x": Int(1), "y": Var("y"), "a": Var("z")}), 2740 Binop(BinopKind.ADD, Binop(BinopKind.ADD, Var("x"), Var("y")), Var("z")), 2741 ) 2742 self.assertEqual(free_in(exp), {"x"}) 2743 2744 def test_match_case_record_spread(self) -> None: 2745 exp = MatchCase(Record({"...": Spread()}), Binop(BinopKind.ADD, Var("x"), Var("y"))) 2746 self.assertEqual(free_in(exp), {"x", "y"}) 2747 2748 def test_match_case_record_spread_name(self) -> None: 2749 exp = MatchCase(Record({"...": Spread("x")}), Binop(BinopKind.ADD, Var("x"), Var("y"))) 2750 self.assertEqual(free_in(exp), {"y"}) 2751 2752 def test_apply(self) -> None: 2753 self.assertEqual(free_in(Apply(Var("x"), Var("y"))), {"x", "y"}) 2754 2755 def test_access(self) -> None: 2756 self.assertEqual(free_in(Access(Var("x"), Var("y"))), {"x", "y"}) 2757 2758 def test_where(self) -> None: 2759 exp = parse(tokenize("x . x = 1")) 2760 self.assertEqual(free_in(exp), set()) 2761 2762 def test_where_same_name(self) -> None: 2763 exp = parse(tokenize("x . x = x+y")) 2764 self.assertEqual(free_in(exp), {"x", "y"}) 2765 2766 def test_assign(self) -> None: 2767 exp = Assign(Var("x"), Int(1)) 2768 self.assertEqual(free_in(exp), set()) 2769 2770 def test_assign_same_name(self) -> None: 2771 exp = Assign(Var("x"), Var("x")) 2772 self.assertEqual(free_in(exp), {"x"}) 2773 2774 def test_closure(self) -> None: 2775 # TODO(max): Should x be considered free in the closure if it's in the 2776 # env? 2777 exp = Closure({"x": Int(1)}, Function(Var("_"), List([Var("x"), Var("y")]))) 2778 self.assertEqual(free_in(exp), {"x", "y"}) 2779 2780 2781class StdLibTests(EndToEndTestsBase): 2782 def test_stdlib_add(self) -> None: 2783 self.assertEqual(self._run("$$add 3 4", STDLIB), Int(7)) 2784 2785 def test_stdlib_quote(self) -> None: 2786 self.assertEqual(self._run("$$quote (3 + 4)"), Binop(BinopKind.ADD, Int(3), Int(4))) 2787 2788 def test_stdlib_quote_pipe(self) -> None: 2789 self.assertEqual(self._run("3 + 4 |> $$quote"), Binop(BinopKind.ADD, Int(3), Int(4))) 2790 2791 def test_stdlib_quote_reverse_pipe(self) -> None: 2792 self.assertEqual(self._run("$$quote <| 3 + 4"), Binop(BinopKind.ADD, Int(3), Int(4))) 2793 2794 def test_stdlib_serialize(self) -> None: 2795 self.assertEqual(self._run("$$serialize 3", STDLIB), Bytes(value=b"i\x06")) 2796 2797 def test_stdlib_serialize_expr(self) -> None: 2798 self.assertEqual( 2799 self._run("(1+2) |> $$quote |> $$serialize", STDLIB), 2800 Bytes(value=b"+\x02+i\x02i\x04"), 2801 ) 2802 2803 def test_stdlib_deserialize(self) -> None: 2804 self.assertEqual(self._run("$$deserialize ~~aQY="), Int(3)) 2805 2806 def test_stdlib_deserialize_expr(self) -> None: 2807 self.assertEqual(self._run("$$deserialize ~~KwIraQJpBA=="), Binop(BinopKind.ADD, Int(1), Int(2))) 2808 2809 def test_stdlib_listlength_empty_list_returns_zero(self) -> None: 2810 self.assertEqual(self._run("$$listlength []", STDLIB), Int(0)) 2811 2812 def test_stdlib_listlength_returns_length(self) -> None: 2813 self.assertEqual(self._run("$$listlength [1,2,3]", STDLIB), Int(3)) 2814 2815 def test_stdlib_listlength_with_non_list_raises_type_error(self) -> None: 2816 with self.assertRaises(TypeError) as ctx: 2817 self._run("$$listlength 1", STDLIB) 2818 self.assertEqual(ctx.exception.args[0], "listlength expected List, but got Int") 2819 2820 2821class PreludeTests(EndToEndTestsBase): 2822 def test_id_returns_input(self) -> None: 2823 self.assertEqual(self._run("id 123"), Int(123)) 2824 2825 def test_filter_returns_matching(self) -> None: 2826 self.assertEqual( 2827 self._run( 2828 """ 2829 filter (x -> x < 4) [2, 6, 3, 7, 1, 8] 2830 """ 2831 ), 2832 List([Int(2), Int(3), Int(1)]), 2833 ) 2834 2835 def test_filter_with_function_returning_non_bool_raises_match_error(self) -> None: 2836 with self.assertRaises(MatchError): 2837 self._run( 2838 """ 2839 filter (x -> #no ()) [1] 2840 """ 2841 ) 2842 2843 def test_quicksort(self) -> None: 2844 self.assertEqual( 2845 self._run( 2846 """ 2847 quicksort [2, 6, 3, 7, 1, 8] 2848 """ 2849 ), 2850 List([Int(1), Int(2), Int(3), Int(6), Int(7), Int(8)]), 2851 ) 2852 2853 def test_quicksort_with_empty_list(self) -> None: 2854 self.assertEqual( 2855 self._run( 2856 """ 2857 quicksort [] 2858 """ 2859 ), 2860 List([]), 2861 ) 2862 2863 def test_quicksort_with_non_int_raises_type_error(self) -> None: 2864 with self.assertRaises(TypeError): 2865 self._run( 2866 """ 2867 quicksort ["a", "c", "b"] 2868 """ 2869 ) 2870 2871 def test_concat(self) -> None: 2872 self.assertEqual( 2873 self._run( 2874 """ 2875 concat [1, 2, 3] [4, 5, 6] 2876 """ 2877 ), 2878 List([Int(1), Int(2), Int(3), Int(4), Int(5), Int(6)]), 2879 ) 2880 2881 def test_concat_with_first_list_empty(self) -> None: 2882 self.assertEqual( 2883 self._run( 2884 """ 2885 concat [] [4, 5, 6] 2886 """ 2887 ), 2888 List([Int(4), Int(5), Int(6)]), 2889 ) 2890 2891 def test_concat_with_second_list_empty(self) -> None: 2892 self.assertEqual( 2893 self._run( 2894 """ 2895 concat [1, 2, 3] [] 2896 """ 2897 ), 2898 List([Int(1), Int(2), Int(3)]), 2899 ) 2900 2901 def test_concat_with_both_lists_empty(self) -> None: 2902 self.assertEqual( 2903 self._run( 2904 """ 2905 concat [] [] 2906 """ 2907 ), 2908 List([]), 2909 ) 2910 2911 def test_map(self) -> None: 2912 self.assertEqual( 2913 self._run( 2914 """ 2915 map (x -> x * 2) [3, 1, 2] 2916 """ 2917 ), 2918 List([Int(6), Int(2), Int(4)]), 2919 ) 2920 2921 def test_map_with_non_function_raises_type_error(self) -> None: 2922 with self.assertRaises(TypeError): 2923 self._run( 2924 """ 2925 map 4 [3, 1, 2] 2926 """ 2927 ) 2928 2929 def test_map_with_non_list_raises_match_error(self) -> None: 2930 with self.assertRaises(MatchError): 2931 self._run( 2932 """ 2933 map (x -> x * 2) 3 2934 """ 2935 ) 2936 2937 def test_range(self) -> None: 2938 self.assertEqual( 2939 self._run( 2940 """ 2941 range 3 2942 """ 2943 ), 2944 List([Int(0), Int(1), Int(2)]), 2945 ) 2946 2947 def test_range_with_non_int_raises_type_error(self) -> None: 2948 with self.assertRaises(TypeError): 2949 self._run( 2950 """ 2951 range "a" 2952 """ 2953 ) 2954 2955 def test_foldr(self) -> None: 2956 self.assertEqual( 2957 self._run( 2958 """ 2959 foldr (x -> a -> a + x) 0 [1, 2, 3] 2960 """ 2961 ), 2962 Int(6), 2963 ) 2964 2965 def test_foldr_on_empty_list_returns_empty_list(self) -> None: 2966 self.assertEqual( 2967 self._run( 2968 """ 2969 foldr (x -> a -> a + x) 0 [] 2970 """ 2971 ), 2972 Int(0), 2973 ) 2974 2975 def test_take(self) -> None: 2976 self.assertEqual( 2977 self._run( 2978 """ 2979 take 3 [1, 2, 3, 4, 5] 2980 """ 2981 ), 2982 List([Int(1), Int(2), Int(3)]), 2983 ) 2984 2985 def test_take_n_more_than_list_length_returns_full_list(self) -> None: 2986 self.assertEqual( 2987 self._run( 2988 """ 2989 take 5 [1, 2, 3] 2990 """ 2991 ), 2992 List([Int(1), Int(2), Int(3)]), 2993 ) 2994 2995 def test_take_with_non_int_raises_type_error(self) -> None: 2996 with self.assertRaises(TypeError): 2997 self._run( 2998 """ 2999 take "a" [1, 2, 3] 3000 """ 3001 ) 3002 3003 def test_all_returns_true(self) -> None: 3004 self.assertEqual( 3005 self._run( 3006 """ 3007 all (x -> x < 5) [1, 2, 3, 4] 3008 """ 3009 ), 3010 TRUE, 3011 ) 3012 3013 def test_all_returns_false(self) -> None: 3014 self.assertEqual( 3015 self._run( 3016 """ 3017 all (x -> x < 5) [2, 4, 6] 3018 """ 3019 ), 3020 FALSE, 3021 ) 3022 3023 def test_all_with_empty_list_returns_true(self) -> None: 3024 self.assertEqual( 3025 self._run( 3026 """ 3027 all (x -> x == 5) [] 3028 """ 3029 ), 3030 TRUE, 3031 ) 3032 3033 def test_all_with_non_bool_raises_type_error(self) -> None: 3034 with self.assertRaises(TypeError): 3035 self._run( 3036 """ 3037 all (x -> x) [1, 2, 3] 3038 """ 3039 ) 3040 3041 def test_all_short_circuits(self) -> None: 3042 self.assertEqual( 3043 self._run( 3044 """ 3045 all (x -> x > 1) [1, "a", "b"] 3046 """ 3047 ), 3048 FALSE, 3049 ) 3050 3051 def test_any_returns_true(self) -> None: 3052 self.assertEqual( 3053 self._run( 3054 """ 3055 any (x -> x < 4) [1, 3, 5] 3056 """ 3057 ), 3058 TRUE, 3059 ) 3060 3061 def test_any_returns_false(self) -> None: 3062 self.assertEqual( 3063 self._run( 3064 """ 3065 any (x -> x < 3) [4, 5, 6] 3066 """ 3067 ), 3068 FALSE, 3069 ) 3070 3071 def test_any_with_empty_list_returns_false(self) -> None: 3072 self.assertEqual( 3073 self._run( 3074 """ 3075 any (x -> x == 5) [] 3076 """ 3077 ), 3078 FALSE, 3079 ) 3080 3081 def test_any_with_non_bool_raises_type_error(self) -> None: 3082 with self.assertRaises(TypeError): 3083 self._run( 3084 """ 3085 any (x -> x) [1, 2, 3] 3086 """ 3087 ) 3088 3089 def test_any_short_circuits(self) -> None: 3090 self.assertEqual( 3091 self._run( 3092 """ 3093 any (x -> x > 1) [2, "a", "b"] 3094 """ 3095 ), 3096 Variant("true", Hole()), 3097 ) 3098 3099 def test_mul_and_div_have_left_to_right_precedence(self) -> None: 3100 self.assertEqual( 3101 self._run( 3102 """ 3103 1 / 3 * 3 3104 """ 3105 ), 3106 Float(1.0), 3107 ) 3108 3109 3110class TypeStrTests(unittest.TestCase): 3111 def test_tyvar(self) -> None: 3112 self.assertEqual(str(TyVar("a")), "'a") 3113 3114 def test_tycon(self) -> None: 3115 self.assertEqual(str(TyCon("int", [])), "int") 3116 3117 def test_tycon_one_arg(self) -> None: 3118 self.assertEqual(str(TyCon("list", [IntType])), "(int list)") 3119 3120 def test_tycon_args(self) -> None: 3121 self.assertEqual(str(TyCon("->", [IntType, IntType])), "(int->int)") 3122 3123 def test_tyrow_empty_closed(self) -> None: 3124 self.assertEqual(str(TyEmptyRow()), "{}") 3125 3126 def test_tyrow_empty_open(self) -> None: 3127 self.assertEqual(str(TyRow({}, TyVar("a"))), "{...'a}") 3128 3129 def test_tyrow_closed(self) -> None: 3130 self.assertEqual(str(TyRow({"x": IntType, "y": StringType})), "{x=int, y=string}") 3131 3132 def test_tyrow_open(self) -> None: 3133 self.assertEqual(str(TyRow({"x": IntType, "y": StringType}, TyVar("a"))), "{x=int, y=string, ...'a}") 3134 3135 def test_tyrow_chain(self) -> None: 3136 inner = TyRow({"x": IntType}) 3137 inner_var = TyVar("a") 3138 inner_var.make_equal_to(inner) 3139 outer = TyRow({"y": StringType}, inner_var) 3140 self.assertEqual(str(outer), "{x=int, y=string}") 3141 3142 def test_forall(self) -> None: 3143 self.assertEqual(str(Forall([TyVar("a"), TyVar("b")], TyVar("a"))), "(forall 'a, 'b. 'a)") 3144 3145 3146class InferTypeTests(unittest.TestCase): 3147 def setUp(self) -> None: 3148 reset_tyvar_counter() 3149 3150 def test_unify_tyvar_tyvar(self) -> None: 3151 a = TyVar("a") 3152 b = TyVar("b") 3153 unify_type(a, b) 3154 self.assertIs(a.find(), b.find()) 3155 3156 def test_unify_tyvar_tycon(self) -> None: 3157 a = TyVar("a") 3158 unify_type(a, IntType) 3159 self.assertIs(a.find(), IntType) 3160 b = TyVar("b") 3161 unify_type(b, IntType) 3162 self.assertIs(b.find(), IntType) 3163 3164 def test_unify_tycon_tycon_name_mismatch(self) -> None: 3165 with self.assertRaisesRegex(InferenceError, "Unification failed"): 3166 unify_type(IntType, StringType) 3167 3168 def test_unify_tycon_tycon_arity_mismatch(self) -> None: 3169 l = TyCon("x", [TyVar("a")]) 3170 r = TyCon("x", []) 3171 with self.assertRaisesRegex(InferenceError, "Unification failed"): 3172 unify_type(l, r) 3173 3174 def test_unify_tycon_tycon_unifies_arg(self) -> None: 3175 a = TyVar("a") 3176 b = TyVar("b") 3177 l = TyCon("x", [a]) 3178 r = TyCon("x", [b]) 3179 unify_type(l, r) 3180 self.assertIs(a.find(), b.find()) 3181 3182 def test_unify_tycon_tycon_unifies_args(self) -> None: 3183 a, b, c, d = map(TyVar, "abcd") 3184 l = func_type(a, b) 3185 r = func_type(c, d) 3186 unify_type(l, r) 3187 self.assertIs(a.find(), c.find()) 3188 self.assertIs(b.find(), d.find()) 3189 self.assertIsNot(a.find(), b.find()) 3190 3191 def test_unify_recursive_fails(self) -> None: 3192 l = TyVar("a") 3193 r = TyCon("x", [TyVar("a")]) 3194 with self.assertRaisesRegex(InferenceError, "Occurs check failed"): 3195 unify_type(l, r) 3196 3197 def test_unify_empty_row(self) -> None: 3198 unify_type(TyEmptyRow(), TyEmptyRow()) 3199 3200 def test_unify_empty_row_open(self) -> None: 3201 l = TyRow({}, TyVar("a")) 3202 r = TyRow({}, TyVar("b")) 3203 unify_type(l, r) 3204 self.assertIs(l.rest.find(), r.rest.find()) 3205 3206 def test_unify_row_unifies_fields(self) -> None: 3207 a = TyVar("a") 3208 b = TyVar("b") 3209 l = TyRow({"x": a}) 3210 r = TyRow({"x": b}) 3211 unify_type(l, r) 3212 self.assertIs(a.find(), b.find()) 3213 3214 def test_unify_empty_right(self) -> None: 3215 l = TyRow({"x": IntType}) 3216 r = TyEmptyRow() 3217 with self.assertRaisesRegex(InferenceError, "Unifying row {x=int} with empty row"): 3218 unify_type(l, r) 3219 3220 def test_unify_empty_left(self) -> None: 3221 l = TyEmptyRow() 3222 r = TyRow({"x": IntType}) 3223 with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {x=int}"): 3224 unify_type(l, r) 3225 3226 def test_unify_missing_closed(self) -> None: 3227 l = TyRow({"x": IntType}) 3228 r = TyRow({"y": IntType}) 3229 with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {y=int, ...'t0}"): 3230 unify_type(l, r) 3231 3232 def test_unify_one_open_one_closed(self) -> None: 3233 rest = TyVar("r1") 3234 l = TyRow({"x": IntType}) 3235 r = TyRow({"x": IntType}, rest) 3236 unify_type(l, r) 3237 self.assertTyEqual(rest.find(), TyEmptyRow()) 3238 3239 def test_unify_one_open_more_than_one_closed(self) -> None: 3240 rest = TyVar("r1") 3241 l = TyRow({"x": IntType}) 3242 r = TyRow({"x": IntType, "y": StringType}, rest) 3243 with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {y=string, ...'r1}"): 3244 unify_type(l, r) 3245 3246 def test_unify_one_open_one_closed_more(self) -> None: 3247 rest = TyVar("r1") 3248 l = TyRow({"x": IntType, "y": StringType}) 3249 r = TyRow({"x": IntType}, rest) 3250 unify_type(l, r) 3251 self.assertTyEqual(rest.find(), TyRow({"y": StringType})) 3252 3253 def test_unify_left_missing_open(self) -> None: 3254 l = TyRow({}, TyVar("r0")) 3255 r = TyRow({"y": IntType}, TyVar("r1")) 3256 unify_type(l, r) 3257 self.assertTyEqual(l.rest, TyRow({"y": IntType}, TyVar("r1"))) 3258 assert isinstance(r.rest, TyVar) 3259 self.assertTrue(r.rest.is_unbound()) 3260 3261 def test_unify_right_missing_open(self) -> None: 3262 l = TyRow({"x": IntType}, TyVar("r0")) 3263 r = TyRow({}, TyVar("r1")) 3264 unify_type(l, r) 3265 assert isinstance(l.rest, TyVar) 3266 self.assertTrue(l.rest.is_unbound()) 3267 self.assertTyEqual(r.rest, TyRow({"x": IntType}, TyVar("r0"))) 3268 3269 def test_unify_both_missing_open(self) -> None: 3270 l = TyRow({"x": IntType}, TyVar("r0")) 3271 r = TyRow({"y": IntType}, TyVar("r1")) 3272 unify_type(l, r) 3273 self.assertTyEqual(l.rest, TyRow({"y": IntType}, TyVar("t0"))) 3274 self.assertTyEqual(r.rest, TyRow({"x": IntType}, TyVar("t0"))) 3275 3276 def test_minimize_tyvar(self) -> None: 3277 ty = fresh_tyvar() 3278 self.assertEqual(minimize(ty), TyVar("a")) 3279 3280 def test_minimize_tycon(self) -> None: 3281 ty = func_type(TyVar("t0"), TyVar("t1"), TyVar("t0")) 3282 self.assertEqual(minimize(ty), func_type(TyVar("a"), TyVar("b"), TyVar("a"))) 3283 3284 def infer(self, expr: Object, ctx: Context) -> MonoType: 3285 return minimize(infer_type(expr, ctx)) 3286 3287 def assertTyEqual(self, l: MonoType, r: MonoType) -> bool: 3288 l = l.find() 3289 r = r.find() 3290 if isinstance(l, TyVar) and isinstance(r, TyVar): 3291 if l != r: 3292 self.fail(f"Type mismatch: {l} != {r}") 3293 return True 3294 if isinstance(l, TyCon) and isinstance(r, TyCon): 3295 if l.name != r.name: 3296 self.fail(f"Type mismatch: {l} != {r}") 3297 if len(l.args) != len(r.args): 3298 self.fail(f"Type mismatch: {l} != {r}") 3299 for l_arg, r_arg in zip(l.args, r.args): 3300 self.assertTyEqual(l_arg, r_arg) 3301 return True 3302 if isinstance(l, TyEmptyRow) and isinstance(r, TyEmptyRow): 3303 return True 3304 if isinstance(l, TyRow) and isinstance(r, TyRow): 3305 l_keys = set(l.fields.keys()) 3306 r_keys = set(r.fields.keys()) 3307 if l_keys != r_keys: 3308 self.fail(f"Type mismatch: {l} != {r}") 3309 for key in l_keys: 3310 self.assertTyEqual(l.fields[key], r.fields[key]) 3311 self.assertTyEqual(l.rest, r.rest) 3312 return True 3313 self.fail(f"Type mismatch: {l} != {r}") 3314 3315 def test_unbound_var(self) -> None: 3316 with self.assertRaisesRegex(InferenceError, "Unbound variable"): 3317 self.infer(Var("a"), {}) 3318 3319 def test_var_instantiates_scheme(self) -> None: 3320 ty = self.infer(Var("a"), {"a": Forall([TyVar("b")], TyVar("b"))}) 3321 self.assertTyEqual(ty, TyVar("a")) 3322 3323 def test_int(self) -> None: 3324 ty = self.infer(Int(123), {}) 3325 self.assertTyEqual(ty, IntType) 3326 3327 def test_float(self) -> None: 3328 ty = self.infer(Float(1.0), {}) 3329 self.assertTyEqual(ty, FloatType) 3330 3331 def test_string(self) -> None: 3332 ty = self.infer(String("abc"), {}) 3333 self.assertTyEqual(ty, StringType) 3334 3335 def test_function_returns_arg(self) -> None: 3336 ty = self.infer(Function(Var("x"), Var("x")), {}) 3337 self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a"))) 3338 3339 def test_nested_function_outer(self) -> None: 3340 ty = self.infer(Function(Var("x"), Function(Var("y"), Var("x"))), {}) 3341 self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("b"), TyVar("a"))) 3342 3343 def test_nested_function_inner(self) -> None: 3344 ty = self.infer(Function(Var("x"), Function(Var("y"), Var("y"))), {}) 3345 self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("b"), TyVar("b"))) 3346 3347 def test_apply_id_int(self) -> None: 3348 func = Function(Var("x"), Var("x")) 3349 arg = Int(123) 3350 ty = self.infer(Apply(func, arg), {}) 3351 self.assertTyEqual(ty, IntType) 3352 3353 def test_apply_two_arg_returns_function(self) -> None: 3354 func = Function(Var("x"), Function(Var("y"), Var("x"))) 3355 arg = Int(123) 3356 ty = self.infer(Apply(func, arg), {}) 3357 self.assertTyEqual(ty, func_type(TyVar("a"), IntType)) 3358 3359 def test_binop_add_constrains_int(self) -> None: 3360 expr = Binop(BinopKind.ADD, Var("x"), Var("y")) 3361 ty = self.infer( 3362 expr, 3363 { 3364 "x": Forall([], TyVar("a")), 3365 "y": Forall([], TyVar("b")), 3366 "+": Forall([], func_type(IntType, IntType, IntType)), 3367 }, 3368 ) 3369 self.assertTyEqual(ty, IntType) 3370 3371 def test_binop_add_function_constrains_int(self) -> None: 3372 x = Var("x") 3373 y = Var("y") 3374 expr = Function(Var("x"), Function(Var("y"), Binop(BinopKind.ADD, x, y))) 3375 ty = self.infer(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3376 self.assertTyEqual(ty, func_type(IntType, IntType, IntType)) 3377 self.assertTyEqual(type_of(x), IntType) 3378 self.assertTyEqual(type_of(y), IntType) 3379 3380 def test_let(self) -> None: 3381 expr = Where(Var("f"), Assign(Var("f"), Function(Var("x"), Var("x")))) 3382 ty = self.infer(expr, {}) 3383 self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a"))) 3384 3385 def test_apply_monotype_to_different_types_raises(self) -> None: 3386 expr = Where( 3387 Where(Var("x"), Assign(Var("x"), Apply(Var("f"), Int(123)))), 3388 Assign(Var("y"), Apply(Var("f"), Float(123.0))), 3389 ) 3390 ctx = {"f": Forall([], func_type(TyVar("a"), TyVar("a")))} 3391 with self.assertRaisesRegex(InferenceError, "Unification failed"): 3392 self.infer(expr, ctx) 3393 3394 def test_apply_polytype_to_different_types(self) -> None: 3395 expr = Where( 3396 Where(Var("x"), Assign(Var("x"), Apply(Var("f"), Int(123)))), 3397 Assign(Var("y"), Apply(Var("f"), Float(123.0))), 3398 ) 3399 ty = self.infer(expr, {"f": Forall([TyVar("a")], func_type(TyVar("a"), TyVar("a")))}) 3400 self.assertTyEqual(ty, IntType) 3401 3402 def test_generalization(self) -> None: 3403 # From https://okmij.org/ftp/ML/generalization.html 3404 expr = parse(tokenize("x -> (y . y = x)")) 3405 ty = self.infer(expr, {}) 3406 self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a"))) 3407 3408 def test_generalization2(self) -> None: 3409 # From https://okmij.org/ftp/ML/generalization.html 3410 expr = parse(tokenize("x -> (y . y = z -> x z)")) 3411 ty = self.infer(expr, {}) 3412 self.assertTyEqual(ty, func_type(func_type(TyVar("a"), TyVar("b")), func_type(TyVar("a"), TyVar("b")))) 3413 3414 def test_id(self) -> None: 3415 expr = Function(Var("x"), Var("x")) 3416 ty = self.infer(expr, {}) 3417 self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a"))) 3418 3419 def test_empty_list(self) -> None: 3420 expr = List([]) 3421 ty = infer_type(expr, {}) 3422 self.assertTyEqual(ty, TyCon("list", [TyVar("t0")])) 3423 3424 def test_list_int(self) -> None: 3425 expr = List([Int(123)]) 3426 ty = infer_type(expr, {}) 3427 self.assertTyEqual(ty, TyCon("list", [IntType])) 3428 3429 def test_list_mismatch(self) -> None: 3430 expr = List([Int(123), Float(123.0)]) 3431 with self.assertRaisesRegex(InferenceError, "Unification failed"): 3432 infer_type(expr, {}) 3433 3434 def test_recursive_fact(self) -> None: 3435 expr = parse(tokenize("fact . fact = | 0 -> 1 | n -> n * fact (n-1)")) 3436 ty = infer_type( 3437 expr, 3438 { 3439 "*": Forall([], func_type(IntType, IntType, IntType)), 3440 "-": Forall([], func_type(IntType, IntType, IntType)), 3441 }, 3442 ) 3443 self.assertTyEqual(ty, func_type(IntType, IntType)) 3444 3445 def test_match_int_int(self) -> None: 3446 expr = parse(tokenize("| 0 -> 1")) 3447 ty = infer_type(expr, {}) 3448 self.assertTyEqual(ty, func_type(IntType, IntType)) 3449 3450 def test_match_int_int_two_cases(self) -> None: 3451 expr = parse(tokenize("| 0 -> 1 | 1 -> 2")) 3452 ty = infer_type(expr, {}) 3453 self.assertTyEqual(ty, func_type(IntType, IntType)) 3454 3455 def test_match_int_int_int_float(self) -> None: 3456 expr = parse(tokenize("| 0 -> 1 | 1 -> 2.0")) 3457 with self.assertRaisesRegex(InferenceError, "Unification failed"): 3458 infer_type(expr, {}) 3459 3460 def test_match_int_int_float_int(self) -> None: 3461 expr = parse(tokenize("| 0 -> 1 | 1.0 -> 2")) 3462 with self.assertRaisesRegex(InferenceError, "Unification failed"): 3463 infer_type(expr, {}) 3464 3465 def test_match_var(self) -> None: 3466 expr = parse(tokenize("| x -> x + 1")) 3467 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3468 self.assertTyEqual(ty, func_type(IntType, IntType)) 3469 3470 def test_match_int_var(self) -> None: 3471 expr = parse(tokenize("| 0 -> 1 | x -> x")) 3472 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3473 self.assertTyEqual(ty, func_type(IntType, IntType)) 3474 3475 def test_match_list_of_int(self) -> None: 3476 expr = parse(tokenize("| [x] -> x + 1")) 3477 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3478 self.assertTyEqual(ty, func_type(list_type(IntType), IntType)) 3479 3480 def test_match_list_of_int_to_list(self) -> None: 3481 expr = parse(tokenize("| [x] -> [x + 1]")) 3482 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3483 self.assertTyEqual(ty, func_type(list_type(IntType), list_type(IntType))) 3484 3485 def test_match_list_of_int_to_int(self) -> None: 3486 expr = parse(tokenize("| [] -> 0 | [x] -> 1 | [x, y] -> x+y")) 3487 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3488 self.assertTyEqual(ty, func_type(list_type(IntType), IntType)) 3489 3490 def test_recursive_var_is_unbound(self) -> None: 3491 expr = parse(tokenize("a . a = a")) 3492 with self.assertRaisesRegex(InferenceError, "Unbound variable"): 3493 infer_type(expr, {}) 3494 3495 def test_recursive(self) -> None: 3496 expr = parse( 3497 tokenize( 3498 """ 3499 length 3500 . length = 3501 | [] -> 0 3502 | [x, ...xs] -> 1 + length xs 3503 """ 3504 ) 3505 ) 3506 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3507 self.assertTyEqual(ty, func_type(list_type(TyVar("t8")), IntType)) 3508 3509 def test_match_list_to_list(self) -> None: 3510 expr = parse(tokenize("| [] -> [] | x -> x")) 3511 ty = infer_type(expr, {}) 3512 self.assertTyEqual(ty, func_type(list_type(TyVar("t1")), list_type(TyVar("t1")))) 3513 3514 def test_match_list_spread(self) -> None: 3515 expr = parse(tokenize("head . head = | [x, ...] -> x")) 3516 ty = infer_type(expr, {}) 3517 self.assertTyEqual(ty, func_type(list_type(TyVar("t4")), TyVar("t4"))) 3518 3519 def test_match_list_spread_rest(self) -> None: 3520 expr = parse(tokenize("tail . tail = | [x, ...xs] -> xs")) 3521 ty = infer_type(expr, {}) 3522 self.assertTyEqual(ty, func_type(list_type(TyVar("t4")), list_type(TyVar("t4")))) 3523 3524 def test_match_list_spread_named(self) -> None: 3525 expr = parse(tokenize("sum . sum = | [] -> 0 | [x, ...xs] -> x + sum xs")) 3526 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3527 self.assertTyEqual(ty, func_type(list_type(IntType), IntType)) 3528 3529 def test_match_list_int_to_list(self) -> None: 3530 expr = parse(tokenize("| [] -> [3] | x -> x")) 3531 ty = infer_type(expr, {}) 3532 self.assertTyEqual(ty, func_type(list_type(IntType), list_type(IntType))) 3533 3534 def test_inc(self) -> None: 3535 expr = parse(tokenize("inc . inc = | 0 -> 1 | 1 -> 2 | a -> a + 1")) 3536 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3537 self.assertTyEqual(ty, func_type(IntType, IntType)) 3538 3539 def test_bytes(self) -> None: 3540 expr = Bytes(b"abc") 3541 ty = infer_type(expr, {}) 3542 self.assertTyEqual(ty, BytesType) 3543 3544 def test_hole(self) -> None: 3545 expr = Hole() 3546 ty = infer_type(expr, {}) 3547 self.assertTyEqual(ty, HoleType) 3548 3549 def test_string_concat(self) -> None: 3550 expr = parse(tokenize('"a" ++ "b"')) 3551 ty = infer_type(expr, OP_ENV) 3552 self.assertTyEqual(ty, StringType) 3553 3554 def test_cons(self) -> None: 3555 expr = parse(tokenize("1 >+ [2]")) 3556 ty = infer_type(expr, OP_ENV) 3557 self.assertTyEqual(ty, list_type(IntType)) 3558 3559 def test_append(self) -> None: 3560 expr = parse(tokenize("[1] +< 2")) 3561 ty = infer_type(expr, OP_ENV) 3562 self.assertTyEqual(ty, list_type(IntType)) 3563 3564 def test_record(self) -> None: 3565 expr = Record({"a": Int(1), "b": String("hello")}) 3566 ty = infer_type(expr, {}) 3567 self.assertTyEqual(ty, TyRow({"a": IntType, "b": StringType})) 3568 3569 def test_match_record(self) -> None: 3570 expr = MatchFunction( 3571 [ 3572 MatchCase( 3573 Record({"x": Var("x")}), 3574 Var("x"), 3575 ) 3576 ] 3577 ) 3578 ty = infer_type(expr, {}) 3579 self.assertTyEqual(ty, func_type(TyRow({"x": TyVar("t1")}), TyVar("t1"))) 3580 3581 def test_access_poly(self) -> None: 3582 expr = Function(Var("r"), Access(Var("r"), Var("x"))) 3583 ty = infer_type(expr, {}) 3584 self.assertTyEqual(ty, func_type(TyRow({"x": TyVar("t1")}, TyVar("t2")), TyVar("t1"))) 3585 3586 def test_apply_row(self) -> None: 3587 row0 = Record({"x": Int(1)}) 3588 row1 = Record({"x": Int(1), "y": Int(2)}) 3589 scheme = Forall([], func_type(TyRow({"x": IntType}, TyVar("a")), IntType)) 3590 ty0 = infer_type(Apply(Var("f"), row0), {"f": scheme}) 3591 self.assertTyEqual(ty0, IntType) 3592 with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {y=int}"): 3593 infer_type(Apply(Var("f"), row1), {"f": scheme}) 3594 3595 def test_apply_row_polymorphic(self) -> None: 3596 row0 = Record({"x": Int(1)}) 3597 row1 = Record({"x": Int(1), "y": Int(2)}) 3598 row2 = Record({"x": Int(1), "y": Int(2), "z": Int(3)}) 3599 scheme = Forall([TyVar("a")], func_type(TyRow({"x": IntType}, TyVar("a")), IntType)) 3600 ty0 = infer_type(Apply(Var("f"), row0), {"f": scheme}) 3601 self.assertTyEqual(ty0, IntType) 3602 ty1 = infer_type(Apply(Var("f"), row1), {"f": scheme}) 3603 self.assertTyEqual(ty1, IntType) 3604 ty2 = infer_type(Apply(Var("f"), row2), {"f": scheme}) 3605 self.assertTyEqual(ty2, IntType) 3606 3607 def test_example_rec_access(self) -> None: 3608 expr = parse(tokenize('rec@a . rec = { a = 1, b = "x" }')) 3609 ty = infer_type(expr, {}) 3610 self.assertTyEqual(ty, IntType) 3611 3612 def test_example_rec_access_poly(self) -> None: 3613 expr = parse( 3614 tokenize( 3615 """ 3616(get_x {x=1}) + (get_x {x=2,y=3}) 3617. get_x = | { x=x, ... } -> x 3618""" 3619 ) 3620 ) 3621 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3622 self.assertTyEqual(ty, IntType) 3623 3624 def test_example_rec_access_poly_named_bug(self) -> None: 3625 expr = parse( 3626 tokenize( 3627 """ 3628(filter_x {x=1, y=2}) + 3 3629. filter_x = | { x=x, ...xs } -> xs 3630""" 3631 ) 3632 ) 3633 with self.assertRaisesRegex(InferenceError, "Cannot unify int and {y=int}"): 3634 infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3635 3636 def test_example_rec_access_rest(self) -> None: 3637 expr = parse( 3638 tokenize( 3639 """ 3640| { x=x, ...xs } -> xs 3641""" 3642 ) 3643 ) 3644 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3645 self.assertTyEqual(ty, func_type(TyRow({"x": TyVar("t1")}, TyVar("t2")), TyVar("t2"))) 3646 3647 def test_example_match_rec_access_rest(self) -> None: 3648 expr = parse( 3649 tokenize( 3650 """ 3651filter_x {x=1, y=2} 3652. filter_x = | { x=x, ...xs } -> xs 3653""" 3654 ) 3655 ) 3656 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 3657 self.assertTyEqual(ty, TyRow({"y": IntType})) 3658 3659 def test_example_rec_access_poly_named(self) -> None: 3660 expr = parse( 3661 tokenize( 3662 """ 3663[(filter_x {x=1, y=2}), (filter_x {x=2, y=3, z=4})] 3664. filter_x = | { x=x, ...xs } -> xs 3665""" 3666 ) 3667 ) 3668 with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {z=int}"): 3669 infer_type(expr, {}) 3670 3671 3672class SerializerTests(unittest.TestCase): 3673 def _serialize(self, obj: Object) -> bytes: 3674 serializer = Serializer() 3675 serializer.serialize(obj) 3676 return bytes(serializer.output) 3677 3678 def test_short(self) -> None: 3679 self.assertEqual(self._serialize(Int(-1)), TYPE_SHORT + b"\x01") 3680 self.assertEqual(self._serialize(Int(0)), TYPE_SHORT + b"\x00") 3681 self.assertEqual(self._serialize(Int(1)), TYPE_SHORT + b"\x02") 3682 self.assertEqual(self._serialize(Int(-(2**33))), TYPE_SHORT + b"\xff\xff\xff\xff?") 3683 self.assertEqual(self._serialize(Int(2**33)), TYPE_SHORT + b"\x80\x80\x80\x80@") 3684 self.assertEqual(self._serialize(Int(-(2**63))), TYPE_SHORT + b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01") 3685 self.assertEqual(self._serialize(Int(2**63 - 1)), TYPE_SHORT + b"\xfe\xff\xff\xff\xff\xff\xff\xff\xff\x01") 3686 3687 def test_long(self) -> None: 3688 self.assertEqual( 3689 self._serialize(Int(2**100)), 3690 TYPE_LONG + b"\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00", 3691 ) 3692 self.assertEqual( 3693 self._serialize(Int(-(2**100))), 3694 TYPE_LONG + b"\x04\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x1f\x00\x00\x00", 3695 ) 3696 3697 def test_string(self) -> None: 3698 self.assertEqual(self._serialize(String("hello")), TYPE_STRING + b"\nhello") 3699 3700 def test_empty_list(self) -> None: 3701 obj = List([]) 3702 self.assertEqual(self._serialize(obj), ref(TYPE_LIST) + b"\x00") 3703 3704 def test_list(self) -> None: 3705 obj = List([Int(123), Int(456)]) 3706 self.assertEqual(self._serialize(obj), ref(TYPE_LIST) + b"\x04i\xf6\x01i\x90\x07") 3707 3708 def test_self_referential_list(self) -> None: 3709 obj = List([]) 3710 obj.items.append(obj) 3711 self.assertEqual(self._serialize(obj), ref(TYPE_LIST) + b"\x02r\x00") 3712 3713 def test_variant(self) -> None: 3714 obj = Variant("abc", Int(123)) 3715 self.assertEqual(self._serialize(obj), TYPE_VARIANT + b"\x06abci\xf6\x01") 3716 3717 def test_record(self) -> None: 3718 obj = Record({"x": Int(1), "y": Int(2)}) 3719 self.assertEqual(self._serialize(obj), TYPE_RECORD + b"\x04\x02xi\x02\x02yi\x04") 3720 3721 def test_var(self) -> None: 3722 obj = Var("x") 3723 self.assertEqual(self._serialize(obj), TYPE_VAR + b"\x02x") 3724 3725 def test_function(self) -> None: 3726 obj = Function(Var("x"), Var("x")) 3727 self.assertEqual(self._serialize(obj), TYPE_FUNCTION + b"v\x02xv\x02x") 3728 3729 def test_empty_match_function(self) -> None: 3730 obj = MatchFunction([]) 3731 self.assertEqual(self._serialize(obj), TYPE_MATCH_FUNCTION + b"\x00") 3732 3733 def test_match_function(self) -> None: 3734 obj = MatchFunction([MatchCase(Int(1), Var("x")), MatchCase(List([Int(1)]), Var("y"))]) 3735 self.assertEqual(self._serialize(obj), TYPE_MATCH_FUNCTION + b"\x04i\x02v\x02x\xdb\x02i\x02v\x02y") 3736 3737 def test_closure(self) -> None: 3738 obj = Closure({}, Function(Var("x"), Var("x"))) 3739 self.assertEqual(self._serialize(obj), ref(TYPE_CLOSURE) + b"fv\x02xv\x02x\x00") 3740 3741 def test_self_referential_closure(self) -> None: 3742 obj = Closure({}, Function(Var("x"), Var("x"))) 3743 assert isinstance(obj.env, dict) # For mypy 3744 obj.env["self"] = obj 3745 self.assertEqual(self._serialize(obj), ref(TYPE_CLOSURE) + b"fv\x02xv\x02x\x02\x08selfr\x00") 3746 3747 def test_bytes(self) -> None: 3748 obj = Bytes(b"abc") 3749 self.assertEqual(self._serialize(obj), TYPE_BYTES + b"\x06abc") 3750 3751 def test_float(self) -> None: 3752 obj = Float(3.14) 3753 self.assertEqual(self._serialize(obj), TYPE_FLOAT + b"\x1f\x85\xebQ\xb8\x1e\t@") 3754 3755 def test_hole(self) -> None: 3756 self.assertEqual(self._serialize(Hole()), TYPE_HOLE) 3757 3758 def test_assign(self) -> None: 3759 obj = Assign(Var("x"), Int(123)) 3760 self.assertEqual(self._serialize(obj), TYPE_ASSIGN + b"v\x02xi\xf6\x01") 3761 3762 def test_binop(self) -> None: 3763 obj = Binop(BinopKind.ADD, Int(3), Int(4)) 3764 self.assertEqual(self._serialize(obj), TYPE_BINOP + b"\x02+i\x06i\x08") 3765 3766 def test_apply(self) -> None: 3767 obj = Apply(Var("f"), Var("x")) 3768 self.assertEqual(self._serialize(obj), TYPE_APPLY + b"v\x02fv\x02x") 3769 3770 def test_where(self) -> None: 3771 obj = Where(Var("a"), Var("b")) 3772 self.assertEqual(self._serialize(obj), TYPE_WHERE + b"v\x02av\x02b") 3773 3774 def test_access(self) -> None: 3775 obj = Access(Var("a"), Var("b")) 3776 self.assertEqual(self._serialize(obj), TYPE_ACCESS + b"v\x02av\x02b") 3777 3778 def test_spread(self) -> None: 3779 self.assertEqual(self._serialize(Spread()), TYPE_SPREAD) 3780 self.assertEqual(self._serialize(Spread("rest")), TYPE_NAMED_SPREAD + b"\x08rest") 3781 3782 def test_true_variant(self) -> None: 3783 obj = Variant("true", Hole()) 3784 self.assertEqual(self._serialize(obj), TYPE_TRUE) 3785 3786 def test_false_variant(self) -> None: 3787 obj = Variant("false", Hole()) 3788 self.assertEqual(self._serialize(obj), TYPE_FALSE) 3789 3790 def test_true_variant_with_non_hole_uses_regular_variant(self) -> None: 3791 obj = Variant("true", Int(123)) 3792 self.assertEqual(self._serialize(obj), TYPE_VARIANT + b"\x08truei\xf6\x01") 3793 3794 def test_false_variant_with_non_hole_uses_regular_variant(self) -> None: 3795 obj = Variant("false", Int(123)) 3796 self.assertEqual(self._serialize(obj), TYPE_VARIANT + b"\x0afalsei\xf6\x01") 3797 3798 3799class RoundTripSerializationTests(unittest.TestCase): 3800 def _serialize(self, obj: Object) -> bytes: 3801 serializer = Serializer() 3802 serializer.serialize(obj) 3803 return bytes(serializer.output) 3804 3805 def _deserialize(self, flat: bytes) -> Object: 3806 deserializer = Deserializer(flat) 3807 return deserializer.parse() 3808 3809 def _serde(self, obj: Object) -> Object: 3810 flat = self._serialize(obj) 3811 return self._deserialize(flat) 3812 3813 def _rt(self, obj: Object) -> None: 3814 result = self._serde(obj) 3815 self.assertEqual(result, obj) 3816 3817 def test_short(self) -> None: 3818 for i in range(-(2**16), 2**16): 3819 self._rt(Int(i)) 3820 3821 self._rt(Int(-(2**63))) 3822 self._rt(Int(2**63 - 1)) 3823 3824 def test_long(self) -> None: 3825 self._rt(Int(2**100)) 3826 self._rt(Int(-(2**100))) 3827 3828 def test_string(self) -> None: 3829 self._rt(String("")) 3830 self._rt(String("a")) 3831 self._rt(String("hello")) 3832 3833 def test_list(self) -> None: 3834 self._rt(List([])) 3835 self._rt(List([Int(123), Int(345)])) 3836 3837 def test_self_referential_list(self) -> None: 3838 ls = List([]) 3839 ls.items.append(ls) 3840 result = self._serde(ls) 3841 self.assertIsInstance(result, List) 3842 assert isinstance(result, List) # For mypy 3843 self.assertIsInstance(result.items, list) 3844 self.assertEqual(len(result.items), 1) 3845 self.assertIs(result.items[0], result) 3846 3847 def test_record(self) -> None: 3848 self._rt(Record({"x": Int(1), "y": Int(2)})) 3849 3850 def test_variant(self) -> None: 3851 self._rt(Variant("abc", Int(123))) 3852 3853 def test_var(self) -> None: 3854 self._rt(Var("x")) 3855 3856 def test_function(self) -> None: 3857 self._rt(Function(Var("x"), Var("x"))) 3858 3859 def test_empty_match_function(self) -> None: 3860 self._rt(MatchFunction([])) 3861 3862 def test_match_function(self) -> None: 3863 obj = MatchFunction([MatchCase(Int(1), Var("x")), MatchCase(List([Int(1)]), Var("y"))]) 3864 self._rt(obj) 3865 3866 def test_closure(self) -> None: 3867 self._rt(Closure({}, Function(Var("x"), Var("x")))) 3868 3869 def test_self_referential_closure(self) -> None: 3870 obj = Closure({}, Function(Var("x"), Var("x"))) 3871 assert isinstance(obj.env, dict) # For mypy 3872 obj.env["self"] = obj 3873 result = self._serde(obj) 3874 self.assertIsInstance(result, Closure) 3875 assert isinstance(result, Closure) # For mypy 3876 self.assertIsInstance(result.env, dict) 3877 self.assertEqual(len(result.env), 1) 3878 self.assertIs(result.env["self"], result) 3879 3880 def test_bytes(self) -> None: 3881 self._rt(Bytes(b"abc")) 3882 3883 def test_float(self) -> None: 3884 self._rt(Float(3.14)) 3885 3886 def test_hole(self) -> None: 3887 self._rt(Hole()) 3888 3889 def test_assign(self) -> None: 3890 self._rt(Assign(Var("x"), Int(123))) 3891 3892 def test_binop(self) -> None: 3893 self._rt(Binop(BinopKind.ADD, Int(3), Int(4))) 3894 3895 def test_apply(self) -> None: 3896 self._rt(Apply(Var("f"), Var("x"))) 3897 3898 def test_where(self) -> None: 3899 self._rt(Where(Var("a"), Var("b"))) 3900 3901 def test_access(self) -> None: 3902 self._rt(Access(Var("a"), Var("b"))) 3903 3904 def test_spread(self) -> None: 3905 self._rt(Spread()) 3906 self._rt(Spread("rest")) 3907 3908 3909class ScrapMonadTests(unittest.TestCase): 3910 def test_create_copies_env(self) -> None: 3911 env = {"a": Int(123)} 3912 result = ScrapMonad(env) 3913 self.assertEqual(result.env, env) 3914 self.assertIsNot(result.env, env) 3915 3916 def test_bind_returns_new_monad(self) -> None: 3917 env = {"a": Int(123)} 3918 orig = ScrapMonad(env) 3919 result, next_monad = orig.bind(Assign(Var("b"), Int(456))) 3920 self.assertEqual(orig.env, {"a": Int(123)}) 3921 self.assertEqual(next_monad.env, {"a": Int(123), "b": Int(456)}) 3922 3923 3924class PrettyPrintTests(unittest.TestCase): 3925 def test_pretty_print_int(self) -> None: 3926 obj = Int(1) 3927 self.assertEqual(pretty(obj), "1") 3928 3929 def test_pretty_print_float(self) -> None: 3930 obj = Float(3.14) 3931 self.assertEqual(pretty(obj), "3.14") 3932 3933 def test_pretty_print_string(self) -> None: 3934 obj = String("hello") 3935 self.assertEqual(pretty(obj), '"hello"') 3936 3937 def test_pretty_print_bytes(self) -> None: 3938 obj = Bytes(b"abc") 3939 self.assertEqual(pretty(obj), "~~YWJj") 3940 3941 def test_pretty_print_var(self) -> None: 3942 obj = Var("ref") 3943 self.assertEqual(pretty(obj), "ref") 3944 3945 def test_pretty_print_hole(self) -> None: 3946 obj = Hole() 3947 self.assertEqual(pretty(obj), "()") 3948 3949 def test_pretty_print_spread(self) -> None: 3950 obj = Spread() 3951 self.assertEqual(pretty(obj), "...") 3952 3953 def test_pretty_print_named_spread(self) -> None: 3954 obj = Spread("rest") 3955 self.assertEqual(pretty(obj), "...rest") 3956 3957 def test_pretty_print_binop(self) -> None: 3958 obj = Binop(BinopKind.ADD, Int(1), Int(2)) 3959 self.assertEqual(pretty(obj), "1 + 2") 3960 3961 def test_pretty_print_binop_precedence(self) -> None: 3962 obj = Binop(BinopKind.ADD, Int(1), Binop(BinopKind.MUL, Int(2), Int(3))) 3963 self.assertEqual(pretty(obj), "1 + 2 * 3") 3964 obj = Binop(BinopKind.MUL, Binop(BinopKind.ADD, Int(1), Int(2)), Int(3)) 3965 self.assertEqual(pretty(obj), "(1 + 2) * 3") 3966 3967 def test_pretty_print_int_list(self) -> None: 3968 obj = List([Int(1), Int(2), Int(3)]) 3969 self.assertEqual(pretty(obj), "[1, 2, 3]") 3970 3971 def test_pretty_print_str_list(self) -> None: 3972 obj = List([String("1"), String("2"), String("3")]) 3973 self.assertEqual(pretty(obj), '["1", "2", "3"]') 3974 3975 def test_pretty_print_recursion(self) -> None: 3976 obj = List([]) 3977 obj.items.append(obj) 3978 self.assertEqual(pretty(obj), "[...]") 3979 3980 def test_pretty_print_assign(self) -> None: 3981 obj = Assign(Var("x"), Int(3)) 3982 self.assertEqual(pretty(obj), "x = 3") 3983 3984 def test_pretty_print_function(self) -> None: 3985 obj = Function(Var("x"), Binop(BinopKind.ADD, Int(1), Var("x"))) 3986 self.assertEqual(pretty(obj), "x -> 1 + x") 3987 3988 def test_pretty_print_nested_function(self) -> None: 3989 obj = Function(Var("x"), Function(Var("y"), Binop(BinopKind.ADD, Var("x"), Var("y")))) 3990 self.assertEqual(pretty(obj), "x -> y -> x + y") 3991 3992 def test_pretty_print_apply(self) -> None: 3993 obj = Apply(Var("x"), Var("y")) 3994 self.assertEqual(pretty(obj), "x y") 3995 3996 def test_pretty_print_compose(self) -> None: 3997 gensym_reset() 3998 obj = parse(tokenize("(x -> x + 3) << (x -> x * 2)")) 3999 self.assertEqual( 4000 pretty(obj), 4001 "$v0 -> (x -> x + 3) ((x -> x * 2) $v0)", 4002 ) 4003 gensym_reset() 4004 obj = parse(tokenize("(x -> x + 3) >> (x -> x * 2)")) 4005 self.assertEqual( 4006 pretty(obj), 4007 "$v0 -> (x -> x * 2) ((x -> x + 3) $v0)", 4008 ) 4009 4010 def test_pretty_print_where(self) -> None: 4011 obj = Where( 4012 Binop(BinopKind.ADD, Var("a"), Var("b")), 4013 Assign(Var("a"), Int(1)), 4014 ) 4015 self.assertEqual(pretty(obj), "a + b . a = 1") 4016 4017 def test_pretty_print_assert(self) -> None: 4018 obj = Assert(Int(123), Variant("true", String("foo"))) 4019 self.assertEqual(pretty(obj), '123 ! #true "foo"') 4020 4021 def test_pretty_print_envobject(self) -> None: 4022 obj = EnvObject({"x": Int(1)}) 4023 self.assertEqual(pretty(obj), "EnvObject({'x': Int(value=1)})") 4024 4025 def test_pretty_print_matchfunction(self) -> None: 4026 obj = MatchFunction([MatchCase(Var("y"), Var("x"))]) 4027 self.assertEqual(pretty(obj), "| y -> x") 4028 4029 def test_pretty_print_matchfunction_precedence(self) -> None: 4030 obj = MatchFunction( 4031 [ 4032 MatchCase(Var("a"), MatchFunction([MatchCase(Var("b"), Var("c"))])), 4033 MatchCase(Var("x"), MatchFunction([MatchCase(Var("y"), Var("z"))])), 4034 ] 4035 ) 4036 self.assertEqual( 4037 pretty(obj), 4038 """\ 4039| a -> (| b -> c) 4040| x -> (| y -> z)""", 4041 ) 4042 4043 def test_pretty_print_relocation(self) -> None: 4044 obj = Relocation("relocate") 4045 self.assertEqual(pretty(obj), "Relocation(name='relocate')") 4046 4047 def test_pretty_print_nativefunction(self) -> None: 4048 obj = NativeFunction("times2", lambda x: Int(x.value * 2)) # type: ignore [attr-defined] 4049 self.assertEqual(pretty(obj), "NativeFunction(name=times2)") 4050 4051 def test_pretty_print_closure(self) -> None: 4052 obj = Closure({"a": Int(123)}, Function(Var("x"), Var("x"))) 4053 self.assertEqual(pretty(obj), "Closure(['a'], x -> x)") 4054 4055 def test_pretty_print_record(self) -> None: 4056 obj = Record({"a": Int(1), "b": Int(2)}) 4057 self.assertEqual(pretty(obj), "{a = 1, b = 2}") 4058 4059 def test_pretty_print_access(self) -> None: 4060 obj = Access(Record({"a": Int(4)}), Var("a")) 4061 self.assertEqual(pretty(obj), "{a = 4} @ a") 4062 4063 def test_pretty_print_variant(self) -> None: 4064 obj = Variant("x", Int(123)) 4065 self.assertEqual(pretty(obj), "#x 123") 4066 4067 obj = Variant("x", Function(Var("a"), Var("b"))) 4068 self.assertEqual(pretty(obj), "#x (a -> b)") 4069 4070 4071class ServerCommandTests(unittest.TestCase): 4072 def setUp(self) -> None: 4073 import threading 4074 import time 4075 import os 4076 import socket 4077 import argparse 4078 from scrapscript import server_command 4079 4080 # Find a random available port 4081 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 4082 s.bind(("127.0.0.1", 0)) 4083 self.host, self.port = s.getsockname() 4084 4085 args = argparse.Namespace( 4086 directory=os.path.join(os.path.dirname(__file__), "examples"), 4087 host=self.host, 4088 port=self.port, 4089 ) 4090 4091 self.server_thread = threading.Thread(target=server_command, args=(args,)) 4092 self.server_thread.daemon = True 4093 self.server_thread.start() 4094 4095 # Wait for the server to start 4096 while True: 4097 try: 4098 with socket.create_connection((self.host, self.port), timeout=0.1) as s: 4099 break 4100 except (ConnectionRefusedError, socket.timeout): 4101 time.sleep(0.01) 4102 4103 def tearDown(self) -> None: 4104 quit_request = urllib.request.Request(f"http://{self.host}:{self.port}/", method="QUIT") 4105 urllib.request.urlopen(quit_request) 4106 4107 def test_server_serves_scrap_by_path(self) -> None: 4108 response = urllib.request.urlopen(f"http://{self.host}:{self.port}/0_home/factorial") 4109 self.assertEqual(response.status, 200) 4110 4111 def test_server_serves_scrap_by_hash(self) -> None: 4112 response = urllib.request.urlopen(f"http://{self.host}:{self.port}/$09242a8dfec0ed32eb9ddd5452f0082998712d35306fec2042bad8ac5b6e9580") 4113 self.assertEqual(response.status, 200) 4114 4115 def test_server_fails_missing_scrap(self) -> None: 4116 with self.assertRaises(urllib.error.HTTPError) as cm: 4117 urllib.request.urlopen(f"http://{self.host}:{self.port}/foo") 4118 self.assertEqual(cm.exception.code, 404) 4119 4120 4121if __name__ == "__main__": 4122 unittest.main()