this repo has no description
1//! hash embedding layer — the spaCy/Thinc MultiHashEmbed.
2//!
3//! each token attribute (NORM, PREFIX, SUFFIX, SHAPE) is an opaque uint64.
4//! for each attribute, MurmurHash3 produces 4 bucket indices into an
5//! embedding table. the 4 looked-up rows are summed to produce the
6//! token's embedding for that feature. the 4 feature embeddings are
7//! then concatenated.
8//!
9//! this is the "hash trick" — no vocabulary needed, just a fixed-size
10//! table and a hash function. collisions are handled implicitly by
11//! the model learning robust representations despite aliasing.
12
13const std = @import("std");
14const hash = @import("hash.zig");
15const ops = @import("ops.zig");
16
17/// a single hash embedding table.
18/// maps uint64 attribute IDs → nO-dimensional vectors via 4-bucket hashing.
19pub const HashEmbed = struct {
20 /// embedding table, shape (nV, nO), row-major
21 E: []const f32,
22 /// number of rows in the table
23 nV: usize,
24 /// output dimensionality per feature
25 nO: usize,
26 /// hash seed (different per feature: NORM=8, PREFIX=9, SUFFIX=10, SHAPE=11)
27 seed: u32,
28
29 /// look up a single attribute ID, writing the nO-dim result to `out`.
30 /// out must have len >= nO.
31 pub fn lookup(self: HashEmbed, id: u64, out: []f32) void {
32 std.debug.assert(out.len >= self.nO);
33
34 const buckets = hash.murmurhash3_128_uint64(id, self.seed);
35
36 // zero the output, then accumulate 4 rows
37 @memset(out[0..self.nO], 0.0);
38
39 inline for (0..4) |k| {
40 const row_idx = buckets[k] % @as(u32, @intCast(self.nV));
41 const row = self.E[row_idx * self.nO ..][0..self.nO];
42 for (0..self.nO) |j| {
43 out[j] += row[j];
44 }
45 }
46 }
47
48 /// look up a batch of attribute IDs, writing results to `out`.
49 /// out is (batch, nO) row-major.
50 pub fn lookupBatch(self: HashEmbed, ids: []const u64, out: []f32) void {
51 for (0..ids.len) |i| {
52 self.lookup(ids[i], out[i * self.nO ..][0..self.nO]);
53 }
54 }
55};
56
57/// the full MultiHashEmbed: 4 parallel HashEmbed tables (NORM, PREFIX, SUFFIX, SHAPE),
58/// concatenated into a 4*nO dimensional vector, then projected through
59/// a Maxout(nO, 4*nO, nP=3) + LayerNorm to produce the final embedding.
60pub const MultiHashEmbed = struct {
61 embeds: [4]HashEmbed,
62 /// maxout weight, shape (nO, nP, 4*nO) = (nO * nP * 4 * nO) floats
63 maxout_W: []const f32,
64 /// maxout bias, shape (nO, nP) = (nO * nP) floats
65 maxout_b: []const f32,
66 /// layernorm gain, shape (nO,)
67 ln_G: []const f32,
68 /// layernorm bias, shape (nO,)
69 ln_b: []const f32,
70 /// output width (96 for en_core_web_sm)
71 nO: usize,
72 /// number of maxout pieces (3 for en_core_web_sm)
73 nP: usize,
74
75 /// embed a single token's 4 attribute IDs → nO-dimensional vector.
76 /// attrs: [NORM, PREFIX, SUFFIX, SHAPE] as uint64s.
77 /// scratch must have len >= 4 * nO + nO * nP (concatenated embeddings + pre-maxout buffer).
78 /// out must have len >= nO.
79 pub fn forward(self: MultiHashEmbed, attrs: [4]u64, out: []f32, scratch: []f32) void {
80 const concat_dim = 4 * self.nO;
81 const pre_maxout_dim = self.nO * self.nP;
82 std.debug.assert(scratch.len >= concat_dim + pre_maxout_dim);
83 std.debug.assert(out.len >= self.nO);
84
85 const concat = scratch[0..concat_dim];
86 const pre_maxout = scratch[concat_dim..][0..pre_maxout_dim];
87
88 // 4 parallel hash embeddings → concatenate
89 for (0..4) |f| {
90 self.embeds[f].lookup(attrs[f], concat[f * self.nO ..][0..self.nO]);
91 }
92
93 // maxout: W @ concat + b, then take max of nP pieces
94 ops.matvec_bias(pre_maxout, concat, self.maxout_W, self.maxout_b, concat_dim, pre_maxout_dim);
95 ops.maxout(out, pre_maxout, self.nO, self.nP);
96
97 // layernorm in-place
98 ops.layernorm(out, out, self.ln_G, self.ln_b, self.nO);
99 }
100};
101
102/// token attributes — the 4 features spaCy extracts per token.
103pub const TokenAttrs = struct {
104 norm: u64, // lowercase form hash
105 prefix: u64, // first char hash
106 suffix: u64, // last 3 chars hash
107 shape: u64, // character class pattern hash
108
109 pub fn asArray(self: TokenAttrs) [4]u64 {
110 return .{ self.norm, self.prefix, self.suffix, self.shape };
111 }
112};
113
114/// compute the spaCy "shape" string for a token.
115/// rules: uppercase → 'X', lowercase → 'x', digit → 'd', other → literal.
116/// consecutive same-class chars collapse after 4 (e.g. "abcdefg" → "xxxx").
117pub fn computeShape(token: []const u8, buf: []u8) []const u8 {
118 var len: usize = 0;
119 var last_class: u8 = 0;
120 var class_run: u8 = 0;
121
122 for (token) |c| {
123 const class: u8 = if (c >= 'A' and c <= 'Z')
124 'X'
125 else if (c >= 'a' and c <= 'z')
126 'x'
127 else if (c >= '0' and c <= '9')
128 'd'
129 else
130 c;
131
132 if (class == last_class) {
133 class_run += 1;
134 if (class_run > 4) continue; // collapse: emit at most 4 of the same class
135 } else {
136 last_class = class;
137 class_run = 1;
138 }
139
140 if (len >= buf.len) break;
141 buf[len] = class;
142 len += 1;
143 }
144
145 return buf[0..len];
146}
147
148/// extract all 4 attributes for a token.
149pub fn extractAttrs(token: []const u8) TokenAttrs {
150 // NORM: lowercase
151 var norm_buf: [512]u8 = undefined;
152 var norm_len: usize = 0;
153 for (token) |c| {
154 if (norm_len >= norm_buf.len) break;
155 norm_buf[norm_len] = if (c >= 'A' and c <= 'Z') c + 32 else c;
156 norm_len += 1;
157 }
158
159 // PREFIX: first char
160 var prefix_buf: [1]u8 = undefined;
161 const prefix_len: usize = if (token.len > 0) 1 else 0;
162 if (prefix_len > 0) prefix_buf[0] = token[0];
163
164 // SUFFIX: last 3 chars
165 var suffix_buf: [3]u8 = undefined;
166 const suffix_start = if (token.len >= 3) token.len - 3 else 0;
167 const suffix = token[suffix_start..];
168 @memcpy(suffix_buf[0..suffix.len], suffix);
169
170 // SHAPE
171 var shape_buf: [128]u8 = undefined;
172 const shape = computeShape(token, &shape_buf);
173
174 return .{
175 .norm = hash.hashString(norm_buf[0..norm_len]),
176 .prefix = hash.hashString(prefix_buf[0..prefix_len]),
177 .suffix = hash.hashString(suffix_buf[0..suffix.len]),
178 .shape = hash.hashString(shape),
179 };
180}
181
182// === tests ===
183
184const testing = std.testing;
185
186// cross-validated against spaCy token.shape_
187test "computeShape matches spacy" {
188 var buf: [64]u8 = undefined;
189 try testing.expectEqualStrings("XxxxxX", computeShape("SpaceX", &buf));
190 try testing.expectEqualStrings("xxxx", computeShape("hello", &buf));
191 try testing.expectEqualStrings("XXXddd", computeShape("ABC123", &buf));
192 try testing.expectEqualStrings("Xx", computeShape("Hi", &buf));
193 try testing.expectEqualStrings("XXXX", computeShape("HELLO", &buf));
194 try testing.expectEqualStrings("xxxx", computeShape("abcdefg", &buf));
195}
196
197test "extractAttrs deterministic" {
198 const a1 = extractAttrs("Obama");
199 const a2 = extractAttrs("Obama");
200 try testing.expectEqual(a1.norm, a2.norm);
201 try testing.expectEqual(a1.prefix, a2.prefix);
202 try testing.expectEqual(a1.suffix, a2.suffix);
203 try testing.expectEqual(a1.shape, a2.shape);
204}
205
206test "extractAttrs different tokens" {
207 const a1 = extractAttrs("Obama");
208 const a2 = extractAttrs("Trump");
209 // norms should differ (different lowercase strings)
210 try testing.expect(a1.norm != a2.norm);
211}
212
213test "HashEmbed lookup" {
214 // tiny 4-row, 2-dim embedding table
215 const table = [_]f32{
216 1, 0, // row 0
217 0, 1, // row 1
218 2, 3, // row 2
219 4, 5, // row 3
220 };
221 const embed = HashEmbed{
222 .E = &table,
223 .nV = 4,
224 .nO = 2,
225 .seed = 8,
226 };
227
228 var out: [2]f32 = undefined;
229 embed.lookup(42, &out);
230 // should produce some non-zero result (sum of 4 rows)
231 try testing.expect(out[0] != 0 or out[1] != 0);
232
233 // deterministic
234 var out2: [2]f32 = undefined;
235 embed.lookup(42, &out2);
236 try testing.expectEqual(out[0], out2[0]);
237 try testing.expectEqual(out[1], out2[1]);
238}