1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 pythonOlder,
6 fetchFromGitHub,
7 which,
8 setuptools,
9 # runtime dependencies
10 numpy,
11 torch,
12 # check dependencies
13 pytestCheckHook,
14 pytest-cov-stub,
15 # , pytest-mpi
16 pytest-timeout,
17 # , pytorch-image-models
18 hydra-core,
19 fairscale,
20 scipy,
21 cmake,
22 ninja,
23 triton,
24 networkx,
25 #, apex
26 einops,
27 transformers,
28 timm,
29#, flash-attn
30}:
31let
32 inherit (torch) cudaCapabilities cudaPackages cudaSupport;
33 version = "0.0.28.post3";
34in
35buildPythonPackage {
36 pname = "xformers";
37 inherit version;
38 pyproject = true;
39
40 disabled = pythonOlder "3.9";
41
42 src = fetchFromGitHub {
43 owner = "facebookresearch";
44 repo = "xformers";
45 tag = "v${version}";
46 hash = "sha256-23tnhCHK+Z0No8fqZxkgDFp2VIgXZR4jpM+pkb/vvmw=";
47 fetchSubmodules = true;
48 };
49
50 patches = [ ./0001-fix-allow-building-without-git.patch ];
51
52 build-system = [ setuptools ];
53
54 preBuild = ''
55 cat << EOF > ./xformers/version.py
56 # noqa: C801
57 __version__ = "${version}"
58 EOF
59
60 export MAX_JOBS=$NIX_BUILD_CORES
61 '';
62
63 env = lib.attrsets.optionalAttrs cudaSupport {
64 TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
65 };
66
67 stdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv;
68
69 buildInputs = lib.optionals cudaSupport (
70 with cudaPackages;
71 [
72 # flash-attn build
73 cuda_cudart # cuda_runtime_api.h
74 libcusparse # cusparse.h
75 cuda_cccl # nv/target
76 libcublas # cublas_v2.h
77 libcusolver # cusolverDn.h
78 libcurand # curand_kernel.h
79 ]
80 );
81
82 nativeBuildInputs = [
83 ninja
84 which
85 ] ++ lib.optionals cudaSupport (with cudaPackages; [ cuda_nvcc ]);
86
87 dependencies = [
88 numpy
89 torch
90 ];
91
92 pythonImportsCheck = [ "xformers" ];
93
94 # Has broken 0.03 version:
95 # https://github.com/NixOS/nixpkgs/pull/285495#issuecomment-1920730720
96 passthru.skipBulkUpdate = true;
97
98 dontUseCmakeConfigure = true;
99
100 # see commented out missing packages
101 doCheck = false;
102
103 nativeCheckInputs = [
104 pytestCheckHook
105 pytest-cov-stub
106 pytest-timeout
107 hydra-core
108 fairscale
109 scipy
110 cmake
111 networkx
112 triton
113 # apex
114 einops
115 transformers
116 timm
117 # flash-attn
118 ];
119
120 meta = with lib; {
121 description = "Collection of composable Transformer building blocks";
122 homepage = "https://github.com/facebookresearch/xformers";
123 changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md";
124 license = licenses.bsd3;
125 maintainers = with maintainers; [ happysalada ];
126 };
127}