1{ stdenv
2, lib
3, buildPythonPackage
4, fetchFromGitHub
5, pythonAtLeast
6, pythonOlder
7, pytestCheckHook
8, setuptools
9, numpy
10, packaging
11, psutil
12, pyyaml
13, torch
14, evaluate
15, parameterized
16, transformers
17}:
18
19buildPythonPackage rec {
20 pname = "accelerate";
21 version = "0.23.0";
22 format = "pyproject";
23 disabled = pythonOlder "3.7";
24
25 src = fetchFromGitHub {
26 owner = "huggingface";
27 repo = pname;
28 rev = "refs/tags/v${version}";
29 hash = "sha256-pFkEgE1NGLPBW1CeGU0RJr+1Nj/y58ZcljyOnJuR47A=";
30 };
31
32 nativeBuildInputs = [ setuptools ];
33
34 propagatedBuildInputs = [
35 numpy
36 packaging
37 psutil
38 pyyaml
39 torch
40 ];
41
42 nativeCheckInputs = [
43 evaluate
44 parameterized
45 pytestCheckHook
46 transformers
47 ];
48 preCheck = ''
49 export HOME=$(mktemp -d)
50 export PATH=$out/bin:$PATH
51 '';
52 pytestFlagsArray = [ "tests" ];
53 disabledTests = [
54 # try to download data:
55 "FeatureExamplesTests"
56 "test_infer_auto_device_map_on_t0pp"
57
58 # require socket communication
59 "test_explicit_dtypes"
60 "test_gated"
61 "test_invalid_model_name"
62 "test_invalid_model_name_transformers"
63 "test_no_metadata"
64 "test_no_split_modules"
65 "test_remote_code"
66 "test_transformers_model"
67
68 # set the environment variable, CC, which conflicts with standard environment
69 "test_patch_environment_key_exists"
70 ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
71 # usual aarch64-linux RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly
72 "CheckpointTest"
73 ] ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [
74 # RuntimeError: torch_shm_manager: execl failed: Permission denied
75 "CheckpointTest"
76 ] ++ lib.optionals (pythonAtLeast "3.11") [
77 # python3.11 not yet supported for torch.compile
78 "test_dynamo_extract_model"
79 ];
80
81 disabledTestPaths = lib.optionals (!(stdenv.isLinux && stdenv.isx86_64)) [
82 # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException:
83 "tests/test_cpu.py"
84 "tests/test_grad_sync.py"
85 "tests/test_metrics.py"
86 "tests/test_scheduler.py"
87 ];
88
89 pythonImportsCheck = [
90 "accelerate"
91 ];
92
93 meta = with lib; {
94 homepage = "https://huggingface.co/docs/accelerate";
95 description = "A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision";
96 changelog = "https://github.com/huggingface/accelerate/releases/tag/v${version}";
97 license = licenses.asl20;
98 maintainers = with maintainers; [ bcdarwin ];
99 mainProgram = "accelerate";
100 };
101}