馃悕馃悕馃悕
at dev 3.4 kB view raw
1import torch 2import numpy as np 3from PIL import Image 4from safetensors.torch import save_file as sft_save 5import io 6 7def badfunc(): 8 return 1 / 0 9 10def lerp(a, b, t): 11 return (1-t)*a + t*b 12 13def timed(f): 14 import time 15 def _f(*args, **kwargs): 16 t0 = time.perf_counter() 17 f(*args, **kwargs) 18 print(f"{f.__name__}: {time.perf_counter() - t0}s") 19 return _f 20 21def ifmain(name, provide_arg=None): 22 if name == "__main__": 23 if provide_arg is None: 24 return lambda f: f() 25 return lambda f: f(provide_arg) 26 return lambda f: f 27 28def cpilify(z): 29 z_3d = torch.stack([z.real, z.imag, torch.zeros_like(z.real)]) 30 z_norm = (z_3d / 2 + 0.5).clamp(0, 1) 31 z_np = z_norm.detach().cpu().permute(1, 2, 0).numpy() 32 z_bytes = (z_np * 255).round().astype("uint8") 33 return Image.fromarray(z_bytes) 34 35# complex, from -1 to 1 & -i to i 36def csave(x, f): 37 cpilify(x).save(f"out/{f}.png") 38 39# monochrome, 0 to 1 40def mpilify_cpu(z): 41 _z = z.cpu().clamp_(0,1).mul_(255).round() 42 z_np = _z.unsqueeze(2).expand(-1,-1,3).type(torch.uint8).numpy() 43 return Image.fromarray(z_np) 44 45def mpilify(z): 46 _z = torch.clone(z).clamp_(0,1).mul_(255).round() 47 z_np = _z.unsqueeze(2).expand(-1, -1, 3).type(torch.uint8).cpu().numpy() 48 return Image.fromarray(z_np) 49 50def mstreamify(z): 51 return torch.clone(z).clamp_(0,1).mul_(255).round().unsqueeze(2).expand(-1,-1,3).type(torch.uint8).cpu().numpy().tobytes() 52 53def msave_cpu(x, f): 54 mpilify_cpu(x).save(f"out/{f}.png") 55 56def msave(x, f): 57 mpilify(x).save(f"out/{f}.png") 58 59def msave_alt(x, f): 60 with io.BytesIO() as buffer: 61 mpilify(x).save(buffer, format="png") 62 buffer.getvalue() 63 #_z = torch.clone(x).clamp_(0,1).mul_(255).round() 64 #z_np = _z.unsqueeze(2).expand(-1, -1, 3).type(torch.uint8) 65 #sft_save({"":_z.type(torch.uint8)}, f"out/{f}.mono.sft") 66 #torch.save(z_np, "out/{f}.pt") 67 68# 3 channels 69def pilify(z): 70 z_norm = z.clamp(0, 1) 71 z_np = z_norm.detach().cpu().permute(1, 2, 0).numpy() 72 z_bytes = (z_np * 255).round().astype("uint8") 73 return Image.fromarray(z_bytes) 74 75def load_image_tensor(path): 76 with Image.open(path) as pil_image: 77 np_image = np.array(pil_image).astype(np.float32) / 255.0 78 return torch.from_numpy(np_image).permute(2,0,1) 79 80def save(x, f): 81 pilify(x).save(f"out/{f}.png") 82 83def streamify(z): 84 z_norm = z.clamp(0, 1) 85 z_np = z_norm.detach().cpu().permute(1, 2, 0).numpy() 86 return (z_np * 255).round().astype("uint8").tobytes() 87 88# grid of complex numbers 89def cgrid_legacy(h,w,center,span,ctype=torch.cdouble,dtype=torch.double,**_): 90 g = torch.zeros([h, w], dtype=ctype) 91 92 low = center - span / 2 93 hi = center + span / 2 94 95 yspace = torch.linspace(low.imag, hi.imag, h, dtype=dtype) 96 xspace = torch.linspace(low.real, hi.real, w, dtype=dtype) 97 98 for _x in range(h): 99 g[_x] += xspace 100 for _y in range(w): 101 g[:, _y] += yspace * 1j 102 103 return g 104 105 106# result, iterations; iterations == -1 if no convergence before limit 107def gauss_seidel(a, b): 108 x = torch.zeros_like(b) 109 itlim = 1000 110 for it in range(1, itlim): 111 xn = torch.zeros_like(x) 112 for i in range(a.shape[0]): 113 s1 = a[i, :i].dot(xn[:i]) 114 s2 = a[i, i+1:].dot(x[i+1:]) 115 xn[i] = (b[i] - s1 - s2) / a[i, i] 116 if torch.allclose(x, xn, rtol=1e-8): 117 return xn, it 118 x = xn 119 return x, -1 120 121