1{ lib, stdenv
2, buildPythonPackage
3, fetchurl
4, python
5, pythonAtLeast
6, pythonOlder
7, addOpenGLRunpath
8, cudaPackages
9, future
10, numpy
11, autoPatchelfHook
12, patchelf
13, pyyaml
14, requests
15, setuptools
16, typing-extensions
17, sympy
18, jinja2
19, networkx
20, filelock
21, openai-triton
22}:
23
24let
25 pyVerNoDot = builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion;
26 srcs = import ./binary-hashes.nix version;
27 unsupported = throw "Unsupported system";
28 version = "2.0.1";
29in buildPythonPackage {
30 inherit version;
31
32 pname = "torch";
33 # Don't forget to update torch to the same version.
34
35 format = "wheel";
36
37 disabled = (pythonOlder "3.8") || (pythonAtLeast "3.12");
38
39 src = fetchurl srcs."${stdenv.system}-${pyVerNoDot}" or unsupported;
40
41 nativeBuildInputs = [
42 addOpenGLRunpath
43 autoPatchelfHook
44 cudaPackages.autoAddOpenGLRunpathHook
45 patchelf
46 ];
47
48 buildInputs = with cudaPackages; [
49 # $out/${sitePackages}/nvfuser/_C*.so wants libnvToolsExt.so.1 but torch/lib only ships
50 # libnvToolsExt-$hash.so.1
51 cuda_nvtx
52 ];
53
54 autoPatchelfIgnoreMissingDeps = [
55 # This is the hardware-dependent userspace driver that comes from
56 # nvidia_x11 package. It must be deployed at runtime in
57 # /run/opengl-driver/lib or pointed at by LD_LIBRARY_PATH variable, rather
58 # than pinned in runpath
59 "libcuda.so.1"
60 ];
61
62 propagatedBuildInputs = [
63 future
64 numpy
65 pyyaml
66 requests
67 setuptools
68 typing-extensions
69 sympy
70 jinja2
71 networkx
72 filelock
73 ] ++ lib.optionals stdenv.isx86_64 [
74 openai-triton
75 ];
76
77 postInstall = ''
78 # ONNX conversion
79 rm -rf $out/bin
80 '';
81
82 postFixup = ''
83 addAutoPatchelfSearchPath "$out/${python.sitePackages}/torch/lib"
84
85 patchelf $out/${python.sitePackages}/torch/lib/libcudnn.so.8 --add-needed libcudnn_cnn_infer.so.8
86
87 pushd $out/${python.sitePackages}/torch/lib || exit 1
88 for LIBNVRTC in ./libnvrtc*
89 do
90 case "$LIBNVRTC" in
91 ./libnvrtc-builtins*) true;;
92 ./libnvrtc*) patchelf "$LIBNVRTC" --add-needed libnvrtc-builtins* ;;
93 esac
94 done
95 popd || exit 1
96 '';
97
98 # The wheel-binary is not stripped to avoid the error of `ImportError: libtorch_cuda_cpp.so: ELF load command address/offset not properly aligned.`.
99 dontStrip = true;
100
101 pythonImportsCheck = [ "torch" ];
102
103 meta = with lib; {
104 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
105 homepage = "https://pytorch.org/";
106 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
107 # Includes CUDA and Intel MKL, but redistributions of the binary are not limited.
108 # https://docs.nvidia.com/cuda/eula/index.html
109 # https://www.intel.com/content/www/us/en/developer/articles/license/onemkl-license-faq.html
110 # torch's license is BSD3.
111 # torch-bin includes CUDA and MKL binaries, therefore unfreeRedistributable is set.
112 license = with licenses; [ bsd3 issl unfreeRedistributable ];
113 sourceProvenance = with sourceTypes; [ binaryNativeCode ];
114 platforms = [ "aarch64-darwin" "aarch64-linux" "x86_64-darwin" "x86_64-linux" ];
115 hydraPlatforms = []; # output size 3.2G on 1.11.0
116 maintainers = with maintainers; [ junjihashimoto ];
117 };
118}