a collection of generative art scripts
1
2# Based on a paper by Jascha Sohl-Dickstein
3# https://arxiv.org/abs/2402.06184
4# The boundary of neural network trainability is fractal
5
6# also a blog post: https://sohl-dickstein.github.io/2024/02/12/fractal.html
7
8import math
9
10import torch
11
12from torch.func import vmap, grad
13
14from pyt.lib.spaces import map_space, grid
15from pyt.lib.util import msave, save
16
17torch.manual_seed(69420)
18
19dev = torch.device("cuda:0")
20
21t_real = torch.float64
22
23
24torch.set_float32_matmul_precision('high')
25
26
27# alphas: mean field neural network parametrization; reference [9] from Sohl-Dickstein paper:
28# Song Mei, Andrea Montanari, and Phan-Minh Nguyen. A
29# mean field view of the landscape of two-layer neural networks.
30# Proceedings of the National Academy of Sciences, 115(33):
31# E7665–E7671, 2018.
32
33def persistent():
34 nonlin = "tanh"
35 network_n = 16 # original paper: 16
36 training_steps = 100
37
38 alpha_1 = 1 / network_n
39
40 match nonlin:
41 case "tanh": # TODO others
42 sigma = torch.tanh
43 alpha_0 = math.sqrt(2/network_n)
44 case _:
45 alpha_0 = math.sqrt(1/network_n)
46
47 def y_pred(x, W_0, W_1):
48 return alpha_1 * W_1 @ sigma(alpha_0 * W_0 @ x)
49
50 def calculate_loss(D, W_0, W_1):
51 x, y = D
52 loss = ((y - y_pred(x, W_0, W_1))**2).mean()
53 return (loss, loss)
54
55 run_network = grad(calculate_loss, argnums=(1,2), has_aux=True)
56
57
58 def train_network(eta_0, eta_1, D, W_0_init, W_1_init, training_steps):
59 W_0 = W_0_init
60 W_1 = W_1_init
61
62 loss_init = calculate_loss(D, W_0, W_1)[0].clamp(min=1e-8)
63
64 loss_record = torch.zeros((), device=W_0.device)
65
66 for index in range(training_steps):
67 ((grad_W_0, grad_W_1), loss) = run_network(D, W_0, W_1)
68 W_0 = W_0 - grad_W_0 * eta_0
69 W_1 = W_1 - grad_W_1 * eta_1
70 if index >= training_steps - 20:
71 loss_record = loss_record + loss / loss_init
72
73 return (loss_record / 20)
74
75 train_many = vmap(train_network, in_dims=(0, 0, None, 0, 0, None))
76
77 train_many = torch.compile(train_many)
78
79#span = 250, 250
80#origin = (span[0] // 2) - 10, (span[1] // 2) - 10
81
82#span = 1e7, 1e7
83#origin = 0, 0
84
85span = 5.0 / 3, 5.0 / 3
86origin = 2.0, 2.0
87
88stretch = 1, 1
89
90zooms = []
91
92scale = 4096
93
94save_partials = False
95
96# TODO save losses directly as safetensors
97
98def render_fractal(seed):
99 torch.manual_seed(seed)
100
101 # always generate random data at fp64, then convert, so that rng is at least somewhat consistent
102 # across dtypes
103
104 dataset_x = torch.randn([network_n, dataset_size], dtype=torch.float64, device=dev).to(t_real)
105 dataset_y = torch.randn((dataset_size,), dtype=torch.float64, device=dev).to(t_real)
106
107 D = (dataset_x, dataset_y)
108
109 _W_0 = torch.randn([network_n, network_n], dtype=torch.float64, device=dev).to(t_real)
110 _W_1 = torch.randn([1, network_n], dtype=torch.float64, device=dev).to(t_real)
111
112 mapping = map_space(origin, span, zooms, stretch, scale)
113 (_, (height,width)) = mapping
114
115 canvas = torch.zeros([height, width], dtype=t_real, device=dev)
116
117 # eta = learning rate
118 etas = grid(mapping).to(dev)
119 etas = torch.pow(10.0, etas)
120
121 eta_0 = etas[:,:,1]
122 eta_1 = etas[:,:,0].flip(0)
123
124 cols_per_chunk = 16
125
126 convergence_threshold = 1.0
127
128 last_report = 0
129
130 for col_start in range(0, width, cols_per_chunk):
131 col_end = col_start + cols_per_chunk
132 e0 = eta_0[:, col_start:col_end].reshape(-1)
133 e1 = eta_1[:, col_start:col_end].reshape(-1)
134
135 _W_0_batch = _W_0.unsqueeze(0).expand(height * cols_per_chunk, -1, -1).contiguous()
136 _W_1_batch = _W_1.unsqueeze(0).expand(height * cols_per_chunk, -1, -1).contiguous()
137
138 res = train_many(e0, e1, D, _W_0_batch, _W_1_batch, training_steps)
139 res = res.nan_to_num(nan=1e6, posinf=1e6, neginf=-1e6)
140
141 canvas[:, col_start:col_end] = res.reshape(height, cols_per_chunk)
142
143 if save_partials and col_end > last_report + 64:
144 last_report = col_end
145 c = canvas[:, 0:col_end].clone()
146
147 conv = c < convergence_threshold
148
149 t_conv = 1 - c / convergence_threshold
150 t_div = torch.log1p(c - convergence_threshold) / torch.log1p(torch.tensor(1e6))
151
152 t_conv = (t_conv * conv)
153 t_div = (t_div * ~conv)
154
155 t_conv /= t_conv.max().clamp(min=1e-6)
156 t_div /= t_div.max().clamp(min=1e-6)
157
158 msave(t_conv, f"{run_dir}/{training_steps:05d}_conv_{col_start}")
159 msave(t_div, f"{run_dir}/{training_steps:05d}_div_{col_start}")
160
161
162 c = canvas.clone()
163
164 '''
165 c = torch.log1p(c) / torch.log1p(torch.tensor(1e6))
166 c /= c.max().clamp(min=1e-6)
167 msave(c, f"{run_dir}/{training_steps:05d}")
168 '''
169
170 conv = c < convergence_threshold
171
172 t_conv = 1 - c / convergence_threshold
173 t_div = torch.log1p(c - convergence_threshold) / torch.log1p(torch.tensor(1e6))
174
175 t_conv = (t_conv * conv)
176 t_div = (t_div * ~conv)
177
178 t_conv /= t_conv.max().clamp(min=1e-6)
179 t_div /= t_div.max().clamp(min=1e-6)
180
181 #msave(t_conv, f"{run_dir}/{training_steps:05d}_conv_final")
182 #msave(t_div, f"{run_dir}/{training_steps:05d}_div_final")
183 save(torch.stack((t_div, t_conv*0.8, t_conv)).to(dtype=torch.float), f"{run_dir}/{seed:06d}")
184
185
186def main():
187
188 dataset_size = network_n * (network_n + 1)
189
190 schedule(render_fractal, range(1))
191
192
193
194def final():
195 # todo colorize
196 pass
197