1{ 2 lib, 3 stdenv, 4 config, 5 buildPythonPackage, 6 fetchFromGitHub, 7 8 # patches 9 replaceVars, 10 addDriverRunpath, 11 cudaPackages, 12 llvmPackages, 13 ocl-icd, 14 rocmPackages, 15 16 # build-system 17 setuptools, 18 19 # optional-dependencies 20 llvmlite, 21 triton, 22 unicorn, 23 24 # tests 25 pytestCheckHook, 26 writableTmpDirAsHomeHook, 27 blobfile, 28 bottle, 29 capstone, 30 clang, 31 hexdump, 32 hypothesis, 33 jax, 34 librosa, 35 networkx, 36 numpy, 37 onnx, 38 onnxruntime, 39 pillow, 40 pytest-xdist, 41 safetensors, 42 sentencepiece, 43 tiktoken, 44 torch, 45 tqdm, 46 transformers, 47 48 # passthru 49 tinygrad, 50 51 cudaSupport ? config.cudaSupport, 52 rocmSupport ? config.rocmSupport, 53}: 54 55buildPythonPackage rec { 56 pname = "tinygrad"; 57 version = "0.10.2"; 58 pyproject = true; 59 60 src = fetchFromGitHub { 61 owner = "tinygrad"; 62 repo = "tinygrad"; 63 tag = "v${version}"; 64 hash = "sha256-BXQMacp6QjlgsVwhp2pxEZkRylZfKQhqIh92/0dPlfg="; 65 }; 66 67 patches = [ 68 (replaceVars ./fix-dlopen-cuda.patch { 69 inherit (addDriverRunpath) driverLink; 70 libnvrtc = 71 if cudaSupport then 72 "${lib.getLib cudaPackages.cuda_nvrtc}/lib/libnvrtc.so" 73 else 74 "Please import nixpkgs with `config.cudaSupport = true`"; 75 }) 76 ]; 77 78 postPatch = 79 # Patch `clang` directly in the source file 80 # Use the unwrapped variant to enable the "native" features currently unavailable in the sandbox 81 '' 82 substituteInPlace tinygrad/runtime/ops_cpu.py \ 83 --replace-fail "getenv(\"CC\", 'clang')" "'${lib.getExe llvmPackages.clang-unwrapped}'" 84 '' 85 + '' 86 substituteInPlace tinygrad/runtime/autogen/libc.py \ 87 --replace-fail "ctypes.util.find_library('c')" "'${stdenv.cc.libc}/lib/libc.so.6'" 88 '' 89 + '' 90 substituteInPlace tinygrad/runtime/support/llvm.py \ 91 --replace-fail "ctypes.util.find_library('LLVM')" "'${lib.getLib llvmPackages.llvm}/lib/libLLVM.so'" 92 '' 93 + lib.optionalString stdenv.hostPlatform.isLinux '' 94 substituteInPlace tinygrad/runtime/autogen/opencl.py \ 95 --replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'" 96 '' 97 # test/test_tensor.py imports the PTX variable from the cuda_compiler.py file. 98 # This import leads to loading the libnvrtc.so library that is not substituted when cudaSupport = false. 99 # -> As a fix, we hardcode this variable to False 100 + lib.optionalString (!cudaSupport) '' 101 substituteInPlace test/test_tensor.py \ 102 --replace-fail "from tinygrad.runtime.support.compiler_cuda import PTX" "PTX = False" 103 '' 104 # `cuda_fp16.h` and co. are needed at runtime to compile kernels 105 + lib.optionalString cudaSupport '' 106 substituteInPlace tinygrad/runtime/support/compiler_cuda.py \ 107 --replace-fail \ 108 ', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"' \ 109 ', "-I${lib.getDev cudaPackages.cuda_cudart}/include/"' 110 '' 111 + lib.optionalString rocmSupport '' 112 substituteInPlace tinygrad/runtime/autogen/hip.py \ 113 --replace-fail "/opt/rocm/" "${rocmPackages.clr}/" 114 115 substituteInPlace tinygrad/runtime/support/compiler_hip.py \ 116 --replace-fail "/opt/rocm/include" "${rocmPackages.clr}/include" 117 118 substituteInPlace tinygrad/runtime/support/compiler_hip.py \ 119 --replace-fail "/opt/rocm/llvm" "${rocmPackages.llvm.llvm}" 120 121 substituteInPlace tinygrad/runtime/autogen/comgr.py \ 122 --replace-fail "/opt/rocm/" "${rocmPackages.rocm-comgr}/" 123 ''; 124 125 build-system = [ setuptools ]; 126 127 optional-dependencies = { 128 llvm = [ llvmlite ]; 129 arm = [ unicorn ]; 130 triton = [ triton ]; 131 }; 132 133 pythonImportsCheck = 134 [ 135 "tinygrad" 136 ] 137 ++ lib.optionals cudaSupport [ 138 "tinygrad.runtime.ops_nv" 139 ]; 140 141 nativeCheckInputs = [ 142 pytestCheckHook 143 writableTmpDirAsHomeHook 144 145 blobfile 146 bottle 147 capstone 148 clang 149 hexdump 150 hypothesis 151 jax 152 librosa 153 networkx 154 numpy 155 onnx 156 onnxruntime 157 pillow 158 pytest-xdist 159 safetensors 160 sentencepiece 161 tiktoken 162 torch 163 tqdm 164 transformers 165 ] ++ networkx.optional-dependencies.extra; 166 167 disabledTests = 168 [ 169 # RuntimeError: Attempting to relocate against an undefined symbol 'fmaxf' 170 "test_backward_sum_acc_dtype" 171 "test_failure_27" 172 173 # Flaky: 174 # AssertionError: 2.1376906810000946 not less than 2.0 175 "test_recursive_pad" 176 177 # Require internet access 178 "test_benchmark_openpilot_model" 179 "test_bn_alone" 180 "test_bn_linear" 181 "test_bn_mnist" 182 "test_car" 183 "test_chicken" 184 "test_chicken_bigbatch" 185 "test_conv_mnist" 186 "testCopySHMtoDefault" 187 "test_data_parallel_resnet" 188 "test_e2e_big" 189 "test_fetch_small" 190 "test_huggingface_enet_safetensors" 191 "test_index_mnist" 192 "test_linear_mnist" 193 "test_load_convnext" 194 "test_load_enet" 195 "test_load_enet_alt" 196 "test_load_llama2bfloat" 197 "test_load_resnet" 198 "test_mnist_val" 199 "test_openpilot_model" 200 "test_resnet" 201 "test_shufflenet" 202 "test_transcribe_batch12" 203 "test_transcribe_batch21" 204 "test_transcribe_file1" 205 "test_transcribe_file2" 206 "test_transcribe_long" 207 "test_transcribe_long_no_batch" 208 "test_vgg7" 209 ] 210 ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [ 211 # Fail with AssertionError 212 "test_casts_from" 213 "test_casts_to" 214 "test_int8" 215 "test_int8_to_uint16_negative" 216 ]; 217 218 disabledTestPaths = [ 219 # Require internet access 220 "test/models/test_mnist.py" 221 "test/models/test_real_world.py" 222 "test/testextra/test_lr_scheduler.py" 223 224 # Files under this directory are not considered as tests by upstream and should be skipped 225 "extra/" 226 ]; 227 228 passthru.tests = { 229 withCuda = tinygrad.override { cudaSupport = true; }; 230 }; 231 232 meta = { 233 description = "Simple and powerful neural network framework"; 234 homepage = "https://github.com/tinygrad/tinygrad"; 235 changelog = "https://github.com/tinygrad/tinygrad/releases/tag/v${version}"; 236 license = lib.licenses.mit; 237 maintainers = with lib.maintainers; [ GaetanLepage ]; 238 badPlatforms = [ 239 # Fatal Python error: Aborted 240 # onnxruntime/capi/_pybind_state.py", line 32 in <module> 241 "aarch64-linux" 242 243 # Tests segfault on darwin 244 lib.systems.inspect.patterns.isDarwin 245 ]; 246 }; 247}