1{ stdenv, lib, fetchFromGitHub, fetchpatch, buildPythonPackage, python,
2 config, cudaSupport ? config.cudaSupport, cudaPackages,
3 effectiveMagma ?
4 if cudaSupport then magma-cuda-static
5 else if rocmSupport then magma-hip
6 else magma,
7 magma,
8 magma-hip,
9 magma-cuda-static,
10 useSystemNccl ? true,
11 MPISupport ? false, mpi,
12 buildDocs ? false,
13
14 # Native build inputs
15 cmake, linkFarm, symlinkJoin, which, pybind11, removeReferencesTo,
16 pythonRelaxDepsHook,
17
18 # Build inputs
19 numactl,
20 Accelerate, CoreServices, libobjc,
21
22 # Propagated build inputs
23 filelock,
24 jinja2,
25 networkx,
26 sympy,
27 numpy, pyyaml, cffi, click, typing-extensions,
28 # ROCm build and `torch.compile` requires `openai-triton`
29 tritonSupport ? (!stdenv.isDarwin), openai-triton,
30
31 # Unit tests
32 hypothesis, psutil,
33
34 # Disable MKLDNN on aarch64-darwin, it negatively impacts performance,
35 # this is also what official pytorch build does
36 mklDnnSupport ? !(stdenv.isDarwin && stdenv.isAarch64),
37
38 # virtual pkg that consistently instantiates blas across nixpkgs
39 # See https://github.com/NixOS/nixpkgs/pull/83888
40 blas,
41
42 # ninja (https://ninja-build.org) must be available to run C++ extensions tests,
43 ninja,
44
45 # dependencies for torch.utils.tensorboard
46 pillow, six, future, tensorboard, protobuf,
47
48 pythonOlder,
49
50 # ROCm dependencies
51 rocmSupport ? config.rocmSupport,
52 rocmPackages,
53 gpuTargets ? [ ]
54}:
55
56let
57 inherit (lib) attrsets lists strings trivial;
58 inherit (cudaPackages) cudaFlags cudnn;
59
60 # Some packages are not available on all platforms
61 nccl = cudaPackages.nccl or null;
62
63 setBool = v: if v then "1" else "0";
64
65 # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/utils/cpp_extension.py#L1744
66 supportedTorchCudaCapabilities =
67 let
68 real = ["3.5" "3.7" "5.0" "5.2" "5.3" "6.0" "6.1" "6.2" "7.0" "7.2" "7.5" "8.0" "8.6" "8.7" "8.9" "9.0"];
69 ptx = lists.map (x: "${x}+PTX") real;
70 in
71 real ++ ptx;
72
73 # NOTE: The lists.subtractLists function is perhaps a bit unintuitive. It subtracts the elements
74 # of the first list *from* the second list. That means:
75 # lists.subtractLists a b = b - a
76
77 # For CUDA
78 supportedCudaCapabilities = lists.intersectLists cudaFlags.cudaCapabilities supportedTorchCudaCapabilities;
79 unsupportedCudaCapabilities = lists.subtractLists supportedCudaCapabilities cudaFlags.cudaCapabilities;
80
81 # Use trivial.warnIf to print a warning if any unsupported GPU targets are specified.
82 gpuArchWarner = supported: unsupported:
83 trivial.throwIf (supported == [ ])
84 (
85 "No supported GPU targets specified. Requested GPU targets: "
86 + strings.concatStringsSep ", " unsupported
87 )
88 supported;
89
90 # Create the gpuTargetString.
91 gpuTargetString = strings.concatStringsSep ";" (
92 if gpuTargets != [ ] then
93 # If gpuTargets is specified, it always takes priority.
94 gpuTargets
95 else if cudaSupport then
96 gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities
97 else if rocmSupport then
98 rocmPackages.clr.gpuTargets
99 else
100 throw "No GPU targets specified"
101 );
102
103 rocmtoolkit_joined = symlinkJoin {
104 name = "rocm-merged";
105
106 paths = with rocmPackages; [
107 rocm-core clr rccl miopen miopengemm rocrand rocblas
108 rocsparse hipsparse rocthrust rocprim hipcub roctracer
109 rocfft rocsolver hipfft hipsolver hipblas
110 rocminfo rocm-thunk rocm-comgr rocm-device-libs
111 rocm-runtime clr.icd hipify
112 ];
113
114 # Fix `setuptools` not being found
115 postBuild = ''
116 rm -rf $out/nix-support
117 '';
118 };
119
120 brokenConditions = attrsets.filterAttrs (_: cond: cond) {
121 "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
122 "CUDA is not targeting Linux" = cudaSupport && !stdenv.isLinux;
123 "Unsupported CUDA version" = cudaSupport && !(builtins.elem cudaPackages.cudaMajorVersion [ "11" "12" ]);
124 "MPI cudatoolkit does not match cudaPackages.cudatoolkit" = MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
125 "Magma cudaPackages does not match cudaPackages" = cudaSupport && (effectiveMagma.cudaPackages != cudaPackages);
126 };
127in buildPythonPackage rec {
128 pname = "torch";
129 # Don't forget to update torch-bin to the same version.
130 version = "2.0.1";
131 format = "setuptools";
132
133 disabled = pythonOlder "3.8.0";
134
135 outputs = [
136 "out" # output standard python package
137 "dev" # output libtorch headers
138 "lib" # output libtorch libraries
139 ];
140
141 src = fetchFromGitHub {
142 owner = "pytorch";
143 repo = "pytorch";
144 rev = "refs/tags/v${version}";
145 fetchSubmodules = true;
146 hash = "sha256-xUj77yKz3IQ3gd/G32pI4OhL3LoN1zS7eFg0/0nZp5I=";
147 };
148
149 patches = lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [
150 # pthreadpool added support for Grand Central Dispatch in April
151 # 2020. However, this relies on functionality (DISPATCH_APPLY_AUTO)
152 # that is available starting with macOS 10.13. However, our current
153 # base is 10.12. Until we upgrade, we can fall back on the older
154 # pthread support.
155 ./pthreadpool-disable-gcd.diff
156 ] ++ lib.optionals stdenv.isLinux [
157 # Propagate CUPTI to Kineto by overriding the search path with environment variables.
158 # https://github.com/pytorch/pytorch/pull/108847
159 ./pytorch-pr-108847.patch
160 ];
161
162 postPatch = lib.optionalString rocmSupport ''
163 # https://github.com/facebookincubator/gloo/pull/297
164 substituteInPlace third_party/gloo/cmake/Hipify.cmake \
165 --replace "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}"
166
167 # Replace hard-coded rocm paths
168 substituteInPlace caffe2/CMakeLists.txt \
169 --replace "/opt/rocm" "${rocmtoolkit_joined}" \
170 --replace "hcc/include" "hip/include" \
171 --replace "rocblas/include" "include/rocblas" \
172 --replace "hipsparse/include" "include/hipsparse"
173
174 # Doesn't pick up the environment variable?
175 substituteInPlace third_party/kineto/libkineto/CMakeLists.txt \
176 --replace "\''$ENV{ROCM_SOURCE_DIR}" "${rocmtoolkit_joined}" \
177 --replace "/opt/rocm" "${rocmtoolkit_joined}"
178
179 # Strangely, this is never set in cmake
180 substituteInPlace cmake/public/LoadHIP.cmake \
181 --replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \
182 "set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitString "." rocmPackages.clr.version))})"
183 ''
184 # Detection of NCCL version doesn't work particularly well when using the static binary.
185 + lib.optionalString cudaSupport ''
186 substituteInPlace cmake/Modules/FindNCCL.cmake \
187 --replace \
188 'message(FATAL_ERROR "Found NCCL header version and library version' \
189 'message(WARNING "Found NCCL header version and library version'
190 ''
191 # TODO(@connorbaker): Remove this patch after 2.1.0 lands.
192 + lib.optionalString cudaSupport ''
193 substituteInPlace torch/utils/cpp_extension.py \
194 --replace \
195 "'8.6', '8.9'" \
196 "'8.6', '8.7', '8.9'"
197 ''
198 # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc'
199 # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header.
200 + lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.hostPlatform.darwinSdkVersion "11.0") ''
201 substituteInPlace third_party/pocketfft/pocketfft_hdronly.h --replace '#if __cplusplus >= 201703L
202 inline void *aligned_alloc(size_t align, size_t size)' '#if __cplusplus >= 201703L && 0
203 inline void *aligned_alloc(size_t align, size_t size)'
204 '';
205
206 # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
207 # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
208 # without extreme care to ensure they don't lock each other out of shared resources.
209 # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
210 preConfigure = lib.optionalString cudaSupport ''
211 export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
212 export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
213 export CUDNN_LIB_DIR=${cudnn.lib}/lib
214 export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include
215 export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib
216 '' + lib.optionalString rocmSupport ''
217 export ROCM_PATH=${rocmtoolkit_joined}
218 export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
219 export PYTORCH_ROCM_ARCH="${gpuTargetString}"
220 export CMAKE_CXX_FLAGS="-I${rocmtoolkit_joined}/include -I${rocmtoolkit_joined}/include/rocblas"
221 python tools/amd_build/build_amd.py
222 '';
223
224 # Use pytorch's custom configurations
225 dontUseCmakeConfigure = true;
226
227 # causes possible redefinition of _FORTIFY_SOURCE
228 hardeningDisable = [ "fortify3" ];
229
230 BUILD_NAMEDTENSOR = setBool true;
231 BUILD_DOCS = setBool buildDocs;
232
233 # We only do an imports check, so do not build tests either.
234 BUILD_TEST = setBool false;
235
236 # Unlike MKL, oneDNN (née MKLDNN) is FOSS, so we enable support for
237 # it by default. PyTorch currently uses its own vendored version
238 # of oneDNN through Intel iDeep.
239 USE_MKLDNN = setBool mklDnnSupport;
240 USE_MKLDNN_CBLAS = setBool mklDnnSupport;
241
242 # Avoid using pybind11 from git submodule
243 # Also avoids pytorch exporting the headers of pybind11
244 USE_SYSTEM_PYBIND11 = true;
245
246 preBuild = ''
247 export MAX_JOBS=$NIX_BUILD_CORES
248 ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
249 ${cmake}/bin/cmake build
250 '';
251
252 preFixup = ''
253 function join_by { local IFS="$1"; shift; echo "$*"; }
254 function strip2 {
255 IFS=':'
256 read -ra RP <<< $(patchelf --print-rpath $1)
257 IFS=' '
258 RP_NEW=$(join_by : ''${RP[@]:2})
259 patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
260 }
261 for f in $(find ''${out} -name 'libcaffe2*.so')
262 do
263 strip2 $f
264 done
265 '';
266
267 # Override the (weirdly) wrong version set by default. See
268 # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
269 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
270 PYTORCH_BUILD_VERSION = version;
271 PYTORCH_BUILD_NUMBER = 0;
272
273 USE_NCCL = setBool (nccl != null);
274 USE_SYSTEM_NCCL = setBool useSystemNccl; # don't build pytorch's third_party NCCL
275 USE_STATIC_NCCL = setBool useSystemNccl;
276
277 # Suppress a weird warning in mkl-dnn, part of ideep in pytorch
278 # (upstream seems to have fixed this in the wrong place?)
279 # https://github.com/intel/mkl-dnn/commit/8134d346cdb7fe1695a2aa55771071d455fae0bc
280 # https://github.com/pytorch/pytorch/issues/22346
281 #
282 # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
283 # https://github.com/pytorch/pytorch/blob/v1.11.0/setup.py#L17
284 env.NIX_CFLAGS_COMPILE = toString ((lib.optionals (blas.implementation == "mkl") [ "-Wno-error=array-bounds" ]
285 # Suppress gcc regression: avx512 math function raises uninitialized variable warning
286 # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105593
287 # See also: Fails to compile with GCC 12.1.0 https://github.com/pytorch/pytorch/issues/77939
288 ++ lib.optionals (stdenv.cc.isGNU && lib.versionAtLeast stdenv.cc.version "12.0.0") [
289 "-Wno-error=maybe-uninitialized"
290 "-Wno-error=uninitialized"
291 ]
292 # Since pytorch 2.0:
293 # gcc-12.2.0/include/c++/12.2.0/bits/new_allocator.h:158:33: error: ‘void operator delete(void*, std::size_t)’
294 # ... called on pointer ‘<unknown>’ with nonzero offset [1, 9223372036854775800] [-Werror=free-nonheap-object]
295 ++ lib.optionals (stdenv.cc.isGNU && lib.versions.major stdenv.cc.version == "12" ) [
296 "-Wno-error=free-nonheap-object"
297 ]
298 # .../source/torch/csrc/autograd/generated/python_functions_0.cpp:85:3:
299 # error: cast from ... to ... converts to incompatible function type [-Werror,-Wcast-function-type-strict]
300 ++ lib.optionals (stdenv.cc.isClang && lib.versionAtLeast stdenv.cc.version "16") [
301 "-Wno-error=cast-function-type-strict"
302 # Suppresses the most spammy warnings.
303 # This is mainly to fix https://github.com/NixOS/nixpkgs/issues/266895.
304 ] ++ lib.optionals rocmSupport [
305 "-Wno-#warnings"
306 "-Wno-cpp"
307 "-Wno-unknown-warning-option"
308 "-Wno-ignored-attributes"
309 "-Wno-deprecated-declarations"
310 "-Wno-defaulted-function-deleted"
311 "-Wno-pass-failed"
312 ] ++ [
313 "-Wno-unused-command-line-argument"
314 "-Wno-uninitialized"
315 "-Wno-array-bounds"
316 "-Wno-free-nonheap-object"
317 "-Wno-unused-result"
318 ] ++ lib.optionals stdenv.cc.isGNU [
319 "-Wno-maybe-uninitialized"
320 "-Wno-stringop-overflow"
321 ]));
322
323 nativeBuildInputs = [
324 cmake
325 which
326 ninja
327 pybind11
328 pythonRelaxDepsHook
329 removeReferencesTo
330 ] ++ lib.optionals cudaSupport (with cudaPackages; [
331 autoAddOpenGLRunpathHook
332 cuda_nvcc
333 ])
334 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
335
336 buildInputs = [ blas blas.provider ]
337 ++ lib.optionals cudaSupport (with cudaPackages; [
338 cuda_cccl.dev # <thrust/*>
339 cuda_cudart # cuda_runtime.h and libraries
340 cuda_cupti.dev # For kineto
341 cuda_cupti.lib # For kineto
342 cuda_nvcc.dev # crt/host_config.h; even though we include this in nativeBuildinputs, it's needed here too
343 cuda_nvml_dev.dev # <nvml.h>
344 cuda_nvrtc.dev
345 cuda_nvrtc.lib
346 cuda_nvtx.dev
347 cuda_nvtx.lib # -llibNVToolsExt
348 cudnn.dev
349 cudnn.lib
350 libcublas.dev
351 libcublas.lib
352 libcufft.dev
353 libcufft.lib
354 libcurand.dev
355 libcurand.lib
356 libcusolver.dev
357 libcusolver.lib
358 libcusparse.dev
359 libcusparse.lib
360 ] ++ lists.optionals (nccl != null) [
361 # Some platforms do not support NCCL (i.e., Jetson)
362 nccl.dev # Provides nccl.h AND a static copy of NCCL!
363 ] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
364 cuda_nvprof.dev # <cuda_profiler_api.h>
365 ] ++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [
366 cuda_profiler_api.dev # <cuda_profiler_api.h>
367 ])
368 ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]
369 ++ lib.optionals (cudaSupport || rocmSupport) [ effectiveMagma ]
370 ++ lib.optionals stdenv.isLinux [ numactl ]
371 ++ lib.optionals stdenv.isDarwin [ Accelerate CoreServices libobjc ];
372
373 propagatedBuildInputs = [
374 cffi
375 click
376 numpy
377 pyyaml
378
379 # From install_requires:
380 filelock
381 typing-extensions
382 sympy
383 networkx
384 jinja2
385
386 # the following are required for tensorboard support
387 pillow six future tensorboard protobuf
388
389 # torch/csrc requires `pybind11` at runtime
390 pybind11
391 ]
392 ++ lib.optionals tritonSupport [ openai-triton ]
393 ++ lib.optionals MPISupport [ mpi ]
394 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
395
396 # Tests take a long time and may be flaky, so just sanity-check imports
397 doCheck = false;
398
399 pythonImportsCheck = [
400 "torch"
401 ];
402
403 nativeCheckInputs = [ hypothesis ninja psutil ];
404
405 checkPhase = with lib.versions; with lib.strings; concatStringsSep " " [
406 "runHook preCheck"
407 "${python.interpreter} test/run_test.py"
408 "--exclude"
409 (concatStringsSep " " [
410 "utils" # utils requires git, which is not allowed in the check phase
411
412 # "dataloader" # psutils correctly finds and triggers multiprocessing, but is too sandboxed to run -- resulting in numerous errors
413 # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build
414
415 # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins
416 (optionalString (majorMinor version == "1.3" ) "tensorboard")
417 ])
418 "runHook postCheck"
419 ];
420
421 pythonRemoveDeps = [
422 # In our dist-info the name is just "triton"
423 "pytorch-triton-rocm"
424 ];
425
426 postInstall = ''
427 find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +
428
429 mkdir $dev
430 cp -r $out/${python.sitePackages}/torch/include $dev/include
431 cp -r $out/${python.sitePackages}/torch/share $dev/share
432
433 # Fix up library paths for split outputs
434 substituteInPlace \
435 $dev/share/cmake/Torch/TorchConfig.cmake \
436 --replace \''${TORCH_INSTALL_PREFIX}/lib "$lib/lib"
437
438 substituteInPlace \
439 $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \
440 --replace \''${_IMPORT_PREFIX}/lib "$lib/lib"
441
442 mkdir $lib
443 mv $out/${python.sitePackages}/torch/lib $lib/lib
444 ln -s $lib/lib $out/${python.sitePackages}/torch/lib
445 '' + lib.optionalString rocmSupport ''
446 substituteInPlace $dev/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake \
447 --replace "\''${_IMPORT_PREFIX}/lib64" "$lib/lib"
448
449 substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \
450 --replace "/build/source/torch/include" "$dev/include"
451 '';
452
453 postFixup = lib.optionalString stdenv.isDarwin ''
454 for f in $(ls $lib/lib/*.dylib); do
455 install_name_tool -id $lib/lib/$(basename $f) $f || true
456 done
457
458 install_name_tool -change @rpath/libshm.dylib $lib/lib/libshm.dylib $lib/lib/libtorch_python.dylib
459 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libtorch_python.dylib
460 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch_python.dylib
461
462 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch.dylib
463
464 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libshm.dylib
465 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libshm.dylib
466 '';
467
468 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
469 requiredSystemFeatures = [ "big-parallel" ];
470
471 passthru = {
472 inherit cudaSupport cudaPackages;
473 # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
474 blasProvider = blas.provider;
475 # To help debug when a package is broken due to CUDA support
476 inherit brokenConditions;
477 cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
478 };
479
480 meta = with lib; {
481 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
482 # keep PyTorch in the description so the package can be found under that name on search.nixos.org
483 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
484 homepage = "https://pytorch.org/";
485 license = licenses.bsd3;
486 maintainers = with maintainers; [ teh thoughtpolice tscholak ]; # tscholak esp. for darwin-related builds
487 platforms = with platforms; linux ++ lib.optionals (!cudaSupport && !rocmSupport) darwin;
488 broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
489 };
490}