Skip to content

Stop and go

StopAndGoHarness

Bases: ABC

Abstract base class for testing consistency between interrupted and continuous training.

Users should override cls.setup_model and update cls.setup_class to customize the downstream test cases. Metadata are collected through callbacks and users can add new unit tests by comparing the metadata for the interrupted and continuous cases.

By default, learning rate, global step, optimizer state, consumed samples, input and output tensors, and loss are compared. Users can add additional metrics by adding new callbacks to cls.callbacks and associated test functions.

Stop and go tests act as follows
  • setup a clean model for a brief training run, set callbacks to track.
  • interrupt training via the StopAndGoException in the callback Raise.
  • train the model resumed from the checkpoint with the same set of callbacks.
  • train the model continuously without interruption with a new set of the same callbacks.
  • compare each pair of interrupted and continuous callbacks to check for equality.
Considerations when implementing this class
  • The derived test name should start with Test, and test methods should start with test_ to enable pytest discovery.
  • devices, pipeline_model_parallel, and tensor_model_parallel may impact the setup of DataModule. Certain datasets expect a known global batch size, which depends on the number of devices and conditional tensor model parallel/ pipeline model parallel settings. By default, we are testing only on single device without parallelism.
  • 'mode' is useful in some cases, but not in all cases. Implement conditions based on these when useful. As an example, it may be useful to implement a test that stops and resumes.
    • changing callbacks to test metadata integrity (core feature of stop-and-go tests).
    • changing the model construction to use different hyperparameters.
    • ... etc Each of the above tests cases may be useful for automated testing of various expected behavior.
  • stop(), resume(), continuous() or collectively run_stop_and_go() are provided methods which execute the actual tests, leveraging the conditions in the various setup methods, respecting 'mode' where necessary.

Attributes:

Name Type Description
root_dir

The root directory.

val_check_interval int

The validation check interval. Stored as an attribute to ensure consistency.

exp_name str

The experiment name.

extra_metrics_dict str

A dictionary of metrics and their corresponding functions.

See Also: bionemo.testing.callbacks.

