1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, fetchpatch
5, pythonOlder
6, torch
7, torchvision
8, pytestCheckHook
9, transformers
10}:
11
12buildPythonPackage rec {
13 pname = "torchinfo";
14 version = "1.8.0";
15 format = "setuptools";
16
17 disabled = pythonOlder "3.7";
18
19 src = fetchFromGitHub {
20 owner = "TylerYep";
21 repo = "torchinfo";
22 rev = "refs/tags/v${version}";
23 hash = "sha256-pPjg498aT8y4b4tqIzNxxKyobZX01u+66ScS/mee51Q=";
24 };
25
26 patches = [
27 (fetchpatch { # Add support for Python 3.11 and pytorch 2.1
28 url = "https://github.com/TylerYep/torchinfo/commit/c74784c71c84e62bcf56664653b7f28d72a2ee0d.patch";
29 hash = "sha256-xSSqs0tuFpdMXUsoVv4sZLCeVnkK6pDDhX/Eobvn5mw=";
30 includes = [
31 "torchinfo/model_statistics.py"
32 ];
33 })
34 ];
35
36 propagatedBuildInputs = [
37 torch
38 torchvision
39 ];
40
41 nativeCheckInputs = [
42 pytestCheckHook
43 transformers
44 ];
45
46 preCheck = ''
47 export HOME=$(mktemp -d)
48 '';
49
50 disabledTests = [
51 # Skip as it downloads pretrained weights (require network access)
52 "test_eval_order_doesnt_matter"
53 "test_flan_t5_small"
54 # AssertionError in output
55 "test_google"
56 # "addmm_impl_cpu_" not implemented for 'Half'
57 "test_input_size_half_precision"
58 ];
59
60 disabledTestPaths = [
61 # Test requires network access
62 "tests/torchinfo_xl_test.py"
63 ];
64
65 pythonImportsCheck = [
66 "torchinfo"
67 ];
68
69 meta = with lib; {
70 description = "API to visualize pytorch models";
71 homepage = "https://github.com/TylerYep/torchinfo";
72 license = licenses.mit;
73 maintainers = with maintainers; [ petterstorvik ];
74 };
75}