Skip to content

Api

ESM2Config dataclass

Bases: ESM2GenericConfig, IOMixinWithGettersSetters

Configuration class for ESM2 model.

Source code in bionemo/esm2/model/model.py
342
343
344
345
346
347
348
@dataclass
class ESM2Config(ESM2GenericConfig, iom.IOMixinWithGettersSetters):
    """Configuration class for ESM2 model."""

    model_cls: Type[ESM2Model] = ESM2Model
    num_layers: int = 33  # 650M
    hidden_size: int = 1280  # 650M

ESM2GenericConfig dataclass

Bases: BioBertConfig[ESM2ModelT, MegatronLossType]

Configuration class for ESM2 model.

Attributes:

Name Type Description
num_layers int

Number of layers in the model.

hidden_size int

Hidden size of the model.

num_attention_heads int

Number of attention heads in the model.

ffn_hidden_size int

Hidden size of the feed-forward network.

hidden_dropout float

Dropout rate for hidden layers.

attention_dropout float

Dropout rate for attention layers.

apply_residual_connection_post_layernorm bool

Whether to apply residual connection after layer normalization.

layernorm_epsilon float

Epsilon value for layer normalization.

layernorm_zero_centered_gamma float

Whether to zero-center the gamma parameter in layer normalization.

activation_func Callable

Activation function used in the model.

init_method_std float

Standard deviation for weight initialization.

apply_query_key_layer_scaling float

Whether to apply scaling to query and key layers.

masked_softmax_fusion float

Whether to use a kernel that fuses attention softmax with its mask.

fp16_lm_cross_entropy bool

Whether to move the cross entropy unreduced loss calculation for lm head to fp16.

share_embeddings_and_output_weights bool

Whether to share embeddings and output weights.

enable_autocast bool

Whether to enable autocast for mixed precision.

biobert_spec_option BiobertSpecOption

BiobertSpecOption for the model.

position_embedding_type PositionEmbeddingKinds

Type of position embedding used in the model.

seq_length int

Length of the input sequence.

make_vocab_size_divisible_by int

Make the vocabulary size divisible by this value.

token_dropout bool

Whether to apply token dropout.

use_attention_mask bool

Whether to use attention mask.

use_esm_attention bool

Whether to use ESM attention.

attention_softmax_in_fp32 bool

Whether to use fp32 for attention softmax.

optimizer_fn Optional[Callable[[MegatronBioBertModel], Optimizer]]

Optional optimizer function for the model.

parallel_output bool

Whether to use parallel output.

rotary_base int

Base value for rotary positional encoding.

rotary_percent float

Percentage of rotary positional encoding.

seq_len_interpolation_factor Optional[float]

Interpolation factor for sequence length.

get_attention_mask_from_fusion Optional[float]

Whether to get attention mask from fusion.

nemo1_ckpt_path str | None

Path to NEMO1 checkpoint.

return_only_hidden_states bool

Whether to return only hidden states.

loss_reduction_class bool

Loss reduction class for the model. Default to BERTMLMLossWithReduction.