Source code in bionemo/testing/harnesses/stop_and_go.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
class StopAndGoHarness(ABC):
    """Abstract base class for testing consistency between interrupted and continuous training.

    Users should override cls.setup_model and update cls.setup_class to customize the downstream test cases. Metadata
    are collected through callbacks and users can add new unit tests by comparing the metadata for the interrupted and
    continuous cases.

    By default, learning rate, global step, optimizer state, consumed samples, input and output tensors, and loss are
    compared. Users can add additional metrics by adding new callbacks to `cls.callbacks` and associated test functions.

    Stop and go tests act as follows:
        - setup a clean model for a brief training run, set callbacks to track.
        - interrupt training via the StopAndGoException in the callback Raise.
        - train the model resumed from the checkpoint with the same set of callbacks.
        - train the model continuously without interruption with a new set of the same callbacks.
        - compare each pair of interrupted and continuous callbacks to check for equality.

    Considerations when implementing this class:
        - The derived test name should start with `Test`, and test methods should start with `test_` to enable pytest
          discovery.
        - devices, pipeline_model_parallel, and tensor_model_parallel may impact the setup of DataModule. Certain
            datasets expect a known global batch size, which depends on the number of devices and conditional tensor
            model parallel/ pipeline model parallel settings. By default, we are testing only on single device without
            parallelism.
        - 'mode' is useful in some cases, but not in all cases. Implement conditions based on these when useful. As an
            example, it may be useful to implement a test that stops and resumes.
            - changing callbacks to test metadata integrity (core feature of stop-and-go tests).
            - changing the model construction to use different hyperparameters.
            - ... etc
            Each of the above tests cases may be useful for automated testing of various expected behavior.
        - stop(), resume(), continuous() or collectively run_stop_and_go() are provided methods which execute the actual
          tests, leveraging the conditions in the various setup methods, respecting 'mode' where necessary.

    Attributes:
        root_dir: The root directory.
        val_check_interval: The validation check interval. Stored as an attribute to ensure consistency.
        exp_name: The experiment name.
        extra_metrics_dict: A dictionary of metrics and their corresponding functions.

    See Also: bionemo.testing.callbacks.
    """

    # class variables that need to be overridden
    num_steps: int
    val_check_interval: int
    limit_val_batches: int
    lr: float = 1e-4
    precision: Literal["16-mixed", "bf16-mixed", "32"]

    # class variables that will be setup in setUpClass
    tempdir: tempfile.TemporaryDirectory
    metadata_dir: pathlib.Path
    exp_name: str
    callbacks: CallbackDict
    nemo_logger: NeMoLogger

    @classmethod
    def setup_class(cls) -> None:
        """Sets up the class by creating a temporary directory, metadata_dir, exp_name and callbacks."""
        cls.tempdir = tempfile.TemporaryDirectory()
        cls.metadata_dir = pathlib.Path(cls.tempdir.name) / "metadata"
        cls.exp_name = cls.__name__

        cls.callbacks = cls.get_default_callbacks()

        cls.nemo_logger = NeMoLogger(
            log_dir=cls.tempdir.name,
            name=cls.exp_name,
            use_datetime_version=False,
            version=None,
            tensorboard=None,
            wandb=None,
            ckpt=None,
        )

    @classmethod
    def teardown_class(cls) -> None:
        """Tears down the class by cleaning up the temporary directory."""
        cls.tempdir.cleanup()

    @classmethod
    @abstractmethod
    def setup_model(cls, mode: Mode) -> tuple[pl.LightningModule, pl.LightningDataModule, nl.MegatronOptimizerModule]:
        """Constructs the model, data, and optimizer for the test harness.

        Optionally supports separate code paths for 'stop'/'resume'/'continuous', although implementors are encouraged
        to use the same code path for both.

        Args:
            mode: The mode indicating whether to stop or go.

        Returns:
            tuple: A tuple containing the model, data, and optimizer.
        """
        raise NotImplementedError()

    @classmethod
    def setup_trainer(
        cls,
        mode: Mode,
    ) -> nl.Trainer:
        """Setup trainer by passing stop, resume, or continuous callbacks according to mode.

        Args:
            mode (Mode): The mode indicating whether to stop, resume, or train continuously.

        Returns:
            (nl.Trainer): NeMo Lightning trainer object.
        """
        strategy = MegatronStrategy(
            ddp="megatron",
            find_unused_parameters=True,
            ckpt_include_optimizer=True,
        )

        trainer = nl.Trainer(
            devices=1,
            max_steps=cls.num_steps,
            accelerator="gpu",
            strategy=strategy,
            limit_val_batches=cls.limit_val_batches,
            val_check_interval=cls.val_check_interval,
            log_every_n_steps=cls.val_check_interval,
            num_nodes=1,
            callbacks=list(cls.callbacks[mode].values()),
            plugins=nl.MegatronMixedPrecision(precision=cls.precision),
        )
        return trainer

    @classmethod
    def get_default_callbacks(cls) -> CallbackDict:
        """Returns a list of callbacks based on the specified mode. Base implementation provides reasonable defaults.

        To extend this method, call the super and append to the callbacks, depending on which mode you are in:

        ```python
        callbacks = super().get_callbacks()
        callbacks[mode]["MyCustomCallback"] = MyCustomCallback()
        return callbacks
        ```

        Returns:
            A dictionary of callbacks based on the specified mode, each of which maps a callback name to a callback
            object.
        """
        callbacks: CallbackDict = {}

        def make_callbacks() -> Dict[Type[pl.Callback], pl.Callback]:
            return {
                testing_callbacks.LearningRateCallback: testing_callbacks.LearningRateCallback(),
                testing_callbacks.GlobalStepStateCallback: testing_callbacks.GlobalStepStateCallback(),
                testing_callbacks.ConsumedSamplesCallback: testing_callbacks.ConsumedSamplesCallback(),
                testing_callbacks.OptimizerStateCallback: testing_callbacks.OptimizerStateCallback(),
                testing_callbacks.TrainInputCallback: testing_callbacks.TrainInputCallback(),
                testing_callbacks.TrainOutputCallback: testing_callbacks.TrainOutputCallback(),
                testing_callbacks.TrainLossCallback: testing_callbacks.TrainLossCallback(),
                testing_callbacks.ValidInputCallback: testing_callbacks.ValidInputCallback(),
                testing_callbacks.ValidOutputCallback: testing_callbacks.ValidOutputCallback(),
                testing_callbacks.ValidLossCallback: testing_callbacks.ValidLossCallback(),
            }

        interrupted_callbacks = make_callbacks()
        callbacks[Mode.CONTINUOUS] = make_callbacks()

        for mode in [Mode.STOP, Mode.RESUME]:
            consumed_samples_cls = testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
            callbacks[mode] = {
                consumed_samples_cls: consumed_samples_cls(mode=mode),
                **interrupted_callbacks,
            }

        callbacks[Mode.STOP].update(
            {
                testing_callbacks.RaiseAfterMetadataCallback: testing_callbacks.RaiseAfterMetadataCallback(),
                nl_callbacks.ModelCheckpoint: nl_callbacks.ModelCheckpoint(
                    save_last=True,
                    monitor="reduced_train_loss",
                    save_top_k=2,
                    every_n_train_steps=cls.val_check_interval,
                    always_save_context=True,
                ),
            }
        )

        return callbacks

    # stop() and resume() are provided methods and run the requisite methods with the appropriate mode.
    @classmethod
    def stop(cls) -> None:
        """Runs pre-training and 'stops' after the first checkpoint is saved.

        This method sets up the model, data, and optimizer for the Mode.STOP mode.
        It then sets up the trainer and strategy for the Mode.STOP mode with the given metrics.
        The training process is executed using the `llm.train` function, passing the model, data, trainer, logger, optimizer, and resume options.
        If a `testing_callbacks.StopAndGoException` is raised during training, it is caught and no action is taken.

        Raises:
            testing_callbacks.StopAndGoException: If a stop and go exception occurs during training.
        """
        logging.info("Running stop()...")

        model, data, opt = cls.setup_model(mode=Mode.STOP)
        trainer = cls.setup_trainer(Mode.STOP)
        with distributed_model_parallel_state():
            try:
                llm.train(
                    model=model,
                    data=data,
                    trainer=trainer,
                    log=cls.nemo_logger,
                    optim=opt,
                    resume=resume.AutoResume(
                        resume_if_exists=False,  # Looks for the -last checkpoint to continue training.
                        resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
                    ),
                )
            except testing_callbacks.StopAndGoException:
                return

    @classmethod
    def resume(cls) -> None:
        """Resumes the model from the checkpoint saved at the end of `stop()` and verifies the metadata integrity."""
        logging.info("Running resume()...")

        model, data, opt = cls.setup_model(mode=Mode.RESUME)
        trainer = cls.setup_trainer(Mode.RESUME)
        with distributed_model_parallel_state():
            llm.train(
                model=model,
                data=data,
                trainer=trainer,
                log=cls.nemo_logger,
                optim=opt,
                resume=resume.AutoResume(
                    resume_if_exists=True,  # Looks for the -last checkpoint to continue training.
                    resume_ignore_no_checkpoint=False,  # When false this will throw an error with no existing checkpoint.
                ),
            )

    @classmethod
    def continuous(cls) -> None:
        """Trains the model in one continuous path without stopping."""
        logging.info("Running continuous()...")

        model, data, opt = cls.setup_model(mode=Mode.CONTINUOUS)
        trainer = cls.setup_trainer(Mode.CONTINUOUS)
        with distributed_model_parallel_state():
            llm.train(model=model, data=data, trainer=trainer, log=cls.nemo_logger, optim=opt)

    @classmethod
    def run_stop_and_go(cls):
        """Executes training both continuously and with a checkpoint interruption."""
        # Interrupted model training
        cls.stop()
        cls.resume()

        # Continuous model training.
        cls.continuous()

    @pytest.mark.parametrize(
        "callback_type",
        [
            testing_callbacks.LearningRateCallback,
            testing_callbacks.GlobalStepStateCallback,
            testing_callbacks.ConsumedSamplesCallback,
            testing_callbacks.OptimizerStateCallback,
            testing_callbacks.TrainInputCallback,
            testing_callbacks.TrainOutputCallback,
            testing_callbacks.TrainLossCallback,
        ],
    )
    def test_stop_and_go_consistency(self, callback_type):
        """Tests the consistency of the callback data between the interrupted and continuous checks."""
        interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
        continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
        assert interrupted_callback.data, f"No data found for {callback_type}"

        if callback_type == testing_callbacks.TrainOutputCallback:
            atol = 1e-3
        else:
            atol = 1e-4

        recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data, atol=atol)

    def test_train_val_init_consumed_samples(self):
        """Tests the initial consumed samples in stop-and-go scenario."""
        train_consumed_stop, val_consumed_stop = get_callback(
            self.callbacks, Mode.STOP, testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
        ).data
        train_consumed_go, val_consumed_go = get_callback(
            self.callbacks, Mode.RESUME, testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
        ).data

        assert val_consumed_stop == 0
        assert val_consumed_go == 0
        assert train_consumed_stop == 0
        assert train_consumed_go > 0

    # TODO: For some reason, validation in NeMo runs an extra batch in the case when the training is stopped and
    # resumed. Hopefully we can fix this upstream and remove the indexing based on the length of the continuous
    # validation batches.
    @pytest.mark.xfail(reason="Validation runs an extra batch in the case when training is stopped and resumed.")
    def test_identical_number_of_validation_batches(self):
        """Ensures that the input tensors for training are identical for the interrupted and continuous tests."""
        callback_type = testing_callbacks.ValidInputCallback
        interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
        continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
        assert interrupted_callback.data, f"No data found for {callback_type}"
        recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data)
        assert len(interrupted_callback.data) == len(continuous_callback.data)

    @pytest.mark.parametrize(
        "callback_type",
        [
            testing_callbacks.ValidInputCallback,
            testing_callbacks.ValidOutputCallback,
            testing_callbacks.ValidLossCallback,
        ],
    )
    def test_stop_and_go_consistency_with_uneven_validation_sizes(self, callback_type):
        """Ensures that the input tensors for training are identical for the interrupted and continuous tests."""
        interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
        continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
        assert interrupted_callback.data, f"No data found for {callback_type}"

        # Hack: Validation seems to run an extra batch in the case when training is stopped and resumed, but we can
        # still test the rest of the data to ensure consistency.
        interrupted_data = interrupted_callback.data[-len(continuous_callback.data) :]

        if callback_type == testing_callbacks.ValidOutputCallback:
            atol = 1e-3
        else:
            atol = 1e-4

        recursive_assert_approx_equal(interrupted_data, continuous_callback.data, atol=atol)

