馃悕馃悕馃悕
at main 6.9 kB view raw
1import time 2from pathlib import Path 3from random import random 4 5import torch 6 7from lib.util import * 8 9type fp_range = tuple[float, float] 10type fp_region2 = tuple[fp_range, fp_range] 11type fp_coords2 = tuple[float, float] 12type hw = tuple[int, int] 13type region_mapping = tuple[fp_region2, hw] 14 15def main(): 16 import math 17 name = "ps_particle" 18 device = "cuda" 19 t_real = torch.double 20 t_complex = torch.cdouble 21 22 tau = 3.14159265358979323 * 2 23 24 flow = 8 25 ebb = 1 - (1 / 5) + 0.0001 26 phi = 0 27 randomize = False 28 random_range = tau #/ 50 29 rescaling = False 30 show_gridlines = True 31 32 particle_count = 10**6 33 34 iterations = 200#0001 35 36 # 2 ** 9 = 512; 10 => 1024; 11 => 2048 37 scale_power = 12 38 scale = 2 ** scale_power 39 40 origin = 0, -(flow*ebb)/2 41 span = 10, 10 42 43 stretch = 1, 1 44 45 zooms = [ 46 #((0.5,0.8),(0.1,0.5)) 47 ] 48 49 save_every = 5000 50 agg_every = 1#000 51 yell_every = 1000 52 grid_on_agg = False#True 53 54 quantile_range = torch.tensor([0.05, 0.95], dtype=t_real, device=device) 55 56 schedule(run, None) 57 58def blus(a, b, phi): 59 theta_a = a.angle() 60 return torch.polar(a.abs() + b.abs() * torch.cos(b.angle() - theta_a - phi), theta_a) 61 62def run(): 63 run_dir = time.strftime(f"%d.%m.%Y/{name}_t%H.%M.%S") 64 Path("out/" + run_dir).mkdir(parents=True, exist_ok=True) 65 Path("out/" + run_dir + "/aggregate").mkdir(parents=True, exist_ok=True) 66 Path("out/" + run_dir + "/frame").mkdir(parents=True, exist_ok=True) 67 68 global span 69 70 x_min = origin[0] - (span[0] / 2) 71 y_min = origin[1] - (span[1] / 2) 72 73 for ((xa, xb), (ya, yb)) in zooms: 74 x_min += span[0] * xa 75 y_min += span[1] * ya 76 span = span[0] * (xb - xa), span[1] * (yb - ya) 77 78 x_max = x_min + span[0] 79 y_max = y_min + span[1] 80 81 aspect = span[0] * stretch[0] / (span[1] * stretch[1]) 82 83 if aspect < 1: 84 h = scale 85 w = int(scale * aspect) 86 else: 87 w = scale 88 h = int(scale / aspect) 89 90 x_range = (x_min, x_max) 91 y_range = (y_min, y_max) 92 region = (x_range, y_range) 93 mapping = (region, (h,w)) 94 95 p_positions = (torch.rand([particle_count], device=device, dtype=t_complex)) 96 #p_positions.imag = torch.linspace(0, tau, particle_count) 97 #p_positions.real = torch.linspace(0, 1, particle_count) 98 #p_positions = torch.polar(p_positions.real, p_positions.imag) 99 100 p_colors = torch.ones([particle_count,3], device=device, dtype=t_real) 101 color_rotation = torch.linspace(0, tau / 4, particle_count) 102 103 p_colors[:,0] = torch.frac((p_positions.real) * 1 / (span[0])) 104 p_colors[:,2] = torch.frac((p_positions.imag) * 1 / (span[1])) 105 106 p_positions.real *= 0.025 * (y_max - y_min) 107 p_positions.real += y_min + (0.4875) * (y_max - y_min) 108 p_positions.imag *= 0.025 * (x_max - x_min) * 4 109 p_positions.imag += x_min + 0.501 * (x_max - x_min) 110 111 #p_colors[:,0] = torch.cos(color_rotation) 112 p_colors[:,1] = 1.0 - (p_colors[:,0] + p_colors[:,2])#0.1 113 #p_colors[:,2] *= 0 114 #p_colors[:,2] = torch.sin(color_rotation) 115 116 canvas = torch.zeros([3, h, w], device=device, dtype=t_real) 117 scratch = torch.zeros([h, w, 3], device=device, dtype=t_real) 118 gridlines = torch.zeros([h,w], device=device) 119 120 def project(p): 121 return torch.view_as_real(p).permute(1,0) 122 123 ones = torch.ones_like(p_positions) 124 global direction 125 direction = flow * torch.polar(ones.real, 1 * tau * ones.real) 126 def next_positions(p, i): 127 global direction 128 #return blus(p, ones) - ones 129 if randomize: 130 direction = flow * torch.polar(ones.real, (random() * random_range - random_range / 2) * ones.real) 131 result = blus(p, direction, phi) - direction * ebb 132 res_abs = result.abs() 133 if rescaling: 134 result = torch.polar(res_abs / res_abs.max(), result.angle()) 135 return result 136 137 for i in range(10): 138 frac = i / 10 139 gridlines[math.floor(frac*h), :] = 1 140 gridlines[:, math.floor(frac*w)] = 1 141 142 def insert_at_coords(coords, values, target, mapping: region_mapping): 143 (region, hw) = mapping 144 (xrange, yrange) = region 145 (h,w) = hw 146 (x_min, x_max) = xrange 147 (y_min, y_max) = yrange 148 149 mask = torch.ones([particle_count], device=device) 150 mask *= (coords[1] >= x_min) * (coords[1] <= x_max) 151 mask *= (coords[0] >= y_min) * (coords[0] <= y_max) 152 in_range = mask.nonzero().squeeze() 153 154 # TODO: combine coord & value tensors so there's only one index_select necessary 155 coords_filtered = torch.index_select(coords.permute(1,0), 0, in_range) 156 values_filtered = torch.index_select(values, 0, in_range) 157 158 coords_filtered[:,1] -= x_min 159 coords_filtered[:,1] *= (w-1) / (x_max - x_min) 160 coords_filtered[:,0] -= y_min 161 coords_filtered[:,0] *= (h-1) / (y_max - y_min) 162 indices = coords_filtered.long() 163 164 target.index_put_((indices[:,0],indices[:,1]), values_filtered, accumulate=True) 165 166 167 for iteration in range(iterations): 168 p_projected = project(p_positions).clone() 169 170 if iteration % 1 == 0: 171 172 scratch *= 0 173 insert_at_coords(p_projected, p_colors, scratch, mapping) 174 canvas += scratch.permute(2,0,1) 175 176 177 temp = canvas.clone() 178 #for d in range(3): 179 # temp[d] -= temp[d].mean() 180 # temp[d] /= 8 * temp[d].std() 181 # temp[d] -= temp[d].min() 182 183 temp -= temp.min() 184 temp = torch.log(temp) 185 temp /= temp.max() 186 187 if iteration % agg_every == 0: 188 p_low, p_high = torch.quantile(temp[:,(2*h//5):(3*h//5),:], quantile_range) 189 # temp[0] += gridlines 190 # temp[1] += gridlines 191 # temp[2] += gridlines 192 193 if grid_on_agg: 194 save((1 - temp).clamp_(0.0, 1.0) - gridlines, f"{run_dir}/aggregate/_{iteration:06d}") 195 else: 196 save((1 - temp).clamp_(0.0, 1.0), f"{run_dir}/aggregate/_{iteration:06d}") 197 temp = (temp - p_low) / (1e-7 + p_high - p_low) 198 #save(1 - temp, f"{run_dir}/aggregate/{iteration:06d}") 199 #scratch /= scratch.max() 200 #scratch = scratch.sqrt().sqrt() 201 if show_gridlines: 202 scratch[:,:,1] += gridlines 203 scratch[:,:,2] += gridlines 204 if iteration % save_every == 0: 205 save(1 - scratch.permute(2,0,1), f"{run_dir}/frame/{iteration:06d}") 206 if iteration % yell_every == 0: 207 print(f"{iteration} iterations") 208 209 p_positions = next_positions(p_positions, iteration) 210 211 #for d in range(3): 212 # canvas[d] -= canvas[d].mean() 213 # canvas[d] /= 8 * canvas[d].std() 214 # canvas[d] -= canvas[d].min() 215 216 217 torch.set_default_device(device) 218