1{ buildPythonPackage, pythonOlder, 2 cudaSupport ? false, cudatoolkit ? null, cudnn ? null, 3 fetchFromGitHub, lib, numpy, pyyaml, cffi, typing, cmake, 4 linkFarm, symlinkJoin, 5 utillinux, which }: 6 7assert cudnn == null || cudatoolkit != null; 8assert !cudaSupport || cudatoolkit != null; 9 10let 11 cudatoolkit_joined = symlinkJoin { 12 name = "${cudatoolkit.name}-unsplit"; 13 paths = [ cudatoolkit.out cudatoolkit.lib ]; 14 }; 15 16 # Normally libcuda.so.1 is provided at runtime by nvidia-x11 via 17 # LD_LIBRARY_PATH=/run/opengl-driver/lib. We only use the stub 18 # libcuda.so from cudatoolkit for running tests, so that we don’t have 19 # to recompile pytorch on every update to nvidia-x11 or the kernel. 20 cudaStub = linkFarm "cuda-stub" [{ 21 name = "libcuda.so.1"; 22 path = "${cudatoolkit}/lib/stubs/libcuda.so"; 23 }]; 24 cudaStubEnv = lib.optionalString cudaSupport 25 "LD_LIBRARY_PATH=${cudaStub}\${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH} "; 26 27in buildPythonPackage rec { 28 version = "0.4.1"; 29 pname = "pytorch"; 30 31 src = fetchFromGitHub { 32 owner = "pytorch"; 33 repo = "pytorch"; 34 rev = "v${version}"; 35 fetchSubmodules = true; 36 sha256 = "1cr8h47jxgfar5bamyvlayvqymnb2qvp7rr0ka2d2d4rdldf9lrp"; 37 }; 38 39 preConfigure = lib.optionalString cudaSupport '' 40 export CC=${cudatoolkit.cc}/bin/gcc CXX=${cudatoolkit.cc}/bin/g++ 41 '' + lib.optionalString (cudaSupport && cudnn != null) '' 42 export CUDNN_INCLUDE_DIR=${cudnn}/include 43 ''; 44 45 buildInputs = [ 46 cmake 47 numpy.blas 48 utillinux 49 which 50 ] ++ lib.optionals cudaSupport [cudatoolkit_joined cudnn]; 51 52 propagatedBuildInputs = [ 53 cffi 54 numpy 55 pyyaml 56 ] ++ lib.optional (pythonOlder "3.5") typing; 57 58 checkPhase = '' 59 ${cudaStubEnv}python test/run_test.py --exclude distributed autograd distributions jit sparse torch utils nn 60 ''; 61 62 meta = { 63 description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration."; 64 homepage = https://pytorch.org/; 65 license = lib.licenses.bsd3; 66 platforms = lib.platforms.linux; 67 maintainers = with lib.maintainers; [ teh ]; 68 }; 69}