continuous() classmethod

Trains the model in one continuous path without stopping.

Source code in bionemo/testing/harnesses/stop_and_go.py
300
301
302
303
304
305
306
307
308
@classmethod
def continuous(cls) -> None:
    """Trains the model in one continuous path without stopping."""
    logging.info("Running continuous()...")

    model, data, opt = cls.setup_model(mode=Mode.CONTINUOUS)
    trainer = cls.setup_trainer(Mode.CONTINUOUS)
    with distributed_model_parallel_state():
        llm.train(model=model, data=data, trainer=trainer, log=cls.nemo_logger, optim=opt)

get_default_callbacks() classmethod

Returns a list of callbacks based on the specified mode. Base implementation provides reasonable defaults.

To extend this method, call the super and append to the callbacks, depending on which mode you are in:

callbacks = super().get_callbacks()
callbacks[mode]["MyCustomCallback"] = MyCustomCallback()
return callbacks

Returns:

Type Description
CallbackDict

A dictionary of callbacks based on the specified mode, each of which maps a callback name to a callback

CallbackDict

object.

Source code in bionemo/testing/harnesses/stop_and_go.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
@classmethod
def get_default_callbacks(cls) -> CallbackDict:
    """Returns a list of callbacks based on the specified mode. Base implementation provides reasonable defaults.

    To extend this method, call the super and append to the callbacks, depending on which mode you are in:

    ```python
    callbacks = super().get_callbacks()
    callbacks[mode]["MyCustomCallback"] = MyCustomCallback()
    return callbacks
    ```

    Returns:
        A dictionary of callbacks based on the specified mode, each of which maps a callback name to a callback
        object.
    """
    callbacks: CallbackDict = {}

    def make_callbacks() -> Dict[Type[pl.Callback], pl.Callback]:
        return {
            testing_callbacks.LearningRateCallback: testing_callbacks.LearningRateCallback(),
            testing_callbacks.GlobalStepStateCallback: testing_callbacks.GlobalStepStateCallback(),
            testing_callbacks.ConsumedSamplesCallback: testing_callbacks.ConsumedSamplesCallback(),
            testing_callbacks.OptimizerStateCallback: testing_callbacks.OptimizerStateCallback(),
            testing_callbacks.TrainInputCallback: testing_callbacks.TrainInputCallback(),
            testing_callbacks.TrainOutputCallback: testing_callbacks.TrainOutputCallback(),
            testing_callbacks.TrainLossCallback: testing_callbacks.TrainLossCallback(),
            testing_callbacks.ValidInputCallback: testing_callbacks.ValidInputCallback(),
            testing_callbacks.ValidOutputCallback: testing_callbacks.ValidOutputCallback(),
            testing_callbacks.ValidLossCallback: testing_callbacks.ValidLossCallback(),
        }

    interrupted_callbacks = make_callbacks()
    callbacks[Mode.CONTINUOUS] = make_callbacks()

    for mode in [Mode.STOP, Mode.RESUME]:
        consumed_samples_cls = testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
        callbacks[mode] = {
            consumed_samples_cls: consumed_samples_cls(mode=mode),
            **interrupted_callbacks,
        }

    callbacks[Mode.STOP].update(
        {
            testing_callbacks.RaiseAfterMetadataCallback: testing_callbacks.RaiseAfterMetadataCallback(),
            nl_callbacks.ModelCheckpoint: nl_callbacks.ModelCheckpoint(
                save_last=True,
                monitor="reduced_train_loss",
                save_top_k=2,
                every_n_train_steps=cls.val_check_interval,
                always_save_context=True,
            ),
        }
    )

    return callbacks

