1{
2 stdenv,
3 lib,
4 fetchFromGitHub,
5 buildPythonPackage,
6 python,
7 config,
8 cudaSupport ? config.cudaSupport,
9 cudaPackages,
10 autoAddDriverRunpath,
11 effectiveMagma ?
12 if cudaSupport then
13 magma-cuda-static
14 else if rocmSupport then
15 magma-hip
16 else
17 magma,
18 magma,
19 magma-hip,
20 magma-cuda-static,
21 # Use the system NCCL as long as we're targeting CUDA on a supported platform.
22 useSystemNccl ? (cudaSupport && !cudaPackages.nccl.meta.unsupported || rocmSupport),
23 MPISupport ? false,
24 mpi,
25 buildDocs ? false,
26
27 # tests.cudaAvailable:
28 callPackage,
29 torchWithCuda,
30
31 # Native build inputs
32 cmake,
33 symlinkJoin,
34 which,
35 pybind11,
36 removeReferencesTo,
37
38 # Build inputs
39 numactl,
40 Accelerate,
41 CoreServices,
42 libobjc,
43
44 # Propagated build inputs
45 astunparse,
46 fsspec,
47 filelock,
48 jinja2,
49 networkx,
50 sympy,
51 numpy,
52 pyyaml,
53 cffi,
54 click,
55 typing-extensions,
56 # ROCm build and `torch.compile` requires `triton`
57 tritonSupport ? (!stdenv.isDarwin),
58 triton,
59
60 # Unit tests
61 hypothesis,
62 psutil,
63
64 # Disable MKLDNN on aarch64-darwin, it negatively impacts performance,
65 # this is also what official pytorch build does
66 mklDnnSupport ? !(stdenv.isDarwin && stdenv.isAarch64),
67
68 # virtual pkg that consistently instantiates blas across nixpkgs
69 # See https://github.com/NixOS/nixpkgs/pull/83888
70 blas,
71
72 # ninja (https://ninja-build.org) must be available to run C++ extensions tests,
73 ninja,
74
75 # dependencies for torch.utils.tensorboard
76 pillow,
77 six,
78 future,
79 tensorboard,
80 protobuf,
81
82 pythonOlder,
83
84 # ROCm dependencies
85 rocmSupport ? config.rocmSupport,
86 rocmPackages_5,
87 gpuTargets ? [ ],
88}:
89
90let
91 inherit (lib)
92 attrsets
93 lists
94 strings
95 trivial
96 ;
97 inherit (cudaPackages) cudaFlags cudnn nccl;
98
99 rocmPackages = rocmPackages_5;
100
101 setBool = v: if v then "1" else "0";
102
103 # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/utils/cpp_extension.py#L1744
104 supportedTorchCudaCapabilities =
105 let
106 real = [
107 "3.5"
108 "3.7"
109 "5.0"
110 "5.2"
111 "5.3"
112 "6.0"
113 "6.1"
114 "6.2"
115 "7.0"
116 "7.2"
117 "7.5"
118 "8.0"
119 "8.6"
120 "8.7"
121 "8.9"
122 "9.0"
123 ];
124 ptx = lists.map (x: "${x}+PTX") real;
125 in
126 real ++ ptx;
127
128 # NOTE: The lists.subtractLists function is perhaps a bit unintuitive. It subtracts the elements
129 # of the first list *from* the second list. That means:
130 # lists.subtractLists a b = b - a
131
132 # For CUDA
133 supportedCudaCapabilities = lists.intersectLists cudaFlags.cudaCapabilities supportedTorchCudaCapabilities;
134 unsupportedCudaCapabilities = lists.subtractLists supportedCudaCapabilities cudaFlags.cudaCapabilities;
135
136 # Use trivial.warnIf to print a warning if any unsupported GPU targets are specified.
137 gpuArchWarner =
138 supported: unsupported:
139 trivial.throwIf (supported == [ ]) (
140 "No supported GPU targets specified. Requested GPU targets: "
141 + strings.concatStringsSep ", " unsupported
142 ) supported;
143
144 # Create the gpuTargetString.
145 gpuTargetString = strings.concatStringsSep ";" (
146 if gpuTargets != [ ] then
147 # If gpuTargets is specified, it always takes priority.
148 gpuTargets
149 else if cudaSupport then
150 gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities
151 else if rocmSupport then
152 rocmPackages.clr.gpuTargets
153 else
154 throw "No GPU targets specified"
155 );
156
157 rocmtoolkit_joined = symlinkJoin {
158 name = "rocm-merged";
159
160 paths = with rocmPackages; [
161 rocm-core
162 clr
163 rccl
164 miopen
165 miopengemm
166 rocrand
167 rocblas
168 rocsparse
169 hipsparse
170 rocthrust
171 rocprim
172 hipcub
173 roctracer
174 rocfft
175 rocsolver
176 hipfft
177 hipsolver
178 hipblas
179 rocminfo
180 rocm-thunk
181 rocm-comgr
182 rocm-device-libs
183 rocm-runtime
184 clr.icd
185 hipify
186 ];
187
188 # Fix `setuptools` not being found
189 postBuild = ''
190 rm -rf $out/nix-support
191 '';
192 };
193
194 brokenConditions = attrsets.filterAttrs (_: cond: cond) {
195 "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
196 "CUDA is not targeting Linux" = cudaSupport && !stdenv.isLinux;
197 "Unsupported CUDA version" =
198 cudaSupport
199 && !(builtins.elem cudaPackages.cudaMajorVersion [
200 "11"
201 "12"
202 ]);
203 "MPI cudatoolkit does not match cudaPackages.cudatoolkit" =
204 MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
205 "Magma cudaPackages does not match cudaPackages" =
206 cudaSupport && (effectiveMagma.cudaPackages != cudaPackages);
207 "Rocm support is currently broken because `rocmPackages.hipblaslt` is unpackaged. (2024-06-09)" = rocmSupport;
208 };
209in
210buildPythonPackage rec {
211 pname = "torch";
212 # Don't forget to update torch-bin to the same version.
213 version = "2.3.1";
214 pyproject = true;
215
216 disabled = pythonOlder "3.8.0";
217
218 outputs = [
219 "out" # output standard python package
220 "dev" # output libtorch headers
221 "lib" # output libtorch libraries
222 "cxxdev" # propagated deps for the cmake consumers of torch
223 ];
224 cudaPropagateToOutput = "cxxdev";
225
226 src = fetchFromGitHub {
227 owner = "pytorch";
228 repo = "pytorch";
229 rev = "refs/tags/v${version}";
230 fetchSubmodules = true;
231 hash = "sha256-vpgtOqzIDKgRuqdT8lB/g6j+oMIH1RPxdbjtlzZFjV8=";
232 };
233
234 patches =
235 lib.optionals cudaSupport [ ./fix-cmake-cuda-toolkit.patch ]
236 ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [
237 # pthreadpool added support for Grand Central Dispatch in April
238 # 2020. However, this relies on functionality (DISPATCH_APPLY_AUTO)
239 # that is available starting with macOS 10.13. However, our current
240 # base is 10.12. Until we upgrade, we can fall back on the older
241 # pthread support.
242 ./pthreadpool-disable-gcd.diff
243 ]
244 ++ lib.optionals stdenv.isLinux [
245 # Propagate CUPTI to Kineto by overriding the search path with environment variables.
246 # https://github.com/pytorch/pytorch/pull/108847
247 ./pytorch-pr-108847.patch
248 ];
249
250 postPatch =
251 lib.optionalString rocmSupport ''
252 # https://github.com/facebookincubator/gloo/pull/297
253 substituteInPlace third_party/gloo/cmake/Hipify.cmake \
254 --replace "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}"
255
256 # Replace hard-coded rocm paths
257 substituteInPlace caffe2/CMakeLists.txt \
258 --replace "/opt/rocm" "${rocmtoolkit_joined}" \
259 --replace "hcc/include" "hip/include" \
260 --replace "rocblas/include" "include/rocblas" \
261 --replace "hipsparse/include" "include/hipsparse"
262
263 # Doesn't pick up the environment variable?
264 substituteInPlace third_party/kineto/libkineto/CMakeLists.txt \
265 --replace "\''$ENV{ROCM_SOURCE_DIR}" "${rocmtoolkit_joined}" \
266 --replace "/opt/rocm" "${rocmtoolkit_joined}"
267
268 # Strangely, this is never set in cmake
269 substituteInPlace cmake/public/LoadHIP.cmake \
270 --replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \
271 "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitVersion rocmPackages.clr.version))})"
272 ''
273 # Detection of NCCL version doesn't work particularly well when using the static binary.
274 + lib.optionalString cudaSupport ''
275 substituteInPlace cmake/Modules/FindNCCL.cmake \
276 --replace \
277 'message(FATAL_ERROR "Found NCCL header version and library version' \
278 'message(WARNING "Found NCCL header version and library version'
279 ''
280 # Remove PyTorch's FindCUDAToolkit.cmake and to use CMake's default.
281 # We do not remove the entirety of cmake/Modules_CUDA_fix because we need FindCUDNN.cmake.
282 + lib.optionalString cudaSupport ''
283 rm cmake/Modules/FindCUDAToolkit.cmake
284 rm -rf cmake/Modules_CUDA_fix/{upstream,FindCUDA.cmake}
285 ''
286 # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc'
287 # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header.
288 +
289 lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.hostPlatform.darwinSdkVersion "11.0")
290 ''
291 substituteInPlace third_party/pocketfft/pocketfft_hdronly.h --replace-fail '#if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER))
292 inline void *aligned_alloc(size_t align, size_t size)' '#if 0
293 inline void *aligned_alloc(size_t align, size_t size)'
294 '';
295
296 # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
297 # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
298 # without extreme care to ensure they don't lock each other out of shared resources.
299 # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
300 preConfigure =
301 lib.optionalString cudaSupport ''
302 export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
303 export CUPTI_INCLUDE_DIR=${lib.getDev cudaPackages.cuda_cupti}/include
304 export CUPTI_LIBRARY_DIR=${lib.getLib cudaPackages.cuda_cupti}/lib
305 ''
306 + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
307 export CUDNN_INCLUDE_DIR=${lib.getLib cudnn}/include
308 export CUDNN_LIB_DIR=${cudnn.lib}/lib
309 ''
310 + lib.optionalString rocmSupport ''
311 export ROCM_PATH=${rocmtoolkit_joined}
312 export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
313 export PYTORCH_ROCM_ARCH="${gpuTargetString}"
314 export CMAKE_CXX_FLAGS="-I${rocmtoolkit_joined}/include -I${rocmtoolkit_joined}/include/rocblas"
315 python tools/amd_build/build_amd.py
316 '';
317
318 # Use pytorch's custom configurations
319 dontUseCmakeConfigure = true;
320
321 # causes possible redefinition of _FORTIFY_SOURCE
322 hardeningDisable = [ "fortify3" ];
323
324 BUILD_NAMEDTENSOR = setBool true;
325 BUILD_DOCS = setBool buildDocs;
326
327 # We only do an imports check, so do not build tests either.
328 BUILD_TEST = setBool false;
329
330 # Unlike MKL, oneDNN (née MKLDNN) is FOSS, so we enable support for
331 # it by default. PyTorch currently uses its own vendored version
332 # of oneDNN through Intel iDeep.
333 USE_MKLDNN = setBool mklDnnSupport;
334 USE_MKLDNN_CBLAS = setBool mklDnnSupport;
335
336 # Avoid using pybind11 from git submodule
337 # Also avoids pytorch exporting the headers of pybind11
338 USE_SYSTEM_PYBIND11 = true;
339
340 # NB technical debt: building without NNPACK as workaround for missing `six`
341 USE_NNPACK = 0;
342
343 preBuild = ''
344 export MAX_JOBS=$NIX_BUILD_CORES
345 ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
346 ${cmake}/bin/cmake build
347 '';
348
349 preFixup = ''
350 function join_by { local IFS="$1"; shift; echo "$*"; }
351 function strip2 {
352 IFS=':'
353 read -ra RP <<< $(patchelf --print-rpath $1)
354 IFS=' '
355 RP_NEW=$(join_by : ''${RP[@]:2})
356 patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
357 }
358 for f in $(find ''${out} -name 'libcaffe2*.so')
359 do
360 strip2 $f
361 done
362 '';
363
364 # Override the (weirdly) wrong version set by default. See
365 # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
366 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
367 PYTORCH_BUILD_VERSION = version;
368 PYTORCH_BUILD_NUMBER = 0;
369
370 # In-tree builds of NCCL are not supported.
371 # Use NCCL when cudaSupport is enabled and nccl is available.
372 USE_NCCL = setBool useSystemNccl;
373 USE_SYSTEM_NCCL = USE_NCCL;
374 USE_STATIC_NCCL = USE_NCCL;
375
376 # Suppress a weird warning in mkl-dnn, part of ideep in pytorch
377 # (upstream seems to have fixed this in the wrong place?)
378 # https://github.com/intel/mkl-dnn/commit/8134d346cdb7fe1695a2aa55771071d455fae0bc
379 # https://github.com/pytorch/pytorch/issues/22346
380 #
381 # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
382 # https://github.com/pytorch/pytorch/blob/v1.11.0/setup.py#L17
383 env.NIX_CFLAGS_COMPILE = toString (
384 (
385 lib.optionals (blas.implementation == "mkl") [ "-Wno-error=array-bounds" ]
386 # Suppress gcc regression: avx512 math function raises uninitialized variable warning
387 # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105593
388 # See also: Fails to compile with GCC 12.1.0 https://github.com/pytorch/pytorch/issues/77939
389 ++ lib.optionals (stdenv.cc.isGNU && lib.versionAtLeast stdenv.cc.version "12.0.0") [
390 "-Wno-error=maybe-uninitialized"
391 "-Wno-error=uninitialized"
392 ]
393 # Since pytorch 2.0:
394 # gcc-12.2.0/include/c++/12.2.0/bits/new_allocator.h:158:33: error: ‘void operator delete(void*, std::size_t)’
395 # ... called on pointer ‘<unknown>’ with nonzero offset [1, 9223372036854775800] [-Werror=free-nonheap-object]
396 ++ lib.optionals (stdenv.cc.isGNU && lib.versions.major stdenv.cc.version == "12") [
397 "-Wno-error=free-nonheap-object"
398 ]
399 # .../source/torch/csrc/autograd/generated/python_functions_0.cpp:85:3:
400 # error: cast from ... to ... converts to incompatible function type [-Werror,-Wcast-function-type-strict]
401 ++ lib.optionals (stdenv.cc.isClang && lib.versionAtLeast stdenv.cc.version "16") [
402 "-Wno-error=cast-function-type-strict"
403 # Suppresses the most spammy warnings.
404 # This is mainly to fix https://github.com/NixOS/nixpkgs/issues/266895.
405 ]
406 ++ lib.optionals rocmSupport [
407 "-Wno-#warnings"
408 "-Wno-cpp"
409 "-Wno-unknown-warning-option"
410 "-Wno-ignored-attributes"
411 "-Wno-deprecated-declarations"
412 "-Wno-defaulted-function-deleted"
413 "-Wno-pass-failed"
414 ]
415 ++ [
416 "-Wno-unused-command-line-argument"
417 "-Wno-uninitialized"
418 "-Wno-array-bounds"
419 "-Wno-free-nonheap-object"
420 "-Wno-unused-result"
421 ]
422 ++ lib.optionals stdenv.cc.isGNU [
423 "-Wno-maybe-uninitialized"
424 "-Wno-stringop-overflow"
425 ]
426 )
427 );
428
429 nativeBuildInputs =
430 [
431 cmake
432 which
433 ninja
434 pybind11
435 removeReferencesTo
436 ]
437 ++ lib.optionals cudaSupport (
438 with cudaPackages;
439 [
440 autoAddDriverRunpath
441 cuda_nvcc
442 ]
443 )
444 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
445
446 buildInputs =
447 [
448 blas
449 blas.provider
450 ]
451 ++ lib.optionals cudaSupport (
452 with cudaPackages;
453 [
454 cuda_cccl # <thrust/*>
455 cuda_cudart # cuda_runtime.h and libraries
456 cuda_cupti # For kineto
457 cuda_nvcc # crt/host_config.h; even though we include this in nativeBuildinputs, it's needed here too
458 cuda_nvml_dev # <nvml.h>
459 cuda_nvrtc
460 cuda_nvtx # -llibNVToolsExt
461 libcublas
462 libcufft
463 libcurand
464 libcusolver
465 libcusparse
466 ]
467 ++ lists.optionals (cudaPackages ? cudnn) [
468 cudnn
469 ]
470 ++ lists.optionals useSystemNccl [
471 # Some platforms do not support NCCL (i.e., Jetson)
472 nccl # Provides nccl.h AND a static copy of NCCL!
473 ]
474 ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
475 cuda_nvprof # <cuda_profiler_api.h>
476 ]
477 ++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [
478 cuda_profiler_api # <cuda_profiler_api.h>
479 ]
480 )
481 ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]
482 ++ lib.optionals (cudaSupport || rocmSupport) [ effectiveMagma ]
483 ++ lib.optionals stdenv.isLinux [ numactl ]
484 ++ lib.optionals stdenv.isDarwin [
485 Accelerate
486 CoreServices
487 libobjc
488 ]
489 ++ lib.optionals tritonSupport [ triton ]
490 ++ lib.optionals MPISupport [ mpi ]
491 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
492
493 dependencies = [
494 astunparse
495 cffi
496 click
497 numpy
498 pyyaml
499
500 # From install_requires:
501 fsspec
502 filelock
503 typing-extensions
504 sympy
505 networkx
506 jinja2
507
508 # the following are required for tensorboard support
509 pillow
510 six
511 future
512 tensorboard
513 protobuf
514
515 # torch/csrc requires `pybind11` at runtime
516 pybind11
517 ] ++ lib.optionals tritonSupport [ triton ];
518
519 propagatedCxxBuildInputs =
520 [ ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
521
522 # Tests take a long time and may be flaky, so just sanity-check imports
523 doCheck = false;
524
525 pythonImportsCheck = [ "torch" ];
526
527 nativeCheckInputs = [
528 hypothesis
529 ninja
530 psutil
531 ];
532
533 checkPhase =
534 with lib.versions;
535 with lib.strings;
536 concatStringsSep " " [
537 "runHook preCheck"
538 "${python.interpreter} test/run_test.py"
539 "--exclude"
540 (concatStringsSep " " [
541 "utils" # utils requires git, which is not allowed in the check phase
542
543 # "dataloader" # psutils correctly finds and triggers multiprocessing, but is too sandboxed to run -- resulting in numerous errors
544 # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build
545
546 # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins
547 (optionalString (majorMinor version == "1.3") "tensorboard")
548 ])
549 "runHook postCheck"
550 ];
551
552 pythonRemoveDeps = [
553 # In our dist-info the name is just "triton"
554 "pytorch-triton-rocm"
555 ];
556
557 postInstall =
558 ''
559 find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +
560
561 mkdir $dev
562 cp -r $out/${python.sitePackages}/torch/include $dev/include
563 cp -r $out/${python.sitePackages}/torch/share $dev/share
564
565 # Fix up library paths for split outputs
566 substituteInPlace \
567 $dev/share/cmake/Torch/TorchConfig.cmake \
568 --replace \''${TORCH_INSTALL_PREFIX}/lib "$lib/lib"
569
570 substituteInPlace \
571 $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \
572 --replace \''${_IMPORT_PREFIX}/lib "$lib/lib"
573
574 mkdir $lib
575 mv $out/${python.sitePackages}/torch/lib $lib/lib
576 ln -s $lib/lib $out/${python.sitePackages}/torch/lib
577 ''
578 + lib.optionalString rocmSupport ''
579 substituteInPlace $dev/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake \
580 --replace "\''${_IMPORT_PREFIX}/lib64" "$lib/lib"
581
582 substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \
583 --replace "/build/source/torch/include" "$dev/include"
584 '';
585
586 postFixup =
587 ''
588 mkdir -p "$cxxdev/nix-support"
589 printWords "''${propagatedCxxBuildInputs[@]}" >> "$cxxdev/nix-support/propagated-build-inputs"
590 ''
591 + lib.optionalString stdenv.isDarwin ''
592 for f in $(ls $lib/lib/*.dylib); do
593 install_name_tool -id $lib/lib/$(basename $f) $f || true
594 done
595
596 install_name_tool -change @rpath/libshm.dylib $lib/lib/libshm.dylib $lib/lib/libtorch_python.dylib
597 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libtorch_python.dylib
598 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch_python.dylib
599
600 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch.dylib
601
602 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libshm.dylib
603 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libshm.dylib
604 '';
605
606 # See https://github.com/NixOS/nixpkgs/issues/296179
607 #
608 # This is a quick hack to add `libnvrtc` to the runpath so that torch can find
609 # it when it is needed at runtime.
610 extraRunpaths = lib.optionals cudaSupport [ "${lib.getLib cudaPackages.cuda_nvrtc}/lib" ];
611 postPhases = lib.optionals stdenv.isLinux [ "postPatchelfPhase" ];
612 postPatchelfPhase = ''
613 while IFS= read -r -d $'\0' elf ; do
614 for extra in $extraRunpaths ; do
615 echo patchelf "$elf" --add-rpath "$extra" >&2
616 patchelf "$elf" --add-rpath "$extra"
617 done
618 done < <(
619 find "''${!outputLib}" "$out" -type f -iname '*.so' -print0
620 )
621 '';
622
623 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
624 requiredSystemFeatures = [ "big-parallel" ];
625
626 passthru = {
627 inherit
628 cudaSupport
629 cudaPackages
630 rocmSupport
631 rocmPackages
632 ;
633 cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
634 # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
635 blasProvider = blas.provider;
636 # To help debug when a package is broken due to CUDA support
637 inherit brokenConditions;
638 tests = callPackage ./tests.nix { };
639 };
640
641 meta = {
642 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
643 # keep PyTorch in the description so the package can be found under that name on search.nixos.org
644 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
645 homepage = "https://pytorch.org/";
646 license = lib.licenses.bsd3;
647 maintainers = with lib.maintainers; [
648 teh
649 thoughtpolice
650 tscholak
651 ]; # tscholak esp. for darwin-related builds
652 platforms = with lib.platforms; linux ++ lib.optionals (!cudaSupport && !rocmSupport) darwin;
653 broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
654 };
655}