Source code in bionemo/esm2/model/model.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
@dataclass
class ESM2GenericConfig(BioBertConfig[ESM2ModelT, MegatronLossType]):
    """Configuration class for ESM2 model.

    Attributes:
        num_layers: Number of layers in the model.
        hidden_size: Hidden size of the model.
        num_attention_heads: Number of attention heads in the model.
        ffn_hidden_size: Hidden size of the feed-forward network.
        hidden_dropout: Dropout rate for hidden layers.
        attention_dropout: Dropout rate for attention layers.
        apply_residual_connection_post_layernorm: Whether to apply residual connection after layer normalization.
        layernorm_epsilon: Epsilon value for layer normalization.
        layernorm_zero_centered_gamma: Whether to zero-center the gamma parameter in layer normalization.
        activation_func: Activation function used in the model.
        init_method_std: Standard deviation for weight initialization.
        apply_query_key_layer_scaling: Whether to apply scaling to query and key layers.
        masked_softmax_fusion: Whether to use a kernel that fuses attention softmax with its mask.
        fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16.
        share_embeddings_and_output_weights: Whether to share embeddings and output weights.
        enable_autocast: Whether to enable autocast for mixed precision.
        biobert_spec_option: BiobertSpecOption for the model.
        position_embedding_type: Type of position embedding used in the model.
        seq_length: Length of the input sequence.
        make_vocab_size_divisible_by: Make the vocabulary size divisible by this value.
        token_dropout: Whether to apply token dropout.
        use_attention_mask: Whether to use attention mask.
        use_esm_attention: Whether to use ESM attention.
        attention_softmax_in_fp32: Whether to use fp32 for attention softmax.
        optimizer_fn: Optional optimizer function for the model.
        parallel_output: Whether to use parallel output.
        rotary_base: Base value for rotary positional encoding.
        rotary_percent: Percentage of rotary positional encoding.
        seq_len_interpolation_factor: Interpolation factor for sequence length.
        get_attention_mask_from_fusion: Whether to get attention mask from fusion.
        nemo1_ckpt_path: Path to NEMO1 checkpoint.
        return_only_hidden_states: Whether to return only hidden states.
        loss_reduction_class: Loss reduction class for the model. Default to BERTMLMLossWithReduction.
    """

    # When overriding fields in a dataclass _always_ declare types: https://github.com/python/cpython/issues/123269
    model_cls: Type[ESM2ModelT] = ESM2Model
    num_layers: int = 33  # 650M
    hidden_size: int = 1280  # 650M
    num_attention_heads: int = 20
    ffn_hidden_size: int = 4 * 1280  # Transformer FFN hidden size. Usually 4 * hidden_size.
    hidden_dropout: float = 0  # ESM2 removes dropout from hidden layers and attention
    attention_dropout: float = 0.0  # ESM2 does not use attention dropout
    apply_residual_connection_post_layernorm: bool = False  # TODO: farhadr False is new default, True was BERT pub.
    layernorm_epsilon: float = 1.0e-5
    bias_activation_fusion: bool = True  # True degrades accuracy slightly, but is faster.
    activation_func: Callable = F.gelu  # esm_gelu_func  # ESM2 MLP
    init_method_std: float = 0.02

    # embedding
    token_dropout: bool = True
    use_attention_mask: bool = True

    # core attention
    use_esm_attention: bool = False  # Skip ESM2 custom attention for TE acceleration. Still passes golden value test.
    attention_softmax_in_fp32: bool = False
    normalize_attention_scores: bool = False

    # From megatron.core.models.gpt.bert_model.GPTModel
    fp16_lm_cross_entropy: bool = False  # Move the cross entropy unreduced loss calculation for lm head to fp16
    parallel_output: bool = True
    share_embeddings_and_output_weights: bool = True
    make_vocab_size_divisible_by: int = 128
    position_embedding_type: PositionEmbeddingKinds = "rope"  # ESM2 uses relative positional encoding 'ROPE' to extrapolate to longer sequences unseen during training
    rotary_base: int = 10000
    rotary_percent: float = 1.0
    seq_len_interpolation_factor: Optional[float] = None
    seq_length: int = 1024
    biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec

    optimizer_fn: Optional[Callable[[MegatronBioBertModel], Optimizer]] = None
    # TODO (@skothenhill,@georgea) update to use the nemo2 checkpoint mixins
    #  support HF (requires weight interleaving on qkv layer) and nemo1 checkpoints ideally.
    nemo1_ckpt_path: str | None = None
    # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in
    #  self.override_parent_fields will be loaded from the checkpoint and override those values here.
    initial_ckpt_path: str | None = None
    # TODO (@jstjohn) come up with a cleaner way in the biobert module to return user requested
    #  things as part of the workflow for inference and fine-tuning.
    return_embeddings: bool = False
    include_embeddings: bool = False
    skip_logits: bool = False
    return_only_hidden_states: bool = False  # return logits

    def __post_init__(self):
        # TODO, as a validator?
        """Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization."""
        super().__post_init__()
        if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
            self.apply_query_key_layer_scaling = False
            self.core_attention_override = ESM2TEDotProductAttention
        elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
            logging.warning(
                "BiobertSpecOption.esm2_bert_layer_local_spec is depreciated. Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
            )
            self.apply_query_key_layer_scaling = True
            self.core_attention_override = ESM2DotProductAttention
        else:
            raise ValueError(f"Unknown biobert_spec_option: {self.biobert_spec_option}")

