1{
2 lib,
3 buildPythonPackage,
4 pythonOlder,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9 torch,
10 which,
11
12 # dependencies
13 cloudpickle,
14 numpy,
15 orjson,
16
17 # checks
18 h5py,
19 pytestCheckHook,
20
21 stdenv,
22}:
23
24buildPythonPackage rec {
25 pname = "tensordict";
26 version = "0.5.0";
27 pyproject = true;
28
29 disabled = pythonOlder "3.8";
30
31 src = fetchFromGitHub {
32 owner = "pytorch";
33 repo = "tensordict";
34 rev = "refs/tags/v${version}";
35 hash = "sha256-jnRlN9gefR77pioIXf0qM1CP6EtpeQkBvVIecGkb/pw=";
36 };
37
38 build-system = [
39 setuptools
40 torch
41 which
42 ];
43
44 dependencies = [
45 cloudpickle
46 numpy
47 orjson
48 torch
49 ];
50
51 pythonImportsCheck = [ "tensordict" ];
52
53 # We have to delete the source because otherwise it is used instead of the installed package.
54 preCheck = ''
55 rm -rf tensordict
56 '';
57
58 nativeCheckInputs = [
59 h5py
60 pytestCheckHook
61 ];
62
63 disabledTests =
64 [
65 # Hangs forever
66 "test_copy_onto"
67
68 # EOFError (MPI related)
69 # AssertionError: assert tensor(False)
70 # + where tensor(False) = <built-in method all of Tensor object at 0x7ffe49bf87d0>()
71 "test_mp"
72
73 # torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: to_module requires TORCHDYNAMO_INLINE_INBUILT_NN_MODULES to be set.
74 "test_functional"
75
76 # hangs forever on some CPUs
77 "test_map_iter_interrupt_early"
78 ]
79 ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [
80 # RuntimeError: internal error
81 "test_add_scale_sequence"
82 "test_modules"
83 "test_setattr"
84
85 # _queue.Empty errors in multiprocessing tests
86 "test_isend"
87 ];
88
89 disabledTestPaths = lib.optionals stdenv.hostPlatform.isDarwin [
90 # torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
91 # OpenMP support not found.
92 "test/test_compile.py"
93
94 # ModuleNotFoundError: No module named 'torch._C._distributed_c10d'; 'torch._C' is not a package
95 "test/test_distributed.py"
96 ];
97
98 meta = {
99 description = "Pytorch dedicated tensor container";
100 changelog = "https://github.com/pytorch/tensordict/releases/tag/v${version}";
101 homepage = "https://github.com/pytorch/tensordict";
102 license = lib.licenses.mit;
103 maintainers = with lib.maintainers; [ GaetanLepage ];
104 };
105}