馃悕馃悕馃悕
at main 5.9 kB view raw
1 2import time 3import math 4from pathlib import Path 5 6import torch 7 8from lib.spaces import insert_at_coords, map_space 9from lib.ode import rk4_step 10from lib.util import * 11 12def main(): 13 schedule(run, None) 14 15device = "cuda" 16torch.set_default_device(device) 17t_fp = torch.double 18torch.set_default_dtype(t_fp) 19 20dissipation = 0.04 # "b" parameter in the literature 21 22tau = 3.14159265358979323 * 2 23 24particle_count = 10000000 25iterations = 10000 26save_if = lambda i: True #i > 150 and i % 5 == 0 27rotation_rate = tau / 2000 28 29# 2 ** 9 = 512; 10 => 1024; 11 => 2048 30scale_power = 10 31 32scale = 2 ** scale_power 33origin = 0, 0 34s = 40 35span = s, s 36zooms = [ 37 ] 38stretch = 1, 1 39 40mapping = map_space(origin, span, zooms, stretch, scale) 41(_, (h,w)) = mapping 42 43dt = 0.05 44 45def proj(u, v): 46 dot = (u * v).sum(dim=1) 47 scale = (dot / (v.norm(p=2, dim=1) ** 2)) 48 out = v.clone() 49 for dim in range(3): 50 out[:,dim] *= scale 51 return out 52 53def proj_shift(a, b): 54 return a + proj(b, a) 55 56def get_transformation(direction, flow, ebb, rescaling): 57 def transformation(p): 58 #return blus(p, ones) - ones 59 #if randomize: 60 # direction = flow * torch.polar(ones.real, (random() * random_range - random_range / 2) * ones.real) 61 if flow == 0: 62 result = p - direction * ebb 63 else: 64 result = proj_shift(p, direction * flow) - direction * ebb 65 #res_abs = result.abs() 66 if rescaling: 67 result /= result.norm(p=2,dim=1).max() 68 #result = torch.polar(2 * res_abs / res_abs.max(), result.angle()) 69 return result 70 return transformation 71 72def _deriv(b, c): 73 def derivative(p): 74 return p * ((b * p).sum(dim=0)) - c * b 75 return derivative 76 77def _sprott2014(a): 78 def derivative(p): 79 d = p.roll(shifts=-1, dims=0) 80 d[0] += 2 * (p[0] * p[1]) + (p[0] * p[2]) # y + 2xy + xz 81 d[1] = - (p[0] * p[0]) # -x^2 82 d[2] += d[1] - (p[1] * p[1]) # x - x^2 - y^2 83 d[1] *= 2 # - 2x^2 84 d[1] += a + (p[1] * p[2]) # a - 2x^2 + yz 85 return d 86 return derivative 87 88def _sprottET0(): 89 def derivative(p): 90 d = p.clone() 91 d[0] = p[1] 92 d[1] = - p[0] + p[1] * p[2] 93 d[2] = (p[0] * p[0]) - (4 * p[1] * p[1]) + 1 94 return d 95 return derivative 96 97def _suspension(c): 98 damping = 0.02 99 r_eq = 1.0 100 scale = 0.05 101 def deriv(p): 102 x, y, z = p[0], p[1], p[2] 103 104 # Shift x to center the tori at Re = -c/2 105 X = x + c/2 106 Y = y 107 108 # Polar coordinates 109 r = torch.sqrt(X**2 + Y**2) + 1e-12 110 theta = torch.atan2(Y, X) 111 112 # Radial displacement 113 dr = r + torch.cos(theta) 114 115 # Velocity in 2D 116 dx = scale * (dr * (X/r) - X) 117 dy = scale * (dr * (Y/r) - Y) 118 119 # Optional weak radial damping 120 dx -= scale * damping * (r - r_eq) * (X/r) 121 dy -= scale * damping * (r - r_eq) * (Y/r) 122 123 # z-axis suspension (monotonic) 124 dz = scale * torch.ones_like(x) 125 126 return torch.stack([dx, dy, dz], dim=0) 127 return deriv 128 129 130def gpt_deriv(b, k=1.0, c=0.1, omega=0.5): 131 """ 132 Returns a function deriv(points) -> [3,N] where points are 3D column vectors. 133 b: tensor of shape [3] (will be normalized internally) 134 k: scalar gain for tanh nonlinearity 135 c: scalar drift magnitude along -b 136 omega: scalar rotation strength around b 137 """ 138 b = b / b.norm() # normalize once 139 140 def deriv(points): 141 # <a, b> for each column 142 #dots = torch.einsum('ij,i->j', points, b) # shape [N] 143 dots = (points * b[:]).sum(dim=0) # shape [N] 144 145 # radial scaling term 146 radial = torch.tanh(k * dots).unsqueeze(0) * points # [3,N] 147 148 # drift along -b 149 drift = -c * b[:].expand_as(points) # [3,N] 150 151 152 # rotation around b: b 脳 a 153 rotation = omega * torch.cross(b[:].expand_as(points), points, dim=0) 154 155 return radial + drift + rotation 156 157 return deriv 158 159p_positions = (torch.rand([3, particle_count], device=device, dtype=t_fp) - 0.5) * 2 160#p_positions[0] *= 0.05 161#p_positions[0] += 0.7 162#p_positions[2] = 0 163direction = torch.zeros_like(p_positions) / math.sqrt(3) 164direction[1] += 1 165 166#derivative = _deriv(direction, 0.5) 167#derivative = gpt_deriv(direction, 1.0, 0.1, 0.5) 168#derivative = _sprott2014(0.7) 169#derivative = _sprottET0() 170derivative = _suspension(0.8) 171 172step = lambda p, dp, h: p + dp * h * dt 173rk4_curried = lambda p: rk4_step(derivative, step, p) 174 175 176p_colors = torch.rand([particle_count,3], device=device, dtype=t_fp) 177 178color_rotation = torch.linspace(0, tau / 4, particle_count) 179 180p_colors[:,0] = (p_positions[0,:] / 2) + 0.5 181p_colors[:,1] = (p_positions[1,:] / 2) + 0.5 182p_colors[:,2] = (p_positions[2,:] / 2) + 0.5 183 184def project(p, colors, i): 185 c = math.cos(i * rotation_rate) 186 s = math.sin(i * rotation_rate) 187 rotation = torch.tensor([ 188 [1, 0, 0], 189 [0, c,-s], 190 [0, s, c]]) 191 rotation2 = torch.tensor([ 192 [ c, 0, s], 193 [ 0, 1, 0], 194 [-s, 0, c]]) 195 #rotation = torch.tensor([ 196 # [c, -s, 0], 197 # [s, c, 0], 198 # [0, 0, 1]]) 199 alt_colors = colors.clone() 200 res = (rotation @ p) 201 #color_filter = (p[0] - 0.7).abs() < 0.01 202 #alt_colors[:,0] *= color_filter 203 #alt_colors[:,1] *= color_filter 204 #alt_colors[:,2] *= color_filter 205 return (res[1:3], alt_colors) 206 207frame_index = [0] 208 209def run(): 210 scratch = torch.zeros([h, w, 3], device=device, dtype=t_fp) 211 212 for iteration in range(iterations): 213 if save_if(iteration) or iteration == iterations - 1: 214 (p_projected, alt_colors) = project(p_positions, p_colors, iteration) 215 frame_index[0] += 1 216 scratch *= 0 217 insert_at_coords(p_projected, alt_colors, scratch, mapping) 218 save(scratch.permute(2,0,1), f"{run_dir}/{frame_index[0]:06d}") 219 220 p_positions.copy_(rk4_curried(p_positions)) 221 222