resume() classmethod

Resumes the model from the checkpoint saved at the end of stop() and verifies the metadata integrity.

Source code in bionemo/testing/harnesses/stop_and_go.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
@classmethod
def resume(cls) -> None:
    """Resumes the model from the checkpoint saved at the end of `stop()` and verifies the metadata integrity."""
    logging.info("Running resume()...")

    model, data, opt = cls.setup_model(mode=Mode.RESUME)
    trainer = cls.setup_trainer(Mode.RESUME)
    with distributed_model_parallel_state():
        llm.train(
            model=model,
            data=data,
            trainer=trainer,
            log=cls.nemo_logger,
            optim=opt,
            resume=resume.AutoResume(
                resume_if_exists=True,  # Looks for the -last checkpoint to continue training.
                resume_ignore_no_checkpoint=False,  # When false this will throw an error with no existing checkpoint.
            ),
        )

run_stop_and_go() classmethod

Executes training both continuously and with a checkpoint interruption.

Source code in bionemo/testing/harnesses/stop_and_go.py
310
311
312
313
314
315
316
317
318
@classmethod
def run_stop_and_go(cls):
    """Executes training both continuously and with a checkpoint interruption."""
    # Interrupted model training
    cls.stop()
    cls.resume()

    # Continuous model training.
    cls.continuous()

