1{
2 stdenv,
3 lib,
4 fetchFromGitHub,
5 fetchFromGitLab,
6 git-unroll,
7 buildPythonPackage,
8 python,
9 runCommand,
10 writeShellScript,
11 config,
12 cudaSupport ? config.cudaSupport,
13 cudaPackages,
14 autoAddDriverRunpath,
15 effectiveMagma ?
16 if cudaSupport then
17 magma-cuda-static
18 else if rocmSupport then
19 magma-hip
20 else
21 magma,
22 magma,
23 magma-hip,
24 magma-cuda-static,
25 # Use the system NCCL as long as we're targeting CUDA on a supported platform.
26 useSystemNccl ? (cudaSupport && !cudaPackages.nccl.meta.unsupported || rocmSupport),
27 MPISupport ? false,
28 mpi,
29 buildDocs ? false,
30
31 # tests.cudaAvailable:
32 callPackage,
33
34 # Native build inputs
35 cmake,
36 symlinkJoin,
37 which,
38 pybind11,
39 pkg-config,
40 removeReferencesTo,
41
42 # Build inputs
43 apple-sdk_13,
44 numactl,
45 llvmPackages,
46
47 # dependencies
48 astunparse,
49 binutils,
50 expecttest,
51 filelock,
52 fsspec,
53 hypothesis,
54 jinja2,
55 networkx,
56 packaging,
57 psutil,
58 pyyaml,
59 requests,
60 sympy,
61 types-dataclasses,
62 typing-extensions,
63 # ROCm build and `torch.compile` requires `triton`
64 tritonSupport ? (!stdenv.hostPlatform.isDarwin),
65 triton,
66
67 # TODO: 1. callPackage needs to learn to distinguish between the task
68 # of "asking for an attribute from the parent scope" and
69 # the task of "exposing a formal parameter in .override".
70 # TODO: 2. We should probably abandon attributes such as `torchWithCuda` (etc.)
71 # as they routinely end up consuming the wrong arguments\
72 # (dependencies without cuda support).
73 # Instead we should rely on overlays and nixpkgsFun.
74 # (@SomeoneSerge)
75 _tritonEffective ?
76 if cudaSupport then
77 triton-cuda
78 else if rocmSupport then
79 rocmPackages.triton
80 else
81 triton,
82 triton-cuda,
83
84 # Disable MKLDNN on aarch64-darwin, it negatively impacts performance,
85 # this is also what official pytorch build does
86 mklDnnSupport ? !(stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64),
87
88 # virtual pkg that consistently instantiates blas across nixpkgs
89 # See https://github.com/NixOS/nixpkgs/pull/83888
90 blas,
91
92 # ninja (https://ninja-build.org) must be available to run C++ extensions tests,
93 ninja,
94
95 # dependencies for torch.utils.tensorboard
96 pillow,
97 six,
98 tensorboard,
99 protobuf,
100
101 # ROCm dependencies
102 rocmSupport ? config.rocmSupport,
103 rocmPackages,
104 gpuTargets ? [ ],
105
106 vulkanSupport ? false,
107 vulkan-headers,
108 vulkan-loader,
109 shaderc,
110}:
111
112let
113 inherit (lib)
114 attrsets
115 lists
116 strings
117 trivial
118 ;
119 inherit (cudaPackages) cudnn flags nccl;
120
121 triton = throw "python3Packages.torch: use _tritonEffective instead of triton to avoid divergence";
122
123 setBool = v: if v then "1" else "0";
124
125 # https://github.com/pytorch/pytorch/blob/v2.7.0/torch/utils/cpp_extension.py#L2343-L2345
126 supportedTorchCudaCapabilities =
127 let
128 real = [
129 "3.5"
130 "3.7"
131 "5.0"
132 "5.2"
133 "5.3"
134 "6.0"
135 "6.1"
136 "6.2"
137 "7.0"
138 "7.2"
139 "7.5"
140 "8.0"
141 "8.6"
142 "8.7"
143 "8.9"
144 "9.0"
145 "9.0a"
146 "10.0"
147 "10.0"
148 "10.0a"
149 "10.1"
150 "10.1a"
151 "12.0"
152 "12.0a"
153 ];
154 ptx = lists.map (x: "${x}+PTX") real;
155 in
156 real ++ ptx;
157
158 # NOTE: The lists.subtractLists function is perhaps a bit unintuitive. It subtracts the elements
159 # of the first list *from* the second list. That means:
160 # lists.subtractLists a b = b - a
161
162 # For CUDA
163 supportedCudaCapabilities = lists.intersectLists flags.cudaCapabilities supportedTorchCudaCapabilities;
164 unsupportedCudaCapabilities = lists.subtractLists supportedCudaCapabilities flags.cudaCapabilities;
165
166 isCudaJetson = cudaSupport && cudaPackages.flags.isJetsonBuild;
167
168 # Use trivial.warnIf to print a warning if any unsupported GPU targets are specified.
169 gpuArchWarner =
170 supported: unsupported:
171 trivial.throwIf (supported == [ ]) (
172 "No supported GPU targets specified. Requested GPU targets: "
173 + strings.concatStringsSep ", " unsupported
174 ) supported;
175
176 # Create the gpuTargetString.
177 gpuTargetString = strings.concatStringsSep ";" (
178 if gpuTargets != [ ] then
179 # If gpuTargets is specified, it always takes priority.
180 gpuTargets
181 else if cudaSupport then
182 gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities
183 else if rocmSupport then
184 # Remove RDNA1 gfx101x archs from default ROCm support list to avoid
185 # use of undeclared identifier 'CK_BUFFER_RESOURCE_3RD_DWORD'
186 # TODO: Retest after ROCm 6.4 or torch 2.8
187 lib.lists.subtractLists [
188 "gfx1010"
189 "gfx1012"
190 ] (rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets)
191 else
192 throw "No GPU targets specified"
193 );
194
195 rocmtoolkit_joined = symlinkJoin {
196 name = "rocm-merged";
197
198 paths = with rocmPackages; [
199 rocm-core
200 clr
201 rccl
202 miopen
203 aotriton
204 composable_kernel
205 rocrand
206 rocblas
207 rocsparse
208 hipsparse
209 rocthrust
210 rocprim
211 hipcub
212 roctracer
213 rocfft
214 rocsolver
215 hipfft
216 hiprand
217 hipsolver
218 hipblas-common
219 hipblas
220 hipblaslt
221 rocminfo
222 rocm-comgr
223 rocm-device-libs
224 rocm-runtime
225 clr.icd
226 hipify
227 ];
228
229 # Fix `setuptools` not being found
230 postBuild = ''
231 rm -rf $out/nix-support
232 '';
233 };
234
235 brokenConditions = attrsets.filterAttrs (_: cond: cond) {
236 "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
237 "CUDA is not targeting Linux" = cudaSupport && !stdenv.hostPlatform.isLinux;
238 "Unsupported CUDA version" =
239 cudaSupport
240 && !(builtins.elem cudaPackages.cudaMajorVersion [
241 "11"
242 "12"
243 ]);
244 "MPI cudatoolkit does not match cudaPackages.cudatoolkit" =
245 MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
246 # This used to be a deep package set comparison between cudaPackages and
247 # effectiveMagma.cudaPackages, making torch too strict in cudaPackages.
248 # In particular, this triggered warnings from cuda's `aliases.nix`
249 "Magma cudaPackages does not match cudaPackages" =
250 cudaSupport
251 && (effectiveMagma.cudaPackages.cudaMajorMinorVersion != cudaPackages.cudaMajorMinorVersion);
252 };
253
254 unroll-src = writeShellScript "unroll-src" ''
255 echo "{
256 version,
257 fetchFromGitLab,
258 fetchFromGitHub,
259 runCommand,
260 }:
261 assert version == "'"'$1'"'";"
262 ${lib.getExe git-unroll} https://github.com/pytorch/pytorch v$1
263 echo
264 echo "# Update using: unroll-src [version]"
265 '';
266
267 stdenv' = if cudaSupport then cudaPackages.backendStdenv else stdenv;
268in
269buildPythonPackage rec {
270 pname = "torch";
271 # Don't forget to update torch-bin to the same version.
272 version = "2.7.1";
273 pyproject = true;
274
275 stdenv = stdenv';
276
277 outputs = [
278 "out" # output standard python package
279 "dev" # output libtorch headers
280 "lib" # output libtorch libraries
281 "cxxdev" # propagated deps for the cmake consumers of torch
282 ];
283 cudaPropagateToOutput = "cxxdev";
284
285 src = callPackage ./src.nix {
286 inherit
287 version
288 fetchFromGitHub
289 fetchFromGitLab
290 runCommand
291 ;
292 };
293
294 patches = [
295 ./clang19-template-warning.patch
296 ]
297 ++ lib.optionals cudaSupport [ ./fix-cmake-cuda-toolkit.patch ]
298 ++ lib.optionals stdenv.hostPlatform.isLinux [
299 # Propagate CUPTI to Kineto by overriding the search path with environment variables.
300 # https://github.com/pytorch/pytorch/pull/108847
301 ./pytorch-pr-108847.patch
302 ]
303 ++ lib.optionals (lib.getName blas.provider == "mkl") [
304 # The CMake install tries to add some hardcoded rpaths, incompatible
305 # with the Nix store, which fails. Simply remove this step to get
306 # rpaths that point to the Nix store.
307 ./disable-cmake-mkl-rpath.patch
308 ];
309
310 postPatch = ''
311 # Prevent NCCL from being cloned during the configure phase
312 # TODO: remove when updating to the next release as it will not be needed anymore
313 substituteInPlace tools/build_pytorch_libs.py \
314 --replace-fail " checkout_nccl()" " "
315
316 substituteInPlace cmake/public/cuda.cmake \
317 --replace-fail \
318 'message(FATAL_ERROR "Found two conflicting CUDA' \
319 'message(WARNING "Found two conflicting CUDA' \
320 --replace-warn \
321 "set(CUDAToolkit_ROOT" \
322 "# Upstream: set(CUDAToolkit_ROOT"
323 substituteInPlace third_party/gloo/cmake/Cuda.cmake \
324 --replace-warn "find_package(CUDAToolkit 7.0" "find_package(CUDAToolkit"
325
326 # annotations (3.7), print_function (3.0), with_statement (2.6) are all supported
327 sed -i -e "/from __future__ import/d" **.py
328 substituteInPlace third_party/NNPACK/CMakeLists.txt \
329 --replace-fail "PYTHONPATH=" 'PYTHONPATH=$ENV{PYTHONPATH}:'
330 # flag from cmakeFlags doesn't work, not clear why
331 # setting it at the top of NNPACK's own CMakeLists does
332 sed -i '2s;^;set(PYTHON_SIX_SOURCE_DIR ${six.src})\n;' third_party/NNPACK/CMakeLists.txt
333
334 # Ensure that torch profiler unwind uses addr2line from nix
335 substituteInPlace torch/csrc/profiler/unwind/unwind.cpp \
336 --replace-fail 'addr2line_binary_ = "addr2line"' 'addr2line_binary_ = "${lib.getExe' binutils "addr2line"}"'
337 ''
338 + lib.optionalString rocmSupport ''
339 # https://github.com/facebookincubator/gloo/pull/297
340 substituteInPlace third_party/gloo/cmake/Hipify.cmake \
341 --replace-fail "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}"
342
343 # Replace hard-coded rocm paths
344 substituteInPlace caffe2/CMakeLists.txt \
345 --replace-fail "hcc/include" "hip/include" \
346 --replace-fail "rocblas/include" "include/rocblas" \
347 --replace-fail "hipsparse/include" "include/hipsparse"
348
349 # Doesn't pick up the environment variable?
350 substituteInPlace third_party/kineto/libkineto/CMakeLists.txt \
351 --replace-fail "\''$ENV{ROCM_SOURCE_DIR}" "${rocmtoolkit_joined}"
352
353 # Strangely, this is never set in cmake
354 substituteInPlace cmake/public/LoadHIP.cmake \
355 --replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \
356 "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitVersion rocmPackages.clr.version))})"
357
358 # Use composable kernel as dependency, rather than built-in third-party
359 substituteInPlace aten/src/ATen/CMakeLists.txt \
360 --replace-fail "list(APPEND ATen_HIP_INCLUDE \''${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)" "" \
361 --replace-fail "list(APPEND ATen_HIP_INCLUDE \''${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)" ""
362 ''
363 # Detection of NCCL version doesn't work particularly well when using the static binary.
364 + lib.optionalString cudaSupport ''
365 substituteInPlace cmake/Modules/FindNCCL.cmake \
366 --replace-fail \
367 'message(FATAL_ERROR "Found NCCL header version and library version' \
368 'message(WARNING "Found NCCL header version and library version'
369 ''
370 # Remove PyTorch's FindCUDAToolkit.cmake and use CMake's default.
371 # NOTE: Parts of pytorch rely on unmaintained FindCUDA.cmake with custom patches to support e.g.
372 # newer architectures (sm_90a). We do want to delete vendored patches, but have to keep them
373 # until https://github.com/pytorch/pytorch/issues/76082 is addressed
374 + lib.optionalString cudaSupport ''
375 rm cmake/Modules/FindCUDAToolkit.cmake
376 '';
377
378 # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
379 # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
380 # without extreme care to ensure they don't lock each other out of shared resources.
381 # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
382 preConfigure =
383 lib.optionalString cudaSupport ''
384 export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
385 export CUPTI_INCLUDE_DIR=${lib.getDev cudaPackages.cuda_cupti}/include
386 export CUPTI_LIBRARY_DIR=${lib.getLib cudaPackages.cuda_cupti}/lib
387 ''
388 + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
389 export CUDNN_INCLUDE_DIR=${lib.getLib cudnn}/include
390 export CUDNN_LIB_DIR=${lib.getLib cudnn}/lib
391 ''
392 + lib.optionalString rocmSupport ''
393 export ROCM_PATH=${rocmtoolkit_joined}
394 export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
395 export PYTORCH_ROCM_ARCH="${gpuTargetString}"
396 export CMAKE_CXX_FLAGS="-I${rocmtoolkit_joined}/include -I${rocmtoolkit_joined}/include/rocblas"
397 python tools/amd_build/build_amd.py
398 '';
399
400 # Use pytorch's custom configurations
401 dontUseCmakeConfigure = true;
402
403 # causes possible redefinition of _FORTIFY_SOURCE
404 hardeningDisable = [ "fortify3" ];
405
406 BUILD_NAMEDTENSOR = setBool true;
407 BUILD_DOCS = setBool buildDocs;
408
409 # We only do an imports check, so do not build tests either.
410 BUILD_TEST = setBool false;
411
412 # ninja hook doesn't automatically turn on ninja
413 # because pytorch setup.py is responsible for this
414 CMAKE_GENERATOR = "Ninja";
415
416 # Unlike MKL, oneDNN (née MKLDNN) is FOSS, so we enable support for
417 # it by default. PyTorch currently uses its own vendored version
418 # of oneDNN through Intel iDeep.
419 USE_MKLDNN = setBool mklDnnSupport;
420 USE_MKLDNN_CBLAS = setBool mklDnnSupport;
421
422 # Avoid using pybind11 from git submodule
423 # Also avoids pytorch exporting the headers of pybind11
424 USE_SYSTEM_PYBIND11 = true;
425
426 # Multicore CPU convnet support
427 USE_NNPACK = 1;
428
429 # Explicitly enable MPS for Darwin
430 USE_MPS = setBool stdenv.hostPlatform.isDarwin;
431
432 # building torch.distributed on Darwin is disabled by default
433 # https://pytorch.org/docs/stable/distributed.html#torch.distributed.is_available
434 USE_DISTRIBUTED = setBool true;
435
436 cmakeFlags = [
437 (lib.cmakeFeature "PYTHON_SIX_SOURCE_DIR" "${six.src}")
438 # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
439 (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaMajorMinorVersion)
440 ]
441 ++ lib.optionals cudaSupport [
442 # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
443 # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
444 (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaMajorMinorVersion)
445 ];
446
447 preBuild = ''
448 export MAX_JOBS=$NIX_BUILD_CORES
449 ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
450 ${cmake}/bin/cmake build
451 '';
452
453 preFixup = ''
454 function join_by { local IFS="$1"; shift; echo "$*"; }
455 function strip2 {
456 IFS=':'
457 read -ra RP <<< $(patchelf --print-rpath $1)
458 IFS=' '
459 RP_NEW=$(join_by : ''${RP[@]:2})
460 patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
461 }
462 for f in $(find ''${out} -name 'libcaffe2*.so')
463 do
464 strip2 $f
465 done
466 '';
467
468 # Override the (weirdly) wrong version set by default. See
469 # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
470 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
471 PYTORCH_BUILD_VERSION = version;
472 PYTORCH_BUILD_NUMBER = 0;
473
474 # In-tree builds of NCCL are not supported.
475 # Use NCCL when cudaSupport is enabled and nccl is available.
476 USE_NCCL = setBool useSystemNccl;
477 USE_SYSTEM_NCCL = USE_NCCL;
478 USE_STATIC_NCCL = USE_NCCL;
479
480 # Set the correct Python library path, broken since
481 # https://github.com/pytorch/pytorch/commit/3d617333e
482 PYTHON_LIB_REL_PATH = "${placeholder "out"}/${python.sitePackages}";
483
484 env = {
485 # disable warnings as errors as they break the build on every compiler
486 # bump, among other things.
487 # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
488 # https://github.com/pytorch/pytorch/blob/v1.11.0/setup.py#L17
489 NIX_CFLAGS_COMPILE = toString (
490 [
491 "-Wno-error"
492 ]
493 # fix build aarch64-linux build failure with GCC14
494 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
495 "-Wno-error=incompatible-pointer-types"
496 ]
497 );
498 USE_VULKAN = setBool vulkanSupport;
499 }
500 // lib.optionalAttrs vulkanSupport {
501 VULKAN_SDK = shaderc.bin;
502 }
503 // lib.optionalAttrs rocmSupport {
504 AOTRITON_INSTALLED_PREFIX = "${rocmPackages.aotriton}";
505 };
506
507 nativeBuildInputs = [
508 cmake
509 which
510 ninja
511 pybind11
512 pkg-config
513 removeReferencesTo
514 ]
515 ++ lib.optionals cudaSupport (
516 with cudaPackages;
517 [
518 autoAddDriverRunpath
519 cuda_nvcc
520 ]
521 )
522 ++ lib.optionals isCudaJetson [ cudaPackages.autoAddCudaCompatRunpath ]
523 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
524
525 buildInputs = [
526 blas
527 blas.provider
528 ]
529 # Including openmp leads to two copies being used. This segfaults on ARM.
530 # https://github.com/pytorch/pytorch/issues/149201#issuecomment-2776842320
531 # ++ lib.optionals stdenv.cc.isClang [ llvmPackages.openmp ]
532 ++ lib.optionals cudaSupport (
533 with cudaPackages;
534 [
535 cuda_cccl # <thrust/*>
536 cuda_cudart # cuda_runtime.h and libraries
537 cuda_cupti # For kineto
538 cuda_nvcc # crt/host_config.h; even though we include this in nativeBuildInputs, it's needed here too
539 cuda_nvml_dev # <nvml.h>
540 cuda_nvrtc
541 cuda_nvtx # -llibNVToolsExt
542 cusparselt
543 libcublas
544 libcufft
545 libcufile
546 libcurand
547 libcusolver
548 libcusparse
549 ]
550 ++ lists.optionals (cudaPackages ? cudnn) [ cudnn ]
551 ++ lists.optionals useSystemNccl [
552 # Some platforms do not support NCCL (i.e., Jetson)
553 nccl # Provides nccl.h AND a static copy of NCCL!
554 ]
555 ++ lists.optionals (cudaOlder "11.8") [
556 cuda_nvprof # <cuda_profiler_api.h>
557 ]
558 ++ lists.optionals (cudaAtLeast "11.8") [
559 cuda_profiler_api # <cuda_profiler_api.h>
560 ]
561 )
562 ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]
563 ++ lib.optionals (cudaSupport || rocmSupport) [ effectiveMagma ]
564 ++ lib.optionals stdenv.hostPlatform.isLinux [ numactl ]
565 ++ lib.optionals stdenv.hostPlatform.isDarwin [
566 apple-sdk_13
567 ]
568 ++ lib.optionals tritonSupport [ _tritonEffective ]
569 ++ lib.optionals MPISupport [ mpi ]
570 ++ lib.optionals rocmSupport [
571 rocmtoolkit_joined
572 rocmPackages.clr # Added separately so setup hook applies
573 ];
574
575 pythonRelaxDeps = [
576 "sympy"
577 ];
578 dependencies = [
579 astunparse
580 expecttest
581 filelock
582 fsspec
583 hypothesis
584 jinja2
585 networkx
586 ninja
587 packaging
588 psutil
589 pyyaml
590 requests
591 sympy
592 types-dataclasses
593 typing-extensions
594
595 # the following are required for tensorboard support
596 pillow
597 six
598 tensorboard
599 protobuf
600
601 # torch/csrc requires `pybind11` at runtime
602 pybind11
603 ]
604 ++ lib.optionals tritonSupport [ _tritonEffective ]
605 ++ lib.optionals vulkanSupport [
606 vulkan-headers
607 vulkan-loader
608 ];
609
610 propagatedCxxBuildInputs =
611 [ ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
612
613 # Tests take a long time and may be flaky, so just sanity-check imports
614 doCheck = false;
615
616 pythonImportsCheck = [ "torch" ];
617
618 nativeCheckInputs = [
619 hypothesis
620 ninja
621 psutil
622 ];
623
624 checkPhase =
625 with lib.versions;
626 with lib.strings;
627 concatStringsSep " " [
628 "runHook preCheck"
629 "${python.interpreter} test/run_test.py"
630 "--exclude"
631 (concatStringsSep " " [
632 "utils" # utils requires git, which is not allowed in the check phase
633
634 # "dataloader" # psutils correctly finds and triggers multiprocessing, but is too sandboxed to run -- resulting in numerous errors
635 # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build
636
637 # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins
638 (optionalString (majorMinor version == "1.3") "tensorboard")
639 ])
640 "runHook postCheck"
641 ];
642
643 pythonRemoveDeps = [
644 # In our dist-info the name is just "triton"
645 "pytorch-triton-rocm"
646 ];
647
648 postInstall = ''
649 find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +
650
651 mkdir $dev
652
653 # CppExtension requires that include files are packaged with the main
654 # python library output; which is why they are copied here.
655 cp -r $out/${python.sitePackages}/torch/include $dev/include
656
657 # Cmake files under /share are different and can be safely moved. This
658 # avoids unnecessary closure blow-up due to apple sdk references when
659 # USE_DISTRIBUTED is enabled.
660 mv $out/${python.sitePackages}/torch/share $dev/share
661
662 # Fix up library paths for split outputs
663 substituteInPlace \
664 $dev/share/cmake/Torch/TorchConfig.cmake \
665 --replace-fail \''${TORCH_INSTALL_PREFIX}/lib "$lib/lib"
666
667 substituteInPlace \
668 $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \
669 --replace-fail \''${_IMPORT_PREFIX}/lib "$lib/lib"
670
671 mkdir $lib
672 mv $out/${python.sitePackages}/torch/lib $lib/lib
673 ln -s $lib/lib $out/${python.sitePackages}/torch/lib
674 ''
675 + lib.optionalString rocmSupport ''
676 substituteInPlace $dev/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake \
677 --replace-fail "\''${_IMPORT_PREFIX}/lib64" "$lib/lib"
678
679 substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \
680 --replace-fail "/build/${src.name}/torch/include" "$dev/include"
681 '';
682
683 postFixup = ''
684 mkdir -p "$cxxdev/nix-support"
685 printWords "''${propagatedCxxBuildInputs[@]}" >> "$cxxdev/nix-support/propagated-build-inputs"
686 ''
687 + lib.optionalString stdenv.hostPlatform.isDarwin ''
688 for f in $(ls $lib/lib/*.dylib); do
689 install_name_tool -id $lib/lib/$(basename $f) $f || true
690 done
691
692 install_name_tool -change @rpath/libshm.dylib $lib/lib/libshm.dylib $lib/lib/libtorch_python.dylib
693 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libtorch_python.dylib
694 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch_python.dylib
695
696 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch.dylib
697
698 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libshm.dylib
699 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libshm.dylib
700 '';
701
702 # See https://github.com/NixOS/nixpkgs/issues/296179
703 #
704 # This is a quick hack to add `libnvrtc` to the runpath so that torch can find
705 # it when it is needed at runtime.
706 extraRunpaths = lib.optionals cudaSupport [ "${lib.getLib cudaPackages.cuda_nvrtc}/lib" ];
707 postPhases = lib.optionals stdenv.hostPlatform.isLinux [ "postPatchelfPhase" ];
708 postPatchelfPhase = ''
709 while IFS= read -r -d $'\0' elf ; do
710 for extra in $extraRunpaths ; do
711 echo patchelf "$elf" --add-rpath "$extra" >&2
712 patchelf "$elf" --add-rpath "$extra"
713 done
714 done < <(
715 find "''${!outputLib}" "$out" -type f -iname '*.so' -print0
716 )
717 '';
718
719 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
720 requiredSystemFeatures = [ "big-parallel" ];
721
722 passthru = {
723 inherit
724 cudaSupport
725 cudaPackages
726 rocmSupport
727 rocmPackages
728 unroll-src
729 ;
730 cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
731 # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
732 blasProvider = blas.provider;
733 # To help debug when a package is broken due to CUDA support
734 inherit brokenConditions;
735 tests = callPackage ../tests { };
736 };
737
738 meta = {
739 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
740 # keep PyTorch in the description so the package can be found under that name on search.nixos.org
741 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
742 homepage = "https://pytorch.org/";
743 license = lib.licenses.bsd3;
744 maintainers = with lib.maintainers; [
745 GaetanLepage
746 teh
747 thoughtpolice
748 tscholak
749 ]; # tscholak esp. for darwin-related builds
750 platforms =
751 lib.platforms.linux ++ lib.optionals (!cudaSupport && !rocmSupport) lib.platforms.darwin;
752 broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
753 };
754}