//! BILUO transition-based NER parser. //! //! a greedy left-to-right parser that reads token vectors and predicts //! entity spans using the Begin/In/Last/Unit/Out transition system. //! this is the same architecture as spaCy's TransitionBasedParser. //! //! the parser maintains a state (buffer position, open entity) and at //! each step predicts the highest-scoring valid action. the "valid" //! constraints ensure well-formed entity spans (e.g., I-PERSON can //! only follow B-PERSON or I-PERSON of the same label). const std = @import("std"); const ops = @import("ops.zig"); /// entity label indices — matches en_core_web_sm's NER action table ordering. /// this is the training order, NOT alphabetical. pub const Label = enum(u8) { ORG = 0, DATE = 1, PERSON = 2, GPE = 3, MONEY = 4, CARDINAL = 5, NORP = 6, PERCENT = 7, WORK_OF_ART = 8, LOC = 9, TIME = 10, QUANTITY = 11, FAC = 12, EVENT = 13, ORDINAL = 14, PRODUCT = 15, LAW = 16, LANGUAGE = 17, pub const COUNT = 18; }; /// action types in the BILUO transition system. pub const Action = enum(u8) { BEGIN = 0, IN = 1, LAST = 2, UNIT = 3, OUT = 4, }; /// a recognized entity span. pub const Entity = struct { start: u32, // token index (inclusive) end: u32, // token index (exclusive) label: Label, }; /// total number of possible actions: B/I/L/U for each label + filler + OUT. pub const N_ACTIONS = Label.COUNT * 4 + 2; /// decode an action index (0..N_ACTIONS-1) into (action_type, label). /// layout: [B*18, I*18, L*18, U*18, filler, OUT] — matches spaCy's get_class_name(). pub fn decodeAction(idx: usize) struct { action: Action, label: ?Label } { const n = Label.COUNT; if (idx < n) return .{ .action = .BEGIN, .label = @enumFromInt(@as(u8, @intCast(idx))) }; if (idx < 2 * n) return .{ .action = .IN, .label = @enumFromInt(@as(u8, @intCast(idx - n))) }; if (idx < 3 * n) return .{ .action = .LAST, .label = @enumFromInt(@as(u8, @intCast(idx - 2 * n))) }; if (idx < 4 * n) return .{ .action = .UNIT, .label = @enumFromInt(@as(u8, @intCast(idx - 3 * n))) }; if (idx == 4 * n + 1) return .{ .action = .OUT, .label = null }; // idx == 4*n: filler (U-""), always invalid because label is null return .{ .action = .UNIT, .label = null }; } /// parser state for a single document. pub const State = struct { /// current buffer position (next token to process) buffer_pos: u32 = 0, /// total number of tokens n_tokens: u32, /// currently open entity label (null if no entity open) open_label: ?Label = null, /// start position of currently open entity open_start: u32 = 0, /// collected entities (fixed capacity, no allocation) entities_buf: [128]Entity = undefined, entities_len: u32 = 0, pub const MAX_ENTITIES = 128; pub fn init(n_tokens: u32) State { return .{ .n_tokens = n_tokens }; } pub fn entities(self: *const State) []const Entity { return self.entities_buf[0..self.entities_len]; } fn appendEntity(self: *State, ent: Entity) void { if (self.entities_len < MAX_ENTITIES) { self.entities_buf[self.entities_len] = ent; self.entities_len += 1; } } /// is the parser done (buffer exhausted)? pub fn isFinal(self: State) bool { return self.buffer_pos >= self.n_tokens; } /// tokens remaining in buffer pub fn remaining(self: State) u32 { return self.n_tokens - self.buffer_pos; } /// B(0): current token at front of buffer pub fn b0(self: State) ?u32 { return if (self.buffer_pos < self.n_tokens) self.buffer_pos else null; } /// E(0): first token of current open entity (-1 / null if none) pub fn e0(self: State) ?u32 { return if (self.open_label != null) self.open_start else null; } /// context feature indices for the parser model. /// returns [B(0), E(0), B(0)-1], using n_tokens as the "padding" sentinel. /// feature 2 is only valid when BOTH B(0) and E(0) are valid (entity is open). pub fn contextIds(self: State) [3]u32 { const pad = self.n_tokens; // index into padding row const b = self.b0() orelse pad; const e = self.e0() orelse pad; return .{ b, e, if (b < pad and e < pad and b > 0) b - 1 else pad, }; } /// check whether a given action is valid in the current state. pub fn isValid(self: State, action: Action, label: ?Label) bool { return switch (action) { .BEGIN => self.open_label == null and self.remaining() >= 2 and label != null, .IN => self.open_label != null and self.remaining() >= 2 and label != null and label.? == self.open_label.?, .LAST => self.open_label != null and label != null and label.? == self.open_label.?, .UNIT => self.open_label == null and label != null, .OUT => self.open_label == null, }; } /// apply an action, mutating the state. pub fn apply(self: *State, action: Action, label: ?Label) void { switch (action) { .BEGIN => { self.open_label = label; self.open_start = self.buffer_pos; self.buffer_pos += 1; }, .IN => { self.buffer_pos += 1; }, .LAST => { self.appendEntity(.{ .start = self.open_start, .end = self.buffer_pos + 1, .label = self.open_label.?, }); self.open_label = null; self.buffer_pos += 1; }, .UNIT => { self.appendEntity(.{ .start = self.buffer_pos, .end = self.buffer_pos + 1, .label = label.?, }); self.buffer_pos += 1; }, .OUT => { self.buffer_pos += 1; }, } } /// compute a validity mask for all N_ACTIONS actions. /// valid[i] = true means action i is allowed in the current state. pub fn validMask(self: State) [N_ACTIONS]bool { var mask: [N_ACTIONS]bool = undefined; for (0..N_ACTIONS) |i| { const decoded = decodeAction(i); mask[i] = self.isValid(decoded.action, decoded.label); } return mask; } }; /// greedy argmax over scores, masked to only valid actions. /// returns the index of the highest-scoring valid action. pub fn argmaxValid(scores: []const f32, valid: [N_ACTIONS]bool) usize { var best_idx: usize = 0; var best_score: f32 = -std.math.inf(f32); var found = false; for (0..N_ACTIONS) |i| { if (valid[i] and scores[i] > best_score) { best_score = scores[i]; best_idx = i; found = true; } } // fallback: if nothing is valid (shouldn't happen), return OUT if (!found) return N_ACTIONS - 1; return best_idx; } /// run the greedy parse loop for a document. /// scores_fn: given state context IDs, computes scores for all N_ACTIONS actions. pub fn parse( n_tokens: u32, scores_fn: *const fn (ctx: [3]u32, scores_out: []f32) void, ) State { var state = State.init(n_tokens); var scores: [N_ACTIONS]f32 = undefined; while (!state.isFinal()) { const ctx = state.contextIds(); scores_fn(ctx, &scores); const valid = state.validMask(); const best = argmaxValid(&scores, valid); const decoded = decodeAction(best); state.apply(decoded.action, decoded.label); } return state; } // === tests === const testing = std.testing; test "decodeAction round-trip" { // layout: [B*18, I*18, L*18, U*18, filler, OUT] // index 0 = B-ORG (first label in training order) const a0 = decodeAction(0); try testing.expectEqual(Action.BEGIN, a0.action); try testing.expectEqual(Label.ORG, a0.label.?); // index 2 = B-PERSON const a2 = decodeAction(2); try testing.expectEqual(Action.BEGIN, a2.action); try testing.expectEqual(Label.PERSON, a2.label.?); // index 18 = I-ORG const in0 = decodeAction(18); try testing.expectEqual(Action.IN, in0.action); try testing.expectEqual(Label.ORG, in0.label.?); // index 56 = U-PERSON (54 + 2) const up = decodeAction(56); try testing.expectEqual(Action.UNIT, up.action); try testing.expectEqual(Label.PERSON, up.label.?); // index 72 = filler (U-""), label is null const filler = decodeAction(72); try testing.expectEqual(Action.UNIT, filler.action); try testing.expectEqual(@as(?Label, null), filler.label); // index 73 = OUT const out = decodeAction(73); try testing.expectEqual(Action.OUT, out.action); try testing.expectEqual(@as(?Label, null), out.label); } test "state transitions: simple unit entity" { var state = State.init(3); // token 0: U-PERSON try testing.expect(state.isValid(.UNIT, .PERSON)); state.apply(.UNIT, .PERSON); try testing.expectEqual(@as(u32, 1), state.buffer_pos); // token 1: OUT try testing.expect(state.isValid(.OUT, null)); state.apply(.OUT, null); // token 2: U-GPE state.apply(.UNIT, .GPE); try testing.expect(state.isFinal()); // check entities const ents = state.entities(); try testing.expectEqual(@as(usize, 2), ents.len); try testing.expectEqual(Label.PERSON, ents[0].label); try testing.expectEqual(@as(u32, 0), ents[0].start); try testing.expectEqual(@as(u32, 1), ents[0].end); try testing.expectEqual(Label.GPE, ents[1].label); } test "state transitions: multi-token entity" { // "Barack Obama" = B-PERSON, L-PERSON var state = State.init(4); state.apply(.BEGIN, .PERSON); try testing.expect(state.open_label != null); try testing.expect(!state.isValid(.BEGIN, .ORG)); // can't begin while entity open try testing.expect(!state.isValid(.OUT, null)); // can't OUT while entity open try testing.expect(state.isValid(.IN, .PERSON)); // can continue try testing.expect(state.isValid(.LAST, .PERSON)); // can end try testing.expect(!state.isValid(.IN, .ORG)); // wrong label state.apply(.LAST, .PERSON); try testing.expectEqual(@as(?Label, null), state.open_label); const ents = state.entities(); try testing.expectEqual(@as(usize, 1), ents.len); try testing.expectEqual(@as(u32, 0), ents[0].start); try testing.expectEqual(@as(u32, 2), ents[0].end); } test "validity: BEGIN requires >= 2 remaining" { var state = State.init(1); // only 1 token left — can't BEGIN (need room for LAST) try testing.expect(!state.isValid(.BEGIN, .PERSON)); try testing.expect(state.isValid(.UNIT, .PERSON)); try testing.expect(state.isValid(.OUT, null)); }