this repo has no description
1//! model weight loading and NER inference pipeline.
2//!
3//! loads en_core_web_sm weights from a flat binary (header + contiguous float32s),
4//! then runs the full pipeline: hash embed → CNN encode → linear → parser scoring.
5//! follows the karpathy/llama2.c pattern: mmap/embed bytes, slice into named regions.
6
7const std = @import("std");
8const ops = @import("ops.zig");
9const embed = @import("embed.zig");
10const parser = @import("parser.zig");
11
12/// maximum tokens per document (coral limits text to 500 chars ≈ ~120 tokens)
13pub const MAX_TOKENS = 128;
14
15const HEADER_MAGIC = 0x5350435A; // "SPCZ"
16const HEADER_VERSION = 1;
17const HEADER_UINT32S = 64;
18const HEADER_BYTES = HEADER_UINT32S * 4;
19
20pub const CnnBlock = struct {
21 W: []const f32, // (nO * nP, nI) = (width*3, width*3)
22 b: []const f32, // (nO * nP,) = (width*3,)
23 G: []const f32, // (width,)
24 b_ln: []const f32, // (width,)
25};
26
27pub const Model = struct {
28 // embedding
29 embeds: [4]embed.HashEmbed,
30 reduce_W: []const f32, // (width*nP, 4*width)
31 reduce_b: []const f32, // (width*nP,)
32 reduce_G: []const f32, // (width,)
33 reduce_b_ln: []const f32, // (width,)
34
35 // CNN encoder (4 residual blocks)
36 cnn: [4]CnnBlock,
37
38 // linear projection (tok2vec → parser input)
39 linear_W: []const f32, // (hidden, width)
40 linear_b: []const f32, // (hidden,)
41
42 // parser lower (precomputable affine)
43 lower_W: []const f32, // (nF * hidden * nP, hidden) = (3 * lower_dim, input_dim)
44 lower_b: []const f32, // (lower_dim,)
45 lower_pad: []const f32, // (nF * lower_dim,) = (3 * lower_dim,)
46
47 // parser upper
48 upper_W: []const f32, // (n_actions, hidden)
49 upper_b: []const f32, // (n_actions,)
50
51 // dimensions
52 width: u32, // tok2vec width (96)
53 hidden: u32, // parser hidden / linear output (64)
54 n_actions: u32, // 74
55 cnn_nP: u32, // CNN maxout pieces (3)
56 parser_nP: u32, // parser lower maxout pieces (2)
57
58 /// load model from raw weight bytes (mmap'd, @embedFile'd, or heap).
59 pub fn load(bytes: []const u8) !Model {
60 if (bytes.len < HEADER_BYTES) return error.FileTooSmall;
61
62 const header = std.mem.bytesAsValue([HEADER_UINT32S]u32, bytes[0..HEADER_BYTES]);
63 if (header[0] != HEADER_MAGIC) return error.BadMagic;
64 if (header[1] != HEADER_VERSION) return error.BadVersion;
65
66 const width: u32 = header[2]; // 96
67 const cnn_depth: u32 = header[3]; // 4
68 const cnn_nP: u32 = header[4]; // 3
69 const hidden: u32 = header[5]; // 64
70 const parser_nP: u32 = header[6]; // 2
71 const parser_nF: u32 = header[7]; // 3
72 const n_actions: u32 = header[8]; // 74
73
74 if (cnn_depth != 4) return error.UnsupportedCnnDepth;
75 if (parser_nF != 3) return error.UnsupportedParserNF;
76
77 const embed_nVs = [4]u32{ header[9], header[10], header[11], header[12] };
78 const embed_seeds = [4]u32{ header[13], header[14], header[15], header[16] };
79
80 // slice weights from the data region after the header
81 const aligned: []align(4) const u8 = @alignCast(bytes[HEADER_BYTES..]);
82 const data = std.mem.bytesAsSlice(f32, aligned);
83 var off: usize = 0;
84
85 // helper to advance through contiguous weights
86 const take = struct {
87 fn f(d: []const f32, o: *usize, n: usize) []const f32 {
88 const s = d[o.*..][0..n];
89 o.* += n;
90 return s;
91 }
92 }.f;
93
94 // 1. hash embed tables (4x)
95 var embeds: [4]embed.HashEmbed = undefined;
96 for (0..4) |i| {
97 const nV = embed_nVs[i];
98 const nO = width;
99 embeds[i] = .{
100 .E = take(data, &off, nV * nO),
101 .nV = nV,
102 .nO = nO,
103 .seed = embed_seeds[i],
104 };
105 }
106
107 // 2. reduction maxout + layernorm
108 const reduce_dim = width * cnn_nP; // 288
109 const reduce_in = 4 * width; // 384
110 const reduce_W = take(data, &off, reduce_dim * reduce_in);
111 const reduce_b = take(data, &off, reduce_dim);
112 const reduce_G = take(data, &off, width);
113 const reduce_b_ln = take(data, &off, width);
114
115 // 3. CNN blocks (4x)
116 var cnn: [4]CnnBlock = undefined;
117 const cnn_out = width * cnn_nP; // 288
118 const cnn_in = width * 3; // 288 (from seq2col)
119 for (0..4) |i| {
120 cnn[i] = .{
121 .W = take(data, &off, cnn_out * cnn_in),
122 .b = take(data, &off, cnn_out),
123 .G = take(data, &off, width),
124 .b_ln = take(data, &off, width),
125 };
126 }
127
128 // 4. linear projection
129 const linear_W = take(data, &off, hidden * width);
130 const linear_b = take(data, &off, hidden);
131
132 // 5. parser lower (precomputable affine)
133 const lower_dim = hidden * parser_nP; // 128
134 const lower_W = take(data, &off, parser_nF * lower_dim * hidden);
135 const lower_b = take(data, &off, lower_dim);
136 const lower_pad = take(data, &off, parser_nF * lower_dim);
137
138 // 6. parser upper
139 const upper_W = take(data, &off, n_actions * hidden);
140 const upper_b = take(data, &off, n_actions);
141
142 return .{
143 .embeds = embeds,
144 .reduce_W = reduce_W,
145 .reduce_b = reduce_b,
146 .reduce_G = reduce_G,
147 .reduce_b_ln = reduce_b_ln,
148 .cnn = cnn,
149 .linear_W = linear_W,
150 .linear_b = linear_b,
151 .lower_W = lower_W,
152 .lower_b = lower_b,
153 .lower_pad = lower_pad,
154 .upper_W = upper_W,
155 .upper_b = upper_b,
156 .width = width,
157 .hidden = hidden,
158 .n_actions = n_actions,
159 .cnn_nP = cnn_nP,
160 .parser_nP = parser_nP,
161 };
162 }
163
164 /// embed all tokens via MultiHashEmbed (hash lookups → maxout → layernorm).
165 /// tok_vecs is (n_tokens, width) output buffer.
166 pub fn embedTokens(
167 self: *const Model,
168 tokens: []const []const u8,
169 tok_vecs: []f32,
170 ) void {
171 const w = self.width;
172 const mhe = embed.MultiHashEmbed{
173 .embeds = self.embeds,
174 .maxout_W = self.reduce_W,
175 .maxout_b = self.reduce_b,
176 .ln_G = self.reduce_G,
177 .ln_b = self.reduce_b_ln,
178 .nO = w,
179 .nP = self.cnn_nP,
180 };
181 var scratch: [4 * 96 + 96 * 3]f32 = undefined;
182 for (tokens, 0..) |tok, t| {
183 const attrs = embed.extractAttrs(tok);
184 mhe.forward(attrs.asArray(), tok_vecs[t * w ..][0..w], &scratch);
185 }
186 }
187
188 /// CNN padding: number of zero rows added on each side of the sequence.
189 /// matches the CNN depth so boundary tokens have valid neighbors for all blocks.
190 pub const CNN_PAD = 4;
191
192 /// run 4 CNN residual blocks then linear project to tok2vec_out.
193 /// tok_vecs: (n_tokens, width) input, NOT modified.
194 /// expanded: scratch buffer, must be >= (n_tokens + 2*CNN_PAD) * width * 3.
195 /// tok2vec_out: (n_tokens, hidden) output buffer.
196 pub fn encode(
197 self: *const Model,
198 tok_vecs: []const f32,
199 n_tokens: usize,
200 expanded: []f32,
201 tok2vec_out: []f32,
202 ) void {
203 const w = self.width;
204 const pad = CNN_PAD;
205 const padded_len = n_tokens + 2 * pad;
206 const cnn_out_dim = w * self.cnn_nP; // 288
207 const cnn_in_dim = w * 3; // 288
208 var pre_maxout: [96 * 3]f32 = undefined;
209 var post_maxout: [96]f32 = undefined;
210
211 // create padded buffer: [pad zeros | tok_vecs | pad zeros]
212 var padded: [(MAX_TOKENS + 2 * CNN_PAD) * 96]f32 = undefined;
213 // zero the pad regions
214 @memset(padded[0 .. pad * w], 0);
215 // copy token vectors into the middle
216 @memcpy(padded[pad * w ..][0 .. n_tokens * w], tok_vecs[0 .. n_tokens * w]);
217 // zero the trailing pad
218 @memset(padded[(pad + n_tokens) * w ..][0 .. pad * w], 0);
219
220 for (0..4) |blk| {
221 // seq2col on the full padded sequence
222 ops.seq2col(expanded, padded[0 .. padded_len * w], padded_len, w);
223
224 // per-token: maxout + layernorm + residual (including padding rows)
225 for (0..padded_len) |t| {
226 const exp_t = expanded[t * cnn_in_dim ..][0..cnn_in_dim];
227 ops.matvec_bias(&pre_maxout, exp_t, self.cnn[blk].W, self.cnn[blk].b, cnn_in_dim, cnn_out_dim);
228 ops.maxout(&post_maxout, &pre_maxout, w, self.cnn_nP);
229 ops.layernorm(&post_maxout, &post_maxout, self.cnn[blk].G, self.cnn[blk].b_ln, w);
230 // residual: padded[t] += post_maxout
231 const tv = padded[t * w ..][0..w];
232 ops.vadd(tv, tv, &post_maxout, w);
233 }
234 }
235
236 // linear projection on the real tokens (skip padding rows)
237 const h = self.hidden;
238 for (0..n_tokens) |t| {
239 ops.matvec_bias(
240 tok2vec_out[t * h ..][0..h],
241 padded[(pad + t) * w ..][0..w],
242 self.linear_W,
243 self.linear_b,
244 w,
245 h,
246 );
247 }
248 }
249
250 /// compute action scores for the parser at one step.
251 /// ctx: [B(0), E(0), B(0)-1] token indices (n_tokens = padding sentinel).
252 /// tok2vec_out: (n_tokens, hidden) from encode().
253 pub fn scoreActions(
254 self: *const Model,
255 ctx: [3]u32,
256 tok2vec_out: []const f32,
257 n_tokens: u32,
258 scores: []f32,
259 ) void {
260 const h: usize = self.hidden;
261 const nP: usize = self.parser_nP;
262 const lower_dim = h * nP; // 128
263
264 // accumulate 3 features into hidden
265 var hidden: [128]f32 = [_]f32{0} ** 128;
266 var tmp: [128]f32 = undefined;
267
268 for (0..3) |f| {
269 if (ctx[f] >= n_tokens) {
270 // out-of-bounds → use padding vector
271 ops.vadd(
272 hidden[0..lower_dim],
273 hidden[0..lower_dim],
274 self.lower_pad[f * lower_dim ..][0..lower_dim],
275 lower_dim,
276 );
277 } else {
278 // W_f @ tok2vec[tok_idx]
279 const W_f = self.lower_W[f * lower_dim * h ..][0 .. lower_dim * h];
280 const x = tok2vec_out[ctx[f] * h ..][0..h];
281 ops.matvec(tmp[0..lower_dim], x, W_f, h, lower_dim);
282 ops.vadd(hidden[0..lower_dim], hidden[0..lower_dim], tmp[0..lower_dim], lower_dim);
283 }
284 }
285
286 // add bias
287 ops.vadd(hidden[0..lower_dim], hidden[0..lower_dim], self.lower_b, lower_dim);
288
289 // maxout: (hidden * nP) → (hidden)
290 var maxed: [64]f32 = undefined;
291 ops.maxout(maxed[0..h], hidden[0..lower_dim], h, nP);
292
293 // upper: (n_actions, hidden) @ hidden → scores
294 const na: usize = self.n_actions;
295 ops.matvec_bias(scores[0..na], maxed[0..h], self.upper_W, self.upper_b, h, na);
296 }
297
298 /// run the full NER pipeline on pre-tokenized text.
299 /// tokens: array of token byte slices (pointing into original text).
300 /// returns parser state with recognized entities (token-index spans).
301 pub fn predict(self: *const Model, tokens: []const []const u8) parser.State {
302 const n: u32 = @intCast(@min(tokens.len, MAX_TOKENS));
303 if (n == 0) return parser.State.init(0);
304
305 const w = self.width;
306 const h = self.hidden;
307
308 // scratch buffers (stack-allocated)
309 var tok_vecs: [MAX_TOKENS * 96]f32 = undefined;
310 var expanded: [(MAX_TOKENS + 2 * CNN_PAD) * 96 * 3]f32 = undefined;
311 var tok2vec_out: [MAX_TOKENS * 64]f32 = undefined;
312
313 // embed
314 self.embedTokens(tokens[0..n], tok_vecs[0 .. n * w]);
315
316 // CNN encode + linear project
317 self.encode(
318 tok_vecs[0 .. n * w],
319 n,
320 expanded[0 .. (n + 2 * CNN_PAD) * w * 3],
321 tok2vec_out[0 .. n * h],
322 );
323
324 // greedy parse
325 var state = parser.State.init(n);
326 var scores: [parser.N_ACTIONS]f32 = undefined;
327
328 while (!state.isFinal()) {
329 const ctx = state.contextIds();
330 self.scoreActions(ctx, tok2vec_out[0 .. n * h], n, &scores);
331 const valid = state.validMask();
332 const best = parser.argmaxValid(&scores, valid);
333 const decoded = parser.decodeAction(best);
334 state.apply(decoded.action, decoded.label);
335 }
336
337 return state;
338 }
339};
340
341// === tests ===
342
343const testing = std.testing;
344
345test "Model.load validates header" {
346 // too small
347 try testing.expectError(error.FileTooSmall, Model.load(""));
348
349 // wrong magic
350 var bad_header: [HEADER_BYTES]u8 = [_]u8{0} ** HEADER_BYTES;
351 try testing.expectError(error.BadMagic, Model.load(&bad_header));
352}
353
354/// load weight file from disk for tests (returns null if not found).
355fn loadWeightFile() ?[]align(4) const u8 {
356 const file = std.fs.cwd().openFile("weights/en_core_web_sm.bin", .{}) catch return null;
357 defer file.close();
358 const stat = file.stat() catch return null;
359 const bytes = std.testing.allocator.alignedAlloc(u8, .@"4", stat.size) catch return null;
360 const n = file.readAll(bytes) catch {
361 std.testing.allocator.free(bytes);
362 return null;
363 };
364 return bytes[0..n];
365}
366
367test "Model.load from weight file" {
368 const weights = loadWeightFile() orelse return; // skip if weights not available
369 defer std.testing.allocator.free(weights);
370
371 const m = try Model.load(weights);
372
373 try testing.expectEqual(@as(u32, 96), m.width);
374 try testing.expectEqual(@as(u32, 64), m.hidden);
375 try testing.expectEqual(@as(u32, 74), m.n_actions);
376 try testing.expectEqual(@as(u32, 3), m.cnn_nP);
377 try testing.expectEqual(@as(u32, 2), m.parser_nP);
378
379 // verify embed table sizes
380 try testing.expectEqual(@as(usize, 5000), m.embeds[0].nV);
381 try testing.expectEqual(@as(usize, 1000), m.embeds[1].nV);
382 try testing.expectEqual(@as(usize, 2500), m.embeds[2].nV);
383 try testing.expectEqual(@as(usize, 2500), m.embeds[3].nV);
384}
385
386test "Model.predict basic NER" {
387 const weights = loadWeightFile() orelse return;
388 defer std.testing.allocator.free(weights);
389
390 const m = try Model.load(weights);
391
392 // "Barack Obama visited Paris"
393 const tokens = [_][]const u8{ "Barack", "Obama", "visited", "Paris" };
394 const state = m.predict(&tokens);
395 const ents = state.entities();
396
397 // should find at least one entity
398 try testing.expect(ents.len > 0);
399
400 // log what we found
401 for (ents) |e| {
402 std.debug.print(" [{d}..{d}) {s}\n", .{ e.start, e.end, @tagName(e.label) });
403 }
404}