1{ lib, stdenv
2, buildPythonPackage
3, fetchurl
4, isPy37
5, isPy38
6, isPy39
7, isPy310
8, python
9, addOpenGLRunpath
10, future
11, numpy
12, patchelf
13, pyyaml
14, requests
15, setuptools
16, typing-extensions
17}:
18
19let
20 pyVerNoDot = builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion;
21 srcs = import ./binary-hashes.nix version;
22 unsupported = throw "Unsupported system";
23 version = "1.12.1";
24in buildPythonPackage {
25 inherit version;
26
27 pname = "torch";
28 # Don't forget to update torch to the same version.
29
30 format = "wheel";
31
32 disabled = !(isPy37 || isPy38 || isPy39 || isPy310);
33
34 src = fetchurl srcs."${stdenv.system}-${pyVerNoDot}" or unsupported;
35
36 nativeBuildInputs = [
37 addOpenGLRunpath
38 patchelf
39 ];
40
41 propagatedBuildInputs = [
42 future
43 numpy
44 pyyaml
45 requests
46 setuptools
47 typing-extensions
48 ];
49
50 postInstall = ''
51 # ONNX conversion
52 rm -rf $out/bin
53 '';
54
55 postFixup = let
56 rpath = lib.makeLibraryPath [ stdenv.cc.cc.lib ];
57 in ''
58 find $out/${python.sitePackages}/torch/lib -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
59 echo "setting rpath for $lib..."
60 patchelf --set-rpath "${rpath}:$out/${python.sitePackages}/torch/lib" "$lib"
61 addOpenGLRunpath "$lib"
62 done
63 '';
64
65 # The wheel-binary is not stripped to avoid the error of `ImportError: libtorch_cuda_cpp.so: ELF load command address/offset not properly aligned.`.
66 dontStrip = true;
67
68 pythonImportsCheck = [ "torch" ];
69
70 meta = with lib; {
71 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
72 homepage = "https://pytorch.org/";
73 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
74 # Includes CUDA and Intel MKL, but redistributions of the binary are not limited.
75 # https://docs.nvidia.com/cuda/eula/index.html
76 # https://www.intel.com/content/www/us/en/developer/articles/license/onemkl-license-faq.html
77 license = licenses.bsd3;
78 sourceProvenance = with sourceTypes; [ binaryNativeCode ];
79 platforms = platforms.linux ++ platforms.darwin;
80 hydraPlatforms = []; # output size 3.2G on 1.11.0
81 maintainers = with maintainers; [ junjihashimoto ];
82 };
83}