Skip to content

Callbacks

PredictionWriter

Bases: BasePredictionWriter, Callback

A callback that writes predictions to disk at specified intervals during training.

Logits, Embeddings, Hiddens, Input IDs, and Labels may all be saved to the disk depending on trainer configuration. Batch Idxs are provided for each prediction in the same dictionary. These must be used to maintain order between multi device predictions and single device predictions.

Source code in bionemo/llm/utils/callbacks.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class PredictionWriter(BasePredictionWriter, pl.Callback):
    """A callback that writes predictions to disk at specified intervals during training.

    Logits, Embeddings, Hiddens, Input IDs, and Labels may all be saved to the disk depending on trainer configuration.
    Batch Idxs are provided for each prediction in the same dictionary. These must be used to maintain order between
    multi device predictions and single device predictions.
    """

    def __init__(
        self,
        output_dir: str | os.PathLike,
        write_interval: IntervalT,
        batch_dim_key_defaults: dict[str, int] | None = None,
        seq_dim_key_defaults: dict[str, int] | None = None,
    ):
        """Initializes the callback.

        Args:
            output_dir: The directory where predictions will be written.
            write_interval: The interval at which predictions will be written (batch, epoch). Epoch may not be used with multi-device trainers.
            batch_dim_key_defaults: The default batch dimension for each key, if different from the standard 0.
            seq_dim_key_defaults: The default sequence dimension for each key, if different from the standard 1.
        """
        super().__init__(write_interval)
        self.write_interval = write_interval
        self.output_dir = str(output_dir)
        self.batch_dim_key_defaults = batch_dim_key_defaults
        self.seq_dim_key_defaults = seq_dim_key_defaults

    def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs) -> None:  # noqa: D417
        """Invoked with Trainer.fit, validate, test, and predict are called. Will immediately fail when 'write_interval' is 'epoch' and 'trainer.num_devices' > 1.

        Args:
            trainer: The Trainer instance.
            pl_module: The LightningModule instance.
        """
        if trainer.num_devices > 1 and self.write_interval == "epoch":
            raise ValueError(
                "Multi-GPU predictions are not permitted as outputs are not ordered and batch indices are lost."
            )

    def write_on_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        prediction: Any,
        batch_indices: Sequence[int],
        batch: Any,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        """Writes predictions to disk at the end of each batch.

        Predictions files follow the naming pattern, where rank is the active GPU in which the predictions were made.
        predictions__rank_{rank}__batch_{batch_idx}.pt

        Args:
            trainer: The Trainer instance.
            pl_module: The LightningModule instance.
            prediction: The prediction made by the model.
            batch_indices: The indices of the batch.
            batch: The batch data.
            batch_idx: The index of the batch.
            dataloader_idx: The index of the dataloader.
        """
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}__batch_{batch_idx}.pt")

        # batch_indices is not captured due to a lightning bug when return_predictions = False
        # we use input IDs in the prediction to map the result to input.

        # NOTE store the batch_idx so we do not need to rely on filenames for reconstruction of inputs. This is wrapped
        # in a tensor and list container to ensure compatibility with batch_collator.
        prediction["batch_idx"] = torch.tensor([batch_idx], dtype=torch.int64)

        torch.save(prediction, result_path)
        logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")

    def write_on_epoch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        predictions: Any,
        batch_indices: Sequence[int],
    ) -> None:
        """Writes predictions to disk at the end of each epoch.

        Writing all predictions on epoch end is memory intensive. It is recommended to use the batch writer instead for
        large predictions.

        Multi-device predictions will likely yield predictions in an order that is inconsistent with single device predictions and the input data.

        Args:
            trainer: The Trainer instance.
            pl_module: The LightningModule instance.
            predictions: The predictions made by the model.
            batch_indices: The indices of the batch.

        Raises:
            Multi-GPU predictions are output in an inconsistent order with multiple devices.
        """
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank

        result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}.pt")

        # collate multiple batches / ignore empty ones
        collate_kwargs = {}
        if self.batch_dim_key_defaults is not None:
            collate_kwargs["batch_dim_key_defaults"] = self.batch_dim_key_defaults
        if self.seq_dim_key_defaults is not None:
            collate_kwargs["seq_dim_key_defaults"] = self.seq_dim_key_defaults

        prediction = batch_collator([item for item in predictions if item is not None], **collate_kwargs)

        # batch_indices is not captured due to a lightning bug when return_predictions = False
        # we use input IDs in the prediction to map the result to input
        if isinstance(prediction, dict):
            keys = prediction.keys()
        else:
            keys = "tensor"
        torch.save(prediction, result_path)
        logging.info(f"Inference predictions are stored in {result_path}\n{keys}")

__init__(output_dir, write_interval, batch_dim_key_defaults=None, seq_dim_key_defaults=None)

Initializes the callback.

Parameters:

Name Type Description Default
output_dir str | PathLike

The directory where predictions will be written.

required
write_interval IntervalT

The interval at which predictions will be written (batch, epoch). Epoch may not be used with multi-device trainers.

required
batch_dim_key_defaults dict[str, int] | None

The default batch dimension for each key, if different from the standard 0.

None
seq_dim_key_defaults dict[str, int] | None

The default sequence dimension for each key, if different from the standard 1.

