1{ buildPythonPackage
2, expecttest
3, fetchFromGitHub
4, lib
5, ninja
6, pytestCheckHook
7, python
8, torch
9, pybind11
10, which
11}:
12
13buildPythonPackage rec {
14 pname = "functorch";
15 version = "0.2.0";
16 format = "setuptools";
17
18 src = fetchFromGitHub {
19 owner = "pytorch";
20 repo = pname;
21 rev = "refs/tags/v${version}";
22 hash = "sha256-33skKk5aAIHn+1149ifolXPA+tpQ+WROAZvwPeGBbrA=";
23 };
24
25 # Somewhat surprisingly pytorch is actually necessary for the build process.
26 # `setup.py` imports `torch.utils.cpp_extension`.
27 nativeBuildInputs = [
28 ninja
29 torch
30 which
31 ];
32
33 buildInputs = [
34 pybind11
35 ];
36
37 preCheck = ''
38 rm -rf functorch/
39 '';
40
41 checkInputs = [
42 expecttest
43 pytestCheckHook
44 ];
45
46 # See https://github.com/pytorch/functorch/issues/835.
47 disabledTests = [
48 # RuntimeError: ("('...', '') is in PyTorch's OpInfo db ", "but is not in functorch's OpInfo db. Please regenerate ", '... and add the new tests to ', 'denylists if necessary.')
49 "test_coverage_bernoulli_cpu_float32"
50 "test_coverage_column_stack_cpu_float32"
51 "test_coverage_diagflat_cpu_float32"
52 "test_coverage_flatten_cpu_float32"
53 "test_coverage_linalg_lu_factor_cpu_float32"
54 "test_coverage_linalg_lu_factor_ex_cpu_float32"
55 "test_coverage_multinomial_cpu_float32"
56 "test_coverage_nn_functional_dropout2d_cpu_float32"
57 "test_coverage_nn_functional_feature_alpha_dropout_with_train_cpu_float32"
58 "test_coverage_nn_functional_feature_alpha_dropout_without_train_cpu_float32"
59 "test_coverage_nn_functional_kl_div_cpu_float32"
60 "test_coverage_normal_cpu_float32"
61 "test_coverage_normal_number_mean_cpu_float32"
62 "test_coverage_pca_lowrank_cpu_float32"
63 "test_coverage_round_decimals_0_cpu_float32"
64 "test_coverage_round_decimals_3_cpu_float32"
65 "test_coverage_round_decimals_neg_3_cpu_float32"
66 "test_coverage_scatter_reduce_cpu_float32"
67 "test_coverage_svd_lowrank_cpu_float32"
68
69 # > self.assertEqual(len(functorch_lagging_op_db), len(op_db))
70 # E AssertionError: Scalars are not equal!
71 # E
72 # E Absolute difference: 19
73 # E Relative difference: 0.03525046382189239
74 "test_functorch_lagging_op_db_has_opinfos_cpu"
75
76 # RuntimeError: PyTorch not compiled with LLVM support!
77 "test_bias_gelu"
78 "test_binary_ops"
79 "test_broadcast1"
80 "test_broadcast2"
81 "test_float_double"
82 "test_float_int"
83 "test_fx_trace"
84 "test_int_long"
85 "test_issue57611"
86 "test_slice1"
87 "test_slice2"
88 "test_transposed1"
89 "test_transposed2"
90 "test_unary_ops"
91 ];
92
93 pythonImportsCheck = [ "functorch" ];
94
95 meta = with lib; {
96 description = "JAX-like composable function transforms for PyTorch";
97 homepage = "https://pytorch.org/functorch";
98 license = licenses.bsd3;
99 maintainers = with maintainers; [ samuela ];
100 # See https://github.com/NixOS/nixpkgs/pull/174248#issuecomment-1139895064.
101 platforms = platforms.x86_64;
102 };
103}