1{
2 stdenv,
3 lib,
4 fetchFromGitHub,
5 fetchFromGitLab,
6 git-unroll,
7 buildPythonPackage,
8 python,
9 runCommand,
10 writeShellScript,
11 config,
12 cudaSupport ? config.cudaSupport,
13 cudaPackages,
14 autoAddDriverRunpath,
15 effectiveMagma ?
16 if cudaSupport then
17 magma-cuda-static
18 else if rocmSupport then
19 magma-hip
20 else
21 magma,
22 magma,
23 magma-hip,
24 magma-cuda-static,
25 # Use the system NCCL as long as we're targeting CUDA on a supported platform.
26 useSystemNccl ? (cudaSupport && !cudaPackages.nccl.meta.unsupported || rocmSupport),
27 MPISupport ? false,
28 mpi,
29 buildDocs ? false,
30
31 # tests.cudaAvailable:
32 callPackage,
33
34 # Native build inputs
35 cmake,
36 symlinkJoin,
37 which,
38 pybind11,
39 pkg-config,
40 removeReferencesTo,
41
42 # Build inputs
43 apple-sdk_13,
44 numactl,
45 llvmPackages,
46
47 # dependencies
48 astunparse,
49 expecttest,
50 filelock,
51 fsspec,
52 hypothesis,
53 jinja2,
54 networkx,
55 packaging,
56 psutil,
57 pyyaml,
58 requests,
59 sympy,
60 types-dataclasses,
61 typing-extensions,
62 # ROCm build and `torch.compile` requires `triton`
63 tritonSupport ? (!stdenv.hostPlatform.isDarwin),
64 triton,
65
66 # TODO: 1. callPackage needs to learn to distinguish between the task
67 # of "asking for an attribute from the parent scope" and
68 # the task of "exposing a formal parameter in .override".
69 # TODO: 2. We should probably abandon attributes such as `torchWithCuda` (etc.)
70 # as they routinely end up consuming the wrong arguments\
71 # (dependencies without cuda support).
72 # Instead we should rely on overlays and nixpkgsFun.
73 # (@SomeoneSerge)
74 _tritonEffective ?
75 if cudaSupport then
76 triton-cuda
77 else if rocmSupport then
78 rocmPackages.triton
79 else
80 triton,
81 triton-cuda,
82
83 # Disable MKLDNN on aarch64-darwin, it negatively impacts performance,
84 # this is also what official pytorch build does
85 mklDnnSupport ? !(stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64),
86
87 # virtual pkg that consistently instantiates blas across nixpkgs
88 # See https://github.com/NixOS/nixpkgs/pull/83888
89 blas,
90
91 # ninja (https://ninja-build.org) must be available to run C++ extensions tests,
92 ninja,
93
94 # dependencies for torch.utils.tensorboard
95 pillow,
96 six,
97 tensorboard,
98 protobuf,
99
100 # ROCm dependencies
101 rocmSupport ? config.rocmSupport,
102 rocmPackages,
103 gpuTargets ? [ ],
104
105 vulkanSupport ? false,
106 vulkan-headers,
107 vulkan-loader,
108 shaderc,
109}:
110
111let
112 inherit (lib)
113 attrsets
114 lists
115 strings
116 trivial
117 ;
118 inherit (cudaPackages) cudnn flags nccl;
119
120 triton = throw "python3Packages.torch: use _tritonEffective instead of triton to avoid divergence";
121
122 setBool = v: if v then "1" else "0";
123
124 # https://github.com/pytorch/pytorch/blob/v2.7.0/torch/utils/cpp_extension.py#L2343-L2345
125 supportedTorchCudaCapabilities =
126 let
127 real = [
128 "3.5"
129 "3.7"
130 "5.0"
131 "5.2"
132 "5.3"
133 "6.0"
134 "6.1"
135 "6.2"
136 "7.0"
137 "7.2"
138 "7.5"
139 "8.0"
140 "8.6"
141 "8.7"
142 "8.9"
143 "9.0"
144 "9.0a"
145 "10.0"
146 "10.0"
147 "10.0a"
148 "10.1"
149 "10.1a"
150 "12.0"
151 "12.0a"
152 ];
153 ptx = lists.map (x: "${x}+PTX") real;
154 in
155 real ++ ptx;
156
157 # NOTE: The lists.subtractLists function is perhaps a bit unintuitive. It subtracts the elements
158 # of the first list *from* the second list. That means:
159 # lists.subtractLists a b = b - a
160
161 # For CUDA
162 supportedCudaCapabilities = lists.intersectLists flags.cudaCapabilities supportedTorchCudaCapabilities;
163 unsupportedCudaCapabilities = lists.subtractLists supportedCudaCapabilities flags.cudaCapabilities;
164
165 isCudaJetson = cudaSupport && cudaPackages.flags.isJetsonBuild;
166
167 # Use trivial.warnIf to print a warning if any unsupported GPU targets are specified.
168 gpuArchWarner =
169 supported: unsupported:
170 trivial.throwIf (supported == [ ]) (
171 "No supported GPU targets specified. Requested GPU targets: "
172 + strings.concatStringsSep ", " unsupported
173 ) supported;
174
175 # Create the gpuTargetString.
176 gpuTargetString = strings.concatStringsSep ";" (
177 if gpuTargets != [ ] then
178 # If gpuTargets is specified, it always takes priority.
179 gpuTargets
180 else if cudaSupport then
181 gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities
182 else if rocmSupport then
183 # Remove RDNA1 gfx101x archs from default ROCm support list to avoid
184 # use of undeclared identifier 'CK_BUFFER_RESOURCE_3RD_DWORD'
185 # TODO: Retest after ROCm 6.4 or torch 2.8
186 lib.lists.subtractLists [
187 "gfx1010"
188 "gfx1012"
189 ] (rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets)
190 else
191 throw "No GPU targets specified"
192 );
193
194 rocmtoolkit_joined = symlinkJoin {
195 name = "rocm-merged";
196
197 paths = with rocmPackages; [
198 rocm-core
199 clr
200 rccl
201 miopen
202 aotriton
203 composable_kernel
204 rocrand
205 rocblas
206 rocsparse
207 hipsparse
208 rocthrust
209 rocprim
210 hipcub
211 roctracer
212 rocfft
213 rocsolver
214 hipfft
215 hiprand
216 hipsolver
217 hipblas-common
218 hipblas
219 hipblaslt
220 rocminfo
221 rocm-comgr
222 rocm-device-libs
223 rocm-runtime
224 clr.icd
225 hipify
226 ];
227
228 # Fix `setuptools` not being found
229 postBuild = ''
230 rm -rf $out/nix-support
231 '';
232 };
233
234 brokenConditions = attrsets.filterAttrs (_: cond: cond) {
235 "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
236 "CUDA is not targeting Linux" = cudaSupport && !stdenv.hostPlatform.isLinux;
237 "Unsupported CUDA version" =
238 cudaSupport
239 && !(builtins.elem cudaPackages.cudaMajorVersion [
240 "11"
241 "12"
242 ]);
243 "MPI cudatoolkit does not match cudaPackages.cudatoolkit" =
244 MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
245 # This used to be a deep package set comparison between cudaPackages and
246 # effectiveMagma.cudaPackages, making torch too strict in cudaPackages.
247 # In particular, this triggered warnings from cuda's `aliases.nix`
248 "Magma cudaPackages does not match cudaPackages" =
249 cudaSupport
250 && (effectiveMagma.cudaPackages.cudaMajorMinorVersion != cudaPackages.cudaMajorMinorVersion);
251 };
252
253 unroll-src = writeShellScript "unroll-src" ''
254 echo "{
255 version,
256 fetchFromGitLab,
257 fetchFromGitHub,
258 runCommand,
259 }:
260 assert version == "'"'$1'"'";"
261 ${lib.getExe git-unroll} https://github.com/pytorch/pytorch v$1
262 echo
263 echo "# Update using: unroll-src [version]"
264 '';
265
266 stdenv' = if cudaSupport then cudaPackages.backendStdenv else stdenv;
267in
268buildPythonPackage rec {
269 pname = "torch";
270 # Don't forget to update torch-bin to the same version.
271 version = "2.7.0";
272 pyproject = true;
273
274 stdenv = stdenv';
275
276 outputs = [
277 "out" # output standard python package
278 "dev" # output libtorch headers
279 "lib" # output libtorch libraries
280 "cxxdev" # propagated deps for the cmake consumers of torch
281 ];
282 cudaPropagateToOutput = "cxxdev";
283
284 src = callPackage ./src.nix {
285 inherit
286 version
287 fetchFromGitHub
288 fetchFromGitLab
289 runCommand
290 ;
291 };
292
293 patches =
294 [
295 ./clang19-template-warning.patch
296 ]
297 ++ lib.optionals cudaSupport [ ./fix-cmake-cuda-toolkit.patch ]
298 ++ lib.optionals stdenv.hostPlatform.isLinux [
299 # Propagate CUPTI to Kineto by overriding the search path with environment variables.
300 # https://github.com/pytorch/pytorch/pull/108847
301 ./pytorch-pr-108847.patch
302 ]
303 ++ lib.optionals (lib.getName blas.provider == "mkl") [
304 # The CMake install tries to add some hardcoded rpaths, incompatible
305 # with the Nix store, which fails. Simply remove this step to get
306 # rpaths that point to the Nix store.
307 ./disable-cmake-mkl-rpath.patch
308 ];
309
310 postPatch =
311 ''
312 # Prevent NCCL from being cloned during the configure phase
313 # TODO: remove when updating to the next release as it will not be needed anymore
314 substituteInPlace tools/build_pytorch_libs.py \
315 --replace-fail " checkout_nccl()" " "
316
317 substituteInPlace cmake/public/cuda.cmake \
318 --replace-fail \
319 'message(FATAL_ERROR "Found two conflicting CUDA' \
320 'message(WARNING "Found two conflicting CUDA' \
321 --replace-warn \
322 "set(CUDAToolkit_ROOT" \
323 "# Upstream: set(CUDAToolkit_ROOT"
324 substituteInPlace third_party/gloo/cmake/Cuda.cmake \
325 --replace-warn "find_package(CUDAToolkit 7.0" "find_package(CUDAToolkit"
326
327 # annotations (3.7), print_function (3.0), with_statement (2.6) are all supported
328 sed -i -e "/from __future__ import/d" **.py
329 substituteInPlace third_party/NNPACK/CMakeLists.txt \
330 --replace-fail "PYTHONPATH=" 'PYTHONPATH=$ENV{PYTHONPATH}:'
331 # flag from cmakeFlags doesn't work, not clear why
332 # setting it at the top of NNPACK's own CMakeLists does
333 sed -i '2s;^;set(PYTHON_SIX_SOURCE_DIR ${six.src})\n;' third_party/NNPACK/CMakeLists.txt
334 ''
335 + lib.optionalString rocmSupport ''
336 # https://github.com/facebookincubator/gloo/pull/297
337 substituteInPlace third_party/gloo/cmake/Hipify.cmake \
338 --replace-fail "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}"
339
340 # Replace hard-coded rocm paths
341 substituteInPlace caffe2/CMakeLists.txt \
342 --replace-fail "hcc/include" "hip/include" \
343 --replace-fail "rocblas/include" "include/rocblas" \
344 --replace-fail "hipsparse/include" "include/hipsparse"
345
346 # Doesn't pick up the environment variable?
347 substituteInPlace third_party/kineto/libkineto/CMakeLists.txt \
348 --replace-fail "\''$ENV{ROCM_SOURCE_DIR}" "${rocmtoolkit_joined}"
349
350 # Strangely, this is never set in cmake
351 substituteInPlace cmake/public/LoadHIP.cmake \
352 --replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \
353 "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitVersion rocmPackages.clr.version))})"
354
355 # Use composable kernel as dependency, rather than built-in third-party
356 substituteInPlace aten/src/ATen/CMakeLists.txt \
357 --replace-fail "list(APPEND ATen_HIP_INCLUDE \''${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)" "" \
358 --replace-fail "list(APPEND ATen_HIP_INCLUDE \''${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)" ""
359 ''
360 # Detection of NCCL version doesn't work particularly well when using the static binary.
361 + lib.optionalString cudaSupport ''
362 substituteInPlace cmake/Modules/FindNCCL.cmake \
363 --replace-fail \
364 'message(FATAL_ERROR "Found NCCL header version and library version' \
365 'message(WARNING "Found NCCL header version and library version'
366 ''
367 # Remove PyTorch's FindCUDAToolkit.cmake and use CMake's default.
368 # NOTE: Parts of pytorch rely on unmaintained FindCUDA.cmake with custom patches to support e.g.
369 # newer architectures (sm_90a). We do want to delete vendored patches, but have to keep them
370 # until https://github.com/pytorch/pytorch/issues/76082 is addressed
371 + lib.optionalString cudaSupport ''
372 rm cmake/Modules/FindCUDAToolkit.cmake
373 '';
374
375 # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
376 # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
377 # without extreme care to ensure they don't lock each other out of shared resources.
378 # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
379 preConfigure =
380 lib.optionalString cudaSupport ''
381 export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
382 export CUPTI_INCLUDE_DIR=${lib.getDev cudaPackages.cuda_cupti}/include
383 export CUPTI_LIBRARY_DIR=${lib.getLib cudaPackages.cuda_cupti}/lib
384 ''
385 + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
386 export CUDNN_INCLUDE_DIR=${lib.getLib cudnn}/include
387 export CUDNN_LIB_DIR=${lib.getLib cudnn}/lib
388 ''
389 + lib.optionalString rocmSupport ''
390 export ROCM_PATH=${rocmtoolkit_joined}
391 export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
392 export PYTORCH_ROCM_ARCH="${gpuTargetString}"
393 export CMAKE_CXX_FLAGS="-I${rocmtoolkit_joined}/include -I${rocmtoolkit_joined}/include/rocblas"
394 python tools/amd_build/build_amd.py
395 '';
396
397 # Use pytorch's custom configurations
398 dontUseCmakeConfigure = true;
399
400 # causes possible redefinition of _FORTIFY_SOURCE
401 hardeningDisable = [ "fortify3" ];
402
403 BUILD_NAMEDTENSOR = setBool true;
404 BUILD_DOCS = setBool buildDocs;
405
406 # We only do an imports check, so do not build tests either.
407 BUILD_TEST = setBool false;
408
409 # ninja hook doesn't automatically turn on ninja
410 # because pytorch setup.py is responsible for this
411 CMAKE_GENERATOR = "Ninja";
412
413 # Unlike MKL, oneDNN (née MKLDNN) is FOSS, so we enable support for
414 # it by default. PyTorch currently uses its own vendored version
415 # of oneDNN through Intel iDeep.
416 USE_MKLDNN = setBool mklDnnSupport;
417 USE_MKLDNN_CBLAS = setBool mklDnnSupport;
418
419 # Avoid using pybind11 from git submodule
420 # Also avoids pytorch exporting the headers of pybind11
421 USE_SYSTEM_PYBIND11 = true;
422
423 # Multicore CPU convnet support
424 USE_NNPACK = 1;
425
426 # Explicitly enable MPS for Darwin
427 USE_MPS = setBool stdenv.hostPlatform.isDarwin;
428
429 # building torch.distributed on Darwin is disabled by default
430 # https://pytorch.org/docs/stable/distributed.html#torch.distributed.is_available
431 USE_DISTRIBUTED = setBool true;
432
433 cmakeFlags =
434 [
435 (lib.cmakeFeature "PYTHON_SIX_SOURCE_DIR" "${six.src}")
436 # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
437 (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaMajorMinorVersion)
438 ]
439 ++ lib.optionals cudaSupport [
440 # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
441 # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
442 (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaMajorMinorVersion)
443 ];
444
445 preBuild = ''
446 export MAX_JOBS=$NIX_BUILD_CORES
447 ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
448 ${cmake}/bin/cmake build
449 '';
450
451 preFixup = ''
452 function join_by { local IFS="$1"; shift; echo "$*"; }
453 function strip2 {
454 IFS=':'
455 read -ra RP <<< $(patchelf --print-rpath $1)
456 IFS=' '
457 RP_NEW=$(join_by : ''${RP[@]:2})
458 patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
459 }
460 for f in $(find ''${out} -name 'libcaffe2*.so')
461 do
462 strip2 $f
463 done
464 '';
465
466 # Override the (weirdly) wrong version set by default. See
467 # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
468 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
469 PYTORCH_BUILD_VERSION = version;
470 PYTORCH_BUILD_NUMBER = 0;
471
472 # In-tree builds of NCCL are not supported.
473 # Use NCCL when cudaSupport is enabled and nccl is available.
474 USE_NCCL = setBool useSystemNccl;
475 USE_SYSTEM_NCCL = USE_NCCL;
476 USE_STATIC_NCCL = USE_NCCL;
477
478 # Set the correct Python library path, broken since
479 # https://github.com/pytorch/pytorch/commit/3d617333e
480 PYTHON_LIB_REL_PATH = "${placeholder "out"}/${python.sitePackages}";
481
482 env =
483 {
484 # disable warnings as errors as they break the build on every compiler
485 # bump, among other things.
486 # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
487 # https://github.com/pytorch/pytorch/blob/v1.11.0/setup.py#L17
488 NIX_CFLAGS_COMPILE = toString (
489 [
490 "-Wno-error"
491 ]
492 # fix build aarch64-linux build failure with GCC14
493 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
494 "-Wno-error=incompatible-pointer-types"
495 ]
496 );
497 USE_VULKAN = setBool vulkanSupport;
498 }
499 // lib.optionalAttrs vulkanSupport {
500 VULKAN_SDK = shaderc.bin;
501 }
502 // lib.optionalAttrs rocmSupport {
503 AOTRITON_INSTALLED_PREFIX = "${rocmPackages.aotriton}";
504 };
505
506 nativeBuildInputs =
507 [
508 cmake
509 which
510 ninja
511 pybind11
512 pkg-config
513 removeReferencesTo
514 ]
515 ++ lib.optionals cudaSupport (
516 with cudaPackages;
517 [
518 autoAddDriverRunpath
519 cuda_nvcc
520 ]
521 )
522 ++ lib.optionals isCudaJetson [ cudaPackages.autoAddCudaCompatRunpath ]
523 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
524
525 buildInputs =
526 [
527 blas
528 blas.provider
529 ]
530 ++ lib.optionals stdenv.cc.isClang [ llvmPackages.openmp ]
531 ++ lib.optionals cudaSupport (
532 with cudaPackages;
533 [
534 cuda_cccl # <thrust/*>
535 cuda_cudart # cuda_runtime.h and libraries
536 cuda_cupti # For kineto
537 cuda_nvcc # crt/host_config.h; even though we include this in nativeBuildInputs, it's needed here too
538 cuda_nvml_dev # <nvml.h>
539 cuda_nvrtc
540 cuda_nvtx # -llibNVToolsExt
541 cusparselt
542 libcublas
543 libcufft
544 libcufile
545 libcurand
546 libcusolver
547 libcusparse
548 ]
549 ++ lists.optionals (cudaPackages ? cudnn) [ cudnn ]
550 ++ lists.optionals useSystemNccl [
551 # Some platforms do not support NCCL (i.e., Jetson)
552 nccl # Provides nccl.h AND a static copy of NCCL!
553 ]
554 ++ lists.optionals (cudaOlder "11.8") [
555 cuda_nvprof # <cuda_profiler_api.h>
556 ]
557 ++ lists.optionals (cudaAtLeast "11.8") [
558 cuda_profiler_api # <cuda_profiler_api.h>
559 ]
560 )
561 ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]
562 ++ lib.optionals (cudaSupport || rocmSupport) [ effectiveMagma ]
563 ++ lib.optionals stdenv.hostPlatform.isLinux [ numactl ]
564 ++ lib.optionals stdenv.hostPlatform.isDarwin [
565 apple-sdk_13
566 ]
567 ++ lib.optionals tritonSupport [ _tritonEffective ]
568 ++ lib.optionals MPISupport [ mpi ]
569 ++ lib.optionals rocmSupport [
570 rocmtoolkit_joined
571 rocmPackages.clr # Added separately so setup hook applies
572 ];
573
574 pythonRelaxDeps = [
575 "sympy"
576 ];
577 dependencies =
578 [
579 astunparse
580 expecttest
581 filelock
582 fsspec
583 hypothesis
584 jinja2
585 networkx
586 ninja
587 packaging
588 psutil
589 pyyaml
590 requests
591 sympy
592 types-dataclasses
593 typing-extensions
594
595 # the following are required for tensorboard support
596 pillow
597 six
598 tensorboard
599 protobuf
600
601 # torch/csrc requires `pybind11` at runtime
602 pybind11
603 ]
604 ++ lib.optionals tritonSupport [ _tritonEffective ]
605 ++ lib.optionals vulkanSupport [
606 vulkan-headers
607 vulkan-loader
608 ];
609
610 propagatedCxxBuildInputs =
611 [ ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
612
613 # Tests take a long time and may be flaky, so just sanity-check imports
614 doCheck = false;
615
616 pythonImportsCheck = [ "torch" ];
617
618 nativeCheckInputs = [
619 hypothesis
620 ninja
621 psutil
622 ];
623
624 checkPhase =
625 with lib.versions;
626 with lib.strings;
627 concatStringsSep " " [
628 "runHook preCheck"
629 "${python.interpreter} test/run_test.py"
630 "--exclude"
631 (concatStringsSep " " [
632 "utils" # utils requires git, which is not allowed in the check phase
633
634 # "dataloader" # psutils correctly finds and triggers multiprocessing, but is too sandboxed to run -- resulting in numerous errors
635 # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build
636
637 # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins
638 (optionalString (majorMinor version == "1.3") "tensorboard")
639 ])
640 "runHook postCheck"
641 ];
642
643 pythonRemoveDeps = [
644 # In our dist-info the name is just "triton"
645 "pytorch-triton-rocm"
646 ];
647
648 postInstall =
649 ''
650 find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +
651
652 mkdir $dev
653
654 # CppExtension requires that include files are packaged with the main
655 # python library output; which is why they are copied here.
656 cp -r $out/${python.sitePackages}/torch/include $dev/include
657
658 # Cmake files under /share are different and can be safely moved. This
659 # avoids unnecessary closure blow-up due to apple sdk references when
660 # USE_DISTRIBUTED is enabled.
661 mv $out/${python.sitePackages}/torch/share $dev/share
662
663 # Fix up library paths for split outputs
664 substituteInPlace \
665 $dev/share/cmake/Torch/TorchConfig.cmake \
666 --replace-fail \''${TORCH_INSTALL_PREFIX}/lib "$lib/lib"
667
668 substituteInPlace \
669 $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \
670 --replace-fail \''${_IMPORT_PREFIX}/lib "$lib/lib"
671
672 mkdir $lib
673 mv $out/${python.sitePackages}/torch/lib $lib/lib
674 ln -s $lib/lib $out/${python.sitePackages}/torch/lib
675 ''
676 + lib.optionalString rocmSupport ''
677 substituteInPlace $dev/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake \
678 --replace-fail "\''${_IMPORT_PREFIX}/lib64" "$lib/lib"
679
680 substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \
681 --replace-fail "/build/${src.name}/torch/include" "$dev/include"
682 '';
683
684 postFixup =
685 ''
686 mkdir -p "$cxxdev/nix-support"
687 printWords "''${propagatedCxxBuildInputs[@]}" >> "$cxxdev/nix-support/propagated-build-inputs"
688 ''
689 + lib.optionalString stdenv.hostPlatform.isDarwin ''
690 for f in $(ls $lib/lib/*.dylib); do
691 install_name_tool -id $lib/lib/$(basename $f) $f || true
692 done
693
694 install_name_tool -change @rpath/libshm.dylib $lib/lib/libshm.dylib $lib/lib/libtorch_python.dylib
695 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libtorch_python.dylib
696 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch_python.dylib
697
698 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch.dylib
699
700 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libshm.dylib
701 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libshm.dylib
702 '';
703
704 # See https://github.com/NixOS/nixpkgs/issues/296179
705 #
706 # This is a quick hack to add `libnvrtc` to the runpath so that torch can find
707 # it when it is needed at runtime.
708 extraRunpaths = lib.optionals cudaSupport [ "${lib.getLib cudaPackages.cuda_nvrtc}/lib" ];
709 postPhases = lib.optionals stdenv.hostPlatform.isLinux [ "postPatchelfPhase" ];
710 postPatchelfPhase = ''
711 while IFS= read -r -d $'\0' elf ; do
712 for extra in $extraRunpaths ; do
713 echo patchelf "$elf" --add-rpath "$extra" >&2
714 patchelf "$elf" --add-rpath "$extra"
715 done
716 done < <(
717 find "''${!outputLib}" "$out" -type f -iname '*.so' -print0
718 )
719 '';
720
721 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
722 requiredSystemFeatures = [ "big-parallel" ];
723
724 passthru = {
725 inherit
726 cudaSupport
727 cudaPackages
728 rocmSupport
729 rocmPackages
730 unroll-src
731 ;
732 cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
733 # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
734 blasProvider = blas.provider;
735 # To help debug when a package is broken due to CUDA support
736 inherit brokenConditions;
737 tests = callPackage ../tests { };
738 };
739
740 meta = {
741 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
742 # keep PyTorch in the description so the package can be found under that name on search.nixos.org
743 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
744 homepage = "https://pytorch.org/";
745 license = lib.licenses.bsd3;
746 maintainers = with lib.maintainers; [
747 GaetanLepage
748 teh
749 thoughtpolice
750 tscholak
751 ]; # tscholak esp. for darwin-related builds
752 platforms =
753 lib.platforms.linux
754 ++ lib.optionals (!cudaSupport && !rocmSupport) lib.platforms.darwin;
755 broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
756 };
757}