1{
2 lib,
3 config,
4 buildPythonPackage,
5 fetchFromGitHub,
6 substituteAll,
7 addDriverRunpath,
8 cudaSupport ? config.cudaSupport,
9 rocmSupport ? config.rocmSupport,
10 cudaPackages,
11 ocl-icd,
12 stdenv,
13 rocmPackages,
14 # build-system
15 setuptools,
16 wheel,
17 # dependencies
18 numpy,
19 tqdm,
20 # nativeCheckInputs
21 clang,
22 hexdump,
23 hypothesis,
24 librosa,
25 onnx,
26 pillow,
27 pytest-xdist,
28 pytestCheckHook,
29 safetensors,
30 sentencepiece,
31 tiktoken,
32 torch,
33 transformers,
34}:
35
36buildPythonPackage rec {
37 pname = "tinygrad";
38 version = "0.9.0";
39 pyproject = true;
40
41 src = fetchFromGitHub {
42 owner = "tinygrad";
43 repo = "tinygrad";
44 rev = "refs/tags/v${version}";
45 hash = "sha256-opBxciETZruZjHqz/3vO7rogzjvVJKItulIiok/Zs2Y=";
46 };
47
48 patches = [
49 (substituteAll {
50 src = ./fix-dlopen-cuda.patch;
51 inherit (addDriverRunpath) driverLink;
52 libnvrtc =
53 if cudaSupport then
54 "${lib.getLib cudaPackages.cuda_nvrtc}/lib/libnvrtc.so"
55 else
56 "Please import nixpkgs with `config.cudaSupport = true`";
57 })
58 ];
59
60 postPatch =
61 ''
62 substituteInPlace tinygrad/runtime/autogen/opencl.py \
63 --replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'"
64 ''
65 # hipGetDevicePropertiesR0600 is a symbol from rocm-6. We are currently at rocm-5.
66 # We are not sure that this works. Remove when rocm gets updated to version 6.
67 + lib.optionalString rocmSupport ''
68 substituteInPlace extra/hip_gpu_driver/hip_ioctl.py \
69 --replace-fail "processor = platform.processor()" "processor = ${stdenv.hostPlatform.linuxArch}"
70 substituteInPlace tinygrad/runtime/autogen/hip.py \
71 --replace-fail "/opt/rocm/lib/libamdhip64.so" "${rocmPackages.clr}/lib/libamdhip64.so" \
72 --replace-fail "/opt/rocm/lib/libhiprtc.so" "${rocmPackages.clr}/lib/libhiprtc.so" \
73 --replace-fail "hipGetDevicePropertiesR0600" "hipGetDeviceProperties"
74
75 substituteInPlace tinygrad/runtime/autogen/comgr.py \
76 --replace-fail "/opt/rocm/lib/libamd_comgr.so" "${rocmPackages.rocm-comgr}/lib/libamd_comgr.so"
77 '';
78
79 build-system = [
80 setuptools
81 wheel
82 ];
83
84 dependencies =
85 [
86 numpy
87 tqdm
88 ]
89 ++ lib.optionals stdenv.isDarwin [
90 # pyobjc-framework-libdispatch
91 # pyobjc-framework-metal
92 ];
93
94 pythonImportsCheck = [ "tinygrad" ];
95
96 nativeCheckInputs = [
97 clang
98 hexdump
99 hypothesis
100 librosa
101 onnx
102 pillow
103 pytest-xdist
104 pytestCheckHook
105 safetensors
106 sentencepiece
107 tiktoken
108 torch
109 transformers
110 ];
111
112 preCheck = ''
113 export HOME=$(mktemp -d)
114 '';
115
116 disabledTests =
117 [
118 # Require internet access
119 "test_benchmark_openpilot_model"
120 "test_bn_alone"
121 "test_bn_linear"
122 "test_bn_mnist"
123 "test_car"
124 "test_chicken"
125 "test_chicken_bigbatch"
126 "test_conv_mnist"
127 "testCopySHMtoDefault"
128 "test_data_parallel_resnet"
129 "test_e2e_big"
130 "test_fetch_small"
131 "test_huggingface_enet_safetensors"
132 "test_linear_mnist"
133 "test_load_convnext"
134 "test_load_enet"
135 "test_load_enet_alt"
136 "test_load_llama2bfloat"
137 "test_load_resnet"
138 "test_openpilot_model"
139 "test_resnet"
140 "test_shufflenet"
141 "test_transcribe_batch12"
142 "test_transcribe_batch21"
143 "test_transcribe_file1"
144 "test_transcribe_file2"
145 "test_transcribe_long"
146 "test_transcribe_long_no_batch"
147 "test_vgg7"
148 ]
149 # Fail on aarch64-linux with AssertionError
150 ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [
151 "test_casts_to"
152 "test_casts_to"
153 "test_int8_to_uint16_negative"
154 "test_casts_to"
155 "test_casts_to"
156 "test_casts_from"
157 "test_casts_to"
158 "test_int8"
159 "test_casts_to"
160 ];
161
162 disabledTestPaths =
163 [
164 # Require internet access
165 "test/models/test_mnist.py"
166 "test/models/test_real_world.py"
167 "test/testextra/test_lr_scheduler.py"
168 ]
169 ++ lib.optionals (!rocmSupport) [ "extra/hip_gpu_driver/" ];
170
171 meta = with lib; {
172 description = "Simple and powerful neural network framework";
173 homepage = "https://github.com/tinygrad/tinygrad";
174 changelog = "https://github.com/tinygrad/tinygrad/releases/tag/v${version}";
175 license = licenses.mit;
176 maintainers = with maintainers; [ GaetanLepage ];
177 # Requires unpackaged pyobjc-framework-libdispatch and pyobjc-framework-metal
178 broken = stdenv.isDarwin;
179 };
180}