1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 pythonOlder,
6 fetchFromGitHub,
7 which,
8 setuptools,
9 # runtime dependencies
10 numpy,
11 torch,
12 # check dependencies
13 pytestCheckHook,
14 pytest-cov-stub,
15 # , pytest-mpi
16 pytest-timeout,
17 # , pytorch-image-models
18 hydra-core,
19 fairscale,
20 scipy,
21 cmake,
22 ninja,
23 triton,
24 networkx,
25 #, apex
26 einops,
27 transformers,
28 timm,
29 #, flash-attn
30 openmp,
31}:
32let
33 inherit (torch) cudaCapabilities cudaPackages cudaSupport;
34
35 # version 0.0.32.post2 was confirmed to break CUDA.
36 # Remove this note once the latest published revision "just works".
37 version = "0.0.30";
38in
39buildPythonPackage {
40 pname = "xformers";
41 inherit version;
42 pyproject = true;
43
44 src = fetchFromGitHub {
45 owner = "facebookresearch";
46 repo = "xformers";
47 tag = "v${version}";
48 fetchSubmodules = true;
49 hash = "sha256-ozaw9z8qnGpZ28LQNtwmKeVnrn7KDWNeJKtT6g6Q/W0=";
50 };
51
52 patches = [ ./0001-fix-allow-building-without-git.patch ];
53
54 build-system = [ setuptools ];
55
56 preBuild = ''
57 cat << EOF > ./xformers/version.py
58 # noqa: C801
59 __version__ = "${version}"
60 EOF
61
62 export MAX_JOBS=$NIX_BUILD_CORES
63 '';
64
65 env = lib.attrsets.optionalAttrs cudaSupport {
66 TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
67 };
68
69 stdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv;
70
71 buildInputs =
72 lib.optional stdenv.hostPlatform.isDarwin openmp
73 ++ lib.optionals cudaSupport (
74 with cudaPackages;
75 [
76 # flash-attn build
77 cuda_cudart # cuda_runtime_api.h
78 libcusparse # cusparse.h
79 cuda_cccl # nv/target
80 libcublas # cublas_v2.h
81 libcusolver # cusolverDn.h
82 libcurand # curand_kernel.h
83 ]
84 );
85
86 nativeBuildInputs = [
87 ninja
88 which
89 ]
90 ++ lib.optionals cudaSupport (with cudaPackages; [ cuda_nvcc ])
91 ++ lib.optional stdenv.hostPlatform.isDarwin openmp.dev;
92
93 dependencies = [
94 numpy
95 torch
96 ];
97
98 pythonImportsCheck = [ "xformers" ];
99
100 # Has broken 0.03 version:
101 # https://github.com/NixOS/nixpkgs/pull/285495#issuecomment-1920730720
102 passthru.skipBulkUpdate = true;
103
104 dontUseCmakeConfigure = true;
105
106 # see commented out missing packages
107 doCheck = false;
108
109 nativeCheckInputs = [
110 pytestCheckHook
111 pytest-cov-stub
112 pytest-timeout
113 hydra-core
114 fairscale
115 scipy
116 cmake
117 networkx
118 triton
119 # apex
120 einops
121 transformers
122 timm
123 # flash-attn
124 ];
125
126 meta = {
127 description = "Collection of composable Transformer building blocks";
128 homepage = "https://github.com/facebookresearch/xformers";
129 changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md";
130 license = lib.licenses.bsd3;
131 maintainers = with lib.maintainers; [ happysalada ];
132 };
133}