1{ lib, stdenv 2, buildPythonPackage 3, fetchurl 4, isPy37 5, isPy38 6, isPy39 7, isPy310 8, python 9, addOpenGLRunpath 10, future 11, numpy 12, patchelf 13, pyyaml 14, requests 15, setuptools 16, typing-extensions 17}: 18 19let 20 pyVerNoDot = builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion; 21 srcs = import ./binary-hashes.nix version; 22 unsupported = throw "Unsupported system"; 23 version = "1.12.1"; 24in buildPythonPackage { 25 inherit version; 26 27 pname = "torch"; 28 # Don't forget to update torch to the same version. 29 30 format = "wheel"; 31 32 disabled = !(isPy37 || isPy38 || isPy39 || isPy310); 33 34 src = fetchurl srcs."${stdenv.system}-${pyVerNoDot}" or unsupported; 35 36 nativeBuildInputs = [ 37 addOpenGLRunpath 38 patchelf 39 ]; 40 41 propagatedBuildInputs = [ 42 future 43 numpy 44 pyyaml 45 requests 46 setuptools 47 typing-extensions 48 ]; 49 50 postInstall = '' 51 # ONNX conversion 52 rm -rf $out/bin 53 ''; 54 55 postFixup = let 56 rpath = lib.makeLibraryPath [ stdenv.cc.cc.lib ]; 57 in '' 58 find $out/${python.sitePackages}/torch/lib -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do 59 echo "setting rpath for $lib..." 60 patchelf --set-rpath "${rpath}:$out/${python.sitePackages}/torch/lib" "$lib" 61 addOpenGLRunpath "$lib" 62 done 63 ''; 64 65 # The wheel-binary is not stripped to avoid the error of `ImportError: libtorch_cuda_cpp.so: ELF load command address/offset not properly aligned.`. 66 dontStrip = true; 67 68 pythonImportsCheck = [ "torch" ]; 69 70 meta = with lib; { 71 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration"; 72 homepage = "https://pytorch.org/"; 73 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}"; 74 # Includes CUDA and Intel MKL, but redistributions of the binary are not limited. 75 # https://docs.nvidia.com/cuda/eula/index.html 76 # https://www.intel.com/content/www/us/en/developer/articles/license/onemkl-license-faq.html 77 license = licenses.bsd3; 78 sourceProvenance = with sourceTypes; [ binaryNativeCode ]; 79 platforms = platforms.linux ++ platforms.darwin; 80 hydraPlatforms = []; # output size 3.2G on 1.11.0 81 maintainers = with maintainers; [ junjihashimoto ]; 82 }; 83}