setup_class() classmethod

Sets up the class by creating a temporary directory, metadata_dir, exp_name and callbacks.

Source code in bionemo/testing/harnesses/stop_and_go.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
@classmethod
def setup_class(cls) -> None:
    """Sets up the class by creating a temporary directory, metadata_dir, exp_name and callbacks."""
    cls.tempdir = tempfile.TemporaryDirectory()
    cls.metadata_dir = pathlib.Path(cls.tempdir.name) / "metadata"
    cls.exp_name = cls.__name__

    cls.callbacks = cls.get_default_callbacks()

    cls.nemo_logger = NeMoLogger(
        log_dir=cls.tempdir.name,
        name=cls.exp_name,
        use_datetime_version=False,
        version=None,
        tensorboard=None,
        wandb=None,
        ckpt=None,
    )

setup_model(mode) abstractmethod classmethod

Constructs the model, data, and optimizer for the test harness.

Optionally supports separate code paths for 'stop'/'resume'/'continuous', although implementors are encouraged to use the same code path for both.

Parameters:

Name Type Description Default
mode Mode

The mode indicating whether to stop or go.

required

Returns:

Name Type Description
tuple tuple[LightningModule, LightningDataModule, MegatronOptimizerModule]

A tuple containing the model, data, and optimizer.

