Skip to content

Attention

ESM2DotProductAttention

Bases: DotProductAttention

ESM2-Specific core attention.

Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.

We use the following notation

h: hidden size n: number of attention heads p: number of tensor model parallel partitions b: batch size s: sequence length

Source code in bionemo/esm2/model/attention.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
class ESM2DotProductAttention(DotProductAttention):
    """ESM2-Specific core attention.

    Region where selective activation recomputation is applied.
    This region is memory intensive but less compute intensive which
    makes activation checkpointing more efficient for LLMs (20B+).
    See Reducing Activation Recomputation in Large Transformer Models:
    https://arxiv.org/abs/2205.05198 for more details.

    We use the following notation:
     h: hidden size
     n: number of attention heads
     p: number of tensor model parallel partitions
     b: batch size
     s: sequence length
    """

    def __init__(
        self,
        config: TransformerConfig,
        layer_number: int,
        attn_mask_type: AttnMaskType,
        attention_type: str,
        attention_dropout: Optional[float] = None,
    ) -> None:
        """Initializes the Attention class.

        Args:
            config: The configuration object for the transformer.
            layer_number: The layer number of the attention module.
            attn_mask_type: The type of attention mask to be used.
            attention_type: The type of attention mechanism.
            attention_dropout: The dropout rate for attention weights. Defaults to None.
        """
        super().__init__(
            config=config,
            layer_number=layer_number,
            attn_mask_type=attn_mask_type,
            attention_type=attention_type,
            attention_dropout=attention_dropout,
        )

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        attention_mask: Tensor,
        attn_mask_type: Optional[AttnMaskType] = None,
        packed_seq_params: Optional[PackedSeqParams] = None,
    ):
        """Forward pass of the ESM2DotProductAttention module.

        Args:
            query: The query tensor of shape [sq, b, np, hn].
            key: The key tensor of shape [sk, b, ng, hn].
            value: The value tensor of shape [sk, b, ng, hn].
            attention_mask: The attention mask tensor of shape [b, np, sq, sk].
            attn_mask_type: The attention mask type, currently unused. Defaults to None.
            packed_seq_params: The packed sequence parameters. These are used for context parallelism so will be needed
                to be implemented if we want to support this. Defaults to None.

        Returns:
            Tensor: The context tensor of shape [sq, b, hp].
        """
        if packed_seq_params is not None:
            raise ValueError(
                "Packed sequence is not supported by DotProductAttention. " "Please use TEDotProductAttention instead."
            )

        # ===================================
        # Raw attention scores. [b, n/p, s, s]
        # ===================================

        # expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn]
        # This is a noop for normal attention where ng == np. When using group query attention this
        # creates a view that has the keys and values virtually repeated along their dimension to
        # match the number of queries.

        # attn_mask_type is not used.
        if (np_ng := self.num_attention_heads_per_partition // self.num_query_groups_per_partition) > 1:
            key = key.repeat_interleave(np_ng, dim=2)
            value = value.repeat_interleave(np_ng, dim=2)

        # [b, np, sq, sk]
        b, np, sq, sk = query.size(1), query.size(2), query.size(0), key.size(0)

        # [sq, b, np, hn] -> [sq, b * np, hn]
        # This will be a simple view when doing normal attention, but in group query attention
        # the key and value tensors are repeated to match the queries so you can't use simple strides
        # to extract the queries.
        query = query.reshape(sq, b * np, -1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key = key.view(sk, b * np, -1)

        # preallocting input tensor: [b * np, sq, sk]
        matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
            (b * np, sq, sk),
            query.dtype,
            "mpu",
        )

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
            matmul_input_buffer,
            query.transpose(0, 1),  # [b * np, sq, hn]
            key.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
            alpha=(1.0 / self.norm_factor) if self.config.normalize_attention_scores else 1.0,
        )

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(b, np, sq, sk)

        # ===========================
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, sq, sk]
        # ESM2 Customization
        if self.config.use_esm_attention:
            # NOTE: the slicing here is to make the attention_mask the same shape as the extended
            # attention mask in ESM2. The multiplication by -3.4028e+38 (float32 min_val) is
            # similarly motivated by ESM2's masking approach, which forces softmax of attention scores
            # for masked entries to be close to 0. This number is replaced with min_val of the precision
            # using min_val instead of -inf is stable in an special case where all sequence is masked
            min_val = torch.finfo(attention_scores.dtype).min

            attention_probs: Tensor = self.esm2_scale_mask_softmax(
                attention_scores.masked_fill(attention_mask[:, :, 0:1, :].to(bool), min_val)
            )
        # END ESM2 Customization
        else:
            attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.

        if not self.config.sequence_parallel:
            with tensor_parallel.get_cuda_rng_tracker().fork():
                attention_probs = self.attention_dropout(attention_probs)
        else:
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        b, np, sq, hn = value.size(1), value.size(2), query.size(0), value.size(3)

        # change view [sk, b * np, hn]
        value = value.view(value.size(0), b * np, -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(b * np, sq, -1)

        # matmul: [b * np, sq, hn]
        context = torch.bmm(attention_probs, value.transpose(0, 1))

        # change view [b, np, sq, hn]
        context = context.view(b, np, sq, hn)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context = context.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        context = context.view(sq, b, self.hidden_size_per_partition)

        return context

    def esm2_scale_mask_softmax(
        self,
        input: Tensor,
        mask: Optional[Tensor] = None,
        scale: Optional[Union[float, int]] = None,
        mask_func: Optional[Callable] = None,
    ) -> Tensor:
        """Scale Mask Softmax function.

        Args:
            input: Tensor of shape (Batch, NP, SK, SQ). The input may or may not have already
                had a mask applied to it.
            mask: If a mask is to be applied, it will go here.
            scale: A scale factor that will be applied before the softmax.
            mask_func: An optional function to apply to the mask. If None, it is assumed that
                the input already had the mask applied to it.

        Returns:
            probs: Tensor of normalized probabilities after the softmax has been applied,
                of shape (Batch, NP, SK, SQ).
        """
        if self.attn_mask_type.name != "padding":
            raise ValueError(
                f"self.attn_mask_type: {self.attn_mask_type} is not 'padding'. "
                "Only 'padding' type is supported currently."
            )

        original_dtype = input.dtype  # Store original dtype
        if (
            original_dtype == torch.float16 or original_dtype == torch.bfloat16
        ) and self.config.attention_softmax_in_fp32:
            input = input.float()  # Convert to float32 for softmax

        if scale is not None:
            input = input * scale  # Apply scaling

        if mask is not None and mask_func is not None:
            input = mask_func(input, mask)  # Apply mask function if provided

        probs = torch.nn.functional.softmax(input, dim=-1)  # Apply softmax

        if self.config.attention_softmax_in_fp32 and original_dtype in (torch.float16, torch.bfloat16):
            probs = probs.to(original_dtype)  # Convert back to original dtype if necessary

        return probs

__init__(config, layer_number, attn_mask_type, attention_type, attention_dropout=None)

Initializes the Attention class.

Parameters:

Name Type Description Default
config TransformerConfig

The configuration object for the transformer.

required
layer_number int

The layer number of the attention module.

required
attn_mask_type AttnMaskType

The type of attention mask to be used.

required
attention_type str

The type of attention mechanism.

required
attention_dropout Optional[float]

The dropout rate for attention weights. Defaults to None.

None
Source code in bionemo/esm2/model/attention.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def __init__(
    self,
    config: TransformerConfig,
    layer_number: int,
    attn_mask_type: AttnMaskType,
    attention_type: str,
    attention_dropout: Optional[float] = None,
) -> None:
    """Initializes the Attention class.

    Args:
        config: The configuration object for the transformer.
        layer_number: The layer number of the attention module.
        attn_mask_type: The type of attention mask to be used.
        attention_type: The type of attention mechanism.
        attention_dropout: The dropout rate for attention weights. Defaults to None.
    """
    super().__init__(
        config=config,
        layer_number=layer_number,
        attn_mask_type=attn_mask_type,
        attention_type=attention_type,
        attention_dropout=attention_dropout,
    )

esm2_scale_mask_softmax(input, mask=None, scale=None, mask_func=None)

Scale Mask Softmax function.

Parameters:

Name Type Description Default
input Tensor

Tensor of shape (Batch, NP, SK, SQ). The input may or may not have already had a mask applied to it.

required
mask Optional[Tensor]

If a mask is to be applied, it will go here.

None
scale Optional[Union[float, int]]

A scale factor that will be applied before the softmax.

None
mask_func Optional[Callable]

An optional function to apply to the mask. If None, it is assumed that the input already had the mask applied to it.

None

Returns:

Name Type Description
probs Tensor

Tensor of normalized probabilities after the softmax has been applied, of shape (Batch, NP, SK, SQ).

Source code in bionemo/esm2/model/attention.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def esm2_scale_mask_softmax(
    self,
    input: Tensor,
    mask: Optional[Tensor] = None,
    scale: Optional[Union[float, int]] = None,
    mask_func: Optional[Callable] = None,
) -> Tensor:
    """Scale Mask Softmax function.

    Args:
        input: Tensor of shape (Batch, NP, SK, SQ). The input may or may not have already
            had a mask applied to it.
        mask: If a mask is to be applied, it will go here.
        scale: A scale factor that will be applied before the softmax.
        mask_func: An optional function to apply to the mask. If None, it is assumed that
            the input already had the mask applied to it.

    Returns:
        probs: Tensor of normalized probabilities after the softmax has been applied,
            of shape (Batch, NP, SK, SQ).
    """
    if self.attn_mask_type.name != "padding":
        raise ValueError(
            f"self.attn_mask_type: {self.attn_mask_type} is not 'padding'. "
            "Only 'padding' type is supported currently."
        )

    original_dtype = input.dtype  # Store original dtype
    if (
        original_dtype == torch.float16 or original_dtype == torch.bfloat16
    ) and self.config.attention_softmax_in_fp32:
        input = input.float()  # Convert to float32 for softmax

    if scale is not None:
        input = input * scale  # Apply scaling

    if mask is not None and mask_func is not None:
        input = mask_func(input, mask)  # Apply mask function if provided

    probs = torch.nn.functional.softmax(input, dim=-1)  # Apply softmax

    if self.config.attention_softmax_in_fp32 and original_dtype in (torch.float16, torch.bfloat16):
        probs = probs.to(original_dtype)  # Convert back to original dtype if necessary

    return probs

forward(query, key, value, attention_mask, attn_mask_type=None, packed_seq_params=None)

Forward pass of the ESM2DotProductAttention module.

Parameters:

Name Type Description Default
query Tensor

The query tensor of shape [sq, b, np, hn].

required
key Tensor

The key tensor of shape [sk, b, ng, hn].

required
value Tensor

The value tensor of shape [sk, b, ng, hn].

required
attention_mask Tensor

The attention mask tensor of shape [b, np, sq, sk].

required
attn_mask_type Optional[AttnMaskType]

The attention mask type, currently unused. Defaults to None.

None
packed_seq_params Optional[PackedSeqParams]

The packed sequence parameters. These are used for context parallelism so will be needed to be implemented if we want to support this. Defaults to None.

None

Returns:

Name Type Description
Tensor

The context tensor of shape [sq, b, hp].

Source code in bionemo/esm2/model/attention.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def forward(
    self,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attention_mask: Tensor,
    attn_mask_type: Optional[AttnMaskType] = None,
    packed_seq_params: Optional[PackedSeqParams] = None,
):
    """Forward pass of the ESM2DotProductAttention module.

    Args:
        query: The query tensor of shape [sq, b, np, hn].
        key: The key tensor of shape [sk, b, ng, hn].
        value: The value tensor of shape [sk, b, ng, hn].
        attention_mask: The attention mask tensor of shape [b, np, sq, sk].
        attn_mask_type: The attention mask type, currently unused. Defaults to None.
        packed_seq_params: The packed sequence parameters. These are used for context parallelism so will be needed
            to be implemented if we want to support this. Defaults to None.

    Returns:
        Tensor: The context tensor of shape [sq, b, hp].
    """
    if packed_seq_params is not None:
        raise ValueError(
            "Packed sequence is not supported by DotProductAttention. " "Please use TEDotProductAttention instead."
        )

    # ===================================
    # Raw attention scores. [b, n/p, s, s]
    # ===================================

    # expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn]
    # This is a noop for normal attention where ng == np. When using group query attention this
    # creates a view that has the keys and values virtually repeated along their dimension to
    # match the number of queries.

    # attn_mask_type is not used.
    if (np_ng := self.num_attention_heads_per_partition // self.num_query_groups_per_partition) > 1:
        key = key.repeat_interleave(np_ng, dim=2)
        value = value.repeat_interleave(np_ng, dim=2)

    # [b, np, sq, sk]
    b, np, sq, sk = query.size(1), query.size(2), query.size(0), key.size(0)

    # [sq, b, np, hn] -> [sq, b * np, hn]
    # This will be a simple view when doing normal attention, but in group query attention
    # the key and value tensors are repeated to match the queries so you can't use simple strides
    # to extract the queries.
    query = query.reshape(sq, b * np, -1)
    # [sk, b, np, hn] -> [sk, b * np, hn]
    key = key.view(sk, b * np, -1)

    # preallocting input tensor: [b * np, sq, sk]
    matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
        (b * np, sq, sk),
        query.dtype,
        "mpu",
    )

    # Raw attention scores. [b * np, sq, sk]
    matmul_result = torch.baddbmm(
        matmul_input_buffer,
        query.transpose(0, 1),  # [b * np, sq, hn]
        key.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
        beta=0.0,
        alpha=(1.0 / self.norm_factor) if self.config.normalize_attention_scores else 1.0,
    )

    # change view to [b, np, sq, sk]
    attention_scores = matmul_result.view(b, np, sq, sk)

    # ===========================
    # Attention probs and dropout
    # ===========================

    # attention scores and attention mask [b, np, sq, sk]
    # ESM2 Customization
    if self.config.use_esm_attention:
        # NOTE: the slicing here is to make the attention_mask the same shape as the extended
        # attention mask in ESM2. The multiplication by -3.4028e+38 (float32 min_val) is
        # similarly motivated by ESM2's masking approach, which forces softmax of attention scores
        # for masked entries to be close to 0. This number is replaced with min_val of the precision
        # using min_val instead of -inf is stable in an special case where all sequence is masked
        min_val = torch.finfo(attention_scores.dtype).min

        attention_probs: Tensor = self.esm2_scale_mask_softmax(
            attention_scores.masked_fill(attention_mask[:, :, 0:1, :].to(bool), min_val)
        )
    # END ESM2 Customization
    else:
        attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.

    if not self.config.sequence_parallel:
        with tensor_parallel.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)
    else:
        attention_probs = self.attention_dropout(attention_probs)

    # =========================
    # Context layer. [sq, b, hp]
    # =========================

    # value -> context layer.
    # [sk, b, np, hn] --> [b, np, sq, hn]

    # context layer shape: [b, np, sq, hn]
    b, np, sq, hn = value.size(1), value.size(2), query.size(0), value.size(3)

    # change view [sk, b * np, hn]
    value = value.view(value.size(0), b * np, -1)

    # change view [b * np, sq, sk]
    attention_probs = attention_probs.view(b * np, sq, -1)

    # matmul: [b * np, sq, hn]
    context = torch.bmm(attention_probs, value.transpose(0, 1))

    # change view [b, np, sq, hn]
    context = context.view(b, np, sq, hn)

    # [b, np, sq, hn] --> [sq, b, np, hn]
    context = context.permute(2, 0, 1, 3).contiguous()

    # [sq, b, np, hn] --> [sq, b, hp]
    context = context.view(sq, b, self.hidden_size_per_partition)

    return context

