1{ stdenv, buildBazelPackage, lib, fetchFromGitHub, fetchpatch, symlinkJoin 2, buildPythonPackage, isPy3k, pythonOlder, pythonAtLeast 3, which, swig, binutils, glibcLocales 4, python, jemalloc, openmpi 5, numpy, six, protobuf, tensorflow-tensorboard, backports_weakref, mock, enum34, absl-py 6, cudaSupport ? false, nvidia_x11 ? null, cudatoolkit ? null, cudnn ? null 7# XLA without CUDA is broken 8, xlaSupport ? cudaSupport 9# Default from ./configure script 10, cudaCapabilities ? [ "3.5" "5.2" ] 11, sse42Support ? false 12, avx2Support ? false 13, fmaSupport ? false 14}: 15 16assert cudaSupport -> nvidia_x11 != null 17 && cudatoolkit != null 18 && cudnn != null; 19 20# unsupported combination 21assert ! (stdenv.isDarwin && cudaSupport); 22 23let 24 25 withTensorboard = pythonOlder "3.6"; 26 27 cudatoolkit_joined = symlinkJoin { 28 name = "${cudatoolkit.name}-unsplit"; 29 paths = [ cudatoolkit.out cudatoolkit.lib ]; 30 }; 31 32 tfFeature = x: if x then "1" else "0"; 33 34 version = "1.5.0"; 35 36 pkg = buildBazelPackage rec { 37 name = "tensorflow-build-${version}"; 38 39 src = fetchFromGitHub { 40 owner = "tensorflow"; 41 repo = "tensorflow"; 42 rev = "v${version}"; 43 sha256 = "1c4djsaip901nasm7a6dsimr02bsv70a7b1g0kysb4n39qpdh22q"; 44 }; 45 46 patches = [ 47 # Fix build with Bazel >= 0.10 48 (fetchpatch { 49 url = "https://github.com/tensorflow/tensorflow/commit/6fcfab770c2672e2250e0f5686b9545d99eb7b2b.patch"; 50 sha256 = "0p61za1mx3a7gj1s5lsps16fcw18iwnvq2b46v1kyqfgq77a12vb"; 51 }) 52 (fetchpatch { 53 url = "https://github.com/tensorflow/tensorflow/commit/3f57956725b553d196974c9ad31badeb3eabf8bb.patch"; 54 sha256 = "11dja5gqy0qw27sc9b6yw9r0lfk8dznb32vrqqfcnypk2qmv26va"; 55 }) 56 ]; 57 58 nativeBuildInputs = [ swig which ]; 59 60 buildInputs = [ python jemalloc openmpi glibcLocales numpy ] 61 ++ lib.optionals cudaSupport [ cudatoolkit cudnn nvidia_x11 ]; 62 63 preConfigure = '' 64 patchShebangs configure 65 66 export PYTHON_BIN_PATH="${python.interpreter}" 67 export PYTHON_LIB_PATH="$NIX_BUILD_TOP/site-packages" 68 export TF_NEED_GCP=1 69 export TF_NEED_HDFS=1 70 export TF_ENABLE_XLA=${tfFeature xlaSupport} 71 export CC_OPT_FLAGS=" " 72 # https://github.com/tensorflow/tensorflow/issues/14454 73 export TF_NEED_MPI=${tfFeature cudaSupport} 74 export TF_NEED_CUDA=${tfFeature cudaSupport} 75 ${lib.optionalString cudaSupport '' 76 export CUDA_TOOLKIT_PATH=${cudatoolkit_joined} 77 export TF_CUDA_VERSION=${cudatoolkit.majorVersion} 78 export CUDNN_INSTALL_PATH=${cudnn} 79 export TF_CUDNN_VERSION=${cudnn.majorVersion} 80 export GCC_HOST_COMPILER_PATH=${cudatoolkit.cc}/bin/gcc 81 export TF_CUDA_COMPUTE_CAPABILITIES=${lib.concatStringsSep "," cudaCapabilities} 82 ''} 83 84 mkdir -p "$PYTHON_LIB_PATH" 85 ''; 86 87 NIX_LDFLAGS = lib.optionals cudaSupport [ "-lcublas" "-lcudnn" "-lcuda" "-lcudart" ]; 88 89 hardeningDisable = [ "all" ]; 90 91 bazelFlags = [ "--config=opt" ] 92 ++ lib.optional sse42Support "--copt=-msse4.2" 93 ++ lib.optional avx2Support "--copt=-mavx2" 94 ++ lib.optional fmaSupport "--copt=-mfma" 95 ++ lib.optional cudaSupport "--config=cuda"; 96 97 bazelTarget = "//tensorflow/tools/pip_package:build_pip_package"; 98 99 fetchAttrs = { 100 preInstall = '' 101 rm -rf $bazelOut/external/{bazel_tools,\@bazel_tools.marker,local_*,\@local_*} 102 ''; 103 104 sha256 = "1nc98aqrp14q7llypcwaa0kdn9xi7r0p1mnd3vmmn1m299py33ca"; 105 }; 106 107 buildAttrs = { 108 preBuild = '' 109 patchShebangs . 110 find -type f -name CROSSTOOL\* -exec sed -i \ 111 -e 's,/usr/bin/ar,${binutils.bintools}/bin/ar,g' \ 112 {} \; 113 ''; 114 115 installPhase = '' 116 sed -i 's,.*bdist_wheel.*,cp -rL . "$out"; exit 0,' bazel-bin/tensorflow/tools/pip_package/build_pip_package 117 bazel-bin/tensorflow/tools/pip_package/build_pip_package $PWD/dist 118 ''; 119 }; 120 121 dontFixup = true; 122 }; 123 124in buildPythonPackage rec { 125 pname = "tensorflow"; 126 inherit version; 127 name = "${pname}-${version}"; 128 129 src = pkg; 130 131 installFlags = lib.optional (!withTensorboard) "--no-dependencies"; 132 133 postPatch = lib.optionalString (pythonAtLeast "3.4") '' 134 sed -i '/enum34/d' setup.py 135 ''; 136 137 propagatedBuildInputs = [ numpy six protobuf absl-py ] 138 ++ lib.optional (!isPy3k) mock 139 ++ lib.optionals (pythonOlder "3.4") [ backports_weakref enum34 ] 140 ++ lib.optional withTensorboard tensorflow-tensorboard; 141 142 # Actual tests are slow and impure. 143 checkPhase = '' 144 ${python.interpreter} -c "import tensorflow" 145 ''; 146 147 meta = with stdenv.lib; { 148 description = "Computation using data flow graphs for scalable machine learning"; 149 homepage = http://tensorflow.org; 150 license = licenses.asl20; 151 maintainers = with maintainers; [ jyp abbradar ]; 152 platforms = with platforms; if cudaSupport then linux else linux ++ darwin; 153 broken = !(xlaSupport -> cudaSupport); 154 }; 155}