__post_init__()

Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization.

Source code in bionemo/esm2/model/model.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def __post_init__(self):
    # TODO, as a validator?
    """Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization."""
    super().__post_init__()
    if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
        self.apply_query_key_layer_scaling = False
        self.core_attention_override = ESM2TEDotProductAttention
    elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
        logging.warning(
            "BiobertSpecOption.esm2_bert_layer_local_spec is depreciated. Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
        )
        self.apply_query_key_layer_scaling = True
        self.core_attention_override = ESM2DotProductAttention
    else:
        raise ValueError(f"Unknown biobert_spec_option: {self.biobert_spec_option}")

ESM2Model

Bases: MegatronBioBertModel

ESM2 Transformer language model.

Source code in bionemo/esm2/model/model.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
class ESM2Model(MegatronBioBertModel):
    """ESM2 Transformer language model."""

    def __init__(
        self,
        config: TransformerConfig,
        num_tokentypes: int,
        transformer_layer_spec: spec_utils.ModuleSpec,
        vocab_size: int,
        max_sequence_length: int,
        tokenizer: Optional[BioNeMoESMTokenizer] = None,
        pre_process: bool = True,
        post_process: bool = True,
        fp16_lm_cross_entropy: bool = False,
        parallel_output: bool = True,
        share_embeddings_and_output_weights: bool = False,
        position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute",
        rotary_percent: float = 1.0,
        seq_len_interpolation_factor: Optional[float] = None,
        add_binary_head: bool = True,
        return_embeddings: bool = False,
        include_embeddings: bool = False,
        use_full_attention_mask: bool = False,
        include_hiddens: bool = False,
        skip_logits: bool = False,
    ) -> None:
        """Initialize the ESM2 model.

        Args:
            config (TransformerConfig): transformer config
            num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.
            transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
            vocab_size (int): vocabulary size
            max_sequence_length (int): maximum size of sequence. This is used for positional embedding
            tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model)
            pre_process (bool): Include embedding layer (used with pipeline parallelism)
            post_process (bool): Include an output layer (used with pipeline parallelism)
            fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16.
            parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
            share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False.
            position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
                Defaults is 'learned_absolute'.
            rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
                Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
            seq_len_interpolation_factor (Optional[float]): Interpolation factor for sequence length. Defaults to None.
            add_binary_head (bool): Whether to add a binary head. Defaults to True.
            return_embeddings (bool): Whether to return embeddings. Defaults to False.
            include_embeddings (bool): Whether to include embeddings in the output dictionary. Defaults to False.
            use_full_attention_mask (bool): Whether to use full attention mask. Defaults to False.
            include_hiddens (bool): Whether to include hidden states in the output dictionary. Defaults to False.
            skip_logits (bool): Skip writing the token logits in output dict
        """
        super(MegatronBioBertModel, self).__init__(config=config)
        self.post_process = post_process
        self.add_binary_head = add_binary_head
        if return_embeddings:
            assert self.post_process, "only return embeddings on the last pipeline stage"
        # `b` = batch, `s` = sequence.
        # The old flash attention mechanism apparently wants you to use a b x 1 x s x s attention mask while
        #  the new one wants a b x 1 x 1 x s attention mask. This is a hack to allow us to switch between the two.
        self.use_full_attention_mask = use_full_attention_mask
        self.config: TransformerConfig = config
        self.transformer_layer_spec: spec_utils.ModuleSpec = transformer_layer_spec
        self.vocab_size = vocab_size
        self.max_sequence_length = max_sequence_length
        self.pre_process = pre_process
        self.post_process = post_process
        self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
        self.parallel_output = parallel_output
        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
        self.position_embedding_type = position_embedding_type
        self.add_binary_head = add_binary_head
        self.return_embeddings = return_embeddings
        self.include_embeddings = include_embeddings
        self.include_hiddens = include_hiddens
        self.skip_logits = skip_logits

        # megatron core pipelining currently depends on model type
        self.model_type = ModelType.encoder_or_decoder

        # Embeddings.
        if self.pre_process:
            # ESM2 Customization: ESM2Embedding instead of LanguageModelEmbedding
            # TODO: call super, overwrite the self.embedding, and setup_embeddings_and_output_layer in constructor.
            # Note: need to avoid calling setup twice: skip with super (super(skip_setup=True))
            self.embedding = ESM2Embedding(
                config=self.config,
                vocab_size=self.vocab_size,
                max_sequence_length=self.max_sequence_length,
                position_embedding_type=position_embedding_type,
                num_tokentypes=num_tokentypes,
                # ESM2 NEW ARGS
                token_dropout=self.config.token_dropout,
                use_attention_mask=self.config.use_attention_mask,
                mask_token_id=tokenizer.mask_token_id,
            )

        if self.position_embedding_type == "rope":
            self.rotary_pos_emb = RotaryEmbedding(
                kv_channels=self.config.kv_channels,
                rotary_percent=rotary_percent,
                rotary_interleaved=self.config.rotary_interleaved,
                seq_len_interpolation_factor=seq_len_interpolation_factor,
            )

        # Transformer.
        self.encoder = TransformerBlock(
            config=self.config,
            spec=self.transformer_layer_spec,
            pre_process=self.pre_process,
            post_process=self.post_process,
        )

        # Output
        if post_process:
            # TODO: Make sure you are passing in the mpu_vocab_size properly
            self.lm_head = BertLMHead(
                config.hidden_size,
                config,
            )

            self.output_layer = tensor_parallel.ColumnParallelLinear(
                config.hidden_size,
                self.vocab_size,
                config=config,
                init_method=config.init_method,
                bias=True,
                skip_bias_add=False,
                gather_output=not self.parallel_output,
                skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
            )

            self.binary_head = None
            if self.add_binary_head:
                # TODO: Shoudl switch this to TE ?
                self.binary_head = get_linear_layer(
                    config.hidden_size, 2, config.init_method, config.perform_initialization
                )

                self.pooler = Pooler(config.hidden_size, config.init_method, config, config.sequence_parallel)
        if self.pre_process or self.post_process:
            self.setup_embeddings_and_output_layer()

    def embedding_forward(
        self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: Tensor = None, attention_mask: Tensor = None
    ):
        """Forward pass of the embedding layer.

        Args:
            input_ids: The input tensor of shape (batch_size, sequence_length) containing the input IDs.
            position_ids: The tensor of shape (batch_size, sequence_length) containing the position IDs.
            tokentype_ids: The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None.
            attention_mask: The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None.

        Returns:
            Tensor: The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations.
        """
        # ESM2 Customization: ESM2Embedding forward takes attention_mask
        # in addition to the args required by LanguageModelEmbedding
        return self.embedding(
            input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids, attention_mask=attention_mask
        )

