this repo has no description
at main 325 lines 11 kB view raw
1//! BILUO transition-based NER parser. 2//! 3//! a greedy left-to-right parser that reads token vectors and predicts 4//! entity spans using the Begin/In/Last/Unit/Out transition system. 5//! this is the same architecture as spaCy's TransitionBasedParser. 6//! 7//! the parser maintains a state (buffer position, open entity) and at 8//! each step predicts the highest-scoring valid action. the "valid" 9//! constraints ensure well-formed entity spans (e.g., I-PERSON can 10//! only follow B-PERSON or I-PERSON of the same label). 11 12const std = @import("std"); 13const ops = @import("ops.zig"); 14 15/// entity label indices — matches en_core_web_sm's NER action table ordering. 16/// this is the training order, NOT alphabetical. 17pub const Label = enum(u8) { 18 ORG = 0, 19 DATE = 1, 20 PERSON = 2, 21 GPE = 3, 22 MONEY = 4, 23 CARDINAL = 5, 24 NORP = 6, 25 PERCENT = 7, 26 WORK_OF_ART = 8, 27 LOC = 9, 28 TIME = 10, 29 QUANTITY = 11, 30 FAC = 12, 31 EVENT = 13, 32 ORDINAL = 14, 33 PRODUCT = 15, 34 LAW = 16, 35 LANGUAGE = 17, 36 37 pub const COUNT = 18; 38}; 39 40/// action types in the BILUO transition system. 41pub const Action = enum(u8) { 42 BEGIN = 0, 43 IN = 1, 44 LAST = 2, 45 UNIT = 3, 46 OUT = 4, 47}; 48 49/// a recognized entity span. 50pub const Entity = struct { 51 start: u32, // token index (inclusive) 52 end: u32, // token index (exclusive) 53 label: Label, 54}; 55 56/// total number of possible actions: B/I/L/U for each label + filler + OUT. 57pub const N_ACTIONS = Label.COUNT * 4 + 2; 58 59/// decode an action index (0..N_ACTIONS-1) into (action_type, label). 60/// layout: [B*18, I*18, L*18, U*18, filler, OUT] — matches spaCy's get_class_name(). 61pub fn decodeAction(idx: usize) struct { action: Action, label: ?Label } { 62 const n = Label.COUNT; 63 if (idx < n) return .{ .action = .BEGIN, .label = @enumFromInt(@as(u8, @intCast(idx))) }; 64 if (idx < 2 * n) return .{ .action = .IN, .label = @enumFromInt(@as(u8, @intCast(idx - n))) }; 65 if (idx < 3 * n) return .{ .action = .LAST, .label = @enumFromInt(@as(u8, @intCast(idx - 2 * n))) }; 66 if (idx < 4 * n) return .{ .action = .UNIT, .label = @enumFromInt(@as(u8, @intCast(idx - 3 * n))) }; 67 if (idx == 4 * n + 1) return .{ .action = .OUT, .label = null }; 68 // idx == 4*n: filler (U-""), always invalid because label is null 69 return .{ .action = .UNIT, .label = null }; 70} 71 72/// parser state for a single document. 73pub const State = struct { 74 /// current buffer position (next token to process) 75 buffer_pos: u32 = 0, 76 /// total number of tokens 77 n_tokens: u32, 78 /// currently open entity label (null if no entity open) 79 open_label: ?Label = null, 80 /// start position of currently open entity 81 open_start: u32 = 0, 82 /// collected entities (fixed capacity, no allocation) 83 entities_buf: [128]Entity = undefined, 84 entities_len: u32 = 0, 85 86 pub const MAX_ENTITIES = 128; 87 88 pub fn init(n_tokens: u32) State { 89 return .{ .n_tokens = n_tokens }; 90 } 91 92 pub fn entities(self: *const State) []const Entity { 93 return self.entities_buf[0..self.entities_len]; 94 } 95 96 fn appendEntity(self: *State, ent: Entity) void { 97 if (self.entities_len < MAX_ENTITIES) { 98 self.entities_buf[self.entities_len] = ent; 99 self.entities_len += 1; 100 } 101 } 102 103 /// is the parser done (buffer exhausted)? 104 pub fn isFinal(self: State) bool { 105 return self.buffer_pos >= self.n_tokens; 106 } 107 108 /// tokens remaining in buffer 109 pub fn remaining(self: State) u32 { 110 return self.n_tokens - self.buffer_pos; 111 } 112 113 /// B(0): current token at front of buffer 114 pub fn b0(self: State) ?u32 { 115 return if (self.buffer_pos < self.n_tokens) self.buffer_pos else null; 116 } 117 118 /// E(0): first token of current open entity (-1 / null if none) 119 pub fn e0(self: State) ?u32 { 120 return if (self.open_label != null) self.open_start else null; 121 } 122 123 /// context feature indices for the parser model. 124 /// returns [B(0), E(0), B(0)-1], using n_tokens as the "padding" sentinel. 125 /// feature 2 is only valid when BOTH B(0) and E(0) are valid (entity is open). 126 pub fn contextIds(self: State) [3]u32 { 127 const pad = self.n_tokens; // index into padding row 128 const b = self.b0() orelse pad; 129 const e = self.e0() orelse pad; 130 return .{ 131 b, 132 e, 133 if (b < pad and e < pad and b > 0) b - 1 else pad, 134 }; 135 } 136 137 /// check whether a given action is valid in the current state. 138 pub fn isValid(self: State, action: Action, label: ?Label) bool { 139 return switch (action) { 140 .BEGIN => self.open_label == null and self.remaining() >= 2 and label != null, 141 .IN => self.open_label != null and self.remaining() >= 2 and 142 label != null and label.? == self.open_label.?, 143 .LAST => self.open_label != null and 144 label != null and label.? == self.open_label.?, 145 .UNIT => self.open_label == null and label != null, 146 .OUT => self.open_label == null, 147 }; 148 } 149 150 /// apply an action, mutating the state. 151 pub fn apply(self: *State, action: Action, label: ?Label) void { 152 switch (action) { 153 .BEGIN => { 154 self.open_label = label; 155 self.open_start = self.buffer_pos; 156 self.buffer_pos += 1; 157 }, 158 .IN => { 159 self.buffer_pos += 1; 160 }, 161 .LAST => { 162 self.appendEntity(.{ 163 .start = self.open_start, 164 .end = self.buffer_pos + 1, 165 .label = self.open_label.?, 166 }); 167 self.open_label = null; 168 self.buffer_pos += 1; 169 }, 170 .UNIT => { 171 self.appendEntity(.{ 172 .start = self.buffer_pos, 173 .end = self.buffer_pos + 1, 174 .label = label.?, 175 }); 176 self.buffer_pos += 1; 177 }, 178 .OUT => { 179 self.buffer_pos += 1; 180 }, 181 } 182 } 183 184 /// compute a validity mask for all N_ACTIONS actions. 185 /// valid[i] = true means action i is allowed in the current state. 186 pub fn validMask(self: State) [N_ACTIONS]bool { 187 var mask: [N_ACTIONS]bool = undefined; 188 for (0..N_ACTIONS) |i| { 189 const decoded = decodeAction(i); 190 mask[i] = self.isValid(decoded.action, decoded.label); 191 } 192 return mask; 193 } 194}; 195 196/// greedy argmax over scores, masked to only valid actions. 197/// returns the index of the highest-scoring valid action. 198pub fn argmaxValid(scores: []const f32, valid: [N_ACTIONS]bool) usize { 199 var best_idx: usize = 0; 200 var best_score: f32 = -std.math.inf(f32); 201 var found = false; 202 203 for (0..N_ACTIONS) |i| { 204 if (valid[i] and scores[i] > best_score) { 205 best_score = scores[i]; 206 best_idx = i; 207 found = true; 208 } 209 } 210 211 // fallback: if nothing is valid (shouldn't happen), return OUT 212 if (!found) return N_ACTIONS - 1; 213 return best_idx; 214} 215 216/// run the greedy parse loop for a document. 217/// scores_fn: given state context IDs, computes scores for all N_ACTIONS actions. 218pub fn parse( 219 n_tokens: u32, 220 scores_fn: *const fn (ctx: [3]u32, scores_out: []f32) void, 221) State { 222 var state = State.init(n_tokens); 223 var scores: [N_ACTIONS]f32 = undefined; 224 225 while (!state.isFinal()) { 226 const ctx = state.contextIds(); 227 scores_fn(ctx, &scores); 228 const valid = state.validMask(); 229 const best = argmaxValid(&scores, valid); 230 const decoded = decodeAction(best); 231 state.apply(decoded.action, decoded.label); 232 } 233 234 return state; 235} 236 237// === tests === 238 239const testing = std.testing; 240 241test "decodeAction round-trip" { 242 // layout: [B*18, I*18, L*18, U*18, filler, OUT] 243 // index 0 = B-ORG (first label in training order) 244 const a0 = decodeAction(0); 245 try testing.expectEqual(Action.BEGIN, a0.action); 246 try testing.expectEqual(Label.ORG, a0.label.?); 247 248 // index 2 = B-PERSON 249 const a2 = decodeAction(2); 250 try testing.expectEqual(Action.BEGIN, a2.action); 251 try testing.expectEqual(Label.PERSON, a2.label.?); 252 253 // index 18 = I-ORG 254 const in0 = decodeAction(18); 255 try testing.expectEqual(Action.IN, in0.action); 256 try testing.expectEqual(Label.ORG, in0.label.?); 257 258 // index 56 = U-PERSON (54 + 2) 259 const up = decodeAction(56); 260 try testing.expectEqual(Action.UNIT, up.action); 261 try testing.expectEqual(Label.PERSON, up.label.?); 262 263 // index 72 = filler (U-""), label is null 264 const filler = decodeAction(72); 265 try testing.expectEqual(Action.UNIT, filler.action); 266 try testing.expectEqual(@as(?Label, null), filler.label); 267 268 // index 73 = OUT 269 const out = decodeAction(73); 270 try testing.expectEqual(Action.OUT, out.action); 271 try testing.expectEqual(@as(?Label, null), out.label); 272} 273 274test "state transitions: simple unit entity" { 275 var state = State.init(3); 276 277 // token 0: U-PERSON 278 try testing.expect(state.isValid(.UNIT, .PERSON)); 279 state.apply(.UNIT, .PERSON); 280 try testing.expectEqual(@as(u32, 1), state.buffer_pos); 281 282 // token 1: OUT 283 try testing.expect(state.isValid(.OUT, null)); 284 state.apply(.OUT, null); 285 286 // token 2: U-GPE 287 state.apply(.UNIT, .GPE); 288 try testing.expect(state.isFinal()); 289 290 // check entities 291 const ents = state.entities(); 292 try testing.expectEqual(@as(usize, 2), ents.len); 293 try testing.expectEqual(Label.PERSON, ents[0].label); 294 try testing.expectEqual(@as(u32, 0), ents[0].start); 295 try testing.expectEqual(@as(u32, 1), ents[0].end); 296 try testing.expectEqual(Label.GPE, ents[1].label); 297} 298 299test "state transitions: multi-token entity" { 300 // "Barack Obama" = B-PERSON, L-PERSON 301 var state = State.init(4); 302 303 state.apply(.BEGIN, .PERSON); 304 try testing.expect(state.open_label != null); 305 try testing.expect(!state.isValid(.BEGIN, .ORG)); // can't begin while entity open 306 try testing.expect(!state.isValid(.OUT, null)); // can't OUT while entity open 307 try testing.expect(state.isValid(.IN, .PERSON)); // can continue 308 try testing.expect(state.isValid(.LAST, .PERSON)); // can end 309 try testing.expect(!state.isValid(.IN, .ORG)); // wrong label 310 311 state.apply(.LAST, .PERSON); 312 try testing.expectEqual(@as(?Label, null), state.open_label); 313 const ents = state.entities(); 314 try testing.expectEqual(@as(usize, 1), ents.len); 315 try testing.expectEqual(@as(u32, 0), ents[0].start); 316 try testing.expectEqual(@as(u32, 2), ents[0].end); 317} 318 319test "validity: BEGIN requires >= 2 remaining" { 320 var state = State.init(1); 321 // only 1 token left — can't BEGIN (need room for LAST) 322 try testing.expect(!state.isValid(.BEGIN, .PERSON)); 323 try testing.expect(state.isValid(.UNIT, .PERSON)); 324 try testing.expect(state.isValid(.OUT, null)); 325}