this repo has no description
1//! BILUO transition-based NER parser.
2//!
3//! a greedy left-to-right parser that reads token vectors and predicts
4//! entity spans using the Begin/In/Last/Unit/Out transition system.
5//! this is the same architecture as spaCy's TransitionBasedParser.
6//!
7//! the parser maintains a state (buffer position, open entity) and at
8//! each step predicts the highest-scoring valid action. the "valid"
9//! constraints ensure well-formed entity spans (e.g., I-PERSON can
10//! only follow B-PERSON or I-PERSON of the same label).
11
12const std = @import("std");
13const ops = @import("ops.zig");
14
15/// entity label indices — matches en_core_web_sm's NER action table ordering.
16/// this is the training order, NOT alphabetical.
17pub const Label = enum(u8) {
18 ORG = 0,
19 DATE = 1,
20 PERSON = 2,
21 GPE = 3,
22 MONEY = 4,
23 CARDINAL = 5,
24 NORP = 6,
25 PERCENT = 7,
26 WORK_OF_ART = 8,
27 LOC = 9,
28 TIME = 10,
29 QUANTITY = 11,
30 FAC = 12,
31 EVENT = 13,
32 ORDINAL = 14,
33 PRODUCT = 15,
34 LAW = 16,
35 LANGUAGE = 17,
36
37 pub const COUNT = 18;
38};
39
40/// action types in the BILUO transition system.
41pub const Action = enum(u8) {
42 BEGIN = 0,
43 IN = 1,
44 LAST = 2,
45 UNIT = 3,
46 OUT = 4,
47};
48
49/// a recognized entity span.
50pub const Entity = struct {
51 start: u32, // token index (inclusive)
52 end: u32, // token index (exclusive)
53 label: Label,
54};
55
56/// total number of possible actions: B/I/L/U for each label + filler + OUT.
57pub const N_ACTIONS = Label.COUNT * 4 + 2;
58
59/// decode an action index (0..N_ACTIONS-1) into (action_type, label).
60/// layout: [B*18, I*18, L*18, U*18, filler, OUT] — matches spaCy's get_class_name().
61pub fn decodeAction(idx: usize) struct { action: Action, label: ?Label } {
62 const n = Label.COUNT;
63 if (idx < n) return .{ .action = .BEGIN, .label = @enumFromInt(@as(u8, @intCast(idx))) };
64 if (idx < 2 * n) return .{ .action = .IN, .label = @enumFromInt(@as(u8, @intCast(idx - n))) };
65 if (idx < 3 * n) return .{ .action = .LAST, .label = @enumFromInt(@as(u8, @intCast(idx - 2 * n))) };
66 if (idx < 4 * n) return .{ .action = .UNIT, .label = @enumFromInt(@as(u8, @intCast(idx - 3 * n))) };
67 if (idx == 4 * n + 1) return .{ .action = .OUT, .label = null };
68 // idx == 4*n: filler (U-""), always invalid because label is null
69 return .{ .action = .UNIT, .label = null };
70}
71
72/// parser state for a single document.
73pub const State = struct {
74 /// current buffer position (next token to process)
75 buffer_pos: u32 = 0,
76 /// total number of tokens
77 n_tokens: u32,
78 /// currently open entity label (null if no entity open)
79 open_label: ?Label = null,
80 /// start position of currently open entity
81 open_start: u32 = 0,
82 /// collected entities (fixed capacity, no allocation)
83 entities_buf: [128]Entity = undefined,
84 entities_len: u32 = 0,
85
86 pub const MAX_ENTITIES = 128;
87
88 pub fn init(n_tokens: u32) State {
89 return .{ .n_tokens = n_tokens };
90 }
91
92 pub fn entities(self: *const State) []const Entity {
93 return self.entities_buf[0..self.entities_len];
94 }
95
96 fn appendEntity(self: *State, ent: Entity) void {
97 if (self.entities_len < MAX_ENTITIES) {
98 self.entities_buf[self.entities_len] = ent;
99 self.entities_len += 1;
100 }
101 }
102
103 /// is the parser done (buffer exhausted)?
104 pub fn isFinal(self: State) bool {
105 return self.buffer_pos >= self.n_tokens;
106 }
107
108 /// tokens remaining in buffer
109 pub fn remaining(self: State) u32 {
110 return self.n_tokens - self.buffer_pos;
111 }
112
113 /// B(0): current token at front of buffer
114 pub fn b0(self: State) ?u32 {
115 return if (self.buffer_pos < self.n_tokens) self.buffer_pos else null;
116 }
117
118 /// E(0): first token of current open entity (-1 / null if none)
119 pub fn e0(self: State) ?u32 {
120 return if (self.open_label != null) self.open_start else null;
121 }
122
123 /// context feature indices for the parser model.
124 /// returns [B(0), E(0), B(0)-1], using n_tokens as the "padding" sentinel.
125 /// feature 2 is only valid when BOTH B(0) and E(0) are valid (entity is open).
126 pub fn contextIds(self: State) [3]u32 {
127 const pad = self.n_tokens; // index into padding row
128 const b = self.b0() orelse pad;
129 const e = self.e0() orelse pad;
130 return .{
131 b,
132 e,
133 if (b < pad and e < pad and b > 0) b - 1 else pad,
134 };
135 }
136
137 /// check whether a given action is valid in the current state.
138 pub fn isValid(self: State, action: Action, label: ?Label) bool {
139 return switch (action) {
140 .BEGIN => self.open_label == null and self.remaining() >= 2 and label != null,
141 .IN => self.open_label != null and self.remaining() >= 2 and
142 label != null and label.? == self.open_label.?,
143 .LAST => self.open_label != null and
144 label != null and label.? == self.open_label.?,
145 .UNIT => self.open_label == null and label != null,
146 .OUT => self.open_label == null,
147 };
148 }
149
150 /// apply an action, mutating the state.
151 pub fn apply(self: *State, action: Action, label: ?Label) void {
152 switch (action) {
153 .BEGIN => {
154 self.open_label = label;
155 self.open_start = self.buffer_pos;
156 self.buffer_pos += 1;
157 },
158 .IN => {
159 self.buffer_pos += 1;
160 },
161 .LAST => {
162 self.appendEntity(.{
163 .start = self.open_start,
164 .end = self.buffer_pos + 1,
165 .label = self.open_label.?,
166 });
167 self.open_label = null;
168 self.buffer_pos += 1;
169 },
170 .UNIT => {
171 self.appendEntity(.{
172 .start = self.buffer_pos,
173 .end = self.buffer_pos + 1,
174 .label = label.?,
175 });
176 self.buffer_pos += 1;
177 },
178 .OUT => {
179 self.buffer_pos += 1;
180 },
181 }
182 }
183
184 /// compute a validity mask for all N_ACTIONS actions.
185 /// valid[i] = true means action i is allowed in the current state.
186 pub fn validMask(self: State) [N_ACTIONS]bool {
187 var mask: [N_ACTIONS]bool = undefined;
188 for (0..N_ACTIONS) |i| {
189 const decoded = decodeAction(i);
190 mask[i] = self.isValid(decoded.action, decoded.label);
191 }
192 return mask;
193 }
194};
195
196/// greedy argmax over scores, masked to only valid actions.
197/// returns the index of the highest-scoring valid action.
198pub fn argmaxValid(scores: []const f32, valid: [N_ACTIONS]bool) usize {
199 var best_idx: usize = 0;
200 var best_score: f32 = -std.math.inf(f32);
201 var found = false;
202
203 for (0..N_ACTIONS) |i| {
204 if (valid[i] and scores[i] > best_score) {
205 best_score = scores[i];
206 best_idx = i;
207 found = true;
208 }
209 }
210
211 // fallback: if nothing is valid (shouldn't happen), return OUT
212 if (!found) return N_ACTIONS - 1;
213 return best_idx;
214}
215
216/// run the greedy parse loop for a document.
217/// scores_fn: given state context IDs, computes scores for all N_ACTIONS actions.
218pub fn parse(
219 n_tokens: u32,
220 scores_fn: *const fn (ctx: [3]u32, scores_out: []f32) void,
221) State {
222 var state = State.init(n_tokens);
223 var scores: [N_ACTIONS]f32 = undefined;
224
225 while (!state.isFinal()) {
226 const ctx = state.contextIds();
227 scores_fn(ctx, &scores);
228 const valid = state.validMask();
229 const best = argmaxValid(&scores, valid);
230 const decoded = decodeAction(best);
231 state.apply(decoded.action, decoded.label);
232 }
233
234 return state;
235}
236
237// === tests ===
238
239const testing = std.testing;
240
241test "decodeAction round-trip" {
242 // layout: [B*18, I*18, L*18, U*18, filler, OUT]
243 // index 0 = B-ORG (first label in training order)
244 const a0 = decodeAction(0);
245 try testing.expectEqual(Action.BEGIN, a0.action);
246 try testing.expectEqual(Label.ORG, a0.label.?);
247
248 // index 2 = B-PERSON
249 const a2 = decodeAction(2);
250 try testing.expectEqual(Action.BEGIN, a2.action);
251 try testing.expectEqual(Label.PERSON, a2.label.?);
252
253 // index 18 = I-ORG
254 const in0 = decodeAction(18);
255 try testing.expectEqual(Action.IN, in0.action);
256 try testing.expectEqual(Label.ORG, in0.label.?);
257
258 // index 56 = U-PERSON (54 + 2)
259 const up = decodeAction(56);
260 try testing.expectEqual(Action.UNIT, up.action);
261 try testing.expectEqual(Label.PERSON, up.label.?);
262
263 // index 72 = filler (U-""), label is null
264 const filler = decodeAction(72);
265 try testing.expectEqual(Action.UNIT, filler.action);
266 try testing.expectEqual(@as(?Label, null), filler.label);
267
268 // index 73 = OUT
269 const out = decodeAction(73);
270 try testing.expectEqual(Action.OUT, out.action);
271 try testing.expectEqual(@as(?Label, null), out.label);
272}
273
274test "state transitions: simple unit entity" {
275 var state = State.init(3);
276
277 // token 0: U-PERSON
278 try testing.expect(state.isValid(.UNIT, .PERSON));
279 state.apply(.UNIT, .PERSON);
280 try testing.expectEqual(@as(u32, 1), state.buffer_pos);
281
282 // token 1: OUT
283 try testing.expect(state.isValid(.OUT, null));
284 state.apply(.OUT, null);
285
286 // token 2: U-GPE
287 state.apply(.UNIT, .GPE);
288 try testing.expect(state.isFinal());
289
290 // check entities
291 const ents = state.entities();
292 try testing.expectEqual(@as(usize, 2), ents.len);
293 try testing.expectEqual(Label.PERSON, ents[0].label);
294 try testing.expectEqual(@as(u32, 0), ents[0].start);
295 try testing.expectEqual(@as(u32, 1), ents[0].end);
296 try testing.expectEqual(Label.GPE, ents[1].label);
297}
298
299test "state transitions: multi-token entity" {
300 // "Barack Obama" = B-PERSON, L-PERSON
301 var state = State.init(4);
302
303 state.apply(.BEGIN, .PERSON);
304 try testing.expect(state.open_label != null);
305 try testing.expect(!state.isValid(.BEGIN, .ORG)); // can't begin while entity open
306 try testing.expect(!state.isValid(.OUT, null)); // can't OUT while entity open
307 try testing.expect(state.isValid(.IN, .PERSON)); // can continue
308 try testing.expect(state.isValid(.LAST, .PERSON)); // can end
309 try testing.expect(!state.isValid(.IN, .ORG)); // wrong label
310
311 state.apply(.LAST, .PERSON);
312 try testing.expectEqual(@as(?Label, null), state.open_label);
313 const ents = state.entities();
314 try testing.expectEqual(@as(usize, 1), ents.len);
315 try testing.expectEqual(@as(u32, 0), ents[0].start);
316 try testing.expectEqual(@as(u32, 2), ents[0].end);
317}
318
319test "validity: BEGIN requires >= 2 remaining" {
320 var state = State.init(1);
321 // only 1 token left — can't BEGIN (need room for LAST)
322 try testing.expect(!state.isValid(.BEGIN, .PERSON));
323 try testing.expect(state.isValid(.UNIT, .PERSON));
324 try testing.expect(state.isValid(.OUT, null));
325}