1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 python,
6 pythonOlder,
7 setuptools,
8 wheel,
9 torch,
10 scipy,
11 symlinkJoin,
12}:
13
14let
15 pname = "bitsandbytes";
16 version = "0.42.0";
17
18 inherit (torch) cudaCapabilities cudaPackages cudaSupport;
19 inherit (cudaPackages) backendStdenv cudaVersion;
20
21 # NOTE: torchvision doesn't use cudnn; torch does!
22 # For this reason it is not included.
23 cuda-common-redist = with cudaPackages; [
24 cuda_cccl # <thrust/*>
25 libcublas # cublas_v2.h
26 libcurand
27 libcusolver # cusolverDn.h
28 libcusparse # cusparse.h
29 ];
30
31 cuda-native-redist = symlinkJoin {
32 name = "cuda-native-redist-${cudaVersion}";
33 paths =
34 with cudaPackages;
35 [
36 cuda_cudart # cuda_runtime.h cuda_runtime_api.h
37 cuda_nvcc
38 ]
39 ++ cuda-common-redist;
40 };
41
42 cuda-redist = symlinkJoin {
43 name = "cuda-redist-${cudaVersion}";
44 paths = cuda-common-redist;
45 };
46in
47buildPythonPackage {
48 inherit pname version;
49 pyproject = true;
50
51 disabled = pythonOlder "3.7";
52
53 src = fetchFromGitHub {
54 owner = "TimDettmers";
55 repo = "bitsandbytes";
56 rev = "refs/tags/${version}";
57 hash = "sha256-PZxsFJ6WpfeQqRQrRRBZfZfNY6/TfJFLBeknX24OXcU=";
58 };
59
60 postPatch =
61 ''
62 substituteInPlace Makefile --replace "/usr/bin/g++" "g++" --replace "lib64" "lib"
63 substituteInPlace bitsandbytes/cuda_setup/main.py \
64 --replace "binary_path = package_dir / self.binary_name" \
65 "binary_path = Path('$out/${python.sitePackages}/${pname}')/self.binary_name"
66 ''
67 + lib.optionalString torch.cudaSupport ''
68 substituteInPlace bitsandbytes/cuda_setup/main.py \
69 --replace "/usr/local/cuda/lib64" "${cuda-native-redist}/lib"
70 '';
71
72 CUDA_HOME = "${cuda-native-redist}";
73
74 preBuild =
75 if torch.cudaSupport then
76 with torch.cudaPackages;
77 let
78 cudaVersion = lib.concatStrings (lib.splitVersion torch.cudaPackages.cudaMajorMinorVersion);
79 in
80 ''make CUDA_VERSION=${cudaVersion} cuda${cudaMajorVersion}x''
81 else
82 ''make CUDA_VERSION=CPU cpuonly'';
83
84 nativeBuildInputs = [
85 setuptools
86 wheel
87 ] ++ lib.optionals torch.cudaSupport [ cuda-native-redist ];
88
89 buildInputs = lib.optionals torch.cudaSupport [ cuda-redist ];
90
91 propagatedBuildInputs = [
92 scipy
93 torch
94 ];
95
96 doCheck = false; # tests require CUDA and also GPU access
97
98 pythonImportsCheck = [ "bitsandbytes" ];
99
100 meta = with lib; {
101 description = "8-bit CUDA functions for PyTorch";
102 homepage = "https://github.com/TimDettmers/bitsandbytes";
103 changelog = "https://github.com/TimDettmers/bitsandbytes/releases/tag/${version}";
104 license = licenses.mit;
105 maintainers = with maintainers; [ bcdarwin ];
106 };
107}