mednet.engine.trainer

Functions supporting training of pytorch models.

Functions

get_checkpoint_file(results_dir)

Return the path of the latest checkpoint if it exists.

load_checkpoint(checkpoint_file)

Load the checkpoint.

run(model, datamodule, validation_period, ...)

Fit a CNN model using supervised learning and save it to disk.

setup_datamodule(datamodule, model, ...)

Configure and set up the datamodule.

validate_model_datamodule(model, datamodule)

Validate the use of a model and datamodule together.

mednet.engine.trainer.get_checkpoint_file(results_dir)[source]

Return the path of the latest checkpoint if it exists.

Parameters:

results_dir (Path) – Directory in which results are saved.

Return type:

Path | None

Returns:

Path to the latest checkpoint

mednet.engine.trainer.load_checkpoint(checkpoint_file)[source]

Load the checkpoint.

Parameters:

checkpoint_file (Path) – Path to the checkpoint.

mednet.engine.trainer.setup_datamodule(datamodule, model, batch_size, drop_incomplete_batch, cache_samples, parallel)[source]

Configure and set up the datamodule.

Return type:

None

mednet.engine.trainer.validate_model_datamodule(model, datamodule)[source]

Validate the use of a model and datamodule together.

Parameters:
  • model (Model) – The model to be validated.

  • datamodule (ConcatDataModule) – The datamodule to be validated.

Raises:

TypeError – In case the types of both objects is not compatible.

mednet.engine.trainer.run(model, datamodule, validation_period, device_manager, max_epochs, output_folder, monitoring_interval, accumulate_grad_batches, checkpoint)[source]

Fit a CNN model using supervised learning and save it to disk.

This method supports periodic checkpointing and the output of a tensorboard-formatted log with the evolution of some figures during training.

Parameters:
  • model (LightningModule) – Neural network model (e.g. pasa).

  • datamodule (LightningDataModule) – The lightning DataModule to use for training and validation.

  • validation_period (int) – Number of epochs after which validation happens. By default, we run validation after every training epoch (period=1). You can change this to make validation more sparse, by increasing the validation period. Notice that this affects checkpoint saving. While checkpoints are created after every training step (the last training step always triggers the overriding of latest checkpoint), and that this process is independent of validation runs, evaluation of the ‘best’ model obtained so far based on those will be influenced by this setting.

  • device_manager (DeviceManager) – An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device or a lightning accelerator setup.

  • max_epochs (int) – The maximum number of epochs to train for.

  • output_folder (Path) – Folder in which the results will be saved.

  • monitoring_interval (int | float) – Interval, in seconds (or fractions), through which we should monitor resources during training.

  • accumulate_grad_batches (int) – Number of accumulations for backward propagation to accumulate gradients over k batches before stepping the optimizer. The default of 1 forces the whole batch to be processed at once. Otherwise the batch is multiplied by accumulate-grad-batches pieces, and gradients are accumulated to complete each step. This is especially interesting when one is training on GPUs with a limited amount of onboard RAM.

  • checkpoint (Path | None) – Path to an optional checkpoint file to load.