1{ stdenv
2, lib
3, buildPythonPackage
4, fetchFromGitHub
5, fetchpatch
6, pythonAtLeast
7, pythonOlder
8, pytestCheckHook
9, setuptools
10, numpy
11, packaging
12, psutil
13, pyyaml
14, torch
15, evaluate
16, parameterized
17, transformers
18}:
19
20buildPythonPackage rec {
21 pname = "accelerate";
22 version = "0.21.0";
23 format = "pyproject";
24 disabled = pythonOlder "3.7";
25
26 src = fetchFromGitHub {
27 owner = "huggingface";
28 repo = pname;
29 rev = "refs/tags/v${version}";
30 hash = "sha256-BwM3gyNhsRkxtxLNrycUGwBmXf8eq/7b56/ykMryt5w=";
31 };
32
33 patches = [
34 # fix import error when torch>=2.0.1 and torch.distributed is disabled
35 # https://github.com/huggingface/accelerate/pull/1800
36 (fetchpatch {
37 url = "https://github.com/huggingface/accelerate/commit/32701039d302d3875c50c35ab3e76c467755eae9.patch";
38 hash = "sha256-Hth7qyOfx1sC8UaRdbYTnyRXD/VRKf41GtLc0ee1t2I=";
39 })
40 ];
41
42 nativeBuildInputs = [ setuptools ];
43
44 propagatedBuildInputs = [
45 numpy
46 packaging
47 psutil
48 pyyaml
49 torch
50 ];
51
52 nativeCheckInputs = [
53 evaluate
54 parameterized
55 pytestCheckHook
56 transformers
57 ];
58 preCheck = ''
59 export HOME=$(mktemp -d)
60 export PATH=$out/bin:$PATH
61 '';
62 pytestFlagsArray = [ "tests" ];
63 disabledTests = [
64 # try to download data:
65 "FeatureExamplesTests"
66 "test_infer_auto_device_map_on_t0pp"
67 ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
68 # usual aarch64-linux RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly
69 "CheckpointTest"
70 ] ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [
71 # RuntimeError: torch_shm_manager: execl failed: Permission denied
72 "CheckpointTest"
73 ] ++ lib.optionals (pythonAtLeast "3.11") [
74 # python3.11 not yet supported for torch.compile
75 "test_dynamo_extract_model"
76 ];
77
78 disabledTestPaths = lib.optionals (!(stdenv.isLinux && stdenv.isx86_64)) [
79 # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException:
80 "tests/test_cpu.py"
81 "tests/test_grad_sync.py"
82 "tests/test_metrics.py"
83 "tests/test_scheduler.py"
84 ];
85
86 pythonImportsCheck = [
87 "accelerate"
88 ];
89
90 meta = with lib; {
91 homepage = "https://huggingface.co/docs/accelerate";
92 description = "A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision";
93 changelog = "https://github.com/huggingface/accelerate/releases/tag/v${version}";
94 license = licenses.asl20;
95 maintainers = with maintainers; [ bcdarwin ];
96 mainProgram = "accelerate";
97 };
98}