mednet.engine.classify.predictor¶
Prediction engine for classification tasks.
Functions
|
Run inference on input data, output predictions. |
- mednet.engine.classify.predictor.run(model, datamodule, device_manager)[source]¶
Run inference on input data, output predictions.
- Parameters:
model (
LightningModule) – Neural network model (e.g. pasa).datamodule (
LightningDataModule) – The lightning DataModule to use for training and validation.device_manager (
DeviceManager) – An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device or a lightning accelerator setup.
- Return type:
Union[list[tuple[str,Sequence[int],Sequence[float]]],list[list[tuple[str,Sequence[int],Sequence[float]]]],Mapping[str,Sequence[tuple[str,Sequence[int],Sequence[float]]]],None]- Returns:
Depending on the return type of the DataModule’s
predict_dataloader()method:if
torch.utils.data.DataLoader, then returns alistof predictions.if
listoftorch.utils.data.DataLoader, then returns a list of lists of predictions, each list corresponding to the iteration over one of the dataloaders.if
dictofstrtotorch.utils.data.DataLoader, then returns a dictionary mapping names to lists of predictions.if
None, then returnsNone.
- Raises:
TypeError – If the DataModule’s
predict_dataloader()method does not return any of the types described above.