1{
2 stdenv,
3 lib,
4 buildPythonPackage,
5 fetchFromGitHub,
6 fetchpatch,
7 pythonAtLeast,
8
9 # buildInputs
10 llvmPackages,
11
12 # build-system
13 setuptools,
14
15 # dependencies
16 huggingface-hub,
17 numpy,
18 packaging,
19 psutil,
20 pyyaml,
21 safetensors,
22 torch,
23
24 # tests
25 addBinToPathHook,
26 evaluate,
27 parameterized,
28 pytestCheckHook,
29 transformers,
30 config,
31 cudatoolkit,
32 writableTmpDirAsHomeHook,
33}:
34
35buildPythonPackage rec {
36 pname = "accelerate";
37 version = "1.5.2";
38 pyproject = true;
39
40 src = fetchFromGitHub {
41 owner = "huggingface";
42 repo = "accelerate";
43 tag = "v${version}";
44 hash = "sha256-J4eDm/PcyKK3256l6CAWUj4AWTB6neTKgxbBmul0BPE=";
45 };
46
47 patches = [
48 # Fix tests on darwin: https://github.com/huggingface/accelerate/pull/3464
49 (fetchpatch {
50 url = "https://github.com/huggingface/accelerate/commit/8b31a2fe2c6d0246fff9885fb1f8456fb560abc7.patch";
51 hash = "sha256-Ek9Ou4Y/H1jt3qanf2g3HowBoTsN/bn4yV9O3ogcXMo=";
52 })
53 ];
54
55 buildInputs = [ llvmPackages.openmp ];
56
57 build-system = [ setuptools ];
58
59 dependencies = [
60 huggingface-hub
61 numpy
62 packaging
63 psutil
64 pyyaml
65 safetensors
66 torch
67 ];
68
69 nativeCheckInputs = [
70 addBinToPathHook
71 evaluate
72 parameterized
73 pytestCheckHook
74 transformers
75 writableTmpDirAsHomeHook
76 ];
77
78 preCheck = lib.optionalString config.cudaSupport ''
79 export TRITON_PTXAS_PATH="${lib.getExe' cudatoolkit "ptxas"}"
80 '';
81 pytestFlagsArray = [ "tests" ];
82 disabledTests =
83 [
84 # try to download data:
85 "FeatureExamplesTests"
86 "test_infer_auto_device_map_on_t0pp"
87
88 # require socket communication
89 "test_explicit_dtypes"
90 "test_gated"
91 "test_invalid_model_name"
92 "test_invalid_model_name_transformers"
93 "test_no_metadata"
94 "test_no_split_modules"
95 "test_remote_code"
96 "test_transformers_model"
97
98 # nondeterministic, tests GC behaviour by thresholding global ram usage
99 "test_free_memory_dereferences_prepared_components"
100
101 # set the environment variable, CC, which conflicts with standard environment
102 "test_patch_environment_key_exists"
103 ]
104 ++ lib.optionals ((pythonAtLeast "3.13") || (torch.rocmSupport or false)) [
105 # RuntimeError: Dynamo is not supported on Python 3.13+
106 # OR torch.compile tests broken on torch 2.5 + rocm
107 "test_can_unwrap_distributed_compiled_model_keep_torch_compile"
108 "test_can_unwrap_distributed_compiled_model_remove_torch_compile"
109 "test_convert_to_fp32"
110 "test_send_to_device_compiles"
111 ]
112 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
113 # usual aarch64-linux RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly
114 "CheckpointTest"
115 # TypeError: unsupported operand type(s) for /: 'NoneType' and 'int' (it seems cpuinfo doesn't work here)
116 "test_mpi_multicpu_config_cmd"
117 ]
118 ++ lib.optionals (!config.cudaSupport) [
119 # requires ptxas from cudatoolkit, which is unfree
120 "test_dynamo_extract_model"
121 ]
122 ++ lib.optionals stdenv.hostPlatform.isDarwin [
123 # RuntimeError: 'accelerate-launch /nix/store/a7vhm7b74a7bmxc35j26s9iy1zfaqjs...
124 "test_accelerate_test"
125 "test_init_trackers"
126 "test_init_trackers"
127 "test_log"
128 "test_log_with_tensor"
129
130 # After enabling MPS in pytorch, these tests started failing
131 "test_accelerated_optimizer_step_was_skipped"
132 "test_auto_wrap_policy"
133 "test_autocast_kwargs"
134 "test_automatic_loading"
135 "test_backward_prefetch"
136 "test_can_resume_training"
137 "test_can_resume_training_checkpoints_relative_path"
138 "test_can_resume_training_with_folder"
139 "test_can_unwrap_model_fp16"
140 "test_checkpoint_deletion"
141 "test_cpu_offload"
142 "test_cpu_ram_efficient_loading"
143 "test_grad_scaler_kwargs"
144 "test_invalid_registration"
145 "test_map_location"
146 "test_mixed_precision"
147 "test_mixed_precision_buffer_autocast_override"
148 "test_project_dir"
149 "test_project_dir_with_config"
150 "test_sharding_strategy"
151 "test_state_dict_type"
152 "test_with_save_limit"
153 "test_with_scheduler"
154 ]
155 ++ lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isx86_64) [
156 # RuntimeError: torch_shm_manager: execl failed: Permission denied
157 "CheckpointTest"
158 ];
159
160 disabledTestPaths = lib.optionals (!(stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isx86_64)) [
161 # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException:
162 "tests/test_cpu.py"
163 "tests/test_grad_sync.py"
164 "tests/test_metrics.py"
165 "tests/test_scheduler.py"
166 ];
167
168 pythonImportsCheck = [ "accelerate" ];
169
170 __darwinAllowLocalNetworking = true;
171
172 meta = {
173 homepage = "https://huggingface.co/docs/accelerate";
174 description = "Simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision";
175 changelog = "https://github.com/huggingface/accelerate/releases/tag/v${version}";
176 license = lib.licenses.asl20;
177 maintainers = with lib.maintainers; [ bcdarwin ];
178 mainProgram = "accelerate";
179 };
180}