1{ lib
2, fetchFromGitHub
3, buildPythonPackage
4, numpy
5, absl-py
6, dm-tree
7, wrapt
8, tensorflow
9, tensorflow-probability
10, pytestCheckHook
11, nose }:
12
13buildPythonPackage rec {
14 pname = "trfl";
15 version = "1.2.0";
16
17 src = fetchFromGitHub {
18 owner = "deepmind";
19 repo = pname;
20 rev = "ed6eff5b79ed56923bcb102e152c01ea52451d4c";
21 hash = "sha256-UsDUKJCHSJ4gP+P95Pm7RsPpqTJqJhrsW47C7fTZ77I=";
22 };
23
24 buildInputs = [
25 absl-py
26 dm-tree
27 numpy
28 wrapt
29 ];
30
31 propagatedBuildInputs = [
32 tensorflow
33 tensorflow-probability
34 ];
35
36 nativeCheckInputs = [
37 nose
38 pytestCheckHook
39 ];
40
41 pythonImportsCheck = [
42 "trfl"
43 ];
44
45 # Tests currently fail with assertion errors
46 doCheck = false;
47
48 disabledTestPaths = [
49 # AssertionErrors
50 "trfl/indexing_ops_test.py"
51 "trfl/vtrace_ops_test.py"
52 "trfl/value_ops_test.py"
53 "trfl/target_update_ops_test.py"
54 "trfl/sequence_ops_test.py"
55 "trfl/retrace_ops_test.py"
56 "trfl/policy_ops_test.py"
57 "trfl/policy_gradient_ops_test.py"
58 "trfl/pixel_control_ops_test.py"
59 "trfl/periodic_ops_test.py"
60 "trfl/dpg_ops_test.py"
61 "trfl/distribution_ops_test.py"
62 "trfl/dist_value_ops_test.py"
63 "trfl/discrete_policy_gradient_ops_test.py"
64 "trfl/continuous_retrace_ops_test.py"
65 "trfl/clipping_ops_test.py"
66 "trfl/action_value_ops_test.py"
67 ];
68
69 meta = with lib; {
70 description = "TensorFlow Reinforcement Learning";
71 homepage = "https://github.com/deepmind/trfl";
72 license = licenses.asl20;
73 maintainers = with maintainers; [ onny ];
74 };
75}