ESM2TEDotProductAttention

Bases: TEDotProductAttention

ESM2-Specific transformer engine core attention.

Override the softmax_scale to 1.0 to match the ESM2 implementation while keeping the rest from the original TEDotProductAttention.

Source code in bionemo/esm2/model/attention.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class ESM2TEDotProductAttention(TEDotProductAttention):
    """ESM2-Specific transformer engine core attention.

    Override the softmax_scale to 1.0 to match the ESM2 implementation while keeping the rest from the original TEDotProductAttention.
    """

    def __init__(
        self,
        config: TransformerConfig,
        layer_number: int,
        attn_mask_type: AttnMaskType,
        attention_type: str,
        attention_dropout: float | None = None,
    ):
        """Initialize ESM2TEDotProductAttention."""
        self.config = config
        self.te_forward_mask_type = False
        self.qkv_format: str = "sbhd"

        if self.config.apply_query_key_layer_scaling != bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))):
            raise ValueError(
                f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
                f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
                f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
                f"setting query key layer scaling via argument, so these two must match."
            )

        extra_kwargs = {}
        if _te_version >= packaging.version.Version("0.11.0"):
            extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
        elif self.config.num_query_groups != self.config.num_attention_heads:
            raise ValueError(
                f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
                f"use a newer version of Transformer Engine. "
                f"(num_query_groups ({self.config.num_query_groups}) != "
                f"num_attention_heads ({self.config.num_attention_heads}))"
            )

        if _te_version >= packaging.version.Version("0.10.0"):
            extra_kwargs["attention_type"] = attention_type
            # older version don't need attention_type

        if _te_version > packaging.version.Version("0.12.0"):
            self.te_forward_mask_type = True

        # Only Transformer-Engine version >= 1.0.0 supports context parallelism
        if _te_version >= packaging.version.Version("1.0.0"):
            if getattr(TEDotProductAttention, "cp_stream") is None:
                TEDotProductAttention.cp_stream = torch.cuda.Stream()
            extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
            extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(check_initialized=False)
            extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
        else:
            assert (
                self.config.context_parallel_size == 1
            ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"

        if self.config.deterministic_mode:
            if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
                raise RuntimeError(
                    "deterministic_mode is on and we are using DotProductAttention from "
                    "Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
                    f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
                )

        if config.window_size is not None:
            # Check version
            assert _te_version >= packaging.version.Version("1.2.0"), (
                f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support"
                "sliding window attention."
            )
            extra_kwargs["window_size"] = config.window_size

        super(TEDotProductAttention, self).__init__(
            num_attention_heads=self.config.num_attention_heads,
            kv_channels=self.config.kv_channels,
            attention_dropout=(self.config.attention_dropout if attention_dropout is None else attention_dropout),
            attn_mask_type=attn_mask_type.name,
            sequence_parallel=self.config.sequence_parallel,
            tp_size=self.config.tensor_model_parallel_size,
            get_rng_state_tracker=(get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None),
            tp_group=get_tensor_model_parallel_group(check_initialized=False),
            layer_number=layer_number,
            softmax_scale=1.0,  # TODO subclassing only changes softmax_scale from None to 1.0. Upstream to make this exposed without subclassing
            **extra_kwargs,
        )

__init__(config, layer_number, attn_mask_type, attention_type, attention_dropout=None)

Initialize ESM2TEDotProductAttention.

Source code in bionemo/esm2/model/attention.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def __init__(
    self,
    config: TransformerConfig,
    layer_number: int,
    attn_mask_type: AttnMaskType,
    attention_type: str,
    attention_dropout: float | None = None,
):
    """Initialize ESM2TEDotProductAttention."""
    self.config = config
    self.te_forward_mask_type = False
    self.qkv_format: str = "sbhd"

    if self.config.apply_query_key_layer_scaling != bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))):
        raise ValueError(
            f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
            f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
            f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
            f"setting query key layer scaling via argument, so these two must match."
        )

    extra_kwargs = {}
    if _te_version >= packaging.version.Version("0.11.0"):
        extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
    elif self.config.num_query_groups != self.config.num_attention_heads:
        raise ValueError(
            f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
            f"use a newer version of Transformer Engine. "
            f"(num_query_groups ({self.config.num_query_groups}) != "
            f"num_attention_heads ({self.config.num_attention_heads}))"
        )

    if _te_version >= packaging.version.Version("0.10.0"):
        extra_kwargs["attention_type"] = attention_type
        # older version don't need attention_type

    if _te_version > packaging.version.Version("0.12.0"):
        self.te_forward_mask_type = True

    # Only Transformer-Engine version >= 1.0.0 supports context parallelism
    if _te_version >= packaging.version.Version("1.0.0"):
        if getattr(TEDotProductAttention, "cp_stream") is None:
            TEDotProductAttention.cp_stream = torch.cuda.Stream()
        extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
        extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(check_initialized=False)
        extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
    else:
        assert (
            self.config.context_parallel_size == 1
        ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"

    if self.config.deterministic_mode:
        if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
            raise RuntimeError(
                "deterministic_mode is on and we are using DotProductAttention from "
                "Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
                f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
            )

    if config.window_size is not None:
        # Check version
        assert _te_version >= packaging.version.Version("1.2.0"), (
            f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support"
            "sliding window attention."
        )
        extra_kwargs["window_size"] = config.window_size

    super(TEDotProductAttention, self).__init__(
        num_attention_heads=self.config.num_attention_heads,
        kv_channels=self.config.kv_channels,
        attention_dropout=(self.config.attention_dropout if attention_dropout is None else attention_dropout),
        attn_mask_type=attn_mask_type.name,
        sequence_parallel=self.config.sequence_parallel,
        tp_size=self.config.tensor_model_parallel_size,
        get_rng_state_tracker=(get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None),
        tp_group=get_tensor_model_parallel_group(check_initialized=False),
        layer_number=layer_number,
        softmax_scale=1.0,  # TODO subclassing only changes softmax_scale from None to 1.0. Upstream to make this exposed without subclassing
        **extra_kwargs,
    )