1{ lib, stdenv
2, buildPythonPackage
3, fetchurl
4, isPy37
5, isPy38
6, isPy39
7, python
8, addOpenGLRunpath
9, future
10, numpy
11, patchelf
12, pyyaml
13, requests
14, typing-extensions
15}:
16
17let
18 pyVerNoDot = builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion;
19 srcs = import ./binary-hashes.nix version;
20 unsupported = throw "Unsupported system";
21 version = "1.8.1";
22in buildPythonPackage {
23 inherit version;
24
25 pname = "pytorch";
26 # Don't forget to update pytorch to the same version.
27
28 format = "wheel";
29
30 disabled = !(isPy37 || isPy38 || isPy39);
31
32 src = fetchurl srcs."${stdenv.system}-${pyVerNoDot}" or unsupported;
33
34 nativeBuildInputs = [
35 addOpenGLRunpath
36 patchelf
37 ];
38
39 propagatedBuildInputs = [
40 future
41 numpy
42 pyyaml
43 requests
44 typing-extensions
45 ];
46
47 postInstall = ''
48 # ONNX conversion
49 rm -rf $out/bin
50 '';
51
52 postFixup = let
53 rpath = lib.makeLibraryPath [ stdenv.cc.cc.lib ];
54 in ''
55 find $out/${python.sitePackages}/torch/lib -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
56 echo "setting rpath for $lib..."
57 patchelf --set-rpath "${rpath}:$out/${python.sitePackages}/torch/lib" "$lib"
58 addOpenGLRunpath "$lib"
59 done
60 '';
61
62 pythonImportsCheck = [ "torch" ];
63
64 meta = with lib; {
65 description = "Open source, prototype-to-production deep learning platform";
66 homepage = "https://pytorch.org/";
67 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
68 license = licenses.unfree; # Includes CUDA and Intel MKL.
69 platforms = platforms.linux;
70 maintainers = with maintainers; [ danieldk ];
71 };
72}