1{
2 stdenv,
3 lib,
4 buildPythonPackage,
5 fetchFromGitHub,
6 pythonAtLeast,
7 pythonOlder,
8 llvmPackages,
9 pytest7CheckHook,
10 setuptools,
11 numpy,
12 packaging,
13 psutil,
14 pyyaml,
15 safetensors,
16 torch,
17 config,
18 cudatoolkit,
19 evaluate,
20 parameterized,
21 transformers,
22}:
23
24buildPythonPackage rec {
25 pname = "accelerate";
26 version = "0.30.0";
27 pyproject = true;
28
29 disabled = pythonOlder "3.8";
30
31 src = fetchFromGitHub {
32 owner = "huggingface";
33 repo = "accelerate";
34 rev = "refs/tags/v${version}";
35 hash = "sha256-E20pI5BrcTrMYrhriuOUl5/liSaQQy6eqRyCoauwb9Q=";
36 };
37
38 buildInputs = [ llvmPackages.openmp ];
39
40 build-system = [ setuptools ];
41
42 dependencies = [
43 numpy
44 packaging
45 psutil
46 pyyaml
47 safetensors
48 torch
49 ];
50
51 nativeCheckInputs = [
52 evaluate
53 parameterized
54 pytest7CheckHook
55 transformers
56 ];
57 preCheck =
58 ''
59 export HOME=$(mktemp -d)
60 export PATH=$out/bin:$PATH
61 ''
62 + lib.optionalString config.cudaSupport ''
63 export TRITON_PTXAS_PATH="${cudatoolkit}/bin/ptxas"
64 '';
65 pytestFlagsArray = [ "tests" ];
66 disabledTests =
67 [
68 # try to download data:
69 "FeatureExamplesTests"
70 "test_infer_auto_device_map_on_t0pp"
71
72 # require socket communication
73 "test_explicit_dtypes"
74 "test_gated"
75 "test_invalid_model_name"
76 "test_invalid_model_name_transformers"
77 "test_no_metadata"
78 "test_no_split_modules"
79 "test_remote_code"
80 "test_transformers_model"
81
82 # nondeterministic, tests GC behaviour by thresholding global ram usage
83 "test_free_memory_dereferences_prepared_components"
84
85 # set the environment variable, CC, which conflicts with standard environment
86 "test_patch_environment_key_exists"
87 ]
88 ++ lib.optionals (pythonAtLeast "3.12") [
89 # RuntimeError: Dynamo is not supported on Python 3.12+
90 "test_convert_to_fp32"
91 "test_send_to_device_compiles"
92 ]
93 ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
94 # usual aarch64-linux RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly
95 "CheckpointTest"
96 # TypeError: unsupported operand type(s) for /: 'NoneType' and 'int' (it seems cpuinfo doesn't work here)
97 "test_mpi_multicpu_config_cmd"
98 ]
99 ++ lib.optionals (!config.cudaSupport) [
100 # requires ptxas from cudatoolkit, which is unfree
101 "test_dynamo_extract_model"
102 ]
103 ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [
104 # RuntimeError: torch_shm_manager: execl failed: Permission denied
105 "CheckpointTest"
106 ];
107
108 disabledTestPaths = lib.optionals (!(stdenv.isLinux && stdenv.isx86_64)) [
109 # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException:
110 "tests/test_cpu.py"
111 "tests/test_grad_sync.py"
112 "tests/test_metrics.py"
113 "tests/test_scheduler.py"
114 ];
115
116 pythonImportsCheck = [ "accelerate" ];
117
118 __darwinAllowLocalNetworking = true;
119
120 meta = with lib; {
121 homepage = "https://huggingface.co/docs/accelerate";
122 description = "A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision";
123 changelog = "https://github.com/huggingface/accelerate/releases/tag/v${version}";
124 license = licenses.asl20;
125 maintainers = with maintainers; [ bcdarwin ];
126 mainProgram = "accelerate";
127 };
128}