this repo has no description
1import unittest
2import re
3from typing import Optional
4import urllib.request
5
6# ruff: noqa: F405
7# ruff: noqa: F403
8from scrapscript import *
9
10
11class PeekableTests(unittest.TestCase):
12 def test_can_create_peekable(self) -> None:
13 Peekable(iter([1, 2, 3]))
14
15 def test_can_iterate_over_peekable(self) -> None:
16 sequence = [1, 2, 3]
17 for idx, e in enumerate(Peekable(iter(sequence))):
18 self.assertEqual(sequence[idx], e)
19
20 def test_peek_next(self) -> None:
21 iterator = Peekable(iter([1, 2, 3]))
22 self.assertEqual(iterator.peek(), 1)
23 self.assertEqual(next(iterator), 1)
24 self.assertEqual(iterator.peek(), 2)
25 self.assertEqual(next(iterator), 2)
26 self.assertEqual(iterator.peek(), 3)
27 self.assertEqual(next(iterator), 3)
28 with self.assertRaises(StopIteration):
29 iterator.peek()
30 with self.assertRaises(StopIteration):
31 next(iterator)
32
33 def test_can_peek_peekable(self) -> None:
34 sequence = [1, 2, 3]
35 p = Peekable(iter(sequence))
36 self.assertEqual(p.peek(), 1)
37 # Ensure we can peek repeatedly
38 self.assertEqual(p.peek(), 1)
39 for idx, e in enumerate(p):
40 self.assertEqual(sequence[idx], e)
41
42 def test_peek_on_empty_peekable_raises_stop_iteration(self) -> None:
43 empty = Peekable(iter([]))
44 with self.assertRaises(StopIteration):
45 empty.peek()
46
47 def test_next_on_empty_peekable_raises_stop_iteration(self) -> None:
48 empty = Peekable(iter([]))
49 with self.assertRaises(StopIteration):
50 next(empty)
51
52
53class TokenizerTests(unittest.TestCase):
54 def test_tokenize_digit(self) -> None:
55 self.assertEqual(list(tokenize("1")), [IntLit(1)])
56
57 def test_tokenize_multiple_digits(self) -> None:
58 self.assertEqual(list(tokenize("123")), [IntLit(123)])
59
60 def test_tokenize_negative_int(self) -> None:
61 self.assertEqual(list(tokenize("-123")), [Operator("-"), IntLit(123)])
62
63 def test_tokenize_float(self) -> None:
64 self.assertEqual(list(tokenize("3.14")), [FloatLit(3.14)])
65
66 def test_tokenize_negative_float(self) -> None:
67 self.assertEqual(list(tokenize("-3.14")), [Operator("-"), FloatLit(3.14)])
68
69 @unittest.skip("TODO: support floats with no integer part")
70 def test_tokenize_float_with_no_integer_part(self) -> None:
71 self.assertEqual(list(tokenize(".14")), [FloatLit(0.14)])
72
73 def test_tokenize_float_with_no_decimal_part(self) -> None:
74 self.assertEqual(list(tokenize("10.")), [FloatLit(10.0)])
75
76 def test_tokenize_float_with_multiple_decimal_points_raises_parse_error(self) -> None:
77 with self.assertRaisesRegex(ParseError, re.escape("unexpected token '.'")):
78 list(tokenize("1.0.1"))
79
80 def test_tokenize_binop(self) -> None:
81 self.assertEqual(list(tokenize("1 + 2")), [IntLit(1), Operator("+"), IntLit(2)])
82
83 def test_tokenize_binop_no_spaces(self) -> None:
84 self.assertEqual(list(tokenize("1+2")), [IntLit(1), Operator("+"), IntLit(2)])
85
86 def test_tokenize_two_oper_chars_returns_two_ops(self) -> None:
87 self.assertEqual(list(tokenize(",:")), [Operator(","), Operator(":")])
88
89 def test_tokenize_binary_sub_no_spaces(self) -> None:
90 self.assertEqual(list(tokenize("1-2")), [IntLit(1), Operator("-"), IntLit(2)])
91
92 def test_tokenize_binop_var(self) -> None:
93 ops = ["+", "-", "*", "/", "^", "%", "==", "/=", "<", ">", "<=", ">=", "&&", "||", "++", ">+", "+<"]
94 for op in ops:
95 with self.subTest(op=op):
96 self.assertEqual(list(tokenize(f"a {op} b")), [Name("a"), Operator(op), Name("b")])
97 self.assertEqual(list(tokenize(f"a{op}b")), [Name("a"), Operator(op), Name("b")])
98
99 def test_tokenize_var(self) -> None:
100 self.assertEqual(list(tokenize("abc")), [Name("abc")])
101
102 @unittest.skip("TODO: make this fail to tokenize")
103 def test_tokenize_var_with_quote(self) -> None:
104 self.assertEqual(list(tokenize("sha1'abc")), [Name("sha1'abc")])
105
106 def test_tokenize_dollar_sha1_var(self) -> None:
107 self.assertEqual(list(tokenize("$sha1'foo")), [Name("$sha1'foo")])
108
109 def test_tokenize_dollar_dollar_var(self) -> None:
110 self.assertEqual(list(tokenize("$$bills")), [Name("$$bills")])
111
112 def test_tokenize_dot_dot_raises_parse_error(self) -> None:
113 with self.assertRaisesRegex(ParseError, re.escape("unexpected token '..'")):
114 list(tokenize(".."))
115
116 def test_tokenize_spread(self) -> None:
117 self.assertEqual(list(tokenize("...")), [Operator("...")])
118
119 def test_ignore_whitespace(self) -> None:
120 self.assertEqual(list(tokenize("1\n+\t2")), [IntLit(1), Operator("+"), IntLit(2)])
121
122 def test_ignore_line_comment(self) -> None:
123 self.assertEqual(list(tokenize("-- 1\n2")), [IntLit(2)])
124
125 def test_tokenize_string(self) -> None:
126 self.assertEqual(list(tokenize('"hello"')), [StringLit("hello")])
127
128 def test_tokenize_string_with_spaces(self) -> None:
129 self.assertEqual(list(tokenize('"hello world"')), [StringLit("hello world")])
130
131 def test_tokenize_string_missing_end_quote_raises_parse_error(self) -> None:
132 with self.assertRaisesRegex(UnexpectedEOFError, "while reading string"):
133 list(tokenize('"hello'))
134
135 def test_tokenize_with_trailing_whitespace(self) -> None:
136 self.assertEqual(list(tokenize("- ")), [Operator("-")])
137 self.assertEqual(list(tokenize("-- ")), [])
138 self.assertEqual(list(tokenize("+ ")), [Operator("+")])
139 self.assertEqual(list(tokenize("123 ")), [IntLit(123)])
140 self.assertEqual(list(tokenize("abc ")), [Name("abc")])
141 self.assertEqual(list(tokenize("[ ")), [LeftBracket()])
142 self.assertEqual(list(tokenize("] ")), [RightBracket()])
143
144 def test_tokenize_empty_list(self) -> None:
145 self.assertEqual(list(tokenize("[ ]")), [LeftBracket(), RightBracket()])
146
147 def test_tokenize_empty_list_with_spaces(self) -> None:
148 self.assertEqual(list(tokenize("[ ]")), [LeftBracket(), RightBracket()])
149
150 def test_tokenize_list_with_items(self) -> None:
151 self.assertEqual(
152 list(tokenize("[ 1 , 2 ]")), [LeftBracket(), IntLit(1), Operator(","), IntLit(2), RightBracket()]
153 )
154
155 def test_tokenize_list_with_no_spaces(self) -> None:
156 self.assertEqual(list(tokenize("[1,2]")), [LeftBracket(), IntLit(1), Operator(","), IntLit(2), RightBracket()])
157
158 def test_tokenize_function(self) -> None:
159 self.assertEqual(
160 list(tokenize("a -> b -> a + b")),
161 [Name("a"), Operator("->"), Name("b"), Operator("->"), Name("a"), Operator("+"), Name("b")],
162 )
163
164 def test_tokenize_function_with_no_spaces(self) -> None:
165 self.assertEqual(
166 list(tokenize("a->b->a+b")),
167 [Name("a"), Operator("->"), Name("b"), Operator("->"), Name("a"), Operator("+"), Name("b")],
168 )
169
170 def test_tokenize_where(self) -> None:
171 self.assertEqual(list(tokenize("a . b")), [Name("a"), Operator("."), Name("b")])
172
173 def test_tokenize_assert(self) -> None:
174 self.assertEqual(list(tokenize("a ? b")), [Name("a"), Operator("?"), Name("b")])
175
176 def test_tokenize_hastype(self) -> None:
177 self.assertEqual(list(tokenize("a : b")), [Name("a"), Operator(":"), Name("b")])
178
179 def test_tokenize_minus_returns_minus(self) -> None:
180 self.assertEqual(list(tokenize("-")), [Operator("-")])
181
182 def test_tokenize_tilde_raises_parse_error(self) -> None:
183 with self.assertRaisesRegex(ParseError, "unexpected token '~'"):
184 list(tokenize("~"))
185
186 def test_tokenize_tilde_equals_raises_parse_error(self) -> None:
187 with self.assertRaisesRegex(ParseError, "unexpected token '~'"):
188 list(tokenize("~="))
189
190 def test_tokenize_tilde_tilde_returns_empty_bytes(self) -> None:
191 self.assertEqual(list(tokenize("~~")), [BytesLit("", 64)])
192
193 def test_tokenize_bytes_returns_bytes_base64(self) -> None:
194 self.assertEqual(list(tokenize("~~QUJD")), [BytesLit("QUJD", 64)])
195
196 def test_tokenize_bytes_base85(self) -> None:
197 self.assertEqual(list(tokenize("~~85'K|(_")), [BytesLit("K|(_", 85)])
198
199 def test_tokenize_bytes_base64(self) -> None:
200 self.assertEqual(list(tokenize("~~64'QUJD")), [BytesLit("QUJD", 64)])
201
202 def test_tokenize_bytes_base32(self) -> None:
203 self.assertEqual(list(tokenize("~~32'IFBEG===")), [BytesLit("IFBEG===", 32)])
204
205 def test_tokenize_bytes_base16(self) -> None:
206 self.assertEqual(list(tokenize("~~16'414243")), [BytesLit("414243", 16)])
207
208 def test_tokenize_hole(self) -> None:
209 self.assertEqual(list(tokenize("()")), [LeftParen(), RightParen()])
210
211 def test_tokenize_hole_with_spaces(self) -> None:
212 self.assertEqual(list(tokenize("( )")), [LeftParen(), RightParen()])
213
214 def test_tokenize_parenthetical_expression(self) -> None:
215 self.assertEqual(list(tokenize("(1+2)")), [LeftParen(), IntLit(1), Operator("+"), IntLit(2), RightParen()])
216
217 def test_tokenize_pipe(self) -> None:
218 self.assertEqual(
219 list(tokenize("1 |> f . f = a -> a + 1")),
220 [
221 IntLit(1),
222 Operator("|>"),
223 Name("f"),
224 Operator("."),
225 Name("f"),
226 Operator("="),
227 Name("a"),
228 Operator("->"),
229 Name("a"),
230 Operator("+"),
231 IntLit(1),
232 ],
233 )
234
235 def test_tokenize_reverse_pipe(self) -> None:
236 self.assertEqual(
237 list(tokenize("f <| 1 . f = a -> a + 1")),
238 [
239 Name("f"),
240 Operator("<|"),
241 IntLit(1),
242 Operator("."),
243 Name("f"),
244 Operator("="),
245 Name("a"),
246 Operator("->"),
247 Name("a"),
248 Operator("+"),
249 IntLit(1),
250 ],
251 )
252
253 def test_tokenize_record_no_fields(self) -> None:
254 self.assertEqual(
255 list(tokenize("{ }")),
256 [LeftBrace(), RightBrace()],
257 )
258
259 def test_tokenize_record_no_fields_no_spaces(self) -> None:
260 self.assertEqual(
261 list(tokenize("{}")),
262 [LeftBrace(), RightBrace()],
263 )
264
265 def test_tokenize_record_one_field(self) -> None:
266 self.assertEqual(
267 list(tokenize("{ a = 4 }")),
268 [LeftBrace(), Name("a"), Operator("="), IntLit(4), RightBrace()],
269 )
270
271 def test_tokenize_record_multiple_fields(self) -> None:
272 self.assertEqual(
273 list(tokenize('{ a = 4, b = "z" }')),
274 [
275 LeftBrace(),
276 Name("a"),
277 Operator("="),
278 IntLit(4),
279 Operator(","),
280 Name("b"),
281 Operator("="),
282 StringLit("z"),
283 RightBrace(),
284 ],
285 )
286
287 def test_tokenize_record_access(self) -> None:
288 self.assertEqual(
289 list(tokenize("r@a")),
290 [Name("r"), Operator("@"), Name("a")],
291 )
292
293 def test_tokenize_right_eval(self) -> None:
294 self.assertEqual(list(tokenize("a!b")), [Name("a"), Operator("!"), Name("b")])
295
296 def test_tokenize_match(self) -> None:
297 self.assertEqual(
298 list(tokenize("g = | 1 -> 2 | 2 -> 3")),
299 [
300 Name("g"),
301 Operator("="),
302 Operator("|"),
303 IntLit(1),
304 Operator("->"),
305 IntLit(2),
306 Operator("|"),
307 IntLit(2),
308 Operator("->"),
309 IntLit(3),
310 ],
311 )
312
313 def test_tokenize_compose(self) -> None:
314 self.assertEqual(
315 list(tokenize("f >> g")),
316 [Name("f"), Operator(">>"), Name("g")],
317 )
318
319 def test_tokenize_compose_reverse(self) -> None:
320 self.assertEqual(
321 list(tokenize("f << g")),
322 [Name("f"), Operator("<<"), Name("g")],
323 )
324
325 def test_first_lineno_is_one(self) -> None:
326 l = Lexer("abc")
327 self.assertEqual(l.lineno, 1)
328
329 def test_first_colno_is_one(self) -> None:
330 l = Lexer("abc")
331 self.assertEqual(l.colno, 1)
332
333 def test_first_line_is_empty(self) -> None:
334 l = Lexer("abc")
335 self.assertEqual(l.line, "")
336
337 def test_read_char_increments_colno(self) -> None:
338 l = Lexer("abc")
339 l.read_char()
340 self.assertEqual(l.colno, 2)
341 self.assertEqual(l.lineno, 1)
342
343 def test_read_newline_increments_lineno(self) -> None:
344 l = Lexer("ab\nc")
345 l.read_char()
346 l.read_char()
347 l.read_char()
348 self.assertEqual(l.lineno, 2)
349 self.assertEqual(l.colno, 1)
350
351 def test_read_char_increments_byteno(self) -> None:
352 l = Lexer("abc")
353 l.read_char()
354 self.assertEqual(l.byteno, 1)
355 l.read_char()
356 self.assertEqual(l.byteno, 2)
357 l.read_char()
358 self.assertEqual(l.byteno, 3)
359
360 def test_read_char_appends_to_line(self) -> None:
361 l = Lexer("ab\nc")
362 l.read_char()
363 l.read_char()
364 self.assertEqual(l.line, "ab")
365 l.read_char()
366 self.assertEqual(l.line, "")
367
368 def test_read_token_sets_start_and_end_linenos(self) -> None:
369 l = Lexer("a b \n c d")
370 a = l.read_token()
371 b = l.read_token()
372 c = l.read_token()
373 d = l.read_token()
374
375 self.assertEqual(a.source_extent.start.lineno, 1)
376 self.assertEqual(a.source_extent.end.lineno, 1)
377
378 self.assertEqual(b.source_extent.start.lineno, 1)
379 self.assertEqual(b.source_extent.end.lineno, 1)
380
381 self.assertEqual(c.source_extent.start.lineno, 2)
382 self.assertEqual(c.source_extent.end.lineno, 2)
383
384 self.assertEqual(d.source_extent.start.lineno, 2)
385 self.assertEqual(d.source_extent.end.lineno, 2)
386
387 def test_read_token_sets_source_extents_for_variables(self) -> None:
388 l = Lexer("aa bbbb \n ccccc ddddddd")
389
390 a = l.read_token()
391 b = l.read_token()
392 c = l.read_token()
393 d = l.read_token()
394
395 self.assertEqual(a.source_extent.start.lineno, 1)
396 self.assertEqual(a.source_extent.end.lineno, 1)
397 self.assertEqual(a.source_extent.start.colno, 1)
398 self.assertEqual(a.source_extent.end.colno, 2)
399 self.assertEqual(a.source_extent.start.byteno, 0)
400 self.assertEqual(a.source_extent.end.byteno, 1)
401
402 self.assertEqual(b.source_extent.start.lineno, 1)
403 self.assertEqual(b.source_extent.end.lineno, 1)
404 self.assertEqual(b.source_extent.start.colno, 4)
405 self.assertEqual(b.source_extent.end.colno, 7)
406 self.assertEqual(b.source_extent.start.byteno, 3)
407 self.assertEqual(b.source_extent.end.byteno, 6)
408
409 self.assertEqual(c.source_extent.start.lineno, 2)
410 self.assertEqual(c.source_extent.end.lineno, 2)
411 self.assertEqual(c.source_extent.start.colno, 2)
412 self.assertEqual(c.source_extent.end.colno, 6)
413 self.assertEqual(c.source_extent.start.byteno, 10)
414 self.assertEqual(c.source_extent.end.byteno, 14)
415
416 self.assertEqual(d.source_extent.start.lineno, 2)
417 self.assertEqual(d.source_extent.end.lineno, 2)
418 self.assertEqual(d.source_extent.start.colno, 8)
419 self.assertEqual(d.source_extent.end.colno, 14)
420 self.assertEqual(d.source_extent.start.byteno, 16)
421 self.assertEqual(d.source_extent.end.byteno, 22)
422
423 def test_read_token_correctly_sets_source_extents_for_variants(self) -> None:
424 l = Lexer("# \n\r\n\t abc")
425
426 a = l.read_token()
427 b = l.read_token()
428
429 self.assertEqual(a.source_extent.start.lineno, 1)
430 self.assertEqual(a.source_extent.end.lineno, 1)
431 self.assertEqual(a.source_extent.start.colno, 1)
432 # TODO(max): Should tabs count as one column?
433 self.assertEqual(a.source_extent.end.colno, 1)
434
435 self.assertEqual(b.source_extent.start.lineno, 3)
436 self.assertEqual(b.source_extent.end.lineno, 3)
437 self.assertEqual(b.source_extent.start.colno, 3)
438 self.assertEqual(b.source_extent.end.colno, 5)
439
440 def test_read_token_correctly_sets_source_extents_for_strings(self) -> None:
441 l = Lexer('"今日は、Maxさん。"')
442 a = l.read_token()
443
444 self.assertEqual(a.source_extent.start.lineno, 1)
445 self.assertEqual(a.source_extent.end.lineno, 1)
446
447 self.assertEqual(a.source_extent.start.colno, 1)
448 self.assertEqual(a.source_extent.end.colno, 12)
449
450 self.assertEqual(a.source_extent.start.byteno, 0)
451 self.assertEqual(a.source_extent.end.byteno, 25)
452
453 def test_read_token_correctly_sets_source_extents_for_byte_literals(self) -> None:
454 l = Lexer("~~QUJD ~~85'K|(_ ~~64'QUJD\n ~~32'IFBEG=== ~~16'414243")
455 a = l.read_token()
456 b = l.read_token()
457 c = l.read_token()
458 d = l.read_token()
459 e = l.read_token()
460
461 self.assertEqual(a.source_extent.start.lineno, 1)
462 self.assertEqual(a.source_extent.end.lineno, 1)
463 self.assertEqual(a.source_extent.start.colno, 1)
464 self.assertEqual(a.source_extent.end.colno, 6)
465 self.assertEqual(a.source_extent.start.byteno, 0)
466 self.assertEqual(a.source_extent.end.byteno, 5)
467
468 self.assertEqual(b.source_extent.start.lineno, 1)
469 self.assertEqual(b.source_extent.end.lineno, 1)
470 self.assertEqual(b.source_extent.start.colno, 8)
471 self.assertEqual(b.source_extent.end.colno, 16)
472 self.assertEqual(b.source_extent.start.byteno, 7)
473 self.assertEqual(b.source_extent.end.byteno, 15)
474
475 self.assertEqual(c.source_extent.start.lineno, 1)
476 self.assertEqual(c.source_extent.end.lineno, 1)
477 self.assertEqual(c.source_extent.start.colno, 18)
478 self.assertEqual(c.source_extent.end.colno, 26)
479 self.assertEqual(c.source_extent.start.byteno, 17)
480 self.assertEqual(c.source_extent.end.byteno, 25)
481
482 self.assertEqual(d.source_extent.start.lineno, 2)
483 self.assertEqual(d.source_extent.end.lineno, 2)
484 self.assertEqual(d.source_extent.start.colno, 2)
485 self.assertEqual(d.source_extent.end.colno, 14)
486 self.assertEqual(d.source_extent.start.byteno, 28)
487 self.assertEqual(d.source_extent.end.byteno, 40)
488
489 self.assertEqual(e.source_extent.start.lineno, 2)
490 self.assertEqual(e.source_extent.end.lineno, 2)
491 self.assertEqual(e.source_extent.start.colno, 16)
492 self.assertEqual(e.source_extent.end.colno, 26)
493 self.assertEqual(e.source_extent.start.byteno, 42)
494 self.assertEqual(e.source_extent.end.byteno, 52)
495
496 def test_read_token_correctly_sets_source_extents_for_numbers(self) -> None:
497 l = Lexer("123 123.456")
498 a = l.read_token()
499 b = l.read_token()
500
501 self.assertEqual(a.source_extent.start.lineno, 1)
502 self.assertEqual(a.source_extent.end.lineno, 1)
503 self.assertEqual(a.source_extent.start.colno, 1)
504 self.assertEqual(a.source_extent.end.colno, 3)
505 self.assertEqual(a.source_extent.start.byteno, 0)
506 self.assertEqual(a.source_extent.end.byteno, 2)
507
508 self.assertEqual(b.source_extent.start.lineno, 1)
509 self.assertEqual(b.source_extent.end.lineno, 1)
510 self.assertEqual(b.source_extent.start.colno, 5)
511 self.assertEqual(b.source_extent.end.colno, 11)
512 self.assertEqual(b.source_extent.start.byteno, 4)
513 self.assertEqual(b.source_extent.end.byteno, 10)
514
515 def test_read_token_correctly_sets_source_extents_for_operators(self) -> None:
516 l = Lexer("> >>")
517 a = l.read_token()
518 b = l.read_token()
519
520 self.assertEqual(a.source_extent.start.lineno, 1)
521 self.assertEqual(a.source_extent.end.lineno, 1)
522 self.assertEqual(a.source_extent.start.colno, 1)
523 self.assertEqual(a.source_extent.end.colno, 1)
524 self.assertEqual(a.source_extent.start.byteno, 0)
525 self.assertEqual(a.source_extent.end.byteno, 0)
526
527 self.assertEqual(b.source_extent.start.lineno, 1)
528 self.assertEqual(b.source_extent.end.lineno, 1)
529 self.assertEqual(b.source_extent.start.colno, 3)
530 self.assertEqual(b.source_extent.end.colno, 4)
531 self.assertEqual(b.source_extent.start.byteno, 2)
532 self.assertEqual(b.source_extent.end.byteno, 3)
533
534 def test_tokenize_list_with_only_spread(self) -> None:
535 self.assertEqual(list(tokenize("[ ... ]")), [LeftBracket(), Operator("..."), RightBracket()])
536
537 def test_tokenize_list_with_spread(self) -> None:
538 self.assertEqual(
539 list(tokenize("[ 1 , ... ]")),
540 [
541 LeftBracket(),
542 IntLit(1),
543 Operator(","),
544 Operator("..."),
545 RightBracket(),
546 ],
547 )
548
549 def test_tokenize_list_with_spread_no_spaces(self) -> None:
550 self.assertEqual(
551 list(tokenize("[ 1,... ]")),
552 [
553 LeftBracket(),
554 IntLit(1),
555 Operator(","),
556 Operator("..."),
557 RightBracket(),
558 ],
559 )
560
561 def test_tokenize_list_with_named_spread(self) -> None:
562 self.assertEqual(
563 list(tokenize("[1,...rest]")),
564 [
565 LeftBracket(),
566 IntLit(1),
567 Operator(","),
568 Operator("..."),
569 Name("rest"),
570 RightBracket(),
571 ],
572 )
573
574 def test_tokenize_record_with_only_spread(self) -> None:
575 self.assertEqual(
576 list(tokenize("{ ... }")),
577 [
578 LeftBrace(),
579 Operator("..."),
580 RightBrace(),
581 ],
582 )
583
584 def test_tokenize_record_with_spread(self) -> None:
585 self.assertEqual(
586 list(tokenize("{ x = 1, ...}")),
587 [
588 LeftBrace(),
589 Name("x"),
590 Operator("="),
591 IntLit(1),
592 Operator(","),
593 Operator("..."),
594 RightBrace(),
595 ],
596 )
597
598 def test_tokenize_record_with_spread_no_spaces(self) -> None:
599 self.assertEqual(
600 list(tokenize("{x=1,...}")),
601 [
602 LeftBrace(),
603 Name("x"),
604 Operator("="),
605 IntLit(1),
606 Operator(","),
607 Operator("..."),
608 RightBrace(),
609 ],
610 )
611
612 def test_tokenize_variant_with_whitespace(self) -> None:
613 self.assertEqual(list(tokenize("# \n\r\n\t abc")), [Hash(), Name("abc")])
614
615 def test_tokenize_variant_with_no_space(self) -> None:
616 self.assertEqual(list(tokenize("#abc")), [Hash(), Name("abc")])
617
618
619class ParserTests(unittest.TestCase):
620 def test_parse_with_empty_tokens_raises_parse_error(self) -> None:
621 with self.assertRaises(UnexpectedEOFError) as ctx:
622 parse(Peekable(iter([])))
623 self.assertEqual(ctx.exception.args[0], "unexpected end of input")
624
625 def test_parse_digit_returns_int(self) -> None:
626 self.assertEqual(parse(Peekable(iter([IntLit(1)]))), Int(1))
627
628 def test_parse_digits_returns_int(self) -> None:
629 self.assertEqual(parse(Peekable(iter([IntLit(123)]))), Int(123))
630
631 def test_parse_negative_int_returns_negative_int(self) -> None:
632 self.assertEqual(parse(Peekable(iter([Operator("-"), IntLit(123)]))), Int(-123))
633
634 def test_parse_negative_var_returns_binary_sub_var(self) -> None:
635 self.assertEqual(parse(Peekable(iter([Operator("-"), Name("x")]))), Binop(BinopKind.SUB, Int(0), Var("x")))
636
637 def test_parse_negative_int_binds_tighter_than_plus(self) -> None:
638 self.assertEqual(
639 parse(Peekable(iter([Operator("-"), Name("l"), Operator("+"), Name("r")]))),
640 Binop(BinopKind.ADD, Binop(BinopKind.SUB, Int(0), Var("l")), Var("r")),
641 )
642
643 def test_parse_negative_int_binds_tighter_than_mul(self) -> None:
644 self.assertEqual(
645 parse(Peekable(iter([Operator("-"), Name("l"), Operator("*"), Name("r")]))),
646 Binop(BinopKind.MUL, Binop(BinopKind.SUB, Int(0), Var("l")), Var("r")),
647 )
648
649 def test_parse_negative_int_binds_tighter_than_index(self) -> None:
650 self.assertEqual(
651 parse(Peekable(iter([Operator("-"), Name("l"), Operator("@"), Name("r")]))),
652 Access(Binop(BinopKind.SUB, Int(0), Var("l")), Var("r")),
653 )
654
655 def test_parse_negative_int_binds_tighter_than_apply(self) -> None:
656 self.assertEqual(
657 parse(Peekable(iter([Operator("-"), Name("l"), Name("r")]))),
658 Apply(Binop(BinopKind.SUB, Int(0), Var("l")), Var("r")),
659 )
660
661 def test_parse_decimal_returns_float(self) -> None:
662 self.assertEqual(parse(Peekable(iter([FloatLit(3.14)]))), Float(3.14))
663
664 def test_parse_negative_float_returns_binary_sub_float(self) -> None:
665 self.assertEqual(parse(Peekable(iter([Operator("-"), FloatLit(3.14)]))), Float(-3.14))
666
667 def test_parse_var_returns_var(self) -> None:
668 self.assertEqual(parse(Peekable(iter([Name("abc_123")]))), Var("abc_123"))
669
670 def test_parse_sha_var_returns_var(self) -> None:
671 self.assertEqual(parse(Peekable(iter([Name("$sha1'abc")]))), Var("$sha1'abc"))
672
673 def test_parse_sha_var_without_quote_returns_var(self) -> None:
674 self.assertEqual(parse(Peekable(iter([Name("$sha1abc")]))), Var("$sha1abc"))
675
676 def test_parse_dollar_returns_var(self) -> None:
677 self.assertEqual(parse(Peekable(iter([Name("$")]))), Var("$"))
678
679 def test_parse_dollar_dollar_returns_var(self) -> None:
680 self.assertEqual(parse(Peekable(iter([Name("$$")]))), Var("$$"))
681
682 @unittest.skip("TODO: make this fail to parse")
683 def test_parse_sha_var_without_dollar_raises_parse_error(self) -> None:
684 with self.assertRaisesRegex(ParseError, "unexpected token"):
685 parse(Peekable(iter([Name("sha1'abc")])))
686
687 def test_parse_dollar_dollar_var_returns_var(self) -> None:
688 self.assertEqual(parse(Peekable(iter([Name("$$bills")]))), Var("$$bills"))
689
690 def test_parse_bytes_returns_bytes(self) -> None:
691 self.assertEqual(parse(Peekable(iter([BytesLit("QUJD", 64)]))), Bytes(b"ABC"))
692
693 def test_parse_binary_add_returns_binop(self) -> None:
694 self.assertEqual(
695 parse(Peekable(iter([IntLit(1), Operator("+"), IntLit(2)]))), Binop(BinopKind.ADD, Int(1), Int(2))
696 )
697
698 def test_parse_binary_sub_returns_binop(self) -> None:
699 self.assertEqual(
700 parse(Peekable(iter([IntLit(1), Operator("-"), IntLit(2)]))), Binop(BinopKind.SUB, Int(1), Int(2))
701 )
702
703 def test_parse_binary_add_right_returns_binop(self) -> None:
704 self.assertEqual(
705 parse(Peekable(iter([IntLit(1), Operator("+"), IntLit(2), Operator("+"), IntLit(3)]))),
706 Binop(BinopKind.ADD, Int(1), Binop(BinopKind.ADD, Int(2), Int(3))),
707 )
708
709 def test_mul_binds_tighter_than_add_right(self) -> None:
710 self.assertEqual(
711 parse(Peekable(iter([IntLit(1), Operator("+"), IntLit(2), Operator("*"), IntLit(3)]))),
712 Binop(BinopKind.ADD, Int(1), Binop(BinopKind.MUL, Int(2), Int(3))),
713 )
714
715 def test_mul_binds_tighter_than_add_left(self) -> None:
716 self.assertEqual(
717 parse(Peekable(iter([IntLit(1), Operator("*"), IntLit(2), Operator("+"), IntLit(3)]))),
718 Binop(BinopKind.ADD, Binop(BinopKind.MUL, Int(1), Int(2)), Int(3)),
719 )
720
721 def test_mul_and_div_bind_left_to_right(self) -> None:
722 self.assertEqual(
723 parse(Peekable(iter([IntLit(1), Operator("/"), IntLit(3), Operator("*"), IntLit(3)]))),
724 Binop(BinopKind.MUL, Binop(BinopKind.DIV, Int(1), Int(3)), Int(3)),
725 )
726
727 def test_exp_binds_tighter_than_mul_right(self) -> None:
728 self.assertEqual(
729 parse(Peekable(iter([IntLit(5), Operator("*"), IntLit(2), Operator("^"), IntLit(3)]))),
730 Binop(BinopKind.MUL, Int(5), Binop(BinopKind.EXP, Int(2), Int(3))),
731 )
732
733 def test_list_access_binds_tighter_than_append(self) -> None:
734 self.assertEqual(
735 parse(Peekable(iter([Name("a"), Operator("+<"), Name("ls"), Operator("@"), IntLit(0)]))),
736 Binop(BinopKind.LIST_APPEND, Var("a"), Access(Var("ls"), Int(0))),
737 )
738
739 def test_parse_binary_str_concat_returns_binop(self) -> None:
740 self.assertEqual(
741 parse(Peekable(iter([StringLit("abc"), Operator("++"), StringLit("def")]))),
742 Binop(BinopKind.STRING_CONCAT, String("abc"), String("def")),
743 )
744
745 def test_parse_binary_list_cons_returns_binop(self) -> None:
746 self.assertEqual(
747 parse(Peekable(iter([Name("a"), Operator(">+"), Name("b")]))),
748 Binop(BinopKind.LIST_CONS, Var("a"), Var("b")),
749 )
750
751 def test_parse_binary_list_append_returns_binop(self) -> None:
752 self.assertEqual(
753 parse(Peekable(iter([Name("a"), Operator("+<"), Name("b")]))),
754 Binop(BinopKind.LIST_APPEND, Var("a"), Var("b")),
755 )
756
757 def test_parse_binary_op_returns_binop(self) -> None:
758 ops = ["+", "-", "*", "/", "^", "%", "==", "/=", "<", ">", "<=", ">=", "&&", "||", "++", ">+", "+<"]
759 for op in ops:
760 with self.subTest(op=op):
761 kind = BinopKind.from_str(op)
762 self.assertEqual(
763 parse(Peekable(iter([Name("a"), Operator(op), Name("b")]))), Binop(kind, Var("a"), Var("b"))
764 )
765
766 def test_parse_empty_list(self) -> None:
767 self.assertEqual(
768 parse(Peekable(iter([LeftBracket(), RightBracket()]))),
769 List([]),
770 )
771
772 def test_parse_list_of_ints_returns_list(self) -> None:
773 self.assertEqual(
774 parse(Peekable(iter([LeftBracket(), IntLit(1), Operator(","), IntLit(2), RightBracket()]))),
775 List([Int(1), Int(2)]),
776 )
777
778 def test_parse_list_with_only_comma_raises_parse_error(self) -> None:
779 with self.assertRaises(UnexpectedTokenError) as parse_error:
780 parse(Peekable(iter([LeftBracket(), Operator(","), RightBracket()])))
781
782 self.assertEqual(parse_error.exception.unexpected_token, Operator(","))
783
784 def test_parse_list_with_two_commas_raises_parse_error(self) -> None:
785 with self.assertRaises(UnexpectedTokenError) as parse_error:
786 parse(Peekable(iter([LeftBracket(), Operator(","), Operator(","), RightBracket()])))
787
788 self.assertEqual(parse_error.exception.unexpected_token, Operator(","))
789
790 def test_parse_list_with_trailing_comma_raises_parse_error(self) -> None:
791 with self.assertRaises(UnexpectedTokenError) as parse_error:
792 parse(Peekable(iter([LeftBracket(), IntLit(1), Operator(","), RightBracket()])))
793
794 self.assertEqual(parse_error.exception.unexpected_token, RightBracket())
795
796 def test_parse_assign(self) -> None:
797 self.assertEqual(
798 parse(Peekable(iter([Name("a"), Operator("="), IntLit(1)]))),
799 Assign(Var("a"), Int(1)),
800 )
801
802 def test_parse_function_one_arg_returns_function(self) -> None:
803 self.assertEqual(
804 parse(Peekable(iter([Name("a"), Operator("->"), Name("a"), Operator("+"), IntLit(1)]))),
805 Function(Var("a"), Binop(BinopKind.ADD, Var("a"), Int(1))),
806 )
807
808 def test_parse_function_two_args_returns_functions(self) -> None:
809 self.assertEqual(
810 parse(
811 Peekable(
812 iter([Name("a"), Operator("->"), Name("b"), Operator("->"), Name("a"), Operator("+"), Name("b")])
813 )
814 ),
815 Function(Var("a"), Function(Var("b"), Binop(BinopKind.ADD, Var("a"), Var("b")))),
816 )
817
818 def test_parse_assign_function(self) -> None:
819 self.assertEqual(
820 parse(Peekable(iter([Name("id"), Operator("="), Name("x"), Operator("->"), Name("x")]))),
821 Assign(Var("id"), Function(Var("x"), Var("x"))),
822 )
823
824 def test_parse_function_application_one_arg(self) -> None:
825 self.assertEqual(parse(Peekable(iter([Name("f"), Name("a")]))), Apply(Var("f"), Var("a")))
826
827 def test_parse_function_application_two_args(self) -> None:
828 self.assertEqual(
829 parse(Peekable(iter([Name("f"), Name("a"), Name("b")]))), Apply(Apply(Var("f"), Var("a")), Var("b"))
830 )
831
832 def test_parse_where(self) -> None:
833 self.assertEqual(parse(Peekable(iter([Name("a"), Operator("."), Name("b")]))), Where(Var("a"), Var("b")))
834
835 def test_parse_nested_where(self) -> None:
836 self.assertEqual(
837 parse(Peekable(iter([Name("a"), Operator("."), Name("b"), Operator("."), Name("c")]))),
838 Where(Where(Var("a"), Var("b")), Var("c")),
839 )
840
841 def test_parse_assert(self) -> None:
842 self.assertEqual(parse(Peekable(iter([Name("a"), Operator("?"), Name("b")]))), Assert(Var("a"), Var("b")))
843
844 def test_parse_nested_assert(self) -> None:
845 self.assertEqual(
846 parse(Peekable(iter([Name("a"), Operator("?"), Name("b"), Operator("?"), Name("c")]))),
847 Assert(Assert(Var("a"), Var("b")), Var("c")),
848 )
849
850 def test_parse_mixed_assert_where(self) -> None:
851 self.assertEqual(
852 parse(Peekable(iter([Name("a"), Operator("?"), Name("b"), Operator("."), Name("c")]))),
853 Where(Assert(Var("a"), Var("b")), Var("c")),
854 )
855
856 def test_parse_hastype(self) -> None:
857 self.assertEqual(
858 parse(Peekable(iter([Name("a"), Operator(":"), Name("b")]))), Binop(BinopKind.HASTYPE, Var("a"), Var("b"))
859 )
860
861 def test_parse_hole(self) -> None:
862 self.assertEqual(parse(Peekable(iter([LeftParen(), RightParen()]))), Hole())
863
864 def test_parse_parenthesized_expression(self) -> None:
865 self.assertEqual(
866 parse(Peekable(iter([LeftParen(), IntLit(1), Operator("+"), IntLit(2), RightParen()]))),
867 Binop(BinopKind.ADD, Int(1), Int(2)),
868 )
869
870 def test_parse_parenthesized_add_mul(self) -> None:
871 self.assertEqual(
872 parse(
873 Peekable(
874 iter([LeftParen(), IntLit(1), Operator("+"), IntLit(2), RightParen(), Operator("*"), IntLit(3)])
875 )
876 ),
877 Binop(BinopKind.MUL, Binop(BinopKind.ADD, Int(1), Int(2)), Int(3)),
878 )
879
880 def test_parse_pipe(self) -> None:
881 self.assertEqual(
882 parse(Peekable(iter([IntLit(1), Operator("|>"), Name("f")]))),
883 Apply(Var("f"), Int(1)),
884 )
885
886 def test_parse_nested_pipe(self) -> None:
887 self.assertEqual(
888 parse(Peekable(iter([IntLit(1), Operator("|>"), Name("f"), Operator("|>"), Name("g")]))),
889 Apply(Var("g"), Apply(Var("f"), Int(1))),
890 )
891
892 def test_parse_reverse_pipe(self) -> None:
893 self.assertEqual(
894 parse(Peekable(iter([Name("f"), Operator("<|"), IntLit(1)]))),
895 Apply(Var("f"), Int(1)),
896 )
897
898 def test_parse_nested_reverse_pipe(self) -> None:
899 self.assertEqual(
900 parse(Peekable(iter([Name("g"), Operator("<|"), Name("f"), Operator("<|"), IntLit(1)]))),
901 Apply(Var("g"), Apply(Var("f"), Int(1))),
902 )
903
904 def test_parse_empty_record(self) -> None:
905 self.assertEqual(parse(Peekable(iter([LeftBrace(), RightBrace()]))), Record({}))
906
907 def test_parse_record_single_field(self) -> None:
908 self.assertEqual(
909 parse(Peekable(iter([LeftBrace(), Name("a"), Operator("="), IntLit(4), RightBrace()]))),
910 Record({"a": Int(4)}),
911 )
912
913 def test_parse_record_with_expression(self) -> None:
914 self.assertEqual(
915 parse(
916 Peekable(
917 iter([LeftBrace(), Name("a"), Operator("="), IntLit(1), Operator("+"), IntLit(2), RightBrace()])
918 )
919 ),
920 Record({"a": Binop(BinopKind.ADD, Int(1), Int(2))}),
921 )
922
923 def test_parse_record_multiple_fields(self) -> None:
924 self.assertEqual(
925 parse(
926 Peekable(
927 iter(
928 [
929 LeftBrace(),
930 Name("a"),
931 Operator("="),
932 IntLit(4),
933 Operator(","),
934 Name("b"),
935 Operator("="),
936 StringLit("z"),
937 RightBrace(),
938 ]
939 )
940 )
941 ),
942 Record({"a": Int(4), "b": String("z")}),
943 )
944
945 def test_non_variable_in_assignment_raises_parse_error(self) -> None:
946 with self.assertRaises(ParseError) as ctx:
947 parse(Peekable(iter([IntLit(3), Operator("="), IntLit(4)])))
948 self.assertEqual(ctx.exception.args[0], "expected variable in assignment Int(value=3)")
949
950 def test_non_assign_in_record_constructor_raises_parse_error(self) -> None:
951 with self.assertRaises(ParseError) as ctx:
952 parse(Peekable(iter([LeftBrace(), IntLit(1), Operator(","), IntLit(2), RightBrace()])))
953 self.assertEqual(ctx.exception.args[0], "failed to parse variable assignment in record constructor")
954
955 def test_parse_right_eval_returns_binop(self) -> None:
956 self.assertEqual(
957 parse(Peekable(iter([Name("a"), Operator("!"), Name("b")]))),
958 Binop(BinopKind.RIGHT_EVAL, Var("a"), Var("b")),
959 )
960
961 def test_parse_right_eval_with_defs_returns_binop(self) -> None:
962 self.assertEqual(
963 parse(Peekable(iter([Name("a"), Operator("!"), Name("b"), Operator("."), Name("c")]))),
964 Binop(BinopKind.RIGHT_EVAL, Var("a"), Where(Var("b"), Var("c"))),
965 )
966
967 def test_parse_match_no_cases_raises_parse_error(self) -> None:
968 with self.assertRaises(ParseError) as ctx:
969 parse(Peekable(iter([Operator("|")])))
970 self.assertEqual(ctx.exception.args[0], "unexpected end of input")
971
972 def test_parse_match_one_case(self) -> None:
973 self.assertEqual(
974 parse(Peekable(iter([Operator("|"), IntLit(1), Operator("->"), IntLit(2)]))),
975 MatchFunction([MatchCase(Int(1), Int(2))]),
976 )
977
978 def test_parse_match_two_cases(self) -> None:
979 self.assertEqual(
980 parse(
981 Peekable(
982 iter(
983 [
984 Operator("|"),
985 IntLit(1),
986 Operator("->"),
987 IntLit(2),
988 Operator("|"),
989 IntLit(2),
990 Operator("->"),
991 IntLit(3),
992 ]
993 )
994 )
995 ),
996 MatchFunction(
997 [
998 MatchCase(Int(1), Int(2)),
999 MatchCase(Int(2), Int(3)),
1000 ]
1001 ),
1002 )
1003
1004 def test_parse_compose(self) -> None:
1005 gensym_reset()
1006 self.assertEqual(
1007 parse(Peekable(iter([Name("f"), Operator(">>"), Name("g")]))),
1008 Function(Var("$v0"), Apply(Var("g"), Apply(Var("f"), Var("$v0")))),
1009 )
1010
1011 def test_parse_compose_reverse(self) -> None:
1012 gensym_reset()
1013 self.assertEqual(
1014 parse(Peekable(iter([Name("f"), Operator("<<"), Name("g")]))),
1015 Function(Var("$v0"), Apply(Var("f"), Apply(Var("g"), Var("$v0")))),
1016 )
1017
1018 def test_parse_double_compose(self) -> None:
1019 gensym_reset()
1020 self.assertEqual(
1021 parse(Peekable(iter([Name("f"), Operator("<<"), Name("g"), Operator("<<"), Name("h")]))),
1022 Function(
1023 Var("$v1"),
1024 Apply(Var("f"), Apply(Function(Var("$v0"), Apply(Var("g"), Apply(Var("h"), Var("$v0")))), Var("$v1"))),
1025 ),
1026 )
1027
1028 def test_boolean_and_binds_tighter_than_or(self) -> None:
1029 self.assertEqual(
1030 parse(Peekable(iter([Name("x"), Operator("||"), Name("y"), Operator("&&"), Name("z")]))),
1031 Binop(BinopKind.BOOL_OR, Var("x"), Binop(BinopKind.BOOL_AND, Var("y"), Var("z"))),
1032 )
1033
1034 def test_parse_list_spread(self) -> None:
1035 self.assertEqual(
1036 parse(Peekable(iter([LeftBracket(), IntLit(1), Operator(","), Operator("..."), RightBracket()]))),
1037 List([Int(1), Spread()]),
1038 )
1039
1040 @unittest.skip("TODO(max): Raise if ...x is used with non-name")
1041 def test_parse_list_with_non_name_expr_after_spread_raises_parse_error(self) -> None:
1042 with self.assertRaisesRegex(ParseError, re.escape("unexpected token IntLit(lineno=-1, value=1)")):
1043 parse(Peekable(iter([LeftBracket(), IntLit(1), Operator(","), Operator("..."), IntLit(2), RightBracket()])))
1044
1045 def test_parse_list_with_named_spread(self) -> None:
1046 self.assertEqual(
1047 parse(
1048 Peekable(
1049 iter(
1050 [
1051 LeftBracket(),
1052 IntLit(1),
1053 Operator(","),
1054 Operator("..."),
1055 Name("rest"),
1056 RightBracket(),
1057 ]
1058 )
1059 )
1060 ),
1061 List([Int(1), Spread("rest")]),
1062 )
1063
1064 def test_parse_list_spread_beginning_raises_parse_error(self) -> None:
1065 with self.assertRaisesRegex(ParseError, re.escape("spread must come at end of list match")):
1066 parse(Peekable(iter([LeftBracket(), Operator("..."), Operator(","), IntLit(1), RightBracket()])))
1067
1068 def test_parse_list_named_spread_beginning_raises_parse_error(self) -> None:
1069 with self.assertRaisesRegex(ParseError, re.escape("spread must come at end of list match")):
1070 parse(
1071 Peekable(iter([LeftBracket(), Operator("..."), Name("rest"), Operator(","), IntLit(1), RightBracket()]))
1072 )
1073
1074 def test_parse_list_spread_middle_raises_parse_error(self) -> None:
1075 with self.assertRaisesRegex(ParseError, re.escape("spread must come at end of list match")):
1076 parse(
1077 Peekable(
1078 iter(
1079 [
1080 LeftBracket(),
1081 IntLit(1),
1082 Operator(","),
1083 Operator("..."),
1084 Operator(","),
1085 IntLit(1),
1086 RightBracket(),
1087 ]
1088 )
1089 )
1090 )
1091
1092 def test_parse_list_named_spread_middle_raises_parse_error(self) -> None:
1093 with self.assertRaisesRegex(ParseError, re.escape("spread must come at end of list match")):
1094 parse(
1095 Peekable(
1096 iter(
1097 [
1098 LeftBracket(),
1099 IntLit(1),
1100 Operator(","),
1101 Operator("..."),
1102 Name("rest"),
1103 Operator(","),
1104 IntLit(1),
1105 RightBracket(),
1106 ]
1107 )
1108 )
1109 )
1110
1111 def test_parse_record_spread(self) -> None:
1112 self.assertEqual(
1113 parse(
1114 Peekable(
1115 iter(
1116 [LeftBrace(), Name("x"), Operator("="), IntLit(1), Operator(","), Operator("..."), RightBrace()]
1117 )
1118 )
1119 ),
1120 Record({"x": Int(1), "...": Spread()}),
1121 )
1122
1123 def test_parse_record_spread_beginning_raises_parse_error(self) -> None:
1124 with self.assertRaisesRegex(ParseError, re.escape("spread must come at end of record match")):
1125 parse(
1126 Peekable(
1127 iter(
1128 [LeftBrace(), Operator("..."), Operator(","), Name("x"), Operator("="), IntLit(1), RightBrace()]
1129 )
1130 )
1131 )
1132
1133 def test_parse_record_spread_middle_raises_parse_error(self) -> None:
1134 with self.assertRaisesRegex(ParseError, re.escape("spread must come at end of record match")):
1135 parse(
1136 Peekable(
1137 iter(
1138 [
1139 LeftBrace(),
1140 Name("x"),
1141 Operator("="),
1142 IntLit(1),
1143 Operator(","),
1144 Operator("..."),
1145 Operator(","),
1146 Name("y"),
1147 Operator("="),
1148 IntLit(2),
1149 RightBrace(),
1150 ]
1151 )
1152 )
1153 )
1154
1155 def test_parse_record_with_only_comma_raises_parse_error(self) -> None:
1156 with self.assertRaises(UnexpectedTokenError) as parse_error:
1157 parse(Peekable(iter([LeftBrace(), Operator(","), RightBrace()])))
1158
1159 self.assertEqual(parse_error.exception.unexpected_token, Operator(","))
1160
1161 def test_parse_record_with_two_commas_raises_parse_error(self) -> None:
1162 with self.assertRaises(UnexpectedTokenError) as parse_error:
1163 parse(Peekable(iter([LeftBrace(), Operator(","), Operator(","), RightBrace()])))
1164
1165 self.assertEqual(parse_error.exception.unexpected_token, Operator(","))
1166
1167 def test_parse_record_with_trailing_comma_raises_parse_error(self) -> None:
1168 with self.assertRaises(UnexpectedTokenError) as parse_error:
1169 parse(Peekable(iter([LeftBrace(), Name("x"), Operator("="), IntLit(1), Operator(","), RightBrace()])))
1170
1171 self.assertEqual(parse_error.exception.unexpected_token, RightBrace())
1172
1173 def test_parse_variant_returns_variant(self) -> None:
1174 self.assertEqual(parse(Peekable(iter([Hash(), Name("abc"), IntLit(1)]))), Variant("abc", Int(1)))
1175
1176 def test_parse_hash_raises_unexpected_eof_error(self) -> None:
1177 tokens = Peekable(iter([Hash()]))
1178 with self.assertRaises(UnexpectedEOFError):
1179 parse(tokens)
1180
1181 def test_parse_variant_non_name_raises_parse_error(self) -> None:
1182 with self.assertRaises(UnexpectedTokenError) as parse_error:
1183 parse(Peekable(iter([Hash(), IntLit(1)])))
1184
1185 self.assertEqual(parse_error.exception.unexpected_token, IntLit(1))
1186
1187 def test_parse_variant_eof_raises_unexpected_eof_error(self) -> None:
1188 with self.assertRaises(UnexpectedEOFError):
1189 parse(Peekable(iter([Hash()])))
1190
1191 def test_match_with_variant(self) -> None:
1192 ast = parse(tokenize("| #true () -> 123"))
1193 self.assertEqual(ast, MatchFunction([MatchCase(TRUE, Int(123))]))
1194
1195 def test_binary_and_with_variant_args(self) -> None:
1196 ast = parse(tokenize("#true() && #false()"))
1197 self.assertEqual(ast, Binop(BinopKind.BOOL_AND, TRUE, FALSE))
1198
1199 def test_apply_with_variant_args(self) -> None:
1200 ast = parse(tokenize("f #true() #false()"))
1201 self.assertEqual(ast, Apply(Apply(Var("f"), TRUE), FALSE))
1202
1203 def test_parse_int_preserves_source_extent(self) -> None:
1204 source_extent = SourceExtent(
1205 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1206 )
1207 int_lit = IntLit(1).with_source(source_extent)
1208 self.assertEqual(parse(Peekable(iter([int_lit]))).source_extent, source_extent)
1209
1210 def test_parse_float_preserves_source_extent(self) -> None:
1211 source_extent = SourceExtent(
1212 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=3, byteno=2)
1213 )
1214 float_lit = FloatLit(3.2).with_source(source_extent)
1215 self.assertEqual(parse(Peekable(iter([float_lit]))).source_extent, source_extent)
1216
1217 def test_parse_string_preserves_source_extent(self) -> None:
1218 source_extent = SourceExtent(
1219 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=7, byteno=6)
1220 )
1221 string_lit = StringLit("Hello").with_source(source_extent)
1222 self.assertEqual(parse(Peekable(iter([string_lit]))).source_extent, source_extent)
1223
1224 def test_parse_bytes_preserves_source_extent(self) -> None:
1225 source_extent = SourceExtent(
1226 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=9, byteno=8)
1227 )
1228 bytes_lit = BytesLit("QUJD", 64).with_source(source_extent)
1229 self.assertEqual(parse(Peekable(iter([bytes_lit]))).source_extent, source_extent)
1230
1231 def test_parse_var_preserves_source_extent(self) -> None:
1232 source_extent = SourceExtent(
1233 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1234 )
1235 var = Name("x").with_source(source_extent)
1236 self.assertEqual(parse(Peekable(iter([var]))).source_extent, source_extent)
1237
1238 def test_parse_hole_preserves_source_extent(self) -> None:
1239 left_paren = LeftParen().with_source(
1240 SourceExtent(
1241 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1242 )
1243 )
1244 right_paren = RightParen().with_source(
1245 SourceExtent(
1246 start=SourceLocation(lineno=1, colno=2, byteno=1), end=SourceLocation(lineno=1, colno=2, byteno=1)
1247 )
1248 )
1249 hole_source_extent = SourceExtent(
1250 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=2, byteno=1)
1251 )
1252 self.assertEqual(parse(Peekable(iter([left_paren, right_paren]))).source_extent, hole_source_extent)
1253
1254 def test_parenthesized_expression_preserves_source_extent(self) -> None:
1255 left_paren = LeftParen().with_source(
1256 SourceExtent(
1257 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1258 )
1259 )
1260 int_lit = IntLit(1).with_source(
1261 SourceExtent(
1262 start=SourceLocation(lineno=1, colno=2, byteno=1), end=SourceLocation(lineno=1, colno=2, byteno=1)
1263 )
1264 )
1265 right_paren = RightParen().with_source(
1266 SourceExtent(
1267 start=SourceLocation(lineno=1, colno=3, byteno=2), end=SourceLocation(lineno=1, colno=3, byteno=2)
1268 )
1269 )
1270 parenthesized_int_lit_source_extent = SourceExtent(
1271 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=3, byteno=2)
1272 )
1273 self.assertEqual(
1274 parse(Peekable(iter([left_paren, int_lit, right_paren]))).source_extent, parenthesized_int_lit_source_extent
1275 )
1276
1277 def test_parse_spread_preserves_source_extent(self) -> None:
1278 ellipsis = Operator("...").with_source(
1279 SourceExtent(
1280 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=3, byteno=2)
1281 )
1282 )
1283 name = Name("x").with_source(
1284 SourceExtent(
1285 start=SourceLocation(lineno=1, colno=4, byteno=3), end=SourceLocation(lineno=1, colno=4, byteno=3)
1286 )
1287 )
1288 spread_source_extent = SourceExtent(
1289 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=4, byteno=3)
1290 )
1291 self.assertEqual(parse(Peekable(iter([ellipsis, name]))).source_extent, spread_source_extent)
1292
1293 def test_parse_binop_preserves_source_extent(self) -> None:
1294 first_addend = IntLit(1).with_source(
1295 SourceExtent(
1296 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1297 )
1298 )
1299 operator = Operator("+").with_source(
1300 SourceExtent(
1301 start=SourceLocation(lineno=1, colno=3, byteno=2), end=SourceLocation(lineno=1, colno=3, byteno=2)
1302 )
1303 )
1304 second_addend = IntLit(2).with_source(
1305 SourceExtent(
1306 start=SourceLocation(lineno=2, colno=5, byteno=4), end=SourceLocation(lineno=2, colno=5, byteno=4)
1307 )
1308 )
1309 binop_source_extent = SourceExtent(
1310 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=2, colno=5, byteno=4)
1311 )
1312 self.assertEqual(
1313 parse(Peekable(iter([first_addend, operator, second_addend]))).source_extent, binop_source_extent
1314 )
1315
1316 def test_parse_list_preserves_source_extent(self) -> None:
1317 left_bracket = LeftBracket().with_source(
1318 SourceExtent(
1319 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1320 )
1321 )
1322 one = IntLit(1).with_source(
1323 SourceExtent(
1324 start=SourceLocation(lineno=1, colno=2, byteno=1), end=SourceLocation(lineno=1, colno=2, byteno=1)
1325 )
1326 )
1327 comma = Operator(",").with_source(
1328 SourceExtent(
1329 start=SourceLocation(lineno=1, colno=3, byteno=2), end=SourceLocation(lineno=1, colno=3, byteno=2)
1330 )
1331 )
1332 two = IntLit(2).with_source(
1333 SourceExtent(
1334 start=SourceLocation(lineno=1, colno=5, byteno=4), end=SourceLocation(lineno=1, colno=5, byteno=4)
1335 )
1336 )
1337 right_bracket = RightBracket().with_source(
1338 SourceExtent(
1339 start=SourceLocation(lineno=1, colno=6, byteno=5), end=SourceLocation(lineno=1, colno=6, byteno=5)
1340 )
1341 )
1342 list_source_extent = SourceExtent(
1343 start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=6, byteno=5)
1344 )
1345 self.assertEqual(
1346 parse(Peekable(iter([left_bracket, one, comma, two, right_bracket]))).source_extent, list_source_extent
1347 )
1348
1349
1350class MatchTests(unittest.TestCase):
1351 def test_match_hole_with_non_hole_returns_none(self) -> None:
1352 self.assertEqual(match(Int(1), pattern=Hole()), None)
1353
1354 def test_match_hole_with_hole_returns_empty_dict(self) -> None:
1355 self.assertEqual(match(Hole(), pattern=Hole()), {})
1356
1357 def test_match_with_equal_ints_returns_empty_dict(self) -> None:
1358 self.assertEqual(match(Int(1), pattern=Int(1)), {})
1359
1360 def test_match_with_inequal_ints_returns_none(self) -> None:
1361 self.assertEqual(match(Int(2), pattern=Int(1)), None)
1362
1363 def test_match_int_with_non_int_returns_none(self) -> None:
1364 self.assertEqual(match(String("abc"), pattern=Int(1)), None)
1365
1366 def test_match_with_equal_floats_raises_match_error(self) -> None:
1367 with self.assertRaisesRegex(MatchError, re.escape("pattern matching is not supported for Floats")):
1368 match(Float(1), pattern=Float(1))
1369
1370 def test_match_with_inequal_floats_raises_match_error(self) -> None:
1371 with self.assertRaisesRegex(MatchError, re.escape("pattern matching is not supported for Floats")):
1372 match(Float(2), pattern=Float(1))
1373
1374 def test_match_float_with_non_float_raises_match_error(self) -> None:
1375 with self.assertRaisesRegex(MatchError, re.escape("pattern matching is not supported for Floats")):
1376 match(String("abc"), pattern=Float(1))
1377
1378 def test_match_with_equal_strings_returns_empty_dict(self) -> None:
1379 self.assertEqual(match(String("a"), pattern=String("a")), {})
1380
1381 def test_match_with_inequal_strings_returns_none(self) -> None:
1382 self.assertEqual(match(String("b"), pattern=String("a")), None)
1383
1384 def test_match_string_with_non_string_returns_none(self) -> None:
1385 self.assertEqual(match(Int(1), pattern=String("abc")), None)
1386
1387 def test_match_var_returns_dict_with_var_name(self) -> None:
1388 self.assertEqual(match(String("abc"), pattern=Var("a")), {"a": String("abc")})
1389
1390 def test_match_record_with_non_record_returns_none(self) -> None:
1391 self.assertEqual(
1392 match(
1393 Int(2),
1394 pattern=Record({"x": Var("x"), "y": Var("y")}),
1395 ),
1396 None,
1397 )
1398
1399 def test_match_record_with_more_fields_in_pattern_returns_none(self) -> None:
1400 self.assertEqual(
1401 match(
1402 Record({"x": Int(1), "y": Int(2)}),
1403 pattern=Record({"x": Var("x"), "y": Var("y"), "z": Var("z")}),
1404 ),
1405 None,
1406 )
1407
1408 def test_match_record_with_fewer_fields_in_pattern_returns_none(self) -> None:
1409 self.assertEqual(
1410 match(
1411 Record({"x": Int(1), "y": Int(2)}),
1412 pattern=Record({"x": Var("x")}),
1413 ),
1414 None,
1415 )
1416
1417 def test_match_record_with_vars_returns_dict_with_keys(self) -> None:
1418 self.assertEqual(
1419 match(
1420 Record({"x": Int(1), "y": Int(2)}),
1421 pattern=Record({"x": Var("x"), "y": Var("y")}),
1422 ),
1423 {"x": Int(1), "y": Int(2)},
1424 )
1425
1426 def test_match_record_with_matching_const_returns_dict_with_other_keys(self) -> None:
1427 # TODO(max): Should this be the case? I feel like we should return all
1428 # the keys.
1429 self.assertEqual(
1430 match(
1431 Record({"x": Int(1), "y": Int(2)}),
1432 pattern=Record({"x": Int(1), "y": Var("y")}),
1433 ),
1434 {"y": Int(2)},
1435 )
1436
1437 def test_match_record_with_non_matching_const_returns_none(self) -> None:
1438 self.assertEqual(
1439 match(
1440 Record({"x": Int(1), "y": Int(2)}),
1441 pattern=Record({"x": Int(3), "y": Var("y")}),
1442 ),
1443 None,
1444 )
1445
1446 def test_match_list_with_non_list_returns_none(self) -> None:
1447 self.assertEqual(
1448 match(
1449 Int(2),
1450 pattern=List([Var("x"), Var("y")]),
1451 ),
1452 None,
1453 )
1454
1455 def test_match_list_with_more_fields_in_pattern_returns_none(self) -> None:
1456 self.assertEqual(
1457 match(
1458 List([Int(1), Int(2)]),
1459 pattern=List([Var("x"), Var("y"), Var("z")]),
1460 ),
1461 None,
1462 )
1463
1464 def test_match_list_with_fewer_fields_in_pattern_returns_none(self) -> None:
1465 self.assertEqual(
1466 match(
1467 List([Int(1), Int(2)]),
1468 pattern=List([Var("x")]),
1469 ),
1470 None,
1471 )
1472
1473 def test_match_list_with_vars_returns_dict_with_keys(self) -> None:
1474 self.assertEqual(
1475 match(
1476 List([Int(1), Int(2)]),
1477 pattern=List([Var("x"), Var("y")]),
1478 ),
1479 {"x": Int(1), "y": Int(2)},
1480 )
1481
1482 def test_match_list_with_matching_const_returns_dict_with_other_keys(self) -> None:
1483 self.assertEqual(
1484 match(
1485 List([Int(1), Int(2)]),
1486 pattern=List([Int(1), Var("y")]),
1487 ),
1488 {"y": Int(2)},
1489 )
1490
1491 def test_match_list_with_non_matching_const_returns_none(self) -> None:
1492 self.assertEqual(
1493 match(
1494 List([Int(1), Int(2)]),
1495 pattern=List([Int(3), Var("y")]),
1496 ),
1497 None,
1498 )
1499
1500 def test_parse_right_pipe(self) -> None:
1501 text = "3 + 4 |> $$quote"
1502 ast = parse(tokenize(text))
1503 self.assertEqual(ast, Apply(Var("$$quote"), Binop(BinopKind.ADD, Int(3), Int(4))))
1504
1505 def test_parse_left_pipe(self) -> None:
1506 text = "$$quote <| 3 + 4"
1507 ast = parse(tokenize(text))
1508 self.assertEqual(ast, Apply(Var("$$quote"), Binop(BinopKind.ADD, Int(3), Int(4))))
1509
1510 def test_parse_match_with_left_apply(self) -> None:
1511 text = """| a -> b <| c
1512 | d -> e"""
1513 tokens = tokenize(text)
1514 self.assertEqual(
1515 list(tokens),
1516 [
1517 Operator("|"),
1518 Name("a"),
1519 Operator("->"),
1520 Name("b"),
1521 Operator("<|"),
1522 Name("c"),
1523 Operator("|"),
1524 Name("d"),
1525 Operator("->"),
1526 Name("e"),
1527 ],
1528 )
1529 tokens = tokenize(text)
1530 ast = parse(tokens)
1531 self.assertEqual(
1532 ast, MatchFunction([MatchCase(Var("a"), Apply(Var("b"), Var("c"))), MatchCase(Var("d"), Var("e"))])
1533 )
1534
1535 def test_parse_match_with_right_apply(self) -> None:
1536 text = """
1537| 1 -> 19
1538| a -> a |> (x -> x + 1)
1539"""
1540 tokens = tokenize(text)
1541 ast = parse(tokens)
1542 self.assertEqual(
1543 ast,
1544 MatchFunction(
1545 [
1546 MatchCase(Int(1), Int(19)),
1547 MatchCase(
1548 Var("a"),
1549 Apply(
1550 Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Int(1))),
1551 Var("a"),
1552 ),
1553 ),
1554 ]
1555 ),
1556 )
1557
1558 def test_match_list_with_spread_returns_empty_dict(self) -> None:
1559 self.assertEqual(
1560 match(
1561 List([Int(1), Int(2), Int(3), Int(4), Int(5)]),
1562 pattern=List([Int(1), Spread()]),
1563 ),
1564 {},
1565 )
1566
1567 def test_match_list_with_named_spread_returns_name_bound_to_rest(self) -> None:
1568 self.assertEqual(
1569 match(
1570 List([Int(1), Int(2), Int(3), Int(4)]),
1571 pattern=List([Var("a"), Int(2), Spread("rest")]),
1572 ),
1573 {"a": Int(1), "rest": List([Int(3), Int(4)])},
1574 )
1575
1576 def test_match_list_with_named_spread_returns_name_bound_to_empty_rest(self) -> None:
1577 self.assertEqual(
1578 match(
1579 List([Int(1), Int(2)]),
1580 pattern=List([Var("a"), Int(2), Spread("rest")]),
1581 ),
1582 {"a": Int(1), "rest": List([])},
1583 )
1584
1585 def test_match_list_with_mismatched_spread_returns_none(self) -> None:
1586 self.assertEqual(
1587 match(
1588 List([Int(1), Int(2), Int(3), Int(4), Int(5)]),
1589 pattern=List([Int(1), Int(6), Spread()]),
1590 ),
1591 None,
1592 )
1593
1594 def test_match_record_with_constant_and_spread_returns_empty_dict(self) -> None:
1595 self.assertEqual(
1596 match(
1597 Record({"a": Int(1), "b": Int(2), "c": Int(3)}),
1598 pattern=Record({"a": Int(1), "...": Spread()}),
1599 ),
1600 {},
1601 )
1602
1603 def test_match_record_with_var_and_spread_returns_match(self) -> None:
1604 self.assertEqual(
1605 match(
1606 Record({"a": Int(1), "b": Int(2), "c": Int(3)}),
1607 pattern=Record({"a": Var("x"), "...": Spread()}),
1608 ),
1609 {"x": Int(1)},
1610 )
1611
1612 def test_match_record_with_mismatched_spread_returns_none(self) -> None:
1613 self.assertEqual(
1614 match(
1615 Record({"a": Int(1), "b": Int(2), "c": Int(3)}),
1616 pattern=Record({"d": Var("x"), "...": Spread()}),
1617 ),
1618 None,
1619 )
1620
1621 def test_match_variant_with_equal_tag_returns_empty_dict(self) -> None:
1622 self.assertEqual(match(Variant("abc", Hole()), pattern=Variant("abc", Hole())), {})
1623
1624 def test_match_variant_with_inequal_tag_returns_none(self) -> None:
1625 self.assertEqual(match(Variant("def", Hole()), pattern=Variant("abc", Hole())), None)
1626
1627 def test_match_variant_matches_value(self) -> None:
1628 self.assertEqual(match(Variant("abc", Int(123)), pattern=Variant("abc", Hole())), None)
1629 self.assertEqual(match(Variant("abc", Int(123)), pattern=Variant("abc", Int(123))), {})
1630
1631 def test_match_variant_with_different_type_returns_none(self) -> None:
1632 self.assertEqual(match(Int(123), pattern=Variant("abc", Hole())), None)
1633
1634
1635class EvalTests(unittest.TestCase):
1636 def test_eval_int_returns_int(self) -> None:
1637 exp = Int(5)
1638 self.assertEqual(eval_exp({}, exp), Int(5))
1639
1640 def test_eval_float_returns_float(self) -> None:
1641 exp = Float(3.14)
1642 self.assertEqual(eval_exp({}, exp), Float(3.14))
1643
1644 def test_eval_str_returns_str(self) -> None:
1645 exp = String("xyz")
1646 self.assertEqual(eval_exp({}, exp), String("xyz"))
1647
1648 def test_eval_bytes_returns_bytes(self) -> None:
1649 exp = Bytes(b"xyz")
1650 self.assertEqual(eval_exp({}, exp), Bytes(b"xyz"))
1651
1652 def test_eval_with_non_existent_var_raises_name_error(self) -> None:
1653 exp = Var("no")
1654 with self.assertRaises(NameError) as ctx:
1655 eval_exp({}, exp)
1656 self.assertEqual(ctx.exception.args[0], "name 'no' is not defined")
1657
1658 def test_eval_with_bound_var_returns_value(self) -> None:
1659 exp = Var("yes")
1660 env = {"yes": Int(123)}
1661 self.assertEqual(eval_exp(env, exp), Int(123))
1662
1663 def test_eval_with_binop_add_returns_sum(self) -> None:
1664 exp = Binop(BinopKind.ADD, Int(1), Int(2))
1665 self.assertEqual(eval_exp({}, exp), Int(3))
1666
1667 def test_eval_with_nested_binop(self) -> None:
1668 exp = Binop(BinopKind.ADD, Binop(BinopKind.ADD, Int(1), Int(2)), Int(3))
1669 self.assertEqual(eval_exp({}, exp), Int(6))
1670
1671 def test_eval_with_binop_add_with_int_string_raises_type_error(self) -> None:
1672 exp = Binop(BinopKind.ADD, Int(1), String("hello"))
1673 with self.assertRaises(TypeError) as ctx:
1674 eval_exp({}, exp)
1675 self.assertEqual(ctx.exception.args[0], "expected Int or Float, got String")
1676
1677 def test_eval_with_binop_sub(self) -> None:
1678 exp = Binop(BinopKind.SUB, Int(1), Int(2))
1679 self.assertEqual(eval_exp({}, exp), Int(-1))
1680
1681 def test_eval_with_binop_mul(self) -> None:
1682 exp = Binop(BinopKind.MUL, Int(2), Int(3))
1683 self.assertEqual(eval_exp({}, exp), Int(6))
1684
1685 def test_eval_with_binop_div(self) -> None:
1686 exp = Binop(BinopKind.DIV, Int(3), Int(10))
1687 self.assertEqual(eval_exp({}, exp), Float(0.3))
1688
1689 def test_eval_with_binop_floor_div(self) -> None:
1690 exp = Binop(BinopKind.FLOOR_DIV, Int(2), Int(3))
1691 self.assertEqual(eval_exp({}, exp), Int(0))
1692
1693 def test_eval_with_binop_exp(self) -> None:
1694 exp = Binop(BinopKind.EXP, Int(2), Int(3))
1695 self.assertEqual(eval_exp({}, exp), Int(8))
1696
1697 def test_eval_with_binop_mod(self) -> None:
1698 exp = Binop(BinopKind.MOD, Int(10), Int(4))
1699 self.assertEqual(eval_exp({}, exp), Int(2))
1700
1701 def test_eval_with_binop_equal_with_equal_returns_true(self) -> None:
1702 exp = Binop(BinopKind.EQUAL, Int(1), Int(1))
1703 self.assertEqual(eval_exp({}, exp), TRUE)
1704
1705 def test_eval_with_binop_equal_with_inequal_returns_false(self) -> None:
1706 exp = Binop(BinopKind.EQUAL, Int(1), Int(2))
1707 self.assertEqual(eval_exp({}, exp), FALSE)
1708
1709 def test_eval_with_binop_not_equal_with_equal_returns_false(self) -> None:
1710 exp = Binop(BinopKind.NOT_EQUAL, Int(1), Int(1))
1711 self.assertEqual(eval_exp({}, exp), FALSE)
1712
1713 def test_eval_with_binop_not_equal_with_inequal_returns_true(self) -> None:
1714 exp = Binop(BinopKind.NOT_EQUAL, Int(1), Int(2))
1715 self.assertEqual(eval_exp({}, exp), TRUE)
1716
1717 def test_eval_with_binop_concat_with_strings_returns_string(self) -> None:
1718 exp = Binop(BinopKind.STRING_CONCAT, String("hello"), String(" world"))
1719 self.assertEqual(eval_exp({}, exp), String("hello world"))
1720
1721 def test_eval_with_binop_concat_with_int_string_raises_type_error(self) -> None:
1722 exp = Binop(BinopKind.STRING_CONCAT, Int(123), String(" world"))
1723 with self.assertRaises(TypeError) as ctx:
1724 eval_exp({}, exp)
1725 self.assertEqual(ctx.exception.args[0], "expected String, got Int")
1726
1727 def test_eval_with_binop_concat_with_string_int_raises_type_error(self) -> None:
1728 exp = Binop(BinopKind.STRING_CONCAT, String(" world"), Int(123))
1729 with self.assertRaises(TypeError) as ctx:
1730 eval_exp({}, exp)
1731 self.assertEqual(ctx.exception.args[0], "expected String, got Int")
1732
1733 def test_eval_with_binop_cons_with_int_list_returns_list(self) -> None:
1734 exp = Binop(BinopKind.LIST_CONS, Int(1), List([Int(2), Int(3)]))
1735 self.assertEqual(eval_exp({}, exp), List([Int(1), Int(2), Int(3)]))
1736
1737 def test_eval_with_binop_cons_with_list_list_returns_nested_list(self) -> None:
1738 exp = Binop(BinopKind.LIST_CONS, List([]), List([]))
1739 self.assertEqual(eval_exp({}, exp), List([List([])]))
1740
1741 def test_eval_with_binop_cons_with_list_int_raises_type_error(self) -> None:
1742 exp = Binop(BinopKind.LIST_CONS, List([]), Int(123))
1743 with self.assertRaises(TypeError) as ctx:
1744 eval_exp({}, exp)
1745 self.assertEqual(ctx.exception.args[0], "expected List, got Int")
1746
1747 def test_eval_with_list_append(self) -> None:
1748 exp = Binop(BinopKind.LIST_APPEND, List([Int(1), Int(2)]), Int(3))
1749 self.assertEqual(eval_exp({}, exp), List([Int(1), Int(2), Int(3)]))
1750
1751 def test_eval_with_list_evaluates_elements(self) -> None:
1752 exp = List(
1753 [
1754 Binop(BinopKind.ADD, Int(1), Int(2)),
1755 Binop(BinopKind.ADD, Int(3), Int(4)),
1756 ]
1757 )
1758 self.assertEqual(eval_exp({}, exp), List([Int(3), Int(7)]))
1759
1760 def test_eval_with_function_returns_closure_with_improved_env(self) -> None:
1761 exp = Function(Var("x"), Var("x"))
1762 self.assertEqual(eval_exp({"a": Int(1), "b": Int(2)}, exp), Closure({}, exp))
1763
1764 def test_eval_with_match_function_returns_closure_with_improved_env(self) -> None:
1765 exp = MatchFunction([])
1766 self.assertEqual(eval_exp({"a": Int(1), "b": Int(2)}, exp), Closure({}, exp))
1767
1768 def test_eval_assign_returns_env_object(self) -> None:
1769 exp = Assign(Var("a"), Int(1))
1770 env: Env = {}
1771 result = eval_exp(env, exp)
1772 self.assertEqual(result, EnvObject({"a": Int(1)}))
1773
1774 def test_eval_assign_function_returns_closure_without_function_in_env(self) -> None:
1775 exp = Assign(Var("a"), Function(Var("x"), Var("x")))
1776 result = eval_exp({}, exp)
1777 assert isinstance(result, EnvObject)
1778 closure = result.env["a"]
1779 self.assertIsInstance(closure, Closure)
1780 self.assertEqual(closure, Closure({}, Function(Var("x"), Var("x"))))
1781
1782 def test_eval_assign_function_returns_closure_with_function_in_env(self) -> None:
1783 exp = Assign(Var("a"), Function(Var("x"), Var("a")))
1784 result = eval_exp({}, exp)
1785 assert isinstance(result, EnvObject)
1786 closure = result.env["a"]
1787 self.assertIsInstance(closure, Closure)
1788 self.assertEqual(closure, Closure({"a": closure}, Function(Var("x"), Var("a"))))
1789
1790 def test_eval_assign_does_not_modify_env(self) -> None:
1791 exp = Assign(Var("a"), Int(1))
1792 env: Env = {}
1793 eval_exp(env, exp)
1794 self.assertEqual(env, {})
1795
1796 def test_eval_where_evaluates_in_order(self) -> None:
1797 exp = Where(Binop(BinopKind.ADD, Var("a"), Int(2)), Assign(Var("a"), Int(1)))
1798 env: Env = {}
1799 self.assertEqual(eval_exp(env, exp), Int(3))
1800 self.assertEqual(env, {})
1801
1802 def test_eval_nested_where(self) -> None:
1803 exp = Where(
1804 Where(
1805 Binop(BinopKind.ADD, Var("a"), Var("b")),
1806 Assign(Var("a"), Int(1)),
1807 ),
1808 Assign(Var("b"), Int(2)),
1809 )
1810 env: Env = {}
1811 self.assertEqual(eval_exp(env, exp), Int(3))
1812 self.assertEqual(env, {})
1813
1814 def test_eval_assert_with_truthy_cond_returns_value(self) -> None:
1815 exp = Assert(Int(123), TRUE)
1816 self.assertEqual(eval_exp({}, exp), Int(123))
1817
1818 def test_eval_assert_with_falsey_cond_raises_assertion_error(self) -> None:
1819 exp = Assert(Int(123), FALSE)
1820 with self.assertRaisesRegex(AssertionError, re.escape("condition #false () failed")):
1821 eval_exp({}, exp)
1822
1823 def test_eval_nested_assert(self) -> None:
1824 exp = Assert(Assert(Int(123), TRUE), TRUE)
1825 self.assertEqual(eval_exp({}, exp), Int(123))
1826
1827 def test_eval_hole(self) -> None:
1828 exp = Hole()
1829 self.assertEqual(eval_exp({}, exp), Hole())
1830
1831 def test_eval_function_application_one_arg(self) -> None:
1832 exp = Apply(Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Int(1))), Int(2))
1833 self.assertEqual(eval_exp({}, exp), Int(3))
1834
1835 def test_eval_function_application_two_args(self) -> None:
1836 exp = Apply(
1837 Apply(Function(Var("a"), Function(Var("b"), Binop(BinopKind.ADD, Var("a"), Var("b")))), Int(3)),
1838 Int(2),
1839 )
1840 self.assertEqual(eval_exp({}, exp), Int(5))
1841
1842 def test_eval_function_returns_closure_with_captured_env(self) -> None:
1843 exp = Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Var("y")))
1844 res = eval_exp({"y": Int(5)}, exp)
1845 self.assertIsInstance(res, Closure)
1846 assert isinstance(res, Closure) # for mypy
1847 self.assertEqual(res.env, {"y": Int(5)})
1848
1849 def test_eval_function_capture_env(self) -> None:
1850 exp = Apply(Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Var("y"))), Int(2))
1851 self.assertEqual(eval_exp({"y": Int(5)}, exp), Int(7))
1852
1853 def test_eval_non_function_raises_type_error(self) -> None:
1854 exp = Apply(Int(3), Int(4))
1855 with self.assertRaisesRegex(TypeError, re.escape("attempted to apply a non-closure of type Int")):
1856 eval_exp({}, exp)
1857
1858 def test_eval_access_from_invalid_object_raises_type_error(self) -> None:
1859 exp = Access(Int(4), String("x"))
1860 with self.assertRaisesRegex(TypeError, re.escape("attempted to access from type Int")):
1861 eval_exp({}, exp)
1862
1863 def test_eval_record_evaluates_value_expressions(self) -> None:
1864 exp = Record({"a": Binop(BinopKind.ADD, Int(1), Int(2))})
1865 self.assertEqual(eval_exp({}, exp), Record({"a": Int(3)}))
1866
1867 def test_eval_record_access_with_invalid_accessor_raises_type_error(self) -> None:
1868 exp = Access(Record({"a": Int(4)}), Int(0))
1869 with self.assertRaisesRegex(
1870 TypeError, re.escape("cannot access record field using Int, expected a field name")
1871 ):
1872 eval_exp({}, exp)
1873
1874 def test_eval_record_access_with_unknown_accessor_raises_name_error(self) -> None:
1875 exp = Access(Record({"a": Int(4)}), Var("b"))
1876 with self.assertRaisesRegex(NameError, re.escape("no assignment to b found in record")):
1877 eval_exp({}, exp)
1878
1879 def test_eval_record_access(self) -> None:
1880 exp = Access(Record({"a": Int(4)}), Var("a"))
1881 self.assertEqual(eval_exp({}, exp), Int(4))
1882
1883 def test_eval_list_access_with_invalid_accessor_raises_type_error(self) -> None:
1884 exp = Access(List([Int(4)]), String("hello"))
1885 with self.assertRaisesRegex(TypeError, re.escape("cannot index into list using type String, expected integer")):
1886 eval_exp({}, exp)
1887
1888 def test_eval_list_access_with_out_of_bounds_accessor_raises_value_error(self) -> None:
1889 exp = Access(List([Int(1), Int(2), Int(3)]), Int(4))
1890 with self.assertRaisesRegex(ValueError, re.escape("index 4 out of bounds for list")):
1891 eval_exp({}, exp)
1892
1893 def test_eval_list_access(self) -> None:
1894 exp = Access(List([String("a"), String("b"), String("c")]), Int(2))
1895 self.assertEqual(eval_exp({}, exp), String("c"))
1896
1897 def test_right_eval_evaluates_right_hand_side(self) -> None:
1898 exp = Binop(BinopKind.RIGHT_EVAL, Int(1), Int(2))
1899 self.assertEqual(eval_exp({}, exp), Int(2))
1900
1901 def test_match_no_cases_raises_match_error(self) -> None:
1902 exp = Apply(MatchFunction([]), Int(1))
1903 with self.assertRaisesRegex(MatchError, "no matching cases"):
1904 eval_exp({}, exp)
1905
1906 def test_match_int_with_equal_int_matches(self) -> None:
1907 exp = Apply(MatchFunction([MatchCase(pattern=Int(1), body=Int(2))]), Int(1))
1908 self.assertEqual(eval_exp({}, exp), Int(2))
1909
1910 def test_match_int_with_inequal_int_raises_match_error(self) -> None:
1911 exp = Apply(MatchFunction([MatchCase(pattern=Int(1), body=Int(2))]), Int(3))
1912 with self.assertRaisesRegex(MatchError, "no matching cases"):
1913 eval_exp({}, exp)
1914
1915 def test_match_string_with_equal_string_matches(self) -> None:
1916 exp = Apply(MatchFunction([MatchCase(pattern=String("a"), body=String("b"))]), String("a"))
1917 self.assertEqual(eval_exp({}, exp), String("b"))
1918
1919 def test_match_string_with_inequal_string_raises_match_error(self) -> None:
1920 exp = Apply(MatchFunction([MatchCase(pattern=String("a"), body=String("b"))]), String("c"))
1921 with self.assertRaisesRegex(MatchError, "no matching cases"):
1922 eval_exp({}, exp)
1923
1924 def test_match_falls_through_to_next(self) -> None:
1925 exp = Apply(
1926 MatchFunction([MatchCase(pattern=Int(3), body=Int(4)), MatchCase(pattern=Int(1), body=Int(2))]), Int(1)
1927 )
1928 self.assertEqual(eval_exp({}, exp), Int(2))
1929
1930 def test_eval_compose(self) -> None:
1931 gensym_reset()
1932 exp = parse(tokenize("(x -> x + 3) << (x -> x * 2)"))
1933 env = {"a": Int(1)}
1934 expected = Closure(
1935 {},
1936 Function(
1937 Var("$v0"),
1938 Apply(
1939 Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Int(3))),
1940 Apply(Function(Var("x"), Binop(BinopKind.MUL, Var("x"), Int(2))), Var("$v0")),
1941 ),
1942 ),
1943 )
1944 self.assertEqual(eval_exp(env, exp), expected)
1945
1946 def test_eval_native_function_returns_function(self) -> None:
1947 exp = NativeFunction("times2", lambda x: Int(x.value * 2)) # type: ignore [attr-defined]
1948 self.assertIs(eval_exp({}, exp), exp)
1949
1950 def test_eval_apply_native_function_calls_function(self) -> None:
1951 exp = Apply(NativeFunction("times2", lambda x: Int(x.value * 2)), Int(3)) # type: ignore [attr-defined]
1952 self.assertEqual(eval_exp({}, exp), Int(6))
1953
1954 def test_eval_apply_quote_returns_ast(self) -> None:
1955 ast = Binop(BinopKind.ADD, Int(1), Int(2))
1956 exp = Apply(Var("$$quote"), ast)
1957 self.assertIs(eval_exp({}, exp), ast)
1958
1959 def test_eval_apply_closure_with_match_function_has_access_to_closure_vars(self) -> None:
1960 ast = Apply(Closure({"x": Int(1)}, MatchFunction([MatchCase(Var("y"), Var("x"))])), Int(2))
1961 self.assertEqual(eval_exp({}, ast), Int(1))
1962
1963 def test_eval_less_returns_bool(self) -> None:
1964 ast = Binop(BinopKind.LESS, Int(3), Int(4))
1965 self.assertEqual(eval_exp({}, ast), TRUE)
1966
1967 def test_eval_less_on_non_bool_raises_type_error(self) -> None:
1968 ast = Binop(BinopKind.LESS, String("xyz"), Int(4))
1969 with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")):
1970 eval_exp({}, ast)
1971
1972 def test_eval_less_equal_returns_bool(self) -> None:
1973 ast = Binop(BinopKind.LESS_EQUAL, Int(3), Int(4))
1974 self.assertEqual(eval_exp({}, ast), TRUE)
1975
1976 def test_eval_less_equal_on_non_bool_raises_type_error(self) -> None:
1977 ast = Binop(BinopKind.LESS_EQUAL, String("xyz"), Int(4))
1978 with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")):
1979 eval_exp({}, ast)
1980
1981 def test_eval_greater_returns_bool(self) -> None:
1982 ast = Binop(BinopKind.GREATER, Int(3), Int(4))
1983 self.assertEqual(eval_exp({}, ast), FALSE)
1984
1985 def test_eval_greater_on_non_bool_raises_type_error(self) -> None:
1986 ast = Binop(BinopKind.GREATER, String("xyz"), Int(4))
1987 with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")):
1988 eval_exp({}, ast)
1989
1990 def test_eval_greater_equal_returns_bool(self) -> None:
1991 ast = Binop(BinopKind.GREATER_EQUAL, Int(3), Int(4))
1992 self.assertEqual(eval_exp({}, ast), FALSE)
1993
1994 def test_eval_greater_equal_on_non_bool_raises_type_error(self) -> None:
1995 ast = Binop(BinopKind.GREATER_EQUAL, String("xyz"), Int(4))
1996 with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")):
1997 eval_exp({}, ast)
1998
1999 def test_boolean_and_evaluates_args(self) -> None:
2000 ast = Binop(BinopKind.BOOL_AND, TRUE, Var("a"))
2001 self.assertEqual(eval_exp({"a": FALSE}, ast), FALSE)
2002
2003 ast = Binop(BinopKind.BOOL_AND, Var("a"), FALSE)
2004 self.assertEqual(eval_exp({"a": TRUE}, ast), FALSE)
2005
2006 def test_boolean_or_evaluates_args(self) -> None:
2007 ast = Binop(BinopKind.BOOL_OR, FALSE, Var("a"))
2008 self.assertEqual(eval_exp({"a": TRUE}, ast), TRUE)
2009
2010 ast = Binop(BinopKind.BOOL_OR, Var("a"), TRUE)
2011 self.assertEqual(eval_exp({"a": FALSE}, ast), TRUE)
2012
2013 def test_boolean_and_short_circuit(self) -> None:
2014 def raise_func(message: Object) -> Object:
2015 if not isinstance(message, String):
2016 raise TypeError(f"raise_func expected String, but got {type(message).__name__}")
2017 raise RuntimeError(message)
2018
2019 error = NativeFunction("error", raise_func)
2020 apply = Apply(Var("error"), String("expected failure"))
2021 ast = Binop(BinopKind.BOOL_AND, FALSE, apply)
2022 self.assertEqual(eval_exp({"error": error}, ast), FALSE)
2023
2024 def test_boolean_or_short_circuit(self) -> None:
2025 def raise_func(message: Object) -> Object:
2026 if not isinstance(message, String):
2027 raise TypeError(f"raise_func expected String, but got {type(message).__name__}")
2028 raise RuntimeError(message)
2029
2030 error = NativeFunction("error", raise_func)
2031 apply = Apply(Var("error"), String("expected failure"))
2032 ast = Binop(BinopKind.BOOL_OR, TRUE, apply)
2033 self.assertEqual(eval_exp({"error": error}, ast), TRUE)
2034
2035 def test_boolean_and_on_int_raises_type_error(self) -> None:
2036 exp = Binop(BinopKind.BOOL_AND, Int(1), Int(2))
2037 with self.assertRaisesRegex(TypeError, re.escape("expected #true or #false, got Int")):
2038 eval_exp({}, exp)
2039
2040 def test_boolean_or_on_int_raises_type_error(self) -> None:
2041 exp = Binop(BinopKind.BOOL_OR, Int(1), Int(2))
2042 with self.assertRaisesRegex(TypeError, re.escape("expected #true or #false, got Int")):
2043 eval_exp({}, exp)
2044
2045 def test_eval_record_with_spread_fails(self) -> None:
2046 exp = Record({"x": Spread()})
2047 with self.assertRaisesRegex(RuntimeError, "cannot evaluate a spread"):
2048 eval_exp({}, exp)
2049
2050 def test_eval_variant_returns_variant(self) -> None:
2051 self.assertEqual(
2052 eval_exp(
2053 {},
2054 Variant("abc", Binop(BinopKind.ADD, Int(1), Int(2))),
2055 ),
2056 Variant("abc", Int(3)),
2057 )
2058
2059 def test_eval_float_and_float_addition_returns_float(self) -> None:
2060 self.assertEqual(eval_exp({}, Binop(BinopKind.ADD, Float(1.0), Float(2.0))), Float(3.0))
2061
2062 def test_eval_int_and_float_addition_returns_float(self) -> None:
2063 self.assertEqual(eval_exp({}, Binop(BinopKind.ADD, Int(1), Float(2.0))), Float(3.0))
2064
2065 def test_eval_int_and_float_division_returns_float(self) -> None:
2066 self.assertEqual(eval_exp({}, Binop(BinopKind.DIV, Int(1), Float(2.0))), Float(0.5))
2067
2068 def test_eval_float_and_int_division_returns_float(self) -> None:
2069 self.assertEqual(eval_exp({}, Binop(BinopKind.DIV, Float(1.0), Int(2))), Float(0.5))
2070
2071 def test_eval_int_and_int_division_returns_float(self) -> None:
2072 self.assertEqual(eval_exp({}, Binop(BinopKind.DIV, Int(1), Int(2))), Float(0.5))
2073
2074
2075class EndToEndTestsBase(unittest.TestCase):
2076 def _run(self, text: str, env: Optional[Env] = None, check: bool = False) -> Object:
2077 tokens = tokenize(text)
2078 ast = parse(tokens)
2079 if check:
2080 infer_type(ast, OP_ENV)
2081 if env is None:
2082 env = boot_env()
2083 return eval_exp(env, ast)
2084
2085
2086class EndToEndTests(EndToEndTestsBase):
2087 def test_int_returns_int(self) -> None:
2088 self.assertEqual(self._run("1"), Int(1))
2089
2090 def test_float_returns_float(self) -> None:
2091 self.assertEqual(self._run("3.14"), Float(3.14))
2092
2093 def test_bytes_returns_bytes(self) -> None:
2094 self.assertEqual(self._run("~~QUJD"), Bytes(b"ABC"))
2095
2096 def test_bytes_base85_returns_bytes(self) -> None:
2097 self.assertEqual(self._run("~~85'K|(_"), Bytes(b"ABC"))
2098
2099 def test_bytes_base64_returns_bytes(self) -> None:
2100 self.assertEqual(self._run("~~64'QUJD"), Bytes(b"ABC"))
2101
2102 def test_bytes_base32_returns_bytes(self) -> None:
2103 self.assertEqual(self._run("~~32'IFBEG==="), Bytes(b"ABC"))
2104
2105 def test_bytes_base16_returns_bytes(self) -> None:
2106 self.assertEqual(self._run("~~16'414243"), Bytes(b"ABC"))
2107
2108 def test_int_add_returns_int(self) -> None:
2109 self.assertEqual(self._run("1 + 2"), Int(3))
2110
2111 def test_int_sub_returns_int(self) -> None:
2112 self.assertEqual(self._run("1 - 2"), Int(-1))
2113
2114 def test_string_concat_returns_string(self) -> None:
2115 self.assertEqual(self._run('"abc" ++ "def"'), String("abcdef"))
2116
2117 def test_list_cons_returns_list(self) -> None:
2118 self.assertEqual(self._run("1 >+ [2,3]"), List([Int(1), Int(2), Int(3)]))
2119
2120 def test_list_cons_nested_returns_list(self) -> None:
2121 self.assertEqual(self._run("1 >+ 2 >+ [3,4]"), List([Int(1), Int(2), Int(3), Int(4)]))
2122
2123 def test_list_append_returns_list(self) -> None:
2124 self.assertEqual(self._run("[1,2] +< 3"), List([Int(1), Int(2), Int(3)]))
2125
2126 def test_list_append_nested_returns_list(self) -> None:
2127 self.assertEqual(self._run("[1,2] +< 3 +< 4"), List([Int(1), Int(2), Int(3), Int(4)]))
2128
2129 def test_empty_list(self) -> None:
2130 self.assertEqual(self._run("[ ]"), List([]))
2131
2132 def test_empty_list_with_no_spaces(self) -> None:
2133 self.assertEqual(self._run("[]"), List([]))
2134
2135 def test_list_of_ints(self) -> None:
2136 self.assertEqual(self._run("[ 1 , 2 ]"), List([Int(1), Int(2)]))
2137
2138 def test_list_of_exprs(self) -> None:
2139 self.assertEqual(
2140 self._run("[ 1 + 2 , 3 + 4 ]"),
2141 List([Int(3), Int(7)]),
2142 )
2143
2144 def test_where(self) -> None:
2145 self.assertEqual(self._run("a + 2 . a = 1"), Int(3))
2146
2147 def test_nested_where(self) -> None:
2148 self.assertEqual(self._run("a + b . a = 1 . b = 2"), Int(3))
2149
2150 def test_assert_with_truthy_cond_returns_value(self) -> None:
2151 self.assertEqual(self._run("a + 1 ? a == 1 . a = 1"), Int(2))
2152
2153 def test_assert_with_falsey_cond_raises_assertion_error(self) -> None:
2154 with self.assertRaisesRegex(AssertionError, "condition a == 2 failed"):
2155 self._run("a + 1 ? a == 2 . a = 1")
2156
2157 def test_nested_assert(self) -> None:
2158 self.assertEqual(self._run("a + b ? a == 1 ? b == 2 . a = 1 . b = 2"), Int(3))
2159
2160 def test_hole(self) -> None:
2161 self.assertEqual(self._run("()"), Hole())
2162
2163 def test_bindings_behave_like_letstar(self) -> None:
2164 with self.assertRaises(NameError) as ctx:
2165 self._run("b . a = 1 . b = a")
2166 self.assertEqual(ctx.exception.args[0], "name 'a' is not defined")
2167
2168 def test_function_application_two_args(self) -> None:
2169 self.assertEqual(self._run("(a -> b -> a + b) 3 2"), Int(5))
2170
2171 def test_function_create_list_correct_order(self) -> None:
2172 self.assertEqual(self._run("(a -> b -> [a, b]) 3 2"), List([Int(3), Int(2)]))
2173
2174 def test_create_record(self) -> None:
2175 self.assertEqual(self._run("{a = 1 + 3}"), Record({"a": Int(4)}))
2176
2177 def test_access_record(self) -> None:
2178 self.assertEqual(self._run('rec@b . rec = { a = 1, b = "x" }'), String("x"))
2179
2180 def test_access_list(self) -> None:
2181 self.assertEqual(self._run("xs@1 . xs = [1, 2, 3]"), Int(2))
2182
2183 def test_access_list_var(self) -> None:
2184 self.assertEqual(self._run("xs@y . y = 2 . xs = [1, 2, 3]"), Int(3))
2185
2186 def test_access_list_expr(self) -> None:
2187 self.assertEqual(self._run("xs@(1+1) . xs = [1, 2, 3]"), Int(3))
2188
2189 def test_access_list_closure_var(self) -> None:
2190 self.assertEqual(
2191 self._run("list_at 1 [1,2,3] . list_at = idx -> ls -> ls@idx"),
2192 Int(2),
2193 )
2194
2195 def test_functions_eval_arguments(self) -> None:
2196 self.assertEqual(self._run("(x -> x) c . c = 1"), Int(1))
2197
2198 def test_non_var_function_arg_raises_parse_error(self) -> None:
2199 with self.assertRaises(RuntimeError) as ctx:
2200 self._run("1 -> a")
2201 self.assertEqual(ctx.exception.args[0], "expected variable in function definition 1")
2202
2203 def test_compose(self) -> None:
2204 self.assertEqual(self._run("((a -> a + 1) >> (b -> b * 2)) 3"), Int(8))
2205
2206 def test_compose_does_not_expose_internal_x(self) -> None:
2207 with self.assertRaisesRegex(NameError, "name 'x' is not defined"):
2208 self._run("f 3 . f = ((y -> x) >> (z -> x))")
2209
2210 def test_double_compose(self) -> None:
2211 self.assertEqual(self._run("((a -> a + 1) >> (x -> x) >> (b -> b * 2)) 3"), Int(8))
2212
2213 def test_reverse_compose(self) -> None:
2214 self.assertEqual(self._run("((a -> a + 1) << (b -> b * 2)) 3"), Int(7))
2215
2216 def test_simple_int_match(self) -> None:
2217 self.assertEqual(
2218 self._run(
2219 """
2220 inc 2
2221 . inc =
2222 | 1 -> 2
2223 | 2 -> 3
2224 """
2225 ),
2226 Int(3),
2227 )
2228
2229 def test_match_var_binds_var(self) -> None:
2230 self.assertEqual(
2231 self._run(
2232 """
2233 id 3
2234 . id =
2235 | x -> x
2236 """
2237 ),
2238 Int(3),
2239 )
2240
2241 def test_match_var_binds_first_arm(self) -> None:
2242 self.assertEqual(
2243 self._run(
2244 """
2245 id 3
2246 . id =
2247 | x -> x
2248 | y -> y * 2
2249 """
2250 ),
2251 Int(3),
2252 )
2253
2254 def test_match_function_can_close_over_variables(self) -> None:
2255 self.assertEqual(
2256 self._run(
2257 """
2258 f 1 2
2259 . f = a ->
2260 | b -> a + b
2261 """
2262 ),
2263 Int(3),
2264 )
2265
2266 def test_match_record_binds_var(self) -> None:
2267 self.assertEqual(
2268 self._run(
2269 """
2270 get_x rec
2271 . rec = { x = 3 }
2272 . get_x =
2273 | { x = x } -> x
2274 """
2275 ),
2276 Int(3),
2277 )
2278
2279 def test_match_record_binds_vars(self) -> None:
2280 self.assertEqual(
2281 self._run(
2282 """
2283 mult rec
2284 . rec = { x = 3, y = 4 }
2285 . mult =
2286 | { x = x, y = y } -> x * y
2287 """
2288 ),
2289 Int(12),
2290 )
2291
2292 def test_match_record_with_extra_fields_does_not_match(self) -> None:
2293 with self.assertRaises(MatchError):
2294 self._run(
2295 """
2296 mult rec
2297 . rec = { x = 3 }
2298 . mult =
2299 | { x = x, y = y } -> x * y
2300 """
2301 )
2302
2303 def test_match_record_with_constant(self) -> None:
2304 self.assertEqual(
2305 self._run(
2306 """
2307 mult rec
2308 . rec = { x = 4, y = 5 }
2309 . mult =
2310 | { x = 3, y = y } -> 1
2311 | { x = 4, y = y } -> 2
2312 """
2313 ),
2314 Int(2),
2315 )
2316
2317 def test_match_record_with_non_record_fails(self) -> None:
2318 with self.assertRaises(MatchError):
2319 self._run(
2320 """
2321 get_x 3
2322 . get_x =
2323 | { x = x } -> x
2324 """
2325 )
2326
2327 def test_match_record_doubly_binds_vars(self) -> None:
2328 self.assertEqual(
2329 self._run(
2330 """
2331 get_x rec
2332 . rec = { a = 3, b = 3 }
2333 . get_x =
2334 | { a = x, b = x } -> x
2335 """
2336 ),
2337 Int(3),
2338 )
2339
2340 def test_match_record_spread_binds_spread(self) -> None:
2341 self.assertEqual(self._run("(| { a=1, ...rest } -> rest) {a=1, b=2, c=3}"), Record({"b": Int(2), "c": Int(3)}))
2342
2343 def test_match_list_binds_vars(self) -> None:
2344 self.assertEqual(
2345 self._run(
2346 """
2347 mult xs
2348 . xs = [1, 2, 3, 4, 5]
2349 . mult =
2350 | [1, x, 3, y, 5] -> x * y
2351 """
2352 ),
2353 Int(8),
2354 )
2355
2356 def test_match_list_incorrect_length_does_not_match(self) -> None:
2357 with self.assertRaises(MatchError):
2358 self._run(
2359 """
2360 mult xs
2361 . xs = [1, 2, 3]
2362 . mult =
2363 | [1, 2] -> 1
2364 | [1, 2, 3, 4] -> 1
2365 | [1, 3] -> 1
2366 """
2367 )
2368
2369 def test_match_list_with_constant(self) -> None:
2370 self.assertEqual(
2371 self._run(
2372 """
2373 middle xs
2374 . xs = [4, 5, 6]
2375 . middle =
2376 | [1, x, 3] -> x
2377 | [4, x, 6] -> x
2378 | [7, x, 9] -> x
2379 """
2380 ),
2381 Int(5),
2382 )
2383
2384 def test_match_list_with_non_list_fails(self) -> None:
2385 with self.assertRaises(MatchError):
2386 self._run(
2387 """
2388 get_x 3
2389 . get_x =
2390 | [2, x] -> x
2391 """
2392 )
2393
2394 def test_match_list_doubly_binds_vars(self) -> None:
2395 self.assertEqual(
2396 self._run(
2397 """
2398 mult xs
2399 . xs = [1, 2, 3, 2, 1]
2400 . mult =
2401 | [1, x, 3, x, 1] -> x
2402 """
2403 ),
2404 Int(2),
2405 )
2406
2407 def test_match_list_spread_binds_spread(self) -> None:
2408 self.assertEqual(self._run("(| [x, ...xs] -> xs) [1, 2]"), List([Int(2)]))
2409
2410 def test_pipe(self) -> None:
2411 self.assertEqual(self._run("1 |> (a -> a + 2)"), Int(3))
2412
2413 def test_pipe_nested(self) -> None:
2414 self.assertEqual(self._run("1 |> (a -> a + 2) |> (b -> b * 2)"), Int(6))
2415
2416 def test_reverse_pipe(self) -> None:
2417 self.assertEqual(self._run("(a -> a + 2) <| 1"), Int(3))
2418
2419 def test_reverse_pipe_nested(self) -> None:
2420 self.assertEqual(self._run("(b -> b * 2) <| (a -> a + 2) <| 1"), Int(6))
2421
2422 def test_function_can_reference_itself(self) -> None:
2423 result = self._run(
2424 """
2425 f 1
2426 . f = n -> f
2427 """,
2428 {},
2429 )
2430 self.assertEqual(result, Closure({"f": result}, Function(Var("n"), Var("f"))))
2431
2432 def test_function_can_call_itself(self) -> None:
2433 with self.assertRaises(RecursionError):
2434 self._run(
2435 """
2436 f 1
2437 . f = n -> f n
2438 """
2439 )
2440
2441 def test_match_function_can_call_itself(self) -> None:
2442 self.assertEqual(
2443 self._run(
2444 """
2445 fac 5
2446 . fac =
2447 | 0 -> 1
2448 | 1 -> 1
2449 | n -> n * fac (n - 1)
2450 """
2451 ),
2452 Int(120),
2453 )
2454
2455 def test_list_access_binds_tighter_than_append(self) -> None:
2456 self.assertEqual(self._run("[1, 2, 3] +< xs@0 . xs = [4]"), List([Int(1), Int(2), Int(3), Int(4)]))
2457
2458 def test_exponentiation(self) -> None:
2459 self.assertEqual(self._run("6 ^ 2"), Int(36))
2460
2461 def test_modulus(self) -> None:
2462 self.assertEqual(self._run("11 % 3"), Int(2))
2463
2464 def test_exp_binds_tighter_than_mul(self) -> None:
2465 self.assertEqual(self._run("5 * 2 ^ 3"), Int(40))
2466
2467 def test_variant_true_returns_true(self) -> None:
2468 self.assertEqual(self._run("# true ()", {}), TRUE)
2469
2470 def test_variant_false_returns_false(self) -> None:
2471 self.assertEqual(self._run("#false ()", {}), FALSE)
2472
2473 def test_boolean_and_binds_tighter_than_or(self) -> None:
2474 self.assertEqual(self._run("#true () || #true () && boom", {}), TRUE)
2475
2476 def test_compare_binds_tighter_than_boolean_and(self) -> None:
2477 self.assertEqual(self._run("1 < 2 && 2 < 1"), FALSE)
2478
2479 def test_match_list_spread(self) -> None:
2480 self.assertEqual(
2481 self._run(
2482 """
2483 f [2, 4, 6]
2484 . f =
2485 | [] -> 0
2486 | [x, ...] -> x
2487 | c -> 1
2488 """
2489 ),
2490 Int(2),
2491 )
2492
2493 def test_match_list_named_spread(self) -> None:
2494 self.assertEqual(
2495 self._run(
2496 """
2497 tail [1,2,3]
2498 . tail =
2499 | [first, ...rest] -> rest
2500 """
2501 ),
2502 List([Int(2), Int(3)]),
2503 )
2504
2505 def test_match_record_spread(self) -> None:
2506 self.assertEqual(
2507 self._run(
2508 """
2509 f {x = 4, y = 5}
2510 . f =
2511 | {} -> 0
2512 | {x = a, ...} -> a
2513 | c -> 1
2514 """
2515 ),
2516 Int(4),
2517 )
2518
2519 def test_match_expr_as_boolean_variants(self) -> None:
2520 self.assertEqual(
2521 self._run(
2522 """
2523 say (1 < 2)
2524 . say =
2525 | #false () -> "oh no"
2526 | #true () -> "omg"
2527 """
2528 ),
2529 String("omg"),
2530 )
2531
2532 def test_match_variant_record(self) -> None:
2533 self.assertEqual(
2534 self._run(
2535 """
2536 f #add {x = 3, y = 4}
2537 . f =
2538 | # b () -> "foo"
2539 | #add {x = x, y = y} -> x + y
2540 """
2541 ),
2542 Int(7),
2543 )
2544
2545 def test_int_div_returns_float(self) -> None:
2546 self.assertEqual(self._run("1 / 2 + 3"), Float(3.5))
2547 with self.assertRaisesRegex(InferenceError, "int and float"):
2548 self._run("1 / 2 + 3", check=True)
2549
2550 def test_eval_count_bits_function_preserves_source_extents(self) -> None:
2551 env_object = self._run(
2552 """
2553 count_bits = counts ->
2554 | [1, ...bits] -> count_bits { zeros = counts@zeros, ones = 1 + counts@ones } bits
2555 | [0, ...bits] -> count_bits { zeros = 1 + counts@zeros, ones = counts@ones } bits
2556 | [] -> counts
2557 """
2558 )
2559
2560 assert isinstance(env_object, EnvObject)
2561
2562 count_bits_closure = env_object.env["count_bits"]
2563
2564 assert isinstance(count_bits_closure, Closure)
2565
2566 outer_function = count_bits_closure.func
2567 outer_function_source_extent = SourceExtent(
2568 start=SourceLocation(lineno=2, colno=22, byteno=22),
2569 end=SourceLocation(lineno=5, colno=24, byteno=241),
2570 )
2571
2572 assert isinstance(outer_function, Function)
2573
2574 inner_function = outer_function.body
2575
2576 assert isinstance(inner_function, MatchFunction)
2577
2578 match_function_one = inner_function.cases[0]
2579 match_function_one_source_extent = SourceExtent(
2580 start=SourceLocation(lineno=3, colno=11, byteno=42), end=SourceLocation(lineno=3, colno=92, byteno=123)
2581 )
2582
2583 match_function_two = inner_function.cases[1]
2584 match_function_two_source_extent = SourceExtent(
2585 start=SourceLocation(lineno=4, colno=11, byteno=135), end=SourceLocation(lineno=4, colno=92, byteno=216)
2586 )
2587
2588 match_function_three = inner_function.cases[2]
2589 match_function_three_source_extent = SourceExtent(
2590 start=SourceLocation(lineno=5, colno=11, byteno=228), end=SourceLocation(lineno=5, colno=24, byteno=241)
2591 )
2592
2593 match_functions = [match_function_one, match_function_two, match_function_three]
2594 match_function_source_extents = [
2595 match_function_one_source_extent,
2596 match_function_two_source_extent,
2597 match_function_three_source_extent,
2598 ]
2599
2600 self.assertEqual(outer_function.source_extent, outer_function_source_extent)
2601 self.assertTrue(
2602 all(
2603 match_function.source_extent == source_extent
2604 for match_function, source_extent in zip(match_functions, match_function_source_extents)
2605 )
2606 )
2607
2608 def test_eval_collatz_function_preserves_source_extents(self) -> None:
2609 env_object = self._run(
2610 """
2611 collatz = count ->
2612 | 1 -> count
2613 | n -> (n % 2 == 0) |> | #true () -> collatz (count + 1) (n // 2)
2614 | #false () -> collatz (count + 1) (3 * n + 1)
2615 """
2616 )
2617
2618 assert isinstance(env_object, EnvObject)
2619
2620 collatz_closure = env_object.env["collatz"]
2621
2622 assert isinstance(collatz_closure, Closure)
2623
2624 outer_function = collatz_closure.func
2625
2626 assert isinstance(outer_function, Function)
2627
2628 outer_function_source_extent = SourceExtent(
2629 start=SourceLocation(lineno=2, colno=19, byteno=19),
2630 end=SourceLocation(lineno=5, colno=79, byteno=205),
2631 )
2632
2633 inner_function = outer_function.body
2634
2635 assert isinstance(inner_function, MatchFunction)
2636
2637 apply_ast = inner_function.cases[1].body
2638
2639 assert isinstance(apply_ast, Apply)
2640
2641 arg = apply_ast.arg
2642 func = apply_ast.func
2643
2644 arg_source_extent = SourceExtent(
2645 start=SourceLocation(lineno=4, colno=18, byteno=68), end=SourceLocation(lineno=4, colno=29, byteno=79)
2646 )
2647
2648 func_source_extent = SourceExtent(
2649 start=SourceLocation(lineno=4, colno=34, byteno=84), end=SourceLocation(lineno=5, colno=79, byteno=205)
2650 )
2651
2652 self.assertEqual(outer_function.source_extent, outer_function_source_extent)
2653 self.assertEqual(arg.source_extent, arg_source_extent)
2654 self.assertEqual(func.source_extent, func_source_extent)
2655
2656
2657class ClosureOptimizeTests(unittest.TestCase):
2658 def test_int(self) -> None:
2659 self.assertEqual(free_in(Int(1)), set())
2660
2661 def test_float(self) -> None:
2662 self.assertEqual(free_in(Float(1.0)), set())
2663
2664 def test_string(self) -> None:
2665 self.assertEqual(free_in(String("x")), set())
2666
2667 def test_bytes(self) -> None:
2668 self.assertEqual(free_in(Bytes(b"x")), set())
2669
2670 def test_hole(self) -> None:
2671 self.assertEqual(free_in(Hole()), set())
2672
2673 def test_spread(self) -> None:
2674 self.assertEqual(free_in(Spread()), set())
2675
2676 def test_spread_name(self) -> None:
2677 # TODO(max): Should this be assumed to always be in a place where it
2678 # defines a name, and therefore never have free variables?
2679 self.assertEqual(free_in(Spread("x")), {"x"})
2680
2681 def test_nativefunction(self) -> None:
2682 self.assertEqual(free_in(NativeFunction("id", lambda x: x)), set())
2683
2684 def test_variant(self) -> None:
2685 self.assertEqual(free_in(Variant("x", Var("y"))), {"y"})
2686
2687 def test_var(self) -> None:
2688 self.assertEqual(free_in(Var("x")), {"x"})
2689
2690 def test_binop(self) -> None:
2691 self.assertEqual(free_in(Binop(BinopKind.ADD, Var("x"), Var("y"))), {"x", "y"})
2692
2693 def test_empty_list(self) -> None:
2694 self.assertEqual(free_in(List([])), set())
2695
2696 def test_list(self) -> None:
2697 self.assertEqual(free_in(List([Var("x"), Var("y")])), {"x", "y"})
2698
2699 def test_empty_record(self) -> None:
2700 self.assertEqual(free_in(Record({})), set())
2701
2702 def test_record(self) -> None:
2703 self.assertEqual(free_in(Record({"x": Var("x"), "y": Var("y")})), {"x", "y"})
2704
2705 def test_function(self) -> None:
2706 exp = parse(tokenize("x -> x + y"))
2707 self.assertEqual(free_in(exp), {"y"})
2708
2709 def test_nested_function(self) -> None:
2710 exp = parse(tokenize("x -> y -> x + y + z"))
2711 self.assertEqual(free_in(exp), {"z"})
2712
2713 def test_match_function(self) -> None:
2714 exp = parse(tokenize("| 1 -> x | 2 -> y | x -> 3 | z -> 4"))
2715 self.assertEqual(free_in(exp), {"x", "y"})
2716
2717 def test_match_case_int(self) -> None:
2718 exp = MatchCase(Int(1), Var("x"))
2719 self.assertEqual(free_in(exp), {"x"})
2720
2721 def test_match_case_var(self) -> None:
2722 exp = MatchCase(Var("x"), Binop(BinopKind.ADD, Var("x"), Var("y")))
2723 self.assertEqual(free_in(exp), {"y"})
2724
2725 def test_match_case_list(self) -> None:
2726 exp = MatchCase(List([Var("x")]), Binop(BinopKind.ADD, Var("x"), Var("y")))
2727 self.assertEqual(free_in(exp), {"y"})
2728
2729 def test_match_case_list_spread(self) -> None:
2730 exp = MatchCase(List([Spread()]), Binop(BinopKind.ADD, Var("xs"), Var("y")))
2731 self.assertEqual(free_in(exp), {"xs", "y"})
2732
2733 def test_match_case_list_spread_name(self) -> None:
2734 exp = MatchCase(List([Spread("xs")]), Binop(BinopKind.ADD, Var("xs"), Var("y")))
2735 self.assertEqual(free_in(exp), {"y"})
2736
2737 def test_match_case_record(self) -> None:
2738 exp = MatchCase(
2739 Record({"x": Int(1), "y": Var("y"), "a": Var("z")}),
2740 Binop(BinopKind.ADD, Binop(BinopKind.ADD, Var("x"), Var("y")), Var("z")),
2741 )
2742 self.assertEqual(free_in(exp), {"x"})
2743
2744 def test_match_case_record_spread(self) -> None:
2745 exp = MatchCase(Record({"...": Spread()}), Binop(BinopKind.ADD, Var("x"), Var("y")))
2746 self.assertEqual(free_in(exp), {"x", "y"})
2747
2748 def test_match_case_record_spread_name(self) -> None:
2749 exp = MatchCase(Record({"...": Spread("x")}), Binop(BinopKind.ADD, Var("x"), Var("y")))
2750 self.assertEqual(free_in(exp), {"y"})
2751
2752 def test_apply(self) -> None:
2753 self.assertEqual(free_in(Apply(Var("x"), Var("y"))), {"x", "y"})
2754
2755 def test_access(self) -> None:
2756 self.assertEqual(free_in(Access(Var("x"), Var("y"))), {"x", "y"})
2757
2758 def test_where(self) -> None:
2759 exp = parse(tokenize("x . x = 1"))
2760 self.assertEqual(free_in(exp), set())
2761
2762 def test_where_same_name(self) -> None:
2763 exp = parse(tokenize("x . x = x+y"))
2764 self.assertEqual(free_in(exp), {"x", "y"})
2765
2766 def test_assign(self) -> None:
2767 exp = Assign(Var("x"), Int(1))
2768 self.assertEqual(free_in(exp), set())
2769
2770 def test_assign_same_name(self) -> None:
2771 exp = Assign(Var("x"), Var("x"))
2772 self.assertEqual(free_in(exp), {"x"})
2773
2774 def test_closure(self) -> None:
2775 # TODO(max): Should x be considered free in the closure if it's in the
2776 # env?
2777 exp = Closure({"x": Int(1)}, Function(Var("_"), List([Var("x"), Var("y")])))
2778 self.assertEqual(free_in(exp), {"x", "y"})
2779
2780
2781class StdLibTests(EndToEndTestsBase):
2782 def test_stdlib_add(self) -> None:
2783 self.assertEqual(self._run("$$add 3 4", STDLIB), Int(7))
2784
2785 def test_stdlib_quote(self) -> None:
2786 self.assertEqual(self._run("$$quote (3 + 4)"), Binop(BinopKind.ADD, Int(3), Int(4)))
2787
2788 def test_stdlib_quote_pipe(self) -> None:
2789 self.assertEqual(self._run("3 + 4 |> $$quote"), Binop(BinopKind.ADD, Int(3), Int(4)))
2790
2791 def test_stdlib_quote_reverse_pipe(self) -> None:
2792 self.assertEqual(self._run("$$quote <| 3 + 4"), Binop(BinopKind.ADD, Int(3), Int(4)))
2793
2794 def test_stdlib_serialize(self) -> None:
2795 self.assertEqual(self._run("$$serialize 3", STDLIB), Bytes(value=b"i\x06"))
2796
2797 def test_stdlib_serialize_expr(self) -> None:
2798 self.assertEqual(
2799 self._run("(1+2) |> $$quote |> $$serialize", STDLIB),
2800 Bytes(value=b"+\x02+i\x02i\x04"),
2801 )
2802
2803 def test_stdlib_deserialize(self) -> None:
2804 self.assertEqual(self._run("$$deserialize ~~aQY="), Int(3))
2805
2806 def test_stdlib_deserialize_expr(self) -> None:
2807 self.assertEqual(self._run("$$deserialize ~~KwIraQJpBA=="), Binop(BinopKind.ADD, Int(1), Int(2)))
2808
2809 def test_stdlib_listlength_empty_list_returns_zero(self) -> None:
2810 self.assertEqual(self._run("$$listlength []", STDLIB), Int(0))
2811
2812 def test_stdlib_listlength_returns_length(self) -> None:
2813 self.assertEqual(self._run("$$listlength [1,2,3]", STDLIB), Int(3))
2814
2815 def test_stdlib_listlength_with_non_list_raises_type_error(self) -> None:
2816 with self.assertRaises(TypeError) as ctx:
2817 self._run("$$listlength 1", STDLIB)
2818 self.assertEqual(ctx.exception.args[0], "listlength expected List, but got Int")
2819
2820
2821class PreludeTests(EndToEndTestsBase):
2822 def test_id_returns_input(self) -> None:
2823 self.assertEqual(self._run("id 123"), Int(123))
2824
2825 def test_filter_returns_matching(self) -> None:
2826 self.assertEqual(
2827 self._run(
2828 """
2829 filter (x -> x < 4) [2, 6, 3, 7, 1, 8]
2830 """
2831 ),
2832 List([Int(2), Int(3), Int(1)]),
2833 )
2834
2835 def test_filter_with_function_returning_non_bool_raises_match_error(self) -> None:
2836 with self.assertRaises(MatchError):
2837 self._run(
2838 """
2839 filter (x -> #no ()) [1]
2840 """
2841 )
2842
2843 def test_quicksort(self) -> None:
2844 self.assertEqual(
2845 self._run(
2846 """
2847 quicksort [2, 6, 3, 7, 1, 8]
2848 """
2849 ),
2850 List([Int(1), Int(2), Int(3), Int(6), Int(7), Int(8)]),
2851 )
2852
2853 def test_quicksort_with_empty_list(self) -> None:
2854 self.assertEqual(
2855 self._run(
2856 """
2857 quicksort []
2858 """
2859 ),
2860 List([]),
2861 )
2862
2863 def test_quicksort_with_non_int_raises_type_error(self) -> None:
2864 with self.assertRaises(TypeError):
2865 self._run(
2866 """
2867 quicksort ["a", "c", "b"]
2868 """
2869 )
2870
2871 def test_concat(self) -> None:
2872 self.assertEqual(
2873 self._run(
2874 """
2875 concat [1, 2, 3] [4, 5, 6]
2876 """
2877 ),
2878 List([Int(1), Int(2), Int(3), Int(4), Int(5), Int(6)]),
2879 )
2880
2881 def test_concat_with_first_list_empty(self) -> None:
2882 self.assertEqual(
2883 self._run(
2884 """
2885 concat [] [4, 5, 6]
2886 """
2887 ),
2888 List([Int(4), Int(5), Int(6)]),
2889 )
2890
2891 def test_concat_with_second_list_empty(self) -> None:
2892 self.assertEqual(
2893 self._run(
2894 """
2895 concat [1, 2, 3] []
2896 """
2897 ),
2898 List([Int(1), Int(2), Int(3)]),
2899 )
2900
2901 def test_concat_with_both_lists_empty(self) -> None:
2902 self.assertEqual(
2903 self._run(
2904 """
2905 concat [] []
2906 """
2907 ),
2908 List([]),
2909 )
2910
2911 def test_map(self) -> None:
2912 self.assertEqual(
2913 self._run(
2914 """
2915 map (x -> x * 2) [3, 1, 2]
2916 """
2917 ),
2918 List([Int(6), Int(2), Int(4)]),
2919 )
2920
2921 def test_map_with_non_function_raises_type_error(self) -> None:
2922 with self.assertRaises(TypeError):
2923 self._run(
2924 """
2925 map 4 [3, 1, 2]
2926 """
2927 )
2928
2929 def test_map_with_non_list_raises_match_error(self) -> None:
2930 with self.assertRaises(MatchError):
2931 self._run(
2932 """
2933 map (x -> x * 2) 3
2934 """
2935 )
2936
2937 def test_range(self) -> None:
2938 self.assertEqual(
2939 self._run(
2940 """
2941 range 3
2942 """
2943 ),
2944 List([Int(0), Int(1), Int(2)]),
2945 )
2946
2947 def test_range_with_non_int_raises_type_error(self) -> None:
2948 with self.assertRaises(TypeError):
2949 self._run(
2950 """
2951 range "a"
2952 """
2953 )
2954
2955 def test_foldr(self) -> None:
2956 self.assertEqual(
2957 self._run(
2958 """
2959 foldr (x -> a -> a + x) 0 [1, 2, 3]
2960 """
2961 ),
2962 Int(6),
2963 )
2964
2965 def test_foldr_on_empty_list_returns_empty_list(self) -> None:
2966 self.assertEqual(
2967 self._run(
2968 """
2969 foldr (x -> a -> a + x) 0 []
2970 """
2971 ),
2972 Int(0),
2973 )
2974
2975 def test_take(self) -> None:
2976 self.assertEqual(
2977 self._run(
2978 """
2979 take 3 [1, 2, 3, 4, 5]
2980 """
2981 ),
2982 List([Int(1), Int(2), Int(3)]),
2983 )
2984
2985 def test_take_n_more_than_list_length_returns_full_list(self) -> None:
2986 self.assertEqual(
2987 self._run(
2988 """
2989 take 5 [1, 2, 3]
2990 """
2991 ),
2992 List([Int(1), Int(2), Int(3)]),
2993 )
2994
2995 def test_take_with_non_int_raises_type_error(self) -> None:
2996 with self.assertRaises(TypeError):
2997 self._run(
2998 """
2999 take "a" [1, 2, 3]
3000 """
3001 )
3002
3003 def test_all_returns_true(self) -> None:
3004 self.assertEqual(
3005 self._run(
3006 """
3007 all (x -> x < 5) [1, 2, 3, 4]
3008 """
3009 ),
3010 TRUE,
3011 )
3012
3013 def test_all_returns_false(self) -> None:
3014 self.assertEqual(
3015 self._run(
3016 """
3017 all (x -> x < 5) [2, 4, 6]
3018 """
3019 ),
3020 FALSE,
3021 )
3022
3023 def test_all_with_empty_list_returns_true(self) -> None:
3024 self.assertEqual(
3025 self._run(
3026 """
3027 all (x -> x == 5) []
3028 """
3029 ),
3030 TRUE,
3031 )
3032
3033 def test_all_with_non_bool_raises_type_error(self) -> None:
3034 with self.assertRaises(TypeError):
3035 self._run(
3036 """
3037 all (x -> x) [1, 2, 3]
3038 """
3039 )
3040
3041 def test_all_short_circuits(self) -> None:
3042 self.assertEqual(
3043 self._run(
3044 """
3045 all (x -> x > 1) [1, "a", "b"]
3046 """
3047 ),
3048 FALSE,
3049 )
3050
3051 def test_any_returns_true(self) -> None:
3052 self.assertEqual(
3053 self._run(
3054 """
3055 any (x -> x < 4) [1, 3, 5]
3056 """
3057 ),
3058 TRUE,
3059 )
3060
3061 def test_any_returns_false(self) -> None:
3062 self.assertEqual(
3063 self._run(
3064 """
3065 any (x -> x < 3) [4, 5, 6]
3066 """
3067 ),
3068 FALSE,
3069 )
3070
3071 def test_any_with_empty_list_returns_false(self) -> None:
3072 self.assertEqual(
3073 self._run(
3074 """
3075 any (x -> x == 5) []
3076 """
3077 ),
3078 FALSE,
3079 )
3080
3081 def test_any_with_non_bool_raises_type_error(self) -> None:
3082 with self.assertRaises(TypeError):
3083 self._run(
3084 """
3085 any (x -> x) [1, 2, 3]
3086 """
3087 )
3088
3089 def test_any_short_circuits(self) -> None:
3090 self.assertEqual(
3091 self._run(
3092 """
3093 any (x -> x > 1) [2, "a", "b"]
3094 """
3095 ),
3096 Variant("true", Hole()),
3097 )
3098
3099 def test_mul_and_div_have_left_to_right_precedence(self) -> None:
3100 self.assertEqual(
3101 self._run(
3102 """
3103 1 / 3 * 3
3104 """
3105 ),
3106 Float(1.0),
3107 )
3108
3109
3110class TypeStrTests(unittest.TestCase):
3111 def test_tyvar(self) -> None:
3112 self.assertEqual(str(TyVar("a")), "'a")
3113
3114 def test_tycon(self) -> None:
3115 self.assertEqual(str(TyCon("int", [])), "int")
3116
3117 def test_tycon_one_arg(self) -> None:
3118 self.assertEqual(str(TyCon("list", [IntType])), "(int list)")
3119
3120 def test_tycon_args(self) -> None:
3121 self.assertEqual(str(TyCon("->", [IntType, IntType])), "(int->int)")
3122
3123 def test_tyrow_empty_closed(self) -> None:
3124 self.assertEqual(str(TyEmptyRow()), "{}")
3125
3126 def test_tyrow_empty_open(self) -> None:
3127 self.assertEqual(str(TyRow({}, TyVar("a"))), "{...'a}")
3128
3129 def test_tyrow_closed(self) -> None:
3130 self.assertEqual(str(TyRow({"x": IntType, "y": StringType})), "{x=int, y=string}")
3131
3132 def test_tyrow_open(self) -> None:
3133 self.assertEqual(str(TyRow({"x": IntType, "y": StringType}, TyVar("a"))), "{x=int, y=string, ...'a}")
3134
3135 def test_tyrow_chain(self) -> None:
3136 inner = TyRow({"x": IntType})
3137 inner_var = TyVar("a")
3138 inner_var.make_equal_to(inner)
3139 outer = TyRow({"y": StringType}, inner_var)
3140 self.assertEqual(str(outer), "{x=int, y=string}")
3141
3142 def test_forall(self) -> None:
3143 self.assertEqual(str(Forall([TyVar("a"), TyVar("b")], TyVar("a"))), "(forall 'a, 'b. 'a)")
3144
3145
3146class InferTypeTests(unittest.TestCase):
3147 def setUp(self) -> None:
3148 reset_tyvar_counter()
3149
3150 def test_unify_tyvar_tyvar(self) -> None:
3151 a = TyVar("a")
3152 b = TyVar("b")
3153 unify_type(a, b)
3154 self.assertIs(a.find(), b.find())
3155
3156 def test_unify_tyvar_tycon(self) -> None:
3157 a = TyVar("a")
3158 unify_type(a, IntType)
3159 self.assertIs(a.find(), IntType)
3160 b = TyVar("b")
3161 unify_type(b, IntType)
3162 self.assertIs(b.find(), IntType)
3163
3164 def test_unify_tycon_tycon_name_mismatch(self) -> None:
3165 with self.assertRaisesRegex(InferenceError, "Unification failed"):
3166 unify_type(IntType, StringType)
3167
3168 def test_unify_tycon_tycon_arity_mismatch(self) -> None:
3169 l = TyCon("x", [TyVar("a")])
3170 r = TyCon("x", [])
3171 with self.assertRaisesRegex(InferenceError, "Unification failed"):
3172 unify_type(l, r)
3173
3174 def test_unify_tycon_tycon_unifies_arg(self) -> None:
3175 a = TyVar("a")
3176 b = TyVar("b")
3177 l = TyCon("x", [a])
3178 r = TyCon("x", [b])
3179 unify_type(l, r)
3180 self.assertIs(a.find(), b.find())
3181
3182 def test_unify_tycon_tycon_unifies_args(self) -> None:
3183 a, b, c, d = map(TyVar, "abcd")
3184 l = func_type(a, b)
3185 r = func_type(c, d)
3186 unify_type(l, r)
3187 self.assertIs(a.find(), c.find())
3188 self.assertIs(b.find(), d.find())
3189 self.assertIsNot(a.find(), b.find())
3190
3191 def test_unify_recursive_fails(self) -> None:
3192 l = TyVar("a")
3193 r = TyCon("x", [TyVar("a")])
3194 with self.assertRaisesRegex(InferenceError, "Occurs check failed"):
3195 unify_type(l, r)
3196
3197 def test_unify_empty_row(self) -> None:
3198 unify_type(TyEmptyRow(), TyEmptyRow())
3199
3200 def test_unify_empty_row_open(self) -> None:
3201 l = TyRow({}, TyVar("a"))
3202 r = TyRow({}, TyVar("b"))
3203 unify_type(l, r)
3204 self.assertIs(l.rest.find(), r.rest.find())
3205
3206 def test_unify_row_unifies_fields(self) -> None:
3207 a = TyVar("a")
3208 b = TyVar("b")
3209 l = TyRow({"x": a})
3210 r = TyRow({"x": b})
3211 unify_type(l, r)
3212 self.assertIs(a.find(), b.find())
3213
3214 def test_unify_empty_right(self) -> None:
3215 l = TyRow({"x": IntType})
3216 r = TyEmptyRow()
3217 with self.assertRaisesRegex(InferenceError, "Unifying row {x=int} with empty row"):
3218 unify_type(l, r)
3219
3220 def test_unify_empty_left(self) -> None:
3221 l = TyEmptyRow()
3222 r = TyRow({"x": IntType})
3223 with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {x=int}"):
3224 unify_type(l, r)
3225
3226 def test_unify_missing_closed(self) -> None:
3227 l = TyRow({"x": IntType})
3228 r = TyRow({"y": IntType})
3229 with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {y=int, ...'t0}"):
3230 unify_type(l, r)
3231
3232 def test_unify_one_open_one_closed(self) -> None:
3233 rest = TyVar("r1")
3234 l = TyRow({"x": IntType})
3235 r = TyRow({"x": IntType}, rest)
3236 unify_type(l, r)
3237 self.assertTyEqual(rest.find(), TyEmptyRow())
3238
3239 def test_unify_one_open_more_than_one_closed(self) -> None:
3240 rest = TyVar("r1")
3241 l = TyRow({"x": IntType})
3242 r = TyRow({"x": IntType, "y": StringType}, rest)
3243 with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {y=string, ...'r1}"):
3244 unify_type(l, r)
3245
3246 def test_unify_one_open_one_closed_more(self) -> None:
3247 rest = TyVar("r1")
3248 l = TyRow({"x": IntType, "y": StringType})
3249 r = TyRow({"x": IntType}, rest)
3250 unify_type(l, r)
3251 self.assertTyEqual(rest.find(), TyRow({"y": StringType}))
3252
3253 def test_unify_left_missing_open(self) -> None:
3254 l = TyRow({}, TyVar("r0"))
3255 r = TyRow({"y": IntType}, TyVar("r1"))
3256 unify_type(l, r)
3257 self.assertTyEqual(l.rest, TyRow({"y": IntType}, TyVar("r1")))
3258 assert isinstance(r.rest, TyVar)
3259 self.assertTrue(r.rest.is_unbound())
3260
3261 def test_unify_right_missing_open(self) -> None:
3262 l = TyRow({"x": IntType}, TyVar("r0"))
3263 r = TyRow({}, TyVar("r1"))
3264 unify_type(l, r)
3265 assert isinstance(l.rest, TyVar)
3266 self.assertTrue(l.rest.is_unbound())
3267 self.assertTyEqual(r.rest, TyRow({"x": IntType}, TyVar("r0")))
3268
3269 def test_unify_both_missing_open(self) -> None:
3270 l = TyRow({"x": IntType}, TyVar("r0"))
3271 r = TyRow({"y": IntType}, TyVar("r1"))
3272 unify_type(l, r)
3273 self.assertTyEqual(l.rest, TyRow({"y": IntType}, TyVar("t0")))
3274 self.assertTyEqual(r.rest, TyRow({"x": IntType}, TyVar("t0")))
3275
3276 def test_minimize_tyvar(self) -> None:
3277 ty = fresh_tyvar()
3278 self.assertEqual(minimize(ty), TyVar("a"))
3279
3280 def test_minimize_tycon(self) -> None:
3281 ty = func_type(TyVar("t0"), TyVar("t1"), TyVar("t0"))
3282 self.assertEqual(minimize(ty), func_type(TyVar("a"), TyVar("b"), TyVar("a")))
3283
3284 def infer(self, expr: Object, ctx: Context) -> MonoType:
3285 return minimize(infer_type(expr, ctx))
3286
3287 def assertTyEqual(self, l: MonoType, r: MonoType) -> bool:
3288 l = l.find()
3289 r = r.find()
3290 if isinstance(l, TyVar) and isinstance(r, TyVar):
3291 if l != r:
3292 self.fail(f"Type mismatch: {l} != {r}")
3293 return True
3294 if isinstance(l, TyCon) and isinstance(r, TyCon):
3295 if l.name != r.name:
3296 self.fail(f"Type mismatch: {l} != {r}")
3297 if len(l.args) != len(r.args):
3298 self.fail(f"Type mismatch: {l} != {r}")
3299 for l_arg, r_arg in zip(l.args, r.args):
3300 self.assertTyEqual(l_arg, r_arg)
3301 return True
3302 if isinstance(l, TyEmptyRow) and isinstance(r, TyEmptyRow):
3303 return True
3304 if isinstance(l, TyRow) and isinstance(r, TyRow):
3305 l_keys = set(l.fields.keys())
3306 r_keys = set(r.fields.keys())
3307 if l_keys != r_keys:
3308 self.fail(f"Type mismatch: {l} != {r}")
3309 for key in l_keys:
3310 self.assertTyEqual(l.fields[key], r.fields[key])
3311 self.assertTyEqual(l.rest, r.rest)
3312 return True
3313 self.fail(f"Type mismatch: {l} != {r}")
3314
3315 def test_unbound_var(self) -> None:
3316 with self.assertRaisesRegex(InferenceError, "Unbound variable"):
3317 self.infer(Var("a"), {})
3318
3319 def test_var_instantiates_scheme(self) -> None:
3320 ty = self.infer(Var("a"), {"a": Forall([TyVar("b")], TyVar("b"))})
3321 self.assertTyEqual(ty, TyVar("a"))
3322
3323 def test_int(self) -> None:
3324 ty = self.infer(Int(123), {})
3325 self.assertTyEqual(ty, IntType)
3326
3327 def test_float(self) -> None:
3328 ty = self.infer(Float(1.0), {})
3329 self.assertTyEqual(ty, FloatType)
3330
3331 def test_string(self) -> None:
3332 ty = self.infer(String("abc"), {})
3333 self.assertTyEqual(ty, StringType)
3334
3335 def test_function_returns_arg(self) -> None:
3336 ty = self.infer(Function(Var("x"), Var("x")), {})
3337 self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a")))
3338
3339 def test_nested_function_outer(self) -> None:
3340 ty = self.infer(Function(Var("x"), Function(Var("y"), Var("x"))), {})
3341 self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("b"), TyVar("a")))
3342
3343 def test_nested_function_inner(self) -> None:
3344 ty = self.infer(Function(Var("x"), Function(Var("y"), Var("y"))), {})
3345 self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("b"), TyVar("b")))
3346
3347 def test_apply_id_int(self) -> None:
3348 func = Function(Var("x"), Var("x"))
3349 arg = Int(123)
3350 ty = self.infer(Apply(func, arg), {})
3351 self.assertTyEqual(ty, IntType)
3352
3353 def test_apply_two_arg_returns_function(self) -> None:
3354 func = Function(Var("x"), Function(Var("y"), Var("x")))
3355 arg = Int(123)
3356 ty = self.infer(Apply(func, arg), {})
3357 self.assertTyEqual(ty, func_type(TyVar("a"), IntType))
3358
3359 def test_binop_add_constrains_int(self) -> None:
3360 expr = Binop(BinopKind.ADD, Var("x"), Var("y"))
3361 ty = self.infer(
3362 expr,
3363 {
3364 "x": Forall([], TyVar("a")),
3365 "y": Forall([], TyVar("b")),
3366 "+": Forall([], func_type(IntType, IntType, IntType)),
3367 },
3368 )
3369 self.assertTyEqual(ty, IntType)
3370
3371 def test_binop_add_function_constrains_int(self) -> None:
3372 x = Var("x")
3373 y = Var("y")
3374 expr = Function(Var("x"), Function(Var("y"), Binop(BinopKind.ADD, x, y)))
3375 ty = self.infer(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3376 self.assertTyEqual(ty, func_type(IntType, IntType, IntType))
3377 self.assertTyEqual(type_of(x), IntType)
3378 self.assertTyEqual(type_of(y), IntType)
3379
3380 def test_let(self) -> None:
3381 expr = Where(Var("f"), Assign(Var("f"), Function(Var("x"), Var("x"))))
3382 ty = self.infer(expr, {})
3383 self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a")))
3384
3385 def test_apply_monotype_to_different_types_raises(self) -> None:
3386 expr = Where(
3387 Where(Var("x"), Assign(Var("x"), Apply(Var("f"), Int(123)))),
3388 Assign(Var("y"), Apply(Var("f"), Float(123.0))),
3389 )
3390 ctx = {"f": Forall([], func_type(TyVar("a"), TyVar("a")))}
3391 with self.assertRaisesRegex(InferenceError, "Unification failed"):
3392 self.infer(expr, ctx)
3393
3394 def test_apply_polytype_to_different_types(self) -> None:
3395 expr = Where(
3396 Where(Var("x"), Assign(Var("x"), Apply(Var("f"), Int(123)))),
3397 Assign(Var("y"), Apply(Var("f"), Float(123.0))),
3398 )
3399 ty = self.infer(expr, {"f": Forall([TyVar("a")], func_type(TyVar("a"), TyVar("a")))})
3400 self.assertTyEqual(ty, IntType)
3401
3402 def test_generalization(self) -> None:
3403 # From https://okmij.org/ftp/ML/generalization.html
3404 expr = parse(tokenize("x -> (y . y = x)"))
3405 ty = self.infer(expr, {})
3406 self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a")))
3407
3408 def test_generalization2(self) -> None:
3409 # From https://okmij.org/ftp/ML/generalization.html
3410 expr = parse(tokenize("x -> (y . y = z -> x z)"))
3411 ty = self.infer(expr, {})
3412 self.assertTyEqual(ty, func_type(func_type(TyVar("a"), TyVar("b")), func_type(TyVar("a"), TyVar("b"))))
3413
3414 def test_id(self) -> None:
3415 expr = Function(Var("x"), Var("x"))
3416 ty = self.infer(expr, {})
3417 self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a")))
3418
3419 def test_empty_list(self) -> None:
3420 expr = List([])
3421 ty = infer_type(expr, {})
3422 self.assertTyEqual(ty, TyCon("list", [TyVar("t0")]))
3423
3424 def test_list_int(self) -> None:
3425 expr = List([Int(123)])
3426 ty = infer_type(expr, {})
3427 self.assertTyEqual(ty, TyCon("list", [IntType]))
3428
3429 def test_list_mismatch(self) -> None:
3430 expr = List([Int(123), Float(123.0)])
3431 with self.assertRaisesRegex(InferenceError, "Unification failed"):
3432 infer_type(expr, {})
3433
3434 def test_recursive_fact(self) -> None:
3435 expr = parse(tokenize("fact . fact = | 0 -> 1 | n -> n * fact (n-1)"))
3436 ty = infer_type(
3437 expr,
3438 {
3439 "*": Forall([], func_type(IntType, IntType, IntType)),
3440 "-": Forall([], func_type(IntType, IntType, IntType)),
3441 },
3442 )
3443 self.assertTyEqual(ty, func_type(IntType, IntType))
3444
3445 def test_match_int_int(self) -> None:
3446 expr = parse(tokenize("| 0 -> 1"))
3447 ty = infer_type(expr, {})
3448 self.assertTyEqual(ty, func_type(IntType, IntType))
3449
3450 def test_match_int_int_two_cases(self) -> None:
3451 expr = parse(tokenize("| 0 -> 1 | 1 -> 2"))
3452 ty = infer_type(expr, {})
3453 self.assertTyEqual(ty, func_type(IntType, IntType))
3454
3455 def test_match_int_int_int_float(self) -> None:
3456 expr = parse(tokenize("| 0 -> 1 | 1 -> 2.0"))
3457 with self.assertRaisesRegex(InferenceError, "Unification failed"):
3458 infer_type(expr, {})
3459
3460 def test_match_int_int_float_int(self) -> None:
3461 expr = parse(tokenize("| 0 -> 1 | 1.0 -> 2"))
3462 with self.assertRaisesRegex(InferenceError, "Unification failed"):
3463 infer_type(expr, {})
3464
3465 def test_match_var(self) -> None:
3466 expr = parse(tokenize("| x -> x + 1"))
3467 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3468 self.assertTyEqual(ty, func_type(IntType, IntType))
3469
3470 def test_match_int_var(self) -> None:
3471 expr = parse(tokenize("| 0 -> 1 | x -> x"))
3472 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3473 self.assertTyEqual(ty, func_type(IntType, IntType))
3474
3475 def test_match_list_of_int(self) -> None:
3476 expr = parse(tokenize("| [x] -> x + 1"))
3477 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3478 self.assertTyEqual(ty, func_type(list_type(IntType), IntType))
3479
3480 def test_match_list_of_int_to_list(self) -> None:
3481 expr = parse(tokenize("| [x] -> [x + 1]"))
3482 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3483 self.assertTyEqual(ty, func_type(list_type(IntType), list_type(IntType)))
3484
3485 def test_match_list_of_int_to_int(self) -> None:
3486 expr = parse(tokenize("| [] -> 0 | [x] -> 1 | [x, y] -> x+y"))
3487 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3488 self.assertTyEqual(ty, func_type(list_type(IntType), IntType))
3489
3490 def test_recursive_var_is_unbound(self) -> None:
3491 expr = parse(tokenize("a . a = a"))
3492 with self.assertRaisesRegex(InferenceError, "Unbound variable"):
3493 infer_type(expr, {})
3494
3495 def test_recursive(self) -> None:
3496 expr = parse(
3497 tokenize(
3498 """
3499 length
3500 . length =
3501 | [] -> 0
3502 | [x, ...xs] -> 1 + length xs
3503 """
3504 )
3505 )
3506 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3507 self.assertTyEqual(ty, func_type(list_type(TyVar("t8")), IntType))
3508
3509 def test_match_list_to_list(self) -> None:
3510 expr = parse(tokenize("| [] -> [] | x -> x"))
3511 ty = infer_type(expr, {})
3512 self.assertTyEqual(ty, func_type(list_type(TyVar("t1")), list_type(TyVar("t1"))))
3513
3514 def test_match_list_spread(self) -> None:
3515 expr = parse(tokenize("head . head = | [x, ...] -> x"))
3516 ty = infer_type(expr, {})
3517 self.assertTyEqual(ty, func_type(list_type(TyVar("t4")), TyVar("t4")))
3518
3519 def test_match_list_spread_rest(self) -> None:
3520 expr = parse(tokenize("tail . tail = | [x, ...xs] -> xs"))
3521 ty = infer_type(expr, {})
3522 self.assertTyEqual(ty, func_type(list_type(TyVar("t4")), list_type(TyVar("t4"))))
3523
3524 def test_match_list_spread_named(self) -> None:
3525 expr = parse(tokenize("sum . sum = | [] -> 0 | [x, ...xs] -> x + sum xs"))
3526 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3527 self.assertTyEqual(ty, func_type(list_type(IntType), IntType))
3528
3529 def test_match_list_int_to_list(self) -> None:
3530 expr = parse(tokenize("| [] -> [3] | x -> x"))
3531 ty = infer_type(expr, {})
3532 self.assertTyEqual(ty, func_type(list_type(IntType), list_type(IntType)))
3533
3534 def test_inc(self) -> None:
3535 expr = parse(tokenize("inc . inc = | 0 -> 1 | 1 -> 2 | a -> a + 1"))
3536 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3537 self.assertTyEqual(ty, func_type(IntType, IntType))
3538
3539 def test_bytes(self) -> None:
3540 expr = Bytes(b"abc")
3541 ty = infer_type(expr, {})
3542 self.assertTyEqual(ty, BytesType)
3543
3544 def test_hole(self) -> None:
3545 expr = Hole()
3546 ty = infer_type(expr, {})
3547 self.assertTyEqual(ty, HoleType)
3548
3549 def test_string_concat(self) -> None:
3550 expr = parse(tokenize('"a" ++ "b"'))
3551 ty = infer_type(expr, OP_ENV)
3552 self.assertTyEqual(ty, StringType)
3553
3554 def test_cons(self) -> None:
3555 expr = parse(tokenize("1 >+ [2]"))
3556 ty = infer_type(expr, OP_ENV)
3557 self.assertTyEqual(ty, list_type(IntType))
3558
3559 def test_append(self) -> None:
3560 expr = parse(tokenize("[1] +< 2"))
3561 ty = infer_type(expr, OP_ENV)
3562 self.assertTyEqual(ty, list_type(IntType))
3563
3564 def test_record(self) -> None:
3565 expr = Record({"a": Int(1), "b": String("hello")})
3566 ty = infer_type(expr, {})
3567 self.assertTyEqual(ty, TyRow({"a": IntType, "b": StringType}))
3568
3569 def test_match_record(self) -> None:
3570 expr = MatchFunction(
3571 [
3572 MatchCase(
3573 Record({"x": Var("x")}),
3574 Var("x"),
3575 )
3576 ]
3577 )
3578 ty = infer_type(expr, {})
3579 self.assertTyEqual(ty, func_type(TyRow({"x": TyVar("t1")}), TyVar("t1")))
3580
3581 def test_access_poly(self) -> None:
3582 expr = Function(Var("r"), Access(Var("r"), Var("x")))
3583 ty = infer_type(expr, {})
3584 self.assertTyEqual(ty, func_type(TyRow({"x": TyVar("t1")}, TyVar("t2")), TyVar("t1")))
3585
3586 def test_apply_row(self) -> None:
3587 row0 = Record({"x": Int(1)})
3588 row1 = Record({"x": Int(1), "y": Int(2)})
3589 scheme = Forall([], func_type(TyRow({"x": IntType}, TyVar("a")), IntType))
3590 ty0 = infer_type(Apply(Var("f"), row0), {"f": scheme})
3591 self.assertTyEqual(ty0, IntType)
3592 with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {y=int}"):
3593 infer_type(Apply(Var("f"), row1), {"f": scheme})
3594
3595 def test_apply_row_polymorphic(self) -> None:
3596 row0 = Record({"x": Int(1)})
3597 row1 = Record({"x": Int(1), "y": Int(2)})
3598 row2 = Record({"x": Int(1), "y": Int(2), "z": Int(3)})
3599 scheme = Forall([TyVar("a")], func_type(TyRow({"x": IntType}, TyVar("a")), IntType))
3600 ty0 = infer_type(Apply(Var("f"), row0), {"f": scheme})
3601 self.assertTyEqual(ty0, IntType)
3602 ty1 = infer_type(Apply(Var("f"), row1), {"f": scheme})
3603 self.assertTyEqual(ty1, IntType)
3604 ty2 = infer_type(Apply(Var("f"), row2), {"f": scheme})
3605 self.assertTyEqual(ty2, IntType)
3606
3607 def test_example_rec_access(self) -> None:
3608 expr = parse(tokenize('rec@a . rec = { a = 1, b = "x" }'))
3609 ty = infer_type(expr, {})
3610 self.assertTyEqual(ty, IntType)
3611
3612 def test_example_rec_access_poly(self) -> None:
3613 expr = parse(
3614 tokenize(
3615 """
3616(get_x {x=1}) + (get_x {x=2,y=3})
3617. get_x = | { x=x, ... } -> x
3618"""
3619 )
3620 )
3621 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3622 self.assertTyEqual(ty, IntType)
3623
3624 def test_example_rec_access_poly_named_bug(self) -> None:
3625 expr = parse(
3626 tokenize(
3627 """
3628(filter_x {x=1, y=2}) + 3
3629. filter_x = | { x=x, ...xs } -> xs
3630"""
3631 )
3632 )
3633 with self.assertRaisesRegex(InferenceError, "Cannot unify int and {y=int}"):
3634 infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3635
3636 def test_example_rec_access_rest(self) -> None:
3637 expr = parse(
3638 tokenize(
3639 """
3640| { x=x, ...xs } -> xs
3641"""
3642 )
3643 )
3644 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3645 self.assertTyEqual(ty, func_type(TyRow({"x": TyVar("t1")}, TyVar("t2")), TyVar("t2")))
3646
3647 def test_example_match_rec_access_rest(self) -> None:
3648 expr = parse(
3649 tokenize(
3650 """
3651filter_x {x=1, y=2}
3652. filter_x = | { x=x, ...xs } -> xs
3653"""
3654 )
3655 )
3656 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))})
3657 self.assertTyEqual(ty, TyRow({"y": IntType}))
3658
3659 def test_example_rec_access_poly_named(self) -> None:
3660 expr = parse(
3661 tokenize(
3662 """
3663[(filter_x {x=1, y=2}), (filter_x {x=2, y=3, z=4})]
3664. filter_x = | { x=x, ...xs } -> xs
3665"""
3666 )
3667 )
3668 with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {z=int}"):
3669 infer_type(expr, {})
3670
3671
3672class SerializerTests(unittest.TestCase):
3673 def _serialize(self, obj: Object) -> bytes:
3674 serializer = Serializer()
3675 serializer.serialize(obj)
3676 return bytes(serializer.output)
3677
3678 def test_short(self) -> None:
3679 self.assertEqual(self._serialize(Int(-1)), TYPE_SHORT + b"\x01")
3680 self.assertEqual(self._serialize(Int(0)), TYPE_SHORT + b"\x00")
3681 self.assertEqual(self._serialize(Int(1)), TYPE_SHORT + b"\x02")
3682 self.assertEqual(self._serialize(Int(-(2**33))), TYPE_SHORT + b"\xff\xff\xff\xff?")
3683 self.assertEqual(self._serialize(Int(2**33)), TYPE_SHORT + b"\x80\x80\x80\x80@")
3684 self.assertEqual(self._serialize(Int(-(2**63))), TYPE_SHORT + b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01")
3685 self.assertEqual(self._serialize(Int(2**63 - 1)), TYPE_SHORT + b"\xfe\xff\xff\xff\xff\xff\xff\xff\xff\x01")
3686
3687 def test_long(self) -> None:
3688 self.assertEqual(
3689 self._serialize(Int(2**100)),
3690 TYPE_LONG + b"\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00",
3691 )
3692 self.assertEqual(
3693 self._serialize(Int(-(2**100))),
3694 TYPE_LONG + b"\x04\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x1f\x00\x00\x00",
3695 )
3696
3697 def test_string(self) -> None:
3698 self.assertEqual(self._serialize(String("hello")), TYPE_STRING + b"\nhello")
3699
3700 def test_empty_list(self) -> None:
3701 obj = List([])
3702 self.assertEqual(self._serialize(obj), ref(TYPE_LIST) + b"\x00")
3703
3704 def test_list(self) -> None:
3705 obj = List([Int(123), Int(456)])
3706 self.assertEqual(self._serialize(obj), ref(TYPE_LIST) + b"\x04i\xf6\x01i\x90\x07")
3707
3708 def test_self_referential_list(self) -> None:
3709 obj = List([])
3710 obj.items.append(obj)
3711 self.assertEqual(self._serialize(obj), ref(TYPE_LIST) + b"\x02r\x00")
3712
3713 def test_variant(self) -> None:
3714 obj = Variant("abc", Int(123))
3715 self.assertEqual(self._serialize(obj), TYPE_VARIANT + b"\x06abci\xf6\x01")
3716
3717 def test_record(self) -> None:
3718 obj = Record({"x": Int(1), "y": Int(2)})
3719 self.assertEqual(self._serialize(obj), TYPE_RECORD + b"\x04\x02xi\x02\x02yi\x04")
3720
3721 def test_var(self) -> None:
3722 obj = Var("x")
3723 self.assertEqual(self._serialize(obj), TYPE_VAR + b"\x02x")
3724
3725 def test_function(self) -> None:
3726 obj = Function(Var("x"), Var("x"))
3727 self.assertEqual(self._serialize(obj), TYPE_FUNCTION + b"v\x02xv\x02x")
3728
3729 def test_empty_match_function(self) -> None:
3730 obj = MatchFunction([])
3731 self.assertEqual(self._serialize(obj), TYPE_MATCH_FUNCTION + b"\x00")
3732
3733 def test_match_function(self) -> None:
3734 obj = MatchFunction([MatchCase(Int(1), Var("x")), MatchCase(List([Int(1)]), Var("y"))])
3735 self.assertEqual(self._serialize(obj), TYPE_MATCH_FUNCTION + b"\x04i\x02v\x02x\xdb\x02i\x02v\x02y")
3736
3737 def test_closure(self) -> None:
3738 obj = Closure({}, Function(Var("x"), Var("x")))
3739 self.assertEqual(self._serialize(obj), ref(TYPE_CLOSURE) + b"fv\x02xv\x02x\x00")
3740
3741 def test_self_referential_closure(self) -> None:
3742 obj = Closure({}, Function(Var("x"), Var("x")))
3743 assert isinstance(obj.env, dict) # For mypy
3744 obj.env["self"] = obj
3745 self.assertEqual(self._serialize(obj), ref(TYPE_CLOSURE) + b"fv\x02xv\x02x\x02\x08selfr\x00")
3746
3747 def test_bytes(self) -> None:
3748 obj = Bytes(b"abc")
3749 self.assertEqual(self._serialize(obj), TYPE_BYTES + b"\x06abc")
3750
3751 def test_float(self) -> None:
3752 obj = Float(3.14)
3753 self.assertEqual(self._serialize(obj), TYPE_FLOAT + b"\x1f\x85\xebQ\xb8\x1e\t@")
3754
3755 def test_hole(self) -> None:
3756 self.assertEqual(self._serialize(Hole()), TYPE_HOLE)
3757
3758 def test_assign(self) -> None:
3759 obj = Assign(Var("x"), Int(123))
3760 self.assertEqual(self._serialize(obj), TYPE_ASSIGN + b"v\x02xi\xf6\x01")
3761
3762 def test_binop(self) -> None:
3763 obj = Binop(BinopKind.ADD, Int(3), Int(4))
3764 self.assertEqual(self._serialize(obj), TYPE_BINOP + b"\x02+i\x06i\x08")
3765
3766 def test_apply(self) -> None:
3767 obj = Apply(Var("f"), Var("x"))
3768 self.assertEqual(self._serialize(obj), TYPE_APPLY + b"v\x02fv\x02x")
3769
3770 def test_where(self) -> None:
3771 obj = Where(Var("a"), Var("b"))
3772 self.assertEqual(self._serialize(obj), TYPE_WHERE + b"v\x02av\x02b")
3773
3774 def test_access(self) -> None:
3775 obj = Access(Var("a"), Var("b"))
3776 self.assertEqual(self._serialize(obj), TYPE_ACCESS + b"v\x02av\x02b")
3777
3778 def test_spread(self) -> None:
3779 self.assertEqual(self._serialize(Spread()), TYPE_SPREAD)
3780 self.assertEqual(self._serialize(Spread("rest")), TYPE_NAMED_SPREAD + b"\x08rest")
3781
3782 def test_true_variant(self) -> None:
3783 obj = Variant("true", Hole())
3784 self.assertEqual(self._serialize(obj), TYPE_TRUE)
3785
3786 def test_false_variant(self) -> None:
3787 obj = Variant("false", Hole())
3788 self.assertEqual(self._serialize(obj), TYPE_FALSE)
3789
3790 def test_true_variant_with_non_hole_uses_regular_variant(self) -> None:
3791 obj = Variant("true", Int(123))
3792 self.assertEqual(self._serialize(obj), TYPE_VARIANT + b"\x08truei\xf6\x01")
3793
3794 def test_false_variant_with_non_hole_uses_regular_variant(self) -> None:
3795 obj = Variant("false", Int(123))
3796 self.assertEqual(self._serialize(obj), TYPE_VARIANT + b"\x0afalsei\xf6\x01")
3797
3798
3799class RoundTripSerializationTests(unittest.TestCase):
3800 def _serialize(self, obj: Object) -> bytes:
3801 serializer = Serializer()
3802 serializer.serialize(obj)
3803 return bytes(serializer.output)
3804
3805 def _deserialize(self, flat: bytes) -> Object:
3806 deserializer = Deserializer(flat)
3807 return deserializer.parse()
3808
3809 def _serde(self, obj: Object) -> Object:
3810 flat = self._serialize(obj)
3811 return self._deserialize(flat)
3812
3813 def _rt(self, obj: Object) -> None:
3814 result = self._serde(obj)
3815 self.assertEqual(result, obj)
3816
3817 def test_short(self) -> None:
3818 for i in range(-(2**16), 2**16):
3819 self._rt(Int(i))
3820
3821 self._rt(Int(-(2**63)))
3822 self._rt(Int(2**63 - 1))
3823
3824 def test_long(self) -> None:
3825 self._rt(Int(2**100))
3826 self._rt(Int(-(2**100)))
3827
3828 def test_string(self) -> None:
3829 self._rt(String(""))
3830 self._rt(String("a"))
3831 self._rt(String("hello"))
3832
3833 def test_list(self) -> None:
3834 self._rt(List([]))
3835 self._rt(List([Int(123), Int(345)]))
3836
3837 def test_self_referential_list(self) -> None:
3838 ls = List([])
3839 ls.items.append(ls)
3840 result = self._serde(ls)
3841 self.assertIsInstance(result, List)
3842 assert isinstance(result, List) # For mypy
3843 self.assertIsInstance(result.items, list)
3844 self.assertEqual(len(result.items), 1)
3845 self.assertIs(result.items[0], result)
3846
3847 def test_record(self) -> None:
3848 self._rt(Record({"x": Int(1), "y": Int(2)}))
3849
3850 def test_variant(self) -> None:
3851 self._rt(Variant("abc", Int(123)))
3852
3853 def test_var(self) -> None:
3854 self._rt(Var("x"))
3855
3856 def test_function(self) -> None:
3857 self._rt(Function(Var("x"), Var("x")))
3858
3859 def test_empty_match_function(self) -> None:
3860 self._rt(MatchFunction([]))
3861
3862 def test_match_function(self) -> None:
3863 obj = MatchFunction([MatchCase(Int(1), Var("x")), MatchCase(List([Int(1)]), Var("y"))])
3864 self._rt(obj)
3865
3866 def test_closure(self) -> None:
3867 self._rt(Closure({}, Function(Var("x"), Var("x"))))
3868
3869 def test_self_referential_closure(self) -> None:
3870 obj = Closure({}, Function(Var("x"), Var("x")))
3871 assert isinstance(obj.env, dict) # For mypy
3872 obj.env["self"] = obj
3873 result = self._serde(obj)
3874 self.assertIsInstance(result, Closure)
3875 assert isinstance(result, Closure) # For mypy
3876 self.assertIsInstance(result.env, dict)
3877 self.assertEqual(len(result.env), 1)
3878 self.assertIs(result.env["self"], result)
3879
3880 def test_bytes(self) -> None:
3881 self._rt(Bytes(b"abc"))
3882
3883 def test_float(self) -> None:
3884 self._rt(Float(3.14))
3885
3886 def test_hole(self) -> None:
3887 self._rt(Hole())
3888
3889 def test_assign(self) -> None:
3890 self._rt(Assign(Var("x"), Int(123)))
3891
3892 def test_binop(self) -> None:
3893 self._rt(Binop(BinopKind.ADD, Int(3), Int(4)))
3894
3895 def test_apply(self) -> None:
3896 self._rt(Apply(Var("f"), Var("x")))
3897
3898 def test_where(self) -> None:
3899 self._rt(Where(Var("a"), Var("b")))
3900
3901 def test_access(self) -> None:
3902 self._rt(Access(Var("a"), Var("b")))
3903
3904 def test_spread(self) -> None:
3905 self._rt(Spread())
3906 self._rt(Spread("rest"))
3907
3908
3909class ScrapMonadTests(unittest.TestCase):
3910 def test_create_copies_env(self) -> None:
3911 env = {"a": Int(123)}
3912 result = ScrapMonad(env)
3913 self.assertEqual(result.env, env)
3914 self.assertIsNot(result.env, env)
3915
3916 def test_bind_returns_new_monad(self) -> None:
3917 env = {"a": Int(123)}
3918 orig = ScrapMonad(env)
3919 result, next_monad = orig.bind(Assign(Var("b"), Int(456)))
3920 self.assertEqual(orig.env, {"a": Int(123)})
3921 self.assertEqual(next_monad.env, {"a": Int(123), "b": Int(456)})
3922
3923
3924class PrettyPrintTests(unittest.TestCase):
3925 def test_pretty_print_int(self) -> None:
3926 obj = Int(1)
3927 self.assertEqual(pretty(obj), "1")
3928
3929 def test_pretty_print_float(self) -> None:
3930 obj = Float(3.14)
3931 self.assertEqual(pretty(obj), "3.14")
3932
3933 def test_pretty_print_string(self) -> None:
3934 obj = String("hello")
3935 self.assertEqual(pretty(obj), '"hello"')
3936
3937 def test_pretty_print_bytes(self) -> None:
3938 obj = Bytes(b"abc")
3939 self.assertEqual(pretty(obj), "~~YWJj")
3940
3941 def test_pretty_print_var(self) -> None:
3942 obj = Var("ref")
3943 self.assertEqual(pretty(obj), "ref")
3944
3945 def test_pretty_print_hole(self) -> None:
3946 obj = Hole()
3947 self.assertEqual(pretty(obj), "()")
3948
3949 def test_pretty_print_spread(self) -> None:
3950 obj = Spread()
3951 self.assertEqual(pretty(obj), "...")
3952
3953 def test_pretty_print_named_spread(self) -> None:
3954 obj = Spread("rest")
3955 self.assertEqual(pretty(obj), "...rest")
3956
3957 def test_pretty_print_binop(self) -> None:
3958 obj = Binop(BinopKind.ADD, Int(1), Int(2))
3959 self.assertEqual(pretty(obj), "1 + 2")
3960
3961 def test_pretty_print_binop_precedence(self) -> None:
3962 obj = Binop(BinopKind.ADD, Int(1), Binop(BinopKind.MUL, Int(2), Int(3)))
3963 self.assertEqual(pretty(obj), "1 + 2 * 3")
3964 obj = Binop(BinopKind.MUL, Binop(BinopKind.ADD, Int(1), Int(2)), Int(3))
3965 self.assertEqual(pretty(obj), "(1 + 2) * 3")
3966
3967 def test_pretty_print_int_list(self) -> None:
3968 obj = List([Int(1), Int(2), Int(3)])
3969 self.assertEqual(pretty(obj), "[1, 2, 3]")
3970
3971 def test_pretty_print_str_list(self) -> None:
3972 obj = List([String("1"), String("2"), String("3")])
3973 self.assertEqual(pretty(obj), '["1", "2", "3"]')
3974
3975 def test_pretty_print_recursion(self) -> None:
3976 obj = List([])
3977 obj.items.append(obj)
3978 self.assertEqual(pretty(obj), "[...]")
3979
3980 def test_pretty_print_assign(self) -> None:
3981 obj = Assign(Var("x"), Int(3))
3982 self.assertEqual(pretty(obj), "x = 3")
3983
3984 def test_pretty_print_function(self) -> None:
3985 obj = Function(Var("x"), Binop(BinopKind.ADD, Int(1), Var("x")))
3986 self.assertEqual(pretty(obj), "x -> 1 + x")
3987
3988 def test_pretty_print_nested_function(self) -> None:
3989 obj = Function(Var("x"), Function(Var("y"), Binop(BinopKind.ADD, Var("x"), Var("y"))))
3990 self.assertEqual(pretty(obj), "x -> y -> x + y")
3991
3992 def test_pretty_print_apply(self) -> None:
3993 obj = Apply(Var("x"), Var("y"))
3994 self.assertEqual(pretty(obj), "x y")
3995
3996 def test_pretty_print_compose(self) -> None:
3997 gensym_reset()
3998 obj = parse(tokenize("(x -> x + 3) << (x -> x * 2)"))
3999 self.assertEqual(
4000 pretty(obj),
4001 "$v0 -> (x -> x + 3) ((x -> x * 2) $v0)",
4002 )
4003 gensym_reset()
4004 obj = parse(tokenize("(x -> x + 3) >> (x -> x * 2)"))
4005 self.assertEqual(
4006 pretty(obj),
4007 "$v0 -> (x -> x * 2) ((x -> x + 3) $v0)",
4008 )
4009
4010 def test_pretty_print_where(self) -> None:
4011 obj = Where(
4012 Binop(BinopKind.ADD, Var("a"), Var("b")),
4013 Assign(Var("a"), Int(1)),
4014 )
4015 self.assertEqual(pretty(obj), "a + b . a = 1")
4016
4017 def test_pretty_print_assert(self) -> None:
4018 obj = Assert(Int(123), Variant("true", String("foo")))
4019 self.assertEqual(pretty(obj), '123 ! #true "foo"')
4020
4021 def test_pretty_print_envobject(self) -> None:
4022 obj = EnvObject({"x": Int(1)})
4023 self.assertEqual(pretty(obj), "EnvObject({'x': Int(value=1)})")
4024
4025 def test_pretty_print_matchfunction(self) -> None:
4026 obj = MatchFunction([MatchCase(Var("y"), Var("x"))])
4027 self.assertEqual(pretty(obj), "| y -> x")
4028
4029 def test_pretty_print_matchfunction_precedence(self) -> None:
4030 obj = MatchFunction(
4031 [
4032 MatchCase(Var("a"), MatchFunction([MatchCase(Var("b"), Var("c"))])),
4033 MatchCase(Var("x"), MatchFunction([MatchCase(Var("y"), Var("z"))])),
4034 ]
4035 )
4036 self.assertEqual(
4037 pretty(obj),
4038 """\
4039| a -> (| b -> c)
4040| x -> (| y -> z)""",
4041 )
4042
4043 def test_pretty_print_relocation(self) -> None:
4044 obj = Relocation("relocate")
4045 self.assertEqual(pretty(obj), "Relocation(name='relocate')")
4046
4047 def test_pretty_print_nativefunction(self) -> None:
4048 obj = NativeFunction("times2", lambda x: Int(x.value * 2)) # type: ignore [attr-defined]
4049 self.assertEqual(pretty(obj), "NativeFunction(name=times2)")
4050
4051 def test_pretty_print_closure(self) -> None:
4052 obj = Closure({"a": Int(123)}, Function(Var("x"), Var("x")))
4053 self.assertEqual(pretty(obj), "Closure(['a'], x -> x)")
4054
4055 def test_pretty_print_record(self) -> None:
4056 obj = Record({"a": Int(1), "b": Int(2)})
4057 self.assertEqual(pretty(obj), "{a = 1, b = 2}")
4058
4059 def test_pretty_print_access(self) -> None:
4060 obj = Access(Record({"a": Int(4)}), Var("a"))
4061 self.assertEqual(pretty(obj), "{a = 4} @ a")
4062
4063 def test_pretty_print_variant(self) -> None:
4064 obj = Variant("x", Int(123))
4065 self.assertEqual(pretty(obj), "#x 123")
4066
4067 obj = Variant("x", Function(Var("a"), Var("b")))
4068 self.assertEqual(pretty(obj), "#x (a -> b)")
4069
4070
4071class ServerCommandTests(unittest.TestCase):
4072 def setUp(self) -> None:
4073 import threading
4074 import time
4075 import os
4076 import socket
4077 import argparse
4078 from scrapscript import server_command
4079
4080 # Find a random available port
4081 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
4082 s.bind(("127.0.0.1", 0))
4083 self.host, self.port = s.getsockname()
4084
4085 args = argparse.Namespace(
4086 directory=os.path.join(os.path.dirname(__file__), "examples"),
4087 host=self.host,
4088 port=self.port,
4089 )
4090
4091 self.server_thread = threading.Thread(target=server_command, args=(args,))
4092 self.server_thread.daemon = True
4093 self.server_thread.start()
4094
4095 # Wait for the server to start
4096 while True:
4097 try:
4098 with socket.create_connection((self.host, self.port), timeout=0.1) as s:
4099 break
4100 except (ConnectionRefusedError, socket.timeout):
4101 time.sleep(0.01)
4102
4103 def tearDown(self) -> None:
4104 quit_request = urllib.request.Request(f"http://{self.host}:{self.port}/", method="QUIT")
4105 urllib.request.urlopen(quit_request)
4106
4107 def test_server_serves_scrap_by_path(self) -> None:
4108 response = urllib.request.urlopen(f"http://{self.host}:{self.port}/0_home/factorial")
4109 self.assertEqual(response.status, 200)
4110
4111 def test_server_serves_scrap_by_hash(self) -> None:
4112 response = urllib.request.urlopen(f"http://{self.host}:{self.port}/$09242a8dfec0ed32eb9ddd5452f0082998712d35306fec2042bad8ac5b6e9580")
4113 self.assertEqual(response.status, 200)
4114
4115 def test_server_fails_missing_scrap(self) -> None:
4116 with self.assertRaises(urllib.error.HTTPError) as cm:
4117 urllib.request.urlopen(f"http://{self.host}:{self.port}/foo")
4118 self.assertEqual(cm.exception.code, 404)
4119
4120
4121if __name__ == "__main__":
4122 unittest.main()