diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index c803fc10e..e0349f472 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -310,6 +310,7 @@ qkv_proj: 'remat' out_proj: 'remat' mla_q: 'remat' mla_kv: 'remat' +attention_out: 'remat' optimizer_memory_host_offload: False parameter_memory_host_offload: False diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 40c2404b7..d7a683f71 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -872,6 +872,11 @@ class RematAndOffload(BaseModel): RematLocation.REMAT, description="Remat policy for the mla's key and value projection.", ) + attention_out: RematLocation = Field( + RematLocation.REMAT, + description="Remat policy for the attention output.", + ) + optimizer_memory_host_offload: bool = Field(False, description="Offload optimizer state to host memory.") parameter_memory_host_offload: bool = Field(False, description="Offload parameters to host memory.") @@ -2060,6 +2065,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de "mla_kv", "mla_q", "qkv_proj", + "attention_out", "out_proj", ] self.tensors_on_device = [t for t in tensors if getattr(self, t) == "device"] diff --git a/src/MaxText/layers/attention_mla.py b/src/MaxText/layers/attention_mla.py index d270dc8e5..7d28d45fd 100644 --- a/src/MaxText/layers/attention_mla.py +++ b/src/MaxText/layers/attention_mla.py @@ -1038,6 +1038,7 @@ def __call__( # Pass the index_mask to the Attention Op out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values, index_mask=index_mask) + out = jax.ad_checkpoint.checkpoint_name(out, "attention_out") if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: out = self._maybe_shard_with_logical(out, self.ep_out_axis_names) else: diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index a7afbd465..8f7c63fa4 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -1132,6 +1132,7 @@ def __call__( bidirectional_mask, self.sinks, ) + out = jax.ad_checkpoint.checkpoint_name(out, "attention_out") if model_mode == MODEL_MODE_PREFILL: out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names) elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: diff --git a/src/MaxText/layers/deepseek_batchsplit.py b/src/MaxText/layers/deepseek_batchsplit.py index ef75bfa1b..ec79fed97 100644 --- a/src/MaxText/layers/deepseek_batchsplit.py +++ b/src/MaxText/layers/deepseek_batchsplit.py @@ -336,6 +336,7 @@ def mla( qk_nope_head_dim=qk_nope_head_dim, mscale=mscale, ) + query = jax.ad_checkpoint.checkpoint_name(query, "query_proj") key, value = kv_projection( inputs, positions, @@ -355,6 +356,8 @@ def mla( qk_nope_head_dim=qk_nope_head_dim, num_query_heads=num_query_heads, ) + key = jax.ad_checkpoint.checkpoint_name(key, "key_proj") + value = jax.ad_checkpoint.checkpoint_name(value, "value_proj") out = attention_op_fn( query, key, @@ -363,7 +366,9 @@ def mla( model_mode, cached_values=[None, None], ) + out = jax.ad_checkpoint.checkpoint_name(out, "attention_out") out = dot(out, out_weights, axes=2) + out = jax.ad_checkpoint.checkpoint_name(out, "out_proj") return out @@ -402,6 +407,7 @@ def query_projection( epsilon=epsilon, dtype=dtype, ) + low_rank_q = jax.ad_checkpoint.checkpoint_name(low_rank_q, "mla_q") q = dot(low_rank_q, wq_b_weights) # Split into non-positional and rotary parts. @@ -451,6 +457,7 @@ def kv_projection( epsilon=kv_norm_epsilon, dtype=dtype, ) + low_rank_main = jax.ad_checkpoint.checkpoint_name(low_rank_main, "mla_kv") key_rope = jnp.expand_dims(low_rank_rope, axis=2) key_rope = yarn( key_rope, @@ -690,6 +697,8 @@ def compute(x, w0, w1, wo, group_sizes, weights, *, wi_tile_size, wo_tile_size, ) layer_w0 = gmm_fn(x, w0, tiling=wi_tile_size) layer_w1 = gmm_fn(x, w1, tiling=wi_tile_size) + layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0") + layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1") intermediate_layer = jax.nn.silu(layer_w0) * layer_w1 intermediate_layer *= weights[:, None] return gmm_fn(intermediate_layer, wo, tiling=wo_tile_size)