馃悕馃悕馃悕
1import torch
2import numpy as np
3from PIL import Image
4from safetensors.torch import save_file as sft_save
5import io
6
7def badfunc():
8 return 1 / 0
9
10def lerp(a, b, t):
11 return (1-t)*a + t*b
12
13def timed(f):
14 import time
15 def _f(*args, **kwargs):
16 t0 = time.perf_counter()
17 f(*args, **kwargs)
18 print(f"{f.__name__}: {time.perf_counter() - t0}s")
19 return _f
20
21def ifmain(name, provide_arg=None):
22 if name == "__main__":
23 if provide_arg is None:
24 return lambda f: f()
25 return lambda f: f(provide_arg)
26 return lambda f: f
27
28def cpilify(z):
29 z_3d = torch.stack([z.real, z.imag, torch.zeros_like(z.real)])
30 z_norm = (z_3d / 2 + 0.5).clamp(0, 1)
31 z_np = z_norm.detach().cpu().permute(1, 2, 0).numpy()
32 z_bytes = (z_np * 255).round().astype("uint8")
33 return Image.fromarray(z_bytes)
34
35# complex, from -1 to 1 & -i to i
36def csave(x, f):
37 cpilify(x).save(f"out/{f}.png")
38
39# monochrome, 0 to 1
40def mpilify_cpu(z):
41 _z = z.cpu().clamp_(0,1).mul_(255).round()
42 z_np = _z.unsqueeze(2).expand(-1,-1,3).type(torch.uint8).numpy()
43 return Image.fromarray(z_np)
44
45def mpilify(z):
46 _z = torch.clone(z).clamp_(0,1).mul_(255).round()
47 z_np = _z.unsqueeze(2).expand(-1, -1, 3).type(torch.uint8).cpu().numpy()
48 return Image.fromarray(z_np)
49
50def mstreamify(z):
51 return torch.clone(z).clamp_(0,1).mul_(255).round().unsqueeze(2).expand(-1,-1,3).type(torch.uint8).cpu().numpy().tobytes()
52
53def msave_cpu(x, f):
54 mpilify_cpu(x).save(f"out/{f}.png")
55
56def msave(x, f):
57 mpilify(x).save(f"out/{f}.png")
58
59def msave_alt(x, f):
60 with io.BytesIO() as buffer:
61 mpilify(x).save(buffer, format="png")
62 buffer.getvalue()
63 #_z = torch.clone(x).clamp_(0,1).mul_(255).round()
64 #z_np = _z.unsqueeze(2).expand(-1, -1, 3).type(torch.uint8)
65 #sft_save({"":_z.type(torch.uint8)}, f"out/{f}.mono.sft")
66 #torch.save(z_np, "out/{f}.pt")
67
68# 3 channels
69def pilify(z):
70 z_norm = z.clamp(0, 1)
71 z_np = z_norm.detach().cpu().permute(1, 2, 0).numpy()
72 z_bytes = (z_np * 255).round().astype("uint8")
73 return Image.fromarray(z_bytes)
74
75def load_image_tensor(path):
76 with Image.open(path) as pil_image:
77 np_image = np.array(pil_image).astype(np.float32) / 255.0
78 return torch.from_numpy(np_image).permute(2,0,1)
79
80def save(x, f):
81 pilify(x).save(f"out/{f}.png")
82
83def streamify(z):
84 z_norm = z.clamp(0, 1)
85 z_np = z_norm.detach().cpu().permute(1, 2, 0).numpy()
86 return (z_np * 255).round().astype("uint8").tobytes()
87
88# grid of complex numbers
89def cgrid_legacy(h,w,center,span,ctype=torch.cdouble,dtype=torch.double,**_):
90 g = torch.zeros([h, w], dtype=ctype)
91
92 low = center - span / 2
93 hi = center + span / 2
94
95 yspace = torch.linspace(low.imag, hi.imag, h, dtype=dtype)
96 xspace = torch.linspace(low.real, hi.real, w, dtype=dtype)
97
98 for _x in range(h):
99 g[_x] += xspace
100 for _y in range(w):
101 g[:, _y] += yspace * 1j
102
103 return g
104
105
106# result, iterations; iterations == -1 if no convergence before limit
107def gauss_seidel(a, b):
108 x = torch.zeros_like(b)
109 itlim = 1000
110 for it in range(1, itlim):
111 xn = torch.zeros_like(x)
112 for i in range(a.shape[0]):
113 s1 = a[i, :i].dot(xn[:i])
114 s2 = a[i, i+1:].dot(x[i+1:])
115 xn[i] = (b[i] - s1 - s2) / a[i, i]
116 if torch.allclose(x, xn, rtol=1e-8):
117 return xn, it
118 x = xn
119 return x, -1
120
121