1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 pybind11,
9 setuptools,
10 setuptools-scm,
11
12 # nativeBuildInputs
13 cmake,
14 ninja,
15
16 # dependencies
17 cloudpickle,
18 importlib-metadata,
19 numpy,
20 orjson,
21 packaging,
22 pyvers,
23 torch,
24
25 # tests
26 h5py,
27 pytestCheckHook,
28}:
29
30buildPythonPackage rec {
31 pname = "tensordict";
32 version = "0.9.1";
33 pyproject = true;
34
35 src = fetchFromGitHub {
36 owner = "pytorch";
37 repo = "tensordict";
38 tag = "v${version}";
39 hash = "sha256-OdS9dw/BtSLZuY857O2njlFOMQj5IJ6v9c2aRP+H1Hc=";
40 };
41
42 build-system = [
43 pybind11
44 setuptools
45 setuptools-scm
46 ];
47
48 nativeBuildInputs = [
49 cmake
50 ninja
51 ];
52 dontUseCmakeConfigure = true;
53
54 dependencies = [
55 cloudpickle
56 importlib-metadata
57 numpy
58 orjson
59 packaging
60 pyvers
61 torch
62 ];
63
64 pythonImportsCheck = [ "tensordict" ];
65
66 # We have to delete the source because otherwise it is used instead of the installed package.
67 preCheck = ''
68 rm -rf tensordict
69 '';
70
71 nativeCheckInputs = [
72 h5py
73 pytestCheckHook
74 ];
75
76 disabledTests = [
77 # FileNotFoundError: [Errno 2] No such file or directory: '/build/source/tensordict/tensorclass.pyi
78 "test_tensorclass_instance_methods"
79 "test_tensorclass_stub_methods"
80
81 # hangs forever on some CPUs
82 "test_map_iter_interrupt_early"
83 ]
84 ++ lib.optionals stdenv.hostPlatform.isDarwin [
85 # Hangs due to the use of a pool
86 "test_chunksize_num_chunks"
87 "test_index_with_generator"
88 "test_map_exception"
89 "test_map"
90 "test_multiprocessing"
91 ];
92
93 disabledTestPaths = [
94 # torch._dynamo.exc.Unsupported: Graph break due to unsupported builtin None.ReferenceType.__new__.
95 "test/test_compile.py"
96 ]
97 ++ lib.optionals stdenv.hostPlatform.isDarwin [
98 # Hangs forever
99 "test/test_distributed.py"
100 # Hangs after testing due to pool usage
101 "test/test_h5.py"
102 "test/test_memmap.py"
103 ];
104
105 meta = {
106 description = "Pytorch dedicated tensor container";
107 changelog = "https://github.com/pytorch/tensordict/releases/tag/${src.tag}";
108 homepage = "https://github.com/pytorch/tensordict";
109 license = lib.licenses.mit;
110 maintainers = with lib.maintainers; [ GaetanLepage ];
111 };
112}