Source code in bionemo/testing/harnesses/stop_and_go.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
@classmethod
@abstractmethod
def setup_model(cls, mode: Mode) -> tuple[pl.LightningModule, pl.LightningDataModule, nl.MegatronOptimizerModule]:
    """Constructs the model, data, and optimizer for the test harness.

    Optionally supports separate code paths for 'stop'/'resume'/'continuous', although implementors are encouraged
    to use the same code path for both.

    Args:
        mode: The mode indicating whether to stop or go.

    Returns:
        tuple: A tuple containing the model, data, and optimizer.
    """
    raise NotImplementedError()

setup_trainer(mode) classmethod

Setup trainer by passing stop, resume, or continuous callbacks according to mode.

Parameters:

Name Type Description Default
mode Mode

The mode indicating whether to stop, resume, or train continuously.

required

Returns:

Type Description
Trainer

NeMo Lightning trainer object.

Source code in bionemo/testing/harnesses/stop_and_go.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@classmethod
def setup_trainer(
    cls,
    mode: Mode,
) -> nl.Trainer:
    """Setup trainer by passing stop, resume, or continuous callbacks according to mode.

    Args:
        mode (Mode): The mode indicating whether to stop, resume, or train continuously.

    Returns:
        (nl.Trainer): NeMo Lightning trainer object.
    """
    strategy = MegatronStrategy(
        ddp="megatron",
        find_unused_parameters=True,
        ckpt_include_optimizer=True,
    )

    trainer = nl.Trainer(
        devices=1,
        max_steps=cls.num_steps,
        accelerator="gpu",
        strategy=strategy,
        limit_val_batches=cls.limit_val_batches,
        val_check_interval=cls.val_check_interval,
        log_every_n_steps=cls.val_check_interval,
        num_nodes=1,
        callbacks=list(cls.callbacks[mode].values()),
        plugins=nl.MegatronMixedPrecision(precision=cls.precision),
    )
    return trainer

stop() classmethod

Runs pre-training and 'stops' after the first checkpoint is saved.

This method sets up the model, data, and optimizer for the Mode.STOP mode. It then sets up the trainer and strategy for the Mode.STOP mode with the given metrics. The training process is executed using the llm.train function, passing the model, data, trainer, logger, optimizer, and resume options. If a testing_callbacks.StopAndGoException is raised during training, it is caught and no action is taken.

Raises:

Type Description
StopAndGoException

If a stop and go exception occurs during training.

Source code in bionemo/testing/harnesses/stop_and_go.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
@classmethod
def stop(cls) -> None:
    """Runs pre-training and 'stops' after the first checkpoint is saved.

    This method sets up the model, data, and optimizer for the Mode.STOP mode.
    It then sets up the trainer and strategy for the Mode.STOP mode with the given metrics.
    The training process is executed using the `llm.train` function, passing the model, data, trainer, logger, optimizer, and resume options.
    If a `testing_callbacks.StopAndGoException` is raised during training, it is caught and no action is taken.

    Raises:
        testing_callbacks.StopAndGoException: If a stop and go exception occurs during training.
    """
    logging.info("Running stop()...")

    model, data, opt = cls.setup_model(mode=Mode.STOP)
    trainer = cls.setup_trainer(Mode.STOP)
    with distributed_model_parallel_state():
        try:
            llm.train(
                model=model,
                data=data,
                trainer=trainer,
                log=cls.nemo_logger,
                optim=opt,
                resume=resume.AutoResume(
                    resume_if_exists=False,  # Looks for the -last checkpoint to continue training.
                    resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
                ),
            )
        except testing_callbacks.StopAndGoException:
            return

teardown_class() classmethod

Tears down the class by cleaning up the temporary directory.

Source code in bionemo/testing/harnesses/stop_and_go.py
136
137
138
139
@classmethod
def teardown_class(cls) -> None:
    """Tears down the class by cleaning up the temporary directory."""
    cls.tempdir.cleanup()

test_identical_number_of_validation_batches()

Ensures that the input tensors for training are identical for the interrupted and continuous tests.

