1{
2 stdenv,
3 lib,
4 fetchFromGitHub,
5 fetchFromGitLab,
6 git-unroll,
7 buildPythonPackage,
8 python,
9 runCommand,
10 writeShellScript,
11 config,
12 cudaSupport ? config.cudaSupport,
13 cudaPackages,
14 autoAddDriverRunpath,
15 effectiveMagma ?
16 if cudaSupport then
17 magma-cuda-static
18 else if rocmSupport then
19 magma-hip
20 else
21 magma,
22 magma,
23 magma-hip,
24 magma-cuda-static,
25 # Use the system NCCL as long as we're targeting CUDA on a supported platform.
26 useSystemNccl ? (cudaSupport && !cudaPackages.nccl.meta.unsupported || rocmSupport),
27 MPISupport ? false,
28 mpi,
29 buildDocs ? false,
30
31 # tests.cudaAvailable:
32 callPackage,
33
34 # Native build inputs
35 cmake,
36 symlinkJoin,
37 which,
38 pybind11,
39 pkg-config,
40 removeReferencesTo,
41
42 # Build inputs
43 apple-sdk_13,
44 numactl,
45
46 # dependencies
47 astunparse,
48 expecttest,
49 filelock,
50 fsspec,
51 hypothesis,
52 jinja2,
53 networkx,
54 packaging,
55 psutil,
56 pyyaml,
57 requests,
58 sympy,
59 types-dataclasses,
60 typing-extensions,
61 # ROCm build and `torch.compile` requires `triton`
62 tritonSupport ? (!stdenv.hostPlatform.isDarwin),
63 triton,
64
65 # TODO: 1. callPackage needs to learn to distinguish between the task
66 # of "asking for an attribute from the parent scope" and
67 # the task of "exposing a formal parameter in .override".
68 # TODO: 2. We should probably abandon attributes such as `torchWithCuda` (etc.)
69 # as they routinely end up consuming the wrong arguments\
70 # (dependencies without cuda support).
71 # Instead we should rely on overlays and nixpkgsFun.
72 # (@SomeoneSerge)
73 _tritonEffective ?
74 if cudaSupport then
75 triton-cuda
76 else if rocmSupport then
77 rocmPackages.triton
78 else
79 triton,
80 triton-cuda,
81
82 # Disable MKLDNN on aarch64-darwin, it negatively impacts performance,
83 # this is also what official pytorch build does
84 mklDnnSupport ? !(stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64),
85
86 # virtual pkg that consistently instantiates blas across nixpkgs
87 # See https://github.com/NixOS/nixpkgs/pull/83888
88 blas,
89
90 # ninja (https://ninja-build.org) must be available to run C++ extensions tests,
91 ninja,
92
93 # dependencies for torch.utils.tensorboard
94 pillow,
95 six,
96 tensorboard,
97 protobuf,
98
99 # ROCm dependencies
100 rocmSupport ? config.rocmSupport,
101 rocmPackages,
102 gpuTargets ? [ ],
103
104 vulkanSupport ? false,
105 vulkan-headers,
106 vulkan-loader,
107 shaderc,
108}:
109
110let
111 inherit (lib)
112 attrsets
113 lists
114 strings
115 trivial
116 ;
117 inherit (cudaPackages) cudnn flags nccl;
118
119 triton = throw "python3Packages.torch: use _tritonEffective instead of triton to avoid divergence";
120
121 setBool = v: if v then "1" else "0";
122
123 # https://github.com/pytorch/pytorch/blob/v2.6.0/torch/utils/cpp_extension.py#L2046-L2048
124 supportedTorchCudaCapabilities =
125 let
126 real = [
127 "3.5"
128 "3.7"
129 "5.0"
130 "5.2"
131 "5.3"
132 "6.0"
133 "6.1"
134 "6.2"
135 "7.0"
136 "7.2"
137 "7.5"
138 "8.0"
139 "8.6"
140 "8.7"
141 "8.9"
142 "9.0"
143 "9.0a"
144 "10.0"
145 ];
146 ptx = lists.map (x: "${x}+PTX") real;
147 in
148 real ++ ptx;
149
150 # NOTE: The lists.subtractLists function is perhaps a bit unintuitive. It subtracts the elements
151 # of the first list *from* the second list. That means:
152 # lists.subtractLists a b = b - a
153
154 # For CUDA
155 supportedCudaCapabilities = lists.intersectLists flags.cudaCapabilities supportedTorchCudaCapabilities;
156 unsupportedCudaCapabilities = lists.subtractLists supportedCudaCapabilities flags.cudaCapabilities;
157
158 isCudaJetson = cudaSupport && cudaPackages.flags.isJetsonBuild;
159
160 # Use trivial.warnIf to print a warning if any unsupported GPU targets are specified.
161 gpuArchWarner =
162 supported: unsupported:
163 trivial.throwIf (supported == [ ]) (
164 "No supported GPU targets specified. Requested GPU targets: "
165 + strings.concatStringsSep ", " unsupported
166 ) supported;
167
168 # Create the gpuTargetString.
169 gpuTargetString = strings.concatStringsSep ";" (
170 if gpuTargets != [ ] then
171 # If gpuTargets is specified, it always takes priority.
172 gpuTargets
173 else if cudaSupport then
174 gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities
175 else if rocmSupport then
176 rocmPackages.clr.gpuTargets
177 else
178 throw "No GPU targets specified"
179 );
180
181 rocmtoolkit_joined = symlinkJoin {
182 name = "rocm-merged";
183
184 paths = with rocmPackages; [
185 rocm-core
186 clr
187 rccl
188 miopen
189 aotriton
190 rocrand
191 rocblas
192 rocsparse
193 hipsparse
194 rocthrust
195 rocprim
196 hipcub
197 roctracer
198 rocfft
199 rocsolver
200 hipfft
201 hiprand
202 hipsolver
203 hipblas-common
204 hipblas
205 hipblaslt
206 rocminfo
207 rocm-comgr
208 rocm-device-libs
209 rocm-runtime
210 clr.icd
211 hipify
212 ];
213
214 # Fix `setuptools` not being found
215 postBuild = ''
216 rm -rf $out/nix-support
217 '';
218 };
219
220 brokenConditions = attrsets.filterAttrs (_: cond: cond) {
221 "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
222 "CUDA is not targeting Linux" = cudaSupport && !stdenv.hostPlatform.isLinux;
223 "Unsupported CUDA version" =
224 cudaSupport
225 && !(builtins.elem cudaPackages.cudaMajorVersion [
226 "11"
227 "12"
228 ]);
229 "MPI cudatoolkit does not match cudaPackages.cudatoolkit" =
230 MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
231 # This used to be a deep package set comparison between cudaPackages and
232 # effectiveMagma.cudaPackages, making torch too strict in cudaPackages.
233 # In particular, this triggered warnings from cuda's `aliases.nix`
234 "Magma cudaPackages does not match cudaPackages" =
235 cudaSupport
236 && (effectiveMagma.cudaPackages.cudaMajorMinorVersion != cudaPackages.cudaMajorMinorVersion);
237 };
238
239 unroll-src = writeShellScript "unroll-src" ''
240 echo "{
241 version,
242 fetchFromGitLab,
243 fetchFromGitHub,
244 runCommand,
245 }:
246 assert version == "'"'$1'"'";"
247 ${lib.getExe git-unroll} https://github.com/pytorch/pytorch v$1
248 echo
249 echo "# Update using: unroll-src [version]"
250 '';
251
252 stdenv' = if cudaSupport then cudaPackages.backendStdenv else stdenv;
253in
254buildPythonPackage rec {
255 pname = "torch";
256 # Don't forget to update torch-bin to the same version.
257 version = "2.6.0";
258 pyproject = true;
259
260 stdenv = stdenv';
261
262 outputs = [
263 "out" # output standard python package
264 "dev" # output libtorch headers
265 "lib" # output libtorch libraries
266 "cxxdev" # propagated deps for the cmake consumers of torch
267 ];
268 cudaPropagateToOutput = "cxxdev";
269
270 src = callPackage ./src.nix {
271 inherit
272 version
273 fetchFromGitHub
274 fetchFromGitLab
275 runCommand
276 ;
277 };
278
279 patches =
280 [
281 ./clang19-template-warning.patch
282 # fix invalid static cast in XNNPACK
283 # https://github.com/google/XNNPACK/issues/7489
284 ./xnnpack-bfloat16.patch
285 ]
286 ++ lib.optionals cudaSupport [ ./fix-cmake-cuda-toolkit.patch ]
287 ++ lib.optionals stdenv.hostPlatform.isLinux [
288 # Propagate CUPTI to Kineto by overriding the search path with environment variables.
289 # https://github.com/pytorch/pytorch/pull/108847
290 ./pytorch-pr-108847.patch
291 ]
292 ++ lib.optionals (lib.getName blas.provider == "mkl") [
293 # The CMake install tries to add some hardcoded rpaths, incompatible
294 # with the Nix store, which fails. Simply remove this step to get
295 # rpaths that point to the Nix store.
296 ./disable-cmake-mkl-rpath.patch
297 ];
298
299 postPatch =
300 ''
301 substituteInPlace cmake/public/cuda.cmake \
302 --replace-fail \
303 'message(FATAL_ERROR "Found two conflicting CUDA' \
304 'message(WARNING "Found two conflicting CUDA' \
305 --replace-warn \
306 "set(CUDAToolkit_ROOT" \
307 "# Upstream: set(CUDAToolkit_ROOT"
308 substituteInPlace third_party/gloo/cmake/Cuda.cmake \
309 --replace-warn "find_package(CUDAToolkit 7.0" "find_package(CUDAToolkit"
310
311 # annotations (3.7), print_function (3.0), with_statement (2.6) are all supported
312 sed -i -e "/from __future__ import/d" **.py
313 substituteInPlace third_party/NNPACK/CMakeLists.txt \
314 --replace-fail "PYTHONPATH=" 'PYTHONPATH=$ENV{PYTHONPATH}:'
315 # flag from cmakeFlags doesn't work, not clear why
316 # setting it at the top of NNPACK's own CMakeLists does
317 sed -i '2s;^;set(PYTHON_SIX_SOURCE_DIR ${six.src})\n;' third_party/NNPACK/CMakeLists.txt
318 ''
319 + lib.optionalString rocmSupport ''
320 # https://github.com/facebookincubator/gloo/pull/297
321 substituteInPlace third_party/gloo/cmake/Hipify.cmake \
322 --replace-fail "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}"
323
324 # Replace hard-coded rocm paths
325 substituteInPlace caffe2/CMakeLists.txt \
326 --replace-fail "/opt/rocm" "${rocmtoolkit_joined}" \
327 --replace-fail "hcc/include" "hip/include" \
328 --replace-fail "rocblas/include" "include/rocblas" \
329 --replace-fail "hipsparse/include" "include/hipsparse"
330
331 # Doesn't pick up the environment variable?
332 substituteInPlace third_party/kineto/libkineto/CMakeLists.txt \
333 --replace-fail "\''$ENV{ROCM_SOURCE_DIR}" "${rocmtoolkit_joined}" \
334 --replace-fail "/opt/rocm" "${rocmtoolkit_joined}"
335
336 # Strangely, this is never set in cmake
337 substituteInPlace cmake/public/LoadHIP.cmake \
338 --replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \
339 "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitVersion rocmPackages.clr.version))})"
340 ''
341 # Detection of NCCL version doesn't work particularly well when using the static binary.
342 + lib.optionalString cudaSupport ''
343 substituteInPlace cmake/Modules/FindNCCL.cmake \
344 --replace-fail \
345 'message(FATAL_ERROR "Found NCCL header version and library version' \
346 'message(WARNING "Found NCCL header version and library version'
347 ''
348 # Remove PyTorch's FindCUDAToolkit.cmake and use CMake's default.
349 # NOTE: Parts of pytorch rely on unmaintained FindCUDA.cmake with custom patches to support e.g.
350 # newer architectures (sm_90a). We do want to delete vendored patches, but have to keep them
351 # until https://github.com/pytorch/pytorch/issues/76082 is addressed
352 + lib.optionalString cudaSupport ''
353 rm cmake/Modules/FindCUDAToolkit.cmake
354 '';
355
356 # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
357 # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
358 # without extreme care to ensure they don't lock each other out of shared resources.
359 # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
360 preConfigure =
361 lib.optionalString cudaSupport ''
362 export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
363 export CUPTI_INCLUDE_DIR=${lib.getDev cudaPackages.cuda_cupti}/include
364 export CUPTI_LIBRARY_DIR=${lib.getLib cudaPackages.cuda_cupti}/lib
365 ''
366 + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
367 export CUDNN_INCLUDE_DIR=${lib.getLib cudnn}/include
368 export CUDNN_LIB_DIR=${lib.getLib cudnn}/lib
369 ''
370 + lib.optionalString rocmSupport ''
371 export ROCM_PATH=${rocmtoolkit_joined}
372 export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
373 export PYTORCH_ROCM_ARCH="${gpuTargetString}"
374 export CMAKE_CXX_FLAGS="-I${rocmtoolkit_joined}/include -I${rocmtoolkit_joined}/include/rocblas"
375 python tools/amd_build/build_amd.py
376 '';
377
378 # Use pytorch's custom configurations
379 dontUseCmakeConfigure = true;
380
381 # causes possible redefinition of _FORTIFY_SOURCE
382 hardeningDisable = [ "fortify3" ];
383
384 BUILD_NAMEDTENSOR = setBool true;
385 BUILD_DOCS = setBool buildDocs;
386
387 # We only do an imports check, so do not build tests either.
388 BUILD_TEST = setBool false;
389
390 # ninja hook doesn't automatically turn on ninja
391 # because pytorch setup.py is responsible for this
392 CMAKE_GENERATOR = "Ninja";
393
394 # Unlike MKL, oneDNN (née MKLDNN) is FOSS, so we enable support for
395 # it by default. PyTorch currently uses its own vendored version
396 # of oneDNN through Intel iDeep.
397 USE_MKLDNN = setBool mklDnnSupport;
398 USE_MKLDNN_CBLAS = setBool mklDnnSupport;
399
400 # Avoid using pybind11 from git submodule
401 # Also avoids pytorch exporting the headers of pybind11
402 USE_SYSTEM_PYBIND11 = true;
403
404 # Multicore CPU convnet support
405 USE_NNPACK = 1;
406
407 # Explicitly enable MPS for Darwin
408 USE_MPS = setBool stdenv.hostPlatform.isDarwin;
409
410 # building torch.distributed on Darwin is disabled by default
411 # https://pytorch.org/docs/stable/distributed.html#torch.distributed.is_available
412 USE_DISTRIBUTED = setBool true;
413
414 cmakeFlags =
415 [
416 (lib.cmakeFeature "PYTHON_SIX_SOURCE_DIR" "${six.src}")
417 # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
418 (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaMajorMinorVersion)
419 ]
420 ++ lib.optionals cudaSupport [
421 # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
422 # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
423 (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaMajorMinorVersion)
424 ];
425
426 preBuild = ''
427 export MAX_JOBS=$NIX_BUILD_CORES
428 ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
429 ${cmake}/bin/cmake build
430 '';
431
432 preFixup = ''
433 function join_by { local IFS="$1"; shift; echo "$*"; }
434 function strip2 {
435 IFS=':'
436 read -ra RP <<< $(patchelf --print-rpath $1)
437 IFS=' '
438 RP_NEW=$(join_by : ''${RP[@]:2})
439 patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
440 }
441 for f in $(find ''${out} -name 'libcaffe2*.so')
442 do
443 strip2 $f
444 done
445 '';
446
447 # Override the (weirdly) wrong version set by default. See
448 # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
449 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
450 PYTORCH_BUILD_VERSION = version;
451 PYTORCH_BUILD_NUMBER = 0;
452
453 # In-tree builds of NCCL are not supported.
454 # Use NCCL when cudaSupport is enabled and nccl is available.
455 USE_NCCL = setBool useSystemNccl;
456 USE_SYSTEM_NCCL = USE_NCCL;
457 USE_STATIC_NCCL = USE_NCCL;
458
459 # Set the correct Python library path, broken since
460 # https://github.com/pytorch/pytorch/commit/3d617333e
461 PYTHON_LIB_REL_PATH = "${placeholder "out"}/${python.sitePackages}";
462
463 env =
464 {
465 # disable warnings as errors as they break the build on every compiler
466 # bump, among other things.
467 # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
468 # https://github.com/pytorch/pytorch/blob/v1.11.0/setup.py#L17
469 NIX_CFLAGS_COMPILE = toString (
470 [
471 "-Wno-error"
472 ]
473 # fix build aarch64-linux build failure with GCC14
474 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
475 "-Wno-error=incompatible-pointer-types"
476 ]
477 );
478 USE_VULKAN = setBool vulkanSupport;
479 }
480 // lib.optionalAttrs vulkanSupport {
481 VULKAN_SDK = shaderc.bin;
482 }
483 // lib.optionalAttrs rocmSupport {
484 AOTRITON_INSTALLED_PREFIX = "${rocmPackages.aotriton}";
485 };
486
487 nativeBuildInputs =
488 [
489 cmake
490 which
491 ninja
492 pybind11
493 pkg-config
494 removeReferencesTo
495 ]
496 ++ lib.optionals cudaSupport (
497 with cudaPackages;
498 [
499 autoAddDriverRunpath
500 cuda_nvcc
501 ]
502 )
503 ++ lib.optionals isCudaJetson [ cudaPackages.autoAddCudaCompatRunpath ]
504 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
505
506 buildInputs =
507 [
508 blas
509 blas.provider
510 ]
511 ++ lib.optionals cudaSupport (
512 with cudaPackages;
513 [
514 cuda_cccl # <thrust/*>
515 cuda_cudart # cuda_runtime.h and libraries
516 cuda_cupti # For kineto
517 cuda_nvcc # crt/host_config.h; even though we include this in nativeBuildInputs, it's needed here too
518 cuda_nvml_dev # <nvml.h>
519 cuda_nvrtc
520 cuda_nvtx # -llibNVToolsExt
521 cusparselt
522 libcublas
523 libcufft
524 libcurand
525 libcusolver
526 libcusparse
527 ]
528 ++ lists.optionals (cudaPackages ? cudnn) [ cudnn ]
529 ++ lists.optionals useSystemNccl [
530 # Some platforms do not support NCCL (i.e., Jetson)
531 nccl # Provides nccl.h AND a static copy of NCCL!
532 ]
533 ++ lists.optionals (cudaOlder "11.8") [
534 cuda_nvprof # <cuda_profiler_api.h>
535 ]
536 ++ lists.optionals (cudaAtLeast "11.8") [
537 cuda_profiler_api # <cuda_profiler_api.h>
538 ]
539 )
540 ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]
541 ++ lib.optionals (cudaSupport || rocmSupport) [ effectiveMagma ]
542 ++ lib.optionals stdenv.hostPlatform.isLinux [ numactl ]
543 ++ lib.optionals stdenv.hostPlatform.isDarwin [
544 apple-sdk_13
545 ]
546 ++ lib.optionals tritonSupport [ _tritonEffective ]
547 ++ lib.optionals MPISupport [ mpi ]
548 ++ lib.optionals rocmSupport [
549 rocmtoolkit_joined
550 rocmPackages.clr # Added separately so setup hook applies
551 ];
552
553 pythonRelaxDeps = [
554 "sympy"
555 ];
556 dependencies =
557 [
558 astunparse
559 expecttest
560 filelock
561 fsspec
562 hypothesis
563 jinja2
564 networkx
565 ninja
566 packaging
567 psutil
568 pyyaml
569 requests
570 sympy
571 types-dataclasses
572 typing-extensions
573
574 # the following are required for tensorboard support
575 pillow
576 six
577 tensorboard
578 protobuf
579
580 # torch/csrc requires `pybind11` at runtime
581 pybind11
582 ]
583 ++ lib.optionals tritonSupport [ _tritonEffective ]
584 ++ lib.optionals vulkanSupport [
585 vulkan-headers
586 vulkan-loader
587 ];
588
589 propagatedCxxBuildInputs =
590 [ ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
591
592 # Tests take a long time and may be flaky, so just sanity-check imports
593 doCheck = false;
594
595 pythonImportsCheck = [ "torch" ];
596
597 nativeCheckInputs = [
598 hypothesis
599 ninja
600 psutil
601 ];
602
603 checkPhase =
604 with lib.versions;
605 with lib.strings;
606 concatStringsSep " " [
607 "runHook preCheck"
608 "${python.interpreter} test/run_test.py"
609 "--exclude"
610 (concatStringsSep " " [
611 "utils" # utils requires git, which is not allowed in the check phase
612
613 # "dataloader" # psutils correctly finds and triggers multiprocessing, but is too sandboxed to run -- resulting in numerous errors
614 # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build
615
616 # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins
617 (optionalString (majorMinor version == "1.3") "tensorboard")
618 ])
619 "runHook postCheck"
620 ];
621
622 pythonRemoveDeps = [
623 # In our dist-info the name is just "triton"
624 "pytorch-triton-rocm"
625 ];
626
627 postInstall =
628 ''
629 find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +
630
631 mkdir $dev
632
633 # CppExtension requires that include files are packaged with the main
634 # python library output; which is why they are copied here.
635 cp -r $out/${python.sitePackages}/torch/include $dev/include
636
637 # Cmake files under /share are different and can be safely moved. This
638 # avoids unnecessary closure blow-up due to apple sdk references when
639 # USE_DISTRIBUTED is enabled.
640 mv $out/${python.sitePackages}/torch/share $dev/share
641
642 # Fix up library paths for split outputs
643 substituteInPlace \
644 $dev/share/cmake/Torch/TorchConfig.cmake \
645 --replace-fail \''${TORCH_INSTALL_PREFIX}/lib "$lib/lib"
646
647 substituteInPlace \
648 $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \
649 --replace-fail \''${_IMPORT_PREFIX}/lib "$lib/lib"
650
651 mkdir $lib
652 mv $out/${python.sitePackages}/torch/lib $lib/lib
653 ln -s $lib/lib $out/${python.sitePackages}/torch/lib
654 ''
655 + lib.optionalString rocmSupport ''
656 substituteInPlace $dev/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake \
657 --replace-fail "\''${_IMPORT_PREFIX}/lib64" "$lib/lib"
658
659 substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \
660 --replace-fail "/build/source/torch/include" "$dev/include"
661 '';
662
663 postFixup =
664 ''
665 mkdir -p "$cxxdev/nix-support"
666 printWords "''${propagatedCxxBuildInputs[@]}" >> "$cxxdev/nix-support/propagated-build-inputs"
667 ''
668 + lib.optionalString stdenv.hostPlatform.isDarwin ''
669 for f in $(ls $lib/lib/*.dylib); do
670 install_name_tool -id $lib/lib/$(basename $f) $f || true
671 done
672
673 install_name_tool -change @rpath/libshm.dylib $lib/lib/libshm.dylib $lib/lib/libtorch_python.dylib
674 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libtorch_python.dylib
675 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch_python.dylib
676
677 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch.dylib
678
679 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libshm.dylib
680 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libshm.dylib
681 '';
682
683 # See https://github.com/NixOS/nixpkgs/issues/296179
684 #
685 # This is a quick hack to add `libnvrtc` to the runpath so that torch can find
686 # it when it is needed at runtime.
687 extraRunpaths = lib.optionals cudaSupport [ "${lib.getLib cudaPackages.cuda_nvrtc}/lib" ];
688 postPhases = lib.optionals stdenv.hostPlatform.isLinux [ "postPatchelfPhase" ];
689 postPatchelfPhase = ''
690 while IFS= read -r -d $'\0' elf ; do
691 for extra in $extraRunpaths ; do
692 echo patchelf "$elf" --add-rpath "$extra" >&2
693 patchelf "$elf" --add-rpath "$extra"
694 done
695 done < <(
696 find "''${!outputLib}" "$out" -type f -iname '*.so' -print0
697 )
698 '';
699
700 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
701 requiredSystemFeatures = [ "big-parallel" ];
702
703 passthru = {
704 inherit
705 cudaSupport
706 cudaPackages
707 rocmSupport
708 rocmPackages
709 unroll-src
710 ;
711 cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
712 # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
713 blasProvider = blas.provider;
714 # To help debug when a package is broken due to CUDA support
715 inherit brokenConditions;
716 tests = callPackage ../tests { };
717 };
718
719 meta = {
720 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
721 # keep PyTorch in the description so the package can be found under that name on search.nixos.org
722 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
723 homepage = "https://pytorch.org/";
724 license = lib.licenses.bsd3;
725 maintainers = with lib.maintainers; [
726 teh
727 thoughtpolice
728 tscholak
729 ]; # tscholak esp. for darwin-related builds
730 platforms =
731 lib.platforms.linux
732 ++ lib.optionals (!cudaSupport && !rocmSupport) lib.platforms.darwin;
733 broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
734 };
735}