1{
2 lib,
3 stdenv,
4 config,
5 buildPythonPackage,
6 fetchFromGitHub,
7
8 # patches
9 replaceVars,
10 addDriverRunpath,
11 cudaPackages,
12 llvmPackages,
13 ocl-icd,
14 rocmPackages,
15
16 # build-system
17 setuptools,
18
19 # optional-dependencies
20 llvmlite,
21 triton,
22 unicorn,
23
24 # tests
25 pytestCheckHook,
26 writableTmpDirAsHomeHook,
27 blobfile,
28 bottle,
29 capstone,
30 clang,
31 hexdump,
32 hypothesis,
33 jax,
34 librosa,
35 networkx,
36 numpy,
37 onnx,
38 onnxruntime,
39 pillow,
40 pytest-xdist,
41 safetensors,
42 sentencepiece,
43 tiktoken,
44 torch,
45 tqdm,
46 transformers,
47
48 # passthru
49 tinygrad,
50
51 cudaSupport ? config.cudaSupport,
52 rocmSupport ? config.rocmSupport,
53}:
54
55buildPythonPackage rec {
56 pname = "tinygrad";
57 version = "0.10.2";
58 pyproject = true;
59
60 src = fetchFromGitHub {
61 owner = "tinygrad";
62 repo = "tinygrad";
63 tag = "v${version}";
64 hash = "sha256-BXQMacp6QjlgsVwhp2pxEZkRylZfKQhqIh92/0dPlfg=";
65 };
66
67 patches = [
68 (replaceVars ./fix-dlopen-cuda.patch {
69 inherit (addDriverRunpath) driverLink;
70 libnvrtc =
71 if cudaSupport then
72 "${lib.getLib cudaPackages.cuda_nvrtc}/lib/libnvrtc.so"
73 else
74 "Please import nixpkgs with `config.cudaSupport = true`";
75 })
76 ];
77
78 postPatch =
79 # Patch `clang` directly in the source file
80 # Use the unwrapped variant to enable the "native" features currently unavailable in the sandbox
81 ''
82 substituteInPlace tinygrad/runtime/ops_cpu.py \
83 --replace-fail "getenv(\"CC\", 'clang')" "'${lib.getExe llvmPackages.clang-unwrapped}'"
84 ''
85 + ''
86 substituteInPlace tinygrad/runtime/autogen/libc.py \
87 --replace-fail "ctypes.util.find_library('c')" "'${stdenv.cc.libc}/lib/libc.so.6'"
88 ''
89 + ''
90 substituteInPlace tinygrad/runtime/support/llvm.py \
91 --replace-fail "ctypes.util.find_library('LLVM')" "'${lib.getLib llvmPackages.llvm}/lib/libLLVM.so'"
92 ''
93 + lib.optionalString stdenv.hostPlatform.isLinux ''
94 substituteInPlace tinygrad/runtime/autogen/opencl.py \
95 --replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'"
96 ''
97 # test/test_tensor.py imports the PTX variable from the cuda_compiler.py file.
98 # This import leads to loading the libnvrtc.so library that is not substituted when cudaSupport = false.
99 # -> As a fix, we hardcode this variable to False
100 + lib.optionalString (!cudaSupport) ''
101 substituteInPlace test/test_tensor.py \
102 --replace-fail "from tinygrad.runtime.support.compiler_cuda import PTX" "PTX = False"
103 ''
104 # `cuda_fp16.h` and co. are needed at runtime to compile kernels
105 + lib.optionalString cudaSupport ''
106 substituteInPlace tinygrad/runtime/support/compiler_cuda.py \
107 --replace-fail \
108 ', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"' \
109 ', "-I${lib.getDev cudaPackages.cuda_cudart}/include/"'
110 ''
111 + lib.optionalString rocmSupport ''
112 substituteInPlace tinygrad/runtime/autogen/hip.py \
113 --replace-fail "/opt/rocm/" "${rocmPackages.clr}/"
114
115 substituteInPlace tinygrad/runtime/support/compiler_hip.py \
116 --replace-fail "/opt/rocm/include" "${rocmPackages.clr}/include"
117
118 substituteInPlace tinygrad/runtime/support/compiler_hip.py \
119 --replace-fail "/opt/rocm/llvm" "${rocmPackages.llvm.llvm}"
120
121 substituteInPlace tinygrad/runtime/autogen/comgr.py \
122 --replace-fail "/opt/rocm/" "${rocmPackages.rocm-comgr}/"
123 '';
124
125 build-system = [ setuptools ];
126
127 optional-dependencies = {
128 llvm = [ llvmlite ];
129 arm = [ unicorn ];
130 triton = [ triton ];
131 };
132
133 pythonImportsCheck =
134 [
135 "tinygrad"
136 ]
137 ++ lib.optionals cudaSupport [
138 "tinygrad.runtime.ops_nv"
139 ];
140
141 nativeCheckInputs = [
142 pytestCheckHook
143 writableTmpDirAsHomeHook
144
145 blobfile
146 bottle
147 capstone
148 clang
149 hexdump
150 hypothesis
151 jax
152 librosa
153 networkx
154 numpy
155 onnx
156 onnxruntime
157 pillow
158 pytest-xdist
159 safetensors
160 sentencepiece
161 tiktoken
162 torch
163 tqdm
164 transformers
165 ] ++ networkx.optional-dependencies.extra;
166
167 disabledTests =
168 [
169 # RuntimeError: Attempting to relocate against an undefined symbol 'fmaxf'
170 "test_backward_sum_acc_dtype"
171 "test_failure_27"
172
173 # Flaky:
174 # AssertionError: 2.1376906810000946 not less than 2.0
175 "test_recursive_pad"
176
177 # Require internet access
178 "test_benchmark_openpilot_model"
179 "test_bn_alone"
180 "test_bn_linear"
181 "test_bn_mnist"
182 "test_car"
183 "test_chicken"
184 "test_chicken_bigbatch"
185 "test_conv_mnist"
186 "testCopySHMtoDefault"
187 "test_data_parallel_resnet"
188 "test_e2e_big"
189 "test_fetch_small"
190 "test_huggingface_enet_safetensors"
191 "test_index_mnist"
192 "test_linear_mnist"
193 "test_load_convnext"
194 "test_load_enet"
195 "test_load_enet_alt"
196 "test_load_llama2bfloat"
197 "test_load_resnet"
198 "test_mnist_val"
199 "test_openpilot_model"
200 "test_resnet"
201 "test_shufflenet"
202 "test_transcribe_batch12"
203 "test_transcribe_batch21"
204 "test_transcribe_file1"
205 "test_transcribe_file2"
206 "test_transcribe_long"
207 "test_transcribe_long_no_batch"
208 "test_vgg7"
209 ]
210 ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [
211 # Fail with AssertionError
212 "test_casts_from"
213 "test_casts_to"
214 "test_int8"
215 "test_int8_to_uint16_negative"
216 ];
217
218 disabledTestPaths = [
219 # Require internet access
220 "test/models/test_mnist.py"
221 "test/models/test_real_world.py"
222 "test/testextra/test_lr_scheduler.py"
223
224 # Files under this directory are not considered as tests by upstream and should be skipped
225 "extra/"
226 ];
227
228 passthru.tests = {
229 withCuda = tinygrad.override { cudaSupport = true; };
230 };
231
232 meta = {
233 description = "Simple and powerful neural network framework";
234 homepage = "https://github.com/tinygrad/tinygrad";
235 changelog = "https://github.com/tinygrad/tinygrad/releases/tag/v${version}";
236 license = lib.licenses.mit;
237 maintainers = with lib.maintainers; [ GaetanLepage ];
238 badPlatforms = [
239 # Fatal Python error: Aborted
240 # onnxruntime/capi/_pybind_state.py", line 32 in <module>
241 "aarch64-linux"
242
243 # Tests segfault on darwin
244 lib.systems.inspect.patterns.isDarwin
245 ];
246 };
247}