1{
2 lib,
3 config,
4 buildPythonPackage,
5 fetchFromGitHub,
6 substituteAll,
7 addDriverRunpath,
8 cudaSupport ? config.cudaSupport,
9 rocmSupport ? config.rocmSupport,
10 cudaPackages,
11 ocl-icd,
12 rocmPackages,
13 stdenv,
14
15 # build-system
16 setuptools,
17
18 # optional-dependencies
19 llvmlite,
20 triton,
21 unicorn,
22
23 # tests
24 blobfile,
25 bottle,
26 clang,
27 hexdump,
28 hypothesis,
29 librosa,
30 networkx,
31 numpy,
32 onnx,
33 pillow,
34 pytest-xdist,
35 pytestCheckHook,
36 safetensors,
37 sentencepiece,
38 tiktoken,
39 torch,
40 tqdm,
41 transformers,
42
43 tinygrad,
44}:
45
46buildPythonPackage rec {
47 pname = "tinygrad";
48 version = "0.10.0";
49 pyproject = true;
50
51 src = fetchFromGitHub {
52 owner = "tinygrad";
53 repo = "tinygrad";
54 rev = "refs/tags/v${version}";
55 hash = "sha256-IIyTb3jDUSEP2IXK6DLsI15E5N34Utt7xv86aTHpXf8=";
56 };
57
58 patches = [
59 (substituteAll {
60 src = ./fix-dlopen-cuda.patch;
61 inherit (addDriverRunpath) driverLink;
62 libnvrtc =
63 if cudaSupport then
64 "${lib.getLib cudaPackages.cuda_nvrtc}/lib/libnvrtc.so"
65 else
66 "Please import nixpkgs with `config.cudaSupport = true`";
67 })
68 ];
69
70 postPatch =
71 # Patch `clang` directly in the source file
72 ''
73 substituteInPlace tinygrad/runtime/ops_clang.py \
74 --replace-fail "'clang'" "'${lib.getExe clang}'"
75 ''
76 + lib.optionalString stdenv.hostPlatform.isLinux ''
77 substituteInPlace tinygrad/runtime/autogen/opencl.py \
78 --replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'"
79 ''
80 # `cuda_fp16.h` and co. are needed at runtime to compile kernels
81 + lib.optionalString cudaSupport ''
82 substituteInPlace tinygrad/runtime/support/compiler_cuda.py \
83 --replace-fail \
84 ', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"' \
85 ', "-I${lib.getDev cudaPackages.cuda_cudart}/include/"'
86 ''
87 + lib.optionalString rocmSupport ''
88 substituteInPlace tinygrad/runtime/autogen/hip.py \
89 --replace-fail "/opt/rocm/lib/libamdhip64.so" "${rocmPackages.clr}/lib/libamdhip64.so" \
90 --replace-fail "/opt/rocm/lib/libhiprtc.so" "${rocmPackages.clr}/lib/libhiprtc.so" \
91
92 substituteInPlace tinygrad/runtime/autogen/comgr.py \
93 --replace-fail "/opt/rocm/lib/libamd_comgr.so" "${rocmPackages.rocm-comgr}/lib/libamd_comgr.so"
94 '';
95
96 build-system = [ setuptools ];
97
98 optional-dependencies = {
99 llvm = [ llvmlite ];
100 arm = [ unicorn ];
101 triton = [ triton ];
102 };
103
104 pythonImportsCheck =
105 [
106 "tinygrad"
107 ]
108 ++ lib.optionals cudaSupport [
109 "tinygrad.runtime.ops_nv"
110 ];
111
112 nativeCheckInputs = [
113 blobfile
114 bottle
115 clang
116 hexdump
117 hypothesis
118 librosa
119 networkx
120 numpy
121 onnx
122 pillow
123 pytest-xdist
124 pytestCheckHook
125 safetensors
126 sentencepiece
127 tiktoken
128 torch
129 tqdm
130 transformers
131 ] ++ networkx.optional-dependencies.extra;
132
133 preCheck = ''
134 export HOME=$(mktemp -d)
135 '';
136
137 disabledTests =
138 [
139 # Fixed in https://github.com/tinygrad/tinygrad/pull/7792
140 # TODO: re-enable at next release
141 "test_kernel_cache_in_action"
142
143 # Require internet access
144 "test_benchmark_openpilot_model"
145 "test_bn_alone"
146 "test_bn_linear"
147 "test_bn_mnist"
148 "test_car"
149 "test_chicken"
150 "test_chicken_bigbatch"
151 "test_conv_mnist"
152 "testCopySHMtoDefault"
153 "test_data_parallel_resnet"
154 "test_e2e_big"
155 "test_fetch_small"
156 "test_huggingface_enet_safetensors"
157 "test_index_mnist"
158 "test_linear_mnist"
159 "test_load_convnext"
160 "test_load_enet"
161 "test_load_enet_alt"
162 "test_load_llama2bfloat"
163 "test_load_resnet"
164 "test_mnist_val"
165 "test_openpilot_model"
166 "test_resnet"
167 "test_shufflenet"
168 "test_transcribe_batch12"
169 "test_transcribe_batch21"
170 "test_transcribe_file1"
171 "test_transcribe_file2"
172 "test_transcribe_long"
173 "test_transcribe_long_no_batch"
174 "test_vgg7"
175 ]
176 ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [
177 # Fixed in https://github.com/tinygrad/tinygrad/pull/7796
178 # TODO: re-enable at next release
179 "test_interpolate_bilinear"
180
181 # Fail with AssertionError
182 "test_casts_from"
183 "test_casts_to"
184 "test_int8"
185 "test_int8_to_uint16_negative"
186 ];
187
188 disabledTestPaths = [
189 # Require internet access
190 "test/models/test_mnist.py"
191 "test/models/test_real_world.py"
192 "test/testextra/test_lr_scheduler.py"
193
194 # Files under this directory are not considered as tests by upstream and should be skipped
195 "extra/"
196 ];
197
198 passthru.tests = {
199 withCuda = tinygrad.override { cudaSupport = true; };
200 };
201
202 meta = {
203 description = "Simple and powerful neural network framework";
204 homepage = "https://github.com/tinygrad/tinygrad";
205 changelog = "https://github.com/tinygrad/tinygrad/releases/tag/v${version}";
206 license = lib.licenses.mit;
207 maintainers = with lib.maintainers; [ GaetanLepage ];
208 # Tests segfault on darwin
209 badPlatforms = [ lib.systems.inspect.patterns.isDarwin ];
210 };
211}