//! model weight loading and NER inference pipeline. //! //! loads en_core_web_sm weights from a flat binary (header + contiguous float32s), //! then runs the full pipeline: hash embed → CNN encode → linear → parser scoring. //! follows the karpathy/llama2.c pattern: mmap/embed bytes, slice into named regions. const std = @import("std"); const ops = @import("ops.zig"); const embed = @import("embed.zig"); const parser = @import("parser.zig"); /// maximum tokens per document (coral limits text to 500 chars ≈ ~120 tokens) pub const MAX_TOKENS = 128; const HEADER_MAGIC = 0x5350435A; // "SPCZ" const HEADER_VERSION = 1; const HEADER_UINT32S = 64; const HEADER_BYTES = HEADER_UINT32S * 4; pub const CnnBlock = struct { W: []const f32, // (nO * nP, nI) = (width*3, width*3) b: []const f32, // (nO * nP,) = (width*3,) G: []const f32, // (width,) b_ln: []const f32, // (width,) }; pub const Model = struct { // embedding embeds: [4]embed.HashEmbed, reduce_W: []const f32, // (width*nP, 4*width) reduce_b: []const f32, // (width*nP,) reduce_G: []const f32, // (width,) reduce_b_ln: []const f32, // (width,) // CNN encoder (4 residual blocks) cnn: [4]CnnBlock, // linear projection (tok2vec → parser input) linear_W: []const f32, // (hidden, width) linear_b: []const f32, // (hidden,) // parser lower (precomputable affine) lower_W: []const f32, // (nF * hidden * nP, hidden) = (3 * lower_dim, input_dim) lower_b: []const f32, // (lower_dim,) lower_pad: []const f32, // (nF * lower_dim,) = (3 * lower_dim,) // parser upper upper_W: []const f32, // (n_actions, hidden) upper_b: []const f32, // (n_actions,) // dimensions width: u32, // tok2vec width (96) hidden: u32, // parser hidden / linear output (64) n_actions: u32, // 74 cnn_nP: u32, // CNN maxout pieces (3) parser_nP: u32, // parser lower maxout pieces (2) /// load model from raw weight bytes (mmap'd, @embedFile'd, or heap). pub fn load(bytes: []const u8) !Model { if (bytes.len < HEADER_BYTES) return error.FileTooSmall; const header = std.mem.bytesAsValue([HEADER_UINT32S]u32, bytes[0..HEADER_BYTES]); if (header[0] != HEADER_MAGIC) return error.BadMagic; if (header[1] != HEADER_VERSION) return error.BadVersion; const width: u32 = header[2]; // 96 const cnn_depth: u32 = header[3]; // 4 const cnn_nP: u32 = header[4]; // 3 const hidden: u32 = header[5]; // 64 const parser_nP: u32 = header[6]; // 2 const parser_nF: u32 = header[7]; // 3 const n_actions: u32 = header[8]; // 74 if (cnn_depth != 4) return error.UnsupportedCnnDepth; if (parser_nF != 3) return error.UnsupportedParserNF; const embed_nVs = [4]u32{ header[9], header[10], header[11], header[12] }; const embed_seeds = [4]u32{ header[13], header[14], header[15], header[16] }; // slice weights from the data region after the header const aligned: []align(4) const u8 = @alignCast(bytes[HEADER_BYTES..]); const data = std.mem.bytesAsSlice(f32, aligned); var off: usize = 0; // helper to advance through contiguous weights const take = struct { fn f(d: []const f32, o: *usize, n: usize) []const f32 { const s = d[o.*..][0..n]; o.* += n; return s; } }.f; // 1. hash embed tables (4x) var embeds: [4]embed.HashEmbed = undefined; for (0..4) |i| { const nV = embed_nVs[i]; const nO = width; embeds[i] = .{ .E = take(data, &off, nV * nO), .nV = nV, .nO = nO, .seed = embed_seeds[i], }; } // 2. reduction maxout + layernorm const reduce_dim = width * cnn_nP; // 288 const reduce_in = 4 * width; // 384 const reduce_W = take(data, &off, reduce_dim * reduce_in); const reduce_b = take(data, &off, reduce_dim); const reduce_G = take(data, &off, width); const reduce_b_ln = take(data, &off, width); // 3. CNN blocks (4x) var cnn: [4]CnnBlock = undefined; const cnn_out = width * cnn_nP; // 288 const cnn_in = width * 3; // 288 (from seq2col) for (0..4) |i| { cnn[i] = .{ .W = take(data, &off, cnn_out * cnn_in), .b = take(data, &off, cnn_out), .G = take(data, &off, width), .b_ln = take(data, &off, width), }; } // 4. linear projection const linear_W = take(data, &off, hidden * width); const linear_b = take(data, &off, hidden); // 5. parser lower (precomputable affine) const lower_dim = hidden * parser_nP; // 128 const lower_W = take(data, &off, parser_nF * lower_dim * hidden); const lower_b = take(data, &off, lower_dim); const lower_pad = take(data, &off, parser_nF * lower_dim); // 6. parser upper const upper_W = take(data, &off, n_actions * hidden); const upper_b = take(data, &off, n_actions); return .{ .embeds = embeds, .reduce_W = reduce_W, .reduce_b = reduce_b, .reduce_G = reduce_G, .reduce_b_ln = reduce_b_ln, .cnn = cnn, .linear_W = linear_W, .linear_b = linear_b, .lower_W = lower_W, .lower_b = lower_b, .lower_pad = lower_pad, .upper_W = upper_W, .upper_b = upper_b, .width = width, .hidden = hidden, .n_actions = n_actions, .cnn_nP = cnn_nP, .parser_nP = parser_nP, }; } /// embed all tokens via MultiHashEmbed (hash lookups → maxout → layernorm). /// tok_vecs is (n_tokens, width) output buffer. pub fn embedTokens( self: *const Model, tokens: []const []const u8, tok_vecs: []f32, ) void { const w = self.width; const mhe = embed.MultiHashEmbed{ .embeds = self.embeds, .maxout_W = self.reduce_W, .maxout_b = self.reduce_b, .ln_G = self.reduce_G, .ln_b = self.reduce_b_ln, .nO = w, .nP = self.cnn_nP, }; var scratch: [4 * 96 + 96 * 3]f32 = undefined; for (tokens, 0..) |tok, t| { const attrs = embed.extractAttrs(tok); mhe.forward(attrs.asArray(), tok_vecs[t * w ..][0..w], &scratch); } } /// CNN padding: number of zero rows added on each side of the sequence. /// matches the CNN depth so boundary tokens have valid neighbors for all blocks. pub const CNN_PAD = 4; /// run 4 CNN residual blocks then linear project to tok2vec_out. /// tok_vecs: (n_tokens, width) input, NOT modified. /// expanded: scratch buffer, must be >= (n_tokens + 2*CNN_PAD) * width * 3. /// tok2vec_out: (n_tokens, hidden) output buffer. pub fn encode( self: *const Model, tok_vecs: []const f32, n_tokens: usize, expanded: []f32, tok2vec_out: []f32, ) void { const w = self.width; const pad = CNN_PAD; const padded_len = n_tokens + 2 * pad; const cnn_out_dim = w * self.cnn_nP; // 288 const cnn_in_dim = w * 3; // 288 var pre_maxout: [96 * 3]f32 = undefined; var post_maxout: [96]f32 = undefined; // create padded buffer: [pad zeros | tok_vecs | pad zeros] var padded: [(MAX_TOKENS + 2 * CNN_PAD) * 96]f32 = undefined; // zero the pad regions @memset(padded[0 .. pad * w], 0); // copy token vectors into the middle @memcpy(padded[pad * w ..][0 .. n_tokens * w], tok_vecs[0 .. n_tokens * w]); // zero the trailing pad @memset(padded[(pad + n_tokens) * w ..][0 .. pad * w], 0); for (0..4) |blk| { // seq2col on the full padded sequence ops.seq2col(expanded, padded[0 .. padded_len * w], padded_len, w); // per-token: maxout + layernorm + residual (including padding rows) for (0..padded_len) |t| { const exp_t = expanded[t * cnn_in_dim ..][0..cnn_in_dim]; ops.matvec_bias(&pre_maxout, exp_t, self.cnn[blk].W, self.cnn[blk].b, cnn_in_dim, cnn_out_dim); ops.maxout(&post_maxout, &pre_maxout, w, self.cnn_nP); ops.layernorm(&post_maxout, &post_maxout, self.cnn[blk].G, self.cnn[blk].b_ln, w); // residual: padded[t] += post_maxout const tv = padded[t * w ..][0..w]; ops.vadd(tv, tv, &post_maxout, w); } } // linear projection on the real tokens (skip padding rows) const h = self.hidden; for (0..n_tokens) |t| { ops.matvec_bias( tok2vec_out[t * h ..][0..h], padded[(pad + t) * w ..][0..w], self.linear_W, self.linear_b, w, h, ); } } /// compute action scores for the parser at one step. /// ctx: [B(0), E(0), B(0)-1] token indices (n_tokens = padding sentinel). /// tok2vec_out: (n_tokens, hidden) from encode(). pub fn scoreActions( self: *const Model, ctx: [3]u32, tok2vec_out: []const f32, n_tokens: u32, scores: []f32, ) void { const h: usize = self.hidden; const nP: usize = self.parser_nP; const lower_dim = h * nP; // 128 // accumulate 3 features into hidden var hidden: [128]f32 = [_]f32{0} ** 128; var tmp: [128]f32 = undefined; for (0..3) |f| { if (ctx[f] >= n_tokens) { // out-of-bounds → use padding vector ops.vadd( hidden[0..lower_dim], hidden[0..lower_dim], self.lower_pad[f * lower_dim ..][0..lower_dim], lower_dim, ); } else { // W_f @ tok2vec[tok_idx] const W_f = self.lower_W[f * lower_dim * h ..][0 .. lower_dim * h]; const x = tok2vec_out[ctx[f] * h ..][0..h]; ops.matvec(tmp[0..lower_dim], x, W_f, h, lower_dim); ops.vadd(hidden[0..lower_dim], hidden[0..lower_dim], tmp[0..lower_dim], lower_dim); } } // add bias ops.vadd(hidden[0..lower_dim], hidden[0..lower_dim], self.lower_b, lower_dim); // maxout: (hidden * nP) → (hidden) var maxed: [64]f32 = undefined; ops.maxout(maxed[0..h], hidden[0..lower_dim], h, nP); // upper: (n_actions, hidden) @ hidden → scores const na: usize = self.n_actions; ops.matvec_bias(scores[0..na], maxed[0..h], self.upper_W, self.upper_b, h, na); } /// run the full NER pipeline on pre-tokenized text. /// tokens: array of token byte slices (pointing into original text). /// returns parser state with recognized entities (token-index spans). pub fn predict(self: *const Model, tokens: []const []const u8) parser.State { const n: u32 = @intCast(@min(tokens.len, MAX_TOKENS)); if (n == 0) return parser.State.init(0); const w = self.width; const h = self.hidden; // scratch buffers (stack-allocated) var tok_vecs: [MAX_TOKENS * 96]f32 = undefined; var expanded: [(MAX_TOKENS + 2 * CNN_PAD) * 96 * 3]f32 = undefined; var tok2vec_out: [MAX_TOKENS * 64]f32 = undefined; // embed self.embedTokens(tokens[0..n], tok_vecs[0 .. n * w]); // CNN encode + linear project self.encode( tok_vecs[0 .. n * w], n, expanded[0 .. (n + 2 * CNN_PAD) * w * 3], tok2vec_out[0 .. n * h], ); // greedy parse var state = parser.State.init(n); var scores: [parser.N_ACTIONS]f32 = undefined; while (!state.isFinal()) { const ctx = state.contextIds(); self.scoreActions(ctx, tok2vec_out[0 .. n * h], n, &scores); const valid = state.validMask(); const best = parser.argmaxValid(&scores, valid); const decoded = parser.decodeAction(best); state.apply(decoded.action, decoded.label); } return state; } }; // === tests === const testing = std.testing; test "Model.load validates header" { // too small try testing.expectError(error.FileTooSmall, Model.load("")); // wrong magic var bad_header: [HEADER_BYTES]u8 = [_]u8{0} ** HEADER_BYTES; try testing.expectError(error.BadMagic, Model.load(&bad_header)); } /// load weight file from disk for tests (returns null if not found). fn loadWeightFile() ?[]align(4) const u8 { const file = std.fs.cwd().openFile("weights/en_core_web_sm.bin", .{}) catch return null; defer file.close(); const stat = file.stat() catch return null; const bytes = std.testing.allocator.alignedAlloc(u8, .@"4", stat.size) catch return null; const n = file.readAll(bytes) catch { std.testing.allocator.free(bytes); return null; }; return bytes[0..n]; } test "Model.load from weight file" { const weights = loadWeightFile() orelse return; // skip if weights not available defer std.testing.allocator.free(weights); const m = try Model.load(weights); try testing.expectEqual(@as(u32, 96), m.width); try testing.expectEqual(@as(u32, 64), m.hidden); try testing.expectEqual(@as(u32, 74), m.n_actions); try testing.expectEqual(@as(u32, 3), m.cnn_nP); try testing.expectEqual(@as(u32, 2), m.parser_nP); // verify embed table sizes try testing.expectEqual(@as(usize, 5000), m.embeds[0].nV); try testing.expectEqual(@as(usize, 1000), m.embeds[1].nV); try testing.expectEqual(@as(usize, 2500), m.embeds[2].nV); try testing.expectEqual(@as(usize, 2500), m.embeds[3].nV); } test "Model.predict basic NER" { const weights = loadWeightFile() orelse return; defer std.testing.allocator.free(weights); const m = try Model.load(weights); // "Barack Obama visited Paris" const tokens = [_][]const u8{ "Barack", "Obama", "visited", "Paris" }; const state = m.predict(&tokens); const ents = state.entities(); // should find at least one entity try testing.expect(ents.len > 0); // log what we found for (ents) |e| { std.debug.print(" [{d}..{d}) {s}\n", .{ e.start, e.end, @tagName(e.label) }); } }