__init__(config, num_tokentypes, transformer_layer_spec, vocab_size, max_sequence_length, tokenizer=None, pre_process=True, post_process=True, fp16_lm_cross_entropy=False, parallel_output=True, share_embeddings_and_output_weights=False, position_embedding_type='learned_absolute', rotary_percent=1.0, seq_len_interpolation_factor=None, add_binary_head=True, return_embeddings=False, include_embeddings=False, use_full_attention_mask=False, include_hiddens=False, skip_logits=False)

Initialize the ESM2 model.

Parameters:

Name Type Description Default
config TransformerConfig

transformer config

required
num_tokentypes int

Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.

required
transformer_layer_spec ModuleSpec

Specifies module to use for transformer layers

required
vocab_size int

vocabulary size

required
max_sequence_length int

maximum size of sequence. This is used for positional embedding

required
tokenizer AutoTokenizer

optional tokenizer object (currently only used in the constructor of ESM2Model)

None
pre_process bool

Include embedding layer (used with pipeline parallelism)

True
post_process bool

Include an output layer (used with pipeline parallelism)

True
fp16_lm_cross_entropy bool

Whether to move the cross entropy unreduced loss calculation for lm head to fp16.

False
parallel_output bool

Do not gather the outputs, keep them split across tensor parallel ranks

