this repo has no description
at main 244 lines 7.5 kB view raw
1//! neural network primitives for NER inference. 2//! 3//! pure functions over float slices — no allocations, no state. 4//! follows the karpathy/llama2.c style: explicit dimensions, 5//! pre-allocated buffers, zero abstraction over the math. 6 7const std = @import("std"); 8 9const VEC_LEN = std.simd.suggestVectorLength(f32) orelse 8; 10 11/// matrix-vector multiply: out = W @ x 12/// W is (d, n) row-major, x is (n,), out is (d,). 13pub fn matvec(out: []f32, x: []const f32, W: []const f32, n: usize, d: usize) void { 14 std.debug.assert(x.len >= n); 15 std.debug.assert(out.len >= d); 16 std.debug.assert(W.len >= d * n); 17 18 const n_vec = n / VEC_LEN; 19 const n_rem = n % VEC_LEN; 20 21 for (0..d) |i| { 22 const row = W[i * n ..][0..n]; 23 var vsum: @Vector(VEC_LEN, f32) = @splat(0.0); 24 25 for (0..n_vec) |v| { 26 const vx: @Vector(VEC_LEN, f32) = x[v * VEC_LEN ..][0..VEC_LEN].*; 27 const vw: @Vector(VEC_LEN, f32) = row[v * VEC_LEN ..][0..VEC_LEN].*; 28 vsum = @mulAdd(@Vector(VEC_LEN, f32), vx, vw, vsum); 29 } 30 var val = @reduce(.Add, vsum); 31 32 // scalar tail 33 const tail = n_vec * VEC_LEN; 34 for (0..n_rem) |j| { 35 val += row[tail + j] * x[tail + j]; 36 } 37 out[i] = val; 38 } 39} 40 41/// matrix-vector multiply with bias: out = W @ x + b 42pub fn matvec_bias(out: []f32, x: []const f32, W: []const f32, b: []const f32, n: usize, d: usize) void { 43 matvec(out, x, W, n, d); 44 for (0..d) |i| { 45 out[i] += b[i]; 46 } 47} 48 49/// maxout: for each output unit, take the max of nP pieces. 50/// input is (nO * nP,), output is (nO,). 51pub fn maxout(out: []f32, input: []const f32, nO: usize, nP: usize) void { 52 std.debug.assert(input.len >= nO * nP); 53 std.debug.assert(out.len >= nO); 54 55 for (0..nO) |i| { 56 var best: f32 = input[i * nP]; 57 for (1..nP) |p| { 58 const val = input[i * nP + p]; 59 if (val > best) best = val; 60 } 61 out[i] = best; 62 } 63} 64 65/// layer normalization: out = G * (x - mean) / sqrt(var + eps) + b 66/// operates per-row: x is (batch, n), G and b are (n,). 67/// for single-row (typical in inference): batch=1, just pass a (n,) slice. 68/// uses the two-pass algorithm (matching numpy's .var()) for float32 stability. 69pub fn layernorm(out: []f32, x: []const f32, G: []const f32, b: []const f32, n: usize) void { 70 std.debug.assert(x.len >= n); 71 std.debug.assert(out.len >= n); 72 std.debug.assert(G.len >= n); 73 std.debug.assert(b.len >= n); 74 75 // pass 1: compute mean 76 var sum: f32 = 0.0; 77 for (0..n) |i| { 78 sum += x[i]; 79 } 80 const nf: f32 = @floatFromInt(n); 81 const mean = sum / nf; 82 83 // pass 2: compute variance as mean of squared deviations (matches numpy) 84 var var_sum: f32 = 0.0; 85 for (0..n) |i| { 86 const d = x[i] - mean; 87 var_sum += d * d; 88 } 89 const variance = var_sum / nf; 90 const rstd = 1.0 / @sqrt(variance + 1e-8); 91 92 for (0..n) |i| { 93 out[i] = G[i] * (x[i] - mean) * rstd + b[i]; 94 } 95} 96 97/// element-wise vector addition: out[i] = a[i] + b[i] 98pub fn vadd(out: []f32, a: []const f32, b: []const f32, n: usize) void { 99 const n_vec = n / VEC_LEN; 100 const n_rem = n % VEC_LEN; 101 102 for (0..n_vec) |v| { 103 const va: @Vector(VEC_LEN, f32) = a[v * VEC_LEN ..][0..VEC_LEN].*; 104 const vb: @Vector(VEC_LEN, f32) = b[v * VEC_LEN ..][0..VEC_LEN].*; 105 const vr = va + vb; 106 const ptr: *[VEC_LEN]f32 = @ptrCast(out[v * VEC_LEN ..][0..VEC_LEN]); 107 ptr.* = vr; 108 } 109 const tail = n_vec * VEC_LEN; 110 for (0..n_rem) |j| { 111 out[tail + j] = a[tail + j] + b[tail + j]; 112 } 113} 114 115/// expand_window(size=1): for each token, concatenate [left, center, right]. 116/// input is (seq_len, width), output is (seq_len, width * 3). 117/// pads with zeros at boundaries. 118pub fn seq2col(out: []f32, input: []const f32, seq_len: usize, width: usize) void { 119 std.debug.assert(input.len >= seq_len * width); 120 std.debug.assert(out.len >= seq_len * width * 3); 121 122 const out_width = width * 3; 123 for (0..seq_len) |t| { 124 const dst = out[t * out_width ..][0..out_width]; 125 126 // left neighbor (zero if t == 0) 127 if (t > 0) { 128 @memcpy(dst[0..width], input[(t - 1) * width ..][0..width]); 129 } else { 130 @memset(dst[0..width], 0.0); 131 } 132 133 // center 134 @memcpy(dst[width..][0..width], input[t * width ..][0..width]); 135 136 // right neighbor (zero if t == seq_len - 1) 137 if (t + 1 < seq_len) { 138 @memcpy(dst[width * 2 ..][0..width], input[(t + 1) * width ..][0..width]); 139 } else { 140 @memset(dst[width * 2 ..][0..width], 0.0); 141 } 142 } 143} 144 145// === tests === 146 147const testing = std.testing; 148const eps = 1e-4; 149 150fn expectApprox(expected: f32, actual: f32) !void { 151 try testing.expectApproxEqAbs(expected, actual, eps); 152} 153 154test "matvec identity-like" { 155 // 2x2 identity matrix times [3, 7] = [3, 7] 156 const W = [_]f32{ 1, 0, 0, 1 }; 157 const x = [_]f32{ 3, 7 }; 158 var out: [2]f32 = undefined; 159 matvec(&out, &x, &W, 2, 2); 160 try expectApprox(3.0, out[0]); 161 try expectApprox(7.0, out[1]); 162} 163 164test "matvec general" { 165 // [[1, 2], [3, 4]] @ [5, 6] = [17, 39] 166 const W = [_]f32{ 1, 2, 3, 4 }; 167 const x = [_]f32{ 5, 6 }; 168 var out: [2]f32 = undefined; 169 matvec(&out, &x, &W, 2, 2); 170 try expectApprox(17.0, out[0]); 171 try expectApprox(39.0, out[1]); 172} 173 174test "maxout basic" { 175 // nO=2, nP=3: input is [1, 5, 3, 2, 8, 4] 176 // unit 0: max(1, 5, 3) = 5 177 // unit 1: max(2, 8, 4) = 8 178 const input = [_]f32{ 1, 5, 3, 2, 8, 4 }; 179 var out: [2]f32 = undefined; 180 maxout(&out, &input, 2, 3); 181 try expectApprox(5.0, out[0]); 182 try expectApprox(8.0, out[1]); 183} 184 185test "layernorm basic" { 186 // normalize [1, 2, 3, 4] with G=1, b=0 187 const x = [_]f32{ 1, 2, 3, 4 }; 188 const G = [_]f32{ 1, 1, 1, 1 }; 189 const b = [_]f32{ 0, 0, 0, 0 }; 190 var out: [4]f32 = undefined; 191 layernorm(&out, &x, &G, &b, 4); 192 193 // mean=2.5, var=1.25, result should be ~[-1.342, -0.447, 0.447, 1.342] 194 try testing.expect(out[0] < 0); 195 try testing.expect(out[1] < 0); 196 try testing.expect(out[2] > 0); 197 try testing.expect(out[3] > 0); 198 // should sum to ~0 199 try expectApprox(0.0, out[0] + out[1] + out[2] + out[3]); 200} 201 202test "seq2col basic" { 203 // 3 tokens, width 2: [[1,2], [3,4], [5,6]] 204 // token 0: [0,0, 1,2, 3,4] 205 // token 1: [1,2, 3,4, 5,6] 206 // token 2: [3,4, 5,6, 0,0] 207 const input = [_]f32{ 1, 2, 3, 4, 5, 6 }; 208 var out: [18]f32 = undefined; 209 seq2col(&out, &input, 3, 2); 210 211 // token 0 212 try expectApprox(0, out[0]); 213 try expectApprox(0, out[1]); 214 try expectApprox(1, out[2]); 215 try expectApprox(2, out[3]); 216 try expectApprox(3, out[4]); 217 try expectApprox(4, out[5]); 218 219 // token 1 220 try expectApprox(1, out[6]); 221 try expectApprox(2, out[7]); 222 try expectApprox(3, out[8]); 223 try expectApprox(4, out[9]); 224 try expectApprox(5, out[10]); 225 try expectApprox(6, out[11]); 226 227 // token 2 228 try expectApprox(3, out[12]); 229 try expectApprox(4, out[13]); 230 try expectApprox(5, out[14]); 231 try expectApprox(6, out[15]); 232 try expectApprox(0, out[16]); 233 try expectApprox(0, out[17]); 234} 235 236test "vadd basic" { 237 const a = [_]f32{ 1, 2, 3 }; 238 const b = [_]f32{ 4, 5, 6 }; 239 var out: [3]f32 = undefined; 240 vadd(&out, &a, &b, 3); 241 try expectApprox(5, out[0]); 242 try expectApprox(7, out[1]); 243 try expectApprox(9, out[2]); 244}