//! neural network primitives for NER inference. //! //! pure functions over float slices — no allocations, no state. //! follows the karpathy/llama2.c style: explicit dimensions, //! pre-allocated buffers, zero abstraction over the math. const std = @import("std"); const VEC_LEN = std.simd.suggestVectorLength(f32) orelse 8; /// matrix-vector multiply: out = W @ x /// W is (d, n) row-major, x is (n,), out is (d,). pub fn matvec(out: []f32, x: []const f32, W: []const f32, n: usize, d: usize) void { std.debug.assert(x.len >= n); std.debug.assert(out.len >= d); std.debug.assert(W.len >= d * n); const n_vec = n / VEC_LEN; const n_rem = n % VEC_LEN; for (0..d) |i| { const row = W[i * n ..][0..n]; var vsum: @Vector(VEC_LEN, f32) = @splat(0.0); for (0..n_vec) |v| { const vx: @Vector(VEC_LEN, f32) = x[v * VEC_LEN ..][0..VEC_LEN].*; const vw: @Vector(VEC_LEN, f32) = row[v * VEC_LEN ..][0..VEC_LEN].*; vsum = @mulAdd(@Vector(VEC_LEN, f32), vx, vw, vsum); } var val = @reduce(.Add, vsum); // scalar tail const tail = n_vec * VEC_LEN; for (0..n_rem) |j| { val += row[tail + j] * x[tail + j]; } out[i] = val; } } /// matrix-vector multiply with bias: out = W @ x + b pub fn matvec_bias(out: []f32, x: []const f32, W: []const f32, b: []const f32, n: usize, d: usize) void { matvec(out, x, W, n, d); for (0..d) |i| { out[i] += b[i]; } } /// maxout: for each output unit, take the max of nP pieces. /// input is (nO * nP,), output is (nO,). pub fn maxout(out: []f32, input: []const f32, nO: usize, nP: usize) void { std.debug.assert(input.len >= nO * nP); std.debug.assert(out.len >= nO); for (0..nO) |i| { var best: f32 = input[i * nP]; for (1..nP) |p| { const val = input[i * nP + p]; if (val > best) best = val; } out[i] = best; } } /// layer normalization: out = G * (x - mean) / sqrt(var + eps) + b /// operates per-row: x is (batch, n), G and b are (n,). /// for single-row (typical in inference): batch=1, just pass a (n,) slice. /// uses the two-pass algorithm (matching numpy's .var()) for float32 stability. pub fn layernorm(out: []f32, x: []const f32, G: []const f32, b: []const f32, n: usize) void { std.debug.assert(x.len >= n); std.debug.assert(out.len >= n); std.debug.assert(G.len >= n); std.debug.assert(b.len >= n); // pass 1: compute mean var sum: f32 = 0.0; for (0..n) |i| { sum += x[i]; } const nf: f32 = @floatFromInt(n); const mean = sum / nf; // pass 2: compute variance as mean of squared deviations (matches numpy) var var_sum: f32 = 0.0; for (0..n) |i| { const d = x[i] - mean; var_sum += d * d; } const variance = var_sum / nf; const rstd = 1.0 / @sqrt(variance + 1e-8); for (0..n) |i| { out[i] = G[i] * (x[i] - mean) * rstd + b[i]; } } /// element-wise vector addition: out[i] = a[i] + b[i] pub fn vadd(out: []f32, a: []const f32, b: []const f32, n: usize) void { const n_vec = n / VEC_LEN; const n_rem = n % VEC_LEN; for (0..n_vec) |v| { const va: @Vector(VEC_LEN, f32) = a[v * VEC_LEN ..][0..VEC_LEN].*; const vb: @Vector(VEC_LEN, f32) = b[v * VEC_LEN ..][0..VEC_LEN].*; const vr = va + vb; const ptr: *[VEC_LEN]f32 = @ptrCast(out[v * VEC_LEN ..][0..VEC_LEN]); ptr.* = vr; } const tail = n_vec * VEC_LEN; for (0..n_rem) |j| { out[tail + j] = a[tail + j] + b[tail + j]; } } /// expand_window(size=1): for each token, concatenate [left, center, right]. /// input is (seq_len, width), output is (seq_len, width * 3). /// pads with zeros at boundaries. pub fn seq2col(out: []f32, input: []const f32, seq_len: usize, width: usize) void { std.debug.assert(input.len >= seq_len * width); std.debug.assert(out.len >= seq_len * width * 3); const out_width = width * 3; for (0..seq_len) |t| { const dst = out[t * out_width ..][0..out_width]; // left neighbor (zero if t == 0) if (t > 0) { @memcpy(dst[0..width], input[(t - 1) * width ..][0..width]); } else { @memset(dst[0..width], 0.0); } // center @memcpy(dst[width..][0..width], input[t * width ..][0..width]); // right neighbor (zero if t == seq_len - 1) if (t + 1 < seq_len) { @memcpy(dst[width * 2 ..][0..width], input[(t + 1) * width ..][0..width]); } else { @memset(dst[width * 2 ..][0..width], 0.0); } } } // === tests === const testing = std.testing; const eps = 1e-4; fn expectApprox(expected: f32, actual: f32) !void { try testing.expectApproxEqAbs(expected, actual, eps); } test "matvec identity-like" { // 2x2 identity matrix times [3, 7] = [3, 7] const W = [_]f32{ 1, 0, 0, 1 }; const x = [_]f32{ 3, 7 }; var out: [2]f32 = undefined; matvec(&out, &x, &W, 2, 2); try expectApprox(3.0, out[0]); try expectApprox(7.0, out[1]); } test "matvec general" { // [[1, 2], [3, 4]] @ [5, 6] = [17, 39] const W = [_]f32{ 1, 2, 3, 4 }; const x = [_]f32{ 5, 6 }; var out: [2]f32 = undefined; matvec(&out, &x, &W, 2, 2); try expectApprox(17.0, out[0]); try expectApprox(39.0, out[1]); } test "maxout basic" { // nO=2, nP=3: input is [1, 5, 3, 2, 8, 4] // unit 0: max(1, 5, 3) = 5 // unit 1: max(2, 8, 4) = 8 const input = [_]f32{ 1, 5, 3, 2, 8, 4 }; var out: [2]f32 = undefined; maxout(&out, &input, 2, 3); try expectApprox(5.0, out[0]); try expectApprox(8.0, out[1]); } test "layernorm basic" { // normalize [1, 2, 3, 4] with G=1, b=0 const x = [_]f32{ 1, 2, 3, 4 }; const G = [_]f32{ 1, 1, 1, 1 }; const b = [_]f32{ 0, 0, 0, 0 }; var out: [4]f32 = undefined; layernorm(&out, &x, &G, &b, 4); // mean=2.5, var=1.25, result should be ~[-1.342, -0.447, 0.447, 1.342] try testing.expect(out[0] < 0); try testing.expect(out[1] < 0); try testing.expect(out[2] > 0); try testing.expect(out[3] > 0); // should sum to ~0 try expectApprox(0.0, out[0] + out[1] + out[2] + out[3]); } test "seq2col basic" { // 3 tokens, width 2: [[1,2], [3,4], [5,6]] // token 0: [0,0, 1,2, 3,4] // token 1: [1,2, 3,4, 5,6] // token 2: [3,4, 5,6, 0,0] const input = [_]f32{ 1, 2, 3, 4, 5, 6 }; var out: [18]f32 = undefined; seq2col(&out, &input, 3, 2); // token 0 try expectApprox(0, out[0]); try expectApprox(0, out[1]); try expectApprox(1, out[2]); try expectApprox(2, out[3]); try expectApprox(3, out[4]); try expectApprox(4, out[5]); // token 1 try expectApprox(1, out[6]); try expectApprox(2, out[7]); try expectApprox(3, out[8]); try expectApprox(4, out[9]); try expectApprox(5, out[10]); try expectApprox(6, out[11]); // token 2 try expectApprox(3, out[12]); try expectApprox(4, out[13]); try expectApprox(5, out[14]); try expectApprox(6, out[15]); try expectApprox(0, out[16]); try expectApprox(0, out[17]); } test "vadd basic" { const a = [_]f32{ 1, 2, 3 }; const b = [_]f32{ 4, 5, 6 }; var out: [3]f32 = undefined; vadd(&out, &a, &b, 3); try expectApprox(5, out[0]); try expectApprox(7, out[1]); try expectApprox(9, out[2]); }