1{ stdenv
2, lib
3, buildPythonPackage
4, fetchFromGitHub
5, pythonOlder
6, pytestCheckHook
7, setuptools
8, numpy
9, packaging
10, psutil
11, pyyaml
12, torch
13, evaluate
14, parameterized
15, transformers
16}:
17
18buildPythonPackage rec {
19 pname = "accelerate";
20 version = "0.19.0";
21 format = "pyproject";
22 disabled = pythonOlder "3.7";
23
24 src = fetchFromGitHub {
25 owner = "huggingface";
26 repo = pname;
27 rev = "refs/tags/v${version}";
28 hash = "sha256-gW4wCpkyxoWfxXu8UHZfgopSQhOoPhGgqEqFiHJ+Db4=";
29 };
30
31 nativeBuildInputs = [ setuptools ];
32
33 propagatedBuildInputs = [
34 numpy
35 packaging
36 psutil
37 pyyaml
38 torch
39 ];
40
41 nativeCheckInputs = [
42 evaluate
43 parameterized
44 pytestCheckHook
45 transformers
46 ];
47 preCheck = ''
48 export HOME=$(mktemp -d)
49 export PATH=$out/bin:$PATH
50 '';
51 pytestFlagsArray = [ "tests" ];
52 disabledTests = [
53 # try to download data:
54 "FeatureExamplesTests"
55 "test_infer_auto_device_map_on_t0pp"
56 # known failure with Torch>2.0; see https://github.com/huggingface/accelerate/pull/1339:
57 # (remove for next release)
58 "test_gradient_sync_cpu_multi"
59 ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
60 # usual aarch64-linux RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly
61 "CheckpointTest"
62 ];
63 # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException:
64 doCheck = !stdenv.isDarwin;
65 pythonImportsCheck = [
66 "accelerate"
67 ];
68
69 meta = with lib; {
70 homepage = "https://huggingface.co/docs/accelerate";
71 description = "A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision";
72 changelog = "https://github.com/huggingface/accelerate/releases/tag/v${version}";
73 license = licenses.asl20;
74 maintainers = with maintainers; [ bcdarwin ];
75 mainProgram = "accelerate";
76 };
77}