a collection of generative art scripts
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 197 lines 5.4 kB view raw
1 2# Based on a paper by Jascha Sohl-Dickstein 3# https://arxiv.org/abs/2402.06184 4# The boundary of neural network trainability is fractal 5 6# also a blog post: https://sohl-dickstein.github.io/2024/02/12/fractal.html 7 8import math 9 10import torch 11 12from torch.func import vmap, grad 13 14from pyt.lib.spaces import map_space, grid 15from pyt.lib.util import msave, save 16 17torch.manual_seed(69420) 18 19dev = torch.device("cuda:0") 20 21t_real = torch.float64 22 23 24torch.set_float32_matmul_precision('high') 25 26 27# alphas: mean field neural network parametrization; reference [9] from Sohl-Dickstein paper: 28# Song Mei, Andrea Montanari, and Phan-Minh Nguyen. A 29# mean field view of the landscape of two-layer neural networks. 30# Proceedings of the National Academy of Sciences, 115(33): 31# E7665–E7671, 2018. 32 33def persistent(): 34 nonlin = "tanh" 35 network_n = 16 # original paper: 16 36 training_steps = 100 37 38 alpha_1 = 1 / network_n 39 40 match nonlin: 41 case "tanh": # TODO others 42 sigma = torch.tanh 43 alpha_0 = math.sqrt(2/network_n) 44 case _: 45 alpha_0 = math.sqrt(1/network_n) 46 47 def y_pred(x, W_0, W_1): 48 return alpha_1 * W_1 @ sigma(alpha_0 * W_0 @ x) 49 50 def calculate_loss(D, W_0, W_1): 51 x, y = D 52 loss = ((y - y_pred(x, W_0, W_1))**2).mean() 53 return (loss, loss) 54 55 run_network = grad(calculate_loss, argnums=(1,2), has_aux=True) 56 57 58 def train_network(eta_0, eta_1, D, W_0_init, W_1_init, training_steps): 59 W_0 = W_0_init 60 W_1 = W_1_init 61 62 loss_init = calculate_loss(D, W_0, W_1)[0].clamp(min=1e-8) 63 64 loss_record = torch.zeros((), device=W_0.device) 65 66 for index in range(training_steps): 67 ((grad_W_0, grad_W_1), loss) = run_network(D, W_0, W_1) 68 W_0 = W_0 - grad_W_0 * eta_0 69 W_1 = W_1 - grad_W_1 * eta_1 70 if index >= training_steps - 20: 71 loss_record = loss_record + loss / loss_init 72 73 return (loss_record / 20) 74 75 train_many = vmap(train_network, in_dims=(0, 0, None, 0, 0, None)) 76 77 train_many = torch.compile(train_many) 78 79#span = 250, 250 80#origin = (span[0] // 2) - 10, (span[1] // 2) - 10 81 82#span = 1e7, 1e7 83#origin = 0, 0 84 85span = 5.0 / 3, 5.0 / 3 86origin = 2.0, 2.0 87 88stretch = 1, 1 89 90zooms = [] 91 92scale = 4096 93 94save_partials = False 95 96# TODO save losses directly as safetensors 97 98def render_fractal(seed): 99 torch.manual_seed(seed) 100 101 # always generate random data at fp64, then convert, so that rng is at least somewhat consistent 102 # across dtypes 103 104 dataset_x = torch.randn([network_n, dataset_size], dtype=torch.float64, device=dev).to(t_real) 105 dataset_y = torch.randn((dataset_size,), dtype=torch.float64, device=dev).to(t_real) 106 107 D = (dataset_x, dataset_y) 108 109 _W_0 = torch.randn([network_n, network_n], dtype=torch.float64, device=dev).to(t_real) 110 _W_1 = torch.randn([1, network_n], dtype=torch.float64, device=dev).to(t_real) 111 112 mapping = map_space(origin, span, zooms, stretch, scale) 113 (_, (height,width)) = mapping 114 115 canvas = torch.zeros([height, width], dtype=t_real, device=dev) 116 117 # eta = learning rate 118 etas = grid(mapping).to(dev) 119 etas = torch.pow(10.0, etas) 120 121 eta_0 = etas[:,:,1] 122 eta_1 = etas[:,:,0].flip(0) 123 124 cols_per_chunk = 16 125 126 convergence_threshold = 1.0 127 128 last_report = 0 129 130 for col_start in range(0, width, cols_per_chunk): 131 col_end = col_start + cols_per_chunk 132 e0 = eta_0[:, col_start:col_end].reshape(-1) 133 e1 = eta_1[:, col_start:col_end].reshape(-1) 134 135 _W_0_batch = _W_0.unsqueeze(0).expand(height * cols_per_chunk, -1, -1).contiguous() 136 _W_1_batch = _W_1.unsqueeze(0).expand(height * cols_per_chunk, -1, -1).contiguous() 137 138 res = train_many(e0, e1, D, _W_0_batch, _W_1_batch, training_steps) 139 res = res.nan_to_num(nan=1e6, posinf=1e6, neginf=-1e6) 140 141 canvas[:, col_start:col_end] = res.reshape(height, cols_per_chunk) 142 143 if save_partials and col_end > last_report + 64: 144 last_report = col_end 145 c = canvas[:, 0:col_end].clone() 146 147 conv = c < convergence_threshold 148 149 t_conv = 1 - c / convergence_threshold 150 t_div = torch.log1p(c - convergence_threshold) / torch.log1p(torch.tensor(1e6)) 151 152 t_conv = (t_conv * conv) 153 t_div = (t_div * ~conv) 154 155 t_conv /= t_conv.max().clamp(min=1e-6) 156 t_div /= t_div.max().clamp(min=1e-6) 157 158 msave(t_conv, f"{run_dir}/{training_steps:05d}_conv_{col_start}") 159 msave(t_div, f"{run_dir}/{training_steps:05d}_div_{col_start}") 160 161 162 c = canvas.clone() 163 164 ''' 165 c = torch.log1p(c) / torch.log1p(torch.tensor(1e6)) 166 c /= c.max().clamp(min=1e-6) 167 msave(c, f"{run_dir}/{training_steps:05d}") 168 ''' 169 170 conv = c < convergence_threshold 171 172 t_conv = 1 - c / convergence_threshold 173 t_div = torch.log1p(c - convergence_threshold) / torch.log1p(torch.tensor(1e6)) 174 175 t_conv = (t_conv * conv) 176 t_div = (t_div * ~conv) 177 178 t_conv /= t_conv.max().clamp(min=1e-6) 179 t_div /= t_div.max().clamp(min=1e-6) 180 181 #msave(t_conv, f"{run_dir}/{training_steps:05d}_conv_final") 182 #msave(t_div, f"{run_dir}/{training_steps:05d}_div_final") 183 save(torch.stack((t_div, t_conv*0.8, t_conv)).to(dtype=torch.float), f"{run_dir}/{seed:06d}") 184 185 186def main(): 187 188 dataset_size = network_n * (network_n + 1) 189 190 schedule(render_fractal, range(1)) 191 192 193 194def final(): 195 # todo colorize 196 pass 197