1{
2 stdenv,
3 lib,
4 fetchFromGitHub,
5 fetchFromGitLab,
6 fetchFromGitea,
7 buildPythonPackage,
8 python,
9 runCommand,
10 writeShellScript,
11 config,
12 cudaSupport ? config.cudaSupport,
13 cudaPackages,
14 autoAddDriverRunpath,
15 effectiveMagma ?
16 if cudaSupport then
17 magma-cuda-static
18 else if rocmSupport then
19 magma-hip
20 else
21 magma,
22 magma,
23 magma-hip,
24 magma-cuda-static,
25 # Use the system NCCL as long as we're targeting CUDA on a supported platform.
26 useSystemNccl ? (cudaSupport && !cudaPackages.nccl.meta.unsupported || rocmSupport),
27 MPISupport ? false,
28 mpi,
29 buildDocs ? false,
30
31 # tests.cudaAvailable:
32 callPackage,
33
34 # Native build inputs
35 cmake,
36 symlinkJoin,
37 which,
38 pybind11,
39 removeReferencesTo,
40
41 # Build inputs
42 apple-sdk_13,
43 numactl,
44
45 # dependencies
46 astunparse,
47 fsspec,
48 filelock,
49 jinja2,
50 networkx,
51 sympy,
52 numpy,
53 pyyaml,
54 cffi,
55 click,
56 typing-extensions,
57 # ROCm build and `torch.compile` requires `triton`
58 tritonSupport ? (!stdenv.hostPlatform.isDarwin),
59 triton,
60
61 # TODO: 1. callPackage needs to learn to distinguish between the task
62 # of "asking for an attribute from the parent scope" and
63 # the task of "exposing a formal parameter in .override".
64 # TODO: 2. We should probably abandon attributes such as `torchWithCuda` (etc.)
65 # as they routinely end up consuming the wrong arguments\
66 # (dependencies without cuda support).
67 # Instead we should rely on overlays and nixpkgsFun.
68 # (@SomeoneSerge)
69 _tritonEffective ? if cudaSupport then triton-cuda else triton,
70 triton-cuda,
71
72 # Unit tests
73 hypothesis,
74 psutil,
75
76 # Disable MKLDNN on aarch64-darwin, it negatively impacts performance,
77 # this is also what official pytorch build does
78 mklDnnSupport ? !(stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64),
79
80 # virtual pkg that consistently instantiates blas across nixpkgs
81 # See https://github.com/NixOS/nixpkgs/pull/83888
82 blas,
83
84 # ninja (https://ninja-build.org) must be available to run C++ extensions tests,
85 ninja,
86
87 # dependencies for torch.utils.tensorboard
88 pillow,
89 six,
90 future,
91 tensorboard,
92 protobuf,
93
94 # ROCm dependencies
95 rocmSupport ? config.rocmSupport,
96 rocmPackages_5,
97 gpuTargets ? [ ],
98
99 vulkanSupport ? false,
100 vulkan-headers,
101 vulkan-loader,
102 shaderc,
103}:
104
105let
106 inherit (lib)
107 attrsets
108 lists
109 strings
110 trivial
111 ;
112 inherit (cudaPackages) cudaFlags cudnn nccl;
113
114 triton = throw "python3Packages.torch: use _tritonEffective instead of triton to avoid divergence";
115
116 rocmPackages = rocmPackages_5;
117
118 setBool = v: if v then "1" else "0";
119
120 # https://github.com/pytorch/pytorch/blob/v2.4.0/torch/utils/cpp_extension.py#L1953
121 supportedTorchCudaCapabilities =
122 let
123 real = [
124 "3.5"
125 "3.7"
126 "5.0"
127 "5.2"
128 "5.3"
129 "6.0"
130 "6.1"
131 "6.2"
132 "7.0"
133 "7.2"
134 "7.5"
135 "8.0"
136 "8.6"
137 "8.7"
138 "8.9"
139 "9.0"
140 "9.0a"
141 ];
142 ptx = lists.map (x: "${x}+PTX") real;
143 in
144 real ++ ptx;
145
146 # NOTE: The lists.subtractLists function is perhaps a bit unintuitive. It subtracts the elements
147 # of the first list *from* the second list. That means:
148 # lists.subtractLists a b = b - a
149
150 # For CUDA
151 supportedCudaCapabilities = lists.intersectLists cudaFlags.cudaCapabilities supportedTorchCudaCapabilities;
152 unsupportedCudaCapabilities = lists.subtractLists supportedCudaCapabilities cudaFlags.cudaCapabilities;
153
154 # Use trivial.warnIf to print a warning if any unsupported GPU targets are specified.
155 gpuArchWarner =
156 supported: unsupported:
157 trivial.throwIf (supported == [ ]) (
158 "No supported GPU targets specified. Requested GPU targets: "
159 + strings.concatStringsSep ", " unsupported
160 ) supported;
161
162 # Create the gpuTargetString.
163 gpuTargetString = strings.concatStringsSep ";" (
164 if gpuTargets != [ ] then
165 # If gpuTargets is specified, it always takes priority.
166 gpuTargets
167 else if cudaSupport then
168 gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities
169 else if rocmSupport then
170 rocmPackages.clr.gpuTargets
171 else
172 throw "No GPU targets specified"
173 );
174
175 rocmtoolkit_joined = symlinkJoin {
176 name = "rocm-merged";
177
178 paths = with rocmPackages; [
179 rocm-core
180 clr
181 rccl
182 miopen
183 miopengemm
184 rocrand
185 rocblas
186 rocsparse
187 hipsparse
188 rocthrust
189 rocprim
190 hipcub
191 roctracer
192 rocfft
193 rocsolver
194 hipfft
195 hipsolver
196 hipblas
197 rocminfo
198 rocm-thunk
199 rocm-comgr
200 rocm-device-libs
201 rocm-runtime
202 clr.icd
203 hipify
204 ];
205
206 # Fix `setuptools` not being found
207 postBuild = ''
208 rm -rf $out/nix-support
209 '';
210 };
211
212 brokenConditions = attrsets.filterAttrs (_: cond: cond) {
213 "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
214 "CUDA is not targeting Linux" = cudaSupport && !stdenv.hostPlatform.isLinux;
215 "Unsupported CUDA version" =
216 cudaSupport
217 && !(builtins.elem cudaPackages.cudaMajorVersion [
218 "11"
219 "12"
220 ]);
221 "MPI cudatoolkit does not match cudaPackages.cudatoolkit" =
222 MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
223 # This used to be a deep package set comparison between cudaPackages and
224 # effectiveMagma.cudaPackages, making torch too strict in cudaPackages.
225 # In particular, this triggered warnings from cuda's `aliases.nix`
226 "Magma cudaPackages does not match cudaPackages" =
227 cudaSupport && (effectiveMagma.cudaPackages.cudaVersion != cudaPackages.cudaVersion);
228 "Rocm support is currently broken because `rocmPackages.hipblaslt` is unpackaged. (2024-06-09)" =
229 rocmSupport;
230 };
231
232 git-unroll = fetchFromGitea {
233 domain = "codeberg.org";
234 owner = "gm6k";
235 repo = "git-unroll";
236 rev = "96bf24f2af153310ec59979c123a8cefda8636db";
237 hash = "sha256-BTlq2Pm4l/oypBzKKpxExVPyQ0CcAP8llUnl/fd3DUU=";
238 };
239
240 unroll-src = writeShellScript "unroll-src" ''
241 echo "{
242 version,
243 fetchFromGitLab,
244 fetchFromGitHub,
245 runCommand,
246 }:
247 assert version == "'"'$1'"'";"
248 ${git-unroll}/unroll https://github.com/pytorch/pytorch v$1
249 echo
250 echo "# Update using: unroll-src [version]"
251 '';
252in
253buildPythonPackage rec {
254 pname = "torch";
255 # Don't forget to update torch-bin to the same version.
256 version = "2.5.1";
257 pyproject = true;
258
259 outputs = [
260 "out" # output standard python package
261 "dev" # output libtorch headers
262 "lib" # output libtorch libraries
263 "cxxdev" # propagated deps for the cmake consumers of torch
264 ];
265 cudaPropagateToOutput = "cxxdev";
266
267 src = callPackage ./src.nix {
268 inherit
269 version
270 fetchFromGitHub
271 fetchFromGitLab
272 runCommand
273 ;
274 };
275
276 patches =
277 lib.optionals cudaSupport [ ./fix-cmake-cuda-toolkit.patch ]
278 ++ lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isx86_64) [
279 # pthreadpool added support for Grand Central Dispatch in April
280 # 2020. However, this relies on functionality (DISPATCH_APPLY_AUTO)
281 # that is available starting with macOS 10.13. However, our current
282 # base is 10.12. Until we upgrade, we can fall back on the older
283 # pthread support.
284 ./pthreadpool-disable-gcd.diff
285 ]
286 ++ lib.optionals stdenv.hostPlatform.isLinux [
287 # Propagate CUPTI to Kineto by overriding the search path with environment variables.
288 # https://github.com/pytorch/pytorch/pull/108847
289 ./pytorch-pr-108847.patch
290 ]
291 ++ lib.optionals (lib.getName blas.provider == "mkl") [
292 # The CMake install tries to add some hardcoded rpaths, incompatible
293 # with the Nix store, which fails. Simply remove this step to get
294 # rpaths that point to the Nix store.
295 ./disable-cmake-mkl-rpath.patch
296 ];
297
298 postPatch =
299 ''
300 substituteInPlace cmake/public/cuda.cmake \
301 --replace-fail \
302 'message(FATAL_ERROR "Found two conflicting CUDA' \
303 'message(WARNING "Found two conflicting CUDA' \
304 --replace-warn \
305 "set(CUDAToolkit_ROOT" \
306 "# Upstream: set(CUDAToolkit_ROOT"
307 substituteInPlace third_party/gloo/cmake/Cuda.cmake \
308 --replace-warn "find_package(CUDAToolkit 7.0" "find_package(CUDAToolkit"
309 ''
310 + lib.optionalString rocmSupport ''
311 # https://github.com/facebookincubator/gloo/pull/297
312 substituteInPlace third_party/gloo/cmake/Hipify.cmake \
313 --replace "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}"
314
315 # Replace hard-coded rocm paths
316 substituteInPlace caffe2/CMakeLists.txt \
317 --replace "/opt/rocm" "${rocmtoolkit_joined}" \
318 --replace "hcc/include" "hip/include" \
319 --replace "rocblas/include" "include/rocblas" \
320 --replace "hipsparse/include" "include/hipsparse"
321
322 # Doesn't pick up the environment variable?
323 substituteInPlace third_party/kineto/libkineto/CMakeLists.txt \
324 --replace "\''$ENV{ROCM_SOURCE_DIR}" "${rocmtoolkit_joined}" \
325 --replace "/opt/rocm" "${rocmtoolkit_joined}"
326
327 # Strangely, this is never set in cmake
328 substituteInPlace cmake/public/LoadHIP.cmake \
329 --replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \
330 "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitVersion rocmPackages.clr.version))})"
331 ''
332 # Detection of NCCL version doesn't work particularly well when using the static binary.
333 + lib.optionalString cudaSupport ''
334 substituteInPlace cmake/Modules/FindNCCL.cmake \
335 --replace \
336 'message(FATAL_ERROR "Found NCCL header version and library version' \
337 'message(WARNING "Found NCCL header version and library version'
338 ''
339 # Remove PyTorch's FindCUDAToolkit.cmake and use CMake's default.
340 # NOTE: Parts of pytorch rely on unmaintained FindCUDA.cmake with custom patches to support e.g.
341 # newer architectures (sm_90a). We do want to delete vendored patches, but have to keep them
342 # until https://github.com/pytorch/pytorch/issues/76082 is addressed
343 + lib.optionalString cudaSupport ''
344 rm cmake/Modules/FindCUDAToolkit.cmake
345 ''
346 # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc'
347 # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header.
348 +
349 lib.optionalString
350 (stdenv.hostPlatform.isDarwin && lib.versionOlder stdenv.hostPlatform.darwinSdkVersion "11.0")
351 ''
352 substituteInPlace third_party/pocketfft/pocketfft_hdronly.h --replace-fail '#if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER))
353 inline void *aligned_alloc(size_t align, size_t size)' '#if 0
354 inline void *aligned_alloc(size_t align, size_t size)'
355 '';
356
357 # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
358 # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
359 # without extreme care to ensure they don't lock each other out of shared resources.
360 # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
361 preConfigure =
362 lib.optionalString cudaSupport ''
363 export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
364 export CUPTI_INCLUDE_DIR=${lib.getDev cudaPackages.cuda_cupti}/include
365 export CUPTI_LIBRARY_DIR=${lib.getLib cudaPackages.cuda_cupti}/lib
366 ''
367 + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
368 export CUDNN_INCLUDE_DIR=${lib.getLib cudnn}/include
369 export CUDNN_LIB_DIR=${cudnn.lib}/lib
370 ''
371 + lib.optionalString rocmSupport ''
372 export ROCM_PATH=${rocmtoolkit_joined}
373 export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
374 export PYTORCH_ROCM_ARCH="${gpuTargetString}"
375 export CMAKE_CXX_FLAGS="-I${rocmtoolkit_joined}/include -I${rocmtoolkit_joined}/include/rocblas"
376 python tools/amd_build/build_amd.py
377 '';
378
379 # Use pytorch's custom configurations
380 dontUseCmakeConfigure = true;
381
382 # causes possible redefinition of _FORTIFY_SOURCE
383 hardeningDisable = [ "fortify3" ];
384
385 BUILD_NAMEDTENSOR = setBool true;
386 BUILD_DOCS = setBool buildDocs;
387
388 # We only do an imports check, so do not build tests either.
389 BUILD_TEST = setBool false;
390
391 # Unlike MKL, oneDNN (née MKLDNN) is FOSS, so we enable support for
392 # it by default. PyTorch currently uses its own vendored version
393 # of oneDNN through Intel iDeep.
394 USE_MKLDNN = setBool mklDnnSupport;
395 USE_MKLDNN_CBLAS = setBool mklDnnSupport;
396
397 # Avoid using pybind11 from git submodule
398 # Also avoids pytorch exporting the headers of pybind11
399 USE_SYSTEM_PYBIND11 = true;
400
401 # NB technical debt: building without NNPACK as workaround for missing `six`
402 USE_NNPACK = 0;
403
404 # Explicitly enable MPS for Darwin
405 USE_MPS = setBool stdenv.hostPlatform.isDarwin;
406
407 cmakeFlags =
408 [
409 # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
410 (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaVersion)
411 ]
412 ++ lib.optionals cudaSupport [
413 # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
414 # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
415 (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaVersion)
416 ];
417
418 preBuild = ''
419 export MAX_JOBS=$NIX_BUILD_CORES
420 ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
421 ${cmake}/bin/cmake build
422 '';
423
424 preFixup = ''
425 function join_by { local IFS="$1"; shift; echo "$*"; }
426 function strip2 {
427 IFS=':'
428 read -ra RP <<< $(patchelf --print-rpath $1)
429 IFS=' '
430 RP_NEW=$(join_by : ''${RP[@]:2})
431 patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
432 }
433 for f in $(find ''${out} -name 'libcaffe2*.so')
434 do
435 strip2 $f
436 done
437 '';
438
439 # Override the (weirdly) wrong version set by default. See
440 # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
441 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
442 PYTORCH_BUILD_VERSION = version;
443 PYTORCH_BUILD_NUMBER = 0;
444
445 # In-tree builds of NCCL are not supported.
446 # Use NCCL when cudaSupport is enabled and nccl is available.
447 USE_NCCL = setBool useSystemNccl;
448 USE_SYSTEM_NCCL = USE_NCCL;
449 USE_STATIC_NCCL = USE_NCCL;
450
451 # Set the correct Python library path, broken since
452 # https://github.com/pytorch/pytorch/commit/3d617333e
453 PYTHON_LIB_REL_PATH = "${placeholder "out"}/${python.sitePackages}";
454
455 env =
456 {
457 # Suppress a weird warning in mkl-dnn, part of ideep in pytorch
458 # (upstream seems to have fixed this in the wrong place?)
459 # https://github.com/intel/mkl-dnn/commit/8134d346cdb7fe1695a2aa55771071d455fae0bc
460 # https://github.com/pytorch/pytorch/issues/22346
461 #
462 # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
463 # https://github.com/pytorch/pytorch/blob/v1.11.0/setup.py#L17
464 NIX_CFLAGS_COMPILE = toString (
465 (
466 lib.optionals (blas.implementation == "mkl") [ "-Wno-error=array-bounds" ]
467 # Suppress gcc regression: avx512 math function raises uninitialized variable warning
468 # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105593
469 # See also: Fails to compile with GCC 12.1.0 https://github.com/pytorch/pytorch/issues/77939
470 ++ lib.optionals (stdenv.cc.isGNU && lib.versionAtLeast stdenv.cc.version "12.0.0") [
471 "-Wno-error=maybe-uninitialized"
472 "-Wno-error=uninitialized"
473 ]
474 # Since pytorch 2.0:
475 # gcc-12.2.0/include/c++/12.2.0/bits/new_allocator.h:158:33: error: ‘void operator delete(void*, std::size_t)’
476 # ... called on pointer ‘<unknown>’ with nonzero offset [1, 9223372036854775800] [-Werror=free-nonheap-object]
477 ++ lib.optionals (stdenv.cc.isGNU && lib.versions.major stdenv.cc.version == "12") [
478 "-Wno-error=free-nonheap-object"
479 ]
480 # .../source/torch/csrc/autograd/generated/python_functions_0.cpp:85:3:
481 # error: cast from ... to ... converts to incompatible function type [-Werror,-Wcast-function-type-strict]
482 ++ lib.optionals (stdenv.cc.isClang && lib.versionAtLeast stdenv.cc.version "16") [
483 "-Wno-error=cast-function-type-strict"
484 # Suppresses the most spammy warnings.
485 # This is mainly to fix https://github.com/NixOS/nixpkgs/issues/266895.
486 ]
487 ++ lib.optionals rocmSupport [
488 "-Wno-#warnings"
489 "-Wno-cpp"
490 "-Wno-unknown-warning-option"
491 "-Wno-ignored-attributes"
492 "-Wno-deprecated-declarations"
493 "-Wno-defaulted-function-deleted"
494 "-Wno-pass-failed"
495 ]
496 ++ [
497 "-Wno-unused-command-line-argument"
498 "-Wno-uninitialized"
499 "-Wno-array-bounds"
500 "-Wno-free-nonheap-object"
501 "-Wno-unused-result"
502 ]
503 ++ lib.optionals stdenv.cc.isGNU [
504 "-Wno-maybe-uninitialized"
505 "-Wno-stringop-overflow"
506 ]
507 )
508 );
509
510 USE_VULKAN = setBool vulkanSupport;
511 }
512 // lib.optionalAttrs vulkanSupport {
513 VULKAN_SDK = shaderc.bin;
514 };
515
516 nativeBuildInputs =
517 [
518 cmake
519 which
520 ninja
521 pybind11
522 removeReferencesTo
523 ]
524 ++ lib.optionals cudaSupport (
525 with cudaPackages;
526 [
527 autoAddDriverRunpath
528 cuda_nvcc
529 ]
530 )
531 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
532
533 buildInputs =
534 [
535 blas
536 blas.provider
537 ]
538 ++ lib.optionals cudaSupport (
539 with cudaPackages;
540 [
541 cuda_cccl # <thrust/*>
542 cuda_cudart # cuda_runtime.h and libraries
543 cuda_cupti # For kineto
544 cuda_nvcc # crt/host_config.h; even though we include this in nativeBuildInputs, it's needed here too
545 cuda_nvml_dev # <nvml.h>
546 cuda_nvrtc
547 cuda_nvtx # -llibNVToolsExt
548 libcublas
549 libcufft
550 libcurand
551 libcusolver
552 libcusparse
553 ]
554 ++ lists.optionals (cudaPackages ? cudnn) [ cudnn ]
555 ++ lists.optionals useSystemNccl [
556 # Some platforms do not support NCCL (i.e., Jetson)
557 nccl # Provides nccl.h AND a static copy of NCCL!
558 ]
559 ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
560 cuda_nvprof # <cuda_profiler_api.h>
561 ]
562 ++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [
563 cuda_profiler_api # <cuda_profiler_api.h>
564 ]
565 )
566 ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]
567 ++ lib.optionals (cudaSupport || rocmSupport) [ effectiveMagma ]
568 ++ lib.optionals stdenv.hostPlatform.isLinux [ numactl ]
569 ++ lib.optionals stdenv.hostPlatform.isDarwin [
570 apple-sdk_13
571 ]
572 ++ lib.optionals tritonSupport [ _tritonEffective ]
573 ++ lib.optionals MPISupport [ mpi ]
574 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
575
576 pythonRelaxDeps = [
577 "sympy"
578 ];
579 dependencies =
580 [
581 astunparse
582 cffi
583 click
584 numpy
585 pyyaml
586
587 # From install_requires:
588 fsspec
589 filelock
590 typing-extensions
591 sympy
592 networkx
593 jinja2
594
595 # the following are required for tensorboard support
596 pillow
597 six
598 future
599 tensorboard
600 protobuf
601
602 # torch/csrc requires `pybind11` at runtime
603 pybind11
604 ]
605 ++ lib.optionals tritonSupport [ _tritonEffective ]
606 ++ lib.optionals vulkanSupport [
607 vulkan-headers
608 vulkan-loader
609 ];
610
611 propagatedCxxBuildInputs =
612 [ ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
613
614 # Tests take a long time and may be flaky, so just sanity-check imports
615 doCheck = false;
616
617 pythonImportsCheck = [ "torch" ];
618
619 nativeCheckInputs = [
620 hypothesis
621 ninja
622 psutil
623 ];
624
625 checkPhase =
626 with lib.versions;
627 with lib.strings;
628 concatStringsSep " " [
629 "runHook preCheck"
630 "${python.interpreter} test/run_test.py"
631 "--exclude"
632 (concatStringsSep " " [
633 "utils" # utils requires git, which is not allowed in the check phase
634
635 # "dataloader" # psutils correctly finds and triggers multiprocessing, but is too sandboxed to run -- resulting in numerous errors
636 # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build
637
638 # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins
639 (optionalString (majorMinor version == "1.3") "tensorboard")
640 ])
641 "runHook postCheck"
642 ];
643
644 pythonRemoveDeps = [
645 # In our dist-info the name is just "triton"
646 "pytorch-triton-rocm"
647 ];
648
649 postInstall =
650 ''
651 find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +
652
653 mkdir $dev
654 cp -r $out/${python.sitePackages}/torch/include $dev/include
655 cp -r $out/${python.sitePackages}/torch/share $dev/share
656
657 # Fix up library paths for split outputs
658 substituteInPlace \
659 $dev/share/cmake/Torch/TorchConfig.cmake \
660 --replace \''${TORCH_INSTALL_PREFIX}/lib "$lib/lib"
661
662 substituteInPlace \
663 $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \
664 --replace \''${_IMPORT_PREFIX}/lib "$lib/lib"
665
666 mkdir $lib
667 mv $out/${python.sitePackages}/torch/lib $lib/lib
668 ln -s $lib/lib $out/${python.sitePackages}/torch/lib
669 ''
670 + lib.optionalString rocmSupport ''
671 substituteInPlace $dev/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake \
672 --replace "\''${_IMPORT_PREFIX}/lib64" "$lib/lib"
673
674 substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \
675 --replace "/build/source/torch/include" "$dev/include"
676 '';
677
678 postFixup =
679 ''
680 mkdir -p "$cxxdev/nix-support"
681 printWords "''${propagatedCxxBuildInputs[@]}" >> "$cxxdev/nix-support/propagated-build-inputs"
682 ''
683 + lib.optionalString stdenv.hostPlatform.isDarwin ''
684 for f in $(ls $lib/lib/*.dylib); do
685 install_name_tool -id $lib/lib/$(basename $f) $f || true
686 done
687
688 install_name_tool -change @rpath/libshm.dylib $lib/lib/libshm.dylib $lib/lib/libtorch_python.dylib
689 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libtorch_python.dylib
690 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch_python.dylib
691
692 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch.dylib
693
694 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libshm.dylib
695 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libshm.dylib
696 '';
697
698 # See https://github.com/NixOS/nixpkgs/issues/296179
699 #
700 # This is a quick hack to add `libnvrtc` to the runpath so that torch can find
701 # it when it is needed at runtime.
702 extraRunpaths = lib.optionals cudaSupport [ "${lib.getLib cudaPackages.cuda_nvrtc}/lib" ];
703 postPhases = lib.optionals stdenv.hostPlatform.isLinux [ "postPatchelfPhase" ];
704 postPatchelfPhase = ''
705 while IFS= read -r -d $'\0' elf ; do
706 for extra in $extraRunpaths ; do
707 echo patchelf "$elf" --add-rpath "$extra" >&2
708 patchelf "$elf" --add-rpath "$extra"
709 done
710 done < <(
711 find "''${!outputLib}" "$out" -type f -iname '*.so' -print0
712 )
713 '';
714
715 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
716 requiredSystemFeatures = [ "big-parallel" ];
717
718 passthru = {
719 inherit
720 cudaSupport
721 cudaPackages
722 rocmSupport
723 rocmPackages
724 unroll-src
725 ;
726 cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
727 # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
728 blasProvider = blas.provider;
729 # To help debug when a package is broken due to CUDA support
730 inherit brokenConditions;
731 tests = callPackage ./tests.nix { };
732 };
733
734 meta = {
735 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
736 # keep PyTorch in the description so the package can be found under that name on search.nixos.org
737 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
738 homepage = "https://pytorch.org/";
739 license = lib.licenses.bsd3;
740 maintainers = with lib.maintainers; [
741 teh
742 thoughtpolice
743 tscholak
744 ]; # tscholak esp. for darwin-related builds
745 platforms =
746 lib.platforms.linux
747 ++ lib.optionals (!cudaSupport && !rocmSupport) lib.platforms.darwin;
748 broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
749 };
750}