1{ stdenv 2, lib 3, buildPythonPackage 4, fetchFromGitHub 5, fetchpatch 6, pythonAtLeast 7, pythonOlder 8, pytestCheckHook 9, setuptools 10, numpy 11, packaging 12, psutil 13, pyyaml 14, torch 15, evaluate 16, parameterized 17, transformers 18}: 19 20buildPythonPackage rec { 21 pname = "accelerate"; 22 version = "0.24.1"; 23 format = "pyproject"; 24 disabled = pythonOlder "3.7"; 25 26 src = fetchFromGitHub { 27 owner = "huggingface"; 28 repo = pname; 29 rev = "refs/tags/v${version}"; 30 hash = "sha256-DKyFb+4DUMhVUwr+sgF2IaJS9pEj2o2shGYwExfffWg="; 31 }; 32 33 patches = [ 34 # https://github.com/huggingface/accelerate/pull/2121 35 (fetchpatch { 36 name = "fix-import-error-without-torch_distributed.patch"; 37 url = "https://github.com/huggingface/accelerate/commit/42048092eabd67a407ea513a62f2acde97079fbc.patch"; 38 hash = "sha256-9lvnU6z5ZEFc5RVw2bP0cGVyrwAp/pxX4ZgnmCN7qH8="; 39 }) 40 ]; 41 42 nativeBuildInputs = [ setuptools ]; 43 44 propagatedBuildInputs = [ 45 numpy 46 packaging 47 psutil 48 pyyaml 49 torch 50 ]; 51 52 nativeCheckInputs = [ 53 evaluate 54 parameterized 55 pytestCheckHook 56 transformers 57 ]; 58 preCheck = '' 59 export HOME=$(mktemp -d) 60 export PATH=$out/bin:$PATH 61 ''; 62 pytestFlagsArray = [ "tests" ]; 63 disabledTests = [ 64 # try to download data: 65 "FeatureExamplesTests" 66 "test_infer_auto_device_map_on_t0pp" 67 68 # require socket communication 69 "test_explicit_dtypes" 70 "test_gated" 71 "test_invalid_model_name" 72 "test_invalid_model_name_transformers" 73 "test_no_metadata" 74 "test_no_split_modules" 75 "test_remote_code" 76 "test_transformers_model" 77 78 # set the environment variable, CC, which conflicts with standard environment 79 "test_patch_environment_key_exists" 80 ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [ 81 # usual aarch64-linux RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly 82 "CheckpointTest" 83 ] ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [ 84 # RuntimeError: torch_shm_manager: execl failed: Permission denied 85 "CheckpointTest" 86 ] ++ lib.optionals (pythonAtLeast "3.11") [ 87 # python3.11 not yet supported for torch.compile 88 "test_dynamo_extract_model" 89 ]; 90 91 disabledTestPaths = lib.optionals (!(stdenv.isLinux && stdenv.isx86_64)) [ 92 # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException: 93 "tests/test_cpu.py" 94 "tests/test_grad_sync.py" 95 "tests/test_metrics.py" 96 "tests/test_scheduler.py" 97 ]; 98 99 pythonImportsCheck = [ 100 "accelerate" 101 ]; 102 103 meta = with lib; { 104 homepage = "https://huggingface.co/docs/accelerate"; 105 description = "A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision"; 106 changelog = "https://github.com/huggingface/accelerate/releases/tag/v${version}"; 107 license = licenses.asl20; 108 maintainers = with maintainers; [ bcdarwin ]; 109 mainProgram = "accelerate"; 110 }; 111}