None
Source code in bionemo/llm/utils/callbacks.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self,
    output_dir: str | os.PathLike,
    write_interval: IntervalT,
    batch_dim_key_defaults: dict[str, int] | None = None,
    seq_dim_key_defaults: dict[str, int] | None = None,
):
    """Initializes the callback.

    Args:
        output_dir: The directory where predictions will be written.
        write_interval: The interval at which predictions will be written (batch, epoch). Epoch may not be used with multi-device trainers.
        batch_dim_key_defaults: The default batch dimension for each key, if different from the standard 0.
        seq_dim_key_defaults: The default sequence dimension for each key, if different from the standard 1.
    """
    super().__init__(write_interval)
    self.write_interval = write_interval
    self.output_dir = str(output_dir)
    self.batch_dim_key_defaults = batch_dim_key_defaults
    self.seq_dim_key_defaults = seq_dim_key_defaults

setup(trainer, pl_module, *args, **kwargs)

Invoked with Trainer.fit, validate, test, and predict are called. Will immediately fail when 'write_interval' is 'epoch' and 'trainer.num_devices' > 1.

Parameters:

Name Type Description Default
trainer Trainer

The Trainer instance.

required
pl_module LightningModule

The LightningModule instance.

required
Source code in bionemo/llm/utils/callbacks.py
60
61
62
63
64
65
66
67
68
69
70
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs) -> None:  # noqa: D417
    """Invoked with Trainer.fit, validate, test, and predict are called. Will immediately fail when 'write_interval' is 'epoch' and 'trainer.num_devices' > 1.

    Args:
        trainer: The Trainer instance.
        pl_module: The LightningModule instance.
    """
    if trainer.num_devices > 1 and self.write_interval == "epoch":
        raise ValueError(
            "Multi-GPU predictions are not permitted as outputs are not ordered and batch indices are lost."
        )

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)

Writes predictions to disk at the end of each batch.

Predictions files follow the naming pattern, where rank is the active GPU in which the predictions were made. predictions__rank_{rank}__batch_{batch_idx}.pt

Parameters:

Name Type Description Default
trainer Trainer

The Trainer instance.

required
pl_module LightningModule

The LightningModule instance.

required
prediction Any

The prediction made by the model.

required
batch_indices Sequence[int]

The indices of the batch.

required
batch Any

The batch data.

required
batch_idx int

The index of the batch.

required
dataloader_idx int

The index of the dataloader.

required
Source code in bionemo/llm/utils/callbacks.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def write_on_batch_end(
    self,
    trainer: pl.Trainer,
    pl_module: pl.LightningModule,
    prediction: Any,
    batch_indices: Sequence[int],
    batch: Any,
    batch_idx: int,
    dataloader_idx: int,
) -> None:
    """Writes predictions to disk at the end of each batch.

    Predictions files follow the naming pattern, where rank is the active GPU in which the predictions were made.
    predictions__rank_{rank}__batch_{batch_idx}.pt

    Args:
        trainer: The Trainer instance.
        pl_module: The LightningModule instance.
        prediction: The prediction made by the model.
        batch_indices: The indices of the batch.
        batch: The batch data.
        batch_idx: The index of the batch.
        dataloader_idx: The index of the dataloader.
    """
    # this will create N (num processes) files in `output_dir` each containing
    # the predictions of it's respective rank
    result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}__batch_{batch_idx}.pt")

    # batch_indices is not captured due to a lightning bug when return_predictions = False
    # we use input IDs in the prediction to map the result to input.

    # NOTE store the batch_idx so we do not need to rely on filenames for reconstruction of inputs. This is wrapped
    # in a tensor and list container to ensure compatibility with batch_collator.
    prediction["batch_idx"] = torch.tensor([batch_idx], dtype=torch.int64)

    torch.save(prediction, result_path)
    logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")

write_on_epoch_end(trainer, pl_module, predictions, batch_indices)

Writes predictions to disk at the end of each epoch.

Writing all predictions on epoch end is memory intensive. It is recommended to use the batch writer instead for large predictions.

Multi-device predictions will likely yield predictions in an order that is inconsistent with single device predictions and the input data.

Parameters:

Name Type Description Default
trainer Trainer

The Trainer instance.

required
pl_module LightningModule

The LightningModule instance.

required
predictions Any

The predictions made by the model.

required
batch_indices Sequence[int]

The indices of the batch.

required
Source code in bionemo/llm/utils/callbacks.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def write_on_epoch_end(
    self,
    trainer: pl.Trainer,
    pl_module: pl.LightningModule,
    predictions: Any,
    batch_indices: Sequence[int],
) -> None:
    """Writes predictions to disk at the end of each epoch.

    Writing all predictions on epoch end is memory intensive. It is recommended to use the batch writer instead for
    large predictions.

    Multi-device predictions will likely yield predictions in an order that is inconsistent with single device predictions and the input data.

    Args:
        trainer: The Trainer instance.
        pl_module: The LightningModule instance.
        predictions: The predictions made by the model.
        batch_indices: The indices of the batch.

    Raises:
        Multi-GPU predictions are output in an inconsistent order with multiple devices.
    """
    # this will create N (num processes) files in `output_dir` each containing
    # the predictions of it's respective rank

    result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}.pt")

    # collate multiple batches / ignore empty ones
    collate_kwargs = {}
    if self.batch_dim_key_defaults is not None:
        collate_kwargs["batch_dim_key_defaults"] = self.batch_dim_key_defaults
    if self.seq_dim_key_defaults is not None:
        collate_kwargs["seq_dim_key_defaults"] = self.seq_dim_key_defaults

    prediction = batch_collator([item for item in predictions if item is not None], **collate_kwargs)

    # batch_indices is not captured due to a lightning bug when return_predictions = False
    # we use input IDs in the prediction to map the result to input
    if isinstance(prediction, dict):
        keys = prediction.keys()
    else:
        keys = "tensor"
    torch.save(prediction, result_path)
    logging.info(f"Inference predictions are stored in {result_path}\n{keys}")