1{
2 stdenv,
3 lib,
4 fetchurl,
5 buildPythonPackage,
6 isPy3k,
7 astor,
8 gast,
9 google-pasta,
10 wrapt,
11 numpy,
12 six,
13 termcolor,
14 packaging,
15 protobuf,
16 absl-py,
17 grpcio,
18 mock,
19 scipy,
20 distutils,
21 wheel,
22 jax,
23 ml-dtypes,
24 opt-einsum,
25 tensorflow-estimator-bin,
26 tensorboard,
27 config,
28 cudaSupport ? config.cudaSupport,
29 cudaPackages,
30 zlib,
31 python,
32 keras-applications,
33 keras-preprocessing,
34 addDriverRunpath,
35 astunparse,
36 flatbuffers,
37 h5py,
38 llvmPackages,
39 typing-extensions,
40}:
41
42# We keep this binary build for three reasons:
43# - the source build doesn't work on Darwin.
44# - the source build is currently brittle and not easy to maintain
45# - the source build doesn't work on NVIDIA Jetson platforms
46
47# unsupported combination
48assert !(stdenv.hostPlatform.isDarwin && cudaSupport);
49
50let
51 packages = import ./binary-hashes.nix;
52 inherit (cudaPackages) cudatoolkit cudnn;
53
54 isCudaJetson = cudaSupport && cudaPackages.cudaFlags.isJetsonBuild;
55 isCudaX64 = cudaSupport && stdenv.hostPlatform.isx86_64;
56in
57buildPythonPackage rec {
58 pname = "tensorflow" + lib.optionalString cudaSupport "-gpu";
59 version = packages."${"version" + lib.optionalString isCudaJetson "_jetson"}";
60 format = "wheel";
61
62 src =
63 let
64 pyVerNoDot = lib.strings.stringAsChars (x: lib.optionalString (x != ".") x) python.pythonVersion;
65 platform = stdenv.system;
66 cuda = lib.optionalString cudaSupport (if isCudaJetson then "_jetson" else "_gpu");
67 key = "${platform}_${pyVerNoDot}${cuda}";
68 in
69 fetchurl (packages.${key} or (throw "tensoflow-bin: unsupported configuration: ${key}"));
70
71 buildInputs = [ llvmPackages.openmp ];
72
73 dependencies = [
74 astunparse
75 flatbuffers
76 typing-extensions
77 packaging
78 protobuf
79 numpy
80 scipy
81 (if isCudaX64 then jax else ml-dtypes)
82 termcolor
83 grpcio
84 six
85 astor
86 absl-py
87 gast
88 opt-einsum
89 google-pasta
90 wrapt
91 tensorflow-estimator-bin
92 tensorboard
93 keras-applications
94 keras-preprocessing
95 h5py
96 ] ++ lib.optional (!isPy3k) mock;
97
98 build-system =
99 [
100 distutils
101 wheel
102 ]
103 ++ lib.optionals cudaSupport [ addDriverRunpath ]
104 ++ lib.optionals isCudaJetson [ cudaPackages.autoAddCudaCompatRunpath ];
105
106 preConfigure = ''
107 unset SOURCE_DATE_EPOCH
108
109 # Make sure that dist and the wheel file are writable.
110 chmod u+rwx -R ./dist
111
112 pushd dist
113
114 for f in tensorflow-*+nv*.whl; do
115 # e.g. *nv24.07* -> *nv24.7*
116 mv "$f" "$(sed -E 's/(nv[0-9]+)\.0*([0-9]+)/\1.\2/' <<< "$f")"
117 done
118
119 wheel unpack --dest unpacked ./*.whl
120 rm ./*.whl
121 (
122 cd unpacked/tensorflow*
123 # Adjust dependency requirements:
124 # - Relax flatbuffers, gast, protobuf, tensorboard, and tensorflow-estimator version requirements that don't match what we have packaged
125 # - The purpose of python3Packages.libclang is not clear at the moment and we don't have it packaged yet
126 # - keras and tensorlow-io-gcs-filesystem will be considered as optional for now.
127 # - numpy was pinned to fix some internal tests: https://github.com/tensorflow/tensorflow/issues/60216
128 sed -i *.dist-info/METADATA \
129 -e "/Requires-Dist: flatbuffers/d" \
130 -e "/Requires-Dist: gast/d" \
131 -e "/Requires-Dist: keras/d" \
132 -e "/Requires-Dist: libclang/d" \
133 -e "/Requires-Dist: protobuf/d" \
134 -e "/Requires-Dist: tensorboard/d" \
135 -e "/Requires-Dist: tensorflow-estimator/d" \
136 -e "/Requires-Dist: tensorflow-io-gcs-filesystem/d" \
137 -e "s/Requires-Dist: numpy (.*)/Requires-Dist: numpy/"
138 )
139 wheel pack ./unpacked/tensorflow*
140
141 popd
142 '';
143
144 postFixup =
145 # When using the cpu-only wheel, the final package will be named `tensorflow_cpu`.
146 # Then, in each package requiring `tensorflow`, our pythonRuntimeDepsCheck will fail with:
147 # importlib.metadata.PackageNotFoundError: No package metadata was found for tensorflow
148 # Hence, we manually rename the package to `tensorflow`.
149 lib.optionalString ((builtins.match ".*tensorflow_cpu.*" src.url) != null) ''
150 (
151 cd $out/${python.sitePackages}
152
153 dest="tensorflow-${version}.dist-info"
154
155 mv tensorflow_cpu-${version}.dist-info "$dest"
156
157 (
158 cd "$dest"
159
160 substituteInPlace METADATA \
161 --replace-fail "tensorflow_cpu" "tensorflow"
162 substituteInPlace RECORD \
163 --replace-fail "tensorflow_cpu" "tensorflow"
164 )
165 )
166 ''
167 # Note that we need to run *after* the fixup phase because the
168 # libraries are loaded at runtime. If we run in preFixup then
169 # patchelf --shrink-rpath will remove the cuda libraries.
170 + (
171 let
172 # rpaths we only need to add if CUDA is enabled.
173 cudapaths = lib.optionals cudaSupport [
174 cudatoolkit.out
175 cudatoolkit.lib
176 cudnn
177 ];
178
179 libpaths = [
180 (lib.getLib stdenv.cc.cc)
181 zlib
182 ];
183
184 rpath = lib.makeLibraryPath (libpaths ++ cudapaths);
185 in
186 lib.optionalString stdenv.hostPlatform.isLinux ''
187 # This is an array containing all the directories in the tensorflow2
188 # package that contain .so files.
189 #
190 # TODO: Create this list programmatically, and remove paths that aren't
191 # actually needed.
192 rrPathArr=(
193 "$out/${python.sitePackages}/tensorflow/"
194 "$out/${python.sitePackages}/tensorflow/core/kernels"
195 "$out/${python.sitePackages}/tensorflow/compiler/mlir/stablehlo/"
196 "$out/${python.sitePackages}/tensorflow/compiler/tf2tensorrt/"
197 "$out/${python.sitePackages}/tensorflow/compiler/tf2xla/ops/"
198 "$out/${python.sitePackages}/tensorflow/include/external/ml_dtypes/"
199 "$out/${python.sitePackages}/tensorflow/lite/experimental/microfrontend/python/ops/"
200 "$out/${python.sitePackages}/tensorflow/lite/python/analyzer_wrapper/"
201 "$out/${python.sitePackages}/tensorflow/lite/python/interpreter_wrapper/"
202 "$out/${python.sitePackages}/tensorflow/lite/python/metrics/"
203 "$out/${python.sitePackages}/tensorflow/lite/python/optimize/"
204 "$out/${python.sitePackages}/tensorflow/python/"
205 "$out/${python.sitePackages}/tensorflow/python/autograph/impl/testing"
206 "$out/${python.sitePackages}/tensorflow/python/client"
207 "$out/${python.sitePackages}/tensorflow/python/data/experimental/service"
208 "$out/${python.sitePackages}/tensorflow/python/framework"
209 "$out/${python.sitePackages}/tensorflow/python/grappler"
210 "$out/${python.sitePackages}/tensorflow/python/lib/core"
211 "$out/${python.sitePackages}/tensorflow/python/lib/io"
212 "$out/${python.sitePackages}/tensorflow/python/platform"
213 "$out/${python.sitePackages}/tensorflow/python/profiler/internal"
214 "$out/${python.sitePackages}/tensorflow/python/saved_model"
215 "$out/${python.sitePackages}/tensorflow/python/util"
216 "$out/${python.sitePackages}/tensorflow/tsl/python/lib/core"
217 "$out/${python.sitePackages}/tensorflow.libs/"
218 "${rpath}"
219 )
220
221 # The the bash array into a colon-separated list of RPATHs.
222 rrPath=$(IFS=$':'; echo "''${rrPathArr[*]}")
223 echo "about to run patchelf with the following rpath: $rrPath"
224
225 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
226 echo "about to patchelf $lib..."
227 chmod a+rx "$lib"
228 patchelf --set-rpath "$rrPath" "$lib"
229 ${lib.optionalString cudaSupport ''
230 addDriverRunpath "$lib"
231 ''}
232 done
233 ''
234 );
235
236 # Upstream has a pip hack that results in bin/tensorboard being in both tensorflow
237 # and the propagated input tensorboard, which causes environment collisions.
238 # Another possibility would be to have tensorboard only in the buildInputs
239 # See https://github.com/NixOS/nixpkgs/pull/44381 for more information.
240 postInstall = ''
241 rm $out/bin/tensorboard
242 '';
243
244 pythonImportsCheck = [
245 "tensorflow"
246 "tensorflow.python"
247 "tensorflow.python.framework"
248 ];
249
250 meta = {
251 description = "Computation using data flow graphs for scalable machine learning";
252 homepage = "http://tensorflow.org";
253 sourceProvenance = with lib.sourceTypes; [ binaryNativeCode ];
254 license = lib.licenses.asl20;
255 maintainers = with lib.maintainers; [
256 jyp
257 abbradar
258 ];
259 badPlatforms = [ "x86_64-darwin" ];
260 };
261}