+1
.gitignore
+1
.gitignore
+161
sketch/gamma.py
+161
sketch/gamma.py
···
1
+
2
+
import math
3
+
import torch
4
+
from pathlib import Path
5
+
6
+
from lib.util import *
7
+
from lib.spaces import insert_at_coords, map_space, cgrid
8
+
9
+
name = "gamma"
10
+
11
+
device = "cuda"
12
+
torch.set_default_device(device)
13
+
14
+
t_real = torch.double
15
+
t_complex = torch.cdouble
16
+
tau = 2 * math.pi
17
+
18
+
# renderer config
19
+
scale_power = 12
20
+
scale = 2 ** scale_power
21
+
origin = 0, 0
22
+
span = 15, 15
23
+
stretch = 1, 1
24
+
zooms = [
25
+
]
26
+
27
+
# particle distribution: sample domain points
28
+
particle_count = 20000000
29
+
30
+
#p_positions = (torch.rand([particle_count], dtype=t_complex) - (0.5 + 0.5j)) * 40
31
+
32
+
# initialize colors (updated every frame from Gamma)
33
+
#p_colors = torch.zeros([particle_count, 3], device=device, dtype=t_real)
34
+
35
+
mapping = map_space(origin, span, zooms, stretch, scale)
36
+
37
+
(_, (h,w)) = mapping
38
+
grid = cgrid(mapping)
39
+
40
+
scratch = torch.zeros([h, w, 3], device=device, dtype=t_real)
41
+
p_positions = torch.zeros([h*w,2], device=device, dtype=t_real)
42
+
p_positions[:,0] = grid.real.reshape((h*w))
43
+
p_positions[:,1] = grid.imag.reshape((h*w))
44
+
45
+
46
+
47
+
def gamma_approx(z):
48
+
# avoid poles in a naive way
49
+
poles = (z.real <= 0) & torch.isclose(z.real.round(), z.real)
50
+
z = z.clone()
51
+
z[poles] = complex("nan")
52
+
53
+
# reflection for left half-plane
54
+
reflect = z.real < 0.5
55
+
zr = z[reflect]
56
+
if zr.numel() > 0:
57
+
g1mz = gamma_approx(1 - zr)
58
+
z[reflect] = math.pi / (torch.sin(math.pi * zr) * g1mz)
59
+
60
+
# Stirling-like formula on right half-plane
61
+
zpos = z[~reflect]
62
+
if zpos.numel() > 0:
63
+
t = zpos - 0.5
64
+
core = math.sqrt(2 * math.pi) * torch.exp(t * torch.log(zpos) - zpos)
65
+
z[~reflect] = core
66
+
return z
67
+
68
+
69
+
a = 2.0j
70
+
b = 2.0
71
+
kernel_xs = torch.tensor([a + b, a - b, -a + b, -a - b], device=device)
72
+
kernel_xs -= kernel_xs.mean()
73
+
74
+
def velu_like(z):
75
+
acc = z
76
+
for xQ in kernel_xs:
77
+
acc += 1 / (z - xQ)
78
+
return acc
79
+
80
+
81
+
def gamma_to_rgb(gz):
82
+
arg = torch.angle(gz)
83
+
hue = (arg + math.pi) / (2 * math.pi)
84
+
mag = torch.abs(gz)
85
+
logmag = torch.log(mag + 1e-12)
86
+
87
+
# normalize brightness heuristically
88
+
val = (logmag + 5) / 10
89
+
val = val.clamp(0.0, 1.0)
90
+
91
+
# convert HSV to RGB (same logic as your framework)
92
+
h6 = hue * 6
93
+
i = torch.floor(h6).long() % 6
94
+
f = h6 - torch.floor(h6)
95
+
p = val * 0
96
+
q = val * (1 - f)
97
+
t = val * f
98
+
99
+
r = torch.zeros_like(val)
100
+
g = torch.zeros_like(val)
101
+
b = torch.zeros_like(val)
102
+
103
+
m = i == 0; r[m]=val[m]; g[m]=t[m]; b[m]=p[m]
104
+
m = i == 1; r[m]=q[m]; g[m]=val[m]; b[m]=p[m]
105
+
m = i == 2; r[m]=p[m]; g[m]=val[m]; b[m]=t[m]
106
+
m = i == 3; r[m]=p[m]; g[m]=q[m]; b[m]=val[m]
107
+
m = i == 4; r[m]=t[m]; g[m]=p[m]; b[m]=val[m]
108
+
m = i == 5; r[m]=val[m]; g[m]=p[m]; b[m]=q[m]
109
+
110
+
return torch.stack([r,g,b], dim=-1)
111
+
112
+
frame_index = [0]
113
+
114
+
def main():
115
+
schedule(main_render, None)
116
+
117
+
def main_render():
118
+
scratch = torch.zeros([h, w, 3], device=device, dtype=t_real)
119
+
naniter = torch.zeros([h, w], device=device, dtype=t_real)
120
+
121
+
# project function kept from original plume code
122
+
def project(p):
123
+
return torch.view_as_real(p).permute(1,0)
124
+
125
+
# single-frame render (or loop for animations)
126
+
frame_index[0] += 1
127
+
128
+
# compute Gamma(p)
129
+
gz = grid.clone()#gamma_approx(p_positions)
130
+
131
+
for i in range(2):
132
+
gz = gamma_approx(gz)#velu_like(gz)
133
+
134
+
naniter += gz.isnan() * 1.0
135
+
136
+
# convert to RGB
137
+
#rgb = gamma_to_rgb(gz)
138
+
139
+
# project particle coords to pixel grid
140
+
#coords = project(p_positions).clone()
141
+
#insert_at_coords(coords, rgb, scratch, mapping)
142
+
143
+
144
+
#scratch[:,:,0] = gz.real
145
+
#scratch[:,:,2] = gz.imag
146
+
#scratch[:,:,1] = naniter / naniter.max()
147
+
148
+
scratch[:,:,0] = (torch.angle(gz) + math.pi)
149
+
scratch[:,:,1] = torch.log(torch.abs(gz) + 1e-12) / 30
150
+
scratch[:,:,1] /= scratch[:,:,2].max()
151
+
scratch[:,:,2] = torch.cos(scratch[:,:,0] / 2)
152
+
scratch[:,:,0] = torch.abs(torch.sin(scratch[:,:,0] / 2))
153
+
154
+
155
+
# normalization/centering like plume_c does
156
+
#scratch[:,:,2] -= scratch[:,:,2].mean()
157
+
#scratch[:,:,2] /= scratch[:,:,2].std() * 6
158
+
#scratch[:,:,2] += 0.5
159
+
160
+
save(scratch.permute(2,0,1).sqrt(), f"{run_dir}/{i:06d}")
161
+
+222
sketch/ode.py
+222
sketch/ode.py
···
1
+
2
+
import time
3
+
import math
4
+
from pathlib import Path
5
+
6
+
import torch
7
+
8
+
from lib.spaces import insert_at_coords, map_space
9
+
from lib.ode import rk4_step
10
+
from lib.util import *
11
+
12
+
def main():
13
+
schedule(run, None)
14
+
15
+
device = "cuda"
16
+
torch.set_default_device(device)
17
+
t_fp = torch.double
18
+
torch.set_default_dtype(t_fp)
19
+
20
+
dissipation = 0.04 # "b" parameter in the literature
21
+
22
+
tau = 3.14159265358979323 * 2
23
+
24
+
particle_count = 10000000
25
+
iterations = 10000
26
+
save_if = lambda i: True #i > 150 and i % 5 == 0
27
+
rotation_rate = tau / 2000
28
+
29
+
# 2 ** 9 = 512; 10 => 1024; 11 => 2048
30
+
scale_power = 10
31
+
32
+
scale = 2 ** scale_power
33
+
origin = 0, 0
34
+
s = 40
35
+
span = s, s
36
+
zooms = [
37
+
]
38
+
stretch = 1, 1
39
+
40
+
mapping = map_space(origin, span, zooms, stretch, scale)
41
+
(_, (h,w)) = mapping
42
+
43
+
dt = 0.05
44
+
45
+
def proj(u, v):
46
+
dot = (u * v).sum(dim=1)
47
+
scale = (dot / (v.norm(p=2, dim=1) ** 2))
48
+
out = v.clone()
49
+
for dim in range(3):
50
+
out[:,dim] *= scale
51
+
return out
52
+
53
+
def proj_shift(a, b):
54
+
return a + proj(b, a)
55
+
56
+
def get_transformation(direction, flow, ebb, rescaling):
57
+
def transformation(p):
58
+
#return blus(p, ones) - ones
59
+
#if randomize:
60
+
# direction = flow * torch.polar(ones.real, (random() * random_range - random_range / 2) * ones.real)
61
+
if flow == 0:
62
+
result = p - direction * ebb
63
+
else:
64
+
result = proj_shift(p, direction * flow) - direction * ebb
65
+
#res_abs = result.abs()
66
+
if rescaling:
67
+
result /= result.norm(p=2,dim=1).max()
68
+
#result = torch.polar(2 * res_abs / res_abs.max(), result.angle())
69
+
return result
70
+
return transformation
71
+
72
+
def _deriv(b, c):
73
+
def derivative(p):
74
+
return p * ((b * p).sum(dim=0)) - c * b
75
+
return derivative
76
+
77
+
def _sprott2014(a):
78
+
def derivative(p):
79
+
d = p.roll(shifts=-1, dims=0)
80
+
d[0] += 2 * (p[0] * p[1]) + (p[0] * p[2]) # y + 2xy + xz
81
+
d[1] = - (p[0] * p[0]) # -x^2
82
+
d[2] += d[1] - (p[1] * p[1]) # x - x^2 - y^2
83
+
d[1] *= 2 # - 2x^2
84
+
d[1] += a + (p[1] * p[2]) # a - 2x^2 + yz
85
+
return d
86
+
return derivative
87
+
88
+
def _sprottET0():
89
+
def derivative(p):
90
+
d = p.clone()
91
+
d[0] = p[1]
92
+
d[1] = - p[0] + p[1] * p[2]
93
+
d[2] = (p[0] * p[0]) - (4 * p[1] * p[1]) + 1
94
+
return d
95
+
return derivative
96
+
97
+
def _suspension(c):
98
+
damping = 0.02
99
+
r_eq = 1.0
100
+
scale = 0.05
101
+
def deriv(p):
102
+
x, y, z = p[0], p[1], p[2]
103
+
104
+
# Shift x to center the tori at Re = -c/2
105
+
X = x + c/2
106
+
Y = y
107
+
108
+
# Polar coordinates
109
+
r = torch.sqrt(X**2 + Y**2) + 1e-12
110
+
theta = torch.atan2(Y, X)
111
+
112
+
# Radial displacement
113
+
dr = r + torch.cos(theta)
114
+
115
+
# Velocity in 2D
116
+
dx = scale * (dr * (X/r) - X)
117
+
dy = scale * (dr * (Y/r) - Y)
118
+
119
+
# Optional weak radial damping
120
+
dx -= scale * damping * (r - r_eq) * (X/r)
121
+
dy -= scale * damping * (r - r_eq) * (Y/r)
122
+
123
+
# z-axis suspension (monotonic)
124
+
dz = scale * torch.ones_like(x)
125
+
126
+
return torch.stack([dx, dy, dz], dim=0)
127
+
return deriv
128
+
129
+
130
+
def gpt_deriv(b, k=1.0, c=0.1, omega=0.5):
131
+
"""
132
+
Returns a function deriv(points) -> [3,N] where points are 3D column vectors.
133
+
b: tensor of shape [3] (will be normalized internally)
134
+
k: scalar gain for tanh nonlinearity
135
+
c: scalar drift magnitude along -b
136
+
omega: scalar rotation strength around b
137
+
"""
138
+
b = b / b.norm() # normalize once
139
+
140
+
def deriv(points):
141
+
# <a, b> for each column
142
+
#dots = torch.einsum('ij,i->j', points, b) # shape [N]
143
+
dots = (points * b[:]).sum(dim=0) # shape [N]
144
+
145
+
# radial scaling term
146
+
radial = torch.tanh(k * dots).unsqueeze(0) * points # [3,N]
147
+
148
+
# drift along -b
149
+
drift = -c * b[:].expand_as(points) # [3,N]
150
+
151
+
152
+
# rotation around b: b ร a
153
+
rotation = omega * torch.cross(b[:].expand_as(points), points, dim=0)
154
+
155
+
return radial + drift + rotation
156
+
157
+
return deriv
158
+
159
+
p_positions = (torch.rand([3, particle_count], device=device, dtype=t_fp) - 0.5) * 2
160
+
#p_positions[0] *= 0.05
161
+
#p_positions[0] += 0.7
162
+
#p_positions[2] = 0
163
+
direction = torch.zeros_like(p_positions) / math.sqrt(3)
164
+
direction[1] += 1
165
+
166
+
#derivative = _deriv(direction, 0.5)
167
+
#derivative = gpt_deriv(direction, 1.0, 0.1, 0.5)
168
+
#derivative = _sprott2014(0.7)
169
+
#derivative = _sprottET0()
170
+
derivative = _suspension(0.8)
171
+
172
+
step = lambda p, dp, h: p + dp * h * dt
173
+
rk4_curried = lambda p: rk4_step(derivative, step, p)
174
+
175
+
176
+
p_colors = torch.rand([particle_count,3], device=device, dtype=t_fp)
177
+
178
+
color_rotation = torch.linspace(0, tau / 4, particle_count)
179
+
180
+
p_colors[:,0] = (p_positions[0,:] / 2) + 0.5
181
+
p_colors[:,1] = (p_positions[1,:] / 2) + 0.5
182
+
p_colors[:,2] = (p_positions[2,:] / 2) + 0.5
183
+
184
+
def project(p, colors, i):
185
+
c = math.cos(i * rotation_rate)
186
+
s = math.sin(i * rotation_rate)
187
+
rotation = torch.tensor([
188
+
[1, 0, 0],
189
+
[0, c,-s],
190
+
[0, s, c]])
191
+
rotation2 = torch.tensor([
192
+
[ c, 0, s],
193
+
[ 0, 1, 0],
194
+
[-s, 0, c]])
195
+
#rotation = torch.tensor([
196
+
# [c, -s, 0],
197
+
# [s, c, 0],
198
+
# [0, 0, 1]])
199
+
alt_colors = colors.clone()
200
+
res = (rotation @ p)
201
+
#color_filter = (p[0] - 0.7).abs() < 0.01
202
+
#alt_colors[:,0] *= color_filter
203
+
#alt_colors[:,1] *= color_filter
204
+
#alt_colors[:,2] *= color_filter
205
+
return (res[1:3], alt_colors)
206
+
207
+
frame_index = [0]
208
+
209
+
def run():
210
+
scratch = torch.zeros([h, w, 3], device=device, dtype=t_fp)
211
+
212
+
for iteration in range(iterations):
213
+
if save_if(iteration) or iteration == iterations - 1:
214
+
(p_projected, alt_colors) = project(p_positions, p_colors, iteration)
215
+
frame_index[0] += 1
216
+
scratch *= 0
217
+
insert_at_coords(p_projected, alt_colors, scratch, mapping)
218
+
save(scratch.permute(2,0,1), f"{run_dir}/{frame_index[0]:06d}")
219
+
220
+
p_positions.copy_(rk4_curried(p_positions))
221
+
222
+
+4
-4
sketch/old/snakepyt_0_0/ode.py
+4
-4
sketch/old/snakepyt_0_0/ode.py
···
142
142
# [0, 0, 1]])
143
143
alt_colors = colors.clone()
144
144
res = (rotation @ p)
145
-
#color_filter = res[2].abs() < 0.7
146
-
#alt_colors[:,0] *= color_filter
147
-
#alt_colors[:,1] *= color_filter
148
-
#alt_colors[:,2] *= color_filter
145
+
color_filter = res[2].abs() < 0.1
146
+
alt_colors[:,0] *= color_filter
147
+
alt_colors[:,1] *= color_filter
148
+
alt_colors[:,2] *= color_filter
149
149
return (res[0:2], alt_colors)
150
150
151
151
frame_index = [0]
+218
sketch/ps_particle.py
+218
sketch/ps_particle.py
···
1
+
import time
2
+
from pathlib import Path
3
+
from random import random
4
+
5
+
import torch
6
+
7
+
from lib.util import *
8
+
9
+
type fp_range = tuple[float, float]
10
+
type fp_region2 = tuple[fp_range, fp_range]
11
+
type fp_coords2 = tuple[float, float]
12
+
type hw = tuple[int, int]
13
+
type region_mapping = tuple[fp_region2, hw]
14
+
15
+
def main():
16
+
import math
17
+
name = "ps_particle"
18
+
device = "cuda"
19
+
t_real = torch.double
20
+
t_complex = torch.cdouble
21
+
22
+
tau = 3.14159265358979323 * 2
23
+
24
+
flow = 8
25
+
ebb = 1 - (1 / 5) + 0.0001
26
+
phi = 0
27
+
randomize = False
28
+
random_range = tau #/ 50
29
+
rescaling = False
30
+
show_gridlines = True
31
+
32
+
particle_count = 10**6
33
+
34
+
iterations = 200#0001
35
+
36
+
# 2 ** 9 = 512; 10 => 1024; 11 => 2048
37
+
scale_power = 12
38
+
scale = 2 ** scale_power
39
+
40
+
origin = 0, -(flow*ebb)/2
41
+
span = 10, 10
42
+
43
+
stretch = 1, 1
44
+
45
+
zooms = [
46
+
#((0.5,0.8),(0.1,0.5))
47
+
]
48
+
49
+
save_every = 5000
50
+
agg_every = 1#000
51
+
yell_every = 1000
52
+
grid_on_agg = False#True
53
+
54
+
quantile_range = torch.tensor([0.05, 0.95], dtype=t_real, device=device)
55
+
56
+
schedule(run, None)
57
+
58
+
def blus(a, b, phi):
59
+
theta_a = a.angle()
60
+
return torch.polar(a.abs() + b.abs() * torch.cos(b.angle() - theta_a - phi), theta_a)
61
+
62
+
def run():
63
+
run_dir = time.strftime(f"%d.%m.%Y/{name}_t%H.%M.%S")
64
+
Path("out/" + run_dir).mkdir(parents=True, exist_ok=True)
65
+
Path("out/" + run_dir + "/aggregate").mkdir(parents=True, exist_ok=True)
66
+
Path("out/" + run_dir + "/frame").mkdir(parents=True, exist_ok=True)
67
+
68
+
global span
69
+
70
+
x_min = origin[0] - (span[0] / 2)
71
+
y_min = origin[1] - (span[1] / 2)
72
+
73
+
for ((xa, xb), (ya, yb)) in zooms:
74
+
x_min += span[0] * xa
75
+
y_min += span[1] * ya
76
+
span = span[0] * (xb - xa), span[1] * (yb - ya)
77
+
78
+
x_max = x_min + span[0]
79
+
y_max = y_min + span[1]
80
+
81
+
aspect = span[0] * stretch[0] / (span[1] * stretch[1])
82
+
83
+
if aspect < 1:
84
+
h = scale
85
+
w = int(scale * aspect)
86
+
else:
87
+
w = scale
88
+
h = int(scale / aspect)
89
+
90
+
x_range = (x_min, x_max)
91
+
y_range = (y_min, y_max)
92
+
region = (x_range, y_range)
93
+
mapping = (region, (h,w))
94
+
95
+
p_positions = (torch.rand([particle_count], device=device, dtype=t_complex))
96
+
#p_positions.imag = torch.linspace(0, tau, particle_count)
97
+
#p_positions.real = torch.linspace(0, 1, particle_count)
98
+
#p_positions = torch.polar(p_positions.real, p_positions.imag)
99
+
100
+
p_colors = torch.ones([particle_count,3], device=device, dtype=t_real)
101
+
color_rotation = torch.linspace(0, tau / 4, particle_count)
102
+
103
+
p_colors[:,0] = torch.frac((p_positions.real) * 1 / (span[0]))
104
+
p_colors[:,2] = torch.frac((p_positions.imag) * 1 / (span[1]))
105
+
106
+
p_positions.real *= 0.025 * (y_max - y_min)
107
+
p_positions.real += y_min + (0.4875) * (y_max - y_min)
108
+
p_positions.imag *= 0.025 * (x_max - x_min) * 4
109
+
p_positions.imag += x_min + 0.501 * (x_max - x_min)
110
+
111
+
#p_colors[:,0] = torch.cos(color_rotation)
112
+
p_colors[:,1] = 1.0 - (p_colors[:,0] + p_colors[:,2])#0.1
113
+
#p_colors[:,2] *= 0
114
+
#p_colors[:,2] = torch.sin(color_rotation)
115
+
116
+
canvas = torch.zeros([3, h, w], device=device, dtype=t_real)
117
+
scratch = torch.zeros([h, w, 3], device=device, dtype=t_real)
118
+
gridlines = torch.zeros([h,w], device=device)
119
+
120
+
def project(p):
121
+
return torch.view_as_real(p).permute(1,0)
122
+
123
+
ones = torch.ones_like(p_positions)
124
+
global direction
125
+
direction = flow * torch.polar(ones.real, 1 * tau * ones.real)
126
+
def next_positions(p, i):
127
+
global direction
128
+
#return blus(p, ones) - ones
129
+
if randomize:
130
+
direction = flow * torch.polar(ones.real, (random() * random_range - random_range / 2) * ones.real)
131
+
result = blus(p, direction, phi) - direction * ebb
132
+
res_abs = result.abs()
133
+
if rescaling:
134
+
result = torch.polar(res_abs / res_abs.max(), result.angle())
135
+
return result
136
+
137
+
for i in range(10):
138
+
frac = i / 10
139
+
gridlines[math.floor(frac*h), :] = 1
140
+
gridlines[:, math.floor(frac*w)] = 1
141
+
142
+
def insert_at_coords(coords, values, target, mapping: region_mapping):
143
+
(region, hw) = mapping
144
+
(xrange, yrange) = region
145
+
(h,w) = hw
146
+
(x_min, x_max) = xrange
147
+
(y_min, y_max) = yrange
148
+
149
+
mask = torch.ones([particle_count], device=device)
150
+
mask *= (coords[1] >= x_min) * (coords[1] <= x_max)
151
+
mask *= (coords[0] >= y_min) * (coords[0] <= y_max)
152
+
in_range = mask.nonzero().squeeze()
153
+
154
+
# TODO: combine coord & value tensors so there's only one index_select necessary
155
+
coords_filtered = torch.index_select(coords.permute(1,0), 0, in_range)
156
+
values_filtered = torch.index_select(values, 0, in_range)
157
+
158
+
coords_filtered[:,1] -= x_min
159
+
coords_filtered[:,1] *= (w-1) / (x_max - x_min)
160
+
coords_filtered[:,0] -= y_min
161
+
coords_filtered[:,0] *= (h-1) / (y_max - y_min)
162
+
indices = coords_filtered.long()
163
+
164
+
target.index_put_((indices[:,0],indices[:,1]), values_filtered, accumulate=True)
165
+
166
+
167
+
for iteration in range(iterations):
168
+
p_projected = project(p_positions).clone()
169
+
170
+
if iteration % 1 == 0:
171
+
172
+
scratch *= 0
173
+
insert_at_coords(p_projected, p_colors, scratch, mapping)
174
+
canvas += scratch.permute(2,0,1)
175
+
176
+
177
+
temp = canvas.clone()
178
+
#for d in range(3):
179
+
# temp[d] -= temp[d].mean()
180
+
# temp[d] /= 8 * temp[d].std()
181
+
# temp[d] -= temp[d].min()
182
+
183
+
temp -= temp.min()
184
+
temp = torch.log(temp)
185
+
temp /= temp.max()
186
+
187
+
if iteration % agg_every == 0:
188
+
p_low, p_high = torch.quantile(temp[:,(2*h//5):(3*h//5),:], quantile_range)
189
+
# temp[0] += gridlines
190
+
# temp[1] += gridlines
191
+
# temp[2] += gridlines
192
+
193
+
if grid_on_agg:
194
+
save((1 - temp).clamp_(0.0, 1.0) - gridlines, f"{run_dir}/aggregate/_{iteration:06d}")
195
+
else:
196
+
save((1 - temp).clamp_(0.0, 1.0), f"{run_dir}/aggregate/_{iteration:06d}")
197
+
temp = (temp - p_low) / (1e-7 + p_high - p_low)
198
+
#save(1 - temp, f"{run_dir}/aggregate/{iteration:06d}")
199
+
#scratch /= scratch.max()
200
+
#scratch = scratch.sqrt().sqrt()
201
+
if show_gridlines:
202
+
scratch[:,:,1] += gridlines
203
+
scratch[:,:,2] += gridlines
204
+
if iteration % save_every == 0:
205
+
save(1 - scratch.permute(2,0,1), f"{run_dir}/frame/{iteration:06d}")
206
+
if iteration % yell_every == 0:
207
+
print(f"{iteration} iterations")
208
+
209
+
p_positions = next_positions(p_positions, iteration)
210
+
211
+
#for d in range(3):
212
+
# canvas[d] -= canvas[d].mean()
213
+
# canvas[d] /= 8 * canvas[d].std()
214
+
# canvas[d] -= canvas[d].min()
215
+
216
+
217
+
torch.set_default_device(device)
218
+
+396
sketch/sdxl.py
+396
sketch/sdxl.py
···
1
+
import math
2
+
import gc
3
+
4
+
import torch
5
+
import numpy as np
6
+
import torch.nn as nn
7
+
from torch.nn import functional as func
8
+
from PIL import Image
9
+
10
+
from diffusers import UNet2DConditionModel
11
+
import transformers
12
+
13
+
from lib.diffusion.guidance import *
14
+
from lib.diffusion.schedule import *
15
+
from lib.diffusion.sdxl_encoder import PromptEncoder
16
+
from lib.diffusion.sdxl_vae import Decoder, save_approx_decode
17
+
18
+
from lib.ode import _euler_step as euler_step
19
+
from lib.log import Timer
20
+
#from lib.util import pilify
21
+
22
+
'''
23
+
beware: this file is a mess, lots of things are broken and poorly named
24
+
'''
25
+
26
+
def save_raw_latents(latents, path):
27
+
lmin = latents.min()
28
+
l = latents - lmin
29
+
lmax = latents.max()
30
+
l = latents / lmax
31
+
l = l.float() * 127.5 + 127.5
32
+
l = l.detach().cpu().numpy()
33
+
l = l.round().astype("uint8")
34
+
35
+
ims = []
36
+
37
+
for lat in l:
38
+
row1 = np.concatenate([lat[0], lat[1]])
39
+
row2 = np.concatenate([lat[2], lat[3]])
40
+
grid = np.concatenate([row1, row2], axis=1)
41
+
#for channel in lat:
42
+
im = Image.fromarray(grid)
43
+
im = im.resize(size=(grid.shape[1]*4, grid.shape[0]*4), resample=Image.NEAREST)
44
+
ims += [im]
45
+
46
+
for im in ims:
47
+
im.save(path)
48
+
49
+
50
+
def pilify(latents, vae):
51
+
#latents = 1 / vae.config.scaling_factor * latents
52
+
latents = 1 / 0.13025 * latents
53
+
latents = latents.to(torch.float32)#vae.dtype)
54
+
with torch.no_grad():
55
+
images = vae.decode(latents)#.sample
56
+
57
+
images = images.detach().mul_(127.5).add_(127.5).clamp_(0,255).round()
58
+
#return [images]
59
+
images = images.permute(0,2,3,1).cpu().numpy().astype("uint8")
60
+
return [Image.fromarray(image) for image in images]
61
+
62
+
63
+
torch.set_grad_enabled(False)
64
+
torch.backends.cuda.matmul.allow_tf32 = True
65
+
torch.set_float32_matmul_precision("medium")
66
+
67
+
def persistent():
68
+
model_path = "/ssd/0/ml_models/ql/hf-diff/stable-diffusion-xl-base-0.9"
69
+
70
+
main_device = "cuda:0"
71
+
decoder_device = "cuda:1"
72
+
clip_device = "cuda:1"
73
+
74
+
main_dtype = torch.float64
75
+
noise_predictor_dtype = torch.float16
76
+
decoder_dtype = torch.float32
77
+
prompt_encoder_dtype = torch.float16
78
+
79
+
with Timer("total"):
80
+
with Timer("decoder"):
81
+
decoder = Decoder()
82
+
decoder.load_safetensors(model_path)
83
+
decoder.to(device=decoder_device)
84
+
#decoder = torch.compile(decoder, mode="default", fullgraph=True)
85
+
86
+
with Timer("noise_predictor"):
87
+
noise_predictor = UNet2DConditionModel.from_pretrained(
88
+
model_path, subfolder="unet", torch_dtype=noise_predictor_dtype
89
+
)
90
+
noise_predictor.to(device=main_device)
91
+
92
+
# compilation will not actually happen until first use of noise_predictor
93
+
# (as of torch 2.2.2) "default" provides the best result on my machine
94
+
# don't use this if you're gonna be changing resolutions a lot
95
+
#noise_predictor = torch.compile(noise_predictor, mode="default", fullgraph=True)
96
+
97
+
with Timer("clip"):
98
+
prompt_encoder = PromptEncoder(model_path, True, (clip_device, main_device), prompt_encoder_dtype)
99
+
100
+
101
+
102
+
# TODO: these should come from config!
103
+
vae_scale = 0.13025
104
+
decoder_dim_scale = 2 ** 3
105
+
variance_range = (0.00085, 0.012)
106
+
107
+
108
+
seed = lambda run, meta: 2835723 + run
109
+
110
+
meta_count = 1
111
+
112
+
width, height = 16, 16
113
+
114
+
steps = 50
115
+
run_count = 10
116
+
117
+
timestep_power = 1
118
+
timestep_range = (0, 999)
119
+
120
+
# TODO prompts are now in sketch/local/prompts, should be simple enough to fetch them
121
+
122
+
empty_p = ""
123
+
124
+
prompts = {
125
+
"encoder_1": [p3, p3],
126
+
"encoder_2": None,
127
+
"encoder_2_pooled": None
128
+
}
129
+
_lerp = lambda t, a, b: t * b + (1-t) * a
130
+
131
+
combine_predictions = lambda s,r: scaled_CFG(
132
+
difference_scales = [
133
+
(1, -1, lambda x: x)
134
+
#(1, 0, lambda x: _lerp(s/steps, 1.0, 0.3) * x),
135
+
#(2, 0, lambda x: _lerp(s/steps, 0.3, 1.0) * x),
136
+
],
137
+
steering_scale = lambda x: 1 * x,
138
+
#base_term = lambda predictions, true_noise: predictions[0],
139
+
base_term = lambda predictions, true_noise: true_noise,
140
+
total_scale = lambda predictions, cfg_result: cfg_result
141
+
)
142
+
143
+
# TNR
144
+
_scale = lambda step, scale: _lerp(step / (steps - 1), scale / steps, 0.1 * scale / steps)
145
+
a,b,c = -7, 10, 15
146
+
#combine_predictions = lambda step, run: lambda p, n: _scale(step, a + c * run / (run_count-1)) * (n - p[0]) + _scale(step, b - c * run / (run_count-1)) * (n - p[1])
147
+
#combine_predictions = lambda step, run: lambda p, n: _scale(step, 4) * (n - p[1]) + _scale(step, 2.5) * (p[0] - n)
148
+
combine_predictions = lambda step, run: lambda p, n: _scale(step, 3) * (n - p[0])
149
+
#combine_predictions = lambda step, run: lambda p, n: _scale(step, _lerp(step/steps, 0.0, 0.0)) * (n - p[0]) + _scale(step, 2) * (n - p[1])
150
+
151
+
152
+
diffusion_method = "tnr"
153
+
154
+
solver_step = euler_step
155
+
156
+
save_raw = lambda run, step: False
157
+
save_approximates = lambda run, step: True
158
+
save_final = True
159
+
160
+
def pre_step(step, latents):
161
+
return latents
162
+
163
+
def post_step(step, latents):
164
+
return latents
165
+
166
+
def main():
167
+
168
+
schedule(meta_run, range(meta_count))
169
+
170
+
def meta_run(meta_id):
171
+
forward_noise_schedule = default_variance_schedule(variance_range).to(main_dtype) # beta
172
+
forward_noise_total = forward_noise_schedule.cumsum(dim=0)
173
+
forward_signal_product = torch.cumprod((1 - forward_noise_schedule), dim=0) # alpha_bar
174
+
partial_signal_product = lambda s, t: torch.prod((1 - forward_noise_schedule)[s+1:t]) # alpha_bar_t / alpha_bar_s (but computed more directly from the forward noise)
175
+
part_noise = (1 - forward_signal_product).sqrt() # sigma
176
+
part_signal = forward_signal_product.sqrt() # mu?
177
+
178
+
schedule(run, range(run_count))
179
+
180
+
def get_signal_ratio(from_timestep, to_timestep):
181
+
if from_timestep < to_timestep: # forward
182
+
return 1 / partial_signal_product(from_timestep, to_timestep).sqrt()
183
+
else: # backward
184
+
return partial_signal_product(to_timestep, from_timestep).sqrt()
185
+
186
+
def step_by_noise(latents, noise, from_timestep, to_timestep):
187
+
signal_ratio = get_signal_ratio(from_timestep, to_timestep)
188
+
return latents / signal_ratio + noise * (part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio)
189
+
190
+
def stupid_simple_step_by_noise(latents, noise, from_timestep, to_timestep):
191
+
signal_ratio = get_signal_ratio(from_timestep, to_timestep)
192
+
return latents / signal_ratio + noise * (1 - 1 / signal_ratio)
193
+
194
+
def cfgpp_step_by_noise(latents, combined, base, from_timestep, to_timestep):
195
+
signal_ratio = get_signal_ratio(from_timestep, to_timestep)
196
+
return latents / signal_ratio + base * part_noise[to_timestep] - combined * (part_noise[from_timestep] / signal_ratio)
197
+
198
+
def tnr_step_by_noise(latents, diff_term, base_term, from_timestep, to_timestep):
199
+
signal_ratio = get_signal_ratio(from_timestep, to_timestep)
200
+
diff_coefficient = part_noise[from_timestep] / signal_ratio
201
+
base_coefficient = part_noise[to_timestep] - diff_coefficient
202
+
#print((1/signal_ratio).item(), base_coefficient.item(), diff_coefficient.item())
203
+
return latents / signal_ratio + base_term * base_coefficient + diff_term * diff_coefficient
204
+
205
+
def tnrb_step_by_noise(latents, diff_term, base_term, from_timestep, to_timestep):
206
+
signal_ratio = get_signal_ratio(from_timestep, to_timestep)
207
+
base_coefficient = part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio
208
+
measure = lambda x: x.abs().max().item()
209
+
#print(measure(latents / signal_ratio), measure(base_term * base_coefficient), measure(diff_term))
210
+
return latents / signal_ratio + base_term * base_coefficient + diff_term
211
+
212
+
def shuffle_step(latents, first_noise, second_noise, timestep, intermediate_timestep):
213
+
if from_timestep < to_timestep: # forward
214
+
signal_ratio = 1 / partial_signal_product(timestep, intermediate_timestep).sqrt()
215
+
else: # backward
216
+
signal_ratio = partial_signal_product(intermediate_timestep, timestep).sqrt()
217
+
return latents + (first_noise - second_noise) * (part_noise[intermediate_timestep] * signal_ratio - part_noise[timestep])
218
+
219
+
def index_interpolate(source, index):
220
+
frac, whole = math.modf(index)
221
+
if frac == 0:
222
+
return source[int(whole)]
223
+
return lerp(source[int(whole)], source[int(whole)+1], frac)
224
+
225
+
def run(run_id):
226
+
try:
227
+
_seed = int(seed(run_id, meta_id))
228
+
except:
229
+
_seed = 0
230
+
print(f"non-integer seed, run {run_id}. replaced with 0.")
231
+
232
+
torch.manual_seed(_seed)
233
+
np.random.seed(_seed)
234
+
235
+
diffusion_timesteps = linspace_timesteps(steps+1, timestep_range[1], timestep_range[0], timestep_power)
236
+
237
+
noise_predictor_batch_size = len(prompts["encoder_1"])
238
+
239
+
(all_penult_states, enc2_pooled) = prompt_encoder.encode(prompts["encoder_1"], prompts["encoder_2"], prompts["encoder_2_pooled"])
240
+
241
+
global width
242
+
global height
243
+
if (width < 64): width *= 64
244
+
if (height < 64): height *= 64
245
+
246
+
latents = torch.zeros(
247
+
(1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
248
+
device=main_device,
249
+
dtype=main_dtype
250
+
)
251
+
252
+
noises = torch.randn(
253
+
#(run_context.steps, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
254
+
(1, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
255
+
device=main_device,
256
+
dtype=main_dtype
257
+
)
258
+
259
+
latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[0])
260
+
261
+
original_size = (height, width)
262
+
target_size = (height, width)
263
+
crop_coords_top_left = (0, 0)
264
+
265
+
# incomprehensible var name tbh go read the sdxl paper if u want to Understand
266
+
add_time_ids = torch.tensor([list(original_size + crop_coords_top_left + target_size)], dtype=noise_predictor_dtype).repeat(noise_predictor_batch_size,1).to("cuda")
267
+
268
+
added_cond_kwargs = {"text_embeds": enc2_pooled.to(noise_predictor_dtype), "time_ids": add_time_ids}
269
+
270
+
for step_index in range(steps):
271
+
noise = noises[0]
272
+
273
+
start_timestep = index_interpolate(diffusion_timesteps, step_index).round().int()
274
+
end_timestep = index_interpolate(diffusion_timesteps, step_index + 1).round().int()
275
+
276
+
end_noise = part_noise[end_timestep]
277
+
end_signal = part_signal[end_timestep]
278
+
start_noise = part_noise[start_timestep]
279
+
start_signal = part_signal[start_timestep]
280
+
signal_ratio = get_signal_ratio(start_timestep, end_timestep)
281
+
start = start_timestep
282
+
end = end_timestep
283
+
284
+
285
+
sigratio = get_signal_ratio(start_timestep, end_timestep)
286
+
287
+
def predict_noise(latents, step=0):
288
+
return noise_predictor(
289
+
latents.repeat(noise_predictor_batch_size, 1, 1, 1).to(noise_predictor_dtype),
290
+
index_interpolate(diffusion_timesteps, step_index + step).round().int(),
291
+
encoder_hidden_states=all_penult_states.to(noise_predictor_dtype),
292
+
return_dict=False,
293
+
added_cond_kwargs=added_cond_kwargs
294
+
)[0]
295
+
296
+
def standard_predictor(combiner):
297
+
def _predict(latents, step=0):
298
+
predictions = predict_noise(latents, step)
299
+
return predictions, noise, combiner(predictions, noise)
300
+
return _predict
301
+
302
+
def constructive_predictor(combiner):
303
+
def _predict(latents, step=0):
304
+
noised = step_by_noise(latents, noise, 0, index_interpolate(diffusion_timesteps, step_index + step).round().int())
305
+
predictions = predict_noise(noised, step)
306
+
return predictions, noise, combiner(latents, predictions, noise)
307
+
return _predict
308
+
309
+
310
+
def standard_diffusion_step(latents, noises, start, end):
311
+
start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
312
+
end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
313
+
predictions, true_noise, combined_prediction = noises
314
+
return step_by_noise(latents, combined_prediction, start_timestep, end_timestep)
315
+
316
+
def stupid_simple_step(latents, noises, start, end):
317
+
start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
318
+
end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
319
+
predictions, true_noise, combined_prediction = noises
320
+
return stupid_simple_step_by_noise(latents, combined_prediction, start_timestep, end_timestep)
321
+
322
+
def cfgpp_diffusion_step(choose_base, choose_combined):
323
+
def _diffusion_step(latents, noises, start, end):
324
+
start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
325
+
end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
326
+
return cfgpp_step_by_noise(latents, choose_combined(noises), choose_base(noises), start_timestep, end_timestep)
327
+
return _diffusion_step
328
+
329
+
def tnr_diffusion_step(latents, noises, start, end):
330
+
start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
331
+
end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
332
+
predictions, true_noise, combined_prediction = noises
333
+
return tnr_step_by_noise(latents, combined_prediction, predictions[0], start_timestep, end_timestep)
334
+
335
+
def tnrb_diffusion_step(latents, noises, start, end):
336
+
start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
337
+
end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
338
+
predictions, true_noise, combined_prediction = noises
339
+
return tnrb_step_by_noise(latents, combined_prediction, predictions[0], start_timestep, end_timestep)
340
+
341
+
def constructive_step(latents, noises, start, end):
342
+
start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
343
+
end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
344
+
predictions, true_noise, combined_prediction = noises
345
+
return latents + combined_prediction
346
+
347
+
def select_prediction(index):
348
+
return lambda noises: noises[0][index]
349
+
350
+
select_true_noise = lambda noises: noises[1]
351
+
select_combined = lambda noises: noises[2]
352
+
353
+
354
+
if diffusion_method == "standard":
355
+
take_step = standard_diffusion_step
356
+
if diffusion_method == "stupid":
357
+
take_step = stupid_simple_step
358
+
if diffusion_method == "cfg++":
359
+
take_step = cfgpp_diffusion_step(select_prediction(0), select_combined)
360
+
if diffusion_method == "tnr":
361
+
take_step = tnr_diffusion_step
362
+
if diffusion_method == "tnrb":
363
+
take_step = tnrb_diffusion_step
364
+
365
+
if diffusion_method == "cons":
366
+
take_step = constructive_step
367
+
get_derivative = constructive_predictor(combine_predictions)
368
+
else:
369
+
get_derivative = standard_predictor(combine_predictions(step_index, run_id))
370
+
371
+
solver = solver_step
372
+
373
+
latents = pre_step(step_index, latents)
374
+
375
+
latents = solver(get_derivative, take_step, latents)
376
+
377
+
latents = post_step(step_index, latents)
378
+
379
+
if step_index < steps - 1 and diffusion_method != "cons":
380
+
pred_original_sample = step_by_noise(latents, noise, diffusion_timesteps[step_index+1], diffusion_timesteps[-1])
381
+
else:
382
+
pred_original_sample = latents
383
+
384
+
if save_raw(run_id, step_index):
385
+
save_raw_latents(pred_original_sample, f"out/{run_dir}/{run_id}_raw_{step_index:03d}.png")
386
+
if save_approximates(run_id, step_index):
387
+
save_approx_decode(pred_original_sample, f"out/{run_dir}/{run_id}_approx_{step_index:03d}.png")
388
+
389
+
if save_final:
390
+
images_pil = pilify(pred_original_sample.to(device=decoder_device), decoder)
391
+
392
+
for im in images_pil:
393
+
for n in range(len(images_pil)):
394
+
images_pil[n].save(f"out/{run_dir}/{meta_id:05d}_{n}_{run_id:05d}_{step_index:05d}.png")
395
+
396
+
+120
-44
sketch/webserver.py
+120
-44
sketch/webserver.py
···
6
6
7
7
#from lib import log
8
8
9
-
# TODO make it upgrade to https instead of just fucking throwing exceptions
9
+
MODE = "local"
10
10
11
-
HOST = "0.0.0.0"
12
-
PORT = 1313
13
-
BASE_DIR = Path("./webui").resolve()
14
-
USE_SSL = False
11
+
if MODE == "local":
12
+
HOST = "0.0.0.0"
13
+
PORT, SSL_PORT = 1313, None
14
+
BASE_DIR = Path("./webui").resolve()
15
+
USE_SSL = False
16
+
SSL_CERT = "/home/ponder/ponder/certs/cert.pem"
17
+
SSL_KEY = "/home/ponder/ponder/certs/key.pem"
18
+
elif MODE == "remote":
19
+
HOST = "ponder.ooo"
20
+
PORT, SSL_PORT = 80, 443
21
+
BASE_DIR = Path("./webui").resolve()
22
+
USE_SSL = True
23
+
SSL_KEY="/etc/letsencrypt/live/ponder.ooo/privkey.pem"
24
+
SSL_CERT="/etc/letsencrypt/live/ponder.ooo/fullchain.pem"
15
25
16
26
ROUTES = {
17
27
"/": "content/main.html",
···
26
36
".orb": "text/x-orb"
27
37
}
28
38
29
-
SSL_CERT = "/home/ponder/ponder/certs/cert.pem"
30
-
SSL_KEY = "/home/ponder/ponder/certs/key.pem"
31
39
32
40
def guess_mime_type(path):
33
41
return MIME_TYPES.get(Path(path).suffix, "application/octet-stream")
···
35
43
def build_response(status_code, body=b"", content_type="text/plain"):
36
44
reason = {
37
45
200: "OK",
46
+
301: "Moved Permanently",
38
47
400: "Bad Request",
39
48
404: "Not Found",
40
49
405: "Method Not Allowed",
···
48
57
f"\r\n"
49
58
).encode() + body
50
59
51
-
def route(req_path):
52
-
if req_path == "/":
53
-
return "content/main.html"
54
-
if any(req_path.endswith(x) for x in [".html", ".png"]):
55
-
return f"content/{req_path}"
56
-
if any(req_path.endswith(x) for x in [".wgsl"]):
57
-
return f"content/wgsl/{req_path}"
58
-
if any(req_path.endswith(x) for x in [".orb"]):
59
-
return f"content/orb/{req_path}"
60
-
if req_path.endswith(".js"):
61
-
return f"js/{req_path}"
62
-
return req_path[1:]
60
+
def route(path):
61
+
path = path.relative_to('/') if path.is_absolute() else path
62
+
63
+
if path == Path():
64
+
return BASE_DIR / "content/main.html"
65
+
66
+
suffix = path.suffix
67
+
68
+
if suffix in [".html", ".png"]:
69
+
return BASE_DIR / Path("content") / path
70
+
if suffix in [".wgsl"]:
71
+
return BASE_DIR / Path("content/wgsl") / path
72
+
if suffix in [".orb"]:
73
+
return BASE_DIR / Path("content/orb") / path
74
+
if suffix in [".js"]:
75
+
return BASE_DIR / Path("js") / path
76
+
return BASE_DIR / path
63
77
64
78
def handle_request(request_data):
65
79
try:
···
71
85
if method != "GET":
72
86
return build_response(405, b"Method Not Allowed")
73
87
74
-
req_path = urlparse(unquote(raw_path)).path
88
+
req_path = Path(urlparse(unquote(raw_path)).path).resolve(strict=False)
75
89
76
-
norm_path = os.path.normpath(req_path)
90
+
routed_path = route(req_path).resolve(strict=False)
77
91
78
-
if ".." in norm_path:
92
+
if not routed_path.is_relative_to(BASE_DIR):
79
93
return build_response(400, b"Fuck You")
80
94
81
-
file_path = route(norm_path)
82
-
83
-
if not file_path:
84
-
return build_response(404, b"Not Found")
85
-
86
-
full_path = BASE_DIR / file_path
87
-
if not full_path.exists():
95
+
if not routed_path.exists():
88
96
return build_response(404, b"File not found")
89
97
90
-
with open(full_path, "rb") as f:
98
+
with open(routed_path, "rb") as f:
91
99
body = f.read()
92
-
return build_response(200, body, guess_mime_type(file_path))
100
+
return build_response(200, body, guess_mime_type(routed_path))
93
101
94
102
except Exception as e:
95
103
return build_response(500, f"Server error: {e}".encode())
96
104
97
-
def main():
105
+
106
+
107
+
def build_redirect_response(location):
108
+
return (
109
+
f"HTTP/1.1 301 Moved Permanently\r\n"
110
+
f"Location: {location}\r\n"
111
+
f"Content-Length: 0\r\n"
112
+
f"Connection: close\r\n"
113
+
f"\r\n"
114
+
).encode()
115
+
116
+
def handle_http_redirect(conn):
117
+
request = conn.recv(4096)
118
+
try:
119
+
lines = request.decode().split("\r\n")
120
+
if lines:
121
+
method, raw_path, *_ = lines[0].split()
122
+
path = urlparse(unquote(raw_path)).path
123
+
else:
124
+
path = "/"
125
+
except:
126
+
path = "/"
127
+
redirect = build_redirect_response(f"https://{HOST}{path}")
128
+
conn.sendall(redirect)
129
+
130
+
def http_redirect_server():
131
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
132
+
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
133
+
server.bind((HOST, PORT))
134
+
server.listen()
135
+
print(f"redirecting on {HOST}:{PORT}")
136
+
while True:
137
+
try:
138
+
conn, addr = server.accept()
139
+
except KeyboardInterrupt:
140
+
break
141
+
except:
142
+
continue
143
+
with conn:
144
+
try:
145
+
handle_http_redirect(conn)
146
+
except:
147
+
pass
148
+
149
+
150
+
def ssl_main():
98
151
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
99
152
ssl_ctx.load_cert_chain(certfile=SSL_CERT, keyfile=SSL_KEY)
100
153
154
+
#threading.Thread(target=http_redirect_server, daemon=True).start()
155
+
101
156
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
102
157
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
103
-
server.bind((HOST, PORT))
158
+
server.bind((HOST, SSL_PORT))
104
159
server.listen()
105
-
print(f"Serving HTTPS on {HOST}:{PORT}")
160
+
print(f"Serving HTTPS on {HOST}:{SSL_PORT}")
106
161
while True:
107
162
try:
108
163
conn, addr = server.accept()
109
-
if USE_SSL:
110
-
with ssl_ctx.wrap_socket(conn, server_side=True) as ssl_conn:
111
-
request = ssl_conn.recv(4096)
112
-
response = handle_request(request)
113
-
ssl_conn.sendall(response)
114
-
else:
115
-
request = conn.recv(4096)
164
+
conn.settimeout(10.0)
165
+
with ssl_ctx.wrap_socket(conn, server_side=True) as ssl_conn:
166
+
request = ssl_conn.recv(4096)
116
167
response = handle_request(request)
117
-
conn.sendall(response)
168
+
ssl_conn.sendall(response)
169
+
except KeyboardInterrupt:
170
+
raise
171
+
except TimeoutError:
172
+
continue
173
+
except Exception as e:
174
+
print(f"Error: {e}")
175
+
exit()
176
+
177
+
178
+
def main():
179
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
180
+
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
181
+
server.bind((HOST, PORT))
182
+
server.listen()
183
+
print(f"Serving HTTP on {HOST}:{PORT}")
184
+
while True:
185
+
try:
186
+
conn, addr = server.accept()
187
+
request = conn.recv(4096)
188
+
response = handle_request(request)
189
+
conn.sendall(response)
118
190
except KeyboardInterrupt:
119
191
raise
120
-
except:
192
+
except Exception as e:
193
+
print(e)
121
194
break
122
195
123
196
if __name__ == "__main__":
124
-
main()
197
+
if USE_SSL:
198
+
ssl_main()
199
+
else:
200
+
main()
+1
-1
webui/content/orb/projective_shift.orb
+1
-1
webui/content/orb/projective_shift.orb
+134
-14
webui/content/wgsl/proj_shift.wgsl
+134
-14
webui/content/wgsl/proj_shift.wgsl
···
10
10
max_iter: u32,
11
11
escape_distance: f32,
12
12
psi: f32,
13
+
d: f32,
14
+
twist: f32,
15
+
squoosh_x: f32,
16
+
squoosh_y: f32
13
17
}
14
18
15
19
@group(0) @binding(0) var<uniform> uniforms: Uniforms;
···
36
40
return vec2<f32>(new_mag * cos(x_angle * psi), new_mag * sin(x_angle * psi));
37
41
}
38
42
39
-
fn iterate_polar(x: vec2<f32>, phi: f32, psi: f32, c: f32) -> vec2<f32> {
43
+
fn iterate_polar(x: vec2<f32>, phi: f32, psi: f32, c: f32, d: f32, cosTwist: f32, sinTwist: f32, squoosh_x: f32, squoosh_y: f32) -> vec2<f32> {
40
44
let shifted = projective_shift(x, phi, psi);
41
-
return vec2<f32>(shifted.x - c, shifted.y);
45
+
let halfEbb = vec2<f32>(-c * 0.5, -d * 0.5);
46
+
let mid = shifted + halfEbb;
47
+
let mids = vec2<f32>(mid.x / squoosh_x, mid.y / squoosh_y);
48
+
let rotated = vec2<f32>(mids.x * cosTwist - mids.y * sinTwist, mids.x * sinTwist + mids.y * cosTwist);
49
+
return rotated + halfEbb;
42
50
}
43
51
44
52
fn iterate_cartesian(z: vec2<f32>, phi: f32, c: f32) -> vec2<f32> {
···
68
76
return vec2<f32>(x, y);
69
77
}
70
78
79
+
fn pixel_to_complex_alt(px: u32, py: u32) -> vec2<f32> {
80
+
let aspect = f32(uniforms.width) / f32(uniforms.height);
81
+
let scale = 4.0 / uniforms.zoom;
82
+
let half_width = f32(uniforms.width) * 0.5;
83
+
let half_height = f32(uniforms.height) * 0.5;
84
+
85
+
// map y-axis to magnitude, x-axis to angle
86
+
let angle = (f32(px) - half_width) * scale * aspect / f32(uniforms.width) + uniforms.center_x;
87
+
let mag = (f32(py) - half_height) * scale / (2.0 * f32(uniforms.height)) + uniforms.center_y;
88
+
89
+
return vec2<f32>(mag * cos(angle), mag * sin(angle));
90
+
}
91
+
92
+
fn pixel_to_complex_inverted(px: u32, py: u32) -> vec2<f32> {
93
+
let aspect = f32(uniforms.width) / f32(uniforms.height);
94
+
let scale = 4.0 / uniforms.zoom;
95
+
let half_width = f32(uniforms.width) * 0.5;
96
+
let half_height = f32(uniforms.height) * 0.5;
97
+
98
+
let x = (f32(px) - half_width) * scale * aspect / f32(uniforms.width) + uniforms.center_x;
99
+
let y = (f32(py) - half_height) * scale / f32(uniforms.height) + uniforms.center_y;
100
+
101
+
let z = vec2<f32>(x, y);
102
+
let mag_sq = x * x + y * y;
103
+
104
+
// handle singularity at origin
105
+
if (mag_sq < 1e-10) {
106
+
return vec2<f32>(1e10, 0.0); // or whatever you want infinity to map to
107
+
}
108
+
109
+
return vec2<f32>(x / mag_sq, -y / mag_sq);
110
+
}
111
+
112
+
fn pixel_to_complex_inverted_d(px: u32, py: u32) -> vec2<f32> {
113
+
let aspect = f32(uniforms.width) / f32(uniforms.height);
114
+
let scale = 4.0 / uniforms.zoom;
115
+
let half_width = f32(uniforms.width) * 0.5;
116
+
let half_height = f32(uniforms.height) * 0.5;
117
+
118
+
// get position relative to center
119
+
let x = (f32(px) - half_width) * scale * aspect / f32(uniforms.width);
120
+
let y = (f32(py) - half_height) * scale / f32(uniforms.height);
121
+
122
+
let mag_sq = x * x + y * y;
123
+
if (mag_sq < 1e-10) {
124
+
return vec2<f32>(uniforms.center_x + 1e10, uniforms.center_y);
125
+
}
126
+
127
+
// invert then add center back
128
+
return vec2<f32>(
129
+
uniforms.center_x + x / mag_sq,
130
+
uniforms.center_y - y / mag_sq
131
+
);
132
+
}
133
+
134
+
fn pixel_to_complex_inverted_b(px: u32, py: u32) -> vec2<f32> {
135
+
let aspect = f32(uniforms.width) / f32(uniforms.height);
136
+
let scale = 4.0 / uniforms.zoom;
137
+
let half_width = f32(uniforms.width) * 0.5;
138
+
let half_height = f32(uniforms.height) * 0.5;
139
+
140
+
let x = (f32(px) - half_width) * scale * aspect / f32(uniforms.width) + uniforms.center_x;
141
+
let y = (f32(py) - half_height) * scale / f32(uniforms.height) + uniforms.center_y;
142
+
143
+
// translate by (c/2, d/2)
144
+
let tx = x + uniforms.c * 0.5;
145
+
let ty = y + uniforms.d * 0.5;
146
+
147
+
let mag_sq = tx * tx + ty * ty;
148
+
149
+
if (mag_sq < 1e-10) {
150
+
return vec2<f32>(1e10, 0.0);
151
+
}
152
+
153
+
// invert, then translate back
154
+
return vec2<f32>(tx / mag_sq - uniforms.c * 0.5, -ty / mag_sq - uniforms.d * 0.5);
155
+
}
156
+
157
+
fn pixel_to_complex_inverted_c(px: u32, py: u32) -> vec2<f32> {
158
+
let aspect = f32(uniforms.width) / f32(uniforms.height);
159
+
let scale = 4.0 / uniforms.zoom;
160
+
let half_width = f32(uniforms.width) * 0.5;
161
+
let half_height = f32(uniforms.height) * 0.5;
162
+
163
+
let x = (f32(px) - half_width) * scale * aspect / f32(uniforms.width) + uniforms.center_x;
164
+
let y = (f32(py) - half_height) * scale / f32(uniforms.height) + uniforms.center_y;
165
+
166
+
// translate by (c/2, d/2), then scale down by 0.5
167
+
let tx = (x + uniforms.c * 0.5) * 2.0;
168
+
let ty = (y + uniforms.d * 0.5) * 2.0;
169
+
170
+
let mag_sq = tx * tx + ty * ty;
171
+
172
+
if (mag_sq < 1e-10) {
173
+
return vec2<f32>(1e10, 0.0);
174
+
}
175
+
176
+
// invert, scale up by 2x, then translate back
177
+
return vec2<f32>((tx / mag_sq) * 0.5 - uniforms.c * 0.5, (-ty / mag_sq) * 0.5 - uniforms.d * 0.5);
178
+
}
179
+
71
180
fn c_avg(prev: f32, val: f32, n: f32) -> f32 {
72
181
return prev + (val - prev) / n;
73
182
}
···
80
189
return x;
81
190
}
82
191
192
+
fn pcg_hash(x: u32) -> u32 {
193
+
let state = x * 747796405u + 2891336453u;
194
+
let word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
195
+
return (word >> 22u) ^ word;
196
+
}
197
+
83
198
fn random_float(seed: u32) -> f32 {
84
199
return f32(hash(seed)) / 4294967296.0; // 2^32
85
200
}
···
93
208
return;
94
209
}
95
210
96
-
var z = pixel_to_complex(px, py);
97
-
let n_perturbations = 10u;
211
+
var z = ${pixel_mapping}(px, py);
212
+
let orig_z = z;
213
+
let n_perturbations = 1u;
98
214
let escape_threshold = uniforms.escape_distance * uniforms.escape_distance;
99
215
100
216
var escaped = false;
101
217
var iter = 0u;
102
218
var cavg = f32(0.0);
103
-
let epsilon = 1.19e-4;
219
+
let epsilon = 1.19e-6;
220
+
221
+
let transient_skip = u32(f32(uniforms.max_iter) * f32(0.95));
104
222
105
-
let transient_skip = u32(f32(uniforms.max_iter) * f32(0.9));
223
+
let cosTwist = cos(uniforms.twist);
224
+
let sinTwist = sin(uniforms.twist);
106
225
107
226
for (iter = 0u; iter < uniforms.max_iter; iter = iter + 1u) {
108
227
let mag_sq = complex_mag_sq(z);
···
110
229
escaped = true;
111
230
//break;
112
231
}
113
-
var z_p: array<vec2<f32>, 4>;
114
-
var df_z_p: array<vec2<f32>, 4>;
115
-
var abs_df_z_p: array<f32, 4>;
232
+
var z_p: array<vec2<f32>, 10>;
233
+
var df_z_p: array<vec2<f32>, 10>;
234
+
var abs_df_z_p: array<f32, 10>;
116
235
var avg_abs_df_z_p = f32(0.0);
117
236
118
-
var f_z = iterate_polar(z, uniforms.phi, uniforms.psi, uniforms.c);
237
+
var f_z = iterate_polar(z, uniforms.phi, uniforms.psi, uniforms.c, uniforms.d, cosTwist, sinTwist, uniforms.squoosh_x, uniforms.squoosh_y);
119
238
120
239
if (iter >= transient_skip) {
121
240
if (cavg < -10.0) { continue; }
122
241
if (cavg > 10.0) { continue; }
123
242
for (var p = 0u; p < n_perturbations; p = p + 1u) {
124
-
let rng_seed = hash(px * 1000u + py * 2000u + p * 5000u);
243
+
let rng_seed = pcg_hash(px + pcg_hash(py << 1)) ^ pcg_hash(p);
125
244
let random_angle = random_float(rng_seed) * 6.283185307;
126
245
let perturb_direction = vec2<f32>(cos(random_angle), sin(random_angle));
127
246
z_p[p] = z + perturb_direction * epsilon;
128
-
df_z_p[p] = iterate_polar(z, uniforms.phi, uniforms.psi, uniforms.c) - f_z;
129
-
abs_df_z_p[p] = abs(complex_mag(df_z_p[p]) / epsilon);
247
+
df_z_p[p] = iterate_polar(z_p[p], uniforms.phi, uniforms.psi, uniforms.c, uniforms.d, cosTwist, sinTwist, uniforms.squoosh_x, uniforms.squoosh_y) - orig_z;
248
+
abs_df_z_p[p] = abs(complex_mag(df_z_p[p]));
130
249
avg_abs_df_z_p += abs_df_z_p[p];
131
250
}
132
251
···
136
255
}
137
256
138
257
z = f_z;
258
+
//z = avg_abs_df_z_p;
139
259
}
140
260
141
261
var color: vec4<f32>;
···
156
276
color = vec4<f32>(r, scaled_mag, b, 1.0);
157
277
}
158
278
159
-
color = vec4<f32>(cavg / 1.0, -cavg / 20.0, -cavg / 20.0, 1.0);
279
+
color = vec4<f32>(cavg / 10.0, -cavg / 5.0, -cavg / 5.0, 1.0);
160
280
161
281
textureStore(output_texture, vec2<i32>(i32(px), i32(py)), color);
162
282
}
+21
-6
webui/js/control/number.js
+21
-6
webui/js/control/number.js
···
136
136
const copyable_value = document.createElement("span");
137
137
copyable_value.innerText = `${spec.value};\n`;
138
138
copyable_value.classList = "copyable-value";
139
-
label_eq.appendChild(copyable_value);
140
139
141
140
label_eq.setAttribute("aria-hidden", true);
142
141
···
158
157
field.step = spec.step;
159
158
field.value = spec.value;
160
159
160
+
const play_button = $element("button");
161
+
play_button.innerText = "โถ/โธ";
162
+
// todo alt text
163
+
164
+
161
165
const set = (value) => {
162
166
slider.value = value;
163
167
field.value = value;
164
168
}
165
169
170
+
const reset_button = $element("button");
171
+
reset_button.innerText = "โณ";
172
+
reset_button.label = "reset";
173
+
reset_button.addEventListener("click", () => {
174
+
set(spec.value);
175
+
spec.onUpdate?.(spec.value, set);
176
+
});
177
+
178
+
166
179
slider.addEventListener("input", () => {
167
180
field.value = slider.value;
168
181
copyable_value.innerText = slider.value;
···
175
188
spec.onUpdate?.(field.value, set);
176
189
});
177
190
178
-
control.appendChild(label);
179
-
control.appendChild(label_eq);
180
-
control.appendChild(field);
181
-
control.appendChild(slider);
182
-
target.appendChild(control);
191
+
target.$with(
192
+
control.$with(
193
+
label, label_eq.$with(copyable_value), field,
194
+
play_button, reset_button,
195
+
slider
196
+
)
197
+
);
183
198
}
184
199
+6
-1
webui/js/control/panel.js
+6
-1
webui/js/control/panel.js
···
9
9
background: var(--main-background);
10
10
font-family: var(--main-font);
11
11
color: var(--main-solid);
12
+
overflow-y: scroll;
13
+
height: 100%;
14
+
width: 100%;
15
+
}
16
+
17
+
.control-panel > * {
12
18
max-width: 300px;
13
-
overflow-y: scroll;
14
19
}
15
20
16
21
.control-panel legend {
+83
-10
webui/js/gpu/proj_shift.js
+83
-10
webui/js/gpu/proj_shift.js
···
24
24
let phi = 0.0;
25
25
let psi = 1.0;
26
26
let c = 0.85;
27
+
let d = 0.0;
27
28
let iterations = 100;
28
29
let escape_distance = 2;
29
30
let centerX = -0.5;
30
31
let centerY = 0.0;
31
32
let zoom = 4.0;
33
+
let twist = 0.0;
34
+
let squoosh_x = 1.0;
35
+
let squoosh_y = 1.0;
32
36
33
37
let showTrajectory = false;
34
38
···
51
55
};
52
56
}
53
57
54
-
function iteratePolar(x, phi, psi, c) {
58
+
function iteratePolar(x, phi, psi, c, d) {
55
59
const shifted = projectiveShift(x, phi, psi);
56
-
return { x: shifted.x - c, y: shifted.y };
60
+
return { x: shifted.x - c, y: shifted.y - d };
57
61
}
58
62
59
63
function computeTrajectory(startZ, maxIters, escapeThreshold) {
···
65
69
if (magSq > escapeThreshold * escapeThreshold) {
66
70
break;
67
71
}
68
-
z = iteratePolar(z, phi, psi, c);
72
+
z = iteratePolar(z, phi, psi, c, d);
69
73
trajectory.push({ x: z.x, y: z.y });
70
74
}
71
75
···
98
102
},
99
103
{
100
104
type: "number",
105
+
label: "d",
106
+
value: d,
107
+
min: 0,
108
+
max: 1,
109
+
step: 0.001,
110
+
onUpdate: (value, set) => {
111
+
d = value;
112
+
render();
113
+
}
114
+
},
115
+
{
116
+
type: "number",
101
117
label: greek["psi"],
102
118
value: psi,
103
119
min: 0,
···
110
126
},
111
127
{
112
128
type: "number",
129
+
label: "twist",
130
+
value: twist,
131
+
min: -$tau/2,
132
+
max: $tau/2,
133
+
step: 0.001,
134
+
onUpdate: (value, set) => {
135
+
twist = value;
136
+
render();
137
+
}
138
+
},
139
+
{
140
+
type: "number",
141
+
label: "squoosh x",
142
+
value: squoosh_x,
143
+
min: 0,
144
+
max: 10,
145
+
step: 0.001,
146
+
onUpdate: (value, set) => {
147
+
squoosh_x = value;
148
+
render();
149
+
}
150
+
},
151
+
{
152
+
type: "number",
153
+
label: "squoosh y",
154
+
value: squoosh_y,
155
+
min: 0,
156
+
max: 10,
157
+
step: 0.001,
158
+
onUpdate: (value, set) => {
159
+
squoosh_y = value;
160
+
render();
161
+
}
162
+
},
163
+
{
164
+
type: "number",
113
165
label: greek["phi"],
114
166
value: phi,
115
167
min: 0,
···
151
203
]]);
152
204
153
205
const renderStack = $div("full");
206
+
renderStack.dataset.name = "renderer";
154
207
renderStack.style.position = "relative";
155
208
156
209
const gpuModule = await $mod("gpu/webgpu", renderStack);
···
162
215
canvas.setAttribute("role", "application");
163
216
canvas.setAttribute("aria-keyshortcuts", "f");
164
217
165
-
const compShader = await $gpu.loadShader("proj_shift");
218
+
const compShader = await $gpu.loadShader("proj_shift", { "pixel_mapping" : "pixel_to_complex" });
166
219
const blitShader = await $gpu.loadShader("blit");
167
220
168
221
if (!compShader || !blitShader) return;
···
205
258
renderStack.appendChild(overlay);
206
259
207
260
function showControls() {
208
-
if (controls.parentNode) return;
261
+
if (!topmost.isConnected) {
262
+
topmost = renderStack;
263
+
}
264
+
else if (topmost.querySelector(".control-panel")) return;
209
265
210
266
return ["show controls", async () => {
211
-
await $mod("layout/split", renderStack.parentNode, [{content: [controls, renderStack], percents: [20, 80]}]);
267
+
const split = await $mod("layout/split", renderStack.parentNode, [{content: [controls, renderStack], percents: [20, 80]}]);
268
+
topmost = split.topmost;
212
269
}];
213
270
}
214
271
···
217
274
return ["show trajectory", () => {showTrajectory = true}];
218
275
}
219
276
277
+
function exitRenderer() {
278
+
// relying on showControls' topmost check to have occurred before this can be called,
279
+
// which is true bc that happens while the context menu is built.
280
+
// this may not remain true if a hotkey is added for exiting w/o opening the context menu
281
+
const target = topmost.parentNode;
282
+
target.replaceChildren();
283
+
$mod("layout/nothing", target);
284
+
}
285
+
220
286
renderStack.$preventCollapse = true;
221
287
renderStack.$contextMenu = {
222
-
items: [showControls, toggleTrajectory]
288
+
items: [showControls, toggleTrajectory, ["exit", exitRenderer]]
223
289
};
224
290
225
-
await $mod("layout/split", target, [{ content: [controls, renderStack], percents: [20, 80]}]);
291
+
const split = await $mod("layout/split", target, [{ content: [controls, renderStack], percents: [20, 80]}]);
292
+
let topmost = split.topmost;
226
293
227
294
let width = canvas.clientWidth;
228
295
let height = canvas.clientHeight;
···
236
303
});
237
304
238
305
const uniformBuffer = $gpu.device.createBuffer({
239
-
size: 10 * 4,
306
+
size: 14 * 4,
240
307
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
241
308
});
242
309
···
307
374
308
375
function updateUniforms() {
309
376
// TODO: save the same buffer
310
-
const uniformData = new ArrayBuffer(10 * 4);
377
+
const uniformData = new ArrayBuffer(14 * 4);
311
378
const view = new DataView(uniformData);
312
379
313
380
/*
···
321
388
max_iter: u32,
322
389
escape_distance: f32,
323
390
psi: f32,
391
+
d: f32,
392
+
twist: f32,
324
393
*/
325
394
326
395
// TODO less brittle hard coded numbers jfc
···
334
403
view.setUint32(28, iterations, true);
335
404
view.setFloat32(32, escape_distance, true);
336
405
view.setFloat32(36, psi, true);
406
+
view.setFloat32(40, d, true);
407
+
view.setFloat32(44, twist, true);
408
+
view.setFloat32(48, squoosh_x, true);
409
+
view.setFloat32(52, squoosh_y, true);
337
410
338
411
$gpu.device.queue.writeBuffer(uniformBuffer, 0, uniformData);
339
412
}
+15
-3
webui/js/gpu/webgpu.js
+15
-3
webui/js/gpu/webgpu.js
···
31
31
const canvasFormat = navigator.gpu.getPreferredCanvasFormat();
32
32
33
33
// TODO make more robust
34
-
async function loadShader(shaderName) {
34
+
async function loadShader(shaderName, substitutions = {}) {
35
35
const response = await fetch(`./${shaderName}.wgsl`);
36
36
const shaderSource = await response.text();
37
-
const module = device.createShaderModule({ code: shaderSource });
37
+
38
+
var substitutionFailure = false;
39
+
const adjustedSource = shaderSource.replace(/\${(\w+)}/g, (match, key) => {
40
+
if (!(key in substitutions)) {
41
+
substitutionFailure = true;
42
+
return "";
43
+
}
44
+
else {
45
+
return substitutions[key];
46
+
}
47
+
});
48
+
49
+
const module = device.createShaderModule({ code: adjustedSource });
38
50
const info = await module.getCompilationInfo();
39
-
if (info.messages.some(m => m.type === "error")) {
51
+
if (info.messages.some(m => m.type === "error") || substitutionFailure) {
40
52
return null;
41
53
}
42
54
return module;
+2
-2
webui/js/layout/nothing.js
+2
-2
webui/js/layout/nothing.js
···
68
68
items: Object.entries(menuItems)
69
69
};
70
70
71
-
const menu = await $mod("control/menu", backdrop, [menuItems]);
71
+
//const menu = await $mod("control/menu", backdrop, [menuItems]);
72
72
73
73
backdrop.addEventListener("keydown", (e) => {
74
74
if (e.key === "Enter" || e.key === "o") {
···
78
78
const centerX = rect.left + rect.width / 2;
79
79
const centerY = rect.top + rect.height / 2;
80
80
81
-
menu.showMenu();
81
+
document.$showMenu(backdrop);
82
82
}
83
83
});
84
84
+17
-16
webui/js/layout/split.js
+17
-16
webui/js/layout/split.js
···
9
9
padding: 0rem;
10
10
}
11
11
12
-
[theme-changed] > .split {
12
+
[data-theme-changed] > .split {
13
13
padding: 0.5rem;
14
14
}
15
15
16
-
.split[orientation=row] {
16
+
.split[data-orientation=row] {
17
17
flex-direction: row;
18
18
}
19
19
20
-
.split[orientation=col] {
20
+
.split[data-orientation=col] {
21
21
flex-direction: column;
22
22
}
23
23
···
32
32
-webkit-user-drag: none;
33
33
}
34
34
35
-
.split[orientation=row] > .splitter {
35
+
.split[data-orientation=row] > .splitter {
36
36
width: 1px;
37
37
cursor: col-resize;
38
38
}
39
39
40
-
.split[orientation=col] > .splitter {
40
+
.split[data-orientation=col] > .splitter {
41
41
height: 1px;
42
42
cursor: row-resize;
43
43
}
···
49
49
pointer-events: auto;
50
50
}
51
51
52
-
.split[orientation=row] > .splitter::before {
52
+
.split[data-orientation=row] > .splitter::before {
53
53
top: 0;
54
54
left: calc(0px - var(--panel-margin));
55
55
width: calc(var(--panel-margin) * 2);
56
56
height: 100%;
57
57
}
58
58
59
-
.split[orientation=col] > .splitter::before {
59
+
.split[data-orientation=col] > .splitter::before {
60
60
left: 0;
61
61
top: calc(0px - var(--panel-margin));
62
62
height: calc(var(--panel-margin) * 2);
63
63
width: 100%;
64
64
}
65
65
66
-
.split[orientation=row] > :first-child {
66
+
.split[data-orientation=row] > :first-child {
67
67
margin-right: var(--panel-margin);
68
68
width: calc(var(--current-portion) - 0.5px - var(--panel-margin));
69
69
}
70
70
71
-
.split[orientation=col] > :first-child {
71
+
.split[data-orientation=col] > :first-child {
72
72
margin-bottom: var(--panel-margin);
73
73
height: calc(var(--current-portion) - 0.5px - var(--panel-margin));
74
74
}
75
75
76
-
.split[orientation=row] > :not(.splitter):not(:first-child):not(:last-child) {
76
+
.split[data-orientation=row] > :not(.splitter):not(:first-child):not(:last-child) {
77
77
margin-left: var(--panel-margin);
78
78
margin-right: var(--panel-margin);
79
79
width: calc(var(--current-portion) - 1px - 2 * var(--panel-margin));
80
80
}
81
81
82
-
.split[orientation=col] > :not(.splitter):not(:first-child):not(:last-child) {
82
+
.split[data-orientation=col] > :not(.splitter):not(:first-child):not(:last-child) {
83
83
margin-top: var(--panel-margin);
84
84
margin-bottom: var(--panel-margin);
85
85
height: calc(var(--current-portion) - 1px - 2 * var(--panel-margin));
86
86
}
87
87
88
-
.split[orientation=row] > :last-child {
88
+
.split[data-orientation=row] > :last-child {
89
89
margin-left: var(--panel-margin);
90
90
width: calc(var(--current-portion) - 0.5px - var(--panel-margin));
91
91
}
92
92
93
-
.split[orientation=col] > :last-child {
93
+
.split[data-orientation=col] > :last-child {
94
94
margin-top: var(--panel-margin);
95
95
height: calc(var(--current-portion) - 0.5px - var(--panel-margin));
96
96
}
···
146
146
var n = content.length;
147
147
148
148
const container = $div("split");
149
-
container.setAttribute("orientation", settings.orientation);
149
+
container.dataset.orientation = settings.orientation;
150
150
var row = settings.orientation === "row";
151
151
152
152
const orientationToggle = [row ? "row->col" : "col->row", () => {
153
153
row = !row;
154
154
settings.orientation = row ? "row" : "col";
155
-
container.setAttribute("orientation", settings.orientation);
155
+
container.dataset.orientation = settings.orientation;
156
156
orientationToggle[0] = row ? "row->col" : "col->row";
157
157
}];
158
158
···
347
347
};
348
348
349
349
return {
350
-
replace: true
350
+
replace: true,
351
+
topmost: container
351
352
};
352
353
}
353
354
+35
webui/js/main.js
+35
webui/js/main.js
···
69
69
enumerable: false
70
70
});
71
71
72
+
73
+
Object.defineProperty(Element.prototype, "$attrs", {
74
+
get() {
75
+
if (this.__attrsProxy) return this.__attrsProxy;
76
+
77
+
const el = this;
78
+
this.__attrsProxy = new Proxy({}, {
79
+
get(_, prop) {
80
+
return el.getAttribute(prop.toString());
81
+
},
82
+
set(_, prop, val) {
83
+
el.setAttribute(prop.toString(), val);
84
+
return true;
85
+
},
86
+
deleteProperty(_, prop) {
87
+
el.removeAttribute(prop.toString());
88
+
return true;
89
+
},
90
+
has(_, prop) {
91
+
return el.hasAttribute(prop.toString());
92
+
},
93
+
ownKeys() {
94
+
return Array.from(el.attributes).map(a => a.name);
95
+
},
96
+
getOwnPropertyDescriptor(_, prop) {
97
+
if (el.hasAttribute(prop.toString())) {
98
+
return { configurable: true, enumerable: true, value: el.getAttribute(prop.toString()) };
99
+
}
100
+
return undefined;
101
+
}
102
+
});
103
+
},
104
+
enumerable: false
105
+
});
106
+
72
107
window.$element = (name) => document.createElement(name);
73
108
window.$div = function (classList = "") {
74
109
const div = $element("div");
+7
-9
webui/js/theme.js
+7
-9
webui/js/theme.js
···
2
2
function checkForParentTheme(element, theme) {
3
3
let parent = element.parentElement;
4
4
while (parent) {
5
-
//if (parent.classList.contains("target")) {
6
-
const parentTheme = parent.getAttribute("theme");
7
-
//if (!parentTheme) return false;
8
-
if (parentTheme)
9
-
return parentTheme !== theme;
10
-
//}
5
+
const parentTheme = parent.dataset.theme;
6
+
7
+
if (parentTheme) return parentTheme !== theme;
8
+
11
9
parent = parent.parentElement;
12
10
}
13
11
···
17
15
export function main(target, initialTheme = null) {
18
16
const storedTheme = localStorage.getItem("theme");
19
17
let theme = initialTheme || storedTheme || "blackboard";
20
-
target.setAttribute("theme", theme);
18
+
target.dataset.theme = theme;
21
19
22
20
if (target === document.body) {
23
21
localStorage.setItem("theme", theme);
24
22
}
25
23
26
24
if (checkForParentTheme(target, theme)) {
27
-
target.setAttribute("theme-changed", "");
25
+
target.dataset.themeChanged = "";
28
26
} else {
29
-
target.removeAttribute("theme-changed");
27
+
delete target.dataset.themeChanged;
30
28
}
31
29
32
30
return { replace: false };
+2
-2
webui/style/theme/themes.css
+2
-2
webui/style/theme/themes.css
···
1
1
2
-
*[theme=whiteboard] {
2
+
*[data-theme=whiteboard] {
3
3
--main-background: #EEDDEE;
4
4
--main-solid: #112233;
5
5
--main-transparent: #112233DD;
···
42
42
--token-css-delimiter: var(--main-solid);
43
43
}
44
44
45
-
*[theme=blackboard] {
45
+
*[data-theme=blackboard] {
46
46
--main-background: #112233;
47
47
--main-solid: #EEDDEE;
48
48
--main-transparent: #EEDDEEDD;