nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 pythonOlder,
6 fetchFromGitHub,
7 which,
8 setuptools,
9 # runtime dependencies
10 numpy,
11 torch,
12 # check dependencies
13 pytestCheckHook,
14 pytest-cov-stub,
15 # , pytest-mpi
16 pytest-timeout,
17 # , pytorch-image-models
18 hydra-core,
19 fairscale,
20 scipy,
21 cmake,
22 ninja,
23 triton,
24 networkx,
25 #, apex
26 einops,
27 transformers,
28 timm,
29 #, flash-attn
30 openmp,
31}:
32let
33 inherit (torch) cudaCapabilities cudaPackages cudaSupport;
34
35 # version 0.0.32.post2 was confirmed to break CUDA.
36 # Remove this note once the latest published revision "just works".
37 version = "0.0.30";
38 effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv;
39in
40buildPythonPackage.override { stdenv = effectiveStdenv; } {
41 pname = "xformers";
42 inherit version;
43 pyproject = true;
44
45 src = fetchFromGitHub {
46 owner = "facebookresearch";
47 repo = "xformers";
48 tag = "v${version}";
49 fetchSubmodules = true;
50 hash = "sha256-ozaw9z8qnGpZ28LQNtwmKeVnrn7KDWNeJKtT6g6Q/W0=";
51 };
52
53 patches = [ ./0001-fix-allow-building-without-git.patch ];
54
55 build-system = [ setuptools ];
56
57 preBuild = ''
58 cat << EOF > ./xformers/version.py
59 # noqa: C801
60 __version__ = "${version}"
61 EOF
62
63 export MAX_JOBS=$NIX_BUILD_CORES
64 '';
65
66 env = lib.attrsets.optionalAttrs cudaSupport {
67 TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
68 };
69
70 buildInputs =
71 lib.optional stdenv.hostPlatform.isDarwin openmp
72 ++ lib.optionals cudaSupport (
73 with cudaPackages;
74 [
75 # flash-attn build
76 cuda_cudart # cuda_runtime_api.h
77 libcusparse # cusparse.h
78 cuda_cccl # nv/target
79 libcublas # cublas_v2.h
80 libcusolver # cusolverDn.h
81 libcurand # curand_kernel.h
82 ]
83 );
84
85 nativeBuildInputs = [
86 ninja
87 which
88 ]
89 ++ lib.optionals cudaSupport (with cudaPackages; [ cuda_nvcc ])
90 ++ lib.optional stdenv.hostPlatform.isDarwin openmp.dev;
91
92 dependencies = [
93 numpy
94 torch
95 ];
96
97 pythonImportsCheck = [ "xformers" ];
98
99 # Has broken 0.03 version:
100 # https://github.com/NixOS/nixpkgs/pull/285495#issuecomment-1920730720
101 passthru.skipBulkUpdate = true;
102
103 dontUseCmakeConfigure = true;
104
105 # see commented out missing packages
106 doCheck = false;
107
108 nativeCheckInputs = [
109 pytestCheckHook
110 pytest-cov-stub
111 pytest-timeout
112 hydra-core
113 fairscale
114 scipy
115 cmake
116 networkx
117 triton
118 # apex
119 einops
120 transformers
121 timm
122 # flash-attn
123 ];
124
125 meta = {
126 description = "Collection of composable Transformer building blocks";
127 homepage = "https://github.com/facebookresearch/xformers";
128 changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md";
129 license = lib.licenses.bsd3;
130 maintainers = with lib.maintainers; [ happysalada ];
131 };
132}