1{
2 lib,
3 buildPythonPackage,
4 pythonOlder,
5 fetchFromGitHub,
6 setuptools,
7 torch,
8 wheel,
9 which,
10 cloudpickle,
11 numpy,
12 h5py,
13 pytestCheckHook,
14 stdenv,
15 pythonAtLeast,
16}:
17
18buildPythonPackage rec {
19 pname = "tensordict";
20 version = "0.4.0";
21 pyproject = true;
22
23 disabled = pythonOlder "3.8";
24
25 src = fetchFromGitHub {
26 owner = "pytorch";
27 repo = "tensordict";
28 rev = "refs/tags/v${version}";
29 hash = "sha256-wKEzNaaazGEkoElzp93RIlq/r5uRUdM7UyDy/DygIEc=";
30 };
31
32 build-system = [
33 setuptools
34 torch
35 wheel
36 which
37 ];
38
39 dependencies = [
40 cloudpickle
41 numpy
42 torch
43 ];
44
45 pythonImportsCheck = [ "tensordict" ];
46
47 # We have to delete the source because otherwise it is used instead of the installed package.
48 preCheck = ''
49 rm -rf tensordict
50 '';
51
52 nativeCheckInputs = [
53 h5py
54 pytestCheckHook
55 ];
56
57 # RuntimeError: internal error
58 disabledTests = lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [
59 "test_add_scale_sequence"
60 "test_modules"
61 "test_setattr"
62 ];
63
64 # ModuleNotFoundError: No module named 'torch._C._distributed_c10d'; 'torch._C' is not a package
65 disabledTestPaths = lib.optionals stdenv.isDarwin [ "test/test_distributed.py" ];
66
67 meta = with lib; {
68 description = "A pytorch dedicated tensor container";
69 changelog = "https://github.com/pytorch/tensordict/releases/tag/v${version}";
70 homepage = "https://github.com/pytorch/tensordict";
71 license = licenses.mit;
72 maintainers = with maintainers; [ GaetanLepage ];
73 # No python 3.12 support yet: https://github.com/pytorch/rl/issues/2035
74 broken = pythonAtLeast "3.12";
75 };
76}