this repo has no description

Add Damas-Hindley-Milner type inference (#195)

Still TODO:

* Record typing
* Row polymorphism
* Variant typing

Co-authored-by: @rdck

authored by bernsteinbear.com and committed by

GitHub f6575eb4 6b9fb033

+658 -17
+24 -13
compiler.py
··· 27 27 Variant, 28 28 Where, 29 29 free_in, 30 - parse, 31 - tokenize, 30 + type_of, 31 + IntType, 32 + StringType, 33 + parse, # needed for /compilerepl 34 + tokenize, # needed for /compilerepl 32 35 ) 33 36 34 37 Env = Dict[str, str] ··· 128 131 self._emit("abort();") 129 132 self._emit("}") 130 133 134 + def _guard_int(self, exp: Object, c_name: str) -> str: 135 + if type_of(exp) != IntType: 136 + self._guard(f"is_num({c_name})") 137 + 138 + def _guard_str(self, exp: Object, c_name: str) -> str: 139 + if type_of(exp) != StringType: 140 + self._guard(f"is_string({c_name})") 141 + 131 142 def _mktemp(self, exp: str) -> str: 132 143 temp = self.gensym() 133 144 return self._handle(temp, exp) ··· 179 190 return self.make_closure(env, fn) 180 191 181 192 def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> Env: 193 + # TODO(max): Give `arg` an AST node so we can track its inferred type 194 + # and make use of that in pattern matching 182 195 if isinstance(pattern, Int): 183 196 self._emit(f"if (!is_num_equal_word({arg}, {pattern.value})) {{ goto {fallthrough}; }}") 184 197 return {} ··· 365 378 right = self.compile(env, exp.right) 366 379 if exp.op == BinopKind.ADD: 367 380 self._debug("collect(heap);") 368 - self._guard(f"is_num({left})") 369 - self._guard(f"is_num({right})") 381 + self._guard_int(exp.left, left) 382 + self._guard_int(exp.right, right) 370 383 return self._mktemp(f"num_add({left}, {right})") 371 384 if exp.op == BinopKind.MUL: 372 385 self._debug("collect(heap);") 373 - self._guard(f"is_num({left})") 374 - self._guard(f"is_num({right})") 386 + self._guard_int(exp.left, left) 387 + self._guard_int(exp.right, right) 375 388 return self._mktemp(f"num_mul({left}, {right})") 376 389 if exp.op == BinopKind.SUB: 377 390 self._debug("collect(heap);") 378 - self._guard(f"is_num({left})") 379 - self._guard(f"is_num({right})") 391 + self._guard_int(exp.left, left) 392 + self._guard_int(exp.right, right) 380 393 return self._mktemp(f"num_sub({left}, {right})") 381 394 if exp.op == BinopKind.LIST_CONS: 382 395 self._debug("collect(heap);") 383 396 return self._mktemp(f"list_cons({left}, {right})") 384 397 if exp.op == BinopKind.STRING_CONCAT: 385 398 self._debug("collect(heap);") 386 - self._guard(f"is_string({left})") 387 - self._guard(f"is_string({right})") 399 + self._guard_str(exp.left, left) 400 + self._guard_str(exp.right, right) 388 401 return self._mktemp(f"string_concat({left}, {right})") 389 402 raise NotImplementedError(f"binop {exp.op}") 390 403 if isinstance(exp, Where): ··· 433 446 raise NotImplementedError(f"exp {type(exp)} {exp}") 434 447 435 448 436 - def compile_to_string(source: str, debug: bool) -> str: 437 - program = parse(tokenize(source)) 438 - 449 + def compile_to_string(program: Object, debug: bool) -> str: 439 450 main_fn = CompiledFunction("scrap_main", params=[]) 440 451 compiler = Compiler(main_fn) 441 452 compiler.debug = debug
+3 -2
compiler_tests.py
··· 2 2 import unittest 3 3 import subprocess 4 4 5 - from scrapscript import env_get_split, discover_cflags 5 + from scrapscript import env_get_split, discover_cflags, parse, tokenize 6 6 from compiler import compile_to_string 7 7 8 8 ··· 15 15 cc = env_get_split("CC", shlex.split(sysconfig.get_config_var("CC"))) 16 16 cflags = discover_cflags(cc, debug) 17 17 cflags += [f"-DMEMORY_SIZE={memory}"] 18 - c_code = compile_to_string(source, debug) 18 + program = parse(tokenize(source)) 19 + c_code = compile_to_string(program, debug) 19 20 with tempfile.NamedTemporaryFile(mode="w", suffix=".c", delete=False) as c_file: 20 21 c_file.write(c_code) 21 22 # The platform is in the same directory as this file
+3 -1
compilerepl.html
··· 68 68 async function sendRequest(exp) { 69 69 const compiler = document.compiler; 70 70 try { 71 - return {result: compiler.compile_to_string(exp, false), ok: true}; 71 + const tokens = compiler.tokenize(exp); 72 + const ast = compiler.parse(tokens); 73 + return {result: compiler.compile_to_string(ast, false), ok: true}; 72 74 } catch (e) { 73 75 return {text: () => e.toString(), ok: false}; 74 76 }
+628 -1
scrapscript.py
··· 1 1 #!/usr/bin/env python3.10 2 + from __future__ import annotations 2 3 import argparse 3 4 import base64 4 5 import code ··· 3970 3971 ) 3971 3972 3972 3973 3974 + class InferenceError(Exception): 3975 + pass 3976 + 3977 + 3978 + @dataclasses.dataclass 3979 + class MonoType: 3980 + def find(self) -> MonoType: 3981 + return self 3982 + 3983 + 3984 + @dataclasses.dataclass 3985 + class TyVar(MonoType): 3986 + forwarded: MonoType | None = dataclasses.field(init=False, default=None) 3987 + name: str 3988 + 3989 + def find(self) -> MonoType: 3990 + result: MonoType = self 3991 + while isinstance(result, TyVar): 3992 + it = result.forwarded 3993 + if it is None: 3994 + return result 3995 + result = it 3996 + return result 3997 + 3998 + def __str__(self) -> str: 3999 + return f"'{self.name}" 4000 + 4001 + def make_equal_to(self, other: MonoType) -> None: 4002 + chain_end = self.find() 4003 + if not isinstance(chain_end, TyVar): 4004 + raise InferenceError(f"{self} is already resolved to {chain_end}") 4005 + chain_end.forwarded = other 4006 + 4007 + 4008 + @dataclasses.dataclass 4009 + class TyCon(MonoType): 4010 + name: str 4011 + args: list[MonoType] 4012 + 4013 + def __str__(self) -> str: 4014 + # TODO(max): Precedence pretty-print type constructors 4015 + if not self.args: 4016 + return self.name 4017 + if len(self.args) == 1: 4018 + return f"({self.args[0]} {self.name})" 4019 + return f"({self.name.join(map(str, self.args))})" 4020 + 4021 + 4022 + @dataclasses.dataclass 4023 + class Forall: 4024 + tyvars: list[TyVar] 4025 + ty: MonoType 4026 + 4027 + def __str__(self) -> str: 4028 + return f"(forall {', '.join(map(str, self.tyvars))}. {self.ty})" 4029 + 4030 + 4031 + class TypeStrTests(unittest.TestCase): 4032 + def test_tyvar(self) -> None: 4033 + self.assertEqual(str(TyVar("a")), "'a") 4034 + 4035 + def test_tycon(self) -> None: 4036 + self.assertEqual(str(TyCon("int", [])), "int") 4037 + 4038 + def test_tycon_one_arg(self) -> None: 4039 + self.assertEqual(str(TyCon("list", [IntType])), "(int list)") 4040 + 4041 + def test_tycon_args(self) -> None: 4042 + self.assertEqual(str(TyCon("->", [IntType, IntType])), "(int->int)") 4043 + 4044 + def test_forall(self) -> None: 4045 + self.assertEqual(str(Forall([TyVar("a"), TyVar("b")], TyVar("a"))), "(forall 'a, 'b. 'a)") 4046 + 4047 + 4048 + def func_type(*args: MonoType) -> TyCon: 4049 + assert len(args) >= 2 4050 + if len(args) == 2: 4051 + return TyCon("->", list(args)) 4052 + return TyCon("->", [args[0], func_type(*args[1:])]) 4053 + 4054 + 4055 + def list_type(arg: MonoType) -> TyCon: 4056 + return TyCon("list", [arg]) 4057 + 4058 + 4059 + def unify_fail(ty1: MonoType, ty2: MonoType) -> None: 4060 + raise InferenceError(f"Unification failed for {ty1} and {ty2}") 4061 + 4062 + 4063 + def occurs_in(tyvar: TyVar, ty: MonoType) -> bool: 4064 + if isinstance(ty, TyVar): 4065 + return tyvar == ty 4066 + if isinstance(ty, TyCon): 4067 + return any(occurs_in(tyvar, arg) for arg in ty.args) 4068 + raise InferenceError(f"Unknown type: {ty}") 4069 + 4070 + 4071 + def unify_type(ty1: MonoType, ty2: MonoType) -> None: 4072 + ty1 = ty1.find() 4073 + ty2 = ty2.find() 4074 + if isinstance(ty1, TyVar): 4075 + if occurs_in(ty1, ty2): 4076 + raise InferenceError(f"Occurs check failed for {ty1} and {ty2}") 4077 + ty1.make_equal_to(ty2) 4078 + return 4079 + if isinstance(ty2, TyVar): # Mirror 4080 + return unify_type(ty2, ty1) 4081 + if isinstance(ty1, TyCon) and isinstance(ty2, TyCon): 4082 + if ty1.name != ty2.name: 4083 + unify_fail(ty1, ty2) 4084 + return 4085 + if len(ty1.args) != len(ty2.args): 4086 + unify_fail(ty1, ty2) 4087 + return 4088 + for l, r in zip(ty1.args, ty2.args): 4089 + unify_type(l, r) 4090 + return 4091 + raise InferenceError(f"Unexpected types {type(ty1)} and {type(ty2)}") 4092 + 4093 + 4094 + Context = typing.Mapping[str, Forall] 4095 + 4096 + 4097 + fresh_var_counter = 0 4098 + 4099 + 4100 + def fresh_tyvar(prefix: str = "t") -> TyVar: 4101 + global fresh_var_counter 4102 + result = f"{prefix}{fresh_var_counter}" 4103 + fresh_var_counter += 1 4104 + return TyVar(result) 4105 + 4106 + 4107 + def collect_vars_in_pattern(pattern: Object) -> Context: 4108 + if isinstance(pattern, (Int, Float, String)): 4109 + return {} 4110 + if isinstance(pattern, Var): 4111 + return {pattern.name: Forall([], fresh_tyvar())} 4112 + if isinstance(pattern, List): 4113 + result: dict[str, Forall] = {} 4114 + for item in pattern.items: 4115 + if isinstance(item, Spread): 4116 + if item.name is not None: 4117 + result[item.name] = Forall([], list_type(fresh_tyvar())) 4118 + break 4119 + result.update(collect_vars_in_pattern(item)) 4120 + return result 4121 + raise InferenceError(f"Unexpected type {type(pattern)}") 4122 + 4123 + 4124 + IntType = TyCon("int", []) 4125 + StringType = TyCon("string", []) 4126 + FloatType = TyCon("float", []) 4127 + BytesType = TyCon("bytes", []) 4128 + HoleType = TyCon("hole", []) 4129 + 4130 + 4131 + Subst = typing.Mapping[str, MonoType] 4132 + 4133 + 4134 + def apply_ty(ty: MonoType, subst: Subst) -> MonoType: 4135 + if isinstance(ty, TyVar): 4136 + return subst.get(ty.name, ty) 4137 + if isinstance(ty, TyCon): 4138 + return TyCon(ty.name, [apply_ty(arg, subst) for arg in ty.args]) 4139 + raise InferenceError(f"Unknown type: {ty}") 4140 + 4141 + 4142 + def instantiate(scheme: Forall) -> MonoType: 4143 + fresh = {tyvar.name: fresh_tyvar() for tyvar in scheme.tyvars} 4144 + return apply_ty(scheme.ty, fresh) 4145 + 4146 + 4147 + def ftv_ty(ty: MonoType) -> set[str]: 4148 + if isinstance(ty, TyVar): 4149 + return {ty.name} 4150 + if isinstance(ty, TyCon): 4151 + return set().union(*map(ftv_ty, ty.args)) 4152 + raise InferenceError(f"Unknown type: {ty}") 4153 + 4154 + 4155 + def generalize(ty: MonoType, ctx: Context) -> Forall: 4156 + def ftv_scheme(ty: Forall) -> set[str]: 4157 + return ftv_ty(ty.ty) - set(tyvar.name for tyvar in ty.tyvars) 4158 + 4159 + def ftv_ctx(ctx: Context) -> set[str]: 4160 + return set().union(*(ftv_scheme(scheme) for scheme in ctx.values())) 4161 + 4162 + # TODO(max): Freshen? 4163 + # TODO(max): Test with free type variable in the context 4164 + tyvars = ftv_ty(ty) - ftv_ctx(ctx) 4165 + return Forall([TyVar(name) for name in sorted(tyvars)], ty) 4166 + 4167 + 4168 + def recursive_find(ty: MonoType) -> MonoType: 4169 + if isinstance(ty, TyVar): 4170 + found = ty.find() 4171 + if ty is found: 4172 + return found 4173 + return recursive_find(found) 4174 + if isinstance(ty, TyCon): 4175 + return TyCon(ty.name, [recursive_find(arg) for arg in ty.args]) 4176 + raise InferenceError(type(ty)) 4177 + 4178 + 4179 + def type_of(expr: Object) -> MonoType: 4180 + ty = getattr(expr, "inferred_type", None) 4181 + if ty is not None: 4182 + return recursive_find(ty) 4183 + return set_type(expr, fresh_tyvar()) 4184 + 4185 + 4186 + def set_type(expr: Object, ty: MonoType) -> MonoType: 4187 + object.__setattr__(expr, "inferred_type", ty) 4188 + return ty 4189 + 4190 + 4191 + def infer_type(expr: Object, ctx: Context) -> MonoType: 4192 + if isinstance(expr, Var): 4193 + scheme = ctx.get(expr.name) 4194 + if scheme is None: 4195 + raise InferenceError(f"Unbound variable {expr.name}") 4196 + return set_type(expr, instantiate(scheme)) 4197 + if isinstance(expr, Int): 4198 + return set_type(expr, IntType) 4199 + if isinstance(expr, Float): 4200 + return set_type(expr, FloatType) 4201 + if isinstance(expr, String): 4202 + return set_type(expr, StringType) 4203 + if isinstance(expr, Function): 4204 + arg_tyvar = fresh_tyvar() 4205 + assert isinstance(expr.arg, Var) 4206 + body_ctx = {**ctx, expr.arg.name: Forall([], arg_tyvar)} 4207 + body_ty = infer_type(expr.body, body_ctx) 4208 + return set_type(expr, func_type(arg_tyvar, body_ty)) 4209 + if isinstance(expr, Binop): 4210 + left, right = expr.left, expr.right 4211 + op = Var(BinopKind.to_str(expr.op)) 4212 + return set_type(expr, infer_type(Apply(Apply(op, left), right), ctx)) 4213 + if isinstance(expr, Where): 4214 + assert isinstance(expr.binding, Assign) 4215 + name, value, body = expr.binding.name.name, expr.binding.value, expr.body 4216 + if isinstance(value, (Function, MatchFunction)): 4217 + # Letrec 4218 + func_ty: MonoType = fresh_tyvar() 4219 + value_ty = infer_type(value, {**ctx, name: Forall([], func_ty)}) 4220 + else: 4221 + # Let 4222 + value_ty = infer_type(value, ctx) 4223 + value_scheme = generalize(recursive_find(value_ty), ctx) 4224 + body_ty = infer_type(body, {**ctx, name: value_scheme}) 4225 + return set_type(expr, body_ty) 4226 + if isinstance(expr, List): 4227 + list_item_ty = fresh_tyvar() 4228 + for item in expr.items: 4229 + if isinstance(item, Spread): 4230 + break 4231 + item_ty = infer_type(item, ctx) 4232 + unify_type(list_item_ty, item_ty) 4233 + return set_type(expr, list_type(list_item_ty)) 4234 + if isinstance(expr, MatchCase): 4235 + pattern_ctx = collect_vars_in_pattern(expr.pattern) 4236 + body_ctx = {**ctx, **pattern_ctx} 4237 + pattern_ty = infer_type(expr.pattern, body_ctx) 4238 + body_ty = infer_type(expr.body, body_ctx) 4239 + return set_type(expr, func_type(pattern_ty, body_ty)) 4240 + if isinstance(expr, Apply): 4241 + func_ty = infer_type(expr.func, ctx) 4242 + arg_ty = infer_type(expr.arg, ctx) 4243 + result = fresh_tyvar() 4244 + unify_type(func_ty, func_type(arg_ty, result)) 4245 + return set_type(expr, result) 4246 + if isinstance(expr, MatchFunction): 4247 + result = fresh_tyvar() 4248 + for case in expr.cases: 4249 + case_ty = infer_type(case, ctx) 4250 + unify_type(result, case_ty) 4251 + return set_type(expr, result) 4252 + if isinstance(expr, Bytes): 4253 + return set_type(expr, BytesType) 4254 + if isinstance(expr, Hole): 4255 + return set_type(expr, HoleType) 4256 + raise InferenceError(f"Unexpected type {type(expr)}") 4257 + 4258 + 4259 + def minimize(ty: MonoType) -> MonoType: 4260 + letters = iter("abcdefghijklmnopqrstuvwxyz") 4261 + free = ftv_ty(ty) 4262 + subst = {ftv: TyVar(next(letters)) for ftv in sorted(free)} 4263 + return apply_ty(ty, subst) 4264 + 4265 + 4266 + class InferTypeTests(unittest.TestCase): 4267 + def setUp(self) -> None: 4268 + global fresh_var_counter 4269 + fresh_var_counter = 0 4270 + 4271 + def test_unify_tyvar_tyvar(self) -> None: 4272 + a = TyVar("a") 4273 + b = TyVar("b") 4274 + unify_type(a, b) 4275 + self.assertIs(a.find(), b.find()) 4276 + 4277 + def test_unify_tyvar_tycon(self) -> None: 4278 + a = TyVar("a") 4279 + unify_type(a, IntType) 4280 + self.assertIs(a.find(), IntType) 4281 + b = TyVar("b") 4282 + unify_type(b, IntType) 4283 + self.assertIs(b.find(), IntType) 4284 + 4285 + def test_unify_tycon_tycon_name_mismatch(self) -> None: 4286 + with self.assertRaisesRegex(InferenceError, "Unification failed"): 4287 + unify_type(IntType, StringType) 4288 + 4289 + def test_unify_tycon_tycon_arity_mismatch(self) -> None: 4290 + l = TyCon("x", [TyVar("a")]) 4291 + r = TyCon("x", []) 4292 + with self.assertRaisesRegex(InferenceError, "Unification failed"): 4293 + unify_type(l, r) 4294 + 4295 + def test_unify_tycon_tycon_unifies_arg(self) -> None: 4296 + a = TyVar("a") 4297 + b = TyVar("b") 4298 + l = TyCon("x", [a]) 4299 + r = TyCon("x", [b]) 4300 + unify_type(l, r) 4301 + self.assertIs(a.find(), b.find()) 4302 + 4303 + def test_unify_tycon_tycon_unifies_args(self) -> None: 4304 + a, b, c, d = map(TyVar, "abcd") 4305 + l = func_type(a, b) 4306 + r = func_type(c, d) 4307 + unify_type(l, r) 4308 + self.assertIs(a.find(), c.find()) 4309 + self.assertIs(b.find(), d.find()) 4310 + self.assertIsNot(a.find(), b.find()) 4311 + 4312 + def test_unify_recursive_fails(self) -> None: 4313 + l = TyVar("a") 4314 + r = TyCon("x", [TyVar("a")]) 4315 + with self.assertRaisesRegex(InferenceError, "Occurs check failed"): 4316 + unify_type(l, r) 4317 + 4318 + def test_minimize_tyvar(self) -> None: 4319 + ty = fresh_tyvar() 4320 + self.assertEqual(minimize(ty), TyVar("a")) 4321 + 4322 + def test_minimize_tycon(self) -> None: 4323 + ty = func_type(TyVar("t0"), TyVar("t1"), TyVar("t0")) 4324 + self.assertEqual(minimize(ty), func_type(TyVar("a"), TyVar("b"), TyVar("a"))) 4325 + 4326 + def infer(self, expr: Object, ctx: Context) -> MonoType: 4327 + return minimize(recursive_find(infer_type(expr, ctx))) 4328 + 4329 + def assertTyEqual(self, l: MonoType, r: MonoType) -> bool: 4330 + l = l.find() 4331 + r = r.find() 4332 + if isinstance(l, TyVar) and isinstance(r, TyVar): 4333 + if l != r: 4334 + self.fail(f"Type mismatch: {l} != {r}") 4335 + return True 4336 + if isinstance(l, TyCon) and isinstance(r, TyCon): 4337 + if l.name != r.name: 4338 + self.fail(f"Type mismatch: {l} != {r}") 4339 + if len(l.args) != len(r.args): 4340 + self.fail(f"Type mismatch: {l} != {r}") 4341 + for l_arg, r_arg in zip(l.args, r.args): 4342 + self.assertTyEqual(l_arg, r_arg) 4343 + return True 4344 + self.fail(f"Type mismatch: {l} != {r}") 4345 + 4346 + def test_unbound_var(self) -> None: 4347 + with self.assertRaisesRegex(InferenceError, "Unbound variable"): 4348 + self.infer(Var("a"), {}) 4349 + 4350 + def test_var_instantiates_scheme(self) -> None: 4351 + ty = self.infer(Var("a"), {"a": Forall([TyVar("b")], TyVar("b"))}) 4352 + self.assertTyEqual(ty, TyVar("a")) 4353 + 4354 + def test_int(self) -> None: 4355 + ty = self.infer(Int(123), {}) 4356 + self.assertTyEqual(ty, IntType) 4357 + 4358 + def test_float(self) -> None: 4359 + ty = self.infer(Float(1.0), {}) 4360 + self.assertTyEqual(ty, FloatType) 4361 + 4362 + def test_string(self) -> None: 4363 + ty = self.infer(String("abc"), {}) 4364 + self.assertTyEqual(ty, StringType) 4365 + 4366 + def test_function_returns_arg(self) -> None: 4367 + ty = self.infer(Function(Var("x"), Var("x")), {}) 4368 + self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a"))) 4369 + 4370 + def test_nested_function_outer(self) -> None: 4371 + ty = self.infer(Function(Var("x"), Function(Var("y"), Var("x"))), {}) 4372 + self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("b"), TyVar("a"))) 4373 + 4374 + def test_nested_function_inner(self) -> None: 4375 + ty = self.infer(Function(Var("x"), Function(Var("y"), Var("y"))), {}) 4376 + self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("b"), TyVar("b"))) 4377 + 4378 + def test_apply_id_int(self) -> None: 4379 + func = Function(Var("x"), Var("x")) 4380 + arg = Int(123) 4381 + ty = self.infer(Apply(func, arg), {}) 4382 + self.assertTyEqual(ty, IntType) 4383 + 4384 + def test_apply_two_arg_returns_function(self) -> None: 4385 + func = Function(Var("x"), Function(Var("y"), Var("x"))) 4386 + arg = Int(123) 4387 + ty = self.infer(Apply(func, arg), {}) 4388 + self.assertTyEqual(ty, func_type(TyVar("a"), IntType)) 4389 + 4390 + def test_binop_add_constrains_int(self) -> None: 4391 + expr = Binop(BinopKind.ADD, Var("x"), Var("y")) 4392 + ty = self.infer( 4393 + expr, 4394 + { 4395 + "x": Forall([], TyVar("a")), 4396 + "y": Forall([], TyVar("b")), 4397 + "+": Forall([], func_type(IntType, IntType, IntType)), 4398 + }, 4399 + ) 4400 + self.assertTyEqual(ty, IntType) 4401 + 4402 + def test_binop_add_function_constrains_int(self) -> None: 4403 + x = Var("x") 4404 + y = Var("y") 4405 + expr = Function(Var("x"), Function(Var("y"), Binop(BinopKind.ADD, x, y))) 4406 + ty = self.infer(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4407 + self.assertTyEqual(ty, func_type(IntType, IntType, IntType)) 4408 + self.assertTyEqual(type_of(x), IntType) 4409 + self.assertTyEqual(type_of(y), IntType) 4410 + 4411 + def test_let(self) -> None: 4412 + expr = Where(Var("f"), Assign(Var("f"), Function(Var("x"), Var("x")))) 4413 + ty = self.infer(expr, {}) 4414 + self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a"))) 4415 + 4416 + def test_apply_monotype_to_different_types_raises(self) -> None: 4417 + expr = Where( 4418 + Where(Var("x"), Assign(Var("x"), Apply(Var("f"), Int(123)))), 4419 + Assign(Var("y"), Apply(Var("f"), Float(123.0))), 4420 + ) 4421 + ctx = {"f": Forall([], func_type(TyVar("a"), TyVar("a")))} 4422 + with self.assertRaisesRegex(InferenceError, "Unification failed"): 4423 + self.infer(expr, ctx) 4424 + 4425 + def test_apply_polytype_to_different_types(self) -> None: 4426 + expr = Where( 4427 + Where(Var("x"), Assign(Var("x"), Apply(Var("f"), Int(123)))), 4428 + Assign(Var("y"), Apply(Var("f"), Float(123.0))), 4429 + ) 4430 + ty = self.infer(expr, {"f": Forall([TyVar("a")], func_type(TyVar("a"), TyVar("a")))}) 4431 + self.assertTyEqual(ty, IntType) 4432 + 4433 + def test_id(self) -> None: 4434 + expr = Function(Var("x"), Var("x")) 4435 + ty = self.infer(expr, {}) 4436 + self.assertTyEqual(ty, func_type(TyVar("a"), TyVar("a"))) 4437 + 4438 + def test_empty_list(self) -> None: 4439 + expr = List([]) 4440 + ty = infer_type(expr, {}) 4441 + self.assertTyEqual(ty, TyCon("list", [TyVar("t0")])) 4442 + 4443 + def test_list_int(self) -> None: 4444 + expr = List([Int(123)]) 4445 + ty = infer_type(expr, {}) 4446 + self.assertTyEqual(ty, TyCon("list", [IntType])) 4447 + 4448 + def test_list_mismatch(self) -> None: 4449 + expr = List([Int(123), Float(123.0)]) 4450 + with self.assertRaisesRegex(InferenceError, "Unification failed"): 4451 + infer_type(expr, {}) 4452 + 4453 + def test_recursive_fact(self) -> None: 4454 + expr = parse(tokenize("fact . fact = | 0 -> 1 | n -> n * fact (n-1)")) 4455 + ty = infer_type( 4456 + expr, 4457 + { 4458 + "*": Forall([], func_type(IntType, IntType, IntType)), 4459 + "-": Forall([], func_type(IntType, IntType, IntType)), 4460 + }, 4461 + ) 4462 + self.assertTyEqual(ty, func_type(IntType, IntType)) 4463 + 4464 + def test_match_int_int(self) -> None: 4465 + expr = parse(tokenize("| 0 -> 1")) 4466 + ty = infer_type(expr, {}) 4467 + self.assertTyEqual(ty, func_type(IntType, IntType)) 4468 + 4469 + def test_match_int_int_two_cases(self) -> None: 4470 + expr = parse(tokenize("| 0 -> 1 | 1 -> 2")) 4471 + ty = infer_type(expr, {}) 4472 + self.assertTyEqual(ty, func_type(IntType, IntType)) 4473 + 4474 + def test_match_int_int_int_float(self) -> None: 4475 + expr = parse(tokenize("| 0 -> 1 | 1 -> 2.0")) 4476 + with self.assertRaisesRegex(InferenceError, "Unification failed"): 4477 + infer_type(expr, {}) 4478 + 4479 + def test_match_int_int_float_int(self) -> None: 4480 + expr = parse(tokenize("| 0 -> 1 | 1.0 -> 2")) 4481 + with self.assertRaisesRegex(InferenceError, "Unification failed"): 4482 + infer_type(expr, {}) 4483 + 4484 + def test_match_var(self) -> None: 4485 + expr = parse(tokenize("| x -> x + 1")) 4486 + ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4487 + self.assertTyEqual(ty, func_type(IntType, IntType)) 4488 + 4489 + def test_match_int_var(self) -> None: 4490 + expr = parse(tokenize("| 0 -> 1 | x -> x")) 4491 + ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4492 + self.assertTyEqual(ty, func_type(IntType, IntType)) 4493 + 4494 + def test_match_list_of_int(self) -> None: 4495 + expr = parse(tokenize("| [x] -> x + 1")) 4496 + ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4497 + self.assertTyEqual(ty, func_type(list_type(IntType), IntType)) 4498 + 4499 + def test_match_list_of_int_to_list(self) -> None: 4500 + expr = parse(tokenize("| [x] -> [x + 1]")) 4501 + ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4502 + self.assertTyEqual(ty, func_type(list_type(IntType), list_type(IntType))) 4503 + 4504 + def test_match_list_of_int_to_int(self) -> None: 4505 + expr = parse(tokenize("| [] -> 0 | [x] -> 1 | [x, y] -> x+y")) 4506 + ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4507 + self.assertTyEqual(ty, func_type(list_type(IntType), IntType)) 4508 + 4509 + def test_recursive_var_is_unbound(self) -> None: 4510 + expr = parse(tokenize("a . a = a")) 4511 + with self.assertRaisesRegex(InferenceError, "Unbound variable"): 4512 + infer_type(expr, {}) 4513 + 4514 + def test_recursive(self) -> None: 4515 + expr = parse( 4516 + tokenize(""" 4517 + length 4518 + . length = 4519 + | [] -> 0 4520 + | [x, ...xs] -> 1 + length xs 4521 + """) 4522 + ) 4523 + ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4524 + self.assertTyEqual(ty, func_type(list_type(TyVar("t9")), IntType)) 4525 + 4526 + def test_match_list_to_list(self) -> None: 4527 + expr = parse(tokenize("| [] -> [] | x -> x")) 4528 + ty = infer_type(expr, {}) 4529 + self.assertTyEqual(ty, func_type(list_type(TyVar("t1")), list_type(TyVar("t1")))) 4530 + 4531 + def test_match_list_spread(self) -> None: 4532 + expr = parse(tokenize("head . head = | [x, ...] -> x")) 4533 + ty = infer_type(expr, {}) 4534 + self.assertTyEqual(ty, func_type(list_type(TyVar("t4")), TyVar("t4"))) 4535 + 4536 + def test_match_list_spread_named(self) -> None: 4537 + expr = parse(tokenize("sum . sum = | [] -> 0 | [x, ...xs] -> x + sum xs")) 4538 + ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4539 + self.assertTyEqual(ty, func_type(list_type(IntType), IntType)) 4540 + 4541 + def test_match_list_int_to_list(self) -> None: 4542 + expr = parse(tokenize("| [] -> [3] | x -> x")) 4543 + ty = infer_type(expr, {}) 4544 + self.assertTyEqual(ty, func_type(list_type(IntType), list_type(IntType))) 4545 + 4546 + def test_inc(self) -> None: 4547 + expr = parse(tokenize("inc . inc = | 0 -> 1 | 1 -> 2 | a -> a + 1")) 4548 + ty = infer_type(expr, {"+": Forall([], func_type(IntType, IntType, IntType))}) 4549 + self.assertTyEqual(ty, func_type(IntType, IntType)) 4550 + 4551 + def test_bytes(self) -> None: 4552 + expr = Bytes(b"abc") 4553 + ty = infer_type(expr, {}) 4554 + self.assertTyEqual(ty, BytesType) 4555 + 4556 + def test_hole(self) -> None: 4557 + expr = Hole() 4558 + ty = infer_type(expr, {}) 4559 + self.assertTyEqual(ty, HoleType) 4560 + 4561 + 3973 4562 class SerializerTests(unittest.TestCase): 3974 4563 def _serialize(self, obj: Object) -> bytes: 3975 4564 serializer = Serializer() ··· 4650 5239 print(pretty(result)) 4651 5240 4652 5241 5242 + def check_command(args: argparse.Namespace) -> None: 5243 + if args.debug: 5244 + logging.basicConfig(level=logging.DEBUG) 5245 + 5246 + program = args.program_file.read() 5247 + tokens = tokenize(program) 5248 + logger.debug("Tokens: %s", tokens) 5249 + ast = parse(tokens) 5250 + logger.debug("AST: %s", ast) 5251 + result = infer_type(ast, OP_ENV) 5252 + result = recursive_find(result) 5253 + result = minimize(result) 5254 + print(result) 5255 + 5256 + 4653 5257 def apply_command(args: argparse.Namespace) -> None: 4654 5258 if args.debug: 4655 5259 logging.basicConfig(level=logging.DEBUG) ··· 4711 5315 return env_get_split("CFLAGS", default_cflags) 4712 5316 4713 5317 5318 + OP_ENV = { 5319 + "+": Forall([], func_type(IntType, IntType, IntType)), 5320 + "-": Forall([], func_type(IntType, IntType, IntType)), 5321 + "*": Forall([], func_type(IntType, IntType, IntType)), 5322 + "/": Forall([], func_type(IntType, IntType, IntType)), 5323 + "++": Forall([], func_type(StringType, StringType, StringType)), 5324 + } 5325 + 5326 + 4714 5327 def compile_command(args: argparse.Namespace) -> None: 5328 + if args.run: 5329 + args.compile = True 4715 5330 from compiler import compile_to_string 4716 5331 4717 5332 with open(args.file, "r") as f: 4718 5333 source = f.read() 4719 5334 4720 - c_program = compile_to_string(source, args.debug) 5335 + program = parse(tokenize(source)) 5336 + if args.check: 5337 + infer_type(program, OP_ENV) 5338 + c_program = compile_to_string(program, args.debug) 4721 5339 4722 5340 with open(args.platform, "r") as f: 4723 5341 platform = f.read() ··· 4772 5390 eval_.add_argument("program_file", type=argparse.FileType("r")) 4773 5391 eval_.add_argument("--debug", action="store_true") 4774 5392 5393 + check = subparsers.add_parser("check") 5394 + check.set_defaults(func=check_command) 5395 + check.add_argument("program_file", type=argparse.FileType("r")) 5396 + check.add_argument("--debug", action="store_true") 5397 + 4775 5398 apply = subparsers.add_parser("apply") 4776 5399 apply.set_defaults(func=apply_command) 4777 5400 apply.add_argument("program") ··· 4786 5409 comp.add_argument("--memory", type=int) 4787 5410 comp.add_argument("--run", action="store_true") 4788 5411 comp.add_argument("--debug", action="store_true", default=False) 5412 + comp.add_argument("--check", action="store_true", default=False) 4789 5413 # The platform is in the same directory as this file 4790 5414 comp.add_argument("--platform", default=os.path.join(os.path.dirname(__file__), "cli.c")) 4791 5415 ··· 4801 5425 4802 5426 4803 5427 if __name__ == "__main__": 5428 + # This is so that we can use scrapscript.py as a main but also import 5429 + # things from `scrapscript` and not have that be a separate module. 5430 + sys.modules["scrapscript"] = sys.modules[__name__] 4804 5431 main()