Skip to content

Config

IOMixinProto

Bases: Protocol

A Protocol for the get/set hparam functions of the IOMixin class from NeMo.

Source code in bionemo/llm/model/config.py
118
119
120
121
122
123
124
125
126
127
class IOMixinProto(Protocol):
    """A Protocol for the get/set hparam functions of the IOMixin class from NeMo."""

    def set_hparam(self, attribute: str, value: Any, also_change_value: bool = True) -> None:
        """Set the value of an attribute in the config attached to the class by the IOMixin."""
        ...

    def get_hparam(self, attribute: str) -> Any:
        """Get the value of an attribute in the config attached to the class by the IOMixin."""
        ...

get_hparam(attribute)

Get the value of an attribute in the config attached to the class by the IOMixin.

Source code in bionemo/llm/model/config.py
125
126
127
def get_hparam(self, attribute: str) -> Any:
    """Get the value of an attribute in the config attached to the class by the IOMixin."""
    ...

set_hparam(attribute, value, also_change_value=True)

Set the value of an attribute in the config attached to the class by the IOMixin.

Source code in bionemo/llm/model/config.py
121
122
123
def set_hparam(self, attribute: str, value: Any, also_change_value: bool = True) -> None:
    """Set the value of an attribute in the config attached to the class by the IOMixin."""
    ...

MegatronBioNeMoModelConfig

Bases: BionemoModelConfig[MegatronModelType], TransformerConfig, WillHaveGetSetHparam

A ModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires.

Source code in bionemo/llm/model/config.py
54
55
56
57
class MegatronBioNeMoModelConfig(BionemoModelConfig[MegatronModelType], TransformerConfig, iom.WillHaveGetSetHparam):
    """A ModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires."""

    model_cls: Type[MegatronModelType]

MegatronBioNeMoTrainableModelConfig dataclass

Bases: MegatronBioNeMoModelConfig[MegatronModelType], BionemoTrainableModelConfig[MegatronModelType, MegatronLossType], Generic[MegatronModelType, MegatronLossType]

A TrainableModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires.

Source code in bionemo/llm/model/config.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@dataclass
class MegatronBioNeMoTrainableModelConfig(
    MegatronBioNeMoModelConfig[MegatronModelType],
    BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
    Generic[MegatronModelType, MegatronLossType],
):
    """A TrainableModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires."""

    initial_ckpt_path: str | None = None
    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
    override_parent_fields: List[str] = field(default_factory=lambda: _OVERRIDE_BIONEMO_CONFIG_DEFAULTS)

    def load_settings_from_checkpoint(self, initial_ckpt_path: str) -> None:
        """Load settings into self from the checkpoint saved in self.

        Any setting in self.override_parent_fields is not overriden. Note that this function will also update the hyper
        parameters in this config, as well as the associated attributes in self in case they were modified post-init.

        Args:
            initial_ckpt_path: The path to the checkpoint to load, note that everything is loaded from this checkpoint
                other than the settings in self.override_parent_fields.

        Returns:
            None, the settings are loaded into self in place, and the hyper-parameters that will later be saved into
                a checkpoint are updated.
        """
        logger.warn(f"Loading {self.initial_ckpt_path}")
        # 1. get the config
        # TODO type(self) is probably not correct, maybe make the class name of the config to load an argument?
        cfg_trainer_ctx: TrainerContext = io.load_context(Path(initial_ckpt_path) / "context")
        initial_config: MegatronBioNeMoTrainableModelConfig = cfg_trainer_ctx.model.config
        initial_fields = {f.name for f in fields(initial_config)}
        my_fields = [f.name for f in fields(self)]
        skip_fields = set(self.override_parent_fields)
        override_fields = [f for f in my_fields if f in initial_fields and f not in skip_fields]
        override_mutate_possibly_extra_mutated_fiddle(self, initial_config, override_fields)

    def update_model_from_checkpoint(self, model: MegatronModelType, initial_ckpt_path: str) -> None:
        """Utility function to standardize how to load a megatron model from a checkpoint ignoring user-specified keys.

        Update the model with the weights from the provided checkpoint path, skipping the keys with the prefixes in
            self.initial_ckpt_skip_keys_with_these_prefixes.

        Args:
            model: The Megatron model to update.
            initial_ckpt_path: The path to the megatron checkpoint to load.

        Returns:
            None, the model is updated in place, supporting megatron model parallelism abstractions, and ignoring
                any extra keys that are provided in self.initial_ckpt_skip_keys_with_these_prefixes.
        """
        load_weights_sharded_inplace_nemo2_to_mcore(
            model=model,  # type: ignore
            distributed_checkpoint_dir=initial_ckpt_path,
            skip_keys_with_these_prefixes=set(self.initial_ckpt_skip_keys_with_these_prefixes),
        )

load_settings_from_checkpoint(initial_ckpt_path)

Load settings into self from the checkpoint saved in self.

Any setting in self.override_parent_fields is not overriden. Note that this function will also update the hyper parameters in this config, as well as the associated attributes in self in case they were modified post-init.

