1{ 2 lib, 3 torch, 4 symlinkJoin, 5 buildPythonPackage, 6 fetchFromGitHub, 7 cmake, 8 setuptools, 9 scipy, 10}: 11 12let 13 pname = "bitsandbytes"; 14 version = "0.45.1"; 15 16 inherit (torch) cudaPackages cudaSupport; 17 inherit (cudaPackages) cudaMajorMinorVersion; 18 19 cudaMajorMinorVersionString = lib.replaceStrings [ "." ] [ "" ] cudaMajorMinorVersion; 20 21 # NOTE: torchvision doesn't use cudnn; torch does! 22 # For this reason it is not included. 23 cuda-common-redist = with cudaPackages; [ 24 (lib.getDev cuda_cccl) # <thrust/*> 25 (lib.getDev libcublas) # cublas_v2.h 26 (lib.getLib libcublas) 27 libcurand 28 libcusolver # cusolverDn.h 29 (lib.getDev libcusparse) # cusparse.h 30 (lib.getLib libcusparse) # cusparse.h 31 (lib.getDev cuda_cudart) # cuda_runtime.h cuda_runtime_api.h 32 ]; 33 34 cuda-native-redist = symlinkJoin { 35 name = "cuda-native-redist-${cudaMajorMinorVersion}"; 36 paths = 37 with cudaPackages; 38 [ 39 (lib.getDev cuda_cudart) # cuda_runtime.h cuda_runtime_api.h 40 (lib.getLib cuda_cudart) 41 (lib.getStatic cuda_cudart) 42 cuda_nvcc 43 ] 44 ++ cuda-common-redist; 45 }; 46 47 cuda-redist = symlinkJoin { 48 name = "cuda-redist-${cudaMajorMinorVersion}"; 49 paths = cuda-common-redist; 50 }; 51in 52buildPythonPackage { 53 inherit pname version; 54 pyproject = true; 55 56 src = fetchFromGitHub { 57 owner = "TimDettmers"; 58 repo = "bitsandbytes"; 59 tag = version; 60 hash = "sha256-MZ+3mUXaAhRb+rBtE+eQqT3XdtFxlWJc/CmTEwQkKSA="; 61 }; 62 63 # By default, which library is loaded depends on the result of `torch.cuda.is_available()`. 64 # When `cudaSupport` is enabled, bypass this check and load the cuda library unconditionally. 65 # Indeed, in this case, only `libbitsandbytes_cuda124.so` is built. `libbitsandbytes_cpu.so` is not. 66 # Also, hardcode the path to the previously built library instead of relying on 67 # `get_cuda_bnb_library_path(cuda_specs)` which relies on `torch.cuda` too. 68 # 69 # WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for cuda version `xx.y`. 70 # This upstream convention could change at some point and thus break the following patch. 71 postPatch = lib.optionalString cudaSupport '' 72 substituteInPlace bitsandbytes/cextension.py \ 73 --replace-fail "if cuda_specs:" "if True:" \ 74 --replace-fail \ 75 "cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \ 76 "cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_cuda${cudaMajorMinorVersionString}.so'" 77 ''; 78 79 nativeBuildInputs = [ 80 cmake 81 cudaPackages.cuda_nvcc 82 ]; 83 84 build-system = [ 85 setuptools 86 ]; 87 88 buildInputs = lib.optionals cudaSupport [ cuda-redist ]; 89 90 cmakeFlags = [ 91 (lib.cmakeFeature "COMPUTE_BACKEND" (if cudaSupport then "cuda" else "cpu")) 92 ]; 93 CUDA_HOME = "${cuda-native-redist}"; 94 NVCC_PREPEND_FLAGS = lib.optionals cudaSupport [ 95 "-I${cuda-native-redist}/include" 96 "-L${cuda-native-redist}/lib" 97 ]; 98 99 preBuild = '' 100 make -j $NIX_BUILD_CORES 101 cd .. # leave /build/source/build 102 ''; 103 104 dependencies = [ 105 scipy 106 torch 107 ]; 108 109 doCheck = false; # tests require CUDA and also GPU access 110 111 pythonImportsCheck = [ "bitsandbytes" ]; 112 113 meta = { 114 description = "8-bit CUDA functions for PyTorch"; 115 homepage = "https://github.com/TimDettmers/bitsandbytes"; 116 changelog = "https://github.com/TimDettmers/bitsandbytes/releases/tag/${version}"; 117 license = lib.licenses.mit; 118 maintainers = with lib.maintainers; [ bcdarwin ]; 119 }; 120}