1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 pythonOlder, 6 fetchFromGitHub, 7 which, 8 setuptools, 9 # runtime dependencies 10 numpy, 11 torch, 12 # check dependencies 13 pytestCheckHook, 14 pytest-cov-stub, 15 # , pytest-mpi 16 pytest-timeout, 17 # , pytorch-image-models 18 hydra-core, 19 fairscale, 20 scipy, 21 cmake, 22 ninja, 23 triton, 24 networkx, 25 #, apex 26 einops, 27 transformers, 28 timm, 29#, flash-attn 30}: 31let 32 inherit (torch) cudaCapabilities cudaPackages cudaSupport; 33 version = "0.0.28.post3"; 34in 35buildPythonPackage { 36 pname = "xformers"; 37 inherit version; 38 pyproject = true; 39 40 disabled = pythonOlder "3.9"; 41 42 src = fetchFromGitHub { 43 owner = "facebookresearch"; 44 repo = "xformers"; 45 tag = "v${version}"; 46 hash = "sha256-23tnhCHK+Z0No8fqZxkgDFp2VIgXZR4jpM+pkb/vvmw="; 47 fetchSubmodules = true; 48 }; 49 50 patches = [ ./0001-fix-allow-building-without-git.patch ]; 51 52 build-system = [ setuptools ]; 53 54 preBuild = '' 55 cat << EOF > ./xformers/version.py 56 # noqa: C801 57 __version__ = "${version}" 58 EOF 59 60 export MAX_JOBS=$NIX_BUILD_CORES 61 ''; 62 63 env = lib.attrsets.optionalAttrs cudaSupport { 64 TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}"; 65 }; 66 67 stdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv; 68 69 buildInputs = lib.optionals cudaSupport ( 70 with cudaPackages; 71 [ 72 # flash-attn build 73 cuda_cudart # cuda_runtime_api.h 74 libcusparse # cusparse.h 75 cuda_cccl # nv/target 76 libcublas # cublas_v2.h 77 libcusolver # cusolverDn.h 78 libcurand # curand_kernel.h 79 ] 80 ); 81 82 nativeBuildInputs = [ 83 ninja 84 which 85 ] ++ lib.optionals cudaSupport (with cudaPackages; [ cuda_nvcc ]); 86 87 dependencies = [ 88 numpy 89 torch 90 ]; 91 92 pythonImportsCheck = [ "xformers" ]; 93 94 # Has broken 0.03 version: 95 # https://github.com/NixOS/nixpkgs/pull/285495#issuecomment-1920730720 96 passthru.skipBulkUpdate = true; 97 98 dontUseCmakeConfigure = true; 99 100 # see commented out missing packages 101 doCheck = false; 102 103 nativeCheckInputs = [ 104 pytestCheckHook 105 pytest-cov-stub 106 pytest-timeout 107 hydra-core 108 fairscale 109 scipy 110 cmake 111 networkx 112 triton 113 # apex 114 einops 115 transformers 116 timm 117 # flash-attn 118 ]; 119 120 meta = with lib; { 121 description = "Collection of composable Transformer building blocks"; 122 homepage = "https://github.com/facebookresearch/xformers"; 123 changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md"; 124 license = licenses.bsd3; 125 maintainers = with maintainers; [ happysalada ]; 126 }; 127}