1{ buildPythonPackage 2, expecttest 3, fetchFromGitHub 4, lib 5, ninja 6, pytestCheckHook 7, python 8, torch 9, pybind11 10, which 11}: 12 13buildPythonPackage rec { 14 pname = "functorch"; 15 version = "0.2.0"; 16 format = "setuptools"; 17 18 src = fetchFromGitHub { 19 owner = "pytorch"; 20 repo = pname; 21 rev = "refs/tags/v${version}"; 22 hash = "sha256-33skKk5aAIHn+1149ifolXPA+tpQ+WROAZvwPeGBbrA="; 23 }; 24 25 # Somewhat surprisingly pytorch is actually necessary for the build process. 26 # `setup.py` imports `torch.utils.cpp_extension`. 27 nativeBuildInputs = [ 28 ninja 29 torch 30 which 31 ]; 32 33 buildInputs = [ 34 pybind11 35 ]; 36 37 preCheck = '' 38 rm -rf functorch/ 39 ''; 40 41 checkInputs = [ 42 expecttest 43 pytestCheckHook 44 ]; 45 46 # See https://github.com/pytorch/functorch/issues/835. 47 disabledTests = [ 48 # RuntimeError: ("('...', '') is in PyTorch's OpInfo db ", "but is not in functorch's OpInfo db. Please regenerate ", '... and add the new tests to ', 'denylists if necessary.') 49 "test_coverage_bernoulli_cpu_float32" 50 "test_coverage_column_stack_cpu_float32" 51 "test_coverage_diagflat_cpu_float32" 52 "test_coverage_flatten_cpu_float32" 53 "test_coverage_linalg_lu_factor_cpu_float32" 54 "test_coverage_linalg_lu_factor_ex_cpu_float32" 55 "test_coverage_multinomial_cpu_float32" 56 "test_coverage_nn_functional_dropout2d_cpu_float32" 57 "test_coverage_nn_functional_feature_alpha_dropout_with_train_cpu_float32" 58 "test_coverage_nn_functional_feature_alpha_dropout_without_train_cpu_float32" 59 "test_coverage_nn_functional_kl_div_cpu_float32" 60 "test_coverage_normal_cpu_float32" 61 "test_coverage_normal_number_mean_cpu_float32" 62 "test_coverage_pca_lowrank_cpu_float32" 63 "test_coverage_round_decimals_0_cpu_float32" 64 "test_coverage_round_decimals_3_cpu_float32" 65 "test_coverage_round_decimals_neg_3_cpu_float32" 66 "test_coverage_scatter_reduce_cpu_float32" 67 "test_coverage_svd_lowrank_cpu_float32" 68 69 # > self.assertEqual(len(functorch_lagging_op_db), len(op_db)) 70 # E AssertionError: Scalars are not equal! 71 # E 72 # E Absolute difference: 19 73 # E Relative difference: 0.03525046382189239 74 "test_functorch_lagging_op_db_has_opinfos_cpu" 75 76 # RuntimeError: PyTorch not compiled with LLVM support! 77 "test_bias_gelu" 78 "test_binary_ops" 79 "test_broadcast1" 80 "test_broadcast2" 81 "test_float_double" 82 "test_float_int" 83 "test_fx_trace" 84 "test_int_long" 85 "test_issue57611" 86 "test_slice1" 87 "test_slice2" 88 "test_transposed1" 89 "test_transposed2" 90 "test_unary_ops" 91 ]; 92 93 pythonImportsCheck = [ "functorch" ]; 94 95 meta = with lib; { 96 description = "JAX-like composable function transforms for PyTorch"; 97 homepage = "https://pytorch.org/functorch"; 98 license = licenses.bsd3; 99 maintainers = with maintainers; [ samuela ]; 100 # See https://github.com/NixOS/nixpkgs/pull/174248#issuecomment-1139895064. 101 platforms = platforms.x86_64; 102 }; 103}