//! hash embedding layer — the spaCy/Thinc MultiHashEmbed. //! //! each token attribute (NORM, PREFIX, SUFFIX, SHAPE) is an opaque uint64. //! for each attribute, MurmurHash3 produces 4 bucket indices into an //! embedding table. the 4 looked-up rows are summed to produce the //! token's embedding for that feature. the 4 feature embeddings are //! then concatenated. //! //! this is the "hash trick" — no vocabulary needed, just a fixed-size //! table and a hash function. collisions are handled implicitly by //! the model learning robust representations despite aliasing. const std = @import("std"); const hash = @import("hash.zig"); const ops = @import("ops.zig"); /// a single hash embedding table. /// maps uint64 attribute IDs → nO-dimensional vectors via 4-bucket hashing. pub const HashEmbed = struct { /// embedding table, shape (nV, nO), row-major E: []const f32, /// number of rows in the table nV: usize, /// output dimensionality per feature nO: usize, /// hash seed (different per feature: NORM=8, PREFIX=9, SUFFIX=10, SHAPE=11) seed: u32, /// look up a single attribute ID, writing the nO-dim result to `out`. /// out must have len >= nO. pub fn lookup(self: HashEmbed, id: u64, out: []f32) void { std.debug.assert(out.len >= self.nO); const buckets = hash.murmurhash3_128_uint64(id, self.seed); // zero the output, then accumulate 4 rows @memset(out[0..self.nO], 0.0); inline for (0..4) |k| { const row_idx = buckets[k] % @as(u32, @intCast(self.nV)); const row = self.E[row_idx * self.nO ..][0..self.nO]; for (0..self.nO) |j| { out[j] += row[j]; } } } /// look up a batch of attribute IDs, writing results to `out`. /// out is (batch, nO) row-major. pub fn lookupBatch(self: HashEmbed, ids: []const u64, out: []f32) void { for (0..ids.len) |i| { self.lookup(ids[i], out[i * self.nO ..][0..self.nO]); } } }; /// the full MultiHashEmbed: 4 parallel HashEmbed tables (NORM, PREFIX, SUFFIX, SHAPE), /// concatenated into a 4*nO dimensional vector, then projected through /// a Maxout(nO, 4*nO, nP=3) + LayerNorm to produce the final embedding. pub const MultiHashEmbed = struct { embeds: [4]HashEmbed, /// maxout weight, shape (nO, nP, 4*nO) = (nO * nP * 4 * nO) floats maxout_W: []const f32, /// maxout bias, shape (nO, nP) = (nO * nP) floats maxout_b: []const f32, /// layernorm gain, shape (nO,) ln_G: []const f32, /// layernorm bias, shape (nO,) ln_b: []const f32, /// output width (96 for en_core_web_sm) nO: usize, /// number of maxout pieces (3 for en_core_web_sm) nP: usize, /// embed a single token's 4 attribute IDs → nO-dimensional vector. /// attrs: [NORM, PREFIX, SUFFIX, SHAPE] as uint64s. /// scratch must have len >= 4 * nO + nO * nP (concatenated embeddings + pre-maxout buffer). /// out must have len >= nO. pub fn forward(self: MultiHashEmbed, attrs: [4]u64, out: []f32, scratch: []f32) void { const concat_dim = 4 * self.nO; const pre_maxout_dim = self.nO * self.nP; std.debug.assert(scratch.len >= concat_dim + pre_maxout_dim); std.debug.assert(out.len >= self.nO); const concat = scratch[0..concat_dim]; const pre_maxout = scratch[concat_dim..][0..pre_maxout_dim]; // 4 parallel hash embeddings → concatenate for (0..4) |f| { self.embeds[f].lookup(attrs[f], concat[f * self.nO ..][0..self.nO]); } // maxout: W @ concat + b, then take max of nP pieces ops.matvec_bias(pre_maxout, concat, self.maxout_W, self.maxout_b, concat_dim, pre_maxout_dim); ops.maxout(out, pre_maxout, self.nO, self.nP); // layernorm in-place ops.layernorm(out, out, self.ln_G, self.ln_b, self.nO); } }; /// token attributes — the 4 features spaCy extracts per token. pub const TokenAttrs = struct { norm: u64, // lowercase form hash prefix: u64, // first char hash suffix: u64, // last 3 chars hash shape: u64, // character class pattern hash pub fn asArray(self: TokenAttrs) [4]u64 { return .{ self.norm, self.prefix, self.suffix, self.shape }; } }; /// compute the spaCy "shape" string for a token. /// rules: uppercase → 'X', lowercase → 'x', digit → 'd', other → literal. /// consecutive same-class chars collapse after 4 (e.g. "abcdefg" → "xxxx"). pub fn computeShape(token: []const u8, buf: []u8) []const u8 { var len: usize = 0; var last_class: u8 = 0; var class_run: u8 = 0; for (token) |c| { const class: u8 = if (c >= 'A' and c <= 'Z') 'X' else if (c >= 'a' and c <= 'z') 'x' else if (c >= '0' and c <= '9') 'd' else c; if (class == last_class) { class_run += 1; if (class_run > 4) continue; // collapse: emit at most 4 of the same class } else { last_class = class; class_run = 1; } if (len >= buf.len) break; buf[len] = class; len += 1; } return buf[0..len]; } /// extract all 4 attributes for a token. pub fn extractAttrs(token: []const u8) TokenAttrs { // NORM: lowercase var norm_buf: [512]u8 = undefined; var norm_len: usize = 0; for (token) |c| { if (norm_len >= norm_buf.len) break; norm_buf[norm_len] = if (c >= 'A' and c <= 'Z') c + 32 else c; norm_len += 1; } // PREFIX: first char var prefix_buf: [1]u8 = undefined; const prefix_len: usize = if (token.len > 0) 1 else 0; if (prefix_len > 0) prefix_buf[0] = token[0]; // SUFFIX: last 3 chars var suffix_buf: [3]u8 = undefined; const suffix_start = if (token.len >= 3) token.len - 3 else 0; const suffix = token[suffix_start..]; @memcpy(suffix_buf[0..suffix.len], suffix); // SHAPE var shape_buf: [128]u8 = undefined; const shape = computeShape(token, &shape_buf); return .{ .norm = hash.hashString(norm_buf[0..norm_len]), .prefix = hash.hashString(prefix_buf[0..prefix_len]), .suffix = hash.hashString(suffix_buf[0..suffix.len]), .shape = hash.hashString(shape), }; } // === tests === const testing = std.testing; // cross-validated against spaCy token.shape_ test "computeShape matches spacy" { var buf: [64]u8 = undefined; try testing.expectEqualStrings("XxxxxX", computeShape("SpaceX", &buf)); try testing.expectEqualStrings("xxxx", computeShape("hello", &buf)); try testing.expectEqualStrings("XXXddd", computeShape("ABC123", &buf)); try testing.expectEqualStrings("Xx", computeShape("Hi", &buf)); try testing.expectEqualStrings("XXXX", computeShape("HELLO", &buf)); try testing.expectEqualStrings("xxxx", computeShape("abcdefg", &buf)); } test "extractAttrs deterministic" { const a1 = extractAttrs("Obama"); const a2 = extractAttrs("Obama"); try testing.expectEqual(a1.norm, a2.norm); try testing.expectEqual(a1.prefix, a2.prefix); try testing.expectEqual(a1.suffix, a2.suffix); try testing.expectEqual(a1.shape, a2.shape); } test "extractAttrs different tokens" { const a1 = extractAttrs("Obama"); const a2 = extractAttrs("Trump"); // norms should differ (different lowercase strings) try testing.expect(a1.norm != a2.norm); } test "HashEmbed lookup" { // tiny 4-row, 2-dim embedding table const table = [_]f32{ 1, 0, // row 0 0, 1, // row 1 2, 3, // row 2 4, 5, // row 3 }; const embed = HashEmbed{ .E = &table, .nV = 4, .nO = 2, .seed = 8, }; var out: [2]f32 = undefined; embed.lookup(42, &out); // should produce some non-zero result (sum of 4 rows) try testing.expect(out[0] != 0 or out[1] != 0); // deterministic var out2: [2]f32 = undefined; embed.lookup(42, &out2); try testing.expectEqual(out[0], out2[0]); try testing.expectEqual(out[1], out2[1]); }