nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 fetchpatch,
6 torch,
7 torchvision,
8 pytestCheckHook,
9 transformers,
10}:
11
12buildPythonPackage rec {
13 pname = "torchinfo";
14 version = "1.8.0";
15 format = "setuptools";
16
17 src = fetchFromGitHub {
18 owner = "TylerYep";
19 repo = "torchinfo";
20 tag = "v${version}";
21 hash = "sha256-pPjg498aT8y4b4tqIzNxxKyobZX01u+66ScS/mee51Q=";
22 };
23
24 patches = [
25 (fetchpatch {
26 # Add support for Python 3.11 and pytorch 2.1
27 url = "https://github.com/TylerYep/torchinfo/commit/c74784c71c84e62bcf56664653b7f28d72a2ee0d.patch";
28 hash = "sha256-xSSqs0tuFpdMXUsoVv4sZLCeVnkK6pDDhX/Eobvn5mw=";
29 includes = [ "torchinfo/model_statistics.py" ];
30 })
31 ];
32
33 propagatedBuildInputs = [
34 torch
35 torchvision
36 ];
37
38 nativeCheckInputs = [
39 pytestCheckHook
40 transformers
41 ];
42
43 preCheck = ''
44 export HOME=$(mktemp -d)
45 '';
46
47 disabledTests = [
48 # Skip as it downloads pretrained weights (require network access)
49 "test_eval_order_doesnt_matter"
50 "test_flan_t5_small"
51 # AssertionError in output
52 "test_google"
53 # "addmm_impl_cpu_" not implemented for 'Half'
54 "test_input_size_half_precision"
55 ];
56
57 disabledTestPaths = [
58 # Test requires network access
59 "tests/torchinfo_xl_test.py"
60 ];
61
62 pythonImportsCheck = [ "torchinfo" ];
63
64 meta = {
65 description = "API to visualize pytorch models";
66 homepage = "https://github.com/TylerYep/torchinfo";
67 license = lib.licenses.mit;
68 maintainers = with lib.maintainers; [ petterstorvik ];
69 };
70}