1{
2 lib,
3 fetchFromGitHub,
4 buildPythonPackage,
5 numpy,
6 absl-py,
7 dm-tree,
8 wrapt,
9 tensorflow,
10 tensorflow-probability,
11 pytestCheckHook,
12 nose,
13}:
14
15buildPythonPackage rec {
16 pname = "trfl";
17 version = "1.2.0";
18 format = "setuptools";
19
20 src = fetchFromGitHub {
21 owner = "deepmind";
22 repo = pname;
23 rev = "ed6eff5b79ed56923bcb102e152c01ea52451d4c";
24 hash = "sha256-UsDUKJCHSJ4gP+P95Pm7RsPpqTJqJhrsW47C7fTZ77I=";
25 };
26
27 buildInputs = [
28 absl-py
29 dm-tree
30 numpy
31 wrapt
32 ];
33
34 propagatedBuildInputs = [
35 tensorflow
36 tensorflow-probability
37 ];
38
39 nativeCheckInputs = [
40 nose
41 pytestCheckHook
42 ];
43
44 pythonImportsCheck = [ "trfl" ];
45
46 # Tests currently fail with assertion errors
47 doCheck = false;
48
49 disabledTestPaths = [
50 # AssertionErrors
51 "trfl/indexing_ops_test.py"
52 "trfl/vtrace_ops_test.py"
53 "trfl/value_ops_test.py"
54 "trfl/target_update_ops_test.py"
55 "trfl/sequence_ops_test.py"
56 "trfl/retrace_ops_test.py"
57 "trfl/policy_ops_test.py"
58 "trfl/policy_gradient_ops_test.py"
59 "trfl/pixel_control_ops_test.py"
60 "trfl/periodic_ops_test.py"
61 "trfl/dpg_ops_test.py"
62 "trfl/distribution_ops_test.py"
63 "trfl/dist_value_ops_test.py"
64 "trfl/discrete_policy_gradient_ops_test.py"
65 "trfl/continuous_retrace_ops_test.py"
66 "trfl/clipping_ops_test.py"
67 "trfl/action_value_ops_test.py"
68 ];
69
70 meta = with lib; {
71 description = "TensorFlow Reinforcement Learning";
72 homepage = "https://github.com/deepmind/trfl";
73 license = licenses.asl20;
74 maintainers = with maintainers; [ onny ];
75 };
76}