1{
2 lib,
3 buildPythonPackage,
4 pythonOlder,
5 fetchFromGitHub,
6 which,
7 # runtime dependencies
8 numpy,
9 torch,
10 # check dependencies
11 pytestCheckHook,
12 pytest-cov,
13 # , pytest-mpi
14 pytest-timeout,
15 # , pytorch-image-models
16 hydra-core,
17 fairscale,
18 scipy,
19 cmake,
20 openai-triton,
21 networkx,
22 #, apex
23 einops,
24 transformers,
25 timm,
26#, flash-attn
27}:
28let
29 inherit (torch) cudaCapabilities cudaPackages cudaSupport;
30 version = "0.0.23.post1";
31in
32buildPythonPackage {
33 pname = "xformers";
34 inherit version;
35 format = "setuptools";
36
37 disabled = pythonOlder "3.7";
38
39 src = fetchFromGitHub {
40 owner = "facebookresearch";
41 repo = "xformers";
42 rev = "refs/tags/v${version}";
43 hash = "sha256-AJXow8MmX4GxtEE2jJJ/ZIBr+3i+uS4cA6vofb390rY=";
44 fetchSubmodules = true;
45 };
46
47 patches = [ ./0001-fix-allow-building-without-git.patch ];
48
49 preBuild =
50 ''
51 cat << EOF > ./xformers/version.py
52 # noqa: C801
53 __version__ = "${version}"
54 EOF
55 ''
56 + lib.optionalString cudaSupport ''
57 export CUDA_HOME=${cudaPackages.cuda_nvcc}
58 export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
59 '';
60
61 buildInputs = lib.optionals cudaSupport (
62 with cudaPackages;
63 [
64 # flash-attn build
65 cuda_cudart # cuda_runtime_api.h
66 libcusparse.dev # cusparse.h
67 cuda_cccl.dev # nv/target
68 libcublas.dev # cublas_v2.h
69 libcusolver.dev # cusolverDn.h
70 libcurand.dev # curand_kernel.h
71 ]
72 );
73
74 nativeBuildInputs = [ which ];
75
76 propagatedBuildInputs = [
77 numpy
78 torch
79 ];
80
81 pythonImportsCheck = [ "xformers" ];
82
83 # Has broken 0.03 version:
84 # https://github.com/NixOS/nixpkgs/pull/285495#issuecomment-1920730720
85 passthru.skipBulkUpdate = true;
86
87 dontUseCmakeConfigure = true;
88
89 # see commented out missing packages
90 doCheck = false;
91
92 nativeCheckInputs = [
93 pytestCheckHook
94 pytest-cov
95 pytest-timeout
96 hydra-core
97 fairscale
98 scipy
99 cmake
100 networkx
101 openai-triton
102 # apex
103 einops
104 transformers
105 timm
106 # flash-attn
107 ];
108
109 meta = with lib; {
110 description = "XFormers: A collection of composable Transformer building blocks";
111 homepage = "https://github.com/facebookresearch/xformers";
112 changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md";
113 license = licenses.bsd3;
114 maintainers = with maintainers; [ happysalada ];
115 };
116}