1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 pybind11,
9 setuptools,
10 setuptools-scm,
11
12 # nativeBuildInputs
13 cmake,
14 ninja,
15
16 # dependencies
17 cloudpickle,
18 importlib-metadata,
19 numpy,
20 orjson,
21 packaging,
22 torch,
23
24 # tests
25 h5py,
26 pytestCheckHook,
27}:
28
29buildPythonPackage rec {
30 pname = "tensordict";
31 version = "0.8.2";
32 pyproject = true;
33
34 src = fetchFromGitHub {
35 owner = "pytorch";
36 repo = "tensordict";
37 tag = "v${version}";
38 hash = "sha256-Qk+pVSwKAIOz6EczGjf4gsOsxAno/vHCgO1EQZDNTsk=";
39 };
40
41 build-system = [
42 pybind11
43 setuptools
44 setuptools-scm
45 ];
46
47 nativeBuildInputs = [
48 cmake
49 ninja
50 ];
51 dontUseCmakeConfigure = true;
52
53 dependencies = [
54 cloudpickle
55 importlib-metadata
56 numpy
57 orjson
58 packaging
59 torch
60 ];
61
62 pythonImportsCheck = [ "tensordict" ];
63
64 # We have to delete the source because otherwise it is used instead of the installed package.
65 preCheck = ''
66 rm -rf tensordict
67 '';
68
69 nativeCheckInputs = [
70 h5py
71 pytestCheckHook
72 ];
73
74 disabledTests =
75 [
76 # FileNotFoundError: [Errno 2] No such file or directory: '/build/source/tensordict/tensorclass.pyi
77 "test_tensorclass_instance_methods"
78 "test_tensorclass_stub_methods"
79
80 # hangs forever on some CPUs
81 "test_map_iter_interrupt_early"
82 ]
83 ++ lib.optionals stdenv.hostPlatform.isDarwin [
84 # Hangs due to the use of a pool
85 "test_chunksize_num_chunks"
86 "test_index_with_generator"
87 "test_map_exception"
88 "test_map"
89 "test_multiprocessing"
90 ];
91
92 disabledTestPaths =
93 [
94 # torch._dynamo.exc.Unsupported: Graph break due to unsupported builtin None.ReferenceType.__new__.
95 "test/test_compile.py"
96 ]
97 ++ lib.optionals stdenv.hostPlatform.isDarwin [
98 # Hangs forever
99 "test/test_distributed.py"
100 # Hangs after testing due to pool usage
101 "test/test_h5.py"
102 "test/test_memmap.py"
103 ];
104
105 meta = {
106 description = "Pytorch dedicated tensor container";
107 changelog = "https://github.com/pytorch/tensordict/releases/tag/${src.tag}";
108 homepage = "https://github.com/pytorch/tensordict";
109 license = lib.licenses.mit;
110 maintainers = with lib.maintainers; [ GaetanLepage ];
111 };
112}