馃悕馃悕馃悕
1import torch
2import numpy as np
3import math
4
5def id_(x):
6 return x
7
8def lerp(a, b, t):
9 return (1-t)*a + t*b
10
11def shifted_sigmoid(shift_x, shift_y, rate_h, range_v):
12 def f_sigmoid(x):
13 exp = ((shift_x - x) / rate_h).exp()
14 sig = 1 / (1 + exp)
15 return shift_y + range_v * (sig - 0.5)
16 return f_sigmoid
17
18def sigmoid(rate_h, range_v):
19 return shifted_sigmoid(0,0,rate_h,range_v)
20
21
22def scale_f(f, scale_x, scale_y):
23 def _f(x):
24 return scale_y * f(x / scale_x)
25 return _f
26
27"""
28distort: index of singular value [0,N] => strength of singular value (number, usually in [0.0, 1.0])
29
30* makes wild assumptions about shape of tensor
31"""
32def svd_distort(tensor, distort):
33 (U, S, Vh) = torch.linalg.svd(tensor)
34
35 svd_mask = torch.ones_like(S)
36 for b in range(len(S)):
37 for r in range(len(S[b])):
38 l = len(S[b][r])
39 for i in range(l):
40 svd_mask[b][r][i] = distort(i)
41
42 return U @ torch.diag_embed(S * svd_mask) @ Vh
43
44def scale_embeddings(tensor, scaling):
45 print(tensor.shape)
46 for i in range(tensor.shape[0]):
47 tensor[i] *= scaling(i)
48
49def svd_distort_embeddings(tensor, distort):
50 out = torch.clone(tensor)
51
52 #for r in range(len(tensor)):
53 (U, S, Vh) = torch.linalg.svd(tensor)
54 distortion_mask = torch.ones_like(S)
55
56 for i in range(len(distortion_mask)):
57 distortion_mask[i] = distort(i)
58
59 S_diag_expanded = torch.zeros_like(tensor)
60 S_diag_expanded[:, :S.shape[0]] = torch.diag(S * distortion_mask)
61
62 out = U @ S_diag_expanded @ Vh
63
64 return out
65
66 #l = len(S[r])
67 #for i in range(l):
68 #svd_mask[r][i] = distort(i)
69
70 #return U @ torch.diag_embed(S * svd_mask) @ Vh
71
72def index_interpolate(source, index):
73 frac, whole = math.modf(index)
74 if frac == 0:
75 return source[int(whole)]
76 return lerp(source[int(whole)], source[int(whole)+1], frac)