this repo has no description
1//! neural network primitives for NER inference.
2//!
3//! pure functions over float slices — no allocations, no state.
4//! follows the karpathy/llama2.c style: explicit dimensions,
5//! pre-allocated buffers, zero abstraction over the math.
6
7const std = @import("std");
8
9const VEC_LEN = std.simd.suggestVectorLength(f32) orelse 8;
10
11/// matrix-vector multiply: out = W @ x
12/// W is (d, n) row-major, x is (n,), out is (d,).
13pub fn matvec(out: []f32, x: []const f32, W: []const f32, n: usize, d: usize) void {
14 std.debug.assert(x.len >= n);
15 std.debug.assert(out.len >= d);
16 std.debug.assert(W.len >= d * n);
17
18 const n_vec = n / VEC_LEN;
19 const n_rem = n % VEC_LEN;
20
21 for (0..d) |i| {
22 const row = W[i * n ..][0..n];
23 var vsum: @Vector(VEC_LEN, f32) = @splat(0.0);
24
25 for (0..n_vec) |v| {
26 const vx: @Vector(VEC_LEN, f32) = x[v * VEC_LEN ..][0..VEC_LEN].*;
27 const vw: @Vector(VEC_LEN, f32) = row[v * VEC_LEN ..][0..VEC_LEN].*;
28 vsum = @mulAdd(@Vector(VEC_LEN, f32), vx, vw, vsum);
29 }
30 var val = @reduce(.Add, vsum);
31
32 // scalar tail
33 const tail = n_vec * VEC_LEN;
34 for (0..n_rem) |j| {
35 val += row[tail + j] * x[tail + j];
36 }
37 out[i] = val;
38 }
39}
40
41/// matrix-vector multiply with bias: out = W @ x + b
42pub fn matvec_bias(out: []f32, x: []const f32, W: []const f32, b: []const f32, n: usize, d: usize) void {
43 matvec(out, x, W, n, d);
44 for (0..d) |i| {
45 out[i] += b[i];
46 }
47}
48
49/// maxout: for each output unit, take the max of nP pieces.
50/// input is (nO * nP,), output is (nO,).
51pub fn maxout(out: []f32, input: []const f32, nO: usize, nP: usize) void {
52 std.debug.assert(input.len >= nO * nP);
53 std.debug.assert(out.len >= nO);
54
55 for (0..nO) |i| {
56 var best: f32 = input[i * nP];
57 for (1..nP) |p| {
58 const val = input[i * nP + p];
59 if (val > best) best = val;
60 }
61 out[i] = best;
62 }
63}
64
65/// layer normalization: out = G * (x - mean) / sqrt(var + eps) + b
66/// operates per-row: x is (batch, n), G and b are (n,).
67/// for single-row (typical in inference): batch=1, just pass a (n,) slice.
68/// uses the two-pass algorithm (matching numpy's .var()) for float32 stability.
69pub fn layernorm(out: []f32, x: []const f32, G: []const f32, b: []const f32, n: usize) void {
70 std.debug.assert(x.len >= n);
71 std.debug.assert(out.len >= n);
72 std.debug.assert(G.len >= n);
73 std.debug.assert(b.len >= n);
74
75 // pass 1: compute mean
76 var sum: f32 = 0.0;
77 for (0..n) |i| {
78 sum += x[i];
79 }
80 const nf: f32 = @floatFromInt(n);
81 const mean = sum / nf;
82
83 // pass 2: compute variance as mean of squared deviations (matches numpy)
84 var var_sum: f32 = 0.0;
85 for (0..n) |i| {
86 const d = x[i] - mean;
87 var_sum += d * d;
88 }
89 const variance = var_sum / nf;
90 const rstd = 1.0 / @sqrt(variance + 1e-8);
91
92 for (0..n) |i| {
93 out[i] = G[i] * (x[i] - mean) * rstd + b[i];
94 }
95}
96
97/// element-wise vector addition: out[i] = a[i] + b[i]
98pub fn vadd(out: []f32, a: []const f32, b: []const f32, n: usize) void {
99 const n_vec = n / VEC_LEN;
100 const n_rem = n % VEC_LEN;
101
102 for (0..n_vec) |v| {
103 const va: @Vector(VEC_LEN, f32) = a[v * VEC_LEN ..][0..VEC_LEN].*;
104 const vb: @Vector(VEC_LEN, f32) = b[v * VEC_LEN ..][0..VEC_LEN].*;
105 const vr = va + vb;
106 const ptr: *[VEC_LEN]f32 = @ptrCast(out[v * VEC_LEN ..][0..VEC_LEN]);
107 ptr.* = vr;
108 }
109 const tail = n_vec * VEC_LEN;
110 for (0..n_rem) |j| {
111 out[tail + j] = a[tail + j] + b[tail + j];
112 }
113}
114
115/// expand_window(size=1): for each token, concatenate [left, center, right].
116/// input is (seq_len, width), output is (seq_len, width * 3).
117/// pads with zeros at boundaries.
118pub fn seq2col(out: []f32, input: []const f32, seq_len: usize, width: usize) void {
119 std.debug.assert(input.len >= seq_len * width);
120 std.debug.assert(out.len >= seq_len * width * 3);
121
122 const out_width = width * 3;
123 for (0..seq_len) |t| {
124 const dst = out[t * out_width ..][0..out_width];
125
126 // left neighbor (zero if t == 0)
127 if (t > 0) {
128 @memcpy(dst[0..width], input[(t - 1) * width ..][0..width]);
129 } else {
130 @memset(dst[0..width], 0.0);
131 }
132
133 // center
134 @memcpy(dst[width..][0..width], input[t * width ..][0..width]);
135
136 // right neighbor (zero if t == seq_len - 1)
137 if (t + 1 < seq_len) {
138 @memcpy(dst[width * 2 ..][0..width], input[(t + 1) * width ..][0..width]);
139 } else {
140 @memset(dst[width * 2 ..][0..width], 0.0);
141 }
142 }
143}
144
145// === tests ===
146
147const testing = std.testing;
148const eps = 1e-4;
149
150fn expectApprox(expected: f32, actual: f32) !void {
151 try testing.expectApproxEqAbs(expected, actual, eps);
152}
153
154test "matvec identity-like" {
155 // 2x2 identity matrix times [3, 7] = [3, 7]
156 const W = [_]f32{ 1, 0, 0, 1 };
157 const x = [_]f32{ 3, 7 };
158 var out: [2]f32 = undefined;
159 matvec(&out, &x, &W, 2, 2);
160 try expectApprox(3.0, out[0]);
161 try expectApprox(7.0, out[1]);
162}
163
164test "matvec general" {
165 // [[1, 2], [3, 4]] @ [5, 6] = [17, 39]
166 const W = [_]f32{ 1, 2, 3, 4 };
167 const x = [_]f32{ 5, 6 };
168 var out: [2]f32 = undefined;
169 matvec(&out, &x, &W, 2, 2);
170 try expectApprox(17.0, out[0]);
171 try expectApprox(39.0, out[1]);
172}
173
174test "maxout basic" {
175 // nO=2, nP=3: input is [1, 5, 3, 2, 8, 4]
176 // unit 0: max(1, 5, 3) = 5
177 // unit 1: max(2, 8, 4) = 8
178 const input = [_]f32{ 1, 5, 3, 2, 8, 4 };
179 var out: [2]f32 = undefined;
180 maxout(&out, &input, 2, 3);
181 try expectApprox(5.0, out[0]);
182 try expectApprox(8.0, out[1]);
183}
184
185test "layernorm basic" {
186 // normalize [1, 2, 3, 4] with G=1, b=0
187 const x = [_]f32{ 1, 2, 3, 4 };
188 const G = [_]f32{ 1, 1, 1, 1 };
189 const b = [_]f32{ 0, 0, 0, 0 };
190 var out: [4]f32 = undefined;
191 layernorm(&out, &x, &G, &b, 4);
192
193 // mean=2.5, var=1.25, result should be ~[-1.342, -0.447, 0.447, 1.342]
194 try testing.expect(out[0] < 0);
195 try testing.expect(out[1] < 0);
196 try testing.expect(out[2] > 0);
197 try testing.expect(out[3] > 0);
198 // should sum to ~0
199 try expectApprox(0.0, out[0] + out[1] + out[2] + out[3]);
200}
201
202test "seq2col basic" {
203 // 3 tokens, width 2: [[1,2], [3,4], [5,6]]
204 // token 0: [0,0, 1,2, 3,4]
205 // token 1: [1,2, 3,4, 5,6]
206 // token 2: [3,4, 5,6, 0,0]
207 const input = [_]f32{ 1, 2, 3, 4, 5, 6 };
208 var out: [18]f32 = undefined;
209 seq2col(&out, &input, 3, 2);
210
211 // token 0
212 try expectApprox(0, out[0]);
213 try expectApprox(0, out[1]);
214 try expectApprox(1, out[2]);
215 try expectApprox(2, out[3]);
216 try expectApprox(3, out[4]);
217 try expectApprox(4, out[5]);
218
219 // token 1
220 try expectApprox(1, out[6]);
221 try expectApprox(2, out[7]);
222 try expectApprox(3, out[8]);
223 try expectApprox(4, out[9]);
224 try expectApprox(5, out[10]);
225 try expectApprox(6, out[11]);
226
227 // token 2
228 try expectApprox(3, out[12]);
229 try expectApprox(4, out[13]);
230 try expectApprox(5, out[14]);
231 try expectApprox(6, out[15]);
232 try expectApprox(0, out[16]);
233 try expectApprox(0, out[17]);
234}
235
236test "vadd basic" {
237 const a = [_]f32{ 1, 2, 3 };
238 const b = [_]f32{ 4, 5, 6 };
239 var out: [3]f32 = undefined;
240 vadd(&out, &a, &b, 3);
241 try expectApprox(5, out[0]);
242 try expectApprox(7, out[1]);
243 try expectApprox(9, out[2]);
244}