nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 stdenv,
3 #bazel_5,
4 bazel,
5 buildBazelPackage,
6 lib,
7 fetchFromGitHub,
8 symlinkJoin,
9 addDriverRunpath,
10 fetchpatch,
11 fetchzip,
12 linkFarm,
13 # Python deps
14 buildPythonPackage,
15 pythonAtLeast,
16 python,
17 # Python libraries
18 numpy,
19 tensorboard,
20 abseil-cpp,
21 absl-py,
22 packaging,
23 setuptools,
24 wheel,
25 google-pasta,
26 opt-einsum,
27 astunparse,
28 h5py,
29 termcolor,
30 grpcio,
31 six,
32 wrapt,
33 protobuf-python,
34 tensorflow-estimator-bin,
35 dill,
36 flatbuffers-python,
37 portpicker,
38 tblib,
39 typing-extensions,
40 # Common deps
41 git,
42 pybind11,
43 which,
44 binutils,
45 glibcLocales,
46 cython,
47 perl,
48 # Common libraries
49 jemalloc,
50 mpi,
51 gast,
52 grpc,
53 sqlite,
54 boringssl,
55 jsoncpp,
56 nsync,
57 curl,
58 snappy-cpp,
59 flatbuffers-core,
60 icu,
61 double-conversion,
62 libpng,
63 libjpeg_turbo,
64 giflib,
65 protobuf-core,
66 # Upstream by default includes cuda support since tensorflow 1.15. We could do
67 # that in nix as well. It would make some things easier and less confusing, but
68 # it would also make the default tensorflow package unfree. See
69 # https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0
70 config,
71 cudaSupport ? config.cudaSupport,
72 cudaPackages,
73 cudaCapabilities ? cudaPackages.flags.cudaCapabilities,
74 mklSupport ? false,
75 mkl,
76 tensorboardSupport ? true,
77 # XLA without CUDA is broken
78 xlaSupport ? cudaSupport,
79 sse42Support ? stdenv.hostPlatform.sse4_2Support,
80 avx2Support ? stdenv.hostPlatform.avx2Support,
81 fmaSupport ? stdenv.hostPlatform.fmaSupport,
82 cctools,
83 llvmPackages,
84}:
85
86let
87 originalStdenv = stdenv;
88in
89let
90 # Tensorflow looks at many toolchain-related variables which may diverge.
91 #
92 # Toolchain for cuda-enabled builds.
93 # We want to achieve two things:
94 # 1. NVCC should use a compatible back-end (e.g. gcc11 for cuda11)
95 # 2. Normal C++ files should be compiled with the same toolchain,
96 # to avoid potential weird dynamic linkage errors at runtime.
97 # This may not be necessary though
98 #
99 # Toolchain for Darwin:
100 # clang 7 fails to emit a symbol for
101 # __ZN4llvm11SmallPtrSetIPKNS_10AllocaInstELj8EED1Ev in any of the
102 # translation units, so the build fails at link time
103 stdenv =
104 if cudaSupport then
105 cudaPackages.backendStdenv
106 else if originalStdenv.hostPlatform.isDarwin then
107 llvmPackages.stdenv
108 else
109 originalStdenv;
110 inherit (cudaPackages) cudatoolkit nccl;
111 # use compatible cuDNN (https://www.tensorflow.org/install/source#gpu)
112 # cudaPackages.cudnn led to this:
113 # https://github.com/tensorflow/tensorflow/issues/60398
114 #cudnnAttribute = "cudnn_8_6";
115 cudnnAttribute = "cudnn";
116 cudnnMerged = symlinkJoin {
117 name = "cudnn-merged";
118 paths = [
119 (lib.getDev cudaPackages.${cudnnAttribute})
120 (lib.getLib cudaPackages.${cudnnAttribute})
121 ];
122 };
123 gentoo-patches = fetchzip {
124 url = "https://dev.gentoo.org/~perfinion/patches/tensorflow-patches-2.12.0.tar.bz2";
125 hash = "sha256-SCRX/5/zML7LmKEPJkcM5Tebez9vv/gmE4xhT/jyqWs=";
126 };
127 protobuf-extra = linkFarm "protobuf-extra" [
128 {
129 name = "include";
130 path = protobuf-core.src;
131 }
132 ];
133
134 withTensorboard = tensorboardSupport;
135
136 cudaComponents = with cudaPackages; [
137 (cuda_nvcc.__spliced.buildHost or cuda_nvcc)
138 (cuda_nvprune.__spliced.buildHost or cuda_nvprune)
139 cuda_cccl # block_load.cuh
140 cuda_cudart # cuda.h
141 cuda_cupti # cupti.h
142 cuda_nvcc # See https://github.com/google/jax/issues/19811
143 cuda_nvml_dev # nvml.h
144 cuda_nvtx # nvToolsExt.h
145 libcublas # cublas_api.h
146 libcufft # cufft.h
147 libcurand # curand.h
148 libcusolver # cusolver_common.h
149 libcusparse # cusparse.h
150 ];
151
152 cudatoolkitDevMerged = symlinkJoin {
153 name = "cuda-${cudaPackages.cudaMajorMinorVersion}-dev-merged";
154 paths = lib.concatMap (p: [
155 (lib.getBin p)
156 (lib.getDev p)
157 (lib.getLib p)
158 (lib.getOutput "static" p) # Makes for a very fat closure
159 ]) cudaComponents;
160 };
161
162 # Tensorflow expects bintools at hard-coded paths, e.g. /usr/bin/ar
163 # The only way to overcome that is to set GCC_HOST_COMPILER_PREFIX,
164 # but that path must contain cc as well, so we merge them
165 cudatoolkit_cc_joined = symlinkJoin {
166 name = "${stdenv.cc.name}-merged";
167 paths = [
168 stdenv.cc
169 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
170 ];
171 };
172
173 # Needed for _some_ system libraries, grep INCLUDEDIR.
174 includes_joined = symlinkJoin {
175 name = "tensorflow-deps-merged";
176 paths = [ jsoncpp ];
177 };
178
179 tfFeature = x: if x then "1" else "0";
180
181 version = "2.13.0";
182 format = "setuptools";
183 variant = lib.optionalString cudaSupport "-gpu";
184 pname = "tensorflow${variant}";
185
186 pythonEnv = python.withPackages (_: [
187 # python deps needed during wheel build time (not runtime, see the buildPythonPackage part for that)
188 # This list can likely be shortened, but each trial takes multiple hours so won't bother for now.
189 absl-py
190 astunparse
191 dill
192 flatbuffers-python
193 gast
194 google-pasta
195 grpcio
196 h5py
197 numpy
198 opt-einsum
199 packaging
200 protobuf-python
201 setuptools
202 six
203 tblib
204 tensorboard
205 tensorflow-estimator-bin
206 termcolor
207 typing-extensions
208 wheel
209 wrapt
210 ]);
211
212 rules_cc_darwin_patched = stdenv.mkDerivation {
213 pname = "rules_cc-${pname}";
214 inherit version;
215
216 src = _bazel-build.deps;
217
218 prePatch = "pushd rules_cc";
219 patches = [
220 # https://github.com/bazelbuild/rules_cc/issues/122
221 (fetchpatch {
222 name = "tensorflow-rules_cc-libtool-path.patch";
223 url = "https://github.com/bazelbuild/rules_cc/commit/8c427ab30bf213630dc3bce9d2e9a0e29d1787db.diff";
224 hash = "sha256-C4v6HY5+jm0ACUZ58gBPVejCYCZfuzYKlHZ0m2qDHCk=";
225 })
226
227 # https://github.com/bazelbuild/rules_cc/pull/124
228 (fetchpatch {
229 name = "tensorflow-rules_cc-install_name_tool-path.patch";
230 url = "https://github.com/bazelbuild/rules_cc/commit/156497dc89100db8a3f57b23c63724759d431d05.diff";
231 hash = "sha256-NES1KeQmMiUJQVoV6dS4YGRxxkZEjOpFSCyOq9HZYO0=";
232 })
233 ];
234 postPatch = "popd";
235
236 dontConfigure = true;
237 dontBuild = true;
238
239 installPhase = ''
240 runHook preInstall
241
242 mv rules_cc/ "$out"
243
244 runHook postInstall
245 '';
246 };
247 llvm-raw_darwin_patched = stdenv.mkDerivation {
248 pname = "llvm-raw-${pname}";
249 inherit version;
250
251 src = _bazel-build.deps;
252
253 prePatch = "pushd llvm-raw";
254 patches = [
255 # Fix a vendored config.h that requires the 10.13 SDK
256 ./llvm_bazel_fix_macos_10_12_sdk.patch
257 ];
258 postPatch = ''
259 touch {BUILD,WORKSPACE}
260 popd
261 '';
262
263 dontConfigure = true;
264 dontBuild = true;
265
266 installPhase = ''
267 runHook preInstall
268
269 mv llvm-raw/ "$out"
270
271 runHook postInstall
272 '';
273 };
274 bazel-build =
275 if stdenv.hostPlatform.isDarwin then
276 _bazel-build.overrideAttrs (prev: {
277 bazelFlags = prev.bazelFlags ++ [
278 "--override_repository=rules_cc=${rules_cc_darwin_patched}"
279 "--override_repository=llvm-raw=${llvm-raw_darwin_patched}"
280 ];
281 preBuild = ''
282 export AR="${cctools}/bin/libtool"
283 '';
284 })
285 else
286 _bazel-build;
287
288 _bazel-build = buildBazelPackage.override { inherit stdenv; } {
289 inherit pname version;
290 #bazel = bazel_5;
291 bazel = bazel;
292
293 src = fetchFromGitHub {
294 owner = "tensorflow";
295 repo = "tensorflow";
296 tag = "v${version}";
297 hash = "sha256-Rq5pAVmxlWBVnph20fkAwbfy+iuBNlfFy14poDPd5h0=";
298 };
299
300 # On update, it can be useful to steal the changes from gentoo
301 # https://gitweb.gentoo.org/repo/gentoo.git/tree/sci-libs/tensorflow
302
303 nativeBuildInputs = [
304 which
305 pythonEnv
306 cython
307 perl
308 protobuf-core
309 protobuf-extra
310 ]
311 ++ lib.optional cudaSupport addDriverRunpath;
312
313 buildInputs = [
314 jemalloc
315 mpi
316 glibcLocales
317 git
318
319 # libs taken from system through the TF_SYS_LIBS mechanism
320 abseil-cpp
321 boringssl
322 curl
323 double-conversion
324 flatbuffers-core
325 giflib
326 grpc
327 # Necessary to fix the "`GLIBCXX_3.4.30' not found" error
328 (icu.override { inherit stdenv; })
329 jsoncpp
330 libjpeg_turbo
331 libpng
332 (pybind11.override (prev: {
333 buildPythonPackage = prev.buildPythonPackage.override {
334 inherit stdenv;
335 };
336 }))
337 snappy-cpp
338 sqlite
339 ]
340 ++ lib.optionals cudaSupport [
341 cudatoolkit
342 cudnnMerged
343 ]
344 ++ lib.optionals mklSupport [ mkl ]
345 ++ lib.optionals (!stdenv.hostPlatform.isDarwin) [ nsync ];
346
347 env = {
348 # arbitrarily set to the current latest bazel version, overly careful
349 TF_IGNORE_MAX_BAZEL_VERSION = true;
350
351 LIBTOOL = lib.optionalString stdenv.hostPlatform.isDarwin "${cctools}/bin/libtool";
352
353 # Take as many libraries from the system as possible. Keep in sync with
354 # list of valid syslibs in
355 # https://github.com/tensorflow/tensorflow/blob/master/third_party/systemlibs/syslibs_configure.bzl
356 TF_SYSTEM_LIBS = lib.concatStringsSep "," (
357 [
358 "absl_py"
359 "astor_archive"
360 "astunparse_archive"
361 "boringssl"
362 "com_google_absl"
363 # Not packaged in nixpkgs
364 # "com_github_googleapis_googleapis"
365 # "com_github_googlecloudplatform_google_cloud_cpp"
366 "com_github_grpc_grpc"
367 "com_google_protobuf"
368 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
369 # "com_googlesource_code_re2"
370 "curl"
371 "cython"
372 "dill_archive"
373 "double_conversion"
374 "flatbuffers"
375 "functools32_archive"
376 "gast_archive"
377 "gif"
378 "hwloc"
379 "icu"
380 "jsoncpp_git"
381 "libjpeg_turbo"
382 "nasm"
383 "opt_einsum_archive"
384 "org_sqlite"
385 "pasta"
386 "png"
387 "pybind11"
388 "six_archive"
389 "snappy"
390 "tblib_archive"
391 "termcolor_archive"
392 "typing_extensions_archive"
393 "wrapt"
394 "zlib"
395 ]
396 ++ lib.optionals (!stdenv.hostPlatform.isDarwin) [
397 "nsync" # fails to build on darwin
398 ]
399 );
400
401 INCLUDEDIR = "${includes_joined}/include";
402
403 # This is needed for the Nix-provided protobuf dependency to work,
404 # as otherwise the rule `link_proto_files` tries to create the links
405 # to `/usr/include/...` which results in build failures.
406 PROTOBUF_INCLUDE_PATH = "${protobuf-core}/include";
407
408 PYTHON_BIN_PATH = pythonEnv.interpreter;
409
410 TF_NEED_GCP = true;
411 TF_NEED_HDFS = true;
412 TF_ENABLE_XLA = tfFeature xlaSupport;
413
414 CC_OPT_FLAGS = " ";
415
416 # https://github.com/tensorflow/tensorflow/issues/14454
417 TF_NEED_MPI = tfFeature cudaSupport;
418
419 TF_NEED_CUDA = tfFeature cudaSupport;
420 TF_CUDA_PATHS = lib.optionalString cudaSupport "${cudatoolkitDevMerged},${cudnnMerged},${lib.getLib nccl}";
421 TF_CUDA_COMPUTE_CAPABILITIES = lib.concatStringsSep "," cudaCapabilities;
422
423 # Needed even when we override stdenv: e.g. for ar
424 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
425 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/cc";
426
427 # https://github.com/tensorflow/tensorflow/pull/39470
428 NIX_CFLAGS_COMPILE = toString [ "-Wno-stringop-truncation" ];
429 };
430
431 patches = [
432 "${gentoo-patches}/0002-systemlib-Latest-absl-LTS-has-split-cord-libs.patch"
433 "${gentoo-patches}/0005-systemlib-Updates-for-Abseil-20220623-LTS.patch"
434 "${gentoo-patches}/0007-systemlibs-Add-well_known_types_py_pb2-target.patch"
435 # https://github.com/conda-forge/tensorflow-feedstock/pull/329/commits/0a63c5a962451b4da99a9948323d8b3ed462f461
436 (fetchpatch {
437 name = "fix-layout-proto-duplicate-loading.patch";
438 url = "https://raw.githubusercontent.com/conda-forge/tensorflow-feedstock/0a63c5a962451b4da99a9948323d8b3ed462f461/recipe/patches/0001-Omit-linking-to-layout_proto_cc-if-protobuf-linkage-.patch";
439 hash = "sha256-/7buV6DinKnrgfqbe7KKSh9rCebeQdXv2Uj+Xg/083w=";
440 })
441 ./com_google_absl_add_log.patch
442 ./absl_py_argparse_flags.patch
443 ./protobuf_python.patch
444 ./pybind11_protobuf_python_runtime_dep.patch
445 ./pybind11_protobuf_newer_version.patch
446 ]
447 ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-darwin") [ ./absl_to_std.patch ];
448
449 postPatch = ''
450 # bazel 3.3 should work just as well as bazel 3.1
451 rm -f .bazelversion
452 patchShebangs .
453 ''
454 + lib.optionalString (!withTensorboard) ''
455 # Tensorboard pulls in a bunch of dependencies, some of which may
456 # include security vulnerabilities. So we make it optional.
457 # https://github.com/tensorflow/tensorflow/issues/20280#issuecomment-400230560
458 sed -i '/tensorboard ~=/d' tensorflow/tools/pip_package/setup.py
459 '';
460
461 preConfigure =
462 let
463 opt_flags =
464 [ ]
465 ++ lib.optionals sse42Support [ "-msse4.2" ]
466 ++ lib.optionals avx2Support [ "-mavx2" ]
467 ++ lib.optionals fmaSupport [ "-mfma" ];
468 in
469 ''
470 patchShebangs configure
471
472 # dummy ldconfig
473 mkdir dummy-ldconfig
474 echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig
475 chmod +x dummy-ldconfig/ldconfig
476 export PATH="$PWD/dummy-ldconfig:$PATH"
477
478 export PYTHON_LIB_PATH="$NIX_BUILD_TOP/site-packages"
479 export CC_OPT_FLAGS="${lib.concatStringsSep " " opt_flags}"
480 mkdir -p "$PYTHON_LIB_PATH"
481
482 # To avoid mixing Python 2 and Python 3
483 unset PYTHONPATH
484 '';
485
486 configurePhase = ''
487 runHook preConfigure
488 ./configure
489 runHook postConfigure
490 '';
491
492 hardeningDisable = [ "format" ];
493
494 bazelBuildFlags = [
495 "--config=opt" # optimize using the flags set in the configure phase
496 ]
497 ++ lib.optionals stdenv.cc.isClang [
498 "--cxxopt=-x"
499 "--cxxopt=c++"
500 "--host_cxxopt=-x"
501 "--host_cxxopt=c++"
502
503 # workaround for https://github.com/bazelbuild/bazel/issues/15359
504 "--spawn_strategy=sandboxed"
505 ]
506 ++ lib.optionals mklSupport [ "--config=mkl" ];
507
508 bazelTargets = [
509 "//tensorflow/tools/pip_package:build_pip_package //tensorflow/tools/lib_package:libtensorflow"
510 ];
511
512 removeRulesCC = false;
513 # Without this Bazel complaints about sandbox violations.
514 dontAddBazelOpts = true;
515
516 fetchAttrs = {
517 sha256 =
518 {
519 x86_64-linux =
520 if cudaSupport then
521 "sha256-5VFMNHeLrUxW5RTr6EhT3pay9nWJ5JkZTGirDds5QkU="
522 else
523 "sha256-KzgWV69Btr84FdwQ5JI2nQEsqiPg1/+TWdbw5bmxXOE=";
524 aarch64-linux =
525 if cudaSupport then
526 "sha256-ty5+51BwHWE1xR4/0WcWTp608NzSAS/iiyN+9zx7/wI="
527 else
528 "sha256-9btXrNHqd720oXTPDhSmFidv5iaZRLjCVX8opmrMjXk=";
529 x86_64-darwin = "sha256-gqb03kB0z2pZQ6m1fyRp1/Nbt8AVVHWpOJSeZNCLc4w=";
530 aarch64-darwin = "sha256-WdgAaFZU+ePwWkVBhLzjlNT7ELfGHOTaMdafcAMD5yo=";
531 }
532 .${stdenv.hostPlatform.system} or (throw "unsupported system ${stdenv.hostPlatform.system}");
533 };
534
535 buildAttrs = {
536 outputs = [
537 "out"
538 "python"
539 ];
540
541 # need to rebuild schemas since we use a different flatbuffers version
542 preBuild = ''
543 (cd tensorflow/lite/schema;${flatbuffers-core}/bin/flatc --gen-object-api -c schema.fbs)
544 (cd tensorflow/lite/schema;${flatbuffers-core}/bin/flatc --gen-object-api -c conversion_metadata.fbs)
545 (cd tensorflow/lite/acceleration/configuration;${flatbuffers-core}/bin/flatc -o configuration.fbs --proto configuration.proto)
546 sed -i s,tflite.proto,tflite,g tensorflow/lite/acceleration/configuration/configuration.fbs/configuration.fbs
547 (cd tensorflow/lite/acceleration/configuration;${flatbuffers-core}/bin/flatc --gen-compare --gen-object-api -c configuration.fbs/configuration.fbs)
548 cp -r tensorflow/lite/acceleration/configuration/configuration.fbs tensorflow/lite/experimental/acceleration/configuration
549 (cd tensorflow/lite/experimental/acceleration/configuration;${flatbuffers-core}/bin/flatc -c configuration.fbs/configuration.fbs)
550 (cd tensorflow/lite/delegates/gpu/cl;${flatbuffers-core}/bin/flatc -c compiled_program_cache.fbs)
551 (cd tensorflow/lite/delegates/gpu/cl;${flatbuffers-core}/bin/flatc -I $NIX_BUILD_TOP/source -c serialization.fbs)
552 (cd tensorflow/lite/delegates/gpu/common;${flatbuffers-core}/bin/flatc -I $NIX_BUILD_TOP/source -c gpu_model.fbs)
553 (cd tensorflow/lite/delegates/gpu/common/task;${flatbuffers-core}/bin/flatc -c serialization_base.fbs)
554 patchShebangs .
555 '';
556
557 installPhase = ''
558 mkdir -p "$out"
559 tar -xf bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz -C "$out"
560 # Write pkgconfig file.
561 mkdir "$out/lib/pkgconfig"
562 cat > "$out/lib/pkgconfig/tensorflow.pc" << EOF
563 Name: TensorFlow
564 Version: ${version}
565 Description: Library for computation using data flow graphs for scalable machine learning
566 Requires:
567 Libs: -L$out/lib -ltensorflow
568 Cflags: -I$out/include/tensorflow
569 EOF
570
571 # build the source code, then copy it to $python (build_pip_package
572 # actually builds a symlink farm so we must dereference them).
573 bazel-bin/tensorflow/tools/pip_package/build_pip_package --src "$PWD/dist"
574 cp -Lr "$PWD/dist" "$python"
575 '';
576
577 postFixup = lib.optionalString cudaSupport ''
578 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
579 addDriverRunpath "$lib"
580 done
581 '';
582
583 requiredSystemFeatures = [ "big-parallel" ];
584 };
585
586 meta = {
587 badPlatforms = lib.optionals cudaSupport lib.platforms.darwin;
588 changelog = "https://github.com/tensorflow/tensorflow/releases/tag/v${version}";
589 description = "Computation using data flow graphs for scalable machine learning";
590 homepage = "http://tensorflow.org";
591 license = lib.licenses.asl20;
592 maintainers = [ ];
593 platforms = with lib.platforms; linux ++ darwin;
594 broken =
595 # Dependencies are EOL and have been removed; an update
596 # to a newer TensorFlow version will be required to fix the
597 # source build.
598 true
599 || stdenv.hostPlatform.isDarwin
600 || !(xlaSupport -> cudaSupport)
601 || !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackages)
602 || !(cudaSupport -> cudaPackages ? cudatoolkit);
603 }
604 // lib.optionalAttrs stdenv.hostPlatform.isDarwin {
605 timeout = 86400; # 24 hours
606 maxSilent = 14400; # 4h, double the default of 7200s
607 };
608 };
609in
610buildPythonPackage {
611 __structuredAttrs = true;
612 inherit version pname format;
613 disabled = pythonAtLeast "3.13";
614
615 src = bazel-build.python;
616
617 # Adjust dependency requirements:
618 # - Drop tensorflow-io dependency until we get it to build
619 # - Relax flatbuffers and gast version requirements
620 # - The purpose of python3Packages.libclang is not clear at the moment and we don't have it packaged yet
621 # - keras will be considered as optional for now.
622 postPatch = ''
623 sed -i setup.py \
624 -e '/tensorflow-io-gcs-filesystem/,+1d' \
625 -e "s/'flatbuffers[^']*',/'flatbuffers',/" \
626 -e "s/'gast[^']*',/'gast',/" \
627 -e "/'libclang[^']*',/d" \
628 -e "/'keras[^']*')\?,/d" \
629 -e "s/'protobuf[^']*',/'protobuf',/" \
630 '';
631
632 # Upstream has a pip hack that results in bin/tensorboard being in both tensorflow
633 # and the propagated input tensorboard, which causes environment collisions.
634 # Another possibility would be to have tensorboard only in the buildInputs
635 # https://github.com/tensorflow/tensorflow/blob/v1.7.1/tensorflow/tools/pip_package/setup.py#L79
636 postInstall = ''
637 rm $out/bin/tensorboard
638 '';
639
640 setupPyGlobalFlags = [
641 "--project_name"
642 pname
643 ];
644
645 # tensorflow/tools/pip_package/setup.py
646 propagatedBuildInputs = [
647 absl-py
648 abseil-cpp
649 astunparse
650 flatbuffers-python
651 gast
652 google-pasta
653 grpcio
654 h5py
655 numpy
656 opt-einsum
657 packaging
658 protobuf-python
659 six
660 tensorflow-estimator-bin
661 termcolor
662 typing-extensions
663 wrapt
664 ]
665 ++ lib.optionals withTensorboard [ tensorboard ];
666
667 nativeBuildInputs = lib.optionals cudaSupport [ addDriverRunpath ];
668
669 postFixup = lib.optionalString cudaSupport ''
670 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
671 addDriverRunpath "$lib"
672
673 patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnnMerged}/lib:${lib.getLib nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
674 done
675 '';
676
677 # Actual tests are slow and impure.
678 # TODO try to run them anyway
679 # TODO better test (files in tensorflow/tools/ci_build/builds/*test)
680 # TEST_PACKAGES in tensorflow/tools/pip_package/setup.py
681 nativeCheckInputs = [
682 dill
683 portpicker
684 tblib
685 ];
686 checkPhase = ''
687 ${python.interpreter} <<EOF
688 # A simple "Hello world"
689 import tensorflow as tf
690 hello = tf.constant("Hello, world!")
691 tf.print(hello)
692
693 tf.random.set_seed(0)
694 width = 512
695 choice = 48
696 t_in = tf.Variable(tf.random.uniform(shape=[width]))
697 with tf.GradientTape() as tape:
698 t_out = tf.slice(tf.nn.softmax(t_in), [choice], [1])
699 diff = tape.gradient(t_out, t_in)
700 assert(0 < tf.reduce_min(tf.slice(diff, [choice], [1])))
701 assert(0 > tf.reduce_max(tf.slice(diff, [1], [choice - 1])))
702 EOF
703 '';
704 # Regression test for #77626 removed because not more `tensorflow.contrib`.
705
706 passthru = {
707 deps = bazel-build.deps;
708 libtensorflow = bazel-build.out;
709 };
710
711 inherit (bazel-build) meta;
712}