Source code in bionemo/testing/harnesses/stop_and_go.py
362
363
364
365
366
367
368
369
370
@pytest.mark.xfail(reason="Validation runs an extra batch in the case when training is stopped and resumed.")
def test_identical_number_of_validation_batches(self):
    """Ensures that the input tensors for training are identical for the interrupted and continuous tests."""
    callback_type = testing_callbacks.ValidInputCallback
    interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
    continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
    assert interrupted_callback.data, f"No data found for {callback_type}"
    recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data)
    assert len(interrupted_callback.data) == len(continuous_callback.data)

test_stop_and_go_consistency(callback_type)

Tests the consistency of the callback data between the interrupted and continuous checks.

Source code in bionemo/testing/harnesses/stop_and_go.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
@pytest.mark.parametrize(
    "callback_type",
    [
        testing_callbacks.LearningRateCallback,
        testing_callbacks.GlobalStepStateCallback,
        testing_callbacks.ConsumedSamplesCallback,
        testing_callbacks.OptimizerStateCallback,
        testing_callbacks.TrainInputCallback,
        testing_callbacks.TrainOutputCallback,
        testing_callbacks.TrainLossCallback,
    ],
)
def test_stop_and_go_consistency(self, callback_type):
    """Tests the consistency of the callback data between the interrupted and continuous checks."""
    interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
    continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
    assert interrupted_callback.data, f"No data found for {callback_type}"

    if callback_type == testing_callbacks.TrainOutputCallback:
        atol = 1e-3
    else:
        atol = 1e-4

    recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data, atol=atol)

test_stop_and_go_consistency_with_uneven_validation_sizes(callback_type)

Ensures that the input tensors for training are identical for the interrupted and continuous tests.

Source code in bionemo/testing/harnesses/stop_and_go.py
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
@pytest.mark.parametrize(
    "callback_type",
    [
        testing_callbacks.ValidInputCallback,
        testing_callbacks.ValidOutputCallback,
        testing_callbacks.ValidLossCallback,
    ],
)
def test_stop_and_go_consistency_with_uneven_validation_sizes(self, callback_type):
    """Ensures that the input tensors for training are identical for the interrupted and continuous tests."""
    interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
    continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
    assert interrupted_callback.data, f"No data found for {callback_type}"

    # Hack: Validation seems to run an extra batch in the case when training is stopped and resumed, but we can
    # still test the rest of the data to ensure consistency.
    interrupted_data = interrupted_callback.data[-len(continuous_callback.data) :]

    if callback_type == testing_callbacks.ValidOutputCallback:
        atol = 1e-3
    else:
        atol = 1e-4

    recursive_assert_approx_equal(interrupted_data, continuous_callback.data, atol=atol)

test_train_val_init_consumed_samples()

Tests the initial consumed samples in stop-and-go scenario.

Source code in bionemo/testing/harnesses/stop_and_go.py
345
346
347
348
349
350
351
352
353
354
355
356
357
def test_train_val_init_consumed_samples(self):
    """Tests the initial consumed samples in stop-and-go scenario."""
    train_consumed_stop, val_consumed_stop = get_callback(
        self.callbacks, Mode.STOP, testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
    ).data
    train_consumed_go, val_consumed_go = get_callback(
        self.callbacks, Mode.RESUME, testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
    ).data

    assert val_consumed_stop == 0
    assert val_consumed_go == 0
    assert train_consumed_stop == 0
    assert train_consumed_go > 0

get_callback(callbacks, mode, callback_type)

Returns the callback with the given name and mode.

Convenience function to make type hinting easier.

Parameters:

Name Type Description Default
callbacks CallbackDict

The dictionary of callbacks.

required
mode Mode

The mode indicating whether to stop or go.

required
callback_type Type[Callback]

The type of the callback.

required

Returns:

Type Description
Callback

pl.Callback: The callback with the given name and mode.

Source code in bionemo/testing/harnesses/stop_and_go.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def get_callback(callbacks: CallbackDict, mode: Mode, callback_type: Type[Callback]) -> Callback:
    """Returns the callback with the given name and mode.

    Convenience function to make type hinting easier.

    Args:
        callbacks: The dictionary of callbacks.
        mode: The mode indicating whether to stop or go.
        callback_type: The type of the callback.

    Returns:
        pl.Callback: The callback with the given name and mode.
    """
    return callbacks[mode][callback_type]  # type: ignore