馃悕馃悕馃悕
1import time
2from pathlib import Path
3from random import random
4
5import torch
6
7from lib.util import *
8
9type fp_range = tuple[float, float]
10type fp_region2 = tuple[fp_range, fp_range]
11type fp_coords2 = tuple[float, float]
12type hw = tuple[int, int]
13type region_mapping = tuple[fp_region2, hw]
14
15def main():
16 import math
17 name = "ps_particle"
18 device = "cuda"
19 t_real = torch.double
20 t_complex = torch.cdouble
21
22 tau = 3.14159265358979323 * 2
23
24 flow = 8
25 ebb = 1 - (1 / 5) + 0.0001
26 phi = 0
27 randomize = False
28 random_range = tau #/ 50
29 rescaling = False
30 show_gridlines = True
31
32 particle_count = 10**6
33
34 iterations = 200#0001
35
36 # 2 ** 9 = 512; 10 => 1024; 11 => 2048
37 scale_power = 12
38 scale = 2 ** scale_power
39
40 origin = 0, -(flow*ebb)/2
41 span = 10, 10
42
43 stretch = 1, 1
44
45 zooms = [
46 #((0.5,0.8),(0.1,0.5))
47 ]
48
49 save_every = 5000
50 agg_every = 1#000
51 yell_every = 1000
52 grid_on_agg = False#True
53
54 quantile_range = torch.tensor([0.05, 0.95], dtype=t_real, device=device)
55
56 schedule(run, None)
57
58def blus(a, b, phi):
59 theta_a = a.angle()
60 return torch.polar(a.abs() + b.abs() * torch.cos(b.angle() - theta_a - phi), theta_a)
61
62def run():
63 run_dir = time.strftime(f"%d.%m.%Y/{name}_t%H.%M.%S")
64 Path("out/" + run_dir).mkdir(parents=True, exist_ok=True)
65 Path("out/" + run_dir + "/aggregate").mkdir(parents=True, exist_ok=True)
66 Path("out/" + run_dir + "/frame").mkdir(parents=True, exist_ok=True)
67
68 global span
69
70 x_min = origin[0] - (span[0] / 2)
71 y_min = origin[1] - (span[1] / 2)
72
73 for ((xa, xb), (ya, yb)) in zooms:
74 x_min += span[0] * xa
75 y_min += span[1] * ya
76 span = span[0] * (xb - xa), span[1] * (yb - ya)
77
78 x_max = x_min + span[0]
79 y_max = y_min + span[1]
80
81 aspect = span[0] * stretch[0] / (span[1] * stretch[1])
82
83 if aspect < 1:
84 h = scale
85 w = int(scale * aspect)
86 else:
87 w = scale
88 h = int(scale / aspect)
89
90 x_range = (x_min, x_max)
91 y_range = (y_min, y_max)
92 region = (x_range, y_range)
93 mapping = (region, (h,w))
94
95 p_positions = (torch.rand([particle_count], device=device, dtype=t_complex))
96 #p_positions.imag = torch.linspace(0, tau, particle_count)
97 #p_positions.real = torch.linspace(0, 1, particle_count)
98 #p_positions = torch.polar(p_positions.real, p_positions.imag)
99
100 p_colors = torch.ones([particle_count,3], device=device, dtype=t_real)
101 color_rotation = torch.linspace(0, tau / 4, particle_count)
102
103 p_colors[:,0] = torch.frac((p_positions.real) * 1 / (span[0]))
104 p_colors[:,2] = torch.frac((p_positions.imag) * 1 / (span[1]))
105
106 p_positions.real *= 0.025 * (y_max - y_min)
107 p_positions.real += y_min + (0.4875) * (y_max - y_min)
108 p_positions.imag *= 0.025 * (x_max - x_min) * 4
109 p_positions.imag += x_min + 0.501 * (x_max - x_min)
110
111 #p_colors[:,0] = torch.cos(color_rotation)
112 p_colors[:,1] = 1.0 - (p_colors[:,0] + p_colors[:,2])#0.1
113 #p_colors[:,2] *= 0
114 #p_colors[:,2] = torch.sin(color_rotation)
115
116 canvas = torch.zeros([3, h, w], device=device, dtype=t_real)
117 scratch = torch.zeros([h, w, 3], device=device, dtype=t_real)
118 gridlines = torch.zeros([h,w], device=device)
119
120 def project(p):
121 return torch.view_as_real(p).permute(1,0)
122
123 ones = torch.ones_like(p_positions)
124 global direction
125 direction = flow * torch.polar(ones.real, 1 * tau * ones.real)
126 def next_positions(p, i):
127 global direction
128 #return blus(p, ones) - ones
129 if randomize:
130 direction = flow * torch.polar(ones.real, (random() * random_range - random_range / 2) * ones.real)
131 result = blus(p, direction, phi) - direction * ebb
132 res_abs = result.abs()
133 if rescaling:
134 result = torch.polar(res_abs / res_abs.max(), result.angle())
135 return result
136
137 for i in range(10):
138 frac = i / 10
139 gridlines[math.floor(frac*h), :] = 1
140 gridlines[:, math.floor(frac*w)] = 1
141
142 def insert_at_coords(coords, values, target, mapping: region_mapping):
143 (region, hw) = mapping
144 (xrange, yrange) = region
145 (h,w) = hw
146 (x_min, x_max) = xrange
147 (y_min, y_max) = yrange
148
149 mask = torch.ones([particle_count], device=device)
150 mask *= (coords[1] >= x_min) * (coords[1] <= x_max)
151 mask *= (coords[0] >= y_min) * (coords[0] <= y_max)
152 in_range = mask.nonzero().squeeze()
153
154 # TODO: combine coord & value tensors so there's only one index_select necessary
155 coords_filtered = torch.index_select(coords.permute(1,0), 0, in_range)
156 values_filtered = torch.index_select(values, 0, in_range)
157
158 coords_filtered[:,1] -= x_min
159 coords_filtered[:,1] *= (w-1) / (x_max - x_min)
160 coords_filtered[:,0] -= y_min
161 coords_filtered[:,0] *= (h-1) / (y_max - y_min)
162 indices = coords_filtered.long()
163
164 target.index_put_((indices[:,0],indices[:,1]), values_filtered, accumulate=True)
165
166
167 for iteration in range(iterations):
168 p_projected = project(p_positions).clone()
169
170 if iteration % 1 == 0:
171
172 scratch *= 0
173 insert_at_coords(p_projected, p_colors, scratch, mapping)
174 canvas += scratch.permute(2,0,1)
175
176
177 temp = canvas.clone()
178 #for d in range(3):
179 # temp[d] -= temp[d].mean()
180 # temp[d] /= 8 * temp[d].std()
181 # temp[d] -= temp[d].min()
182
183 temp -= temp.min()
184 temp = torch.log(temp)
185 temp /= temp.max()
186
187 if iteration % agg_every == 0:
188 p_low, p_high = torch.quantile(temp[:,(2*h//5):(3*h//5),:], quantile_range)
189 # temp[0] += gridlines
190 # temp[1] += gridlines
191 # temp[2] += gridlines
192
193 if grid_on_agg:
194 save((1 - temp).clamp_(0.0, 1.0) - gridlines, f"{run_dir}/aggregate/_{iteration:06d}")
195 else:
196 save((1 - temp).clamp_(0.0, 1.0), f"{run_dir}/aggregate/_{iteration:06d}")
197 temp = (temp - p_low) / (1e-7 + p_high - p_low)
198 #save(1 - temp, f"{run_dir}/aggregate/{iteration:06d}")
199 #scratch /= scratch.max()
200 #scratch = scratch.sqrt().sqrt()
201 if show_gridlines:
202 scratch[:,:,1] += gridlines
203 scratch[:,:,2] += gridlines
204 if iteration % save_every == 0:
205 save(1 - scratch.permute(2,0,1), f"{run_dir}/frame/{iteration:06d}")
206 if iteration % yell_every == 0:
207 print(f"{iteration} iterations")
208
209 p_positions = next_positions(p_positions, iteration)
210
211 #for d in range(3):
212 # canvas[d] -= canvas[d].mean()
213 # canvas[d] /= 8 * canvas[d].std()
214 # canvas[d] -= canvas[d].min()
215
216
217 torch.set_default_device(device)
218