1{ lib, stdenv
2, buildPythonPackage
3, fetchurl
4, python
5, pythonAtLeast
6, pythonOlder
7, addOpenGLRunpath
8, cudaPackages
9, future
10, numpy
11, autoPatchelfHook
12, patchelf
13, pyyaml
14, requests
15, setuptools
16, typing-extensions
17, sympy
18, jinja2
19, networkx
20, filelock
21, openai-triton
22}:
23
24let
25 pyVerNoDot = builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion;
26 srcs = import ./binary-hashes.nix version;
27 unsupported = throw "Unsupported system";
28 version = "2.0.1";
29in buildPythonPackage {
30 inherit version;
31
32 pname = "torch";
33 # Don't forget to update torch to the same version.
34
35 format = "wheel";
36
37 disabled = (pythonOlder "3.8") || (pythonAtLeast "3.12");
38
39 src = fetchurl srcs."${stdenv.system}-${pyVerNoDot}" or unsupported;
40
41 nativeBuildInputs = lib.optionals stdenv.isLinux [
42 addOpenGLRunpath
43 autoPatchelfHook
44 cudaPackages.autoAddOpenGLRunpathHook
45 ];
46
47 buildInputs = lib.optionals stdenv.isLinux (with cudaPackages; [
48 # $out/${sitePackages}/nvfuser/_C*.so wants libnvToolsExt.so.1 but torch/lib only ships
49 # libnvToolsExt-$hash.so.1
50 cuda_nvtx
51 ]);
52
53 autoPatchelfIgnoreMissingDeps = lib.optionals stdenv.isLinux [
54 # This is the hardware-dependent userspace driver that comes from
55 # nvidia_x11 package. It must be deployed at runtime in
56 # /run/opengl-driver/lib or pointed at by LD_LIBRARY_PATH variable, rather
57 # than pinned in runpath
58 "libcuda.so.1"
59 ];
60
61 propagatedBuildInputs = [
62 future
63 numpy
64 pyyaml
65 requests
66 setuptools
67 typing-extensions
68 sympy
69 jinja2
70 networkx
71 filelock
72 ] ++ lib.optionals stdenv.isx86_64 [
73 openai-triton
74 ];
75
76 postInstall = ''
77 # ONNX conversion
78 rm -rf $out/bin
79 '';
80
81 postFixup = lib.optionalString stdenv.isLinux ''
82 addAutoPatchelfSearchPath "$out/${python.sitePackages}/torch/lib"
83
84 patchelf $out/${python.sitePackages}/torch/lib/libcudnn.so.8 --add-needed libcudnn_cnn_infer.so.8
85
86 pushd $out/${python.sitePackages}/torch/lib || exit 1
87 for LIBNVRTC in ./libnvrtc*
88 do
89 case "$LIBNVRTC" in
90 ./libnvrtc-builtins*) true;;
91 ./libnvrtc*) patchelf "$LIBNVRTC" --add-needed libnvrtc-builtins* ;;
92 esac
93 done
94 popd || exit 1
95 '';
96
97 # The wheel-binary is not stripped to avoid the error of `ImportError: libtorch_cuda_cpp.so: ELF load command address/offset not properly aligned.`.
98 dontStrip = true;
99
100 pythonImportsCheck = [ "torch" ];
101
102 meta = with lib; {
103 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
104 homepage = "https://pytorch.org/";
105 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
106 # Includes CUDA and Intel MKL, but redistributions of the binary are not limited.
107 # https://docs.nvidia.com/cuda/eula/index.html
108 # https://www.intel.com/content/www/us/en/developer/articles/license/onemkl-license-faq.html
109 # torch's license is BSD3.
110 # torch-bin includes CUDA and MKL binaries, therefore unfreeRedistributable is set.
111 license = with licenses; [ bsd3 issl unfreeRedistributable ];
112 sourceProvenance = with sourceTypes; [ binaryNativeCode ];
113 platforms = [ "aarch64-darwin" "aarch64-linux" "x86_64-darwin" "x86_64-linux" ];
114 hydraPlatforms = []; # output size 3.2G on 1.11.0
115 maintainers = with maintainers; [ junjihashimoto ];
116 };
117}