1{
2 stdenv,
3 lib,
4 buildPythonPackage,
5 fetchFromGitHub,
6 pythonAtLeast,
7
8 # buildInputs
9 llvmPackages,
10
11 # build-system
12 setuptools,
13
14 # dependencies
15 huggingface-hub,
16 numpy,
17 packaging,
18 psutil,
19 pyyaml,
20 safetensors,
21 torch,
22
23 # tests
24 evaluate,
25 parameterized,
26 pytest7CheckHook,
27 transformers,
28 config,
29 cudatoolkit,
30}:
31
32buildPythonPackage rec {
33 pname = "accelerate";
34 version = "1.2.1";
35 pyproject = true;
36
37 src = fetchFromGitHub {
38 owner = "huggingface";
39 repo = "accelerate";
40 tag = "v${version}";
41 hash = "sha256-KnFf6ge0vUR/C7Rh/c6ZttCGKo9OUIWCUYxk5O+OW7c=";
42 };
43
44 buildInputs = [ llvmPackages.openmp ];
45
46 build-system = [ setuptools ];
47
48 dependencies = [
49 huggingface-hub
50 numpy
51 packaging
52 psutil
53 pyyaml
54 safetensors
55 torch
56 ];
57
58 nativeCheckInputs = [
59 evaluate
60 parameterized
61 pytest7CheckHook
62 transformers
63 ];
64 preCheck =
65 ''
66 export HOME=$(mktemp -d)
67 export PATH=$out/bin:$PATH
68 ''
69 + lib.optionalString config.cudaSupport ''
70 export TRITON_PTXAS_PATH="${cudatoolkit}/bin/ptxas"
71 '';
72 pytestFlagsArray = [ "tests" ];
73 disabledTests =
74 [
75 # try to download data:
76 "FeatureExamplesTests"
77 "test_infer_auto_device_map_on_t0pp"
78
79 # require socket communication
80 "test_explicit_dtypes"
81 "test_gated"
82 "test_invalid_model_name"
83 "test_invalid_model_name_transformers"
84 "test_no_metadata"
85 "test_no_split_modules"
86 "test_remote_code"
87 "test_transformers_model"
88
89 # nondeterministic, tests GC behaviour by thresholding global ram usage
90 "test_free_memory_dereferences_prepared_components"
91
92 # set the environment variable, CC, which conflicts with standard environment
93 "test_patch_environment_key_exists"
94 ]
95 ++ lib.optionals (pythonAtLeast "3.12") [
96 # RuntimeError: Dynamo is not supported on Python 3.12+
97 "test_convert_to_fp32"
98 "test_dynamo_extract_model"
99 "test_send_to_device_compiles"
100 ]
101 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
102 # usual aarch64-linux RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly
103 "CheckpointTest"
104 # TypeError: unsupported operand type(s) for /: 'NoneType' and 'int' (it seems cpuinfo doesn't work here)
105 "test_mpi_multicpu_config_cmd"
106 ]
107 ++ lib.optionals (!config.cudaSupport) [
108 # requires ptxas from cudatoolkit, which is unfree
109 "test_dynamo_extract_model"
110 ]
111 ++ lib.optionals stdenv.hostPlatform.isDarwin [
112 # RuntimeError: 'accelerate-launch /nix/store/a7vhm7b74a7bmxc35j26s9iy1zfaqjs...
113 "test_accelerate_test"
114 "test_init_trackers"
115 "test_init_trackers"
116 "test_log"
117 "test_log_with_tensor"
118
119 # After enabling MPS in pytorch, these tests started failing
120 "test_accelerated_optimizer_step_was_skipped"
121 "test_auto_wrap_policy"
122 "test_autocast_kwargs"
123 "test_automatic_loading"
124 "test_backward_prefetch"
125 "test_can_resume_training"
126 "test_can_resume_training_checkpoints_relative_path"
127 "test_can_resume_training_with_folder"
128 "test_can_unwrap_model_fp16"
129 "test_checkpoint_deletion"
130 "test_cpu_offload"
131 "test_cpu_ram_efficient_loading"
132 "test_grad_scaler_kwargs"
133 "test_invalid_registration"
134 "test_map_location"
135 "test_mixed_precision"
136 "test_mixed_precision_buffer_autocast_override"
137 "test_project_dir"
138 "test_project_dir_with_config"
139 "test_sharding_strategy"
140 "test_state_dict_type"
141 "test_with_save_limit"
142 "test_with_scheduler"
143 ]
144 ++ lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isx86_64) [
145 # RuntimeError: torch_shm_manager: execl failed: Permission denied
146 "CheckpointTest"
147 ];
148
149 disabledTestPaths = lib.optionals (!(stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isx86_64)) [
150 # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException:
151 "tests/test_cpu.py"
152 "tests/test_grad_sync.py"
153 "tests/test_metrics.py"
154 "tests/test_scheduler.py"
155 ];
156
157 pythonImportsCheck = [ "accelerate" ];
158
159 __darwinAllowLocalNetworking = true;
160
161 meta = {
162 homepage = "https://huggingface.co/docs/accelerate";
163 description = "Simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision";
164 changelog = "https://github.com/huggingface/accelerate/releases/tag/v${version}";
165 license = lib.licenses.asl20;
166 maintainers = with lib.maintainers; [ bcdarwin ];
167 mainProgram = "accelerate";
168 };
169}