1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 autoAddDriverRunpath,
6 fetchurl,
7 python,
8 pythonAtLeast,
9 pythonOlder,
10 addDriverRunpath,
11 callPackage,
12 cudaPackages,
13 future,
14 numpy,
15 autoPatchelfHook,
16 pyyaml,
17 requests,
18 setuptools,
19 torch-bin,
20 typing-extensions,
21 sympy,
22 jinja2,
23 networkx,
24 filelock,
25 triton,
26}:
27
28let
29 pyVerNoDot = builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion;
30 srcs = import ./binary-hashes.nix version;
31 unsupported = throw "Unsupported system";
32 version = "2.3.1";
33in
34buildPythonPackage {
35 inherit version;
36
37 pname = "torch";
38 # Don't forget to update torch to the same version.
39
40 format = "wheel";
41
42 disabled = (pythonOlder "3.8") || (pythonAtLeast "3.13");
43
44 src = fetchurl srcs."${stdenv.system}-${pyVerNoDot}" or unsupported;
45
46 nativeBuildInputs = lib.optionals stdenv.isLinux [
47 addDriverRunpath
48 autoPatchelfHook
49 autoAddDriverRunpath
50 ];
51
52 buildInputs = lib.optionals stdenv.isLinux (
53 with cudaPackages;
54 [
55 # $out/${sitePackages}/nvfuser/_C*.so wants libnvToolsExt.so.1 but torch/lib only ships
56 # libnvToolsExt-$hash.so.1
57 cuda_nvtx
58
59 cuda_cudart
60 cuda_cupti
61 cuda_nvrtc
62 cudnn
63 libcublas
64 libcufft
65 libcurand
66 libcusolver
67 libcusparse
68 nccl
69 ]
70 );
71
72 autoPatchelfIgnoreMissingDeps = lib.optionals stdenv.isLinux [
73 # This is the hardware-dependent userspace driver that comes from
74 # nvidia_x11 package. It must be deployed at runtime in
75 # /run/opengl-driver/lib or pointed at by LD_LIBRARY_PATH variable, rather
76 # than pinned in runpath
77 "libcuda.so.1"
78 ];
79
80 dependencies = [
81 future
82 numpy
83 pyyaml
84 requests
85 setuptools
86 typing-extensions
87 sympy
88 jinja2
89 networkx
90 filelock
91 ] ++ lib.optionals (stdenv.isLinux && stdenv.isx86_64) [ triton ];
92
93 postInstall = ''
94 # ONNX conversion
95 rm -rf $out/bin
96 '';
97
98 postFixup = lib.optionalString stdenv.isLinux ''
99 addAutoPatchelfSearchPath "$out/${python.sitePackages}/torch/lib"
100 '';
101
102 # See https://github.com/NixOS/nixpkgs/issues/296179
103 #
104 # This is a quick hack to add `libnvrtc` to the runpath so that torch can find
105 # it when it is needed at runtime.
106 extraRunpaths = lib.optionals stdenv.hostPlatform.isLinux [ "${lib.getLib cudaPackages.cuda_nvrtc}/lib" ];
107 postPhases = lib.optionals stdenv.isLinux [ "postPatchelfPhase" ];
108 postPatchelfPhase = ''
109 while IFS= read -r -d $'\0' elf ; do
110 for extra in $extraRunpaths ; do
111 echo patchelf "$elf" --add-rpath "$extra" >&2
112 patchelf "$elf" --add-rpath "$extra"
113 done
114 done < <(
115 find "''${!outputLib}" "$out" -type f -iname '*.so' -print0
116 )
117 '';
118
119 # The wheel-binary is not stripped to avoid the error of `ImportError: libtorch_cuda_cpp.so: ELF load command address/offset not properly aligned.`.
120 dontStrip = true;
121
122 pythonImportsCheck = [ "torch" ];
123
124 passthru.tests = callPackage ./tests.nix {};
125
126 meta = {
127 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
128 homepage = "https://pytorch.org/";
129 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
130 # Includes CUDA and Intel MKL, but redistributions of the binary are not limited.
131 # https://docs.nvidia.com/cuda/eula/index.html
132 # https://www.intel.com/content/www/us/en/developer/articles/license/onemkl-license-faq.html
133 # torch's license is BSD3.
134 # torch-bin used to vendor CUDA. It still links against CUDA and MKL.
135 license = with lib.licenses; [
136 bsd3
137 issl
138 unfreeRedistributable
139 ];
140 sourceProvenance = with lib.sourceTypes; [ binaryNativeCode ];
141 platforms = [
142 "aarch64-darwin"
143 "aarch64-linux"
144 "x86_64-linux"
145 ];
146 hydraPlatforms = [ ]; # output size 3.2G on 1.11.0
147 maintainers = with lib.maintainers; [ junjihashimoto ];
148 };
149}