1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, python
5, pythonOlder
6, pytestCheckHook
7, setuptools
8, torch
9, einops
10, lion-pytorch
11, scipy
12, symlinkJoin
13}:
14
15let
16 pname = "bitsandbytes";
17 version = "0.38.0";
18
19 inherit (torch) cudaCapabilities cudaPackages cudaSupport;
20 inherit (cudaPackages) backendStdenv cudaVersion;
21
22 # NOTE: torchvision doesn't use cudnn; torch does!
23 # For this reason it is not included.
24 cuda-common-redist = with cudaPackages; [
25 cuda_cccl # <thrust/*>
26 libcublas # cublas_v2.h
27 libcurand
28 libcusolver # cusolverDn.h
29 libcusparse # cusparse.h
30 ];
31
32 cuda-native-redist = symlinkJoin {
33 name = "cuda-native-redist-${cudaVersion}";
34 paths = with cudaPackages; [
35 cuda_cudart # cuda_runtime.h cuda_runtime_api.h
36 cuda_nvcc
37 ] ++ cuda-common-redist;
38 };
39
40 cuda-redist = symlinkJoin {
41 name = "cuda-redist-${cudaVersion}";
42 paths = cuda-common-redist;
43 };
44
45in
46buildPythonPackage {
47 inherit pname version;
48 format = "pyproject";
49
50 disabled = pythonOlder "3.7";
51
52 src = fetchFromGitHub {
53 owner = "TimDettmers";
54 repo = pname;
55 rev = "refs/tags/${version}";
56 hash = "sha256-gGlbzTDvZNo4MhcYzLvWuB2ec7q+Qt5/LtTbJ0Rc+Kk=";
57 };
58
59 postPatch = ''
60 substituteInPlace Makefile --replace "/usr/bin/g++" "g++" --replace "lib64" "lib"
61 substituteInPlace bitsandbytes/cuda_setup/main.py \
62 --replace "binary_path = package_dir / binary_name" \
63 "binary_path = Path('$out/${python.sitePackages}/${pname}')/binary_name"
64 '' + lib.optionalString torch.cudaSupport ''
65 substituteInPlace bitsandbytes/cuda_setup/main.py \
66 --replace "/usr/local/cuda/lib64" "${cuda-native-redist}/lib"
67 '';
68
69 CUDA_HOME = "${cuda-native-redist}";
70
71 preBuild = if torch.cudaSupport then
72 with torch.cudaPackages;
73 let cudaVersion = lib.concatStrings (lib.splitVersion torch.cudaPackages.cudaMajorMinorVersion); in
74 ''make CUDA_VERSION=${cudaVersion} cuda${cudaMajorVersion}x''
75 else
76 ''make CUDA_VERSION=CPU cpuonly'';
77
78 nativeBuildInputs = [ setuptools ] ++ lib.optionals torch.cudaSupport [ cuda-native-redist ];
79 buildInputs = lib.optionals torch.cudaSupport [ cuda-redist ];
80
81 propagatedBuildInputs = [
82 torch
83 ];
84
85 doCheck = false; # tests require CUDA and also GPU access
86 nativeCheckInputs = [ pytestCheckHook einops lion-pytorch scipy ];
87
88 pythonImportsCheck = [
89 "bitsandbytes"
90 ];
91
92 meta = with lib; {
93 homepage = "https://github.com/TimDettmers/bitsandbytes";
94 description = "8-bit CUDA functions for PyTorch";
95 license = licenses.mit;
96 maintainers = with maintainers; [ bcdarwin ];
97 };
98}