Source code for mednet.engine.loggers
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Custom lightning loggers."""
import os
import typing
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.loggers import TensorBoardLogger
[docs]
class CustomTensorboardLogger(TensorBoardLogger):
r"""Custom implementation implementation of lightning's TensorboardLogger.
This implementation puts all logs inside the same directory, instead of a
separate "version_n" directories, which is the default lightning behaviour.
Parameters
----------
save_dir
Directory where to save the logs to.
name
Experiment name. Defaults to ``default``. If it is the empty string
then no per-experiment subdirectory is used.
version
Experiment version. If version is not specified the logger inspects the
save directory for existing versions, then automatically assigns the
next available version. If it is a string then it is used as the
run-specific subdirectory name, otherwise ``version_${version}`` is
used.
log_graph
Adds the computational graph to tensorboard. This requires that the
user has defined the `self.example_input_array` attribute in their
model.
default_hp_metric
Enables a placeholder metric with key `hp_metric` when
`log_hyperparams` is called without a metric (otherwise calls to
log_hyperparams without a metric are ignored).
prefix
A string to put at the beginning of metric keys.
sub_dir
Sub-directory to group TensorBoard logs. If a sub_dir argument is
passed then logs are saved in ``/save_dir/name/version/sub_dir/``.
Defaults to ``None`` in which logs are saved in
``/save_dir/name/version/``.
\**kwargs
Additional arguments used by :py:class:`tensorboardX.SummaryWriter` can
be passed as keyword arguments in this logger. To automatically flush
to disk, ``max_queue`` sets the size of the queue for pending logs before
flushing. ``flush_secs`` determines how many seconds elapses before
flushing.
"""
def __init__(
self,
save_dir: _PATH,
name: str = "lightning-logs",
version: int | str | None = None,
log_graph: bool = False,
default_hp_metric: bool = True,
prefix: str = "",
sub_dir: _PATH | None = None,
**kwargs: dict[str, typing.Any],
):
super().__init__(
save_dir,
name,
version,
log_graph,
default_hp_metric,
prefix,
sub_dir,
**kwargs,
)
@property
def log_dir(self) -> str:
return os.path.join(self.save_dir, self.name) # noqa: PTH118