this repo has no description
at main 404 lines 15 kB view raw
1//! model weight loading and NER inference pipeline. 2//! 3//! loads en_core_web_sm weights from a flat binary (header + contiguous float32s), 4//! then runs the full pipeline: hash embed → CNN encode → linear → parser scoring. 5//! follows the karpathy/llama2.c pattern: mmap/embed bytes, slice into named regions. 6 7const std = @import("std"); 8const ops = @import("ops.zig"); 9const embed = @import("embed.zig"); 10const parser = @import("parser.zig"); 11 12/// maximum tokens per document (coral limits text to 500 chars ≈ ~120 tokens) 13pub const MAX_TOKENS = 128; 14 15const HEADER_MAGIC = 0x5350435A; // "SPCZ" 16const HEADER_VERSION = 1; 17const HEADER_UINT32S = 64; 18const HEADER_BYTES = HEADER_UINT32S * 4; 19 20pub const CnnBlock = struct { 21 W: []const f32, // (nO * nP, nI) = (width*3, width*3) 22 b: []const f32, // (nO * nP,) = (width*3,) 23 G: []const f32, // (width,) 24 b_ln: []const f32, // (width,) 25}; 26 27pub const Model = struct { 28 // embedding 29 embeds: [4]embed.HashEmbed, 30 reduce_W: []const f32, // (width*nP, 4*width) 31 reduce_b: []const f32, // (width*nP,) 32 reduce_G: []const f32, // (width,) 33 reduce_b_ln: []const f32, // (width,) 34 35 // CNN encoder (4 residual blocks) 36 cnn: [4]CnnBlock, 37 38 // linear projection (tok2vec → parser input) 39 linear_W: []const f32, // (hidden, width) 40 linear_b: []const f32, // (hidden,) 41 42 // parser lower (precomputable affine) 43 lower_W: []const f32, // (nF * hidden * nP, hidden) = (3 * lower_dim, input_dim) 44 lower_b: []const f32, // (lower_dim,) 45 lower_pad: []const f32, // (nF * lower_dim,) = (3 * lower_dim,) 46 47 // parser upper 48 upper_W: []const f32, // (n_actions, hidden) 49 upper_b: []const f32, // (n_actions,) 50 51 // dimensions 52 width: u32, // tok2vec width (96) 53 hidden: u32, // parser hidden / linear output (64) 54 n_actions: u32, // 74 55 cnn_nP: u32, // CNN maxout pieces (3) 56 parser_nP: u32, // parser lower maxout pieces (2) 57 58 /// load model from raw weight bytes (mmap'd, @embedFile'd, or heap). 59 pub fn load(bytes: []const u8) !Model { 60 if (bytes.len < HEADER_BYTES) return error.FileTooSmall; 61 62 const header = std.mem.bytesAsValue([HEADER_UINT32S]u32, bytes[0..HEADER_BYTES]); 63 if (header[0] != HEADER_MAGIC) return error.BadMagic; 64 if (header[1] != HEADER_VERSION) return error.BadVersion; 65 66 const width: u32 = header[2]; // 96 67 const cnn_depth: u32 = header[3]; // 4 68 const cnn_nP: u32 = header[4]; // 3 69 const hidden: u32 = header[5]; // 64 70 const parser_nP: u32 = header[6]; // 2 71 const parser_nF: u32 = header[7]; // 3 72 const n_actions: u32 = header[8]; // 74 73 74 if (cnn_depth != 4) return error.UnsupportedCnnDepth; 75 if (parser_nF != 3) return error.UnsupportedParserNF; 76 77 const embed_nVs = [4]u32{ header[9], header[10], header[11], header[12] }; 78 const embed_seeds = [4]u32{ header[13], header[14], header[15], header[16] }; 79 80 // slice weights from the data region after the header 81 const aligned: []align(4) const u8 = @alignCast(bytes[HEADER_BYTES..]); 82 const data = std.mem.bytesAsSlice(f32, aligned); 83 var off: usize = 0; 84 85 // helper to advance through contiguous weights 86 const take = struct { 87 fn f(d: []const f32, o: *usize, n: usize) []const f32 { 88 const s = d[o.*..][0..n]; 89 o.* += n; 90 return s; 91 } 92 }.f; 93 94 // 1. hash embed tables (4x) 95 var embeds: [4]embed.HashEmbed = undefined; 96 for (0..4) |i| { 97 const nV = embed_nVs[i]; 98 const nO = width; 99 embeds[i] = .{ 100 .E = take(data, &off, nV * nO), 101 .nV = nV, 102 .nO = nO, 103 .seed = embed_seeds[i], 104 }; 105 } 106 107 // 2. reduction maxout + layernorm 108 const reduce_dim = width * cnn_nP; // 288 109 const reduce_in = 4 * width; // 384 110 const reduce_W = take(data, &off, reduce_dim * reduce_in); 111 const reduce_b = take(data, &off, reduce_dim); 112 const reduce_G = take(data, &off, width); 113 const reduce_b_ln = take(data, &off, width); 114 115 // 3. CNN blocks (4x) 116 var cnn: [4]CnnBlock = undefined; 117 const cnn_out = width * cnn_nP; // 288 118 const cnn_in = width * 3; // 288 (from seq2col) 119 for (0..4) |i| { 120 cnn[i] = .{ 121 .W = take(data, &off, cnn_out * cnn_in), 122 .b = take(data, &off, cnn_out), 123 .G = take(data, &off, width), 124 .b_ln = take(data, &off, width), 125 }; 126 } 127 128 // 4. linear projection 129 const linear_W = take(data, &off, hidden * width); 130 const linear_b = take(data, &off, hidden); 131 132 // 5. parser lower (precomputable affine) 133 const lower_dim = hidden * parser_nP; // 128 134 const lower_W = take(data, &off, parser_nF * lower_dim * hidden); 135 const lower_b = take(data, &off, lower_dim); 136 const lower_pad = take(data, &off, parser_nF * lower_dim); 137 138 // 6. parser upper 139 const upper_W = take(data, &off, n_actions * hidden); 140 const upper_b = take(data, &off, n_actions); 141 142 return .{ 143 .embeds = embeds, 144 .reduce_W = reduce_W, 145 .reduce_b = reduce_b, 146 .reduce_G = reduce_G, 147 .reduce_b_ln = reduce_b_ln, 148 .cnn = cnn, 149 .linear_W = linear_W, 150 .linear_b = linear_b, 151 .lower_W = lower_W, 152 .lower_b = lower_b, 153 .lower_pad = lower_pad, 154 .upper_W = upper_W, 155 .upper_b = upper_b, 156 .width = width, 157 .hidden = hidden, 158 .n_actions = n_actions, 159 .cnn_nP = cnn_nP, 160 .parser_nP = parser_nP, 161 }; 162 } 163 164 /// embed all tokens via MultiHashEmbed (hash lookups → maxout → layernorm). 165 /// tok_vecs is (n_tokens, width) output buffer. 166 pub fn embedTokens( 167 self: *const Model, 168 tokens: []const []const u8, 169 tok_vecs: []f32, 170 ) void { 171 const w = self.width; 172 const mhe = embed.MultiHashEmbed{ 173 .embeds = self.embeds, 174 .maxout_W = self.reduce_W, 175 .maxout_b = self.reduce_b, 176 .ln_G = self.reduce_G, 177 .ln_b = self.reduce_b_ln, 178 .nO = w, 179 .nP = self.cnn_nP, 180 }; 181 var scratch: [4 * 96 + 96 * 3]f32 = undefined; 182 for (tokens, 0..) |tok, t| { 183 const attrs = embed.extractAttrs(tok); 184 mhe.forward(attrs.asArray(), tok_vecs[t * w ..][0..w], &scratch); 185 } 186 } 187 188 /// CNN padding: number of zero rows added on each side of the sequence. 189 /// matches the CNN depth so boundary tokens have valid neighbors for all blocks. 190 pub const CNN_PAD = 4; 191 192 /// run 4 CNN residual blocks then linear project to tok2vec_out. 193 /// tok_vecs: (n_tokens, width) input, NOT modified. 194 /// expanded: scratch buffer, must be >= (n_tokens + 2*CNN_PAD) * width * 3. 195 /// tok2vec_out: (n_tokens, hidden) output buffer. 196 pub fn encode( 197 self: *const Model, 198 tok_vecs: []const f32, 199 n_tokens: usize, 200 expanded: []f32, 201 tok2vec_out: []f32, 202 ) void { 203 const w = self.width; 204 const pad = CNN_PAD; 205 const padded_len = n_tokens + 2 * pad; 206 const cnn_out_dim = w * self.cnn_nP; // 288 207 const cnn_in_dim = w * 3; // 288 208 var pre_maxout: [96 * 3]f32 = undefined; 209 var post_maxout: [96]f32 = undefined; 210 211 // create padded buffer: [pad zeros | tok_vecs | pad zeros] 212 var padded: [(MAX_TOKENS + 2 * CNN_PAD) * 96]f32 = undefined; 213 // zero the pad regions 214 @memset(padded[0 .. pad * w], 0); 215 // copy token vectors into the middle 216 @memcpy(padded[pad * w ..][0 .. n_tokens * w], tok_vecs[0 .. n_tokens * w]); 217 // zero the trailing pad 218 @memset(padded[(pad + n_tokens) * w ..][0 .. pad * w], 0); 219 220 for (0..4) |blk| { 221 // seq2col on the full padded sequence 222 ops.seq2col(expanded, padded[0 .. padded_len * w], padded_len, w); 223 224 // per-token: maxout + layernorm + residual (including padding rows) 225 for (0..padded_len) |t| { 226 const exp_t = expanded[t * cnn_in_dim ..][0..cnn_in_dim]; 227 ops.matvec_bias(&pre_maxout, exp_t, self.cnn[blk].W, self.cnn[blk].b, cnn_in_dim, cnn_out_dim); 228 ops.maxout(&post_maxout, &pre_maxout, w, self.cnn_nP); 229 ops.layernorm(&post_maxout, &post_maxout, self.cnn[blk].G, self.cnn[blk].b_ln, w); 230 // residual: padded[t] += post_maxout 231 const tv = padded[t * w ..][0..w]; 232 ops.vadd(tv, tv, &post_maxout, w); 233 } 234 } 235 236 // linear projection on the real tokens (skip padding rows) 237 const h = self.hidden; 238 for (0..n_tokens) |t| { 239 ops.matvec_bias( 240 tok2vec_out[t * h ..][0..h], 241 padded[(pad + t) * w ..][0..w], 242 self.linear_W, 243 self.linear_b, 244 w, 245 h, 246 ); 247 } 248 } 249 250 /// compute action scores for the parser at one step. 251 /// ctx: [B(0), E(0), B(0)-1] token indices (n_tokens = padding sentinel). 252 /// tok2vec_out: (n_tokens, hidden) from encode(). 253 pub fn scoreActions( 254 self: *const Model, 255 ctx: [3]u32, 256 tok2vec_out: []const f32, 257 n_tokens: u32, 258 scores: []f32, 259 ) void { 260 const h: usize = self.hidden; 261 const nP: usize = self.parser_nP; 262 const lower_dim = h * nP; // 128 263 264 // accumulate 3 features into hidden 265 var hidden: [128]f32 = [_]f32{0} ** 128; 266 var tmp: [128]f32 = undefined; 267 268 for (0..3) |f| { 269 if (ctx[f] >= n_tokens) { 270 // out-of-bounds → use padding vector 271 ops.vadd( 272 hidden[0..lower_dim], 273 hidden[0..lower_dim], 274 self.lower_pad[f * lower_dim ..][0..lower_dim], 275 lower_dim, 276 ); 277 } else { 278 // W_f @ tok2vec[tok_idx] 279 const W_f = self.lower_W[f * lower_dim * h ..][0 .. lower_dim * h]; 280 const x = tok2vec_out[ctx[f] * h ..][0..h]; 281 ops.matvec(tmp[0..lower_dim], x, W_f, h, lower_dim); 282 ops.vadd(hidden[0..lower_dim], hidden[0..lower_dim], tmp[0..lower_dim], lower_dim); 283 } 284 } 285 286 // add bias 287 ops.vadd(hidden[0..lower_dim], hidden[0..lower_dim], self.lower_b, lower_dim); 288 289 // maxout: (hidden * nP) → (hidden) 290 var maxed: [64]f32 = undefined; 291 ops.maxout(maxed[0..h], hidden[0..lower_dim], h, nP); 292 293 // upper: (n_actions, hidden) @ hidden → scores 294 const na: usize = self.n_actions; 295 ops.matvec_bias(scores[0..na], maxed[0..h], self.upper_W, self.upper_b, h, na); 296 } 297 298 /// run the full NER pipeline on pre-tokenized text. 299 /// tokens: array of token byte slices (pointing into original text). 300 /// returns parser state with recognized entities (token-index spans). 301 pub fn predict(self: *const Model, tokens: []const []const u8) parser.State { 302 const n: u32 = @intCast(@min(tokens.len, MAX_TOKENS)); 303 if (n == 0) return parser.State.init(0); 304 305 const w = self.width; 306 const h = self.hidden; 307 308 // scratch buffers (stack-allocated) 309 var tok_vecs: [MAX_TOKENS * 96]f32 = undefined; 310 var expanded: [(MAX_TOKENS + 2 * CNN_PAD) * 96 * 3]f32 = undefined; 311 var tok2vec_out: [MAX_TOKENS * 64]f32 = undefined; 312 313 // embed 314 self.embedTokens(tokens[0..n], tok_vecs[0 .. n * w]); 315 316 // CNN encode + linear project 317 self.encode( 318 tok_vecs[0 .. n * w], 319 n, 320 expanded[0 .. (n + 2 * CNN_PAD) * w * 3], 321 tok2vec_out[0 .. n * h], 322 ); 323 324 // greedy parse 325 var state = parser.State.init(n); 326 var scores: [parser.N_ACTIONS]f32 = undefined; 327 328 while (!state.isFinal()) { 329 const ctx = state.contextIds(); 330 self.scoreActions(ctx, tok2vec_out[0 .. n * h], n, &scores); 331 const valid = state.validMask(); 332 const best = parser.argmaxValid(&scores, valid); 333 const decoded = parser.decodeAction(best); 334 state.apply(decoded.action, decoded.label); 335 } 336 337 return state; 338 } 339}; 340 341// === tests === 342 343const testing = std.testing; 344 345test "Model.load validates header" { 346 // too small 347 try testing.expectError(error.FileTooSmall, Model.load("")); 348 349 // wrong magic 350 var bad_header: [HEADER_BYTES]u8 = [_]u8{0} ** HEADER_BYTES; 351 try testing.expectError(error.BadMagic, Model.load(&bad_header)); 352} 353 354/// load weight file from disk for tests (returns null if not found). 355fn loadWeightFile() ?[]align(4) const u8 { 356 const file = std.fs.cwd().openFile("weights/en_core_web_sm.bin", .{}) catch return null; 357 defer file.close(); 358 const stat = file.stat() catch return null; 359 const bytes = std.testing.allocator.alignedAlloc(u8, .@"4", stat.size) catch return null; 360 const n = file.readAll(bytes) catch { 361 std.testing.allocator.free(bytes); 362 return null; 363 }; 364 return bytes[0..n]; 365} 366 367test "Model.load from weight file" { 368 const weights = loadWeightFile() orelse return; // skip if weights not available 369 defer std.testing.allocator.free(weights); 370 371 const m = try Model.load(weights); 372 373 try testing.expectEqual(@as(u32, 96), m.width); 374 try testing.expectEqual(@as(u32, 64), m.hidden); 375 try testing.expectEqual(@as(u32, 74), m.n_actions); 376 try testing.expectEqual(@as(u32, 3), m.cnn_nP); 377 try testing.expectEqual(@as(u32, 2), m.parser_nP); 378 379 // verify embed table sizes 380 try testing.expectEqual(@as(usize, 5000), m.embeds[0].nV); 381 try testing.expectEqual(@as(usize, 1000), m.embeds[1].nV); 382 try testing.expectEqual(@as(usize, 2500), m.embeds[2].nV); 383 try testing.expectEqual(@as(usize, 2500), m.embeds[3].nV); 384} 385 386test "Model.predict basic NER" { 387 const weights = loadWeightFile() orelse return; 388 defer std.testing.allocator.free(weights); 389 390 const m = try Model.load(weights); 391 392 // "Barack Obama visited Paris" 393 const tokens = [_][]const u8{ "Barack", "Obama", "visited", "Paris" }; 394 const state = m.predict(&tokens); 395 const ents = state.entities(); 396 397 // should find at least one entity 398 try testing.expect(ents.len > 0); 399 400 // log what we found 401 for (ents) |e| { 402 std.debug.print(" [{d}..{d}) {s}\n", .{ e.start, e.end, @tagName(e.label) }); 403 } 404}