1{
2 stdenv,
3 lib,
4 fetchFromGitHub,
5 fetchFromGitLab,
6 fetchpatch,
7 git-unroll,
8 buildPythonPackage,
9 python,
10 runCommand,
11 writeShellScript,
12 config,
13 cudaSupport ? config.cudaSupport,
14 cudaPackages,
15 autoAddDriverRunpath,
16 effectiveMagma ?
17 if cudaSupport then
18 magma-cuda-static
19 else if rocmSupport then
20 magma-hip
21 else
22 magma,
23 magma,
24 magma-hip,
25 magma-cuda-static,
26 # Use the system NCCL as long as we're targeting CUDA on a supported platform.
27 useSystemNccl ? (cudaSupport && !cudaPackages.nccl.meta.unsupported || rocmSupport),
28 MPISupport ? false,
29 mpi,
30 buildDocs ? false,
31 targetPackages,
32
33 # tests.cudaAvailable:
34 callPackage,
35
36 # Native build inputs
37 cmake,
38 symlinkJoin,
39 which,
40 pybind11,
41 pkg-config,
42 removeReferencesTo,
43
44 # Build inputs
45 apple-sdk_13,
46 openssl,
47 numactl,
48 llvmPackages,
49
50 # dependencies
51 astunparse,
52 binutils,
53 expecttest,
54 filelock,
55 fsspec,
56 hypothesis,
57 jinja2,
58 networkx,
59 packaging,
60 psutil,
61 pyyaml,
62 requests,
63 sympy,
64 types-dataclasses,
65 typing-extensions,
66 # ROCm build and `torch.compile` requires `triton`
67 tritonSupport ? (!stdenv.hostPlatform.isDarwin),
68 triton,
69
70 # TODO: 1. callPackage needs to learn to distinguish between the task
71 # of "asking for an attribute from the parent scope" and
72 # the task of "exposing a formal parameter in .override".
73 # TODO: 2. We should probably abandon attributes such as `torchWithCuda` (etc.)
74 # as they routinely end up consuming the wrong arguments\
75 # (dependencies without cuda support).
76 # Instead we should rely on overlays and nixpkgsFun.
77 # (@SomeoneSerge)
78 _tritonEffective ? if cudaSupport then triton-cuda else triton,
79 triton-cuda,
80
81 # Disable MKLDNN on aarch64-darwin, it negatively impacts performance,
82 # this is also what official pytorch build does
83 mklDnnSupport ? !(stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64),
84
85 # virtual pkg that consistently instantiates blas across nixpkgs
86 # See https://github.com/NixOS/nixpkgs/pull/83888
87 blas,
88
89 # ninja (https://ninja-build.org) must be available to run C++ extensions tests,
90 ninja,
91
92 # dependencies for torch.utils.tensorboard
93 pillow,
94 six,
95 tensorboard,
96 protobuf,
97
98 # ROCm dependencies
99 rocmSupport ? config.rocmSupport,
100 rocmPackages,
101 gpuTargets ? [ ],
102
103 vulkanSupport ? false,
104 vulkan-headers,
105 vulkan-loader,
106 shaderc,
107}:
108
109let
110 inherit (lib)
111 attrsets
112 lists
113 strings
114 trivial
115 ;
116 inherit (cudaPackages) cudnn flags nccl;
117
118 triton = throw "python3Packages.torch: use _tritonEffective instead of triton to avoid divergence";
119
120 setBool = v: if v then "1" else "0";
121
122 # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/utils/cpp_extension.py#L2411-L2414
123 supportedTorchCudaCapabilities =
124 let
125 real = [
126 "3.5"
127 "3.7"
128 "5.0"
129 "5.2"
130 "5.3"
131 "6.0"
132 "6.1"
133 "6.2"
134 "7.0"
135 "7.2"
136 "7.5"
137 "8.0"
138 "8.6"
139 "8.7"
140 "8.9"
141 "9.0"
142 "9.0a"
143 "10.0"
144 "10.0"
145 "10.0a"
146 "10.1"
147 "10.1a"
148 "10.3"
149 "10.3a"
150 "12.0"
151 "12.0a"
152 "12.1"
153 "12.1a"
154 ];
155 ptx = lists.map (x: "${x}+PTX") real;
156 in
157 real ++ ptx;
158
159 # NOTE: The lists.subtractLists function is perhaps a bit unintuitive. It subtracts the elements
160 # of the first list *from* the second list. That means:
161 # lists.subtractLists a b = b - a
162
163 # For CUDA
164 supportedCudaCapabilities = lists.intersectLists flags.cudaCapabilities supportedTorchCudaCapabilities;
165 unsupportedCudaCapabilities = lists.subtractLists supportedCudaCapabilities flags.cudaCapabilities;
166
167 isCudaJetson = cudaSupport && cudaPackages.flags.isJetsonBuild;
168
169 # Use trivial.warnIf to print a warning if any unsupported GPU targets are specified.
170 gpuArchWarner =
171 supported: unsupported:
172 trivial.throwIf (supported == [ ]) (
173 "No supported GPU targets specified. Requested GPU targets: "
174 + strings.concatStringsSep ", " unsupported
175 ) supported;
176
177 # Create the gpuTargetString.
178 gpuTargetString = strings.concatStringsSep ";" (
179 if gpuTargets != [ ] then
180 # If gpuTargets is specified, it always takes priority.
181 gpuTargets
182 else if cudaSupport then
183 gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities
184 else if rocmSupport then
185 lib.lists.subtractLists [
186 # Remove RDNA1 gfx101x archs from default ROCm support list to avoid
187 # use of undeclared identifier 'CK_BUFFER_RESOURCE_3RD_DWORD'
188 # TODO: Retest after ROCm 6.4 or torch 2.8
189 "gfx1010"
190 "gfx1012"
191 ] rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets
192 else
193 throw "No GPU targets specified"
194 );
195
196 # Use vendored CK as header only dep if rocmPackages' CK doesn't properly support targets
197 vendorComposableKernel = rocmSupport && !rocmPackages.composable_kernel.anyMfmaTarget;
198
199 rocmtoolkit_joined = symlinkJoin {
200 name = "rocm-merged";
201
202 paths =
203 with rocmPackages;
204 [
205 rocm-core
206 clr
207 rccl
208 miopen
209 aotriton
210 rocrand
211 rocblas
212 rocsparse
213 hipsparse
214 rocthrust
215 rocprim
216 hipcub
217 roctracer
218 rocfft
219 rocsolver
220 hipfft
221 hiprand
222 hipsolver
223 hipblas-common
224 hipblas
225 hipblaslt
226 rocminfo
227 rocm-comgr
228 rocm-device-libs
229 rocm-runtime
230 rocm-smi
231 clr.icd
232 hipify
233 ]
234 ++ lib.optionals (!vendorComposableKernel) [
235 composable_kernel
236 ];
237
238 # Fix `setuptools` not being found
239 postBuild = ''
240 rm -rf $out/nix-support
241 '';
242 };
243
244 brokenConditions = attrsets.filterAttrs (_: cond: cond) {
245 "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
246 "CUDA is not targeting Linux" = cudaSupport && !stdenv.hostPlatform.isLinux;
247 "Unsupported CUDA version" =
248 cudaSupport
249 && !(builtins.elem cudaPackages.cudaMajorVersion [
250 "11"
251 "12"
252 ]);
253 "MPI cudatoolkit does not match cudaPackages.cudatoolkit" =
254 MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
255 # This used to be a deep package set comparison between cudaPackages and
256 # effectiveMagma.cudaPackages, making torch too strict in cudaPackages.
257 # In particular, this triggered warnings from cuda's `aliases.nix`
258 "Magma cudaPackages does not match cudaPackages" =
259 cudaSupport
260 && (effectiveMagma.cudaPackages.cudaMajorMinorVersion != cudaPackages.cudaMajorMinorVersion);
261 };
262
263 unroll-src = writeShellScript "unroll-src" ''
264 echo "{
265 version,
266 fetchFromGitLab,
267 fetchFromGitHub,
268 runCommand,
269 }:
270 assert version == "'"'$1'"'";"
271 ${lib.getExe git-unroll} https://github.com/pytorch/pytorch v$1
272 echo
273 echo "# Update using: unroll-src [version]"
274 '';
275
276 stdenv' = if cudaSupport then cudaPackages.backendStdenv else stdenv;
277in
278buildPythonPackage rec {
279 pname = "torch";
280 # Don't forget to update torch-bin to the same version.
281 version = "2.8.0";
282 pyproject = true;
283
284 stdenv = stdenv';
285
286 outputs = [
287 "out" # output standard python package
288 "dev" # output libtorch headers
289 "lib" # output libtorch libraries
290 "cxxdev" # propagated deps for the cmake consumers of torch
291 ];
292 cudaPropagateToOutput = "cxxdev";
293
294 src = callPackage ./src.nix {
295 inherit
296 version
297 fetchFromGitHub
298 fetchFromGitLab
299 runCommand
300 ;
301 };
302
303 patches = [
304 ./clang19-template-warning.patch
305
306 # Do not override PYTHONPATH, otherwise, the build fails with:
307 # ModuleNotFoundError: No module named 'typing_extensions'
308 (fetchpatch {
309 name = "cmake-build-preserve-PYTHONPATH";
310 url = "https://github.com/pytorch/pytorch/commit/231c72240d80091f099c95e326d3600cba866eee.patch";
311 hash = "sha256-BBCjxzz2TUkx4nXRyRILA82kMwyb/4+C3eOtYqf5dhk=";
312 })
313
314 # Fixes GCC-14 compatibility on ARM
315 # Adapted from https://github.com/pytorch/pytorch/pull/157867
316 # TODO: remove at the next release
317 ./gcc-14-arm-compat.path
318 ]
319 ++ lib.optionals cudaSupport [
320 ./fix-cmake-cuda-toolkit.patch
321 ./nvtx3-hpp-path-fix.patch
322 ]
323 ++ lib.optionals stdenv.hostPlatform.isLinux [
324 # Propagate CUPTI to Kineto by overriding the search path with environment variables.
325 # https://github.com/pytorch/pytorch/pull/108847
326 ./pytorch-pr-108847.patch
327 ]
328 ++ lib.optionals (lib.getName blas.provider == "mkl") [
329 # The CMake install tries to add some hardcoded rpaths, incompatible
330 # with the Nix store, which fails. Simply remove this step to get
331 # rpaths that point to the Nix store.
332 ./disable-cmake-mkl-rpath.patch
333 ];
334
335 postPatch = ''
336 substituteInPlace pyproject.toml \
337 --replace-fail "setuptools>=62.3.0,<80.0" "setuptools"
338 ''
339 # Provide path to openssl binary for inductor code cache hash
340 # InductorError: FileNotFoundError: [Errno 2] No such file or directory: 'openssl'
341 + ''
342 substituteInPlace torch/_inductor/codecache.py \
343 --replace-fail '"openssl"' '"${lib.getExe openssl}"'
344 ''
345 + ''
346 substituteInPlace cmake/public/cuda.cmake \
347 --replace-fail \
348 'message(FATAL_ERROR "Found two conflicting CUDA' \
349 'message(WARNING "Found two conflicting CUDA' \
350 --replace-warn \
351 "set(CUDAToolkit_ROOT" \
352 "# Upstream: set(CUDAToolkit_ROOT"
353 substituteInPlace third_party/gloo/cmake/Cuda.cmake \
354 --replace-warn "find_package(CUDAToolkit 7.0" "find_package(CUDAToolkit"
355 ''
356 # annotations (3.7), print_function (3.0), with_statement (2.6) are all supported
357 + ''
358 sed -i -e "/from __future__ import/d" **.py
359 substituteInPlace third_party/NNPACK/CMakeLists.txt \
360 --replace-fail "PYTHONPATH=" 'PYTHONPATH=$ENV{PYTHONPATH}:'
361 ''
362 # flag from cmakeFlags doesn't work, not clear why
363 # setting it at the top of NNPACK's own CMakeLists does
364 + ''
365 sed -i '2s;^;set(PYTHON_SIX_SOURCE_DIR ${six.src})\n;' third_party/NNPACK/CMakeLists.txt
366 ''
367 # Ensure that torch profiler unwind uses addr2line from nix
368 + ''
369 substituteInPlace torch/csrc/profiler/unwind/unwind.cpp \
370 --replace-fail 'addr2line_binary_ = "addr2line"' 'addr2line_binary_ = "${lib.getExe' binutils "addr2line"}"'
371 ''
372 # Ensures torch compile can find and use compilers from nix.
373 + ''
374 substituteInPlace torch/_inductor/config.py \
375 --replace-fail '"clang++" if sys.platform == "darwin" else "g++"' \
376 '"${lib.getExe' targetPackages.stdenv.cc "${targetPackages.stdenv.cc.targetPrefix}c++"}"'
377 ''
378 + lib.optionalString rocmSupport ''
379 # https://github.com/facebookincubator/gloo/pull/297
380 substituteInPlace third_party/gloo/cmake/Hipify.cmake \
381 --replace-fail "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}"
382
383 # Doesn't pick up the environment variable?
384 substituteInPlace third_party/kineto/libkineto/CMakeLists.txt \
385 --replace-fail "\''$ENV{ROCM_SOURCE_DIR}" "${rocmtoolkit_joined}"
386 ''
387 # When possible, composable kernel as dependency, rather than built-in third-party
388 + lib.optionalString (rocmSupport && !vendorComposableKernel) ''
389 substituteInPlace aten/src/ATen/CMakeLists.txt \
390 --replace-fail "list(APPEND ATen_HIP_INCLUDE \''${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)" "" \
391 --replace-fail "list(APPEND ATen_HIP_INCLUDE \''${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)" ""
392 ''
393 # Detection of NCCL version doesn't work particularly well when using the static binary.
394 + lib.optionalString cudaSupport ''
395 substituteInPlace cmake/Modules/FindNCCL.cmake \
396 --replace-fail \
397 'message(FATAL_ERROR "Found NCCL header version and library version' \
398 'message(WARNING "Found NCCL header version and library version'
399 ''
400 # Remove PyTorch's FindCUDAToolkit.cmake and use CMake's default.
401 # NOTE: Parts of pytorch rely on unmaintained FindCUDA.cmake with custom patches to support e.g.
402 # newer architectures (sm_90a). We do want to delete vendored patches, but have to keep them
403 # until https://github.com/pytorch/pytorch/issues/76082 is addressed
404 + lib.optionalString cudaSupport ''
405 rm cmake/Modules/FindCUDAToolkit.cmake
406 '';
407
408 # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
409 # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
410 # without extreme care to ensure they don't lock each other out of shared resources.
411 # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
412 preConfigure =
413 lib.optionalString cudaSupport ''
414 export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
415 export CUPTI_INCLUDE_DIR=${lib.getDev cudaPackages.cuda_cupti}/include
416 export CUPTI_LIBRARY_DIR=${lib.getLib cudaPackages.cuda_cupti}/lib
417 ''
418 + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
419 export CUDNN_INCLUDE_DIR=${lib.getLib cudnn}/include
420 export CUDNN_LIB_DIR=${lib.getLib cudnn}/lib
421 ''
422 + lib.optionalString rocmSupport ''
423 export ROCM_PATH=${rocmtoolkit_joined}
424 export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
425 export PYTORCH_ROCM_ARCH="${gpuTargetString}"
426 export CMAKE_CXX_FLAGS="-I${rocmtoolkit_joined}/include"
427 python tools/amd_build/build_amd.py
428 '';
429
430 # Use pytorch's custom configurations
431 dontUseCmakeConfigure = true;
432
433 # causes possible redefinition of _FORTIFY_SOURCE
434 hardeningDisable = [ "fortify3" ];
435
436 BUILD_NAMEDTENSOR = setBool true;
437 BUILD_DOCS = setBool buildDocs;
438
439 # We only do an imports check, so do not build tests either.
440 BUILD_TEST = setBool false;
441
442 # ninja hook doesn't automatically turn on ninja
443 # because pytorch setup.py is responsible for this
444 CMAKE_GENERATOR = "Ninja";
445
446 # Unlike MKL, oneDNN (née MKLDNN) is FOSS, so we enable support for
447 # it by default. PyTorch currently uses its own vendored version
448 # of oneDNN through Intel iDeep.
449 USE_MKLDNN = setBool mklDnnSupport;
450 USE_MKLDNN_CBLAS = setBool mklDnnSupport;
451
452 # Avoid using pybind11 from git submodule
453 # Also avoids pytorch exporting the headers of pybind11
454 USE_SYSTEM_PYBIND11 = true;
455
456 # Multicore CPU convnet support
457 USE_NNPACK = 1;
458
459 # Explicitly enable MPS for Darwin
460 USE_MPS = setBool stdenv.hostPlatform.isDarwin;
461
462 # building torch.distributed on Darwin is disabled by default
463 # https://pytorch.org/docs/stable/distributed.html#torch.distributed.is_available
464 USE_DISTRIBUTED = setBool true;
465
466 cmakeFlags = [
467 (lib.cmakeFeature "PYTHON_SIX_SOURCE_DIR" "${six.src}")
468 # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
469 (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaMajorMinorVersion)
470 ]
471 ++ lib.optionals cudaSupport [
472 # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
473 # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
474 (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaMajorMinorVersion)
475 ];
476
477 preBuild = ''
478 export MAX_JOBS=$NIX_BUILD_CORES
479 ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
480 ${cmake}/bin/cmake build
481 '';
482
483 preFixup = ''
484 function join_by { local IFS="$1"; shift; echo "$*"; }
485 function strip2 {
486 IFS=':'
487 read -ra RP <<< $(patchelf --print-rpath $1)
488 IFS=' '
489 RP_NEW=$(join_by : ''${RP[@]:2})
490 patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
491 }
492 for f in $(find ''${out} -name 'libcaffe2*.so')
493 do
494 strip2 $f
495 done
496 '';
497
498 # Override the (weirdly) wrong version set by default. See
499 # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
500 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
501 PYTORCH_BUILD_VERSION = version;
502 PYTORCH_BUILD_NUMBER = 0;
503
504 # In-tree builds of NCCL are not supported.
505 # Use NCCL when cudaSupport is enabled and nccl is available.
506 USE_NCCL = setBool useSystemNccl;
507 USE_SYSTEM_NCCL = USE_NCCL;
508 USE_STATIC_NCCL = USE_NCCL;
509
510 # Set the correct Python library path, broken since
511 # https://github.com/pytorch/pytorch/commit/3d617333e
512 PYTHON_LIB_REL_PATH = "${placeholder "out"}/${python.sitePackages}";
513
514 env = {
515 # disable warnings as errors as they break the build on every compiler
516 # bump, among other things.
517 # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
518 # https://github.com/pytorch/pytorch/blob/v1.11.0/setup.py#L17
519 NIX_CFLAGS_COMPILE = toString (
520 [
521 "-Wno-error"
522 ]
523 # fix build aarch64-linux build failure with GCC14
524 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
525 "-Wno-error=incompatible-pointer-types"
526 ]
527 );
528 USE_VULKAN = setBool vulkanSupport;
529 }
530 // lib.optionalAttrs vulkanSupport {
531 VULKAN_SDK = shaderc.bin;
532 }
533 // lib.optionalAttrs rocmSupport {
534 AOTRITON_INSTALLED_PREFIX = "${rocmPackages.aotriton}";
535 };
536
537 nativeBuildInputs = [
538 cmake
539 which
540 ninja
541 pybind11
542 pkg-config
543 removeReferencesTo
544 ]
545 ++ lib.optionals cudaSupport (
546 with cudaPackages;
547 [
548 autoAddDriverRunpath
549 cuda_nvcc
550 ]
551 )
552 ++ lib.optionals isCudaJetson [ cudaPackages.autoAddCudaCompatRunpath ]
553 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
554
555 buildInputs = [
556 blas
557 blas.provider
558 ]
559 # Including openmp leads to two copies being used on ARM, which segfaults.
560 # https://github.com/pytorch/pytorch/issues/149201#issuecomment-2776842320
561 ++ lib.optionals (stdenv.cc.isClang && !stdenv.hostPlatform.isAarch64) [ llvmPackages.openmp ]
562 ++ lib.optionals cudaSupport (
563 with cudaPackages;
564 [
565 cuda_cccl # <thrust/*>
566 cuda_cudart # cuda_runtime.h and libraries
567 cuda_cupti # For kineto
568 cuda_nvcc # crt/host_config.h; even though we include this in nativeBuildInputs, it's needed here too
569 cuda_nvml_dev # <nvml.h>
570 cuda_nvrtc
571 cuda_nvtx # -llibNVToolsExt
572 cusparselt
573 libcublas
574 libcufft
575 libcufile
576 libcurand
577 libcusolver
578 libcusparse
579 ]
580 ++ lists.optionals (cudaPackages ? cudnn) [ cudnn ]
581 ++ lists.optionals useSystemNccl [
582 # Some platforms do not support NCCL (i.e., Jetson)
583 nccl # Provides nccl.h AND a static copy of NCCL!
584 ]
585 ++ [
586 cuda_profiler_api # <cuda_profiler_api.h>
587 ]
588 )
589 ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]
590 ++ lib.optionals (cudaSupport || rocmSupport) [ effectiveMagma ]
591 ++ lib.optionals stdenv.hostPlatform.isLinux [ numactl ]
592 ++ lib.optionals stdenv.hostPlatform.isDarwin [
593 apple-sdk_13
594 ]
595 ++ lib.optionals tritonSupport [ _tritonEffective ]
596 ++ lib.optionals MPISupport [ mpi ]
597 ++ lib.optionals rocmSupport [
598 rocmtoolkit_joined
599 rocmPackages.clr # Added separately so setup hook applies
600 ];
601
602 pythonRelaxDeps = [
603 "sympy"
604 ];
605 dependencies = [
606 astunparse
607 expecttest
608 filelock
609 fsspec
610 hypothesis
611 jinja2
612 networkx
613 ninja
614 packaging
615 psutil
616 pyyaml
617 requests
618 sympy
619 types-dataclasses
620 typing-extensions
621
622 # the following are required for tensorboard support
623 pillow
624 six
625 tensorboard
626 protobuf
627
628 # torch/csrc requires `pybind11` at runtime
629 pybind11
630 ]
631 ++ lib.optionals tritonSupport [ _tritonEffective ]
632 ++ lib.optionals vulkanSupport [
633 vulkan-headers
634 vulkan-loader
635 ];
636
637 propagatedCxxBuildInputs =
638 [ ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
639
640 # Tests take a long time and may be flaky, so just sanity-check imports
641 doCheck = false;
642
643 pythonImportsCheck = [ "torch" ];
644
645 nativeCheckInputs = [
646 hypothesis
647 ninja
648 psutil
649 ];
650
651 checkPhase =
652 with lib.versions;
653 with lib.strings;
654 concatStringsSep " " [
655 "runHook preCheck"
656 "${python.interpreter} test/run_test.py"
657 "--exclude"
658 (concatStringsSep " " [
659 "utils" # utils requires git, which is not allowed in the check phase
660
661 # "dataloader" # psutils correctly finds and triggers multiprocessing, but is too sandboxed to run -- resulting in numerous errors
662 # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build
663
664 # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins
665 (optionalString (majorMinor version == "1.3") "tensorboard")
666 ])
667 "runHook postCheck"
668 ];
669
670 pythonRemoveDeps = [
671 # In our dist-info the name is just "triton"
672 "pytorch-triton-rocm"
673 ];
674
675 postInstall = ''
676 find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +
677
678 mkdir $dev
679
680 # CppExtension requires that include files are packaged with the main
681 # python library output; which is why they are copied here.
682 cp -r $out/${python.sitePackages}/torch/include $dev/include
683
684 # Cmake files under /share are different and can be safely moved. This
685 # avoids unnecessary closure blow-up due to apple sdk references when
686 # USE_DISTRIBUTED is enabled.
687 mv $out/${python.sitePackages}/torch/share $dev/share
688
689 # Fix up library paths for split outputs
690 substituteInPlace \
691 $dev/share/cmake/Torch/TorchConfig.cmake \
692 --replace-fail \''${TORCH_INSTALL_PREFIX}/lib "$lib/lib"
693
694 substituteInPlace \
695 $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \
696 --replace-fail \''${_IMPORT_PREFIX}/lib "$lib/lib"
697
698 mkdir $lib
699 mv $out/${python.sitePackages}/torch/lib $lib/lib
700 ln -s $lib/lib $out/${python.sitePackages}/torch/lib
701 ''
702 + lib.optionalString rocmSupport ''
703 substituteInPlace $dev/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake \
704 --replace-fail "\''${_IMPORT_PREFIX}/lib64" "$lib/lib"
705
706 substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \
707 --replace-fail "/build/${src.name}/torch/include" "$dev/include"
708 '';
709
710 postFixup = ''
711 mkdir -p "$cxxdev/nix-support"
712 printWords "''${propagatedCxxBuildInputs[@]}" >> "$cxxdev/nix-support/propagated-build-inputs"
713 ''
714 + lib.optionalString stdenv.hostPlatform.isDarwin ''
715 for f in $(ls $lib/lib/*.dylib); do
716 install_name_tool -id $lib/lib/$(basename $f) $f || true
717 done
718
719 install_name_tool -change @rpath/libshm.dylib $lib/lib/libshm.dylib $lib/lib/libtorch_python.dylib
720 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libtorch_python.dylib
721 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch_python.dylib
722
723 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch.dylib
724
725 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libshm.dylib
726 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libshm.dylib
727 '';
728
729 # See https://github.com/NixOS/nixpkgs/issues/296179
730 #
731 # This is a quick hack to add `libnvrtc` to the runpath so that torch can find
732 # it when it is needed at runtime.
733 extraRunpaths = lib.optionals cudaSupport [ "${lib.getLib cudaPackages.cuda_nvrtc}/lib" ];
734 postPhases = lib.optionals stdenv.hostPlatform.isLinux [ "postPatchelfPhase" ];
735 postPatchelfPhase = ''
736 while IFS= read -r -d $'\0' elf ; do
737 for extra in $extraRunpaths ; do
738 echo patchelf "$elf" --add-rpath "$extra" >&2
739 patchelf "$elf" --add-rpath "$extra"
740 done
741 done < <(
742 find "''${!outputLib}" "$out" -type f -iname '*.so' -print0
743 )
744 '';
745
746 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
747 requiredSystemFeatures = [ "big-parallel" ];
748
749 passthru = {
750 inherit
751 cudaSupport
752 cudaPackages
753 rocmSupport
754 rocmPackages
755 unroll-src
756 ;
757 cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
758 # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
759 blasProvider = blas.provider;
760 # To help debug when a package is broken due to CUDA support
761 inherit brokenConditions;
762 tests = callPackage ../tests { };
763 };
764
765 meta = {
766 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
767 # keep PyTorch in the description so the package can be found under that name on search.nixos.org
768 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
769 homepage = "https://pytorch.org/";
770 license = lib.licenses.bsd3;
771 maintainers = with lib.maintainers; [
772 GaetanLepage
773 teh
774 thoughtpolice
775 tscholak
776 ]; # tscholak esp. for darwin-related builds
777 platforms =
778 lib.platforms.linux ++ lib.optionals (!cudaSupport && !rocmSupport) lib.platforms.darwin;
779 broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
780 };
781}