馃悕馃悕馃悕
at dev 76 lines 2.0 kB view raw
1import torch 2import numpy as np 3import math 4 5def id_(x): 6 return x 7 8def lerp(a, b, t): 9 return (1-t)*a + t*b 10 11def shifted_sigmoid(shift_x, shift_y, rate_h, range_v): 12 def f_sigmoid(x): 13 exp = ((shift_x - x) / rate_h).exp() 14 sig = 1 / (1 + exp) 15 return shift_y + range_v * (sig - 0.5) 16 return f_sigmoid 17 18def sigmoid(rate_h, range_v): 19 return shifted_sigmoid(0,0,rate_h,range_v) 20 21 22def scale_f(f, scale_x, scale_y): 23 def _f(x): 24 return scale_y * f(x / scale_x) 25 return _f 26 27""" 28distort: index of singular value [0,N] => strength of singular value (number, usually in [0.0, 1.0]) 29 30* makes wild assumptions about shape of tensor 31""" 32def svd_distort(tensor, distort): 33 (U, S, Vh) = torch.linalg.svd(tensor) 34 35 svd_mask = torch.ones_like(S) 36 for b in range(len(S)): 37 for r in range(len(S[b])): 38 l = len(S[b][r]) 39 for i in range(l): 40 svd_mask[b][r][i] = distort(i) 41 42 return U @ torch.diag_embed(S * svd_mask) @ Vh 43 44def scale_embeddings(tensor, scaling): 45 print(tensor.shape) 46 for i in range(tensor.shape[0]): 47 tensor[i] *= scaling(i) 48 49def svd_distort_embeddings(tensor, distort): 50 out = torch.clone(tensor) 51 52 #for r in range(len(tensor)): 53 (U, S, Vh) = torch.linalg.svd(tensor) 54 distortion_mask = torch.ones_like(S) 55 56 for i in range(len(distortion_mask)): 57 distortion_mask[i] = distort(i) 58 59 S_diag_expanded = torch.zeros_like(tensor) 60 S_diag_expanded[:, :S.shape[0]] = torch.diag(S * distortion_mask) 61 62 out = U @ S_diag_expanded @ Vh 63 64 return out 65 66 #l = len(S[r]) 67 #for i in range(l): 68 #svd_mask[r][i] = distort(i) 69 70 #return U @ torch.diag_embed(S * svd_mask) @ Vh 71 72def index_interpolate(source, index): 73 frac, whole = math.modf(index) 74 if frac == 0: 75 return source[int(whole)] 76 return lerp(source[int(whole)], source[int(whole)+1], frac)