1{ lib 2, buildPythonPackage 3, fetchFromGitHub 4, fetchpatch 5, pythonOlder 6, torch 7, torchvision 8, pytestCheckHook 9, transformers 10}: 11 12buildPythonPackage rec { 13 pname = "torchinfo"; 14 version = "1.8.0"; 15 format = "setuptools"; 16 17 disabled = pythonOlder "3.7"; 18 19 src = fetchFromGitHub { 20 owner = "TylerYep"; 21 repo = "torchinfo"; 22 rev = "refs/tags/v${version}"; 23 hash = "sha256-pPjg498aT8y4b4tqIzNxxKyobZX01u+66ScS/mee51Q="; 24 }; 25 26 patches = [ 27 (fetchpatch { # Add support for Python 3.11 and pytorch 2.1 28 url = "https://github.com/TylerYep/torchinfo/commit/c74784c71c84e62bcf56664653b7f28d72a2ee0d.patch"; 29 hash = "sha256-xSSqs0tuFpdMXUsoVv4sZLCeVnkK6pDDhX/Eobvn5mw="; 30 includes = [ 31 "torchinfo/model_statistics.py" 32 ]; 33 }) 34 ]; 35 36 propagatedBuildInputs = [ 37 torch 38 torchvision 39 ]; 40 41 nativeCheckInputs = [ 42 pytestCheckHook 43 transformers 44 ]; 45 46 preCheck = '' 47 export HOME=$(mktemp -d) 48 ''; 49 50 disabledTests = [ 51 # Skip as it downloads pretrained weights (require network access) 52 "test_eval_order_doesnt_matter" 53 "test_flan_t5_small" 54 # AssertionError in output 55 "test_google" 56 # "addmm_impl_cpu_" not implemented for 'Half' 57 "test_input_size_half_precision" 58 ]; 59 60 disabledTestPaths = [ 61 # Test requires network access 62 "tests/torchinfo_xl_test.py" 63 ]; 64 65 pythonImportsCheck = [ 66 "torchinfo" 67 ]; 68 69 meta = with lib; { 70 description = "API to visualize pytorch models"; 71 homepage = "https://github.com/TylerYep/torchinfo"; 72 license = licenses.mit; 73 maintainers = with maintainers; [ petterstorvik ]; 74 }; 75}