1{ lib 2, buildPythonPackage 3, fetchFromGitHub 4, python 5, pythonOlder 6, setuptools 7, wheel 8, torch 9, scipy 10, symlinkJoin 11}: 12 13let 14 pname = "bitsandbytes"; 15 version = "0.41.0"; 16 17 inherit (torch) cudaCapabilities cudaPackages cudaSupport; 18 inherit (cudaPackages) backendStdenv cudaVersion; 19 20 # NOTE: torchvision doesn't use cudnn; torch does! 21 # For this reason it is not included. 22 cuda-common-redist = with cudaPackages; [ 23 cuda_cccl # <thrust/*> 24 libcublas # cublas_v2.h 25 libcurand 26 libcusolver # cusolverDn.h 27 libcusparse # cusparse.h 28 ]; 29 30 cuda-native-redist = symlinkJoin { 31 name = "cuda-native-redist-${cudaVersion}"; 32 paths = with cudaPackages; [ 33 cuda_cudart # cuda_runtime.h cuda_runtime_api.h 34 cuda_nvcc 35 ] ++ cuda-common-redist; 36 }; 37 38 cuda-redist = symlinkJoin { 39 name = "cuda-redist-${cudaVersion}"; 40 paths = cuda-common-redist; 41 }; 42 43in 44buildPythonPackage { 45 inherit pname version; 46 format = "pyproject"; 47 48 disabled = pythonOlder "3.7"; 49 50 src = fetchFromGitHub { 51 owner = "TimDettmers"; 52 repo = pname; 53 rev = "refs/tags/${version}"; 54 hash = "sha256-e6SK2ylITookO6bhpfdRp/V4y2S9rk6Lo1PD3xXrcmM="; 55 }; 56 57 postPatch = '' 58 substituteInPlace Makefile --replace "/usr/bin/g++" "g++" --replace "lib64" "lib" 59 substituteInPlace bitsandbytes/cuda_setup/main.py \ 60 --replace "binary_path = package_dir / self.binary_name" \ 61 "binary_path = Path('$out/${python.sitePackages}/${pname}')/self.binary_name" 62 '' + lib.optionalString torch.cudaSupport '' 63 substituteInPlace bitsandbytes/cuda_setup/main.py \ 64 --replace "/usr/local/cuda/lib64" "${cuda-native-redist}/lib" 65 ''; 66 67 CUDA_HOME = "${cuda-native-redist}"; 68 69 preBuild = if torch.cudaSupport then 70 with torch.cudaPackages; 71 let cudaVersion = lib.concatStrings (lib.splitVersion torch.cudaPackages.cudaMajorMinorVersion); in 72 ''make CUDA_VERSION=${cudaVersion} cuda${cudaMajorVersion}x'' 73 else 74 ''make CUDA_VERSION=CPU cpuonly''; 75 76 nativeBuildInputs = [ setuptools wheel ] ++ lib.optionals torch.cudaSupport [ cuda-native-redist ]; 77 buildInputs = lib.optionals torch.cudaSupport [ cuda-redist ]; 78 79 propagatedBuildInputs = [ 80 scipy 81 torch 82 ]; 83 84 doCheck = false; # tests require CUDA and also GPU access 85 86 pythonImportsCheck = [ 87 "bitsandbytes" 88 ]; 89 90 meta = with lib; { 91 homepage = "https://github.com/TimDettmers/bitsandbytes"; 92 description = "8-bit CUDA functions for PyTorch"; 93 license = licenses.mit; 94 maintainers = with maintainers; [ bcdarwin ]; 95 }; 96}