Parameters:

Name Type Description Default
initial_ckpt_path str

The path to the checkpoint to load, note that everything is loaded from this checkpoint other than the settings in self.override_parent_fields.

required

Returns:

Type Description
None

None, the settings are loaded into self in place, and the hyper-parameters that will later be saved into a checkpoint are updated.

Source code in bionemo/llm/model/config.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def load_settings_from_checkpoint(self, initial_ckpt_path: str) -> None:
    """Load settings into self from the checkpoint saved in self.

    Any setting in self.override_parent_fields is not overriden. Note that this function will also update the hyper
    parameters in this config, as well as the associated attributes in self in case they were modified post-init.

    Args:
        initial_ckpt_path: The path to the checkpoint to load, note that everything is loaded from this checkpoint
            other than the settings in self.override_parent_fields.

    Returns:
        None, the settings are loaded into self in place, and the hyper-parameters that will later be saved into
            a checkpoint are updated.
    """
    logger.warn(f"Loading {self.initial_ckpt_path}")
    # 1. get the config
    # TODO type(self) is probably not correct, maybe make the class name of the config to load an argument?
    cfg_trainer_ctx: TrainerContext = io.load_context(Path(initial_ckpt_path) / "context")
    initial_config: MegatronBioNeMoTrainableModelConfig = cfg_trainer_ctx.model.config
    initial_fields = {f.name for f in fields(initial_config)}
    my_fields = [f.name for f in fields(self)]
    skip_fields = set(self.override_parent_fields)
    override_fields = [f for f in my_fields if f in initial_fields and f not in skip_fields]
    override_mutate_possibly_extra_mutated_fiddle(self, initial_config, override_fields)

update_model_from_checkpoint(model, initial_ckpt_path)

Utility function to standardize how to load a megatron model from a checkpoint ignoring user-specified keys.

Update the model with the weights from the provided checkpoint path, skipping the keys with the prefixes in self.initial_ckpt_skip_keys_with_these_prefixes.

Parameters:

Name Type Description Default
model MegatronModelType

The Megatron model to update.

required
initial_ckpt_path str

The path to the megatron checkpoint to load.

required

Returns:

Type Description
None

None, the model is updated in place, supporting megatron model parallelism abstractions, and ignoring any extra keys that are provided in self.initial_ckpt_skip_keys_with_these_prefixes.

Source code in bionemo/llm/model/config.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def update_model_from_checkpoint(self, model: MegatronModelType, initial_ckpt_path: str) -> None:
    """Utility function to standardize how to load a megatron model from a checkpoint ignoring user-specified keys.

    Update the model with the weights from the provided checkpoint path, skipping the keys with the prefixes in
        self.initial_ckpt_skip_keys_with_these_prefixes.

    Args:
        model: The Megatron model to update.
        initial_ckpt_path: The path to the megatron checkpoint to load.

    Returns:
        None, the model is updated in place, supporting megatron model parallelism abstractions, and ignoring
            any extra keys that are provided in self.initial_ckpt_skip_keys_with_these_prefixes.
    """
    load_weights_sharded_inplace_nemo2_to_mcore(
        model=model,  # type: ignore
        distributed_checkpoint_dir=initial_ckpt_path,
        skip_keys_with_these_prefixes=set(self.initial_ckpt_skip_keys_with_these_prefixes),
    )

override_mutate_possibly_extra_mutated_fiddle(target_cfg, source_cfg, maybe_mutated_elements_to_clone)

Override the values of the target config with the values of the source config for the given elements.

This will modify the tracked init hyper-parameter values, as well as modifying the associated attributes in self incase they were modified later by post_init code.

Parameters:

Name Type Description Default
target_cfg IOMixinProto

The config to update.

required
source_cfg IOMixinProto

The config to copy values from.

required
maybe_mutated_elements_to_clone List[str]

The list of elements to copy from the source config to the target config.

required

Returns:

Type Description
None

None, the target config is updated in place.

Source code in bionemo/llm/model/config.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def override_mutate_possibly_extra_mutated_fiddle(
    target_cfg: IOMixinProto, source_cfg: IOMixinProto, maybe_mutated_elements_to_clone: List[str]
) -> None:
    """Override the values of the target config with the values of the source config for the given elements.

    This will modify the tracked init hyper-parameter values, as well as modifying the associated attributes in
        self incase they were modified later by post_init code.

    Args:
        target_cfg: The config to update.
        source_cfg: The config to copy values from.
        maybe_mutated_elements_to_clone: The list of elements to copy from the source config to the target config.

    Returns:
        None, the target config is updated in place.
    """
    for f in maybe_mutated_elements_to_clone:
        # 1. Update the tracked config values. Note that the associated attribute in self may have been modified
        #  post-init, so we don't want to change the value in self here. We do that separately next.
        target_cfg.set_hparam(f, source_cfg.get_hparam(f), also_change_value=False)
        # 2. Update the lazily untracked values (if the same variable name is used post-init)
        setattr(target_cfg, f, getattr(source_cfg, f))