1{
2 autoAddDriverRunpath,
3 buildPythonPackage,
4 config,
5 cudaPackages,
6 fetchFromGitHub,
7 fetchurl,
8 jax,
9 lib,
10 llvmPackages,
11 numpy,
12 pkgsBuildHost,
13 python,
14 replaceVars,
15 runCommand,
16 setuptools,
17 stdenv,
18 torch,
19 warp-lang, # Self-reference to this package for passthru.tests
20 writableTmpDirAsHomeHook,
21 writeShellApplication,
22
23 # Use standalone LLVM-based JIT compiler and CPU device support
24 standaloneSupport ? true,
25
26 # Use CUDA toolchain and GPU device support
27 cudaSupport ? config.cudaSupport,
28
29 # Build Warp with MathDx support (requires CUDA support)
30 # Most linear-algebra tile operations like tile_cholesky(), tile_fft(),
31 # and tile_matmul() require Warp to be built with the MathDx library.
32 # libmathdxSupport ? cudaSupport && stdenv.hostPlatform.isLinux,
33 libmathdxSupport ? cudaSupport,
34}@args:
35assert libmathdxSupport -> cudaSupport;
36let
37 effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else args.stdenv;
38 stdenv = builtins.throw "Use effectiveStdenv instead of stdenv directly, as it may be replaced by cudaPackages.backendStdenv";
39
40 version = "1.8.0";
41
42 libmathdx = effectiveStdenv.mkDerivation (finalAttrs: {
43 # NOTE: The version used should match the version Warp requires:
44 # https://github.com/NVIDIA/warp/blob/${version}/deps/libmathdx-deps.packman.xml
45 pname = "libmathdx";
46 version = "0.2.1";
47
48 outputs = [
49 "out"
50 "static"
51 ];
52
53 src =
54 let
55 baseURL = "https://developer.download.nvidia.com/compute/cublasdx/redist/cublasdx";
56 name = lib.concatStringsSep "-" [
57 finalAttrs.pname
58 "Linux"
59 effectiveStdenv.hostPlatform.parsed.cpu.name
60 finalAttrs.version
61 ];
62
63 # nix-hash --type sha256 --to-sri $(nix-prefetch-url "https://...")
64 hashes = {
65 aarch64-linux = "sha256-smB13xev2TG1xUx4+06KRgYEnPMczpjBOOX7uC1APbE=";
66 x86_64-linux = "sha256-+3TbLuL5Y2flLRicQgPVLs8KZQBqNYJYJ8P3etgX7g0=";
67 };
68 in
69 lib.mapNullable (
70 hash:
71 fetchurl {
72 inherit hash name;
73 url = "${baseURL}/${name}.tar.gz";
74 }
75 ) (hashes.${effectiveStdenv.hostPlatform.system} or null);
76
77 dontUnpack = true;
78 dontConfigure = true;
79 dontBuild = true;
80
81 installPhase = ''
82 runHook preInstall
83
84 mkdir -p "$out"
85 tar -xzf "$src" -C "$out"
86
87 mkdir -p "$static"
88 moveToOutput "lib/libmathdx_static.a" "$static"
89
90 runHook postInstall
91 '';
92
93 meta = {
94 description = "Library used to integrate cuBLASDx and cuFFTDx into Warp";
95 homepage = "https://developer.nvidia.com/cublasdx-downloads";
96 sourceProvenance = with lib.sourceTypes; [ binaryNativeCode ];
97 license = with lib.licenses; [
98 # By downloading and using the software, you agree to fully
99 # comply with the terms and conditions of the NVIDIA Software
100 # License Agreement.
101 (
102 nvidiaCudaRedist
103 // {
104 url = "https://developer.download.nvidia.cn/compute/mathdx/License.txt";
105 }
106 )
107
108 # Some of the libmathdx routines were written by or derived
109 # from code written by Meta Platforms, Inc. and affiliates and
110 # are subject to the BSD License.
111 bsd3
112
113 # Some of the libmathdx routines were written by or derived from
114 # code written by Victor Zverovich and are subject to the following
115 # license:
116 mit
117 ];
118 platforms = [
119 "aarch64-linux"
120 "x86_64-linux"
121 ];
122 maintainers = with lib.maintainers; [ yzx9 ];
123 };
124 });
125in
126buildPythonPackage {
127 pname = "warp-lang";
128 inherit version;
129 pyproject = true;
130
131 # TODO(@connorbaker): Some CUDA setup hook is failing when __structuredAttrs is false,
132 # causing a bunch of missing math symbols (like expf) when linking against the static library
133 # provided by NVCC.
134 __structuredAttrs = true;
135
136 stdenv = effectiveStdenv;
137
138 src = fetchFromGitHub {
139 owner = "NVIDIA";
140 repo = "warp";
141 tag = "v${version}";
142 hash = "sha256-zCRB92acxOiIFGjfRh2Cr1qq8pbhm+Rd011quMP/D88=";
143 };
144
145 patches =
146 lib.optionals effectiveStdenv.hostPlatform.isDarwin [
147 (replaceVars ./darwin-libcxx.patch {
148 LIBCXX_DEV = llvmPackages.libcxx.dev;
149 LIBCXX_LIB = llvmPackages.libcxx;
150 })
151 ./darwin-single-target.patch
152 ]
153 ++ lib.optionals standaloneSupport [
154 (replaceVars ./standalone-llvm.patch {
155 LLVM_DEV = llvmPackages.llvm.dev;
156 LLVM_LIB = llvmPackages.llvm.lib;
157 LIBCLANG_DEV = llvmPackages.libclang.dev;
158 LIBCLANG_LIB = llvmPackages.libclang.lib;
159 })
160 ./standalone-cxx11-abi.patch
161 ];
162
163 postPatch =
164 # Patch build_dll.py to use our gencode flags rather than NVIDIA's very broad defaults.
165 lib.optionalString cudaSupport ''
166 nixLog "patching $PWD/warp/build_dll.py to use our gencode flags"
167 substituteInPlace "$PWD/warp/build_dll.py" \
168 --replace-fail \
169 '*gencode_opts,' \
170 '${
171 lib.concatMapStringsSep ", " (gencodeString: ''"${gencodeString}"'') cudaPackages.flags.gencode
172 },' \
173 --replace-fail \
174 '*clang_arch_flags,' \
175 '${
176 lib.concatMapStringsSep ", " (
177 realArch: ''"--cuda-gpu-arch=${realArch}"''
178 ) cudaPackages.flags.realArches
179 },'
180 ''
181 # Patch build_dll.py to use dynamic libraries rather than static ones.
182 # NOTE: We do not patch the `nvptxcompiler_static` path because it is not available as a dynamic library.
183 + lib.optionalString cudaSupport ''
184 nixLog "patching $PWD/warp/build_dll.py to use dynamic libraries"
185 substituteInPlace "$PWD/warp/build_dll.py" \
186 --replace-fail \
187 '-lcudart_static' \
188 '-lcudart' \
189 --replace-fail \
190 '-lnvrtc_static' \
191 '-lnvrtc' \
192 --replace-fail \
193 '-lnvrtc-builtins_static' \
194 '-lnvrtc-builtins' \
195 --replace-fail \
196 '-lnvJitLink_static' \
197 '-lnvJitLink' \
198 --replace-fail \
199 '-lmathdx_static' \
200 '-lmathdx'
201 ''
202 # Broken tests on aarch64. Since unittest doesn't support disabling a
203 # single test, and pytest isn't compatible, we patch the test file directly
204 # instead.
205 #
206 # See: https://github.com/NVIDIA/warp/issues/552
207 + lib.optionalString effectiveStdenv.hostPlatform.isAarch64 ''
208 nixLog "patching $PWD/warp/tests/test_fem.py to disable broken tests on aarch64"
209 substituteInPlace "$PWD/warp/tests/test_fem.py" \
210 --replace-fail \
211 'add_function_test(TestFem, "test_integrate_gradient", test_integrate_gradient, devices=devices)' \
212 ""
213 ''
214 # AssertionError: 0.4082476496696472 != 0.40824246406555176 within 5 places
215 + lib.optionalString effectiveStdenv.hostPlatform.isDarwin ''
216 nixLog "patching $PWD/warp/tests/test_fem.py to disable broken tests on darwin"
217 substituteInPlace "$PWD/warp/tests/test_codegen.py" \
218 --replace-fail \
219 'places=5' \
220 'places=4'
221 ''
222 # These tests fail on CPU and CUDA.
223 + ''
224 nixLog "patching $PWD/warp/tests/test_reload.py to disable broken tests"
225 substituteInPlace "$PWD/warp/tests/test_reload.py" \
226 --replace-fail \
227 'add_function_test(TestReload, "test_reload", test_reload, devices=devices)' \
228 "" \
229 --replace-fail \
230 'add_function_test(TestReload, "test_reload_references", test_reload_references, devices=get_test_devices("basic"))' \
231 ""
232 '';
233
234 build-system = [
235 setuptools
236 ];
237
238 dependencies = [
239 numpy
240 ];
241
242 # NOTE: While normally we wouldn't include autoAddDriverRunpath for packages built from source, since Warp
243 # will be loading GPU drivers at runtime, we need to inject the path to our video drivers.
244 nativeBuildInputs = lib.optionals cudaSupport [
245 autoAddDriverRunpath
246 cudaPackages.cuda_nvcc
247 ];
248
249 buildInputs =
250 lib.optionals standaloneSupport [
251 llvmPackages.llvm
252 llvmPackages.clang
253 llvmPackages.libcxx
254 ]
255 ++ lib.optionals cudaSupport [
256 (lib.getOutput "static" cudaPackages.cuda_nvcc) # dependency on nvptxcompiler_static; no dynamic version available
257 cudaPackages.cuda_cccl
258 cudaPackages.cuda_cudart
259 cudaPackages.cuda_nvcc
260 cudaPackages.cuda_nvrtc
261 ]
262 ++ lib.optionals libmathdxSupport [
263 libmathdx
264 cudaPackages.libcublas
265 cudaPackages.libcufft
266 cudaPackages.libcusolver
267 cudaPackages.libnvjitlink
268 ];
269
270 preBuild =
271 let
272 buildOptions =
273 lib.optionals effectiveStdenv.cc.isClang [
274 "--clang_build_toolchain"
275 ]
276 ++ lib.optionals (!standaloneSupport) [
277 "--no_standalone"
278 ]
279 ++ lib.optionals cudaSupport [
280 # NOTE: The `cuda_path` argument is the directory which contains `bin/nvcc` (i.e., the bin output).
281 "--cuda_path=${lib.getBin pkgsBuildHost.cudaPackages.cuda_nvcc}"
282 ]
283 ++ lib.optionals libmathdxSupport [
284 "--libmathdx"
285 "--libmathdx_path=${libmathdx}"
286 ]
287 ++ lib.optionals (!libmathdxSupport) [
288 "--no_libmathdx"
289 ];
290
291 buildOptionString = lib.concatStringsSep " " buildOptions;
292 in
293 ''
294 nixLog "running $PWD/build_lib.py to create components necessary to build the wheel"
295 "${python.pythonOnBuildForHost.interpreter}" "$PWD/build_lib.py" ${buildOptionString}
296 '';
297
298 pythonImportsCheck = [
299 "warp"
300 ];
301
302 # See passthru.tests.
303 doCheck = false;
304
305 passthru = {
306 # Make libmathdx available for introspection.
307 inherit libmathdx;
308
309 # Scripts which provide test packages and implement test logic.
310 testers.unit-tests = writeShellApplication {
311 name = "warp-lang-unit-tests";
312 runtimeInputs = [
313 # Use the references from args
314 (python.withPackages (_: [
315 warp-lang
316 jax
317 torch
318 ]))
319 # Disable paddlepaddle interop tests: malloc(): unaligned tcache chunk detected
320 # (paddlepaddle.override { inherit cudaSupport; })
321 ];
322 text = ''
323 python3 -m warp.tests
324 '';
325 };
326
327 # Tests run within the Nix sandbox.
328 tests =
329 let
330 mkUnitTests =
331 {
332 cudaSupport,
333 libmathdxSupport,
334 }:
335 let
336 name =
337 "warp-lang-unit-tests-cpu" # CPU is baseline
338 + lib.optionalString cudaSupport "-cuda"
339 + lib.optionalString libmathdxSupport "-libmathdx";
340
341 warp-lang' = warp-lang.override {
342 inherit cudaSupport libmathdxSupport;
343 # Make sure the warp-lang provided through callPackage is replaced with the override we're making.
344 warp-lang = warp-lang';
345 };
346 in
347 runCommand name
348 {
349 nativeBuildInputs = [
350 warp-lang'.passthru.testers.unit-tests
351 writableTmpDirAsHomeHook
352 ];
353 requiredSystemFeatures = lib.optionals cudaSupport [ "cuda" ];
354 # Many unit tests fail with segfaults on aarch64-linux, especially in the sim
355 # and grad modules. However, other functionality generally works, so we don't
356 # mark the package as broken.
357 #
358 # See: https://www.github.com/NVIDIA/warp/issues/{356,372,552}
359 meta.broken = effectiveStdenv.hostPlatform.isAarch64 && effectiveStdenv.hostPlatform.isLinux;
360 }
361 ''
362 nixLog "running ${name}"
363
364 if warp-lang-unit-tests; then
365 nixLog "${name} passed"
366 touch "$out"
367 else
368 nixErrorLog "${name} failed"
369 exit 1
370 fi
371 '';
372 in
373 {
374 cpu = mkUnitTests {
375 cudaSupport = false;
376 libmathdxSupport = false;
377 };
378 cuda = {
379 cudaOnly = mkUnitTests {
380 cudaSupport = true;
381 libmathdxSupport = false;
382 };
383 cudaWithLibmathDx = mkUnitTests {
384 cudaSupport = true;
385 libmathdxSupport = true;
386 };
387 };
388 };
389 };
390
391 meta = {
392 description = "Python framework for high performance GPU simulation and graphics";
393 longDescription = ''
394 Warp is a Python framework for writing high-performance simulation
395 and graphics code. Warp takes regular Python functions and JIT
396 compiles them to efficient kernel code that can run on the CPU or
397 GPU.
398
399 Warp is designed for spatial computing and comes with a rich set
400 of primitives that make it easy to write programs for physics
401 simulation, perception, robotics, and geometry processing. In
402 addition, Warp kernels are differentiable and can be used as part
403 of machine-learning pipelines with frameworks such as PyTorch,
404 JAX and Paddle.
405 '';
406 homepage = "https://github.com/NVIDIA/warp";
407 changelog = "https://github.com/NVIDIA/warp/blob/v${version}/CHANGELOG.md";
408 license = lib.licenses.asl20;
409 platforms = with lib.platforms; linux ++ darwin;
410 maintainers = with lib.maintainers; [ yzx9 ];
411 };
412}