1{
2 lib,
3 stdenv,
4 config,
5 buildPythonPackage,
6 fetchFromGitHub,
7
8 # patches
9 replaceVars,
10 addDriverRunpath,
11 cudaPackages,
12 llvmPackages,
13 ocl-icd,
14 rocmPackages,
15
16 # build-system
17 setuptools,
18
19 # optional-dependencies
20 llvmlite,
21 triton,
22 unicorn,
23
24 # tests
25 pytestCheckHook,
26 writableTmpDirAsHomeHook,
27 blobfile,
28 bottle,
29 capstone,
30 clang,
31 hexdump,
32 hypothesis,
33 jax,
34 librosa,
35 ml-dtypes,
36 networkx,
37 numpy,
38 onnx,
39 onnxruntime,
40 pillow,
41 pytest-xdist,
42 safetensors,
43 sentencepiece,
44 tiktoken,
45 torch,
46 tqdm,
47 transformers,
48 z3-solver,
49
50 # passthru
51 tinygrad,
52
53 cudaSupport ? config.cudaSupport,
54 rocmSupport ? config.rocmSupport,
55}:
56
57buildPythonPackage rec {
58 pname = "tinygrad";
59 version = "0.10.3";
60 pyproject = true;
61
62 src = fetchFromGitHub {
63 owner = "tinygrad";
64 repo = "tinygrad";
65 tag = "v${version}";
66 hash = "sha256-IQ0EAjj8kYUwzvMsAiNnvRm/twC40r9JWXUocaETjC8=";
67 };
68
69 patches = [
70 (replaceVars ./fix-dlopen-cuda.patch {
71 inherit (addDriverRunpath) driverLink;
72 libnvrtc =
73 if cudaSupport then
74 "${lib.getLib cudaPackages.cuda_nvrtc}/lib/libnvrtc.so"
75 else
76 "Please import nixpkgs with `config.cudaSupport = true`";
77 })
78 ];
79
80 postPatch =
81 # Patch `clang` directly in the source file
82 # Use the unwrapped variant to enable the "native" features currently unavailable in the sandbox
83 ''
84 substituteInPlace tinygrad/runtime/ops_cpu.py \
85 --replace-fail "getenv(\"CC\", 'clang')" "'${lib.getExe llvmPackages.clang-unwrapped}'"
86 ''
87 + ''
88 substituteInPlace tinygrad/runtime/autogen/libc.py \
89 --replace-fail "ctypes.util.find_library('c')" "'${stdenv.cc.libc}/lib/libc.so.6'"
90 ''
91 + ''
92 substituteInPlace tinygrad/runtime/support/llvm.py \
93 --replace-fail "ctypes.util.find_library('LLVM')" "'${lib.getLib llvmPackages.llvm}/lib/libLLVM.so'"
94 ''
95 + lib.optionalString stdenv.hostPlatform.isLinux ''
96 substituteInPlace tinygrad/runtime/autogen/opencl.py \
97 --replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'"
98 ''
99 # test/test_tensor.py imports the PTX variable from the cuda_compiler.py file.
100 # This import leads to loading the libnvrtc.so library that is not substituted when cudaSupport = false.
101 # -> As a fix, we hardcode this variable to False
102 + lib.optionalString (!cudaSupport) ''
103 substituteInPlace test/test_tensor.py \
104 --replace-fail "from tinygrad.runtime.support.compiler_cuda import PTX" "PTX = False"
105 ''
106 # `cuda_fp16.h` and co. are needed at runtime to compile kernels
107 + lib.optionalString cudaSupport ''
108 substituteInPlace tinygrad/runtime/support/compiler_cuda.py \
109 --replace-fail \
110 ', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"' \
111 ', "-I${lib.getDev cudaPackages.cuda_cudart}/include/"'
112 ''
113 + lib.optionalString rocmSupport ''
114 substituteInPlace tinygrad/runtime/autogen/hip.py \
115 --replace-fail "/opt/rocm/" "${rocmPackages.clr}/"
116
117 substituteInPlace tinygrad/runtime/support/compiler_hip.py \
118 --replace-fail "/opt/rocm/include" "${rocmPackages.clr}/include"
119
120 substituteInPlace tinygrad/runtime/support/compiler_hip.py \
121 --replace-fail "/opt/rocm/llvm" "${rocmPackages.llvm.llvm}"
122
123 substituteInPlace tinygrad/runtime/autogen/comgr.py \
124 --replace-fail "/opt/rocm/" "${rocmPackages.rocm-comgr}/"
125 '';
126
127 build-system = [ setuptools ];
128
129 optional-dependencies = {
130 llvm = [ llvmlite ];
131 arm = [ unicorn ];
132 triton = [ triton ];
133 };
134
135 pythonImportsCheck =
136 [
137 "tinygrad"
138 ]
139 ++ lib.optionals cudaSupport [
140 "tinygrad.runtime.ops_nv"
141 ];
142
143 nativeCheckInputs = [
144 pytestCheckHook
145 writableTmpDirAsHomeHook
146
147 blobfile
148 bottle
149 capstone
150 clang
151 hexdump
152 hypothesis
153 jax
154 librosa
155 ml-dtypes
156 networkx
157 numpy
158 onnx
159 onnxruntime
160 pillow
161 pytest-xdist
162 safetensors
163 sentencepiece
164 tiktoken
165 torch
166 tqdm
167 transformers
168 z3-solver
169 ] ++ networkx.optional-dependencies.extra;
170
171 disabledTests =
172 [
173 # RuntimeError: Attempting to relocate against an undefined symbol 'fmaxf'
174 "test_backward_sum_acc_dtype"
175 "test_failure_27"
176
177 # Flaky:
178 # AssertionError: 2.1376906810000946 not less than 2.0
179 "test_recursive_pad"
180
181 # Since updated onnx to 1.18.0:
182 # onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from ...
183 # Unsupported model IR version: 11, max supported IR version: 10
184 "test_quant_128"
185
186 # Require internet access
187 "test_benchmark_openpilot_model"
188 "test_bn_alone"
189 "test_bn_linear"
190 "test_bn_mnist"
191 "test_car"
192 "test_chicken"
193 "test_chicken_bigbatch"
194 "test_conv_mnist"
195 "testCopySHMtoDefault"
196 "test_data_parallel_resnet"
197 "test_e2e_big"
198 "test_fetch_small"
199 "test_huggingface_enet_safetensors"
200 "test_index_mnist"
201 "test_linear_mnist"
202 "test_load_convnext"
203 "test_load_enet"
204 "test_load_enet_alt"
205 "test_load_llama2bfloat"
206 "test_load_resnet"
207 "test_mnist_val"
208 "test_openpilot_model"
209 "test_resnet"
210 "test_shufflenet"
211 "test_transcribe_batch12"
212 "test_transcribe_batch21"
213 "test_transcribe_file1"
214 "test_transcribe_file2"
215 "test_transcribe_long"
216 "test_transcribe_long_no_batch"
217 "test_vgg7"
218 ]
219 ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [
220 # Fail with AssertionError
221 "test_casts_from"
222 "test_casts_to"
223 "test_int8"
224 "test_int8_to_uint16_negative"
225 ];
226
227 disabledTestPaths = [
228 # Require internet access
229 "test/models/test_mnist.py"
230 "test/models/test_real_world.py"
231 "test/testextra/test_lr_scheduler.py"
232
233 # Files under this directory are not considered as tests by upstream and should be skipped
234 "extra/"
235 ];
236
237 passthru.tests = {
238 withCuda = tinygrad.override { cudaSupport = true; };
239 };
240
241 meta = {
242 description = "Simple and powerful neural network framework";
243 homepage = "https://github.com/tinygrad/tinygrad";
244 changelog = "https://github.com/tinygrad/tinygrad/releases/tag/v${version}";
245 license = lib.licenses.mit;
246 maintainers = with lib.maintainers; [ GaetanLepage ];
247 badPlatforms = [
248 # Fatal Python error: Aborted
249 # onnxruntime/capi/_pybind_state.py", line 32 in <module>
250 "aarch64-linux"
251
252 # Tests segfault on darwin
253 lib.systems.inspect.patterns.isDarwin
254 ];
255 };
256}