Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for mednet.utils.summary
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488
from functools import reduce
import torch
from torch.nn.modules.module import _addindent
# ignore this space!
def _repr ( model : torch . nn . Module ) -> tuple [ str , int ]:
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = model . extra_repr ()
# empty string will be split into list ['']
if extra_repr :
extra_lines = extra_repr . split ( " \n " )
child_lines = []
total_params = 0
for key , module in model . _modules . items (): # noqa: SLF001
mod_str , num_params = _repr ( module )
mod_str = _addindent ( mod_str , 2 )
child_lines . append ( "(" + key + "): " + mod_str )
total_params += num_params
lines = extra_lines + child_lines
for _ , p in model . _parameters . items (): # noqa: SLF001
if hasattr ( p , "dtype" ):
total_params += reduce ( lambda x , y : x * y , p . shape )
main_str = model . _get_name () + "(" # noqa: SLF001
if lines :
# simple one-liner info, which most builtin Modules will use
if len ( extra_lines ) == 1 and not child_lines :
main_str += extra_lines [ 0 ]
else :
main_str += " \n " + " \n " . join ( lines ) + " \n "
main_str += ")"
main_str += f ", { total_params : , } params"
return main_str , total_params
[docs]
def summary ( model : torch . nn . Module ) -> tuple [ str , int ]:
"""Count the number of parameters in each model layer.
Parameters
----------
model
Model to summarize.
Returns
-------
tuple[int, str]
A tuple containing a multiline string representation of the network and the number of parameters.
"""
return _repr ( model )