Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
at master 2.8 kB view raw
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 pythonOlder, 6 fetchFromGitHub, 7 which, 8 setuptools, 9 # runtime dependencies 10 numpy, 11 torch, 12 # check dependencies 13 pytestCheckHook, 14 pytest-cov-stub, 15 # , pytest-mpi 16 pytest-timeout, 17 # , pytorch-image-models 18 hydra-core, 19 fairscale, 20 scipy, 21 cmake, 22 ninja, 23 triton, 24 networkx, 25 #, apex 26 einops, 27 transformers, 28 timm, 29 #, flash-attn 30 openmp, 31}: 32let 33 inherit (torch) cudaCapabilities cudaPackages cudaSupport; 34 35 # version 0.0.32.post2 was confirmed to break CUDA. 36 # Remove this note once the latest published revision "just works". 37 version = "0.0.30"; 38in 39buildPythonPackage { 40 pname = "xformers"; 41 inherit version; 42 pyproject = true; 43 44 src = fetchFromGitHub { 45 owner = "facebookresearch"; 46 repo = "xformers"; 47 tag = "v${version}"; 48 fetchSubmodules = true; 49 hash = "sha256-ozaw9z8qnGpZ28LQNtwmKeVnrn7KDWNeJKtT6g6Q/W0="; 50 }; 51 52 patches = [ ./0001-fix-allow-building-without-git.patch ]; 53 54 build-system = [ setuptools ]; 55 56 preBuild = '' 57 cat << EOF > ./xformers/version.py 58 # noqa: C801 59 __version__ = "${version}" 60 EOF 61 62 export MAX_JOBS=$NIX_BUILD_CORES 63 ''; 64 65 env = lib.attrsets.optionalAttrs cudaSupport { 66 TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}"; 67 }; 68 69 stdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv; 70 71 buildInputs = 72 lib.optional stdenv.hostPlatform.isDarwin openmp 73 ++ lib.optionals cudaSupport ( 74 with cudaPackages; 75 [ 76 # flash-attn build 77 cuda_cudart # cuda_runtime_api.h 78 libcusparse # cusparse.h 79 cuda_cccl # nv/target 80 libcublas # cublas_v2.h 81 libcusolver # cusolverDn.h 82 libcurand # curand_kernel.h 83 ] 84 ); 85 86 nativeBuildInputs = [ 87 ninja 88 which 89 ] 90 ++ lib.optionals cudaSupport (with cudaPackages; [ cuda_nvcc ]) 91 ++ lib.optional stdenv.hostPlatform.isDarwin openmp.dev; 92 93 dependencies = [ 94 numpy 95 torch 96 ]; 97 98 pythonImportsCheck = [ "xformers" ]; 99 100 # Has broken 0.03 version: 101 # https://github.com/NixOS/nixpkgs/pull/285495#issuecomment-1920730720 102 passthru.skipBulkUpdate = true; 103 104 dontUseCmakeConfigure = true; 105 106 # see commented out missing packages 107 doCheck = false; 108 109 nativeCheckInputs = [ 110 pytestCheckHook 111 pytest-cov-stub 112 pytest-timeout 113 hydra-core 114 fairscale 115 scipy 116 cmake 117 networkx 118 triton 119 # apex 120 einops 121 transformers 122 timm 123 # flash-attn 124 ]; 125 126 meta = { 127 description = "Collection of composable Transformer building blocks"; 128 homepage = "https://github.com/facebookresearch/xformers"; 129 changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md"; 130 license = lib.licenses.bsd3; 131 maintainers = with lib.maintainers; [ happysalada ]; 132 }; 133}