Source code for mednet.utils.summary
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488
from functools import reduce
import torch
from torch.nn.modules.module import _addindent
# ignore this space!
def _repr(model: torch.nn.Module) -> tuple[str, int]:
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = model.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
total_params = 0
for key, module in model._modules.items(): # noqa: SLF001
mod_str, num_params = _repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
total_params += num_params
lines = extra_lines + child_lines
for _, p in model._parameters.items(): # noqa: SLF001
if hasattr(p, "dtype"):
total_params += reduce(lambda x, y: x * y, p.shape)
main_str = model._get_name() + "(" # noqa: SLF001
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
main_str += f", {total_params:,} params"
return main_str, total_params
[docs]
def summary(model: torch.nn.Module) -> tuple[str, int]:
"""Count the number of parameters in each model layer.
Parameters
----------
model
Model to summarize.
Returns
-------
tuple[int, str]
A tuple containing a multiline string representation of the network and the number of parameters.
"""
return _repr(model)