1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, python
5, pythonOlder
6, setuptools
7, wheel
8, torch
9, scipy
10, symlinkJoin
11}:
12
13let
14 pname = "bitsandbytes";
15 version = "0.41.0";
16
17 inherit (torch) cudaCapabilities cudaPackages cudaSupport;
18 inherit (cudaPackages) backendStdenv cudaVersion;
19
20 # NOTE: torchvision doesn't use cudnn; torch does!
21 # For this reason it is not included.
22 cuda-common-redist = with cudaPackages; [
23 cuda_cccl # <thrust/*>
24 libcublas # cublas_v2.h
25 libcurand
26 libcusolver # cusolverDn.h
27 libcusparse # cusparse.h
28 ];
29
30 cuda-native-redist = symlinkJoin {
31 name = "cuda-native-redist-${cudaVersion}";
32 paths = with cudaPackages; [
33 cuda_cudart # cuda_runtime.h cuda_runtime_api.h
34 cuda_nvcc
35 ] ++ cuda-common-redist;
36 };
37
38 cuda-redist = symlinkJoin {
39 name = "cuda-redist-${cudaVersion}";
40 paths = cuda-common-redist;
41 };
42
43in
44buildPythonPackage {
45 inherit pname version;
46 format = "pyproject";
47
48 disabled = pythonOlder "3.7";
49
50 src = fetchFromGitHub {
51 owner = "TimDettmers";
52 repo = pname;
53 rev = "refs/tags/${version}";
54 hash = "sha256-e6SK2ylITookO6bhpfdRp/V4y2S9rk6Lo1PD3xXrcmM=";
55 };
56
57 postPatch = ''
58 substituteInPlace Makefile --replace "/usr/bin/g++" "g++" --replace "lib64" "lib"
59 substituteInPlace bitsandbytes/cuda_setup/main.py \
60 --replace "binary_path = package_dir / self.binary_name" \
61 "binary_path = Path('$out/${python.sitePackages}/${pname}')/self.binary_name"
62 '' + lib.optionalString torch.cudaSupport ''
63 substituteInPlace bitsandbytes/cuda_setup/main.py \
64 --replace "/usr/local/cuda/lib64" "${cuda-native-redist}/lib"
65 '';
66
67 CUDA_HOME = "${cuda-native-redist}";
68
69 preBuild = if torch.cudaSupport then
70 with torch.cudaPackages;
71 let cudaVersion = lib.concatStrings (lib.splitVersion torch.cudaPackages.cudaMajorMinorVersion); in
72 ''make CUDA_VERSION=${cudaVersion} cuda${cudaMajorVersion}x''
73 else
74 ''make CUDA_VERSION=CPU cpuonly'';
75
76 nativeBuildInputs = [ setuptools wheel ] ++ lib.optionals torch.cudaSupport [ cuda-native-redist ];
77 buildInputs = lib.optionals torch.cudaSupport [ cuda-redist ];
78
79 propagatedBuildInputs = [
80 scipy
81 torch
82 ];
83
84 doCheck = false; # tests require CUDA and also GPU access
85
86 pythonImportsCheck = [
87 "bitsandbytes"
88 ];
89
90 meta = with lib; {
91 homepage = "https://github.com/TimDettmers/bitsandbytes";
92 description = "8-bit CUDA functions for PyTorch";
93 license = licenses.mit;
94 maintainers = with maintainers; [ bcdarwin ];
95 };
96}