nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 stdenv,
3 lib,
4 fetchFromGitHub,
5 fetchFromGitLab,
6 git-unroll,
7 buildPythonPackage,
8 python,
9 runCommand,
10 writeShellScript,
11 config,
12 cudaSupport ? config.cudaSupport,
13 cudaPackages,
14 autoAddDriverRunpath,
15 effectiveMagma ?
16 if cudaSupport then
17 magma-cuda-static
18 else if rocmSupport then
19 magma-hip
20 else
21 magma,
22 magma,
23 magma-hip,
24 magma-cuda-static,
25 # Use the system NCCL as long as we're targeting CUDA on a supported platform.
26 useSystemNccl ? (cudaSupport && cudaPackages.nccl.meta.available || rocmSupport),
27 MPISupport ? false,
28 mpi,
29 buildDocs ? false,
30 targetPackages,
31
32 # tests.cudaAvailable:
33 callPackage,
34
35 # Native build inputs
36 cmake,
37 symlinkJoin,
38 which,
39 pybind11,
40 pkg-config,
41 removeReferencesTo,
42
43 # Build inputs
44 openssl,
45 numactl,
46 llvmPackages,
47
48 # dependencies
49 astunparse,
50 binutils,
51 expecttest,
52 filelock,
53 fsspec,
54 hypothesis,
55 jinja2,
56 networkx,
57 packaging,
58 psutil,
59 pyyaml,
60 requests,
61 sympy,
62 types-dataclasses,
63 typing-extensions,
64 # ROCm build and `torch.compile` requires `triton`
65 tritonSupport ? (!stdenv.hostPlatform.isDarwin),
66 triton,
67
68 # TODO: 1. callPackage needs to learn to distinguish between the task
69 # of "asking for an attribute from the parent scope" and
70 # the task of "exposing a formal parameter in .override".
71 # TODO: 2. We should probably abandon attributes such as `torchWithCuda` (etc.)
72 # as they routinely end up consuming the wrong arguments\
73 # (dependencies without cuda support).
74 # Instead we should rely on overlays and nixpkgsFun.
75 # (@SomeoneSerge)
76 _tritonEffective ? if cudaSupport then triton-cuda else triton,
77 triton-cuda,
78
79 # Disable MKLDNN on aarch64-darwin, it negatively impacts performance,
80 # this is also what official pytorch build does
81 mklDnnSupport ? !(stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64),
82
83 # virtual pkg that consistently instantiates blas across nixpkgs
84 # See https://github.com/NixOS/nixpkgs/pull/83888
85 blas,
86
87 # ninja (https://ninja-build.org) must be available to run C++ extensions tests,
88 ninja,
89
90 # dependencies for torch.utils.tensorboard
91 pillow,
92 six,
93 tensorboard,
94 protobuf,
95
96 # ROCm dependencies
97 rocmSupport ? config.rocmSupport,
98 rocmPackages,
99 gpuTargets ? [ ],
100
101 vulkanSupport ? false,
102 vulkan-headers,
103 vulkan-loader,
104 shaderc,
105}:
106
107let
108 inherit (lib)
109 attrsets
110 lists
111 strings
112 trivial
113 ;
114 inherit (cudaPackages) cudnn flags nccl;
115
116 triton = throw "python3Packages.torch: use _tritonEffective instead of triton to avoid divergence";
117
118 setBool = v: if v then "1" else "0";
119
120 # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/utils/cpp_extension.py#L2411-L2414
121 supportedTorchCudaCapabilities =
122 let
123 real = [
124 "3.5"
125 "3.7"
126 "5.0"
127 "5.2"
128 "5.3"
129 "6.0"
130 "6.1"
131 "6.2"
132 "7.0"
133 "7.2"
134 "7.5"
135 "8.0"
136 "8.6"
137 "8.7"
138 "8.9"
139 "9.0"
140 "9.0a"
141 "10.0"
142 "10.0"
143 "10.0a"
144 "10.1"
145 "10.1a"
146 "10.3"
147 "10.3a"
148 "12.0"
149 "12.0a"
150 "12.1"
151 "12.1a"
152 ];
153 ptx = lists.map (x: "${x}+PTX") real;
154 in
155 real ++ ptx;
156
157 # NOTE: The lists.subtractLists function is perhaps a bit unintuitive. It subtracts the elements
158 # of the first list *from* the second list. That means:
159 # lists.subtractLists a b = b - a
160
161 # For CUDA
162 supportedCudaCapabilities = lists.intersectLists flags.cudaCapabilities supportedTorchCudaCapabilities;
163 unsupportedCudaCapabilities = lists.subtractLists supportedCudaCapabilities flags.cudaCapabilities;
164
165 isCudaJetson = cudaSupport && cudaPackages.flags.isJetsonBuild;
166
167 # Use trivial.warnIf to print a warning if any unsupported GPU targets are specified.
168 gpuArchWarner =
169 supported: unsupported:
170 trivial.throwIf (supported == [ ]) (
171 "No supported GPU targets specified. Requested GPU targets: "
172 + strings.concatStringsSep ", " unsupported
173 ) supported;
174
175 # Create the gpuTargetString.
176 gpuTargetString = strings.concatStringsSep ";" (
177 if gpuTargets != [ ] then
178 # If gpuTargets is specified, it always takes priority.
179 gpuTargets
180 else if cudaSupport then
181 gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities
182 else if rocmSupport then
183 lib.lists.subtractLists [
184 # Remove RDNA1 gfx101x archs from default ROCm support list to avoid
185 # use of undeclared identifier 'CK_BUFFER_RESOURCE_3RD_DWORD'
186 # TODO: Retest after ROCm 6.4 or torch 2.8
187 "gfx1010"
188 "gfx1012"
189 ] rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets
190 else
191 throw "No GPU targets specified"
192 );
193
194 # Use vendored CK as header only dep if rocmPackages' CK doesn't properly support targets
195 vendorComposableKernel = rocmSupport && !rocmPackages.composable_kernel.anyMfmaTarget;
196
197 rocmtoolkit_joined = symlinkJoin {
198 name = "rocm-merged";
199
200 paths =
201 with rocmPackages;
202 [
203 rocm-core
204 clr
205 rccl
206 miopen
207 aotriton
208 rocrand
209 rocblas
210 rocsparse
211 hipsparse
212 rocthrust
213 rocprim
214 hipcub
215 roctracer
216 rocfft
217 rocsolver
218 hipfft
219 hiprand
220 hipsolver
221 hipblas-common
222 hipblas
223 hipblaslt
224 rocminfo
225 rocm-comgr
226 rocm-device-libs
227 rocm-runtime
228 rocm-smi
229 clr.icd
230 hipify
231 ]
232 ++ lib.optionals (!vendorComposableKernel) [
233 composable_kernel
234 ];
235
236 # Fix `setuptools` not being found
237 postBuild = ''
238 rm -rf $out/nix-support
239 '';
240 };
241
242 brokenConditions = attrsets.filterAttrs (_: cond: cond) {
243 "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
244 "CUDA is not targeting Linux" = cudaSupport && !stdenv.hostPlatform.isLinux;
245 "Unsupported CUDA version" =
246 cudaSupport
247 && !(builtins.elem cudaPackages.cudaMajorVersion [
248 "11"
249 "12"
250 ]);
251 "MPI cudatoolkit does not match cudaPackages.cudatoolkit" =
252 MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
253 # This used to be a deep package set comparison between cudaPackages and
254 # effectiveMagma.cudaPackages, making torch too strict in cudaPackages.
255 # In particular, this triggered warnings from cuda's `aliases.nix`
256 "Magma cudaPackages does not match cudaPackages" =
257 cudaSupport
258 && (effectiveMagma.cudaPackages.cudaMajorMinorVersion != cudaPackages.cudaMajorMinorVersion);
259 };
260
261 unroll-src = writeShellScript "unroll-src" ''
262 echo "{
263 version,
264 fetchFromGitLab,
265 fetchFromGitHub,
266 runCommand,
267 }:
268 assert version == "'"'$1'"'";"
269 ${lib.getExe git-unroll} https://github.com/pytorch/pytorch v$1
270 echo
271 echo "# Update using: unroll-src [version]"
272 '';
273
274 stdenv' = if cudaSupport then cudaPackages.backendStdenv else stdenv;
275in
276let
277 # From here on, `stdenv` shall be `stdenv'`.
278 stdenv = stdenv';
279in
280buildPythonPackage.override { inherit stdenv; } rec {
281 pname = "torch";
282 # Don't forget to update torch-bin to the same version.
283 version = "2.9.1";
284 pyproject = true;
285
286 outputs = [
287 "out" # output standard python package
288 "dev" # output libtorch headers
289 "lib" # output libtorch libraries
290 "cxxdev" # propagated deps for the cmake consumers of torch
291 ];
292 cudaPropagateToOutput = "cxxdev";
293
294 src = callPackage ./src.nix {
295 inherit
296 version
297 fetchFromGitHub
298 fetchFromGitLab
299 runCommand
300 ;
301 };
302
303 patches = [
304 ./clang19-template-warning.patch
305 ]
306 ++ lib.optionals cudaSupport [
307 ./fix-cmake-cuda-toolkit.patch
308 ./nvtx3-hpp-path-fix.patch
309 ]
310 ++ lib.optionals stdenv.hostPlatform.isLinux [
311 # Propagate CUPTI to Kineto by overriding the search path with environment variables.
312 # https://github.com/pytorch/pytorch/pull/108847
313 ./pytorch-pr-108847.patch
314 ]
315 ++ lib.optionals (lib.getName blas.provider == "mkl") [
316 # The CMake install tries to add some hardcoded rpaths, incompatible
317 # with the Nix store, which fails. Simply remove this step to get
318 # rpaths that point to the Nix store.
319 ./disable-cmake-mkl-rpath.patch
320 ];
321
322 postPatch = ''
323 substituteInPlace pyproject.toml \
324 --replace-fail "setuptools>=70.1.0,<80.0" "setuptools"
325 ''
326 # Provide path to openssl binary for inductor code cache hash
327 # InductorError: FileNotFoundError: [Errno 2] No such file or directory: 'openssl'
328 + ''
329 substituteInPlace torch/_inductor/codecache.py \
330 --replace-fail '"openssl"' '"${lib.getExe openssl}"'
331 ''
332 + ''
333 substituteInPlace cmake/public/cuda.cmake \
334 --replace-fail \
335 'message(FATAL_ERROR "Found two conflicting CUDA' \
336 'message(WARNING "Found two conflicting CUDA' \
337 --replace-warn \
338 "set(CUDAToolkit_ROOT" \
339 "# Upstream: set(CUDAToolkit_ROOT"
340 substituteInPlace third_party/gloo/cmake/Cuda.cmake \
341 --replace-warn "find_package(CUDAToolkit 7.0" "find_package(CUDAToolkit"
342 ''
343 # annotations (3.7), print_function (3.0), with_statement (2.6) are all supported
344 + ''
345 sed -i -e "/from __future__ import/d" **.py
346 substituteInPlace third_party/NNPACK/CMakeLists.txt \
347 --replace-fail "PYTHONPATH=" 'PYTHONPATH=$ENV{PYTHONPATH}:'
348 ''
349 # flag from cmakeFlags doesn't work, not clear why
350 # setting it at the top of NNPACK's own CMakeLists does
351 + ''
352 sed -i '2s;^;set(PYTHON_SIX_SOURCE_DIR ${six.src})\n;' third_party/NNPACK/CMakeLists.txt
353 ''
354 # Ensure that torch profiler unwind uses addr2line from nix
355 + ''
356 substituteInPlace torch/csrc/profiler/unwind/unwind.cpp \
357 --replace-fail 'addr2line_binary_ = "addr2line"' 'addr2line_binary_ = "${lib.getExe' binutils "addr2line"}"'
358 ''
359 # Ensures torch compile can find and use compilers from nix.
360 + ''
361 substituteInPlace torch/_inductor/config.py \
362 --replace-fail '"clang++" if sys.platform == "darwin" else "g++"' \
363 '"${lib.getExe' targetPackages.stdenv.cc "${targetPackages.stdenv.cc.targetPrefix}c++"}"'
364 ''
365 + lib.optionalString rocmSupport ''
366 # https://github.com/facebookincubator/gloo/pull/297
367 substituteInPlace third_party/gloo/cmake/Hipify.cmake \
368 --replace-fail "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}"
369
370 # Doesn't pick up the environment variable?
371 substituteInPlace third_party/kineto/libkineto/CMakeLists.txt \
372 --replace-fail "\''$ENV{ROCM_SOURCE_DIR}" "${rocmtoolkit_joined}"
373 ''
374 # When possible, composable kernel as dependency, rather than built-in third-party
375 + lib.optionalString (rocmSupport && !vendorComposableKernel) ''
376 substituteInPlace aten/src/ATen/CMakeLists.txt \
377 --replace-fail "list(APPEND ATen_HIP_INCLUDE \''${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)" "" \
378 --replace-fail "list(APPEND ATen_HIP_INCLUDE \''${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)" ""
379 ''
380 # Detection of NCCL version doesn't work particularly well when using the static binary.
381 + lib.optionalString cudaSupport ''
382 substituteInPlace cmake/Modules/FindNCCL.cmake \
383 --replace-fail \
384 'message(FATAL_ERROR "Found NCCL header version and library version' \
385 'message(WARNING "Found NCCL header version and library version'
386 ''
387 # Remove PyTorch's FindCUDAToolkit.cmake and use CMake's default.
388 # NOTE: Parts of pytorch rely on unmaintained FindCUDA.cmake with custom patches to support e.g.
389 # newer architectures (sm_90a). We do want to delete vendored patches, but have to keep them
390 # until https://github.com/pytorch/pytorch/issues/76082 is addressed
391 + lib.optionalString cudaSupport ''
392 rm cmake/Modules/FindCUDAToolkit.cmake
393 '';
394
395 # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
396 # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
397 # without extreme care to ensure they don't lock each other out of shared resources.
398 # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
399 preConfigure =
400 lib.optionalString cudaSupport ''
401 export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
402 export CUPTI_INCLUDE_DIR=${lib.getDev cudaPackages.cuda_cupti}/include
403 export CUPTI_LIBRARY_DIR=${lib.getLib cudaPackages.cuda_cupti}/lib
404 ''
405 + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
406 export CUDNN_INCLUDE_DIR=${lib.getLib cudnn}/include
407 export CUDNN_LIB_DIR=${lib.getLib cudnn}/lib
408 ''
409 + lib.optionalString rocmSupport ''
410 export ROCM_PATH=${rocmtoolkit_joined}
411 export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
412 export PYTORCH_ROCM_ARCH="${gpuTargetString}"
413 export CMAKE_CXX_FLAGS="-I${rocmtoolkit_joined}/include"
414 python tools/amd_build/build_amd.py
415 '';
416
417 # Use pytorch's custom configurations
418 dontUseCmakeConfigure = true;
419
420 # causes possible redefinition of _FORTIFY_SOURCE
421 hardeningDisable = [ "fortify3" ];
422
423 BUILD_NAMEDTENSOR = setBool true;
424 BUILD_DOCS = setBool buildDocs;
425
426 # We only do an imports check, so do not build tests either.
427 BUILD_TEST = setBool false;
428
429 # ninja hook doesn't automatically turn on ninja
430 # because pytorch setup.py is responsible for this
431 CMAKE_GENERATOR = "Ninja";
432
433 # Unlike MKL, oneDNN (née MKLDNN) is FOSS, so we enable support for
434 # it by default. PyTorch currently uses its own vendored version
435 # of oneDNN through Intel iDeep.
436 USE_MKLDNN = setBool mklDnnSupport;
437 USE_MKLDNN_CBLAS = setBool mklDnnSupport;
438
439 # Avoid using pybind11 from git submodule
440 # Also avoids pytorch exporting the headers of pybind11
441 USE_SYSTEM_PYBIND11 = true;
442
443 # Multicore CPU convnet support
444 USE_NNPACK = 1;
445
446 # Explicitly enable MPS for Darwin
447 USE_MPS = setBool stdenv.hostPlatform.isDarwin;
448
449 # building torch.distributed on Darwin is disabled by default
450 # https://pytorch.org/docs/stable/distributed.html#torch.distributed.is_available
451 USE_DISTRIBUTED = setBool true;
452
453 cmakeFlags = [
454 (lib.cmakeFeature "PYTHON_SIX_SOURCE_DIR" "${six.src}")
455 # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
456 (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaMajorMinorVersion)
457 ]
458 ++ lib.optionals cudaSupport [
459 # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
460 # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
461 (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaMajorMinorVersion)
462 ];
463
464 preBuild = ''
465 export MAX_JOBS=$NIX_BUILD_CORES
466 ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
467 ${cmake}/bin/cmake build
468 '';
469
470 preFixup = ''
471 function join_by { local IFS="$1"; shift; echo "$*"; }
472 function strip2 {
473 IFS=':'
474 read -ra RP <<< $(patchelf --print-rpath $1)
475 IFS=' '
476 RP_NEW=$(join_by : ''${RP[@]:2})
477 patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
478 }
479 for f in $(find ''${out} -name 'libcaffe2*.so')
480 do
481 strip2 $f
482 done
483 '';
484
485 # Override the (weirdly) wrong version set by default. See
486 # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
487 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
488 PYTORCH_BUILD_VERSION = version;
489 PYTORCH_BUILD_NUMBER = 0;
490
491 # In-tree builds of NCCL are not supported.
492 # Use NCCL when cudaSupport is enabled and nccl is available.
493 USE_NCCL = setBool useSystemNccl;
494 USE_SYSTEM_NCCL = USE_NCCL;
495 USE_STATIC_NCCL = USE_NCCL;
496
497 # Set the correct Python library path, broken since
498 # https://github.com/pytorch/pytorch/commit/3d617333e
499 PYTHON_LIB_REL_PATH = "${placeholder "out"}/${python.sitePackages}";
500
501 env = {
502 # disable warnings as errors as they break the build on every compiler
503 # bump, among other things.
504 # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
505 # https://github.com/pytorch/pytorch/blob/v1.11.0/setup.py#L17
506 NIX_CFLAGS_COMPILE = toString (
507 [
508 "-Wno-error"
509 ]
510 # fix build aarch64-linux build failure with GCC14
511 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
512 "-Wno-error=incompatible-pointer-types"
513 ]
514 );
515 USE_VULKAN = setBool vulkanSupport;
516 }
517 // lib.optionalAttrs vulkanSupport {
518 VULKAN_SDK = shaderc.bin;
519 }
520 // lib.optionalAttrs rocmSupport {
521 AOTRITON_INSTALLED_PREFIX = "${rocmPackages.aotriton}";
522 # Broken HIP flag setup, fails to compile due to not finding rocthrust
523 # Only supports gfx942 so let's turn it off for now
524 USE_FBGEMM_GENAI = setBool false;
525 };
526
527 nativeBuildInputs = [
528 cmake
529 which
530 ninja
531 pybind11
532 pkg-config
533 removeReferencesTo
534 ]
535 ++ lib.optionals cudaSupport (
536 with cudaPackages;
537 [
538 autoAddDriverRunpath
539 cuda_nvcc
540 ]
541 )
542 ++ lib.optionals isCudaJetson [ cudaPackages.autoAddCudaCompatRunpath ]
543 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
544
545 buildInputs = [
546 blas
547 blas.provider
548 ]
549 # Including openmp leads to two copies being used on ARM, which segfaults.
550 # https://github.com/pytorch/pytorch/issues/149201#issuecomment-2776842320
551 ++ lib.optionals (stdenv.cc.isClang && !stdenv.hostPlatform.isAarch64) [ llvmPackages.openmp ]
552 ++ lib.optionals cudaSupport (
553 with cudaPackages;
554 [
555 cuda_cccl # <thrust/*>
556 cuda_cudart # cuda_runtime.h and libraries
557 cuda_cupti # For kineto
558 cuda_nvcc # crt/host_config.h; even though we include this in nativeBuildInputs, it's needed here too
559 cuda_nvml_dev # <nvml.h>
560 cuda_nvrtc
561 cuda_nvtx # -llibNVToolsExt
562 libcublas
563 libcufft
564 libcufile
565 libcurand
566 libcusolver
567 libcusparse
568 libcusparse_lt
569 ]
570 ++ lists.optionals (cudaPackages ? cudnn) [ cudnn ]
571 ++ lists.optionals useSystemNccl [
572 # Some platforms do not support NCCL (i.e., Jetson)
573 (lib.getDev nccl) # Provides nccl.h
574 (lib.getOutput "static" nccl) # Provides static library
575 ]
576 ++ [
577 cuda_profiler_api # <cuda_profiler_api.h>
578 ]
579 )
580 ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]
581 ++ lib.optionals (cudaSupport || rocmSupport) [ effectiveMagma ]
582 ++ lib.optionals stdenv.hostPlatform.isLinux [ numactl ]
583 ++ lib.optionals tritonSupport [ _tritonEffective ]
584 ++ lib.optionals MPISupport [ mpi ]
585 ++ lib.optionals rocmSupport [
586 rocmtoolkit_joined
587 rocmPackages.clr # Added separately so setup hook applies
588 ];
589
590 pythonRelaxDeps = [
591 "sympy"
592 ];
593 dependencies = [
594 astunparse
595 expecttest
596 filelock
597 fsspec
598 hypothesis
599 jinja2
600 networkx
601 ninja
602 packaging
603 psutil
604 pyyaml
605 requests
606 sympy
607 types-dataclasses
608 typing-extensions
609
610 # the following are required for tensorboard support
611 pillow
612 six
613 tensorboard
614 protobuf
615
616 # torch/csrc requires `pybind11` at runtime
617 pybind11
618 ]
619 ++ lib.optionals tritonSupport [ _tritonEffective ]
620 ++ lib.optionals vulkanSupport [
621 vulkan-headers
622 vulkan-loader
623 ];
624
625 propagatedCxxBuildInputs =
626 [ ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
627
628 # Tests take a long time and may be flaky, so just sanity-check imports
629 doCheck = false;
630
631 pythonImportsCheck = [ "torch" ];
632
633 nativeCheckInputs = [
634 hypothesis
635 ninja
636 psutil
637 ];
638
639 checkPhase =
640 with lib.versions;
641 with lib.strings;
642 concatStringsSep " " [
643 "runHook preCheck"
644 "${python.interpreter} test/run_test.py"
645 "--exclude"
646 (concatStringsSep " " [
647 "utils" # utils requires git, which is not allowed in the check phase
648
649 # "dataloader" # psutils correctly finds and triggers multiprocessing, but is too sandboxed to run -- resulting in numerous errors
650 # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build
651
652 # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins
653 (optionalString (majorMinor version == "1.3") "tensorboard")
654 ])
655 "runHook postCheck"
656 ];
657
658 pythonRemoveDeps = [
659 # In our dist-info the name is just "triton"
660 "pytorch-triton-rocm"
661 ];
662
663 postInstall = ''
664 find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +
665
666 mkdir $dev
667
668 # CppExtension requires that include files are packaged with the main
669 # python library output; which is why they are copied here.
670 cp -r $out/${python.sitePackages}/torch/include $dev/include
671
672 # Cmake files under /share are different and can be safely moved. This
673 # avoids unnecessary closure blow-up due to apple sdk references when
674 # USE_DISTRIBUTED is enabled.
675 mv $out/${python.sitePackages}/torch/share $dev/share
676
677 # Fix up library paths for split outputs
678 substituteInPlace \
679 $dev/share/cmake/Torch/TorchConfig.cmake \
680 --replace-fail \''${TORCH_INSTALL_PREFIX}/lib "$lib/lib"
681
682 substituteInPlace \
683 $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \
684 --replace-fail \''${_IMPORT_PREFIX}/lib "$lib/lib"
685
686 mkdir $lib
687 mv $out/${python.sitePackages}/torch/lib $lib/lib
688 ln -s $lib/lib $out/${python.sitePackages}/torch/lib
689 ''
690 + lib.optionalString rocmSupport ''
691 substituteInPlace $dev/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake \
692 --replace-fail "\''${_IMPORT_PREFIX}/lib64" "$lib/lib"
693
694 substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \
695 --replace-fail "/build/${src.name}/torch/include" "$dev/include"
696 '';
697
698 postFixup = ''
699 mkdir -p "$cxxdev/nix-support"
700 printWords "''${propagatedCxxBuildInputs[@]}" >> "$cxxdev/nix-support/propagated-build-inputs"
701 ''
702 + lib.optionalString stdenv.hostPlatform.isDarwin ''
703 for f in $(ls $lib/lib/*.dylib); do
704 install_name_tool -id $lib/lib/$(basename $f) $f || true
705 done
706
707 install_name_tool -change @rpath/libshm.dylib $lib/lib/libshm.dylib $lib/lib/libtorch_python.dylib
708 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libtorch_python.dylib
709 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch_python.dylib
710
711 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch.dylib
712
713 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libshm.dylib
714 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libshm.dylib
715 '';
716
717 # See https://github.com/NixOS/nixpkgs/issues/296179
718 #
719 # This is a quick hack to add `libnvrtc` to the runpath so that torch can find
720 # it when it is needed at runtime.
721 extraRunpaths = lib.optionals cudaSupport [ "${lib.getLib cudaPackages.cuda_nvrtc}/lib" ];
722 postPhases = lib.optionals stdenv.hostPlatform.isLinux [ "postPatchelfPhase" ];
723 postPatchelfPhase = ''
724 while IFS= read -r -d $'\0' elf ; do
725 for extra in $extraRunpaths ; do
726 echo patchelf "$elf" --add-rpath "$extra" >&2
727 patchelf "$elf" --add-rpath "$extra"
728 done
729 done < <(
730 find "''${!outputLib}" "$out" -type f -iname '*.so' -print0
731 )
732 '';
733
734 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
735 requiredSystemFeatures = [ "big-parallel" ];
736
737 passthru = {
738 inherit
739 cudaSupport
740 cudaPackages
741 rocmSupport
742 rocmPackages
743 unroll-src
744 ;
745 cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
746 # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
747 blasProvider = blas.provider;
748 # To help debug when a package is broken due to CUDA support
749 inherit brokenConditions;
750 tests = callPackage ../tests { };
751 };
752
753 meta = {
754 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
755 # keep PyTorch in the description so the package can be found under that name on search.nixos.org
756 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
757 homepage = "https://pytorch.org/";
758 license = lib.licenses.bsd3;
759 maintainers = with lib.maintainers; [
760 GaetanLepage
761 LunNova # esp. for ROCm
762 teh
763 thoughtpolice
764 tscholak
765 ]; # tscholak esp. for darwin-related builds
766 platforms =
767 lib.platforms.linux ++ lib.optionals (!cudaSupport && !rocmSupport) lib.platforms.darwin;
768 broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
769 };
770}