1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, pytestCheckHook
5, pythonOlder
6, torch
7, torchvision
8}:
9
10buildPythonPackage rec {
11 pname = "torchinfo";
12 version = "1.7.2";
13 format = "setuptools";
14
15 disabled = pythonOlder "3.7";
16
17 src = fetchFromGitHub {
18 owner = "TylerYep";
19 repo = pname;
20 rev = "refs/tags/v${version}";
21 hash = "sha256-O+I7BNQ5moV/ZcbbuP/IFoi0LO0WsGHBbSfgPmFu1Ec=";
22 };
23
24 propagatedBuildInputs = [
25 torch
26 torchvision
27 ];
28
29 nativeCheckInputs = [
30 pytestCheckHook
31 ];
32
33 disabledTests = [
34 # Skip as it downloads pretrained weights (require network access)
35 "test_eval_order_doesnt_matter"
36 # AssertionError in output
37 "test_google"
38 ];
39
40 disabledTestPaths = [
41 # Wants "compressai", which we don't package (2023-03-23)
42 "tests/torchinfo_xl_test.py"
43 ];
44
45 pythonImportsCheck = [
46 "torchinfo"
47 ];
48
49 meta = with lib; {
50 description = "API to visualize pytorch models";
51 homepage = "https://github.com/TylerYep/torchinfo";
52 license = licenses.mit;
53 maintainers = with maintainers; [ petterstorvik ];
54 };
55}