1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9
10 # dependencies
11 absl-py,
12 h5py,
13 ml-dtypes,
14 namex,
15 numpy,
16 optree,
17 packaging,
18 rich,
19 tensorflow,
20 pythonAtLeast,
21 distutils,
22
23 # tests
24 dm-tree,
25 jax,
26 jaxlib,
27 pandas,
28 pydot,
29 pytestCheckHook,
30 tf-keras,
31 torch,
32}:
33
34buildPythonPackage rec {
35 pname = "keras";
36 version = "3.7.0";
37 pyproject = true;
38
39 src = fetchFromGitHub {
40 owner = "keras-team";
41 repo = "keras";
42 rev = "refs/tags/v${version}";
43 hash = "sha256-qidY1OmlOYPKVoxryx1bEukA7IS6rPV4jqlnuf3y39w=";
44 };
45
46 build-system = [
47 setuptools
48 ];
49
50 dependencies = [
51 absl-py
52 h5py
53 ml-dtypes
54 namex
55 numpy
56 optree
57 packaging
58 rich
59 tensorflow
60 ] ++ lib.optionals (pythonAtLeast "3.12") [ distutils ];
61
62 pythonImportsCheck = [
63 "keras"
64 "keras._tf_keras"
65 ];
66
67 nativeCheckInputs = [
68 dm-tree
69 jaxlib
70 jax
71 pandas
72 pydot
73 pytestCheckHook
74 tf-keras
75 torch
76 ];
77
78 preCheck = ''
79 export HOME=$(mktemp -d)
80 '';
81
82 disabledTests =
83 [
84 # Tries to install the package in the sandbox
85 "test_keras_imports"
86
87 # TypeError: this __dict__ descriptor does not support '_DictWrapper' objects
88 "test_reloading_default_saved_model"
89 ]
90 ++ lib.optionals stdenv.isDarwin [
91 # AttributeError: module 'numpy' has no attribute 'float128'. Did you mean: 'float16'?
92 "test_spectrogram_error"
93 ];
94
95 disabledTestPaths = [
96 # Datasets are downloaded from the internet
97 "integration_tests/dataset_tests"
98
99 # TypeError: test_custom_fit.<locals>.CustomModel.train_step() missing 1 required positional argument: 'data'
100 "integration_tests/jax_custom_fit_test.py"
101
102 # RuntimeError: Virtual devices cannot be modified after being initialized
103 "integration_tests/tf_distribute_training_test.py"
104
105 # AttributeError: 'CustomModel' object has no attribute 'zero_grad'
106 "integration_tests/torch_custom_fit_test.py"
107
108 # Fails for an unclear reason:
109 # self.assertLen(list(net.parameters()), 2
110 # AssertionError: 0 != 2
111 "integration_tests/torch_workflow_test.py"
112
113 # Most tests require internet access
114 "keras/src/applications/applications_test.py"
115
116 # TypeError: this __dict__ descriptor does not support '_DictWrapper' objects
117 "keras/src/backend/tensorflow/saved_model_test.py"
118 "keras/src/export/export_lib_test.py"
119
120 # KeyError: 'Unable to synchronously open object (bad object header version number)'
121 "keras/src/saving/file_editor_test.py"
122 ];
123
124 meta = {
125 description = "Multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch";
126 homepage = "https://keras.io";
127 changelog = "https://github.com/keras-team/keras/releases/tag/v${version}";
128 license = lib.licenses.mit;
129 maintainers = with lib.maintainers; [ GaetanLepage ];
130 };
131}