True
share_embeddings_and_output_weights bool

When True, input embeddings and output logit weights are shared. Defaults to False.

False
position_embedding_type string

Position embedding type. Options ['learned_absolute', 'rope']. Defaults is 'learned_absolute'.

'learned_absolute'
rotary_percent float

Percent of rotary dimension to use for rotary position embeddings. Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.

1.0
seq_len_interpolation_factor Optional[float]

Interpolation factor for sequence length. Defaults to None.

None
add_binary_head bool

Whether to add a binary head. Defaults to True.

True
return_embeddings bool

Whether to return embeddings. Defaults to False.

False
include_embeddings bool

Whether to include embeddings in the output dictionary. Defaults to False.

False
use_full_attention_mask bool

Whether to use full attention mask. Defaults to False.

False
include_hiddens bool

Whether to include hidden states in the output dictionary. Defaults to False.

False
skip_logits bool

Skip writing the token logits in output dict

False
Source code in bionemo/esm2/model/model.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def __init__(
    self,
    config: TransformerConfig,
    num_tokentypes: int,
    transformer_layer_spec: spec_utils.ModuleSpec,
    vocab_size: int,
    max_sequence_length: int,
    tokenizer: Optional[BioNeMoESMTokenizer] = None,
    pre_process: bool = True,
    post_process: bool = True,
    fp16_lm_cross_entropy: bool = False,
    parallel_output: bool = True,
    share_embeddings_and_output_weights: bool = False,
    position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute",
    rotary_percent: float = 1.0,
    seq_len_interpolation_factor: Optional[float] = None,
    add_binary_head: bool = True,
    return_embeddings: bool = False,
    include_embeddings: bool = False,
    use_full_attention_mask: bool = False,
    include_hiddens: bool = False,
    skip_logits: bool = False,
) -> None:
    """Initialize the ESM2 model.

    Args:
        config (TransformerConfig): transformer config
        num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.
        transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
        vocab_size (int): vocabulary size
        max_sequence_length (int): maximum size of sequence. This is used for positional embedding
        tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model)
        pre_process (bool): Include embedding layer (used with pipeline parallelism)
        post_process (bool): Include an output layer (used with pipeline parallelism)
        fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16.
        parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
        share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False.
        position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
            Defaults is 'learned_absolute'.
        rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
            Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
        seq_len_interpolation_factor (Optional[float]): Interpolation factor for sequence length. Defaults to None.
        add_binary_head (bool): Whether to add a binary head. Defaults to True.
        return_embeddings (bool): Whether to return embeddings. Defaults to False.
        include_embeddings (bool): Whether to include embeddings in the output dictionary. Defaults to False.
        use_full_attention_mask (bool): Whether to use full attention mask. Defaults to False.
        include_hiddens (bool): Whether to include hidden states in the output dictionary. Defaults to False.
        skip_logits (bool): Skip writing the token logits in output dict
    """
    super(MegatronBioBertModel, self).__init__(config=config)
    self.post_process = post_process
    self.add_binary_head = add_binary_head
    if return_embeddings:
        assert self.post_process, "only return embeddings on the last pipeline stage"
    # `b` = batch, `s` = sequence.
    # The old flash attention mechanism apparently wants you to use a b x 1 x s x s attention mask while
    #  the new one wants a b x 1 x 1 x s attention mask. This is a hack to allow us to switch between the two.
    self.use_full_attention_mask = use_full_attention_mask
    self.config: TransformerConfig = config
    self.transformer_layer_spec: spec_utils.ModuleSpec = transformer_layer_spec
    self.vocab_size = vocab_size
    self.max_sequence_length = max_sequence_length
    self.pre_process = pre_process
    self.post_process = post_process
    self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
    self.parallel_output = parallel_output
    self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
    self.position_embedding_type = position_embedding_type
    self.add_binary_head = add_binary_head
    self.return_embeddings = return_embeddings
    self.include_embeddings = include_embeddings
    self.include_hiddens = include_hiddens
    self.skip_logits = skip_logits

    # megatron core pipelining currently depends on model type
    self.model_type = ModelType.encoder_or_decoder

    # Embeddings.
    if self.pre_process:
        # ESM2 Customization: ESM2Embedding instead of LanguageModelEmbedding
        # TODO: call super, overwrite the self.embedding, and setup_embeddings_and_output_layer in constructor.
        # Note: need to avoid calling setup twice: skip with super (super(skip_setup=True))
        self.embedding = ESM2Embedding(
            config=self.config,
            vocab_size=self.vocab_size,
            max_sequence_length=self.max_sequence_length,
            position_embedding_type=position_embedding_type,
            num_tokentypes=num_tokentypes,
            # ESM2 NEW ARGS
            token_dropout=self.config.token_dropout,
            use_attention_mask=self.config.use_attention_mask,
            mask_token_id=tokenizer.mask_token_id,
        )

    if self.position_embedding_type == "rope":
        self.rotary_pos_emb = RotaryEmbedding(
            kv_channels=self.config.kv_channels,
            rotary_percent=rotary_percent,
            rotary_interleaved=self.config.rotary_interleaved,
            seq_len_interpolation_factor=seq_len_interpolation_factor,
        )

    # Transformer.
    self.encoder = TransformerBlock(
        config=self.config,
        spec=self.transformer_layer_spec,
        pre_process=self.pre_process,
        post_process=self.post_process,
    )

    # Output
    if post_process:
        # TODO: Make sure you are passing in the mpu_vocab_size properly
        self.lm_head = BertLMHead(
            config.hidden_size,
            config,
        )

        self.output_layer = tensor_parallel.ColumnParallelLinear(
            config.hidden_size,
            self.vocab_size,
            config=config,
            init_method=config.init_method,
            bias=True,
            skip_bias_add=False,
            gather_output=not self.parallel_output,
            skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
        )

        self.binary_head = None
        if self.add_binary_head:
            # TODO: Shoudl switch this to TE ?
            self.binary_head = get_linear_layer(
                config.hidden_size, 2, config.init_method, config.perform_initialization
            )

            self.pooler = Pooler(config.hidden_size, config.init_method, config, config.sequence_parallel)
    if self.pre_process or self.post_process:
        self.setup_embeddings_and_output_layer()

embedding_forward(input_ids, position_ids, tokentype_ids=None, attention_mask=None)

Forward pass of the embedding layer.

Parameters:

Name Type Description Default
input_ids Tensor

The input tensor of shape (batch_size, sequence_length) containing the input IDs.

required
position_ids Tensor

The tensor of shape (batch_size, sequence_length) containing the position IDs.

required
tokentype_ids Tensor

The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None.

None
attention_mask Tensor

The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None.

None

Returns:

Name Type Description
Tensor

The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations.

Source code in bionemo/esm2/model/model.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def embedding_forward(
    self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: Tensor = None, attention_mask: Tensor = None
):
    """Forward pass of the embedding layer.

    Args:
        input_ids: The input tensor of shape (batch_size, sequence_length) containing the input IDs.
        position_ids: The tensor of shape (batch_size, sequence_length) containing the position IDs.
        tokentype_ids: The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None.
        attention_mask: The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None.

    Returns:
        Tensor: The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations.
    """
    # ESM2 Customization: ESM2Embedding forward takes attention_mask
    # in addition to the args required by LanguageModelEmbedding
    return self.embedding(
        input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids, attention_mask=attention_mask
    )