at 24.11-pre 21 kB view raw
1{ 2 stdenv, 3 lib, 4 fetchFromGitHub, 5 buildPythonPackage, 6 python, 7 config, 8 cudaSupport ? config.cudaSupport, 9 cudaPackages, 10 autoAddDriverRunpath, 11 effectiveMagma ? 12 if cudaSupport then 13 magma-cuda-static 14 else if rocmSupport then 15 magma-hip 16 else 17 magma, 18 magma, 19 magma-hip, 20 magma-cuda-static, 21 # Use the system NCCL as long as we're targeting CUDA on a supported platform. 22 useSystemNccl ? (cudaSupport && !cudaPackages.nccl.meta.unsupported || rocmSupport), 23 MPISupport ? false, 24 mpi, 25 buildDocs ? false, 26 27 # Native build inputs 28 cmake, 29 symlinkJoin, 30 which, 31 pybind11, 32 removeReferencesTo, 33 pythonRelaxDepsHook, 34 35 # Build inputs 36 numactl, 37 Accelerate, 38 CoreServices, 39 libobjc, 40 41 # Propagated build inputs 42 astunparse, 43 fsspec, 44 filelock, 45 jinja2, 46 networkx, 47 sympy, 48 numpy, 49 pyyaml, 50 cffi, 51 click, 52 typing-extensions, 53 # ROCm build and `torch.compile` requires `openai-triton` 54 tritonSupport ? (!stdenv.isDarwin), 55 openai-triton, 56 57 # Unit tests 58 hypothesis, 59 psutil, 60 61 # Disable MKLDNN on aarch64-darwin, it negatively impacts performance, 62 # this is also what official pytorch build does 63 mklDnnSupport ? !(stdenv.isDarwin && stdenv.isAarch64), 64 65 # virtual pkg that consistently instantiates blas across nixpkgs 66 # See https://github.com/NixOS/nixpkgs/pull/83888 67 blas, 68 69 # ninja (https://ninja-build.org) must be available to run C++ extensions tests, 70 ninja, 71 72 # dependencies for torch.utils.tensorboard 73 pillow, 74 six, 75 future, 76 tensorboard, 77 protobuf, 78 79 pythonOlder, 80 81 # ROCm dependencies 82 rocmSupport ? config.rocmSupport, 83 rocmPackages_5, 84 gpuTargets ? [ ], 85}: 86 87let 88 inherit (lib) 89 attrsets 90 lists 91 strings 92 trivial 93 ; 94 inherit (cudaPackages) cudaFlags cudnn nccl; 95 96 rocmPackages = rocmPackages_5; 97 98 setBool = v: if v then "1" else "0"; 99 100 # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/utils/cpp_extension.py#L1744 101 supportedTorchCudaCapabilities = 102 let 103 real = [ 104 "3.5" 105 "3.7" 106 "5.0" 107 "5.2" 108 "5.3" 109 "6.0" 110 "6.1" 111 "6.2" 112 "7.0" 113 "7.2" 114 "7.5" 115 "8.0" 116 "8.6" 117 "8.7" 118 "8.9" 119 "9.0" 120 ]; 121 ptx = lists.map (x: "${x}+PTX") real; 122 in 123 real ++ ptx; 124 125 # NOTE: The lists.subtractLists function is perhaps a bit unintuitive. It subtracts the elements 126 # of the first list *from* the second list. That means: 127 # lists.subtractLists a b = b - a 128 129 # For CUDA 130 supportedCudaCapabilities = lists.intersectLists cudaFlags.cudaCapabilities supportedTorchCudaCapabilities; 131 unsupportedCudaCapabilities = lists.subtractLists supportedCudaCapabilities cudaFlags.cudaCapabilities; 132 133 # Use trivial.warnIf to print a warning if any unsupported GPU targets are specified. 134 gpuArchWarner = 135 supported: unsupported: 136 trivial.throwIf (supported == [ ]) ( 137 "No supported GPU targets specified. Requested GPU targets: " 138 + strings.concatStringsSep ", " unsupported 139 ) supported; 140 141 # Create the gpuTargetString. 142 gpuTargetString = strings.concatStringsSep ";" ( 143 if gpuTargets != [ ] then 144 # If gpuTargets is specified, it always takes priority. 145 gpuTargets 146 else if cudaSupport then 147 gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities 148 else if rocmSupport then 149 rocmPackages.clr.gpuTargets 150 else 151 throw "No GPU targets specified" 152 ); 153 154 rocmtoolkit_joined = symlinkJoin { 155 name = "rocm-merged"; 156 157 paths = with rocmPackages; [ 158 rocm-core 159 clr 160 rccl 161 miopen 162 miopengemm 163 rocrand 164 rocblas 165 rocsparse 166 hipsparse 167 rocthrust 168 rocprim 169 hipcub 170 roctracer 171 rocfft 172 rocsolver 173 hipfft 174 hipsolver 175 hipblas 176 rocminfo 177 rocm-thunk 178 rocm-comgr 179 rocm-device-libs 180 rocm-runtime 181 clr.icd 182 hipify 183 ]; 184 185 # Fix `setuptools` not being found 186 postBuild = '' 187 rm -rf $out/nix-support 188 ''; 189 }; 190 191 brokenConditions = attrsets.filterAttrs (_: cond: cond) { 192 "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport; 193 "CUDA is not targeting Linux" = cudaSupport && !stdenv.isLinux; 194 "Unsupported CUDA version" = 195 cudaSupport 196 && !(builtins.elem cudaPackages.cudaMajorVersion [ 197 "11" 198 "12" 199 ]); 200 "MPI cudatoolkit does not match cudaPackages.cudatoolkit" = 201 MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit); 202 "Magma cudaPackages does not match cudaPackages" = 203 cudaSupport && (effectiveMagma.cudaPackages != cudaPackages); 204 }; 205in 206buildPythonPackage rec { 207 pname = "torch"; 208 # Don't forget to update torch-bin to the same version. 209 version = "2.3.0"; 210 pyproject = true; 211 212 disabled = pythonOlder "3.8.0"; 213 214 outputs = [ 215 "out" # output standard python package 216 "dev" # output libtorch headers 217 "lib" # output libtorch libraries 218 "cxxdev" # propagated deps for the cmake consumers of torch 219 ]; 220 cudaPropagateToOutput = "cxxdev"; 221 222 src = fetchFromGitHub { 223 owner = "pytorch"; 224 repo = "pytorch"; 225 rev = "refs/tags/v${version}"; 226 fetchSubmodules = true; 227 hash = "sha256-UmH4Mv5QL7Mz4Y4pvxn8F1FGBR/UzYZjE2Ys8Oc0FWQ="; 228 }; 229 230 patches = 231 lib.optionals cudaSupport [ ./fix-cmake-cuda-toolkit.patch ] 232 ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [ 233 # pthreadpool added support for Grand Central Dispatch in April 234 # 2020. However, this relies on functionality (DISPATCH_APPLY_AUTO) 235 # that is available starting with macOS 10.13. However, our current 236 # base is 10.12. Until we upgrade, we can fall back on the older 237 # pthread support. 238 ./pthreadpool-disable-gcd.diff 239 ] 240 ++ lib.optionals stdenv.isLinux [ 241 # Propagate CUPTI to Kineto by overriding the search path with environment variables. 242 # https://github.com/pytorch/pytorch/pull/108847 243 ./pytorch-pr-108847.patch 244 ]; 245 246 postPatch = 247 lib.optionalString rocmSupport '' 248 # https://github.com/facebookincubator/gloo/pull/297 249 substituteInPlace third_party/gloo/cmake/Hipify.cmake \ 250 --replace "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}" 251 252 # Replace hard-coded rocm paths 253 substituteInPlace caffe2/CMakeLists.txt \ 254 --replace "/opt/rocm" "${rocmtoolkit_joined}" \ 255 --replace "hcc/include" "hip/include" \ 256 --replace "rocblas/include" "include/rocblas" \ 257 --replace "hipsparse/include" "include/hipsparse" 258 259 # Doesn't pick up the environment variable? 260 substituteInPlace third_party/kineto/libkineto/CMakeLists.txt \ 261 --replace "\''$ENV{ROCM_SOURCE_DIR}" "${rocmtoolkit_joined}" \ 262 --replace "/opt/rocm" "${rocmtoolkit_joined}" 263 264 # Strangely, this is never set in cmake 265 substituteInPlace cmake/public/LoadHIP.cmake \ 266 --replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \ 267 "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitVersion rocmPackages.clr.version))})" 268 '' 269 # Detection of NCCL version doesn't work particularly well when using the static binary. 270 + lib.optionalString cudaSupport '' 271 substituteInPlace cmake/Modules/FindNCCL.cmake \ 272 --replace \ 273 'message(FATAL_ERROR "Found NCCL header version and library version' \ 274 'message(WARNING "Found NCCL header version and library version' 275 '' 276 # Remove PyTorch's FindCUDAToolkit.cmake and to use CMake's default. 277 # We do not remove the entirety of cmake/Modules_CUDA_fix because we need FindCUDNN.cmake. 278 + lib.optionalString cudaSupport '' 279 rm cmake/Modules/FindCUDAToolkit.cmake 280 rm -rf cmake/Modules_CUDA_fix/{upstream,FindCUDA.cmake} 281 '' 282 # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc' 283 # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header. 284 + 285 lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.hostPlatform.darwinSdkVersion "11.0") 286 '' 287 substituteInPlace third_party/pocketfft/pocketfft_hdronly.h --replace-fail '#if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER)) 288 inline void *aligned_alloc(size_t align, size_t size)' '#if 0 289 inline void *aligned_alloc(size_t align, size_t size)' 290 ''; 291 292 # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken 293 # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time 294 # without extreme care to ensure they don't lock each other out of shared resources. 295 # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195. 296 preConfigure = 297 lib.optionalString cudaSupport '' 298 export TORCH_CUDA_ARCH_LIST="${gpuTargetString}" 299 export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include 300 export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib 301 '' 302 + lib.optionalString (cudaSupport && cudaPackages ? cudnn) '' 303 export CUDNN_INCLUDE_DIR=${cudnn.dev}/include 304 export CUDNN_LIB_DIR=${cudnn.lib}/lib 305 '' 306 + lib.optionalString rocmSupport '' 307 export ROCM_PATH=${rocmtoolkit_joined} 308 export ROCM_SOURCE_DIR=${rocmtoolkit_joined} 309 export PYTORCH_ROCM_ARCH="${gpuTargetString}" 310 export CMAKE_CXX_FLAGS="-I${rocmtoolkit_joined}/include -I${rocmtoolkit_joined}/include/rocblas" 311 python tools/amd_build/build_amd.py 312 ''; 313 314 # Use pytorch's custom configurations 315 dontUseCmakeConfigure = true; 316 317 # causes possible redefinition of _FORTIFY_SOURCE 318 hardeningDisable = [ "fortify3" ]; 319 320 BUILD_NAMEDTENSOR = setBool true; 321 BUILD_DOCS = setBool buildDocs; 322 323 # We only do an imports check, so do not build tests either. 324 BUILD_TEST = setBool false; 325 326 # Unlike MKL, oneDNN (née MKLDNN) is FOSS, so we enable support for 327 # it by default. PyTorch currently uses its own vendored version 328 # of oneDNN through Intel iDeep. 329 USE_MKLDNN = setBool mklDnnSupport; 330 USE_MKLDNN_CBLAS = setBool mklDnnSupport; 331 332 # Avoid using pybind11 from git submodule 333 # Also avoids pytorch exporting the headers of pybind11 334 USE_SYSTEM_PYBIND11 = true; 335 336 # NB technical debt: building without NNPACK as workaround for missing `six` 337 USE_NNPACK = 0; 338 339 preBuild = '' 340 export MAX_JOBS=$NIX_BUILD_CORES 341 ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only 342 ${cmake}/bin/cmake build 343 ''; 344 345 preFixup = '' 346 function join_by { local IFS="$1"; shift; echo "$*"; } 347 function strip2 { 348 IFS=':' 349 read -ra RP <<< $(patchelf --print-rpath $1) 350 IFS=' ' 351 RP_NEW=$(join_by : ''${RP[@]:2}) 352 patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1" 353 } 354 for f in $(find ''${out} -name 'libcaffe2*.so') 355 do 356 strip2 $f 357 done 358 ''; 359 360 # Override the (weirdly) wrong version set by default. See 361 # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038 362 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267 363 PYTORCH_BUILD_VERSION = version; 364 PYTORCH_BUILD_NUMBER = 0; 365 366 # In-tree builds of NCCL are not supported. 367 # Use NCCL when cudaSupport is enabled and nccl is available. 368 USE_NCCL = setBool useSystemNccl; 369 USE_SYSTEM_NCCL = USE_NCCL; 370 USE_STATIC_NCCL = USE_NCCL; 371 372 # Suppress a weird warning in mkl-dnn, part of ideep in pytorch 373 # (upstream seems to have fixed this in the wrong place?) 374 # https://github.com/intel/mkl-dnn/commit/8134d346cdb7fe1695a2aa55771071d455fae0bc 375 # https://github.com/pytorch/pytorch/issues/22346 376 # 377 # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++: 378 # https://github.com/pytorch/pytorch/blob/v1.11.0/setup.py#L17 379 env.NIX_CFLAGS_COMPILE = toString ( 380 ( 381 lib.optionals (blas.implementation == "mkl") [ "-Wno-error=array-bounds" ] 382 # Suppress gcc regression: avx512 math function raises uninitialized variable warning 383 # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105593 384 # See also: Fails to compile with GCC 12.1.0 https://github.com/pytorch/pytorch/issues/77939 385 ++ lib.optionals (stdenv.cc.isGNU && lib.versionAtLeast stdenv.cc.version "12.0.0") [ 386 "-Wno-error=maybe-uninitialized" 387 "-Wno-error=uninitialized" 388 ] 389 # Since pytorch 2.0: 390 # gcc-12.2.0/include/c++/12.2.0/bits/new_allocator.h:158:33: error: ‘void operator delete(void*, std::size_t)’ 391 # ... called on pointer ‘<unknown>’ with nonzero offset [1, 9223372036854775800] [-Werror=free-nonheap-object] 392 ++ lib.optionals (stdenv.cc.isGNU && lib.versions.major stdenv.cc.version == "12") [ 393 "-Wno-error=free-nonheap-object" 394 ] 395 # .../source/torch/csrc/autograd/generated/python_functions_0.cpp:85:3: 396 # error: cast from ... to ... converts to incompatible function type [-Werror,-Wcast-function-type-strict] 397 ++ lib.optionals (stdenv.cc.isClang && lib.versionAtLeast stdenv.cc.version "16") [ 398 "-Wno-error=cast-function-type-strict" 399 # Suppresses the most spammy warnings. 400 # This is mainly to fix https://github.com/NixOS/nixpkgs/issues/266895. 401 ] 402 ++ lib.optionals rocmSupport [ 403 "-Wno-#warnings" 404 "-Wno-cpp" 405 "-Wno-unknown-warning-option" 406 "-Wno-ignored-attributes" 407 "-Wno-deprecated-declarations" 408 "-Wno-defaulted-function-deleted" 409 "-Wno-pass-failed" 410 ] 411 ++ [ 412 "-Wno-unused-command-line-argument" 413 "-Wno-uninitialized" 414 "-Wno-array-bounds" 415 "-Wno-free-nonheap-object" 416 "-Wno-unused-result" 417 ] 418 ++ lib.optionals stdenv.cc.isGNU [ 419 "-Wno-maybe-uninitialized" 420 "-Wno-stringop-overflow" 421 ] 422 ) 423 ); 424 425 nativeBuildInputs = 426 [ 427 cmake 428 which 429 ninja 430 pybind11 431 pythonRelaxDepsHook 432 removeReferencesTo 433 ] 434 ++ lib.optionals cudaSupport ( 435 with cudaPackages; 436 [ 437 autoAddDriverRunpath 438 cuda_nvcc 439 ] 440 ) 441 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ]; 442 443 buildInputs = 444 [ 445 blas 446 blas.provider 447 ] 448 ++ lib.optionals cudaSupport ( 449 with cudaPackages; 450 [ 451 cuda_cccl.dev # <thrust/*> 452 cuda_cudart.dev # cuda_runtime.h and libraries 453 cuda_cudart.lib 454 cuda_cudart.static 455 cuda_cupti.dev # For kineto 456 cuda_cupti.lib # For kineto 457 cuda_nvcc.dev # crt/host_config.h; even though we include this in nativeBuildinputs, it's needed here too 458 cuda_nvml_dev.dev # <nvml.h> 459 cuda_nvrtc.dev 460 cuda_nvrtc.lib 461 cuda_nvtx.dev 462 cuda_nvtx.lib # -llibNVToolsExt 463 libcublas.dev 464 libcublas.lib 465 libcufft.dev 466 libcufft.lib 467 libcurand.dev 468 libcurand.lib 469 libcusolver.dev 470 libcusolver.lib 471 libcusparse.dev 472 libcusparse.lib 473 ] 474 ++ lists.optionals (cudaPackages ? cudnn) [ 475 cudnn.dev 476 cudnn.lib 477 ] 478 ++ lists.optionals useSystemNccl [ 479 # Some platforms do not support NCCL (i.e., Jetson) 480 nccl.dev # Provides nccl.h AND a static copy of NCCL! 481 ] 482 ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [ 483 cuda_nvprof.dev # <cuda_profiler_api.h> 484 ] 485 ++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [ 486 cuda_profiler_api.dev # <cuda_profiler_api.h> 487 ] 488 ) 489 ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ] 490 ++ lib.optionals (cudaSupport || rocmSupport) [ effectiveMagma ] 491 ++ lib.optionals stdenv.isLinux [ numactl ] 492 ++ lib.optionals stdenv.isDarwin [ 493 Accelerate 494 CoreServices 495 libobjc 496 ] 497 ++ lib.optionals tritonSupport [ openai-triton ] 498 ++ lib.optionals MPISupport [ mpi ] 499 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ]; 500 501 propagatedBuildInputs = [ 502 astunparse 503 cffi 504 click 505 numpy 506 pyyaml 507 508 # From install_requires: 509 fsspec 510 filelock 511 typing-extensions 512 sympy 513 networkx 514 jinja2 515 516 # the following are required for tensorboard support 517 pillow 518 six 519 future 520 tensorboard 521 protobuf 522 523 # torch/csrc requires `pybind11` at runtime 524 pybind11 525 ] ++ lib.optionals tritonSupport [ openai-triton ]; 526 527 propagatedCxxBuildInputs = 528 [ ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ]; 529 530 # Tests take a long time and may be flaky, so just sanity-check imports 531 doCheck = false; 532 533 pythonImportsCheck = [ "torch" ]; 534 535 nativeCheckInputs = [ 536 hypothesis 537 ninja 538 psutil 539 ]; 540 541 checkPhase = 542 with lib.versions; 543 with lib.strings; 544 concatStringsSep " " [ 545 "runHook preCheck" 546 "${python.interpreter} test/run_test.py" 547 "--exclude" 548 (concatStringsSep " " [ 549 "utils" # utils requires git, which is not allowed in the check phase 550 551 # "dataloader" # psutils correctly finds and triggers multiprocessing, but is too sandboxed to run -- resulting in numerous errors 552 # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build 553 554 # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins 555 (optionalString (majorMinor version == "1.3") "tensorboard") 556 ]) 557 "runHook postCheck" 558 ]; 559 560 pythonRemoveDeps = [ 561 # In our dist-info the name is just "triton" 562 "pytorch-triton-rocm" 563 ]; 564 565 postInstall = 566 '' 567 find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' + 568 569 mkdir $dev 570 cp -r $out/${python.sitePackages}/torch/include $dev/include 571 cp -r $out/${python.sitePackages}/torch/share $dev/share 572 573 # Fix up library paths for split outputs 574 substituteInPlace \ 575 $dev/share/cmake/Torch/TorchConfig.cmake \ 576 --replace \''${TORCH_INSTALL_PREFIX}/lib "$lib/lib" 577 578 substituteInPlace \ 579 $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \ 580 --replace \''${_IMPORT_PREFIX}/lib "$lib/lib" 581 582 mkdir $lib 583 mv $out/${python.sitePackages}/torch/lib $lib/lib 584 ln -s $lib/lib $out/${python.sitePackages}/torch/lib 585 '' 586 + lib.optionalString rocmSupport '' 587 substituteInPlace $dev/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake \ 588 --replace "\''${_IMPORT_PREFIX}/lib64" "$lib/lib" 589 590 substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \ 591 --replace "/build/source/torch/include" "$dev/include" 592 ''; 593 594 postFixup = 595 '' 596 mkdir -p "$cxxdev/nix-support" 597 printWords "''${propagatedCxxBuildInputs[@]}" >> "$cxxdev/nix-support/propagated-build-inputs" 598 '' 599 + lib.optionalString stdenv.isDarwin '' 600 for f in $(ls $lib/lib/*.dylib); do 601 install_name_tool -id $lib/lib/$(basename $f) $f || true 602 done 603 604 install_name_tool -change @rpath/libshm.dylib $lib/lib/libshm.dylib $lib/lib/libtorch_python.dylib 605 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libtorch_python.dylib 606 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch_python.dylib 607 608 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch.dylib 609 610 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libshm.dylib 611 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libshm.dylib 612 ''; 613 614 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder. 615 requiredSystemFeatures = [ "big-parallel" ]; 616 617 passthru = { 618 inherit 619 cudaSupport 620 cudaPackages 621 rocmSupport 622 rocmPackages 623 ; 624 # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability. 625 blasProvider = blas.provider; 626 # To help debug when a package is broken due to CUDA support 627 inherit brokenConditions; 628 cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ]; 629 }; 630 631 meta = with lib; { 632 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}"; 633 # keep PyTorch in the description so the package can be found under that name on search.nixos.org 634 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration"; 635 homepage = "https://pytorch.org/"; 636 license = licenses.bsd3; 637 maintainers = with maintainers; [ 638 teh 639 thoughtpolice 640 tscholak 641 ]; # tscholak esp. for darwin-related builds 642 platforms = with platforms; linux ++ lib.optionals (!cudaSupport && !rocmSupport) darwin; 643 broken = builtins.any trivial.id (builtins.attrValues brokenConditions); 644 }; 645}