nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at python-updates 320 lines 11 kB view raw
1{ 2 autoAddDriverRunpath, 3 buildPythonPackage, 4 config, 5 cudaPackages, 6 callPackage, 7 fetchFromGitHub, 8 jax, 9 lib, 10 llvmPackages, # TODO: use llvm 21 in 1.10, see python-packages.nix 11 numpy, 12 pkgsBuildHost, 13 python, 14 replaceVars, 15 runCommand, 16 setuptools, 17 stdenv, 18 torch, 19 warp-lang, # Self-reference to this package for passthru.tests 20 writableTmpDirAsHomeHook, 21 writeShellApplication, 22 23 # Use standalone LLVM-based JIT compiler and CPU device support 24 standaloneSupport ? true, 25 26 # Use CUDA toolchain and GPU device support 27 cudaSupport ? config.cudaSupport, 28 29 # Build Warp with MathDx support (requires CUDA support) 30 # Most linear-algebra tile operations like tile_cholesky(), tile_fft(), 31 # and tile_matmul() require Warp to be built with the MathDx library. 32 # libmathdxSupport ? cudaSupport && stdenv.hostPlatform.isLinux, 33 libmathdxSupport ? cudaSupport, 34}@args: 35 36assert libmathdxSupport -> cudaSupport; 37 38let 39 effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else args.stdenv; 40 stdenv = throw "Use effectiveStdenv instead of stdenv directly, as it may be replaced by cudaPackages.backendStdenv"; 41 42 libmathdx = callPackage ./libmathdx.nix { }; 43in 44buildPythonPackage.override { stdenv = effectiveStdenv; } (finalAttrs: { 45 pname = "warp-lang"; 46 version = "1.11.0"; 47 pyproject = true; 48 49 # TODO(@connorbaker): Some CUDA setup hook is failing when __structuredAttrs is false, 50 # causing a bunch of missing math symbols (like expf) when linking against the static library 51 # provided by NVCC. 52 __structuredAttrs = true; 53 54 src = fetchFromGitHub { 55 owner = "NVIDIA"; 56 repo = "warp"; 57 tag = "v${finalAttrs.version}"; 58 hash = "sha256-wV4F6E4l0lfPB8zk/XhmdMNk649j5aJelW/DVu2R5mM="; 59 }; 60 61 patches = lib.optionals standaloneSupport [ 62 (replaceVars ./dynamic-link.patch { 63 LLVM_LIB = llvmPackages.llvm.lib; 64 LIBCLANG_LIB = llvmPackages.libclang.lib; 65 }) 66 ]; 67 68 postPatch = '' 69 nixLog "patching $PWD/build_llvm.py to remove pre-C++11 ABI flag" 70 substituteInPlace "$PWD/build_llvm.py" \ 71 --replace-fail \ 72 '"-D", f"CMAKE_CXX_FLAGS=-D_GLIBCXX_USE_CXX11_ABI=0 {abi_version}", # The pre-C++11 ABI is still the default on the CentOS 7 toolchain' \ 73 "" 74 75 substituteInPlace "$PWD/warp/_src/build_dll.py" \ 76 --replace-fail " -D_GLIBCXX_USE_CXX11_ABI=0" "" 77 '' 78 + lib.optionalString effectiveStdenv.hostPlatform.isDarwin ( 79 '' 80 nixLog "patching $PWD/warp/_src/build_dll.py to remove macOS target flag and link against libc++" 81 substituteInPlace "$PWD/warp/_src/build_dll.py" \ 82 --replace-fail "--target={arch}-apple-macos11" "" \ 83 --replace-fail 'ld_inputs = []' "ld_inputs = ['-L\"${llvmPackages.libcxx}/lib\" -lc++']" 84 '' 85 # AssertionError: 0.4082476496696472 != 0.40824246406555176 within 5 places 86 + '' 87 nixLog "patching $PWD/warp/tests/test_fem.py to disable broken tests on darwin" 88 substituteInPlace "$PWD/warp/tests/test_codegen.py" \ 89 --replace-fail 'places=5' 'places=4' 90 '' 91 ) 92 + lib.optionalString effectiveStdenv.cc.isClang '' 93 substituteInPlace "$PWD/warp/_src/build_dll.py" \ 94 --replace-fail "clang++" "${effectiveStdenv.cc}/bin/cc" 95 '' 96 + lib.optionalString standaloneSupport '' 97 substituteInPlace "$PWD/warp/_src/build_dll.py" \ 98 --replace-fail \ 99 '-I"{warp_home_path.parent}/external/llvm-project/out/install/{mode}-{arch}/include"' \ 100 '-I"${llvmPackages.llvm.dev}/include"' \ 101 --replace-fail \ 102 '-I"{warp_home_path.parent}/_build/host-deps/llvm-project/release-{arch}/include"' \ 103 '-I"${llvmPackages.libclang.dev}/include"' \ 104 105 '' 106 # Patch build_dll.py to use our gencode flags rather than NVIDIA's very broad defaults. 107 + lib.optionalString cudaSupport ( 108 let 109 gencodeOpts = lib.concatMapStringsSep ", " ( 110 gencodeString: ''"${gencodeString}"'' 111 ) cudaPackages.flags.gencode; 112 113 clangArchFlags = lib.concatMapStringsSep ", " ( 114 realArch: ''"--cuda-gpu-arch=${realArch}"'' 115 ) cudaPackages.flags.realArches; 116 in 117 '' 118 nixLog "patching $PWD/warp/_src/build_dll.py to use our gencode flags" 119 substituteInPlace "$PWD/warp/_src/build_dll.py" \ 120 --replace-fail '*gencode_opts,' '${gencodeOpts},' \ 121 --replace-fail '*clang_arch_flags,' '${clangArchFlags},' 122 '' 123 # Patch build_dll.py to use dynamic libraries rather than static ones. 124 # NOTE: We do not patch the `nvptxcompiler_static` path because it is not available as a dynamic library. 125 + '' 126 nixLog "patching $PWD/warp/_srsc/build_dll.py to use dynamic libraries" 127 substituteInPlace "$PWD/warp/_src/build_dll.py" \ 128 --replace-fail '-lcudart_static' '-lcudart' \ 129 --replace-fail '-lnvrtc_static' '-lnvrtc' \ 130 --replace-fail '-lnvrtc-builtins_static' '-lnvrtc-builtins' \ 131 --replace-fail '-lnvJitLink_static' '-lnvJitLink' \ 132 --replace-fail '-lmathdx_static' '-lmathdx' 133 '' 134 ) 135 # These tests fail on CPU and CUDA. 136 + '' 137 nixLog "patching $PWD/warp/tests/test_reload.py to disable broken tests" 138 substituteInPlace "$PWD/warp/tests/test_reload.py" \ 139 --replace-fail \ 140 'add_function_test(TestReload, "test_reload", test_reload, devices=devices)' \ 141 "" \ 142 --replace-fail \ 143 'add_function_test(TestReload, "test_reload_references", test_reload_references, devices=get_test_devices("basic"))' \ 144 "" 145 ''; 146 147 build-system = [ 148 setuptools 149 ]; 150 151 dependencies = [ 152 numpy 153 ]; 154 155 nativeBuildInputs = lib.optionals cudaSupport [ 156 # NOTE: While normally we wouldn't include autoAddDriverRunpath for packages built from source, since Warp 157 # will be loading GPU drivers at runtime, we need to inject the path to our video drivers. 158 autoAddDriverRunpath 159 ]; 160 161 buildInputs = 162 lib.optionals standaloneSupport [ 163 llvmPackages.llvm 164 llvmPackages.clang 165 llvmPackages.libcxx 166 ] 167 ++ lib.optionals cudaSupport [ 168 (lib.getStatic cudaPackages.cuda_nvcc) # dependency on nvptxcompiler_static; no dynamic version available 169 cudaPackages.cuda_cccl 170 cudaPackages.cuda_cudart 171 cudaPackages.cuda_nvcc 172 cudaPackages.cuda_nvrtc 173 ] 174 ++ lib.optionals libmathdxSupport [ 175 libmathdx 176 cudaPackages.libcublas 177 cudaPackages.libcufft 178 cudaPackages.libcusolver 179 cudaPackages.libnvjitlink 180 ]; 181 182 preBuild = 183 let 184 buildOptions = 185 lib.optionals effectiveStdenv.cc.isClang [ 186 "--clang_build_toolchain" 187 ] 188 ++ lib.optionals (!standaloneSupport) [ 189 "--no_standalone" 190 ] 191 ++ lib.optionals cudaSupport [ 192 # NOTE: The `cuda_path` argument is the directory which contains `bin/nvcc` (i.e., the bin output). 193 "--cuda_path=${lib.getBin pkgsBuildHost.cudaPackages.cuda_nvcc}" 194 ] 195 ++ lib.optionals libmathdxSupport [ 196 "--libmathdx" 197 "--libmathdx_path=${libmathdx}" 198 ] 199 ++ lib.optionals (!libmathdxSupport) [ 200 "--no_libmathdx" 201 ]; 202 203 buildOptionString = lib.concatStringsSep " " buildOptions; 204 in 205 '' 206 nixLog "running $PWD/build_lib.py to create components necessary to build the wheel" 207 "${python.pythonOnBuildForHost.interpreter}" "$PWD/build_lib.py" ${buildOptionString} 208 ''; 209 210 pythonImportsCheck = [ 211 "warp" 212 ]; 213 214 # See passthru.tests. 215 doCheck = false; 216 217 passthru = { 218 # Make libmathdx available for introspection. 219 inherit libmathdx; 220 221 # Scripts which provide test packages and implement test logic. 222 testers.unit-tests = 223 let 224 # Use the references from args 225 python' = python.withPackages (_: [ 226 warp-lang 227 jax 228 torch 229 ]); 230 # Disable paddlepaddle interop tests: malloc(): unaligned tcache chunk detected 231 # (paddlepaddle.override { inherit cudaSupport; }) 232 in 233 writeShellApplication { 234 name = "warp-lang-unit-tests"; 235 runtimeInputs = [ python' ]; 236 text = '' 237 ${python'}/bin/python3 -m warp.tests 238 ''; 239 }; 240 241 # Tests run within the Nix sandbox. 242 tests = 243 let 244 mkUnitTests = 245 { 246 cudaSupport, 247 libmathdxSupport, 248 }: 249 let 250 name = 251 "warp-lang-unit-tests-cpu" # CPU is baseline 252 + lib.optionalString cudaSupport "-cuda" 253 + lib.optionalString libmathdxSupport "-libmathdx"; 254 255 warp-lang' = warp-lang.override { 256 inherit cudaSupport libmathdxSupport; 257 # Make sure the warp-lang provided through callPackage is replaced with the override we're making. 258 warp-lang = warp-lang'; 259 }; 260 in 261 runCommand name 262 { 263 nativeBuildInputs = [ 264 warp-lang'.passthru.testers.unit-tests 265 writableTmpDirAsHomeHook 266 ]; 267 requiredSystemFeatures = lib.optionals cudaSupport [ "cuda" ]; 268 } 269 '' 270 nixLog "running ${name}" 271 272 if warp-lang-unit-tests; then 273 nixLog "${name} passed" 274 touch "$out" 275 else 276 nixErrorLog "${name} failed" 277 exit 1 278 fi 279 ''; 280 in 281 { 282 cpu = mkUnitTests { 283 cudaSupport = false; 284 libmathdxSupport = false; 285 }; 286 cuda = { 287 cudaOnly = mkUnitTests { 288 cudaSupport = true; 289 libmathdxSupport = false; 290 }; 291 cudaWithLibmathDx = mkUnitTests { 292 cudaSupport = true; 293 libmathdxSupport = true; 294 }; 295 }; 296 }; 297 }; 298 299 meta = { 300 description = "Python framework for high performance GPU simulation and graphics"; 301 longDescription = '' 302 Warp is a Python framework for writing high-performance simulation 303 and graphics code. Warp takes regular Python functions and JIT 304 compiles them to efficient kernel code that can run on the CPU or 305 GPU. 306 307 Warp is designed for spatial computing and comes with a rich set 308 of primitives that make it easy to write programs for physics 309 simulation, perception, robotics, and geometry processing. In 310 addition, Warp kernels are differentiable and can be used as part 311 of machine-learning pipelines with frameworks such as PyTorch, 312 JAX and Paddle. 313 ''; 314 homepage = "https://github.com/NVIDIA/warp"; 315 changelog = "https://github.com/NVIDIA/warp/blob/${finalAttrs.src.tag}/CHANGELOG.md"; 316 license = lib.licenses.asl20; 317 platforms = lib.platforms.linux ++ [ "aarch64-darwin" ]; 318 maintainers = with lib.maintainers; [ yzx9 ]; 319 }; 320})