1{
2 stdenv,
3 lib,
4 fetchFromGitHub,
5 buildPythonPackage,
6 python,
7 config,
8 cudaSupport ? config.cudaSupport,
9 cudaPackages,
10 autoAddDriverRunpath,
11 effectiveMagma ?
12 if cudaSupport then
13 magma-cuda-static
14 else if rocmSupport then
15 magma-hip
16 else
17 magma,
18 magma,
19 magma-hip,
20 magma-cuda-static,
21 # Use the system NCCL as long as we're targeting CUDA on a supported platform.
22 useSystemNccl ? (cudaSupport && !cudaPackages.nccl.meta.unsupported || rocmSupport),
23 MPISupport ? false,
24 mpi,
25 buildDocs ? false,
26
27 # Native build inputs
28 cmake,
29 symlinkJoin,
30 which,
31 pybind11,
32 removeReferencesTo,
33 pythonRelaxDepsHook,
34
35 # Build inputs
36 numactl,
37 Accelerate,
38 CoreServices,
39 libobjc,
40
41 # Propagated build inputs
42 astunparse,
43 fsspec,
44 filelock,
45 jinja2,
46 networkx,
47 sympy,
48 numpy,
49 pyyaml,
50 cffi,
51 click,
52 typing-extensions,
53 # ROCm build and `torch.compile` requires `openai-triton`
54 tritonSupport ? (!stdenv.isDarwin),
55 openai-triton,
56
57 # Unit tests
58 hypothesis,
59 psutil,
60
61 # Disable MKLDNN on aarch64-darwin, it negatively impacts performance,
62 # this is also what official pytorch build does
63 mklDnnSupport ? !(stdenv.isDarwin && stdenv.isAarch64),
64
65 # virtual pkg that consistently instantiates blas across nixpkgs
66 # See https://github.com/NixOS/nixpkgs/pull/83888
67 blas,
68
69 # ninja (https://ninja-build.org) must be available to run C++ extensions tests,
70 ninja,
71
72 # dependencies for torch.utils.tensorboard
73 pillow,
74 six,
75 future,
76 tensorboard,
77 protobuf,
78
79 pythonOlder,
80
81 # ROCm dependencies
82 rocmSupport ? config.rocmSupport,
83 rocmPackages_5,
84 gpuTargets ? [ ],
85}:
86
87let
88 inherit (lib)
89 attrsets
90 lists
91 strings
92 trivial
93 ;
94 inherit (cudaPackages) cudaFlags cudnn nccl;
95
96 rocmPackages = rocmPackages_5;
97
98 setBool = v: if v then "1" else "0";
99
100 # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/utils/cpp_extension.py#L1744
101 supportedTorchCudaCapabilities =
102 let
103 real = [
104 "3.5"
105 "3.7"
106 "5.0"
107 "5.2"
108 "5.3"
109 "6.0"
110 "6.1"
111 "6.2"
112 "7.0"
113 "7.2"
114 "7.5"
115 "8.0"
116 "8.6"
117 "8.7"
118 "8.9"
119 "9.0"
120 ];
121 ptx = lists.map (x: "${x}+PTX") real;
122 in
123 real ++ ptx;
124
125 # NOTE: The lists.subtractLists function is perhaps a bit unintuitive. It subtracts the elements
126 # of the first list *from* the second list. That means:
127 # lists.subtractLists a b = b - a
128
129 # For CUDA
130 supportedCudaCapabilities = lists.intersectLists cudaFlags.cudaCapabilities supportedTorchCudaCapabilities;
131 unsupportedCudaCapabilities = lists.subtractLists supportedCudaCapabilities cudaFlags.cudaCapabilities;
132
133 # Use trivial.warnIf to print a warning if any unsupported GPU targets are specified.
134 gpuArchWarner =
135 supported: unsupported:
136 trivial.throwIf (supported == [ ]) (
137 "No supported GPU targets specified. Requested GPU targets: "
138 + strings.concatStringsSep ", " unsupported
139 ) supported;
140
141 # Create the gpuTargetString.
142 gpuTargetString = strings.concatStringsSep ";" (
143 if gpuTargets != [ ] then
144 # If gpuTargets is specified, it always takes priority.
145 gpuTargets
146 else if cudaSupport then
147 gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities
148 else if rocmSupport then
149 rocmPackages.clr.gpuTargets
150 else
151 throw "No GPU targets specified"
152 );
153
154 rocmtoolkit_joined = symlinkJoin {
155 name = "rocm-merged";
156
157 paths = with rocmPackages; [
158 rocm-core
159 clr
160 rccl
161 miopen
162 miopengemm
163 rocrand
164 rocblas
165 rocsparse
166 hipsparse
167 rocthrust
168 rocprim
169 hipcub
170 roctracer
171 rocfft
172 rocsolver
173 hipfft
174 hipsolver
175 hipblas
176 rocminfo
177 rocm-thunk
178 rocm-comgr
179 rocm-device-libs
180 rocm-runtime
181 clr.icd
182 hipify
183 ];
184
185 # Fix `setuptools` not being found
186 postBuild = ''
187 rm -rf $out/nix-support
188 '';
189 };
190
191 brokenConditions = attrsets.filterAttrs (_: cond: cond) {
192 "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
193 "CUDA is not targeting Linux" = cudaSupport && !stdenv.isLinux;
194 "Unsupported CUDA version" =
195 cudaSupport
196 && !(builtins.elem cudaPackages.cudaMajorVersion [
197 "11"
198 "12"
199 ]);
200 "MPI cudatoolkit does not match cudaPackages.cudatoolkit" =
201 MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
202 "Magma cudaPackages does not match cudaPackages" =
203 cudaSupport && (effectiveMagma.cudaPackages != cudaPackages);
204 };
205in
206buildPythonPackage rec {
207 pname = "torch";
208 # Don't forget to update torch-bin to the same version.
209 version = "2.3.0";
210 pyproject = true;
211
212 disabled = pythonOlder "3.8.0";
213
214 outputs = [
215 "out" # output standard python package
216 "dev" # output libtorch headers
217 "lib" # output libtorch libraries
218 "cxxdev" # propagated deps for the cmake consumers of torch
219 ];
220 cudaPropagateToOutput = "cxxdev";
221
222 src = fetchFromGitHub {
223 owner = "pytorch";
224 repo = "pytorch";
225 rev = "refs/tags/v${version}";
226 fetchSubmodules = true;
227 hash = "sha256-UmH4Mv5QL7Mz4Y4pvxn8F1FGBR/UzYZjE2Ys8Oc0FWQ=";
228 };
229
230 patches =
231 lib.optionals cudaSupport [ ./fix-cmake-cuda-toolkit.patch ]
232 ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [
233 # pthreadpool added support for Grand Central Dispatch in April
234 # 2020. However, this relies on functionality (DISPATCH_APPLY_AUTO)
235 # that is available starting with macOS 10.13. However, our current
236 # base is 10.12. Until we upgrade, we can fall back on the older
237 # pthread support.
238 ./pthreadpool-disable-gcd.diff
239 ]
240 ++ lib.optionals stdenv.isLinux [
241 # Propagate CUPTI to Kineto by overriding the search path with environment variables.
242 # https://github.com/pytorch/pytorch/pull/108847
243 ./pytorch-pr-108847.patch
244 ];
245
246 postPatch =
247 lib.optionalString rocmSupport ''
248 # https://github.com/facebookincubator/gloo/pull/297
249 substituteInPlace third_party/gloo/cmake/Hipify.cmake \
250 --replace "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}"
251
252 # Replace hard-coded rocm paths
253 substituteInPlace caffe2/CMakeLists.txt \
254 --replace "/opt/rocm" "${rocmtoolkit_joined}" \
255 --replace "hcc/include" "hip/include" \
256 --replace "rocblas/include" "include/rocblas" \
257 --replace "hipsparse/include" "include/hipsparse"
258
259 # Doesn't pick up the environment variable?
260 substituteInPlace third_party/kineto/libkineto/CMakeLists.txt \
261 --replace "\''$ENV{ROCM_SOURCE_DIR}" "${rocmtoolkit_joined}" \
262 --replace "/opt/rocm" "${rocmtoolkit_joined}"
263
264 # Strangely, this is never set in cmake
265 substituteInPlace cmake/public/LoadHIP.cmake \
266 --replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \
267 "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitVersion rocmPackages.clr.version))})"
268 ''
269 # Detection of NCCL version doesn't work particularly well when using the static binary.
270 + lib.optionalString cudaSupport ''
271 substituteInPlace cmake/Modules/FindNCCL.cmake \
272 --replace \
273 'message(FATAL_ERROR "Found NCCL header version and library version' \
274 'message(WARNING "Found NCCL header version and library version'
275 ''
276 # Remove PyTorch's FindCUDAToolkit.cmake and to use CMake's default.
277 # We do not remove the entirety of cmake/Modules_CUDA_fix because we need FindCUDNN.cmake.
278 + lib.optionalString cudaSupport ''
279 rm cmake/Modules/FindCUDAToolkit.cmake
280 rm -rf cmake/Modules_CUDA_fix/{upstream,FindCUDA.cmake}
281 ''
282 # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc'
283 # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header.
284 +
285 lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.hostPlatform.darwinSdkVersion "11.0")
286 ''
287 substituteInPlace third_party/pocketfft/pocketfft_hdronly.h --replace-fail '#if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER))
288 inline void *aligned_alloc(size_t align, size_t size)' '#if 0
289 inline void *aligned_alloc(size_t align, size_t size)'
290 '';
291
292 # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
293 # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
294 # without extreme care to ensure they don't lock each other out of shared resources.
295 # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
296 preConfigure =
297 lib.optionalString cudaSupport ''
298 export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
299 export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include
300 export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib
301 ''
302 + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
303 export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
304 export CUDNN_LIB_DIR=${cudnn.lib}/lib
305 ''
306 + lib.optionalString rocmSupport ''
307 export ROCM_PATH=${rocmtoolkit_joined}
308 export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
309 export PYTORCH_ROCM_ARCH="${gpuTargetString}"
310 export CMAKE_CXX_FLAGS="-I${rocmtoolkit_joined}/include -I${rocmtoolkit_joined}/include/rocblas"
311 python tools/amd_build/build_amd.py
312 '';
313
314 # Use pytorch's custom configurations
315 dontUseCmakeConfigure = true;
316
317 # causes possible redefinition of _FORTIFY_SOURCE
318 hardeningDisable = [ "fortify3" ];
319
320 BUILD_NAMEDTENSOR = setBool true;
321 BUILD_DOCS = setBool buildDocs;
322
323 # We only do an imports check, so do not build tests either.
324 BUILD_TEST = setBool false;
325
326 # Unlike MKL, oneDNN (née MKLDNN) is FOSS, so we enable support for
327 # it by default. PyTorch currently uses its own vendored version
328 # of oneDNN through Intel iDeep.
329 USE_MKLDNN = setBool mklDnnSupport;
330 USE_MKLDNN_CBLAS = setBool mklDnnSupport;
331
332 # Avoid using pybind11 from git submodule
333 # Also avoids pytorch exporting the headers of pybind11
334 USE_SYSTEM_PYBIND11 = true;
335
336 # NB technical debt: building without NNPACK as workaround for missing `six`
337 USE_NNPACK = 0;
338
339 preBuild = ''
340 export MAX_JOBS=$NIX_BUILD_CORES
341 ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
342 ${cmake}/bin/cmake build
343 '';
344
345 preFixup = ''
346 function join_by { local IFS="$1"; shift; echo "$*"; }
347 function strip2 {
348 IFS=':'
349 read -ra RP <<< $(patchelf --print-rpath $1)
350 IFS=' '
351 RP_NEW=$(join_by : ''${RP[@]:2})
352 patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
353 }
354 for f in $(find ''${out} -name 'libcaffe2*.so')
355 do
356 strip2 $f
357 done
358 '';
359
360 # Override the (weirdly) wrong version set by default. See
361 # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
362 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
363 PYTORCH_BUILD_VERSION = version;
364 PYTORCH_BUILD_NUMBER = 0;
365
366 # In-tree builds of NCCL are not supported.
367 # Use NCCL when cudaSupport is enabled and nccl is available.
368 USE_NCCL = setBool useSystemNccl;
369 USE_SYSTEM_NCCL = USE_NCCL;
370 USE_STATIC_NCCL = USE_NCCL;
371
372 # Suppress a weird warning in mkl-dnn, part of ideep in pytorch
373 # (upstream seems to have fixed this in the wrong place?)
374 # https://github.com/intel/mkl-dnn/commit/8134d346cdb7fe1695a2aa55771071d455fae0bc
375 # https://github.com/pytorch/pytorch/issues/22346
376 #
377 # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
378 # https://github.com/pytorch/pytorch/blob/v1.11.0/setup.py#L17
379 env.NIX_CFLAGS_COMPILE = toString (
380 (
381 lib.optionals (blas.implementation == "mkl") [ "-Wno-error=array-bounds" ]
382 # Suppress gcc regression: avx512 math function raises uninitialized variable warning
383 # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105593
384 # See also: Fails to compile with GCC 12.1.0 https://github.com/pytorch/pytorch/issues/77939
385 ++ lib.optionals (stdenv.cc.isGNU && lib.versionAtLeast stdenv.cc.version "12.0.0") [
386 "-Wno-error=maybe-uninitialized"
387 "-Wno-error=uninitialized"
388 ]
389 # Since pytorch 2.0:
390 # gcc-12.2.0/include/c++/12.2.0/bits/new_allocator.h:158:33: error: ‘void operator delete(void*, std::size_t)’
391 # ... called on pointer ‘<unknown>’ with nonzero offset [1, 9223372036854775800] [-Werror=free-nonheap-object]
392 ++ lib.optionals (stdenv.cc.isGNU && lib.versions.major stdenv.cc.version == "12") [
393 "-Wno-error=free-nonheap-object"
394 ]
395 # .../source/torch/csrc/autograd/generated/python_functions_0.cpp:85:3:
396 # error: cast from ... to ... converts to incompatible function type [-Werror,-Wcast-function-type-strict]
397 ++ lib.optionals (stdenv.cc.isClang && lib.versionAtLeast stdenv.cc.version "16") [
398 "-Wno-error=cast-function-type-strict"
399 # Suppresses the most spammy warnings.
400 # This is mainly to fix https://github.com/NixOS/nixpkgs/issues/266895.
401 ]
402 ++ lib.optionals rocmSupport [
403 "-Wno-#warnings"
404 "-Wno-cpp"
405 "-Wno-unknown-warning-option"
406 "-Wno-ignored-attributes"
407 "-Wno-deprecated-declarations"
408 "-Wno-defaulted-function-deleted"
409 "-Wno-pass-failed"
410 ]
411 ++ [
412 "-Wno-unused-command-line-argument"
413 "-Wno-uninitialized"
414 "-Wno-array-bounds"
415 "-Wno-free-nonheap-object"
416 "-Wno-unused-result"
417 ]
418 ++ lib.optionals stdenv.cc.isGNU [
419 "-Wno-maybe-uninitialized"
420 "-Wno-stringop-overflow"
421 ]
422 )
423 );
424
425 nativeBuildInputs =
426 [
427 cmake
428 which
429 ninja
430 pybind11
431 pythonRelaxDepsHook
432 removeReferencesTo
433 ]
434 ++ lib.optionals cudaSupport (
435 with cudaPackages;
436 [
437 autoAddDriverRunpath
438 cuda_nvcc
439 ]
440 )
441 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
442
443 buildInputs =
444 [
445 blas
446 blas.provider
447 ]
448 ++ lib.optionals cudaSupport (
449 with cudaPackages;
450 [
451 cuda_cccl.dev # <thrust/*>
452 cuda_cudart.dev # cuda_runtime.h and libraries
453 cuda_cudart.lib
454 cuda_cudart.static
455 cuda_cupti.dev # For kineto
456 cuda_cupti.lib # For kineto
457 cuda_nvcc.dev # crt/host_config.h; even though we include this in nativeBuildinputs, it's needed here too
458 cuda_nvml_dev.dev # <nvml.h>
459 cuda_nvrtc.dev
460 cuda_nvrtc.lib
461 cuda_nvtx.dev
462 cuda_nvtx.lib # -llibNVToolsExt
463 libcublas.dev
464 libcublas.lib
465 libcufft.dev
466 libcufft.lib
467 libcurand.dev
468 libcurand.lib
469 libcusolver.dev
470 libcusolver.lib
471 libcusparse.dev
472 libcusparse.lib
473 ]
474 ++ lists.optionals (cudaPackages ? cudnn) [
475 cudnn.dev
476 cudnn.lib
477 ]
478 ++ lists.optionals useSystemNccl [
479 # Some platforms do not support NCCL (i.e., Jetson)
480 nccl.dev # Provides nccl.h AND a static copy of NCCL!
481 ]
482 ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
483 cuda_nvprof.dev # <cuda_profiler_api.h>
484 ]
485 ++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [
486 cuda_profiler_api.dev # <cuda_profiler_api.h>
487 ]
488 )
489 ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]
490 ++ lib.optionals (cudaSupport || rocmSupport) [ effectiveMagma ]
491 ++ lib.optionals stdenv.isLinux [ numactl ]
492 ++ lib.optionals stdenv.isDarwin [
493 Accelerate
494 CoreServices
495 libobjc
496 ]
497 ++ lib.optionals tritonSupport [ openai-triton ]
498 ++ lib.optionals MPISupport [ mpi ]
499 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
500
501 propagatedBuildInputs = [
502 astunparse
503 cffi
504 click
505 numpy
506 pyyaml
507
508 # From install_requires:
509 fsspec
510 filelock
511 typing-extensions
512 sympy
513 networkx
514 jinja2
515
516 # the following are required for tensorboard support
517 pillow
518 six
519 future
520 tensorboard
521 protobuf
522
523 # torch/csrc requires `pybind11` at runtime
524 pybind11
525 ] ++ lib.optionals tritonSupport [ openai-triton ];
526
527 propagatedCxxBuildInputs =
528 [ ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
529
530 # Tests take a long time and may be flaky, so just sanity-check imports
531 doCheck = false;
532
533 pythonImportsCheck = [ "torch" ];
534
535 nativeCheckInputs = [
536 hypothesis
537 ninja
538 psutil
539 ];
540
541 checkPhase =
542 with lib.versions;
543 with lib.strings;
544 concatStringsSep " " [
545 "runHook preCheck"
546 "${python.interpreter} test/run_test.py"
547 "--exclude"
548 (concatStringsSep " " [
549 "utils" # utils requires git, which is not allowed in the check phase
550
551 # "dataloader" # psutils correctly finds and triggers multiprocessing, but is too sandboxed to run -- resulting in numerous errors
552 # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build
553
554 # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins
555 (optionalString (majorMinor version == "1.3") "tensorboard")
556 ])
557 "runHook postCheck"
558 ];
559
560 pythonRemoveDeps = [
561 # In our dist-info the name is just "triton"
562 "pytorch-triton-rocm"
563 ];
564
565 postInstall =
566 ''
567 find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +
568
569 mkdir $dev
570 cp -r $out/${python.sitePackages}/torch/include $dev/include
571 cp -r $out/${python.sitePackages}/torch/share $dev/share
572
573 # Fix up library paths for split outputs
574 substituteInPlace \
575 $dev/share/cmake/Torch/TorchConfig.cmake \
576 --replace \''${TORCH_INSTALL_PREFIX}/lib "$lib/lib"
577
578 substituteInPlace \
579 $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \
580 --replace \''${_IMPORT_PREFIX}/lib "$lib/lib"
581
582 mkdir $lib
583 mv $out/${python.sitePackages}/torch/lib $lib/lib
584 ln -s $lib/lib $out/${python.sitePackages}/torch/lib
585 ''
586 + lib.optionalString rocmSupport ''
587 substituteInPlace $dev/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake \
588 --replace "\''${_IMPORT_PREFIX}/lib64" "$lib/lib"
589
590 substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \
591 --replace "/build/source/torch/include" "$dev/include"
592 '';
593
594 postFixup =
595 ''
596 mkdir -p "$cxxdev/nix-support"
597 printWords "''${propagatedCxxBuildInputs[@]}" >> "$cxxdev/nix-support/propagated-build-inputs"
598 ''
599 + lib.optionalString stdenv.isDarwin ''
600 for f in $(ls $lib/lib/*.dylib); do
601 install_name_tool -id $lib/lib/$(basename $f) $f || true
602 done
603
604 install_name_tool -change @rpath/libshm.dylib $lib/lib/libshm.dylib $lib/lib/libtorch_python.dylib
605 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libtorch_python.dylib
606 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch_python.dylib
607
608 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch.dylib
609
610 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libshm.dylib
611 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libshm.dylib
612 '';
613
614 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
615 requiredSystemFeatures = [ "big-parallel" ];
616
617 passthru = {
618 inherit
619 cudaSupport
620 cudaPackages
621 rocmSupport
622 rocmPackages
623 ;
624 # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
625 blasProvider = blas.provider;
626 # To help debug when a package is broken due to CUDA support
627 inherit brokenConditions;
628 cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
629 };
630
631 meta = with lib; {
632 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
633 # keep PyTorch in the description so the package can be found under that name on search.nixos.org
634 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
635 homepage = "https://pytorch.org/";
636 license = licenses.bsd3;
637 maintainers = with maintainers; [
638 teh
639 thoughtpolice
640 tscholak
641 ]; # tscholak esp. for darwin-related builds
642 platforms = with platforms; linux ++ lib.optionals (!cudaSupport && !rocmSupport) darwin;
643 broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
644 };
645}