1{ stdenv
2, lib
3, buildPythonPackage
4, fetchFromGitHub
5, fetchpatch
6, pythonAtLeast
7, pythonOlder
8, pytestCheckHook
9, setuptools
10, numpy
11, packaging
12, psutil
13, pyyaml
14, torch
15, evaluate
16, parameterized
17, transformers
18}:
19
20buildPythonPackage rec {
21 pname = "accelerate";
22 version = "0.24.1";
23 format = "pyproject";
24 disabled = pythonOlder "3.7";
25
26 src = fetchFromGitHub {
27 owner = "huggingface";
28 repo = pname;
29 rev = "refs/tags/v${version}";
30 hash = "sha256-DKyFb+4DUMhVUwr+sgF2IaJS9pEj2o2shGYwExfffWg=";
31 };
32
33 patches = [
34 # https://github.com/huggingface/accelerate/pull/2121
35 (fetchpatch {
36 name = "fix-import-error-without-torch_distributed.patch";
37 url = "https://github.com/huggingface/accelerate/commit/42048092eabd67a407ea513a62f2acde97079fbc.patch";
38 hash = "sha256-9lvnU6z5ZEFc5RVw2bP0cGVyrwAp/pxX4ZgnmCN7qH8=";
39 })
40 ];
41
42 nativeBuildInputs = [ setuptools ];
43
44 propagatedBuildInputs = [
45 numpy
46 packaging
47 psutil
48 pyyaml
49 torch
50 ];
51
52 nativeCheckInputs = [
53 evaluate
54 parameterized
55 pytestCheckHook
56 transformers
57 ];
58 preCheck = ''
59 export HOME=$(mktemp -d)
60 export PATH=$out/bin:$PATH
61 '';
62 pytestFlagsArray = [ "tests" ];
63 disabledTests = [
64 # try to download data:
65 "FeatureExamplesTests"
66 "test_infer_auto_device_map_on_t0pp"
67
68 # require socket communication
69 "test_explicit_dtypes"
70 "test_gated"
71 "test_invalid_model_name"
72 "test_invalid_model_name_transformers"
73 "test_no_metadata"
74 "test_no_split_modules"
75 "test_remote_code"
76 "test_transformers_model"
77
78 # set the environment variable, CC, which conflicts with standard environment
79 "test_patch_environment_key_exists"
80 ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
81 # usual aarch64-linux RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly
82 "CheckpointTest"
83 ] ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [
84 # RuntimeError: torch_shm_manager: execl failed: Permission denied
85 "CheckpointTest"
86 ] ++ lib.optionals (pythonAtLeast "3.11") [
87 # python3.11 not yet supported for torch.compile
88 "test_dynamo_extract_model"
89 ];
90
91 disabledTestPaths = lib.optionals (!(stdenv.isLinux && stdenv.isx86_64)) [
92 # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException:
93 "tests/test_cpu.py"
94 "tests/test_grad_sync.py"
95 "tests/test_metrics.py"
96 "tests/test_scheduler.py"
97 ];
98
99 pythonImportsCheck = [
100 "accelerate"
101 ];
102
103 meta = with lib; {
104 homepage = "https://huggingface.co/docs/accelerate";
105 description = "A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision";
106 changelog = "https://github.com/huggingface/accelerate/releases/tag/v${version}";
107 license = licenses.asl20;
108 maintainers = with maintainers; [ bcdarwin ];
109 mainProgram = "accelerate";
110 };
111}