1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 python, 6 pythonOlder, 7 setuptools, 8 wheel, 9 torch, 10 scipy, 11 symlinkJoin, 12}: 13 14let 15 pname = "bitsandbytes"; 16 version = "0.42.0"; 17 18 inherit (torch) cudaCapabilities cudaPackages cudaSupport; 19 inherit (cudaPackages) backendStdenv cudaVersion; 20 21 # NOTE: torchvision doesn't use cudnn; torch does! 22 # For this reason it is not included. 23 cuda-common-redist = with cudaPackages; [ 24 cuda_cccl # <thrust/*> 25 libcublas # cublas_v2.h 26 libcurand 27 libcusolver # cusolverDn.h 28 libcusparse # cusparse.h 29 ]; 30 31 cuda-native-redist = symlinkJoin { 32 name = "cuda-native-redist-${cudaVersion}"; 33 paths = 34 with cudaPackages; 35 [ 36 cuda_cudart # cuda_runtime.h cuda_runtime_api.h 37 cuda_nvcc 38 ] 39 ++ cuda-common-redist; 40 }; 41 42 cuda-redist = symlinkJoin { 43 name = "cuda-redist-${cudaVersion}"; 44 paths = cuda-common-redist; 45 }; 46in 47buildPythonPackage { 48 inherit pname version; 49 pyproject = true; 50 51 disabled = pythonOlder "3.7"; 52 53 src = fetchFromGitHub { 54 owner = "TimDettmers"; 55 repo = "bitsandbytes"; 56 rev = "refs/tags/${version}"; 57 hash = "sha256-PZxsFJ6WpfeQqRQrRRBZfZfNY6/TfJFLBeknX24OXcU="; 58 }; 59 60 postPatch = 61 '' 62 substituteInPlace Makefile --replace "/usr/bin/g++" "g++" --replace "lib64" "lib" 63 substituteInPlace bitsandbytes/cuda_setup/main.py \ 64 --replace "binary_path = package_dir / self.binary_name" \ 65 "binary_path = Path('$out/${python.sitePackages}/${pname}')/self.binary_name" 66 '' 67 + lib.optionalString torch.cudaSupport '' 68 substituteInPlace bitsandbytes/cuda_setup/main.py \ 69 --replace "/usr/local/cuda/lib64" "${cuda-native-redist}/lib" 70 ''; 71 72 CUDA_HOME = "${cuda-native-redist}"; 73 74 preBuild = 75 if torch.cudaSupport then 76 with torch.cudaPackages; 77 let 78 cudaVersion = lib.concatStrings (lib.splitVersion torch.cudaPackages.cudaMajorMinorVersion); 79 in 80 ''make CUDA_VERSION=${cudaVersion} cuda${cudaMajorVersion}x'' 81 else 82 ''make CUDA_VERSION=CPU cpuonly''; 83 84 nativeBuildInputs = [ 85 setuptools 86 wheel 87 ] ++ lib.optionals torch.cudaSupport [ cuda-native-redist ]; 88 89 buildInputs = lib.optionals torch.cudaSupport [ cuda-redist ]; 90 91 propagatedBuildInputs = [ 92 scipy 93 torch 94 ]; 95 96 doCheck = false; # tests require CUDA and also GPU access 97 98 pythonImportsCheck = [ "bitsandbytes" ]; 99 100 meta = with lib; { 101 description = "8-bit CUDA functions for PyTorch"; 102 homepage = "https://github.com/TimDettmers/bitsandbytes"; 103 changelog = "https://github.com/TimDettmers/bitsandbytes/releases/tag/${version}"; 104 license = licenses.mit; 105 maintainers = with maintainers; [ bcdarwin ]; 106 }; 107}