1{
2 lib,
3 torch,
4 symlinkJoin,
5 buildPythonPackage,
6 fetchFromGitHub,
7 cmake,
8 setuptools,
9 scipy,
10}:
11
12let
13 pname = "bitsandbytes";
14 version = "0.45.1";
15
16 inherit (torch) cudaPackages cudaSupport;
17 inherit (cudaPackages) cudaMajorMinorVersion;
18
19 cudaMajorMinorVersionString = lib.replaceStrings [ "." ] [ "" ] cudaMajorMinorVersion;
20
21 # NOTE: torchvision doesn't use cudnn; torch does!
22 # For this reason it is not included.
23 cuda-common-redist = with cudaPackages; [
24 (lib.getDev cuda_cccl) # <thrust/*>
25 (lib.getDev libcublas) # cublas_v2.h
26 (lib.getLib libcublas)
27 libcurand
28 libcusolver # cusolverDn.h
29 (lib.getDev libcusparse) # cusparse.h
30 (lib.getLib libcusparse) # cusparse.h
31 (lib.getDev cuda_cudart) # cuda_runtime.h cuda_runtime_api.h
32 ];
33
34 cuda-native-redist = symlinkJoin {
35 name = "cuda-native-redist-${cudaMajorMinorVersion}";
36 paths =
37 with cudaPackages;
38 [
39 (lib.getDev cuda_cudart) # cuda_runtime.h cuda_runtime_api.h
40 (lib.getLib cuda_cudart)
41 (lib.getStatic cuda_cudart)
42 cuda_nvcc
43 ]
44 ++ cuda-common-redist;
45 };
46
47 cuda-redist = symlinkJoin {
48 name = "cuda-redist-${cudaMajorMinorVersion}";
49 paths = cuda-common-redist;
50 };
51in
52buildPythonPackage {
53 inherit pname version;
54 pyproject = true;
55
56 src = fetchFromGitHub {
57 owner = "TimDettmers";
58 repo = "bitsandbytes";
59 tag = version;
60 hash = "sha256-MZ+3mUXaAhRb+rBtE+eQqT3XdtFxlWJc/CmTEwQkKSA=";
61 };
62
63 # By default, which library is loaded depends on the result of `torch.cuda.is_available()`.
64 # When `cudaSupport` is enabled, bypass this check and load the cuda library unconditionally.
65 # Indeed, in this case, only `libbitsandbytes_cuda124.so` is built. `libbitsandbytes_cpu.so` is not.
66 # Also, hardcode the path to the previously built library instead of relying on
67 # `get_cuda_bnb_library_path(cuda_specs)` which relies on `torch.cuda` too.
68 #
69 # WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for cuda version `xx.y`.
70 # This upstream convention could change at some point and thus break the following patch.
71 postPatch = lib.optionalString cudaSupport ''
72 substituteInPlace bitsandbytes/cextension.py \
73 --replace-fail "if cuda_specs:" "if True:" \
74 --replace-fail \
75 "cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
76 "cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_cuda${cudaMajorMinorVersionString}.so'"
77 '';
78
79 nativeBuildInputs = [
80 cmake
81 cudaPackages.cuda_nvcc
82 ];
83
84 build-system = [
85 setuptools
86 ];
87
88 buildInputs = lib.optionals cudaSupport [ cuda-redist ];
89
90 cmakeFlags = [
91 (lib.cmakeFeature "COMPUTE_BACKEND" (if cudaSupport then "cuda" else "cpu"))
92 ];
93 CUDA_HOME = "${cuda-native-redist}";
94 NVCC_PREPEND_FLAGS = lib.optionals cudaSupport [
95 "-I${cuda-native-redist}/include"
96 "-L${cuda-native-redist}/lib"
97 ];
98
99 preBuild = ''
100 make -j $NIX_BUILD_CORES
101 cd .. # leave /build/source/build
102 '';
103
104 dependencies = [
105 scipy
106 torch
107 ];
108
109 doCheck = false; # tests require CUDA and also GPU access
110
111 pythonImportsCheck = [ "bitsandbytes" ];
112
113 meta = {
114 description = "8-bit CUDA functions for PyTorch";
115 homepage = "https://github.com/TimDettmers/bitsandbytes";
116 changelog = "https://github.com/TimDettmers/bitsandbytes/releases/tag/${version}";
117 license = lib.licenses.mit;
118 maintainers = with lib.maintainers; [ bcdarwin ];
119 };
120}