this repo has no description

Add row polymorphism to type system (#197)

- **WIP: test, but working row poly**
- **.**
- **.**
- **.**
- **Add comments**
- **.**
- **Add tests**
- **.**
- **.**

authored by bernsteinbear.com and committed by

GitHub 93940c22 6b01661d

+374 -35
+374 -35
scrapscript.py
··· 4004 4004 raise InferenceError(f"{self} is already resolved to {chain_end}") 4005 4005 chain_end.forwarded = other 4006 4006 4007 + def is_unbound(self) -> bool: 4008 + return self.forwarded is None 4009 + 4007 4010 4008 4011 @dataclasses.dataclass 4009 4012 class TyCon(MonoType): ··· 4020 4023 4021 4024 4022 4025 @dataclasses.dataclass 4026 + class TyEmptyRow(MonoType): 4027 + def __str__(self) -> str: 4028 + return "{}" 4029 + 4030 + 4031 + empty_row = TyEmptyRow() 4032 + 4033 + 4034 + @dataclasses.dataclass 4035 + class TyRow(MonoType): 4036 + fields: dict[str, MonoType] 4037 + rest: TyVar | TyRow | TyEmptyRow = dataclasses.field(default_factory=TyEmptyRow) 4038 + 4039 + def __post_init__(self) -> None: 4040 + if not self.fields and isinstance(self.rest, TyEmptyRow): 4041 + raise InferenceError("Empty row must have a rest type") 4042 + 4043 + def __str__(self) -> str: 4044 + flat, rest = row_flatten(self) 4045 + # sort to make tests deterministic 4046 + result = [f"{key}={val}" for key, val in sorted(flat.items())] 4047 + if isinstance(rest, TyVar): 4048 + result.append(f"...{rest}") 4049 + else: 4050 + assert isinstance(rest, TyEmptyRow) 4051 + return "{" + ", ".join(result) + "}" 4052 + 4053 + 4054 + def row_flatten(rec: MonoType) -> tuple[dict[str, MonoType], TyVar | TyEmptyRow]: 4055 + if isinstance(rec, TyVar): 4056 + rec = rec.find() 4057 + if isinstance(rec, TyVar): 4058 + return {}, rec 4059 + if isinstance(rec, TyRow): 4060 + flat, rest = row_flatten(rec.rest) 4061 + flat.update(rec.fields) 4062 + return flat, rest 4063 + if isinstance(rec, TyEmptyRow): 4064 + return {}, rec 4065 + raise InferenceError(f"Expected record type, got {type(rec)}") 4066 + 4067 + 4068 + @dataclasses.dataclass 4023 4069 class Forall: 4024 4070 tyvars: list[TyVar] 4025 4071 ty: MonoType ··· 4041 4087 def test_tycon_args(self) -> None: 4042 4088 self.assertEqual(str(TyCon("->", [IntType, IntType])), "(int->int)") 4043 4089 4090 + def test_tyrow_empty_closed(self) -> None: 4091 + self.assertEqual(str(TyEmptyRow()), "{}") 4092 + 4093 + def test_tyrow_empty_open(self) -> None: 4094 + self.assertEqual(str(TyRow({}, TyVar("a"))), "{...'a}") 4095 + 4096 + def test_tyrow_closed(self) -> None: 4097 + self.assertEqual(str(TyRow({"x": IntType, "y": StringType})), "{x=int, y=string}") 4098 + 4099 + def test_tyrow_open(self) -> None: 4100 + self.assertEqual(str(TyRow({"x": IntType, "y": StringType}, TyVar("a"))), "{x=int, y=string, ...'a}") 4101 + 4102 + def test_tyrow_chain(self) -> None: 4103 + inner = TyRow({"x": IntType}) 4104 + inner_var = TyVar("a") 4105 + inner_var.make_equal_to(inner) 4106 + outer = TyRow({"y": StringType}, inner_var) 4107 + self.assertEqual(str(outer), "{x=int, y=string}") 4108 + 4044 4109 def test_forall(self) -> None: 4045 4110 self.assertEqual(str(Forall([TyVar("a"), TyVar("b")], TyVar("a"))), "(forall 'a, 'b. 'a)") 4046 4111 ··· 4065 4130 return tyvar == ty 4066 4131 if isinstance(ty, TyCon): 4067 4132 return any(occurs_in(tyvar, arg) for arg in ty.args) 4133 + if isinstance(ty, TyEmptyRow): 4134 + return False 4135 + if isinstance(ty, TyRow): 4136 + return any(occurs_in(tyvar, val) for val in ty.fields.values()) or occurs_in(tyvar, ty.rest) 4068 4137 raise InferenceError(f"Unknown type: {ty}") 4069 4138 4070 4139 ··· 4088 4157 for l, r in zip(ty1.args, ty2.args): 4089 4158 unify_type(l, r) 4090 4159 return 4091 - raise InferenceError(f"Unexpected types {type(ty1)} and {type(ty2)}") 4160 + if isinstance(ty1, TyEmptyRow) and isinstance(ty2, TyEmptyRow): 4161 + return 4162 + if isinstance(ty1, TyRow) and isinstance(ty2, TyRow): 4163 + ty1_fields, ty1_rest = row_flatten(ty1) 4164 + ty2_fields, ty2_rest = row_flatten(ty2) 4165 + ty1_missing = {} 4166 + ty2_missing = {} 4167 + all_field_names = set(ty1_fields.keys()) | set(ty2_fields.keys()) 4168 + for key in sorted(all_field_names): # Sort for deterministic error messages 4169 + ty1_val = ty1_fields.get(key) 4170 + ty2_val = ty2_fields.get(key) 4171 + if ty1_val is not None and ty2_val is not None: 4172 + unify_type(ty1_val, ty2_val) 4173 + elif ty1_val is None: 4174 + assert ty2_val is not None 4175 + ty1_missing[key] = ty2_val 4176 + elif ty2_val is None: 4177 + assert ty1_val is not None 4178 + ty2_missing[key] = ty1_val 4179 + # In general, we want to: 4180 + # 1) Add missing fields from one row to the other row 4181 + # 2) "Keep the rows unified" by linking each row's rest to the other 4182 + # row's rest 4183 + if not ty1_missing and not ty2_missing: 4184 + # The rests are either both empty (rows were closed) or both 4185 + # unbound type variables (rows were open); unify the rest variables 4186 + unify_type(ty1_rest, ty2_rest) 4187 + return 4188 + if not ty1_missing: 4189 + # The first row has fields that the second row doesn't have; add 4190 + # them to the second row 4191 + unify_type(ty2_rest, TyRow(ty2_missing, ty1_rest)) 4192 + return 4193 + if not ty2_missing: 4194 + # The second row has fields that the first row doesn't have; add 4195 + # them to the first row 4196 + unify_type(ty1_rest, TyRow(ty1_missing, ty2_rest)) 4197 + return 4198 + # They each have fields the other lacks; create new rows sharing a rest 4199 + # and add the missing fields to each row 4200 + rest = fresh_tyvar() 4201 + unify_type(ty1_rest, TyRow(ty1_missing, rest)) 4202 + unify_type(ty2_rest, TyRow(ty2_missing, rest)) 4203 + return 4204 + if isinstance(ty1, TyRow) and isinstance(ty2, TyEmptyRow): 4205 + raise InferenceError(f"Unifying row {ty1} with empty row") 4206 + if isinstance(ty1, TyEmptyRow) and isinstance(ty2, TyRow): 4207 + raise InferenceError(f"Unifying empty row with row {ty2}") 4208 + raise InferenceError(f"Cannot unify {ty1} and {ty2}") 4092 4209 4093 4210 4094 4211 Context = typing.Mapping[str, Forall] ··· 4104 4221 return TyVar(result) 4105 4222 4106 4223 4107 - def collect_vars_in_pattern(pattern: Object) -> Context: 4108 - if isinstance(pattern, (Int, Float, String)): 4109 - return {} 4110 - if isinstance(pattern, Var): 4111 - return {pattern.name: Forall([], fresh_tyvar())} 4112 - if isinstance(pattern, List): 4113 - result: dict[str, Forall] = {} 4114 - for item in pattern.items: 4115 - if isinstance(item, Spread): 4116 - if item.name is not None: 4117 - result[item.name] = Forall([], list_type(fresh_tyvar())) 4118 - break 4119 - result.update(collect_vars_in_pattern(item)) 4120 - return result 4121 - raise InferenceError(f"Unexpected type {type(pattern)}") 4122 - 4123 - 4124 4224 IntType = TyCon("int", []) 4125 4225 StringType = TyCon("string", []) 4126 4226 FloatType = TyCon("float", []) ··· 4136 4236 return subst.get(ty.name, ty) 4137 4237 if isinstance(ty, TyCon): 4138 4238 return TyCon(ty.name, [apply_ty(arg, subst) for arg in ty.args]) 4239 + if isinstance(ty, TyEmptyRow): 4240 + return ty 4241 + if isinstance(ty, TyRow): 4242 + rest = apply_ty(ty.rest, subst) 4243 + assert isinstance(rest, (TyVar, TyRow, TyEmptyRow)) 4244 + return TyRow({key: apply_ty(val, subst) for key, val in ty.fields.items()}, rest) 4139 4245 raise InferenceError(f"Unknown type: {ty}") 4140 4246 4141 4247 ··· 4149 4255 return {ty.name} 4150 4256 if isinstance(ty, TyCon): 4151 4257 return set().union(*map(ftv_ty, ty.args)) 4258 + if isinstance(ty, TyEmptyRow): 4259 + return set() 4260 + if isinstance(ty, TyRow): 4261 + return set().union(*map(ftv_ty, ty.fields.values()), ftv_ty(ty.rest)) 4152 4262 raise InferenceError(f"Unknown type: {ty}") 4153 4263 4154 4264 ··· 4173 4283 return recursive_find(found) 4174 4284 if isinstance(ty, TyCon): 4175 4285 return TyCon(ty.name, [recursive_find(arg) for arg in ty.args]) 4286 + if isinstance(ty, TyEmptyRow): 4287 + return ty 4288 + if isinstance(ty, TyRow): 4289 + rest = recursive_find(ty.rest) 4290 + assert isinstance(rest, (TyVar, TyRow, TyEmptyRow)) 4291 + return TyRow({name: recursive_find(ty) for name, ty in ty.fields.items()}, rest) 4176 4292 raise InferenceError(type(ty)) 4177 4293 4178 4294 ··· 4188 4304 return ty 4189 4305 4190 4306 4307 + def infer_common(expr: Object) -> MonoType: 4308 + if isinstance(expr, Int): 4309 + return set_type(expr, IntType) 4310 + if isinstance(expr, Float): 4311 + return set_type(expr, FloatType) 4312 + if isinstance(expr, Bytes): 4313 + return set_type(expr, BytesType) 4314 + if isinstance(expr, Hole): 4315 + return set_type(expr, HoleType) 4316 + if isinstance(expr, String): 4317 + return set_type(expr, StringType) 4318 + raise InferenceError(f"{type(expr)} can't be simply inferred") 4319 + 4320 + 4321 + def infer_pattern_type(pattern: Object, ctx: Context) -> MonoType: 4322 + assert isinstance(ctx, dict) 4323 + if isinstance(pattern, (Int, Float, Bytes, Hole, String)): 4324 + return infer_common(pattern) 4325 + if isinstance(pattern, Var): 4326 + result = fresh_tyvar() 4327 + ctx[pattern.name] = Forall([], result) 4328 + return set_type(pattern, result) 4329 + if isinstance(pattern, List): 4330 + list_item_ty = fresh_tyvar() 4331 + result_ty = list_type(list_item_ty) 4332 + for item in pattern.items: 4333 + if isinstance(item, Spread): 4334 + if item.name is not None: 4335 + ctx[item.name] = Forall([], result_ty) 4336 + break 4337 + item_ty = infer_pattern_type(item, ctx) 4338 + unify_type(list_item_ty, item_ty) 4339 + return set_type(pattern, result_ty) 4340 + if isinstance(pattern, Record): 4341 + fields = {} 4342 + rest: TyVar | TyRow | TyEmptyRow = empty_row # Default closed row 4343 + for key, value in pattern.data.items(): 4344 + if isinstance(value, Spread): 4345 + # Open row 4346 + rest = fresh_tyvar() 4347 + if value.name is not None: 4348 + ctx[value.name] = Forall([], rest) 4349 + break 4350 + fields[key] = infer_pattern_type(value, ctx) 4351 + return set_type(pattern, TyRow(fields, rest)) 4352 + raise InferenceError(f"{type(pattern)} isn't allowed in a pattern") 4353 + 4354 + 4191 4355 def infer_type(expr: Object, ctx: Context) -> MonoType: 4356 + if isinstance(expr, (Int, Float, Bytes, Hole, String)): 4357 + return infer_common(expr) 4192 4358 if isinstance(expr, Var): 4193 4359 scheme = ctx.get(expr.name) 4194 4360 if scheme is None: 4195 4361 raise InferenceError(f"Unbound variable {expr.name}") 4196 4362 return set_type(expr, instantiate(scheme)) 4197 - if isinstance(expr, Int): 4198 - return set_type(expr, IntType) 4199 - if isinstance(expr, Float): 4200 - return set_type(expr, FloatType) 4201 - if isinstance(expr, String): 4202 - return set_type(expr, StringType) 4203 4363 if isinstance(expr, Function): 4204 4364 arg_tyvar = fresh_tyvar() 4205 4365 assert isinstance(expr.arg, Var) ··· 4226 4386 if isinstance(expr, List): 4227 4387 list_item_ty = fresh_tyvar() 4228 4388 for item in expr.items: 4229 - if isinstance(item, Spread): 4230 - break 4389 + assert not isinstance(item, Spread), "Spread can only occur in list match (for now)" 4231 4390 item_ty = infer_type(item, ctx) 4232 4391 unify_type(list_item_ty, item_ty) 4233 4392 return set_type(expr, list_type(list_item_ty)) 4234 4393 if isinstance(expr, MatchCase): 4235 - pattern_ctx = collect_vars_in_pattern(expr.pattern) 4236 - body_ctx = {**ctx, **pattern_ctx} 4237 - pattern_ty = infer_type(expr.pattern, body_ctx) 4238 - body_ty = infer_type(expr.body, body_ctx) 4394 + pattern_ctx: Context = {} 4395 + pattern_ty = infer_pattern_type(expr.pattern, pattern_ctx) 4396 + body_ty = infer_type(expr.body, {**ctx, **pattern_ctx}) 4239 4397 return set_type(expr, func_type(pattern_ty, body_ty)) 4240 4398 if isinstance(expr, Apply): 4241 4399 func_ty = infer_type(expr.func, ctx) ··· 4249 4407 case_ty = infer_type(case, ctx) 4250 4408 unify_type(result, case_ty) 4251 4409 return set_type(expr, result) 4252 - if isinstance(expr, Bytes): 4253 - return set_type(expr, BytesType) 4254 - if isinstance(expr, Hole): 4255 - return set_type(expr, HoleType) 4410 + if isinstance(expr, Record): 4411 + fields = {} 4412 + rest: TyVar | TyRow | TyEmptyRow = empty_row 4413 + for key, value in expr.data.items(): 4414 + assert not isinstance(value, Spread), "Spread can only occur in record match (for now)" 4415 + fields[key] = infer_type(value, ctx) 4416 + return set_type(expr, TyRow(fields, rest)) 4417 + if isinstance(expr, Access): 4418 + obj_ty = infer_type(expr.obj, ctx) 4419 + value_ty = fresh_tyvar() 4420 + assert isinstance(expr.at, Var) 4421 + # "has field" constraint in the form of an open row 4422 + unify_type(obj_ty, TyRow({expr.at.name: value_ty}, fresh_tyvar())) 4423 + return value_ty 4256 4424 raise InferenceError(f"Unexpected type {type(expr)}") 4257 4425 4258 4426 ··· 4315 4483 with self.assertRaisesRegex(InferenceError, "Occurs check failed"): 4316 4484 unify_type(l, r) 4317 4485 4486 + def test_unify_empty_row(self) -> None: 4487 + unify_type(TyEmptyRow(), TyEmptyRow()) 4488 + 4489 + def test_unify_empty_row_open(self) -> None: 4490 + l = TyRow({}, TyVar("a")) 4491 + r = TyRow({}, TyVar("b")) 4492 + unify_type(l, r) 4493 + self.assertIs(l.rest.find(), r.rest.find()) 4494 + 4495 + def test_unify_row_unifies_fields(self) -> None: 4496 + a = TyVar("a") 4497 + b = TyVar("b") 4498 + l = TyRow({"x": a}) 4499 + r = TyRow({"x": b}) 4500 + unify_type(l, r) 4501 + self.assertIs(a.find(), b.find()) 4502 + 4503 + def test_unify_empty_right(self) -> None: 4504 + l = TyRow({"x": IntType}) 4505 + r = TyEmptyRow() 4506 + with self.assertRaisesRegex(InferenceError, "Unifying row {x=int} with empty row"): 4507 + unify_type(l, r) 4508 + 4509 + def test_unify_empty_left(self) -> None: 4510 + l = TyEmptyRow() 4511 + r = TyRow({"x": IntType}) 4512 + with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {x=int}"): 4513 + unify_type(l, r) 4514 + 4515 + def test_unify_missing_closed(self) -> None: 4516 + l = TyRow({"x": IntType}) 4517 + r = TyRow({"y": IntType}) 4518 + with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {y=int, ...'t0}"): 4519 + unify_type(l, r) 4520 + 4521 + def test_unify_left_missing_open(self) -> None: 4522 + l = TyRow({}, TyVar("r0")) 4523 + r = TyRow({"y": IntType}, TyVar("r1")) 4524 + unify_type(l, r) 4525 + self.assertTyEqual(l.rest, TyRow({"y": IntType}, TyVar("r1"))) 4526 + assert isinstance(r.rest, TyVar) 4527 + self.assertTrue(r.rest.is_unbound()) 4528 + 4529 + def test_unify_right_missing_open(self) -> None: 4530 + l = TyRow({"x": IntType}, TyVar("r0")) 4531 + r = TyRow({}, TyVar("r1")) 4532 + unify_type(l, r) 4533 + assert isinstance(l.rest, TyVar) 4534 + self.assertTrue(l.rest.is_unbound()) 4535 + self.assertTyEqual(r.rest, TyRow({"x": IntType}, TyVar("r0"))) 4536 + 4537 + def test_unify_both_missing_open(self) -> None: 4538 + l = TyRow({"x": IntType}, TyVar("r0")) 4539 + r = TyRow({"y": IntType}, TyVar("r1")) 4540 + unify_type(l, r) 4541 + self.assertTyEqual(l.rest, TyRow({"y": IntType}, TyVar("t0"))) 4542 + self.assertTyEqual(r.rest, TyRow({"x": IntType}, TyVar("t0"))) 4543 + 4318 4544 def test_minimize_tyvar(self) -> None: 4319 4545 ty = fresh_tyvar() 4320 4546 self.assertEqual(minimize(ty), TyVar("a")) ··· 4340 4566 self.fail(f"Type mismatch: {l} != {r}") 4341 4567 for l_arg, r_arg in zip(l.args, r.args): 4342 4568 self.assertTyEqual(l_arg, r_arg) 4569 + return True 4570 + if isinstance(l, TyEmptyRow) and isinstance(r, TyEmptyRow): 4571 + return True 4572 + if isinstance(l, TyRow) and isinstance(r, TyRow): 4573 + l_keys = set(l.fields.keys()) 4574 + r_keys = set(r.fields.keys()) 4575 + if l_keys != r_keys: 4576 + self.fail(f"Type mismatch: {l} != {r}") 4577 + for key in l_keys: 4578 + self.assertTyEqual(l.fields[key], r.fields[key]) 4579 + self.assertTyEqual(l.rest, r.rest) 4343 4580 return True 4344 4581 self.fail(f"Type mismatch: {l} != {r}") 4345 4582 ··· 4521 4758 """) 4522 4759 ) 4523 4760 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4524 - self.assertTyEqual(ty, func_type(list_type(TyVar("t9")), IntType)) 4761 + self.assertTyEqual(ty, func_type(list_type(TyVar("t8")), IntType)) 4525 4762 4526 4763 def test_match_list_to_list(self) -> None: 4527 4764 expr = parse(tokenize("| [] -> [] | x -> x")) ··· 4533 4770 ty = infer_type(expr, {}) 4534 4771 self.assertTyEqual(ty, func_type(list_type(TyVar("t4")), TyVar("t4"))) 4535 4772 4773 + def test_match_list_spread_rest(self) -> None: 4774 + expr = parse(tokenize("tail . tail = | [x, ...xs] -> xs")) 4775 + ty = infer_type(expr, {}) 4776 + self.assertTyEqual(ty, func_type(list_type(TyVar("t4")), list_type(TyVar("t4")))) 4777 + 4536 4778 def test_match_list_spread_named(self) -> None: 4537 4779 expr = parse(tokenize("sum . sum = | [] -> 0 | [x, ...xs] -> x + sum xs")) 4538 4780 ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) ··· 4572 4814 expr = parse(tokenize("[1] +< 2")) 4573 4815 ty = infer_type(expr, OP_ENV) 4574 4816 self.assertTyEqual(ty, list_type(IntType)) 4817 + 4818 + def test_record(self) -> None: 4819 + expr = Record({"a": Int(1), "b": String("hello")}) 4820 + ty = infer_type(expr, {}) 4821 + self.assertTyEqual(ty, TyRow({"a": IntType, "b": StringType})) 4822 + 4823 + def test_match_record(self) -> None: 4824 + expr = MatchFunction( 4825 + [ 4826 + MatchCase( 4827 + Record({"x": Var("x")}), 4828 + Var("x"), 4829 + ) 4830 + ] 4831 + ) 4832 + ty = infer_type(expr, {}) 4833 + self.assertTyEqual(ty, func_type(TyRow({"x": TyVar("t1")}), TyVar("t1"))) 4834 + 4835 + def test_access_poly(self) -> None: 4836 + expr = Function(Var("r"), Access(Var("r"), Var("x"))) 4837 + ty = infer_type(expr, {}) 4838 + self.assertTyEqual(ty, func_type(TyRow({"x": TyVar("t1")}, TyVar("t2")), TyVar("t1"))) 4839 + 4840 + def test_apply_row(self) -> None: 4841 + row0 = Record({"x": Int(1)}) 4842 + row1 = Record({"x": Int(1), "y": Int(2)}) 4843 + scheme = Forall([], func_type(TyRow({"x": IntType}, TyVar("a")), IntType)) 4844 + ty0 = infer_type(Apply(Var("f"), row0), {"f": scheme}) 4845 + self.assertTyEqual(ty0, IntType) 4846 + with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {y=int}"): 4847 + infer_type(Apply(Var("f"), row1), {"f": scheme}) 4848 + 4849 + def test_apply_row_polymorphic(self) -> None: 4850 + row0 = Record({"x": Int(1)}) 4851 + row1 = Record({"x": Int(1), "y": Int(2)}) 4852 + row2 = Record({"x": Int(1), "y": Int(2), "z": Int(3)}) 4853 + scheme = Forall([TyVar("a")], func_type(TyRow({"x": IntType}, TyVar("a")), IntType)) 4854 + ty0 = infer_type(Apply(Var("f"), row0), {"f": scheme}) 4855 + self.assertTyEqual(ty0, IntType) 4856 + ty1 = infer_type(Apply(Var("f"), row1), {"f": scheme}) 4857 + self.assertTyEqual(ty1, IntType) 4858 + ty2 = infer_type(Apply(Var("f"), row2), {"f": scheme}) 4859 + self.assertTyEqual(ty2, IntType) 4860 + 4861 + def test_example_rec_access(self) -> None: 4862 + expr = parse(tokenize('rec@a . rec = { a = 1, b = "x" }')) 4863 + ty = infer_type(expr, {}) 4864 + self.assertTyEqual(ty, IntType) 4865 + 4866 + def test_example_rec_access_poly(self) -> None: 4867 + expr = parse( 4868 + tokenize(""" 4869 + (get_x {x=1}) + (get_x {x=2,y=3}) 4870 + . get_x = | { x=x, ... } -> x 4871 + """) 4872 + ) 4873 + ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4874 + self.assertTyEqual(ty, IntType) 4875 + 4876 + def test_example_rec_access_poly_named_bug(self) -> None: 4877 + expr = parse( 4878 + tokenize(""" 4879 + (filter_x {x=1, y=2}) + 3 4880 + . filter_x = | { x=x, ...xs } -> xs 4881 + """) 4882 + ) 4883 + with self.assertRaisesRegex(InferenceError, "Cannot unify int and {y=int}"): 4884 + infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4885 + 4886 + def test_example_rec_access_rest(self) -> None: 4887 + expr = parse( 4888 + tokenize(""" 4889 + | { x=x, ...xs } -> xs 4890 + """) 4891 + ) 4892 + ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4893 + self.assertTyEqual(ty, func_type(TyRow({"x": TyVar("t1")}, TyVar("t2")), TyVar("t2"))) 4894 + 4895 + def test_example_match_rec_access_rest(self) -> None: 4896 + expr = parse( 4897 + tokenize(""" 4898 + filter_x {x=1, y=2} 4899 + . filter_x = | { x=x, ...xs } -> xs 4900 + """) 4901 + ) 4902 + ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4903 + self.assertTyEqual(ty, TyRow({"y": IntType})) 4904 + 4905 + def test_example_rec_access_poly_named(self) -> None: 4906 + expr = parse( 4907 + tokenize(""" 4908 + [(filter_x {x=1, y=2}), (filter_x {x=2, y=3, z=4})] 4909 + . filter_x = | { x=x, ...xs } -> xs 4910 + """) 4911 + ) 4912 + with self.assertRaisesRegex(InferenceError, "Unifying empty row with row {z=int}"): 4913 + infer_type(expr, {}) 4575 4914 4576 4915 4577 4916 class SerializerTests(unittest.TestCase):