1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 pythonOlder,
6 fetchFromGitHub,
7 which,
8 # runtime dependencies
9 numpy,
10 torch,
11 # check dependencies
12 pytestCheckHook,
13 pytest-cov,
14 # , pytest-mpi
15 pytest-timeout,
16 # , pytorch-image-models
17 hydra-core,
18 fairscale,
19 scipy,
20 cmake,
21 triton,
22 networkx,
23 #, apex
24 einops,
25 transformers,
26 timm,
27#, flash-attn
28}:
29let
30 inherit (torch) cudaCapabilities cudaPackages cudaSupport;
31 version = "0.0.23.post1";
32in
33buildPythonPackage {
34 pname = "xformers";
35 inherit version;
36 format = "setuptools";
37
38 disabled = pythonOlder "3.7";
39
40 src = fetchFromGitHub {
41 owner = "facebookresearch";
42 repo = "xformers";
43 rev = "refs/tags/v${version}";
44 hash = "sha256-AJXow8MmX4GxtEE2jJJ/ZIBr+3i+uS4cA6vofb390rY=";
45 fetchSubmodules = true;
46 };
47
48 patches = [ ./0001-fix-allow-building-without-git.patch ];
49
50 preBuild =
51 ''
52 cat << EOF > ./xformers/version.py
53 # noqa: C801
54 __version__ = "${version}"
55 EOF
56 '';
57
58 env = lib.attrsets.optionalAttrs cudaSupport {
59 TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
60 };
61
62 stdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv;
63
64 buildInputs = lib.optionals cudaSupport (
65 with cudaPackages;
66 [
67 # flash-attn build
68 cuda_cudart # cuda_runtime_api.h
69 libcusparse # cusparse.h
70 cuda_cccl # nv/target
71 libcublas # cublas_v2.h
72 libcusolver # cusolverDn.h
73 libcurand # curand_kernel.h
74 ]
75 );
76
77 nativeBuildInputs = [ which ] ++ lib.optionals cudaSupport (with cudaPackages; [
78 cuda_nvcc
79 ]);
80
81 propagatedBuildInputs = [
82 numpy
83 torch
84 ];
85
86 pythonImportsCheck = [ "xformers" ];
87
88 # Has broken 0.03 version:
89 # https://github.com/NixOS/nixpkgs/pull/285495#issuecomment-1920730720
90 passthru.skipBulkUpdate = true;
91
92 dontUseCmakeConfigure = true;
93
94 # see commented out missing packages
95 doCheck = false;
96
97 nativeCheckInputs = [
98 pytestCheckHook
99 pytest-cov
100 pytest-timeout
101 hydra-core
102 fairscale
103 scipy
104 cmake
105 networkx
106 triton
107 # apex
108 einops
109 transformers
110 timm
111 # flash-attn
112 ];
113
114 meta = with lib; {
115 description = "XFormers: A collection of composable Transformer building blocks";
116 homepage = "https://github.com/facebookresearch/xformers";
117 changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md";
118 license = licenses.bsd3;
119 maintainers = with maintainers; [ happysalada ];
120 };
121}