mednet.engine.classify.predictor

Prediction engine for classification tasks.

Functions

run(model, datamodule, device_manager)

Run inference on input data, output predictions.

mednet.engine.classify.predictor.run(model, datamodule, device_manager)[source]

Run inference on input data, output predictions.

Parameters:
  • model (LightningModule) – Neural network model (e.g. pasa).

  • datamodule (LightningDataModule) – The lightning DataModule to run predictions on.

  • device_manager (DeviceManager) – An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device or a lightning accelerator setup.

Return type:

Union[list[tuple[str, Sequence[int], Sequence[float]]], list[list[tuple[str, Sequence[int], Sequence[float]]]], Mapping[str, Sequence[tuple[str, Sequence[int], Sequence[float]]]], None]

Returns:

Depending on the return type of the DataModule’s predict_dataloader() method:

Raises:

TypeError – If the DataModule’s predict_dataloader() method does not return any of the types described above.