this repo has no description
at main 238 lines 8.2 kB view raw
1//! hash embedding layer — the spaCy/Thinc MultiHashEmbed. 2//! 3//! each token attribute (NORM, PREFIX, SUFFIX, SHAPE) is an opaque uint64. 4//! for each attribute, MurmurHash3 produces 4 bucket indices into an 5//! embedding table. the 4 looked-up rows are summed to produce the 6//! token's embedding for that feature. the 4 feature embeddings are 7//! then concatenated. 8//! 9//! this is the "hash trick" — no vocabulary needed, just a fixed-size 10//! table and a hash function. collisions are handled implicitly by 11//! the model learning robust representations despite aliasing. 12 13const std = @import("std"); 14const hash = @import("hash.zig"); 15const ops = @import("ops.zig"); 16 17/// a single hash embedding table. 18/// maps uint64 attribute IDs → nO-dimensional vectors via 4-bucket hashing. 19pub const HashEmbed = struct { 20 /// embedding table, shape (nV, nO), row-major 21 E: []const f32, 22 /// number of rows in the table 23 nV: usize, 24 /// output dimensionality per feature 25 nO: usize, 26 /// hash seed (different per feature: NORM=8, PREFIX=9, SUFFIX=10, SHAPE=11) 27 seed: u32, 28 29 /// look up a single attribute ID, writing the nO-dim result to `out`. 30 /// out must have len >= nO. 31 pub fn lookup(self: HashEmbed, id: u64, out: []f32) void { 32 std.debug.assert(out.len >= self.nO); 33 34 const buckets = hash.murmurhash3_128_uint64(id, self.seed); 35 36 // zero the output, then accumulate 4 rows 37 @memset(out[0..self.nO], 0.0); 38 39 inline for (0..4) |k| { 40 const row_idx = buckets[k] % @as(u32, @intCast(self.nV)); 41 const row = self.E[row_idx * self.nO ..][0..self.nO]; 42 for (0..self.nO) |j| { 43 out[j] += row[j]; 44 } 45 } 46 } 47 48 /// look up a batch of attribute IDs, writing results to `out`. 49 /// out is (batch, nO) row-major. 50 pub fn lookupBatch(self: HashEmbed, ids: []const u64, out: []f32) void { 51 for (0..ids.len) |i| { 52 self.lookup(ids[i], out[i * self.nO ..][0..self.nO]); 53 } 54 } 55}; 56 57/// the full MultiHashEmbed: 4 parallel HashEmbed tables (NORM, PREFIX, SUFFIX, SHAPE), 58/// concatenated into a 4*nO dimensional vector, then projected through 59/// a Maxout(nO, 4*nO, nP=3) + LayerNorm to produce the final embedding. 60pub const MultiHashEmbed = struct { 61 embeds: [4]HashEmbed, 62 /// maxout weight, shape (nO, nP, 4*nO) = (nO * nP * 4 * nO) floats 63 maxout_W: []const f32, 64 /// maxout bias, shape (nO, nP) = (nO * nP) floats 65 maxout_b: []const f32, 66 /// layernorm gain, shape (nO,) 67 ln_G: []const f32, 68 /// layernorm bias, shape (nO,) 69 ln_b: []const f32, 70 /// output width (96 for en_core_web_sm) 71 nO: usize, 72 /// number of maxout pieces (3 for en_core_web_sm) 73 nP: usize, 74 75 /// embed a single token's 4 attribute IDs → nO-dimensional vector. 76 /// attrs: [NORM, PREFIX, SUFFIX, SHAPE] as uint64s. 77 /// scratch must have len >= 4 * nO + nO * nP (concatenated embeddings + pre-maxout buffer). 78 /// out must have len >= nO. 79 pub fn forward(self: MultiHashEmbed, attrs: [4]u64, out: []f32, scratch: []f32) void { 80 const concat_dim = 4 * self.nO; 81 const pre_maxout_dim = self.nO * self.nP; 82 std.debug.assert(scratch.len >= concat_dim + pre_maxout_dim); 83 std.debug.assert(out.len >= self.nO); 84 85 const concat = scratch[0..concat_dim]; 86 const pre_maxout = scratch[concat_dim..][0..pre_maxout_dim]; 87 88 // 4 parallel hash embeddings → concatenate 89 for (0..4) |f| { 90 self.embeds[f].lookup(attrs[f], concat[f * self.nO ..][0..self.nO]); 91 } 92 93 // maxout: W @ concat + b, then take max of nP pieces 94 ops.matvec_bias(pre_maxout, concat, self.maxout_W, self.maxout_b, concat_dim, pre_maxout_dim); 95 ops.maxout(out, pre_maxout, self.nO, self.nP); 96 97 // layernorm in-place 98 ops.layernorm(out, out, self.ln_G, self.ln_b, self.nO); 99 } 100}; 101 102/// token attributes — the 4 features spaCy extracts per token. 103pub const TokenAttrs = struct { 104 norm: u64, // lowercase form hash 105 prefix: u64, // first char hash 106 suffix: u64, // last 3 chars hash 107 shape: u64, // character class pattern hash 108 109 pub fn asArray(self: TokenAttrs) [4]u64 { 110 return .{ self.norm, self.prefix, self.suffix, self.shape }; 111 } 112}; 113 114/// compute the spaCy "shape" string for a token. 115/// rules: uppercase → 'X', lowercase → 'x', digit → 'd', other → literal. 116/// consecutive same-class chars collapse after 4 (e.g. "abcdefg" → "xxxx"). 117pub fn computeShape(token: []const u8, buf: []u8) []const u8 { 118 var len: usize = 0; 119 var last_class: u8 = 0; 120 var class_run: u8 = 0; 121 122 for (token) |c| { 123 const class: u8 = if (c >= 'A' and c <= 'Z') 124 'X' 125 else if (c >= 'a' and c <= 'z') 126 'x' 127 else if (c >= '0' and c <= '9') 128 'd' 129 else 130 c; 131 132 if (class == last_class) { 133 class_run += 1; 134 if (class_run > 4) continue; // collapse: emit at most 4 of the same class 135 } else { 136 last_class = class; 137 class_run = 1; 138 } 139 140 if (len >= buf.len) break; 141 buf[len] = class; 142 len += 1; 143 } 144 145 return buf[0..len]; 146} 147 148/// extract all 4 attributes for a token. 149pub fn extractAttrs(token: []const u8) TokenAttrs { 150 // NORM: lowercase 151 var norm_buf: [512]u8 = undefined; 152 var norm_len: usize = 0; 153 for (token) |c| { 154 if (norm_len >= norm_buf.len) break; 155 norm_buf[norm_len] = if (c >= 'A' and c <= 'Z') c + 32 else c; 156 norm_len += 1; 157 } 158 159 // PREFIX: first char 160 var prefix_buf: [1]u8 = undefined; 161 const prefix_len: usize = if (token.len > 0) 1 else 0; 162 if (prefix_len > 0) prefix_buf[0] = token[0]; 163 164 // SUFFIX: last 3 chars 165 var suffix_buf: [3]u8 = undefined; 166 const suffix_start = if (token.len >= 3) token.len - 3 else 0; 167 const suffix = token[suffix_start..]; 168 @memcpy(suffix_buf[0..suffix.len], suffix); 169 170 // SHAPE 171 var shape_buf: [128]u8 = undefined; 172 const shape = computeShape(token, &shape_buf); 173 174 return .{ 175 .norm = hash.hashString(norm_buf[0..norm_len]), 176 .prefix = hash.hashString(prefix_buf[0..prefix_len]), 177 .suffix = hash.hashString(suffix_buf[0..suffix.len]), 178 .shape = hash.hashString(shape), 179 }; 180} 181 182// === tests === 183 184const testing = std.testing; 185 186// cross-validated against spaCy token.shape_ 187test "computeShape matches spacy" { 188 var buf: [64]u8 = undefined; 189 try testing.expectEqualStrings("XxxxxX", computeShape("SpaceX", &buf)); 190 try testing.expectEqualStrings("xxxx", computeShape("hello", &buf)); 191 try testing.expectEqualStrings("XXXddd", computeShape("ABC123", &buf)); 192 try testing.expectEqualStrings("Xx", computeShape("Hi", &buf)); 193 try testing.expectEqualStrings("XXXX", computeShape("HELLO", &buf)); 194 try testing.expectEqualStrings("xxxx", computeShape("abcdefg", &buf)); 195} 196 197test "extractAttrs deterministic" { 198 const a1 = extractAttrs("Obama"); 199 const a2 = extractAttrs("Obama"); 200 try testing.expectEqual(a1.norm, a2.norm); 201 try testing.expectEqual(a1.prefix, a2.prefix); 202 try testing.expectEqual(a1.suffix, a2.suffix); 203 try testing.expectEqual(a1.shape, a2.shape); 204} 205 206test "extractAttrs different tokens" { 207 const a1 = extractAttrs("Obama"); 208 const a2 = extractAttrs("Trump"); 209 // norms should differ (different lowercase strings) 210 try testing.expect(a1.norm != a2.norm); 211} 212 213test "HashEmbed lookup" { 214 // tiny 4-row, 2-dim embedding table 215 const table = [_]f32{ 216 1, 0, // row 0 217 0, 1, // row 1 218 2, 3, // row 2 219 4, 5, // row 3 220 }; 221 const embed = HashEmbed{ 222 .E = &table, 223 .nV = 4, 224 .nO = 2, 225 .seed = 8, 226 }; 227 228 var out: [2]f32 = undefined; 229 embed.lookup(42, &out); 230 // should produce some non-zero result (sum of 4 rows) 231 try testing.expect(out[0] != 0 or out[1] != 0); 232 233 // deterministic 234 var out2: [2]f32 = undefined; 235 embed.lookup(42, &out2); 236 try testing.expectEqual(out[0], out2[0]); 237 try testing.expectEqual(out[1], out2[1]); 238}