๐Ÿ๐Ÿ๐Ÿ

Compare changes

Choose any two refs to compare.

+1
.gitignore
··· 5 5 6 6 sketch/.old 7 7 sketch/scratch.py 8 + sketch/local 8 9 9 10 .git 10 11 .venv
+161
sketch/gamma.py
··· 1 + 2 + import math 3 + import torch 4 + from pathlib import Path 5 + 6 + from lib.util import * 7 + from lib.spaces import insert_at_coords, map_space, cgrid 8 + 9 + name = "gamma" 10 + 11 + device = "cuda" 12 + torch.set_default_device(device) 13 + 14 + t_real = torch.double 15 + t_complex = torch.cdouble 16 + tau = 2 * math.pi 17 + 18 + # renderer config 19 + scale_power = 12 20 + scale = 2 ** scale_power 21 + origin = 0, 0 22 + span = 15, 15 23 + stretch = 1, 1 24 + zooms = [ 25 + ] 26 + 27 + # particle distribution: sample domain points 28 + particle_count = 20000000 29 + 30 + #p_positions = (torch.rand([particle_count], dtype=t_complex) - (0.5 + 0.5j)) * 40 31 + 32 + # initialize colors (updated every frame from Gamma) 33 + #p_colors = torch.zeros([particle_count, 3], device=device, dtype=t_real) 34 + 35 + mapping = map_space(origin, span, zooms, stretch, scale) 36 + 37 + (_, (h,w)) = mapping 38 + grid = cgrid(mapping) 39 + 40 + scratch = torch.zeros([h, w, 3], device=device, dtype=t_real) 41 + p_positions = torch.zeros([h*w,2], device=device, dtype=t_real) 42 + p_positions[:,0] = grid.real.reshape((h*w)) 43 + p_positions[:,1] = grid.imag.reshape((h*w)) 44 + 45 + 46 + 47 + def gamma_approx(z): 48 + # avoid poles in a naive way 49 + poles = (z.real <= 0) & torch.isclose(z.real.round(), z.real) 50 + z = z.clone() 51 + z[poles] = complex("nan") 52 + 53 + # reflection for left half-plane 54 + reflect = z.real < 0.5 55 + zr = z[reflect] 56 + if zr.numel() > 0: 57 + g1mz = gamma_approx(1 - zr) 58 + z[reflect] = math.pi / (torch.sin(math.pi * zr) * g1mz) 59 + 60 + # Stirling-like formula on right half-plane 61 + zpos = z[~reflect] 62 + if zpos.numel() > 0: 63 + t = zpos - 0.5 64 + core = math.sqrt(2 * math.pi) * torch.exp(t * torch.log(zpos) - zpos) 65 + z[~reflect] = core 66 + return z 67 + 68 + 69 + a = 2.0j 70 + b = 2.0 71 + kernel_xs = torch.tensor([a + b, a - b, -a + b, -a - b], device=device) 72 + kernel_xs -= kernel_xs.mean() 73 + 74 + def velu_like(z): 75 + acc = z 76 + for xQ in kernel_xs: 77 + acc += 1 / (z - xQ) 78 + return acc 79 + 80 + 81 + def gamma_to_rgb(gz): 82 + arg = torch.angle(gz) 83 + hue = (arg + math.pi) / (2 * math.pi) 84 + mag = torch.abs(gz) 85 + logmag = torch.log(mag + 1e-12) 86 + 87 + # normalize brightness heuristically 88 + val = (logmag + 5) / 10 89 + val = val.clamp(0.0, 1.0) 90 + 91 + # convert HSV to RGB (same logic as your framework) 92 + h6 = hue * 6 93 + i = torch.floor(h6).long() % 6 94 + f = h6 - torch.floor(h6) 95 + p = val * 0 96 + q = val * (1 - f) 97 + t = val * f 98 + 99 + r = torch.zeros_like(val) 100 + g = torch.zeros_like(val) 101 + b = torch.zeros_like(val) 102 + 103 + m = i == 0; r[m]=val[m]; g[m]=t[m]; b[m]=p[m] 104 + m = i == 1; r[m]=q[m]; g[m]=val[m]; b[m]=p[m] 105 + m = i == 2; r[m]=p[m]; g[m]=val[m]; b[m]=t[m] 106 + m = i == 3; r[m]=p[m]; g[m]=q[m]; b[m]=val[m] 107 + m = i == 4; r[m]=t[m]; g[m]=p[m]; b[m]=val[m] 108 + m = i == 5; r[m]=val[m]; g[m]=p[m]; b[m]=q[m] 109 + 110 + return torch.stack([r,g,b], dim=-1) 111 + 112 + frame_index = [0] 113 + 114 + def main(): 115 + schedule(main_render, None) 116 + 117 + def main_render(): 118 + scratch = torch.zeros([h, w, 3], device=device, dtype=t_real) 119 + naniter = torch.zeros([h, w], device=device, dtype=t_real) 120 + 121 + # project function kept from original plume code 122 + def project(p): 123 + return torch.view_as_real(p).permute(1,0) 124 + 125 + # single-frame render (or loop for animations) 126 + frame_index[0] += 1 127 + 128 + # compute Gamma(p) 129 + gz = grid.clone()#gamma_approx(p_positions) 130 + 131 + for i in range(2): 132 + gz = gamma_approx(gz)#velu_like(gz) 133 + 134 + naniter += gz.isnan() * 1.0 135 + 136 + # convert to RGB 137 + #rgb = gamma_to_rgb(gz) 138 + 139 + # project particle coords to pixel grid 140 + #coords = project(p_positions).clone() 141 + #insert_at_coords(coords, rgb, scratch, mapping) 142 + 143 + 144 + #scratch[:,:,0] = gz.real 145 + #scratch[:,:,2] = gz.imag 146 + #scratch[:,:,1] = naniter / naniter.max() 147 + 148 + scratch[:,:,0] = (torch.angle(gz) + math.pi) 149 + scratch[:,:,1] = torch.log(torch.abs(gz) + 1e-12) / 30 150 + scratch[:,:,1] /= scratch[:,:,2].max() 151 + scratch[:,:,2] = torch.cos(scratch[:,:,0] / 2) 152 + scratch[:,:,0] = torch.abs(torch.sin(scratch[:,:,0] / 2)) 153 + 154 + 155 + # normalization/centering like plume_c does 156 + #scratch[:,:,2] -= scratch[:,:,2].mean() 157 + #scratch[:,:,2] /= scratch[:,:,2].std() * 6 158 + #scratch[:,:,2] += 0.5 159 + 160 + save(scratch.permute(2,0,1).sqrt(), f"{run_dir}/{i:06d}") 161 +
+222
sketch/ode.py
··· 1 + 2 + import time 3 + import math 4 + from pathlib import Path 5 + 6 + import torch 7 + 8 + from lib.spaces import insert_at_coords, map_space 9 + from lib.ode import rk4_step 10 + from lib.util import * 11 + 12 + def main(): 13 + schedule(run, None) 14 + 15 + device = "cuda" 16 + torch.set_default_device(device) 17 + t_fp = torch.double 18 + torch.set_default_dtype(t_fp) 19 + 20 + dissipation = 0.04 # "b" parameter in the literature 21 + 22 + tau = 3.14159265358979323 * 2 23 + 24 + particle_count = 10000000 25 + iterations = 10000 26 + save_if = lambda i: True #i > 150 and i % 5 == 0 27 + rotation_rate = tau / 2000 28 + 29 + # 2 ** 9 = 512; 10 => 1024; 11 => 2048 30 + scale_power = 10 31 + 32 + scale = 2 ** scale_power 33 + origin = 0, 0 34 + s = 40 35 + span = s, s 36 + zooms = [ 37 + ] 38 + stretch = 1, 1 39 + 40 + mapping = map_space(origin, span, zooms, stretch, scale) 41 + (_, (h,w)) = mapping 42 + 43 + dt = 0.05 44 + 45 + def proj(u, v): 46 + dot = (u * v).sum(dim=1) 47 + scale = (dot / (v.norm(p=2, dim=1) ** 2)) 48 + out = v.clone() 49 + for dim in range(3): 50 + out[:,dim] *= scale 51 + return out 52 + 53 + def proj_shift(a, b): 54 + return a + proj(b, a) 55 + 56 + def get_transformation(direction, flow, ebb, rescaling): 57 + def transformation(p): 58 + #return blus(p, ones) - ones 59 + #if randomize: 60 + # direction = flow * torch.polar(ones.real, (random() * random_range - random_range / 2) * ones.real) 61 + if flow == 0: 62 + result = p - direction * ebb 63 + else: 64 + result = proj_shift(p, direction * flow) - direction * ebb 65 + #res_abs = result.abs() 66 + if rescaling: 67 + result /= result.norm(p=2,dim=1).max() 68 + #result = torch.polar(2 * res_abs / res_abs.max(), result.angle()) 69 + return result 70 + return transformation 71 + 72 + def _deriv(b, c): 73 + def derivative(p): 74 + return p * ((b * p).sum(dim=0)) - c * b 75 + return derivative 76 + 77 + def _sprott2014(a): 78 + def derivative(p): 79 + d = p.roll(shifts=-1, dims=0) 80 + d[0] += 2 * (p[0] * p[1]) + (p[0] * p[2]) # y + 2xy + xz 81 + d[1] = - (p[0] * p[0]) # -x^2 82 + d[2] += d[1] - (p[1] * p[1]) # x - x^2 - y^2 83 + d[1] *= 2 # - 2x^2 84 + d[1] += a + (p[1] * p[2]) # a - 2x^2 + yz 85 + return d 86 + return derivative 87 + 88 + def _sprottET0(): 89 + def derivative(p): 90 + d = p.clone() 91 + d[0] = p[1] 92 + d[1] = - p[0] + p[1] * p[2] 93 + d[2] = (p[0] * p[0]) - (4 * p[1] * p[1]) + 1 94 + return d 95 + return derivative 96 + 97 + def _suspension(c): 98 + damping = 0.02 99 + r_eq = 1.0 100 + scale = 0.05 101 + def deriv(p): 102 + x, y, z = p[0], p[1], p[2] 103 + 104 + # Shift x to center the tori at Re = -c/2 105 + X = x + c/2 106 + Y = y 107 + 108 + # Polar coordinates 109 + r = torch.sqrt(X**2 + Y**2) + 1e-12 110 + theta = torch.atan2(Y, X) 111 + 112 + # Radial displacement 113 + dr = r + torch.cos(theta) 114 + 115 + # Velocity in 2D 116 + dx = scale * (dr * (X/r) - X) 117 + dy = scale * (dr * (Y/r) - Y) 118 + 119 + # Optional weak radial damping 120 + dx -= scale * damping * (r - r_eq) * (X/r) 121 + dy -= scale * damping * (r - r_eq) * (Y/r) 122 + 123 + # z-axis suspension (monotonic) 124 + dz = scale * torch.ones_like(x) 125 + 126 + return torch.stack([dx, dy, dz], dim=0) 127 + return deriv 128 + 129 + 130 + def gpt_deriv(b, k=1.0, c=0.1, omega=0.5): 131 + """ 132 + Returns a function deriv(points) -> [3,N] where points are 3D column vectors. 133 + b: tensor of shape [3] (will be normalized internally) 134 + k: scalar gain for tanh nonlinearity 135 + c: scalar drift magnitude along -b 136 + omega: scalar rotation strength around b 137 + """ 138 + b = b / b.norm() # normalize once 139 + 140 + def deriv(points): 141 + # <a, b> for each column 142 + #dots = torch.einsum('ij,i->j', points, b) # shape [N] 143 + dots = (points * b[:]).sum(dim=0) # shape [N] 144 + 145 + # radial scaling term 146 + radial = torch.tanh(k * dots).unsqueeze(0) * points # [3,N] 147 + 148 + # drift along -b 149 + drift = -c * b[:].expand_as(points) # [3,N] 150 + 151 + 152 + # rotation around b: b ร— a 153 + rotation = omega * torch.cross(b[:].expand_as(points), points, dim=0) 154 + 155 + return radial + drift + rotation 156 + 157 + return deriv 158 + 159 + p_positions = (torch.rand([3, particle_count], device=device, dtype=t_fp) - 0.5) * 2 160 + #p_positions[0] *= 0.05 161 + #p_positions[0] += 0.7 162 + #p_positions[2] = 0 163 + direction = torch.zeros_like(p_positions) / math.sqrt(3) 164 + direction[1] += 1 165 + 166 + #derivative = _deriv(direction, 0.5) 167 + #derivative = gpt_deriv(direction, 1.0, 0.1, 0.5) 168 + #derivative = _sprott2014(0.7) 169 + #derivative = _sprottET0() 170 + derivative = _suspension(0.8) 171 + 172 + step = lambda p, dp, h: p + dp * h * dt 173 + rk4_curried = lambda p: rk4_step(derivative, step, p) 174 + 175 + 176 + p_colors = torch.rand([particle_count,3], device=device, dtype=t_fp) 177 + 178 + color_rotation = torch.linspace(0, tau / 4, particle_count) 179 + 180 + p_colors[:,0] = (p_positions[0,:] / 2) + 0.5 181 + p_colors[:,1] = (p_positions[1,:] / 2) + 0.5 182 + p_colors[:,2] = (p_positions[2,:] / 2) + 0.5 183 + 184 + def project(p, colors, i): 185 + c = math.cos(i * rotation_rate) 186 + s = math.sin(i * rotation_rate) 187 + rotation = torch.tensor([ 188 + [1, 0, 0], 189 + [0, c,-s], 190 + [0, s, c]]) 191 + rotation2 = torch.tensor([ 192 + [ c, 0, s], 193 + [ 0, 1, 0], 194 + [-s, 0, c]]) 195 + #rotation = torch.tensor([ 196 + # [c, -s, 0], 197 + # [s, c, 0], 198 + # [0, 0, 1]]) 199 + alt_colors = colors.clone() 200 + res = (rotation @ p) 201 + #color_filter = (p[0] - 0.7).abs() < 0.01 202 + #alt_colors[:,0] *= color_filter 203 + #alt_colors[:,1] *= color_filter 204 + #alt_colors[:,2] *= color_filter 205 + return (res[1:3], alt_colors) 206 + 207 + frame_index = [0] 208 + 209 + def run(): 210 + scratch = torch.zeros([h, w, 3], device=device, dtype=t_fp) 211 + 212 + for iteration in range(iterations): 213 + if save_if(iteration) or iteration == iterations - 1: 214 + (p_projected, alt_colors) = project(p_positions, p_colors, iteration) 215 + frame_index[0] += 1 216 + scratch *= 0 217 + insert_at_coords(p_projected, alt_colors, scratch, mapping) 218 + save(scratch.permute(2,0,1), f"{run_dir}/{frame_index[0]:06d}") 219 + 220 + p_positions.copy_(rk4_curried(p_positions)) 221 + 222 +
+4 -4
sketch/old/snakepyt_0_0/ode.py
··· 142 142 # [0, 0, 1]]) 143 143 alt_colors = colors.clone() 144 144 res = (rotation @ p) 145 - #color_filter = res[2].abs() < 0.7 146 - #alt_colors[:,0] *= color_filter 147 - #alt_colors[:,1] *= color_filter 148 - #alt_colors[:,2] *= color_filter 145 + color_filter = res[2].abs() < 0.1 146 + alt_colors[:,0] *= color_filter 147 + alt_colors[:,1] *= color_filter 148 + alt_colors[:,2] *= color_filter 149 149 return (res[0:2], alt_colors) 150 150 151 151 frame_index = [0]
+218
sketch/ps_particle.py
··· 1 + import time 2 + from pathlib import Path 3 + from random import random 4 + 5 + import torch 6 + 7 + from lib.util import * 8 + 9 + type fp_range = tuple[float, float] 10 + type fp_region2 = tuple[fp_range, fp_range] 11 + type fp_coords2 = tuple[float, float] 12 + type hw = tuple[int, int] 13 + type region_mapping = tuple[fp_region2, hw] 14 + 15 + def main(): 16 + import math 17 + name = "ps_particle" 18 + device = "cuda" 19 + t_real = torch.double 20 + t_complex = torch.cdouble 21 + 22 + tau = 3.14159265358979323 * 2 23 + 24 + flow = 8 25 + ebb = 1 - (1 / 5) + 0.0001 26 + phi = 0 27 + randomize = False 28 + random_range = tau #/ 50 29 + rescaling = False 30 + show_gridlines = True 31 + 32 + particle_count = 10**6 33 + 34 + iterations = 200#0001 35 + 36 + # 2 ** 9 = 512; 10 => 1024; 11 => 2048 37 + scale_power = 12 38 + scale = 2 ** scale_power 39 + 40 + origin = 0, -(flow*ebb)/2 41 + span = 10, 10 42 + 43 + stretch = 1, 1 44 + 45 + zooms = [ 46 + #((0.5,0.8),(0.1,0.5)) 47 + ] 48 + 49 + save_every = 5000 50 + agg_every = 1#000 51 + yell_every = 1000 52 + grid_on_agg = False#True 53 + 54 + quantile_range = torch.tensor([0.05, 0.95], dtype=t_real, device=device) 55 + 56 + schedule(run, None) 57 + 58 + def blus(a, b, phi): 59 + theta_a = a.angle() 60 + return torch.polar(a.abs() + b.abs() * torch.cos(b.angle() - theta_a - phi), theta_a) 61 + 62 + def run(): 63 + run_dir = time.strftime(f"%d.%m.%Y/{name}_t%H.%M.%S") 64 + Path("out/" + run_dir).mkdir(parents=True, exist_ok=True) 65 + Path("out/" + run_dir + "/aggregate").mkdir(parents=True, exist_ok=True) 66 + Path("out/" + run_dir + "/frame").mkdir(parents=True, exist_ok=True) 67 + 68 + global span 69 + 70 + x_min = origin[0] - (span[0] / 2) 71 + y_min = origin[1] - (span[1] / 2) 72 + 73 + for ((xa, xb), (ya, yb)) in zooms: 74 + x_min += span[0] * xa 75 + y_min += span[1] * ya 76 + span = span[0] * (xb - xa), span[1] * (yb - ya) 77 + 78 + x_max = x_min + span[0] 79 + y_max = y_min + span[1] 80 + 81 + aspect = span[0] * stretch[0] / (span[1] * stretch[1]) 82 + 83 + if aspect < 1: 84 + h = scale 85 + w = int(scale * aspect) 86 + else: 87 + w = scale 88 + h = int(scale / aspect) 89 + 90 + x_range = (x_min, x_max) 91 + y_range = (y_min, y_max) 92 + region = (x_range, y_range) 93 + mapping = (region, (h,w)) 94 + 95 + p_positions = (torch.rand([particle_count], device=device, dtype=t_complex)) 96 + #p_positions.imag = torch.linspace(0, tau, particle_count) 97 + #p_positions.real = torch.linspace(0, 1, particle_count) 98 + #p_positions = torch.polar(p_positions.real, p_positions.imag) 99 + 100 + p_colors = torch.ones([particle_count,3], device=device, dtype=t_real) 101 + color_rotation = torch.linspace(0, tau / 4, particle_count) 102 + 103 + p_colors[:,0] = torch.frac((p_positions.real) * 1 / (span[0])) 104 + p_colors[:,2] = torch.frac((p_positions.imag) * 1 / (span[1])) 105 + 106 + p_positions.real *= 0.025 * (y_max - y_min) 107 + p_positions.real += y_min + (0.4875) * (y_max - y_min) 108 + p_positions.imag *= 0.025 * (x_max - x_min) * 4 109 + p_positions.imag += x_min + 0.501 * (x_max - x_min) 110 + 111 + #p_colors[:,0] = torch.cos(color_rotation) 112 + p_colors[:,1] = 1.0 - (p_colors[:,0] + p_colors[:,2])#0.1 113 + #p_colors[:,2] *= 0 114 + #p_colors[:,2] = torch.sin(color_rotation) 115 + 116 + canvas = torch.zeros([3, h, w], device=device, dtype=t_real) 117 + scratch = torch.zeros([h, w, 3], device=device, dtype=t_real) 118 + gridlines = torch.zeros([h,w], device=device) 119 + 120 + def project(p): 121 + return torch.view_as_real(p).permute(1,0) 122 + 123 + ones = torch.ones_like(p_positions) 124 + global direction 125 + direction = flow * torch.polar(ones.real, 1 * tau * ones.real) 126 + def next_positions(p, i): 127 + global direction 128 + #return blus(p, ones) - ones 129 + if randomize: 130 + direction = flow * torch.polar(ones.real, (random() * random_range - random_range / 2) * ones.real) 131 + result = blus(p, direction, phi) - direction * ebb 132 + res_abs = result.abs() 133 + if rescaling: 134 + result = torch.polar(res_abs / res_abs.max(), result.angle()) 135 + return result 136 + 137 + for i in range(10): 138 + frac = i / 10 139 + gridlines[math.floor(frac*h), :] = 1 140 + gridlines[:, math.floor(frac*w)] = 1 141 + 142 + def insert_at_coords(coords, values, target, mapping: region_mapping): 143 + (region, hw) = mapping 144 + (xrange, yrange) = region 145 + (h,w) = hw 146 + (x_min, x_max) = xrange 147 + (y_min, y_max) = yrange 148 + 149 + mask = torch.ones([particle_count], device=device) 150 + mask *= (coords[1] >= x_min) * (coords[1] <= x_max) 151 + mask *= (coords[0] >= y_min) * (coords[0] <= y_max) 152 + in_range = mask.nonzero().squeeze() 153 + 154 + # TODO: combine coord & value tensors so there's only one index_select necessary 155 + coords_filtered = torch.index_select(coords.permute(1,0), 0, in_range) 156 + values_filtered = torch.index_select(values, 0, in_range) 157 + 158 + coords_filtered[:,1] -= x_min 159 + coords_filtered[:,1] *= (w-1) / (x_max - x_min) 160 + coords_filtered[:,0] -= y_min 161 + coords_filtered[:,0] *= (h-1) / (y_max - y_min) 162 + indices = coords_filtered.long() 163 + 164 + target.index_put_((indices[:,0],indices[:,1]), values_filtered, accumulate=True) 165 + 166 + 167 + for iteration in range(iterations): 168 + p_projected = project(p_positions).clone() 169 + 170 + if iteration % 1 == 0: 171 + 172 + scratch *= 0 173 + insert_at_coords(p_projected, p_colors, scratch, mapping) 174 + canvas += scratch.permute(2,0,1) 175 + 176 + 177 + temp = canvas.clone() 178 + #for d in range(3): 179 + # temp[d] -= temp[d].mean() 180 + # temp[d] /= 8 * temp[d].std() 181 + # temp[d] -= temp[d].min() 182 + 183 + temp -= temp.min() 184 + temp = torch.log(temp) 185 + temp /= temp.max() 186 + 187 + if iteration % agg_every == 0: 188 + p_low, p_high = torch.quantile(temp[:,(2*h//5):(3*h//5),:], quantile_range) 189 + # temp[0] += gridlines 190 + # temp[1] += gridlines 191 + # temp[2] += gridlines 192 + 193 + if grid_on_agg: 194 + save((1 - temp).clamp_(0.0, 1.0) - gridlines, f"{run_dir}/aggregate/_{iteration:06d}") 195 + else: 196 + save((1 - temp).clamp_(0.0, 1.0), f"{run_dir}/aggregate/_{iteration:06d}") 197 + temp = (temp - p_low) / (1e-7 + p_high - p_low) 198 + #save(1 - temp, f"{run_dir}/aggregate/{iteration:06d}") 199 + #scratch /= scratch.max() 200 + #scratch = scratch.sqrt().sqrt() 201 + if show_gridlines: 202 + scratch[:,:,1] += gridlines 203 + scratch[:,:,2] += gridlines 204 + if iteration % save_every == 0: 205 + save(1 - scratch.permute(2,0,1), f"{run_dir}/frame/{iteration:06d}") 206 + if iteration % yell_every == 0: 207 + print(f"{iteration} iterations") 208 + 209 + p_positions = next_positions(p_positions, iteration) 210 + 211 + #for d in range(3): 212 + # canvas[d] -= canvas[d].mean() 213 + # canvas[d] /= 8 * canvas[d].std() 214 + # canvas[d] -= canvas[d].min() 215 + 216 + 217 + torch.set_default_device(device) 218 +
+396
sketch/sdxl.py
··· 1 + import math 2 + import gc 3 + 4 + import torch 5 + import numpy as np 6 + import torch.nn as nn 7 + from torch.nn import functional as func 8 + from PIL import Image 9 + 10 + from diffusers import UNet2DConditionModel 11 + import transformers 12 + 13 + from lib.diffusion.guidance import * 14 + from lib.diffusion.schedule import * 15 + from lib.diffusion.sdxl_encoder import PromptEncoder 16 + from lib.diffusion.sdxl_vae import Decoder, save_approx_decode 17 + 18 + from lib.ode import _euler_step as euler_step 19 + from lib.log import Timer 20 + #from lib.util import pilify 21 + 22 + ''' 23 + beware: this file is a mess, lots of things are broken and poorly named 24 + ''' 25 + 26 + def save_raw_latents(latents, path): 27 + lmin = latents.min() 28 + l = latents - lmin 29 + lmax = latents.max() 30 + l = latents / lmax 31 + l = l.float() * 127.5 + 127.5 32 + l = l.detach().cpu().numpy() 33 + l = l.round().astype("uint8") 34 + 35 + ims = [] 36 + 37 + for lat in l: 38 + row1 = np.concatenate([lat[0], lat[1]]) 39 + row2 = np.concatenate([lat[2], lat[3]]) 40 + grid = np.concatenate([row1, row2], axis=1) 41 + #for channel in lat: 42 + im = Image.fromarray(grid) 43 + im = im.resize(size=(grid.shape[1]*4, grid.shape[0]*4), resample=Image.NEAREST) 44 + ims += [im] 45 + 46 + for im in ims: 47 + im.save(path) 48 + 49 + 50 + def pilify(latents, vae): 51 + #latents = 1 / vae.config.scaling_factor * latents 52 + latents = 1 / 0.13025 * latents 53 + latents = latents.to(torch.float32)#vae.dtype) 54 + with torch.no_grad(): 55 + images = vae.decode(latents)#.sample 56 + 57 + images = images.detach().mul_(127.5).add_(127.5).clamp_(0,255).round() 58 + #return [images] 59 + images = images.permute(0,2,3,1).cpu().numpy().astype("uint8") 60 + return [Image.fromarray(image) for image in images] 61 + 62 + 63 + torch.set_grad_enabled(False) 64 + torch.backends.cuda.matmul.allow_tf32 = True 65 + torch.set_float32_matmul_precision("medium") 66 + 67 + def persistent(): 68 + model_path = "/ssd/0/ml_models/ql/hf-diff/stable-diffusion-xl-base-0.9" 69 + 70 + main_device = "cuda:0" 71 + decoder_device = "cuda:1" 72 + clip_device = "cuda:1" 73 + 74 + main_dtype = torch.float64 75 + noise_predictor_dtype = torch.float16 76 + decoder_dtype = torch.float32 77 + prompt_encoder_dtype = torch.float16 78 + 79 + with Timer("total"): 80 + with Timer("decoder"): 81 + decoder = Decoder() 82 + decoder.load_safetensors(model_path) 83 + decoder.to(device=decoder_device) 84 + #decoder = torch.compile(decoder, mode="default", fullgraph=True) 85 + 86 + with Timer("noise_predictor"): 87 + noise_predictor = UNet2DConditionModel.from_pretrained( 88 + model_path, subfolder="unet", torch_dtype=noise_predictor_dtype 89 + ) 90 + noise_predictor.to(device=main_device) 91 + 92 + # compilation will not actually happen until first use of noise_predictor 93 + # (as of torch 2.2.2) "default" provides the best result on my machine 94 + # don't use this if you're gonna be changing resolutions a lot 95 + #noise_predictor = torch.compile(noise_predictor, mode="default", fullgraph=True) 96 + 97 + with Timer("clip"): 98 + prompt_encoder = PromptEncoder(model_path, True, (clip_device, main_device), prompt_encoder_dtype) 99 + 100 + 101 + 102 + # TODO: these should come from config! 103 + vae_scale = 0.13025 104 + decoder_dim_scale = 2 ** 3 105 + variance_range = (0.00085, 0.012) 106 + 107 + 108 + seed = lambda run, meta: 2835723 + run 109 + 110 + meta_count = 1 111 + 112 + width, height = 16, 16 113 + 114 + steps = 50 115 + run_count = 10 116 + 117 + timestep_power = 1 118 + timestep_range = (0, 999) 119 + 120 + # TODO prompts are now in sketch/local/prompts, should be simple enough to fetch them 121 + 122 + empty_p = "" 123 + 124 + prompts = { 125 + "encoder_1": [p3, p3], 126 + "encoder_2": None, 127 + "encoder_2_pooled": None 128 + } 129 + _lerp = lambda t, a, b: t * b + (1-t) * a 130 + 131 + combine_predictions = lambda s,r: scaled_CFG( 132 + difference_scales = [ 133 + (1, -1, lambda x: x) 134 + #(1, 0, lambda x: _lerp(s/steps, 1.0, 0.3) * x), 135 + #(2, 0, lambda x: _lerp(s/steps, 0.3, 1.0) * x), 136 + ], 137 + steering_scale = lambda x: 1 * x, 138 + #base_term = lambda predictions, true_noise: predictions[0], 139 + base_term = lambda predictions, true_noise: true_noise, 140 + total_scale = lambda predictions, cfg_result: cfg_result 141 + ) 142 + 143 + # TNR 144 + _scale = lambda step, scale: _lerp(step / (steps - 1), scale / steps, 0.1 * scale / steps) 145 + a,b,c = -7, 10, 15 146 + #combine_predictions = lambda step, run: lambda p, n: _scale(step, a + c * run / (run_count-1)) * (n - p[0]) + _scale(step, b - c * run / (run_count-1)) * (n - p[1]) 147 + #combine_predictions = lambda step, run: lambda p, n: _scale(step, 4) * (n - p[1]) + _scale(step, 2.5) * (p[0] - n) 148 + combine_predictions = lambda step, run: lambda p, n: _scale(step, 3) * (n - p[0]) 149 + #combine_predictions = lambda step, run: lambda p, n: _scale(step, _lerp(step/steps, 0.0, 0.0)) * (n - p[0]) + _scale(step, 2) * (n - p[1]) 150 + 151 + 152 + diffusion_method = "tnr" 153 + 154 + solver_step = euler_step 155 + 156 + save_raw = lambda run, step: False 157 + save_approximates = lambda run, step: True 158 + save_final = True 159 + 160 + def pre_step(step, latents): 161 + return latents 162 + 163 + def post_step(step, latents): 164 + return latents 165 + 166 + def main(): 167 + 168 + schedule(meta_run, range(meta_count)) 169 + 170 + def meta_run(meta_id): 171 + forward_noise_schedule = default_variance_schedule(variance_range).to(main_dtype) # beta 172 + forward_noise_total = forward_noise_schedule.cumsum(dim=0) 173 + forward_signal_product = torch.cumprod((1 - forward_noise_schedule), dim=0) # alpha_bar 174 + partial_signal_product = lambda s, t: torch.prod((1 - forward_noise_schedule)[s+1:t]) # alpha_bar_t / alpha_bar_s (but computed more directly from the forward noise) 175 + part_noise = (1 - forward_signal_product).sqrt() # sigma 176 + part_signal = forward_signal_product.sqrt() # mu? 177 + 178 + schedule(run, range(run_count)) 179 + 180 + def get_signal_ratio(from_timestep, to_timestep): 181 + if from_timestep < to_timestep: # forward 182 + return 1 / partial_signal_product(from_timestep, to_timestep).sqrt() 183 + else: # backward 184 + return partial_signal_product(to_timestep, from_timestep).sqrt() 185 + 186 + def step_by_noise(latents, noise, from_timestep, to_timestep): 187 + signal_ratio = get_signal_ratio(from_timestep, to_timestep) 188 + return latents / signal_ratio + noise * (part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio) 189 + 190 + def stupid_simple_step_by_noise(latents, noise, from_timestep, to_timestep): 191 + signal_ratio = get_signal_ratio(from_timestep, to_timestep) 192 + return latents / signal_ratio + noise * (1 - 1 / signal_ratio) 193 + 194 + def cfgpp_step_by_noise(latents, combined, base, from_timestep, to_timestep): 195 + signal_ratio = get_signal_ratio(from_timestep, to_timestep) 196 + return latents / signal_ratio + base * part_noise[to_timestep] - combined * (part_noise[from_timestep] / signal_ratio) 197 + 198 + def tnr_step_by_noise(latents, diff_term, base_term, from_timestep, to_timestep): 199 + signal_ratio = get_signal_ratio(from_timestep, to_timestep) 200 + diff_coefficient = part_noise[from_timestep] / signal_ratio 201 + base_coefficient = part_noise[to_timestep] - diff_coefficient 202 + #print((1/signal_ratio).item(), base_coefficient.item(), diff_coefficient.item()) 203 + return latents / signal_ratio + base_term * base_coefficient + diff_term * diff_coefficient 204 + 205 + def tnrb_step_by_noise(latents, diff_term, base_term, from_timestep, to_timestep): 206 + signal_ratio = get_signal_ratio(from_timestep, to_timestep) 207 + base_coefficient = part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio 208 + measure = lambda x: x.abs().max().item() 209 + #print(measure(latents / signal_ratio), measure(base_term * base_coefficient), measure(diff_term)) 210 + return latents / signal_ratio + base_term * base_coefficient + diff_term 211 + 212 + def shuffle_step(latents, first_noise, second_noise, timestep, intermediate_timestep): 213 + if from_timestep < to_timestep: # forward 214 + signal_ratio = 1 / partial_signal_product(timestep, intermediate_timestep).sqrt() 215 + else: # backward 216 + signal_ratio = partial_signal_product(intermediate_timestep, timestep).sqrt() 217 + return latents + (first_noise - second_noise) * (part_noise[intermediate_timestep] * signal_ratio - part_noise[timestep]) 218 + 219 + def index_interpolate(source, index): 220 + frac, whole = math.modf(index) 221 + if frac == 0: 222 + return source[int(whole)] 223 + return lerp(source[int(whole)], source[int(whole)+1], frac) 224 + 225 + def run(run_id): 226 + try: 227 + _seed = int(seed(run_id, meta_id)) 228 + except: 229 + _seed = 0 230 + print(f"non-integer seed, run {run_id}. replaced with 0.") 231 + 232 + torch.manual_seed(_seed) 233 + np.random.seed(_seed) 234 + 235 + diffusion_timesteps = linspace_timesteps(steps+1, timestep_range[1], timestep_range[0], timestep_power) 236 + 237 + noise_predictor_batch_size = len(prompts["encoder_1"]) 238 + 239 + (all_penult_states, enc2_pooled) = prompt_encoder.encode(prompts["encoder_1"], prompts["encoder_2"], prompts["encoder_2_pooled"]) 240 + 241 + global width 242 + global height 243 + if (width < 64): width *= 64 244 + if (height < 64): height *= 64 245 + 246 + latents = torch.zeros( 247 + (1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale), 248 + device=main_device, 249 + dtype=main_dtype 250 + ) 251 + 252 + noises = torch.randn( 253 + #(run_context.steps, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale), 254 + (1, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale), 255 + device=main_device, 256 + dtype=main_dtype 257 + ) 258 + 259 + latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[0]) 260 + 261 + original_size = (height, width) 262 + target_size = (height, width) 263 + crop_coords_top_left = (0, 0) 264 + 265 + # incomprehensible var name tbh go read the sdxl paper if u want to Understand 266 + add_time_ids = torch.tensor([list(original_size + crop_coords_top_left + target_size)], dtype=noise_predictor_dtype).repeat(noise_predictor_batch_size,1).to("cuda") 267 + 268 + added_cond_kwargs = {"text_embeds": enc2_pooled.to(noise_predictor_dtype), "time_ids": add_time_ids} 269 + 270 + for step_index in range(steps): 271 + noise = noises[0] 272 + 273 + start_timestep = index_interpolate(diffusion_timesteps, step_index).round().int() 274 + end_timestep = index_interpolate(diffusion_timesteps, step_index + 1).round().int() 275 + 276 + end_noise = part_noise[end_timestep] 277 + end_signal = part_signal[end_timestep] 278 + start_noise = part_noise[start_timestep] 279 + start_signal = part_signal[start_timestep] 280 + signal_ratio = get_signal_ratio(start_timestep, end_timestep) 281 + start = start_timestep 282 + end = end_timestep 283 + 284 + 285 + sigratio = get_signal_ratio(start_timestep, end_timestep) 286 + 287 + def predict_noise(latents, step=0): 288 + return noise_predictor( 289 + latents.repeat(noise_predictor_batch_size, 1, 1, 1).to(noise_predictor_dtype), 290 + index_interpolate(diffusion_timesteps, step_index + step).round().int(), 291 + encoder_hidden_states=all_penult_states.to(noise_predictor_dtype), 292 + return_dict=False, 293 + added_cond_kwargs=added_cond_kwargs 294 + )[0] 295 + 296 + def standard_predictor(combiner): 297 + def _predict(latents, step=0): 298 + predictions = predict_noise(latents, step) 299 + return predictions, noise, combiner(predictions, noise) 300 + return _predict 301 + 302 + def constructive_predictor(combiner): 303 + def _predict(latents, step=0): 304 + noised = step_by_noise(latents, noise, 0, index_interpolate(diffusion_timesteps, step_index + step).round().int()) 305 + predictions = predict_noise(noised, step) 306 + return predictions, noise, combiner(latents, predictions, noise) 307 + return _predict 308 + 309 + 310 + def standard_diffusion_step(latents, noises, start, end): 311 + start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int() 312 + end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int() 313 + predictions, true_noise, combined_prediction = noises 314 + return step_by_noise(latents, combined_prediction, start_timestep, end_timestep) 315 + 316 + def stupid_simple_step(latents, noises, start, end): 317 + start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int() 318 + end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int() 319 + predictions, true_noise, combined_prediction = noises 320 + return stupid_simple_step_by_noise(latents, combined_prediction, start_timestep, end_timestep) 321 + 322 + def cfgpp_diffusion_step(choose_base, choose_combined): 323 + def _diffusion_step(latents, noises, start, end): 324 + start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int() 325 + end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int() 326 + return cfgpp_step_by_noise(latents, choose_combined(noises), choose_base(noises), start_timestep, end_timestep) 327 + return _diffusion_step 328 + 329 + def tnr_diffusion_step(latents, noises, start, end): 330 + start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int() 331 + end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int() 332 + predictions, true_noise, combined_prediction = noises 333 + return tnr_step_by_noise(latents, combined_prediction, predictions[0], start_timestep, end_timestep) 334 + 335 + def tnrb_diffusion_step(latents, noises, start, end): 336 + start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int() 337 + end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int() 338 + predictions, true_noise, combined_prediction = noises 339 + return tnrb_step_by_noise(latents, combined_prediction, predictions[0], start_timestep, end_timestep) 340 + 341 + def constructive_step(latents, noises, start, end): 342 + start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int() 343 + end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int() 344 + predictions, true_noise, combined_prediction = noises 345 + return latents + combined_prediction 346 + 347 + def select_prediction(index): 348 + return lambda noises: noises[0][index] 349 + 350 + select_true_noise = lambda noises: noises[1] 351 + select_combined = lambda noises: noises[2] 352 + 353 + 354 + if diffusion_method == "standard": 355 + take_step = standard_diffusion_step 356 + if diffusion_method == "stupid": 357 + take_step = stupid_simple_step 358 + if diffusion_method == "cfg++": 359 + take_step = cfgpp_diffusion_step(select_prediction(0), select_combined) 360 + if diffusion_method == "tnr": 361 + take_step = tnr_diffusion_step 362 + if diffusion_method == "tnrb": 363 + take_step = tnrb_diffusion_step 364 + 365 + if diffusion_method == "cons": 366 + take_step = constructive_step 367 + get_derivative = constructive_predictor(combine_predictions) 368 + else: 369 + get_derivative = standard_predictor(combine_predictions(step_index, run_id)) 370 + 371 + solver = solver_step 372 + 373 + latents = pre_step(step_index, latents) 374 + 375 + latents = solver(get_derivative, take_step, latents) 376 + 377 + latents = post_step(step_index, latents) 378 + 379 + if step_index < steps - 1 and diffusion_method != "cons": 380 + pred_original_sample = step_by_noise(latents, noise, diffusion_timesteps[step_index+1], diffusion_timesteps[-1]) 381 + else: 382 + pred_original_sample = latents 383 + 384 + if save_raw(run_id, step_index): 385 + save_raw_latents(pred_original_sample, f"out/{run_dir}/{run_id}_raw_{step_index:03d}.png") 386 + if save_approximates(run_id, step_index): 387 + save_approx_decode(pred_original_sample, f"out/{run_dir}/{run_id}_approx_{step_index:03d}.png") 388 + 389 + if save_final: 390 + images_pil = pilify(pred_original_sample.to(device=decoder_device), decoder) 391 + 392 + for im in images_pil: 393 + for n in range(len(images_pil)): 394 + images_pil[n].save(f"out/{run_dir}/{meta_id:05d}_{n}_{run_id:05d}_{step_index:05d}.png") 395 + 396 +
+120 -44
sketch/webserver.py
··· 6 6 7 7 #from lib import log 8 8 9 - # TODO make it upgrade to https instead of just fucking throwing exceptions 9 + MODE = "local" 10 10 11 - HOST = "0.0.0.0" 12 - PORT = 1313 13 - BASE_DIR = Path("./webui").resolve() 14 - USE_SSL = False 11 + if MODE == "local": 12 + HOST = "0.0.0.0" 13 + PORT, SSL_PORT = 1313, None 14 + BASE_DIR = Path("./webui").resolve() 15 + USE_SSL = False 16 + SSL_CERT = "/home/ponder/ponder/certs/cert.pem" 17 + SSL_KEY = "/home/ponder/ponder/certs/key.pem" 18 + elif MODE == "remote": 19 + HOST = "ponder.ooo" 20 + PORT, SSL_PORT = 80, 443 21 + BASE_DIR = Path("./webui").resolve() 22 + USE_SSL = True 23 + SSL_KEY="/etc/letsencrypt/live/ponder.ooo/privkey.pem" 24 + SSL_CERT="/etc/letsencrypt/live/ponder.ooo/fullchain.pem" 15 25 16 26 ROUTES = { 17 27 "/": "content/main.html", ··· 26 36 ".orb": "text/x-orb" 27 37 } 28 38 29 - SSL_CERT = "/home/ponder/ponder/certs/cert.pem" 30 - SSL_KEY = "/home/ponder/ponder/certs/key.pem" 31 39 32 40 def guess_mime_type(path): 33 41 return MIME_TYPES.get(Path(path).suffix, "application/octet-stream") ··· 35 43 def build_response(status_code, body=b"", content_type="text/plain"): 36 44 reason = { 37 45 200: "OK", 46 + 301: "Moved Permanently", 38 47 400: "Bad Request", 39 48 404: "Not Found", 40 49 405: "Method Not Allowed", ··· 48 57 f"\r\n" 49 58 ).encode() + body 50 59 51 - def route(req_path): 52 - if req_path == "/": 53 - return "content/main.html" 54 - if any(req_path.endswith(x) for x in [".html", ".png"]): 55 - return f"content/{req_path}" 56 - if any(req_path.endswith(x) for x in [".wgsl"]): 57 - return f"content/wgsl/{req_path}" 58 - if any(req_path.endswith(x) for x in [".orb"]): 59 - return f"content/orb/{req_path}" 60 - if req_path.endswith(".js"): 61 - return f"js/{req_path}" 62 - return req_path[1:] 60 + def route(path): 61 + path = path.relative_to('/') if path.is_absolute() else path 62 + 63 + if path == Path(): 64 + return BASE_DIR / "content/main.html" 65 + 66 + suffix = path.suffix 67 + 68 + if suffix in [".html", ".png"]: 69 + return BASE_DIR / Path("content") / path 70 + if suffix in [".wgsl"]: 71 + return BASE_DIR / Path("content/wgsl") / path 72 + if suffix in [".orb"]: 73 + return BASE_DIR / Path("content/orb") / path 74 + if suffix in [".js"]: 75 + return BASE_DIR / Path("js") / path 76 + return BASE_DIR / path 63 77 64 78 def handle_request(request_data): 65 79 try: ··· 71 85 if method != "GET": 72 86 return build_response(405, b"Method Not Allowed") 73 87 74 - req_path = urlparse(unquote(raw_path)).path 88 + req_path = Path(urlparse(unquote(raw_path)).path).resolve(strict=False) 75 89 76 - norm_path = os.path.normpath(req_path) 90 + routed_path = route(req_path).resolve(strict=False) 77 91 78 - if ".." in norm_path: 92 + if not routed_path.is_relative_to(BASE_DIR): 79 93 return build_response(400, b"Fuck You") 80 94 81 - file_path = route(norm_path) 82 - 83 - if not file_path: 84 - return build_response(404, b"Not Found") 85 - 86 - full_path = BASE_DIR / file_path 87 - if not full_path.exists(): 95 + if not routed_path.exists(): 88 96 return build_response(404, b"File not found") 89 97 90 - with open(full_path, "rb") as f: 98 + with open(routed_path, "rb") as f: 91 99 body = f.read() 92 - return build_response(200, body, guess_mime_type(file_path)) 100 + return build_response(200, body, guess_mime_type(routed_path)) 93 101 94 102 except Exception as e: 95 103 return build_response(500, f"Server error: {e}".encode()) 96 104 97 - def main(): 105 + 106 + 107 + def build_redirect_response(location): 108 + return ( 109 + f"HTTP/1.1 301 Moved Permanently\r\n" 110 + f"Location: {location}\r\n" 111 + f"Content-Length: 0\r\n" 112 + f"Connection: close\r\n" 113 + f"\r\n" 114 + ).encode() 115 + 116 + def handle_http_redirect(conn): 117 + request = conn.recv(4096) 118 + try: 119 + lines = request.decode().split("\r\n") 120 + if lines: 121 + method, raw_path, *_ = lines[0].split() 122 + path = urlparse(unquote(raw_path)).path 123 + else: 124 + path = "/" 125 + except: 126 + path = "/" 127 + redirect = build_redirect_response(f"https://{HOST}{path}") 128 + conn.sendall(redirect) 129 + 130 + def http_redirect_server(): 131 + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server: 132 + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 133 + server.bind((HOST, PORT)) 134 + server.listen() 135 + print(f"redirecting on {HOST}:{PORT}") 136 + while True: 137 + try: 138 + conn, addr = server.accept() 139 + except KeyboardInterrupt: 140 + break 141 + except: 142 + continue 143 + with conn: 144 + try: 145 + handle_http_redirect(conn) 146 + except: 147 + pass 148 + 149 + 150 + def ssl_main(): 98 151 ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 99 152 ssl_ctx.load_cert_chain(certfile=SSL_CERT, keyfile=SSL_KEY) 100 153 154 + #threading.Thread(target=http_redirect_server, daemon=True).start() 155 + 101 156 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server: 102 157 server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 103 - server.bind((HOST, PORT)) 158 + server.bind((HOST, SSL_PORT)) 104 159 server.listen() 105 - print(f"Serving HTTPS on {HOST}:{PORT}") 160 + print(f"Serving HTTPS on {HOST}:{SSL_PORT}") 106 161 while True: 107 162 try: 108 163 conn, addr = server.accept() 109 - if USE_SSL: 110 - with ssl_ctx.wrap_socket(conn, server_side=True) as ssl_conn: 111 - request = ssl_conn.recv(4096) 112 - response = handle_request(request) 113 - ssl_conn.sendall(response) 114 - else: 115 - request = conn.recv(4096) 164 + conn.settimeout(10.0) 165 + with ssl_ctx.wrap_socket(conn, server_side=True) as ssl_conn: 166 + request = ssl_conn.recv(4096) 116 167 response = handle_request(request) 117 - conn.sendall(response) 168 + ssl_conn.sendall(response) 169 + except KeyboardInterrupt: 170 + raise 171 + except TimeoutError: 172 + continue 173 + except Exception as e: 174 + print(f"Error: {e}") 175 + exit() 176 + 177 + 178 + def main(): 179 + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server: 180 + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 181 + server.bind((HOST, PORT)) 182 + server.listen() 183 + print(f"Serving HTTP on {HOST}:{PORT}") 184 + while True: 185 + try: 186 + conn, addr = server.accept() 187 + request = conn.recv(4096) 188 + response = handle_request(request) 189 + conn.sendall(response) 118 190 except KeyboardInterrupt: 119 191 raise 120 - except: 192 + except Exception as e: 193 + print(e) 121 194 break 122 195 123 196 if __name__ == "__main__": 124 - main() 197 + if USE_SSL: 198 + ssl_main() 199 + else: 200 + main()
+1 -1
webui/content/orb/projective_shift.orb
··· 8 8 in 9 9 10 10 z_{n+1} = (r + cos(theta-phi))e^{i theta psi} - c 11 - where z_n = r cos theta 11 + where z_n = r e^{i theta} 12 12 } } 13 13 14 14 b{click and drag} to move around, b{scroll} to zoom in and out; b{press F} to go fullscreen
+134 -14
webui/content/wgsl/proj_shift.wgsl
··· 10 10 max_iter: u32, 11 11 escape_distance: f32, 12 12 psi: f32, 13 + d: f32, 14 + twist: f32, 15 + squoosh_x: f32, 16 + squoosh_y: f32 13 17 } 14 18 15 19 @group(0) @binding(0) var<uniform> uniforms: Uniforms; ··· 36 40 return vec2<f32>(new_mag * cos(x_angle * psi), new_mag * sin(x_angle * psi)); 37 41 } 38 42 39 - fn iterate_polar(x: vec2<f32>, phi: f32, psi: f32, c: f32) -> vec2<f32> { 43 + fn iterate_polar(x: vec2<f32>, phi: f32, psi: f32, c: f32, d: f32, cosTwist: f32, sinTwist: f32, squoosh_x: f32, squoosh_y: f32) -> vec2<f32> { 40 44 let shifted = projective_shift(x, phi, psi); 41 - return vec2<f32>(shifted.x - c, shifted.y); 45 + let halfEbb = vec2<f32>(-c * 0.5, -d * 0.5); 46 + let mid = shifted + halfEbb; 47 + let mids = vec2<f32>(mid.x / squoosh_x, mid.y / squoosh_y); 48 + let rotated = vec2<f32>(mids.x * cosTwist - mids.y * sinTwist, mids.x * sinTwist + mids.y * cosTwist); 49 + return rotated + halfEbb; 42 50 } 43 51 44 52 fn iterate_cartesian(z: vec2<f32>, phi: f32, c: f32) -> vec2<f32> { ··· 68 76 return vec2<f32>(x, y); 69 77 } 70 78 79 + fn pixel_to_complex_alt(px: u32, py: u32) -> vec2<f32> { 80 + let aspect = f32(uniforms.width) / f32(uniforms.height); 81 + let scale = 4.0 / uniforms.zoom; 82 + let half_width = f32(uniforms.width) * 0.5; 83 + let half_height = f32(uniforms.height) * 0.5; 84 + 85 + // map y-axis to magnitude, x-axis to angle 86 + let angle = (f32(px) - half_width) * scale * aspect / f32(uniforms.width) + uniforms.center_x; 87 + let mag = (f32(py) - half_height) * scale / (2.0 * f32(uniforms.height)) + uniforms.center_y; 88 + 89 + return vec2<f32>(mag * cos(angle), mag * sin(angle)); 90 + } 91 + 92 + fn pixel_to_complex_inverted(px: u32, py: u32) -> vec2<f32> { 93 + let aspect = f32(uniforms.width) / f32(uniforms.height); 94 + let scale = 4.0 / uniforms.zoom; 95 + let half_width = f32(uniforms.width) * 0.5; 96 + let half_height = f32(uniforms.height) * 0.5; 97 + 98 + let x = (f32(px) - half_width) * scale * aspect / f32(uniforms.width) + uniforms.center_x; 99 + let y = (f32(py) - half_height) * scale / f32(uniforms.height) + uniforms.center_y; 100 + 101 + let z = vec2<f32>(x, y); 102 + let mag_sq = x * x + y * y; 103 + 104 + // handle singularity at origin 105 + if (mag_sq < 1e-10) { 106 + return vec2<f32>(1e10, 0.0); // or whatever you want infinity to map to 107 + } 108 + 109 + return vec2<f32>(x / mag_sq, -y / mag_sq); 110 + } 111 + 112 + fn pixel_to_complex_inverted_d(px: u32, py: u32) -> vec2<f32> { 113 + let aspect = f32(uniforms.width) / f32(uniforms.height); 114 + let scale = 4.0 / uniforms.zoom; 115 + let half_width = f32(uniforms.width) * 0.5; 116 + let half_height = f32(uniforms.height) * 0.5; 117 + 118 + // get position relative to center 119 + let x = (f32(px) - half_width) * scale * aspect / f32(uniforms.width); 120 + let y = (f32(py) - half_height) * scale / f32(uniforms.height); 121 + 122 + let mag_sq = x * x + y * y; 123 + if (mag_sq < 1e-10) { 124 + return vec2<f32>(uniforms.center_x + 1e10, uniforms.center_y); 125 + } 126 + 127 + // invert then add center back 128 + return vec2<f32>( 129 + uniforms.center_x + x / mag_sq, 130 + uniforms.center_y - y / mag_sq 131 + ); 132 + } 133 + 134 + fn pixel_to_complex_inverted_b(px: u32, py: u32) -> vec2<f32> { 135 + let aspect = f32(uniforms.width) / f32(uniforms.height); 136 + let scale = 4.0 / uniforms.zoom; 137 + let half_width = f32(uniforms.width) * 0.5; 138 + let half_height = f32(uniforms.height) * 0.5; 139 + 140 + let x = (f32(px) - half_width) * scale * aspect / f32(uniforms.width) + uniforms.center_x; 141 + let y = (f32(py) - half_height) * scale / f32(uniforms.height) + uniforms.center_y; 142 + 143 + // translate by (c/2, d/2) 144 + let tx = x + uniforms.c * 0.5; 145 + let ty = y + uniforms.d * 0.5; 146 + 147 + let mag_sq = tx * tx + ty * ty; 148 + 149 + if (mag_sq < 1e-10) { 150 + return vec2<f32>(1e10, 0.0); 151 + } 152 + 153 + // invert, then translate back 154 + return vec2<f32>(tx / mag_sq - uniforms.c * 0.5, -ty / mag_sq - uniforms.d * 0.5); 155 + } 156 + 157 + fn pixel_to_complex_inverted_c(px: u32, py: u32) -> vec2<f32> { 158 + let aspect = f32(uniforms.width) / f32(uniforms.height); 159 + let scale = 4.0 / uniforms.zoom; 160 + let half_width = f32(uniforms.width) * 0.5; 161 + let half_height = f32(uniforms.height) * 0.5; 162 + 163 + let x = (f32(px) - half_width) * scale * aspect / f32(uniforms.width) + uniforms.center_x; 164 + let y = (f32(py) - half_height) * scale / f32(uniforms.height) + uniforms.center_y; 165 + 166 + // translate by (c/2, d/2), then scale down by 0.5 167 + let tx = (x + uniforms.c * 0.5) * 2.0; 168 + let ty = (y + uniforms.d * 0.5) * 2.0; 169 + 170 + let mag_sq = tx * tx + ty * ty; 171 + 172 + if (mag_sq < 1e-10) { 173 + return vec2<f32>(1e10, 0.0); 174 + } 175 + 176 + // invert, scale up by 2x, then translate back 177 + return vec2<f32>((tx / mag_sq) * 0.5 - uniforms.c * 0.5, (-ty / mag_sq) * 0.5 - uniforms.d * 0.5); 178 + } 179 + 71 180 fn c_avg(prev: f32, val: f32, n: f32) -> f32 { 72 181 return prev + (val - prev) / n; 73 182 } ··· 80 189 return x; 81 190 } 82 191 192 + fn pcg_hash(x: u32) -> u32 { 193 + let state = x * 747796405u + 2891336453u; 194 + let word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; 195 + return (word >> 22u) ^ word; 196 + } 197 + 83 198 fn random_float(seed: u32) -> f32 { 84 199 return f32(hash(seed)) / 4294967296.0; // 2^32 85 200 } ··· 93 208 return; 94 209 } 95 210 96 - var z = pixel_to_complex(px, py); 97 - let n_perturbations = 10u; 211 + var z = ${pixel_mapping}(px, py); 212 + let orig_z = z; 213 + let n_perturbations = 1u; 98 214 let escape_threshold = uniforms.escape_distance * uniforms.escape_distance; 99 215 100 216 var escaped = false; 101 217 var iter = 0u; 102 218 var cavg = f32(0.0); 103 - let epsilon = 1.19e-4; 219 + let epsilon = 1.19e-6; 220 + 221 + let transient_skip = u32(f32(uniforms.max_iter) * f32(0.95)); 104 222 105 - let transient_skip = u32(f32(uniforms.max_iter) * f32(0.9)); 223 + let cosTwist = cos(uniforms.twist); 224 + let sinTwist = sin(uniforms.twist); 106 225 107 226 for (iter = 0u; iter < uniforms.max_iter; iter = iter + 1u) { 108 227 let mag_sq = complex_mag_sq(z); ··· 110 229 escaped = true; 111 230 //break; 112 231 } 113 - var z_p: array<vec2<f32>, 4>; 114 - var df_z_p: array<vec2<f32>, 4>; 115 - var abs_df_z_p: array<f32, 4>; 232 + var z_p: array<vec2<f32>, 10>; 233 + var df_z_p: array<vec2<f32>, 10>; 234 + var abs_df_z_p: array<f32, 10>; 116 235 var avg_abs_df_z_p = f32(0.0); 117 236 118 - var f_z = iterate_polar(z, uniforms.phi, uniforms.psi, uniforms.c); 237 + var f_z = iterate_polar(z, uniforms.phi, uniforms.psi, uniforms.c, uniforms.d, cosTwist, sinTwist, uniforms.squoosh_x, uniforms.squoosh_y); 119 238 120 239 if (iter >= transient_skip) { 121 240 if (cavg < -10.0) { continue; } 122 241 if (cavg > 10.0) { continue; } 123 242 for (var p = 0u; p < n_perturbations; p = p + 1u) { 124 - let rng_seed = hash(px * 1000u + py * 2000u + p * 5000u); 243 + let rng_seed = pcg_hash(px + pcg_hash(py << 1)) ^ pcg_hash(p); 125 244 let random_angle = random_float(rng_seed) * 6.283185307; 126 245 let perturb_direction = vec2<f32>(cos(random_angle), sin(random_angle)); 127 246 z_p[p] = z + perturb_direction * epsilon; 128 - df_z_p[p] = iterate_polar(z, uniforms.phi, uniforms.psi, uniforms.c) - f_z; 129 - abs_df_z_p[p] = abs(complex_mag(df_z_p[p]) / epsilon); 247 + df_z_p[p] = iterate_polar(z_p[p], uniforms.phi, uniforms.psi, uniforms.c, uniforms.d, cosTwist, sinTwist, uniforms.squoosh_x, uniforms.squoosh_y) - orig_z; 248 + abs_df_z_p[p] = abs(complex_mag(df_z_p[p])); 130 249 avg_abs_df_z_p += abs_df_z_p[p]; 131 250 } 132 251 ··· 136 255 } 137 256 138 257 z = f_z; 258 + //z = avg_abs_df_z_p; 139 259 } 140 260 141 261 var color: vec4<f32>; ··· 156 276 color = vec4<f32>(r, scaled_mag, b, 1.0); 157 277 } 158 278 159 - color = vec4<f32>(cavg / 1.0, -cavg / 20.0, -cavg / 20.0, 1.0); 279 + color = vec4<f32>(cavg / 10.0, -cavg / 5.0, -cavg / 5.0, 1.0); 160 280 161 281 textureStore(output_texture, vec2<i32>(i32(px), i32(py)), color); 162 282 }
+17
webui/js/control/button.js
··· 1 + 2 + $css(` 3 + 4 + `); 5 + 6 + const defaults = { 7 + label: "button", 8 + action: () => {} 9 + }; 10 + 11 + export async function main(target, spec) { 12 + spec = { ...defaults, ...spec }; 13 + 14 + // todo 15 + } 16 + 17 +
+57 -27
webui/js/control/menu.js
··· 4 4 position: absolute; 5 5 margin: 0; 6 6 padding: 0; 7 + top: 0; 8 + left: 0; 7 9 border: none; 8 10 width: 100%; 9 11 height: 100%; 10 12 display: none; 13 + pointer-events: none; 11 14 } 12 15 13 16 .context-menu { ··· 19 22 font-size: 0.875rem; 20 23 user-select: none; 21 24 z-index: 10; 25 + pointer-events: auto; 22 26 } 23 27 24 - .context-menu[centered] { 28 + .context-menu[data-centered] { 25 29 position: absolute; 26 30 left: 50%; 27 31 top: 50%; ··· 100 104 menu.setAttribute("role", "menu"); 101 105 menu.setAttribute("aria-orientation", "vertical"); 102 106 103 - menu.addEventListener("mouseenter", () => { 104 - //menu.firstChild?.blur(); 105 - menu.focus(); 106 - }); 107 - 108 - const onBackdropClick = (e) => { 109 - if (e.target !== backdrop) return; 110 - e.preventDefault(); 107 + menu.addEventListener("focusout", (e) => { 108 + if (menu.contains(e.relatedTarget)) return; 111 109 112 110 backdrop.style.display = "none"; 113 111 menu.$.previousFocus?.focus(); 114 - 115 - // don't make user click twice when clicking away from the context menu 116 - const clickTarget = document.elementFromPoint(e.clientX, e.clientY); 117 - if (clickTarget) { 118 - clickTarget.focus(); 119 - clickTarget.dispatchEvent(new MouseEvent(e.type, { 120 - bubbles: true, 121 - cancelable: true, 122 - clientX: e.clientX, 123 - clientY: e.clientY 124 - })); 125 - } 126 - }; 127 - 128 - backdrop.addEventListener("click", onBackdropClick); 129 - backdrop.addEventListener("contextmenu", onBackdropClick); 112 + }); 130 113 131 114 menu.addEventListener("keydown", (e) => { 132 115 if (!["ArrowDown", "ArrowUp", "j", "k", "Escape"].includes(e.key)) return; ··· 161 144 const showMenu = (target, position = null) => { 162 145 document.body.appendChild(backdrop); 163 146 backdrop.style.display = "block"; 147 + 164 148 menu.$.previousFocus = document.activeElement; 165 149 menu.firstChild?.focus(); 166 150 167 151 const bounds = target.getBoundingClientRect(); 168 152 169 153 if (!position) { 170 - menu.setAttribute("centered", ""); 154 + menu.dataset.centered = ""; 171 155 menu.style.left = ""; 172 156 menu.style.top = ""; 173 157 return; ··· 175 159 176 160 const {x,y} = position; 177 161 178 - menu.removeAttribute("centered"); 162 + delete menu.dataset.centered; 179 163 menu.style.left = x + "px"; 180 164 menu.style.top = y + "px"; 181 165 ··· 221 205 const select = async () => { 222 206 backdrop.style.display = "none"; 223 207 menu.$.previousFocus?.focus(); 208 + 224 209 await item[1](); 225 210 }; 226 211 ··· 240 225 showMenu(e.target, {x: e.clientX, y: e.clientY}); 241 226 }); 242 227 228 + document.$showMenu = (target) => { 229 + menu.replaceChildren(); 230 + 231 + const items = collectItems(target); 232 + 233 + if (items.length === 0) return; 234 + 235 + items.forEach(item => { 236 + if (!item) return; 237 + if (item === "separator") { // TODO improve this 238 + const separator = document.createElement("div"); 239 + separator.className = "context-menu-separator"; 240 + menu.appendChild(separator); 241 + return; 242 + } 243 + 244 + const menuItem = document.createElement("button"); 245 + menuItem.className = "context-menu-item"; 246 + menu.setAttribute("role", "menuItem"); 247 + menu.setAttribute("tabIndex", "-1"); 248 + 249 + menuItem.textContent = item[0]; 250 + 251 + const select = async () => { 252 + backdrop.style.display = "none"; 253 + menu.$.previousFocus?.focus(); 254 + 255 + console.log("sel"); 256 + await item[1](); 257 + }; 258 + 259 + menuItem.onclick = select; 260 + menuItem.addEventListener("keydown", (e) => { 261 + if (e.key === "o" || e.key === "Enter") { 262 + select(); 263 + e.stopPropagation(); 264 + } 265 + }); 266 + 267 + menu.appendChild(menuItem); 268 + }); 269 + 270 + showMenu(target); 271 + }; 272 +
+21 -6
webui/js/control/number.js
··· 136 136 const copyable_value = document.createElement("span"); 137 137 copyable_value.innerText = `${spec.value};\n`; 138 138 copyable_value.classList = "copyable-value"; 139 - label_eq.appendChild(copyable_value); 140 139 141 140 label_eq.setAttribute("aria-hidden", true); 142 141 ··· 158 157 field.step = spec.step; 159 158 field.value = spec.value; 160 159 160 + const play_button = $element("button"); 161 + play_button.innerText = "โ–ถ/โธ"; 162 + // todo alt text 163 + 164 + 161 165 const set = (value) => { 162 166 slider.value = value; 163 167 field.value = value; 164 168 } 165 169 170 + const reset_button = $element("button"); 171 + reset_button.innerText = "โŸณ"; 172 + reset_button.label = "reset"; 173 + reset_button.addEventListener("click", () => { 174 + set(spec.value); 175 + spec.onUpdate?.(spec.value, set); 176 + }); 177 + 178 + 166 179 slider.addEventListener("input", () => { 167 180 field.value = slider.value; 168 181 copyable_value.innerText = slider.value; ··· 175 188 spec.onUpdate?.(field.value, set); 176 189 }); 177 190 178 - control.appendChild(label); 179 - control.appendChild(label_eq); 180 - control.appendChild(field); 181 - control.appendChild(slider); 182 - target.appendChild(control); 191 + target.$with( 192 + control.$with( 193 + label, label_eq.$with(copyable_value), field, 194 + play_button, reset_button, 195 + slider 196 + ) 197 + ); 183 198 } 184 199
+6 -1
webui/js/control/panel.js
··· 9 9 background: var(--main-background); 10 10 font-family: var(--main-font); 11 11 color: var(--main-solid); 12 + overflow-y: scroll; 13 + height: 100%; 14 + width: 100%; 15 + } 16 + 17 + .control-panel > * { 12 18 max-width: 300px; 13 - overflow-y: scroll; 14 19 } 15 20 16 21 .control-panel legend {
+83 -10
webui/js/gpu/proj_shift.js
··· 24 24 let phi = 0.0; 25 25 let psi = 1.0; 26 26 let c = 0.85; 27 + let d = 0.0; 27 28 let iterations = 100; 28 29 let escape_distance = 2; 29 30 let centerX = -0.5; 30 31 let centerY = 0.0; 31 32 let zoom = 4.0; 33 + let twist = 0.0; 34 + let squoosh_x = 1.0; 35 + let squoosh_y = 1.0; 32 36 33 37 let showTrajectory = false; 34 38 ··· 51 55 }; 52 56 } 53 57 54 - function iteratePolar(x, phi, psi, c) { 58 + function iteratePolar(x, phi, psi, c, d) { 55 59 const shifted = projectiveShift(x, phi, psi); 56 - return { x: shifted.x - c, y: shifted.y }; 60 + return { x: shifted.x - c, y: shifted.y - d }; 57 61 } 58 62 59 63 function computeTrajectory(startZ, maxIters, escapeThreshold) { ··· 65 69 if (magSq > escapeThreshold * escapeThreshold) { 66 70 break; 67 71 } 68 - z = iteratePolar(z, phi, psi, c); 72 + z = iteratePolar(z, phi, psi, c, d); 69 73 trajectory.push({ x: z.x, y: z.y }); 70 74 } 71 75 ··· 98 102 }, 99 103 { 100 104 type: "number", 105 + label: "d", 106 + value: d, 107 + min: 0, 108 + max: 1, 109 + step: 0.001, 110 + onUpdate: (value, set) => { 111 + d = value; 112 + render(); 113 + } 114 + }, 115 + { 116 + type: "number", 101 117 label: greek["psi"], 102 118 value: psi, 103 119 min: 0, ··· 110 126 }, 111 127 { 112 128 type: "number", 129 + label: "twist", 130 + value: twist, 131 + min: -$tau/2, 132 + max: $tau/2, 133 + step: 0.001, 134 + onUpdate: (value, set) => { 135 + twist = value; 136 + render(); 137 + } 138 + }, 139 + { 140 + type: "number", 141 + label: "squoosh x", 142 + value: squoosh_x, 143 + min: 0, 144 + max: 10, 145 + step: 0.001, 146 + onUpdate: (value, set) => { 147 + squoosh_x = value; 148 + render(); 149 + } 150 + }, 151 + { 152 + type: "number", 153 + label: "squoosh y", 154 + value: squoosh_y, 155 + min: 0, 156 + max: 10, 157 + step: 0.001, 158 + onUpdate: (value, set) => { 159 + squoosh_y = value; 160 + render(); 161 + } 162 + }, 163 + { 164 + type: "number", 113 165 label: greek["phi"], 114 166 value: phi, 115 167 min: 0, ··· 151 203 ]]); 152 204 153 205 const renderStack = $div("full"); 206 + renderStack.dataset.name = "renderer"; 154 207 renderStack.style.position = "relative"; 155 208 156 209 const gpuModule = await $mod("gpu/webgpu", renderStack); ··· 162 215 canvas.setAttribute("role", "application"); 163 216 canvas.setAttribute("aria-keyshortcuts", "f"); 164 217 165 - const compShader = await $gpu.loadShader("proj_shift"); 218 + const compShader = await $gpu.loadShader("proj_shift", { "pixel_mapping" : "pixel_to_complex" }); 166 219 const blitShader = await $gpu.loadShader("blit"); 167 220 168 221 if (!compShader || !blitShader) return; ··· 205 258 renderStack.appendChild(overlay); 206 259 207 260 function showControls() { 208 - if (controls.parentNode) return; 261 + if (!topmost.isConnected) { 262 + topmost = renderStack; 263 + } 264 + else if (topmost.querySelector(".control-panel")) return; 209 265 210 266 return ["show controls", async () => { 211 - await $mod("layout/split", renderStack.parentNode, [{content: [controls, renderStack], percents: [20, 80]}]); 267 + const split = await $mod("layout/split", renderStack.parentNode, [{content: [controls, renderStack], percents: [20, 80]}]); 268 + topmost = split.topmost; 212 269 }]; 213 270 } 214 271 ··· 217 274 return ["show trajectory", () => {showTrajectory = true}]; 218 275 } 219 276 277 + function exitRenderer() { 278 + // relying on showControls' topmost check to have occurred before this can be called, 279 + // which is true bc that happens while the context menu is built. 280 + // this may not remain true if a hotkey is added for exiting w/o opening the context menu 281 + const target = topmost.parentNode; 282 + target.replaceChildren(); 283 + $mod("layout/nothing", target); 284 + } 285 + 220 286 renderStack.$preventCollapse = true; 221 287 renderStack.$contextMenu = { 222 - items: [showControls, toggleTrajectory] 288 + items: [showControls, toggleTrajectory, ["exit", exitRenderer]] 223 289 }; 224 290 225 - await $mod("layout/split", target, [{ content: [controls, renderStack], percents: [20, 80]}]); 291 + const split = await $mod("layout/split", target, [{ content: [controls, renderStack], percents: [20, 80]}]); 292 + let topmost = split.topmost; 226 293 227 294 let width = canvas.clientWidth; 228 295 let height = canvas.clientHeight; ··· 236 303 }); 237 304 238 305 const uniformBuffer = $gpu.device.createBuffer({ 239 - size: 10 * 4, 306 + size: 14 * 4, 240 307 usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, 241 308 }); 242 309 ··· 307 374 308 375 function updateUniforms() { 309 376 // TODO: save the same buffer 310 - const uniformData = new ArrayBuffer(10 * 4); 377 + const uniformData = new ArrayBuffer(14 * 4); 311 378 const view = new DataView(uniformData); 312 379 313 380 /* ··· 321 388 max_iter: u32, 322 389 escape_distance: f32, 323 390 psi: f32, 391 + d: f32, 392 + twist: f32, 324 393 */ 325 394 326 395 // TODO less brittle hard coded numbers jfc ··· 334 403 view.setUint32(28, iterations, true); 335 404 view.setFloat32(32, escape_distance, true); 336 405 view.setFloat32(36, psi, true); 406 + view.setFloat32(40, d, true); 407 + view.setFloat32(44, twist, true); 408 + view.setFloat32(48, squoosh_x, true); 409 + view.setFloat32(52, squoosh_y, true); 337 410 338 411 $gpu.device.queue.writeBuffer(uniformBuffer, 0, uniformData); 339 412 }
+15 -3
webui/js/gpu/webgpu.js
··· 31 31 const canvasFormat = navigator.gpu.getPreferredCanvasFormat(); 32 32 33 33 // TODO make more robust 34 - async function loadShader(shaderName) { 34 + async function loadShader(shaderName, substitutions = {}) { 35 35 const response = await fetch(`./${shaderName}.wgsl`); 36 36 const shaderSource = await response.text(); 37 - const module = device.createShaderModule({ code: shaderSource }); 37 + 38 + var substitutionFailure = false; 39 + const adjustedSource = shaderSource.replace(/\${(\w+)}/g, (match, key) => { 40 + if (!(key in substitutions)) { 41 + substitutionFailure = true; 42 + return ""; 43 + } 44 + else { 45 + return substitutions[key]; 46 + } 47 + }); 48 + 49 + const module = device.createShaderModule({ code: adjustedSource }); 38 50 const info = await module.getCompilationInfo(); 39 - if (info.messages.some(m => m.type === "error")) { 51 + if (info.messages.some(m => m.type === "error") || substitutionFailure) { 40 52 return null; 41 53 } 42 54 return module;
+2 -2
webui/js/layout/nothing.js
··· 68 68 items: Object.entries(menuItems) 69 69 }; 70 70 71 - const menu = await $mod("control/menu", backdrop, [menuItems]); 71 + //const menu = await $mod("control/menu", backdrop, [menuItems]); 72 72 73 73 backdrop.addEventListener("keydown", (e) => { 74 74 if (e.key === "Enter" || e.key === "o") { ··· 78 78 const centerX = rect.left + rect.width / 2; 79 79 const centerY = rect.top + rect.height / 2; 80 80 81 - menu.showMenu(); 81 + document.$showMenu(backdrop); 82 82 } 83 83 }); 84 84
+17 -16
webui/js/layout/split.js
··· 9 9 padding: 0rem; 10 10 } 11 11 12 - [theme-changed] > .split { 12 + [data-theme-changed] > .split { 13 13 padding: 0.5rem; 14 14 } 15 15 16 - .split[orientation=row] { 16 + .split[data-orientation=row] { 17 17 flex-direction: row; 18 18 } 19 19 20 - .split[orientation=col] { 20 + .split[data-orientation=col] { 21 21 flex-direction: column; 22 22 } 23 23 ··· 32 32 -webkit-user-drag: none; 33 33 } 34 34 35 - .split[orientation=row] > .splitter { 35 + .split[data-orientation=row] > .splitter { 36 36 width: 1px; 37 37 cursor: col-resize; 38 38 } 39 39 40 - .split[orientation=col] > .splitter { 40 + .split[data-orientation=col] > .splitter { 41 41 height: 1px; 42 42 cursor: row-resize; 43 43 } ··· 49 49 pointer-events: auto; 50 50 } 51 51 52 - .split[orientation=row] > .splitter::before { 52 + .split[data-orientation=row] > .splitter::before { 53 53 top: 0; 54 54 left: calc(0px - var(--panel-margin)); 55 55 width: calc(var(--panel-margin) * 2); 56 56 height: 100%; 57 57 } 58 58 59 - .split[orientation=col] > .splitter::before { 59 + .split[data-orientation=col] > .splitter::before { 60 60 left: 0; 61 61 top: calc(0px - var(--panel-margin)); 62 62 height: calc(var(--panel-margin) * 2); 63 63 width: 100%; 64 64 } 65 65 66 - .split[orientation=row] > :first-child { 66 + .split[data-orientation=row] > :first-child { 67 67 margin-right: var(--panel-margin); 68 68 width: calc(var(--current-portion) - 0.5px - var(--panel-margin)); 69 69 } 70 70 71 - .split[orientation=col] > :first-child { 71 + .split[data-orientation=col] > :first-child { 72 72 margin-bottom: var(--panel-margin); 73 73 height: calc(var(--current-portion) - 0.5px - var(--panel-margin)); 74 74 } 75 75 76 - .split[orientation=row] > :not(.splitter):not(:first-child):not(:last-child) { 76 + .split[data-orientation=row] > :not(.splitter):not(:first-child):not(:last-child) { 77 77 margin-left: var(--panel-margin); 78 78 margin-right: var(--panel-margin); 79 79 width: calc(var(--current-portion) - 1px - 2 * var(--panel-margin)); 80 80 } 81 81 82 - .split[orientation=col] > :not(.splitter):not(:first-child):not(:last-child) { 82 + .split[data-orientation=col] > :not(.splitter):not(:first-child):not(:last-child) { 83 83 margin-top: var(--panel-margin); 84 84 margin-bottom: var(--panel-margin); 85 85 height: calc(var(--current-portion) - 1px - 2 * var(--panel-margin)); 86 86 } 87 87 88 - .split[orientation=row] > :last-child { 88 + .split[data-orientation=row] > :last-child { 89 89 margin-left: var(--panel-margin); 90 90 width: calc(var(--current-portion) - 0.5px - var(--panel-margin)); 91 91 } 92 92 93 - .split[orientation=col] > :last-child { 93 + .split[data-orientation=col] > :last-child { 94 94 margin-top: var(--panel-margin); 95 95 height: calc(var(--current-portion) - 0.5px - var(--panel-margin)); 96 96 } ··· 146 146 var n = content.length; 147 147 148 148 const container = $div("split"); 149 - container.setAttribute("orientation", settings.orientation); 149 + container.dataset.orientation = settings.orientation; 150 150 var row = settings.orientation === "row"; 151 151 152 152 const orientationToggle = [row ? "row->col" : "col->row", () => { 153 153 row = !row; 154 154 settings.orientation = row ? "row" : "col"; 155 - container.setAttribute("orientation", settings.orientation); 155 + container.dataset.orientation = settings.orientation; 156 156 orientationToggle[0] = row ? "row->col" : "col->row"; 157 157 }]; 158 158 ··· 347 347 }; 348 348 349 349 return { 350 - replace: true 350 + replace: true, 351 + topmost: container 351 352 }; 352 353 } 353 354
+35
webui/js/main.js
··· 69 69 enumerable: false 70 70 }); 71 71 72 + 73 + Object.defineProperty(Element.prototype, "$attrs", { 74 + get() { 75 + if (this.__attrsProxy) return this.__attrsProxy; 76 + 77 + const el = this; 78 + this.__attrsProxy = new Proxy({}, { 79 + get(_, prop) { 80 + return el.getAttribute(prop.toString()); 81 + }, 82 + set(_, prop, val) { 83 + el.setAttribute(prop.toString(), val); 84 + return true; 85 + }, 86 + deleteProperty(_, prop) { 87 + el.removeAttribute(prop.toString()); 88 + return true; 89 + }, 90 + has(_, prop) { 91 + return el.hasAttribute(prop.toString()); 92 + }, 93 + ownKeys() { 94 + return Array.from(el.attributes).map(a => a.name); 95 + }, 96 + getOwnPropertyDescriptor(_, prop) { 97 + if (el.hasAttribute(prop.toString())) { 98 + return { configurable: true, enumerable: true, value: el.getAttribute(prop.toString()) }; 99 + } 100 + return undefined; 101 + } 102 + }); 103 + }, 104 + enumerable: false 105 + }); 106 + 72 107 window.$element = (name) => document.createElement(name); 73 108 window.$div = function (classList = "") { 74 109 const div = $element("div");
+7 -9
webui/js/theme.js
··· 2 2 function checkForParentTheme(element, theme) { 3 3 let parent = element.parentElement; 4 4 while (parent) { 5 - //if (parent.classList.contains("target")) { 6 - const parentTheme = parent.getAttribute("theme"); 7 - //if (!parentTheme) return false; 8 - if (parentTheme) 9 - return parentTheme !== theme; 10 - //} 5 + const parentTheme = parent.dataset.theme; 6 + 7 + if (parentTheme) return parentTheme !== theme; 8 + 11 9 parent = parent.parentElement; 12 10 } 13 11 ··· 17 15 export function main(target, initialTheme = null) { 18 16 const storedTheme = localStorage.getItem("theme"); 19 17 let theme = initialTheme || storedTheme || "blackboard"; 20 - target.setAttribute("theme", theme); 18 + target.dataset.theme = theme; 21 19 22 20 if (target === document.body) { 23 21 localStorage.setItem("theme", theme); 24 22 } 25 23 26 24 if (checkForParentTheme(target, theme)) { 27 - target.setAttribute("theme-changed", ""); 25 + target.dataset.themeChanged = ""; 28 26 } else { 29 - target.removeAttribute("theme-changed"); 27 + delete target.dataset.themeChanged; 30 28 } 31 29 32 30 return { replace: false };
+2 -2
webui/style/theme/themes.css
··· 1 1 2 - *[theme=whiteboard] { 2 + *[data-theme=whiteboard] { 3 3 --main-background: #EEDDEE; 4 4 --main-solid: #112233; 5 5 --main-transparent: #112233DD; ··· 42 42 --token-css-delimiter: var(--main-solid); 43 43 } 44 44 45 - *[theme=blackboard] { 45 + *[data-theme=blackboard] { 46 46 --main-background: #112233; 47 47 --main-solid: #EEDDEE; 48 48 --main-transparent: #EEDDEEDD;