1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 setuptools,
6 wheel,
7 torch,
8 iopath,
9 cudaPackages,
10 config,
11 cudaSupport ? config.cudaSupport,
12}:
13
14assert cudaSupport -> torch.cudaSupport;
15
16buildPythonPackage rec {
17 pname = "pytorch3d";
18 version = "0.7.8";
19 pyproject = true;
20
21 src = fetchFromGitHub {
22 owner = "facebookresearch";
23 repo = "pytorch3d";
24 rev = "V${version}";
25 hash = "sha256-DEEWWfjwjuXGc0WQInDTmtnWSIDUifyByxdg7hpdHlo=";
26 };
27
28 nativeBuildInputs = lib.optionals cudaSupport [ cudaPackages.cuda_nvcc ];
29 build-system = [
30 setuptools
31 wheel
32 ];
33 dependencies = [
34 torch
35 iopath
36 ];
37 buildInputs = [ (lib.getOutput "cxxdev" torch) ];
38
39 env =
40 {
41 FORCE_CUDA = cudaSupport;
42 }
43 // lib.optionalAttrs cudaSupport {
44 TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
45 };
46
47 pythonImportsCheck = [ "pytorch3d" ];
48
49 passthru.tests.rotations-cuda =
50 cudaPackages.writeGpuTestPython { libraries = ps: [ ps.pytorch3d ]; }
51 ''
52 import pytorch3d.transforms as p3dt
53
54 M = p3dt.random_rotations(n=10, device="cuda")
55 assert "cuda" in M.device.type
56 angles = p3dt.matrix_to_euler_angles(M, "XYZ")
57 assert "cuda" in angles.device.type
58 assert angles.shape == (10, 3), angles.shape
59 print(angles)
60 '';
61
62 meta = {
63 description = "FAIR's library of reusable components for deep learning with 3D data";
64 homepage = "https://github.com/facebookresearch/pytorch3d";
65 license = lib.licenses.bsd3;
66 maintainers = with lib.maintainers; [
67 pbsds
68 SomeoneSerge
69 ];
70 };
71}