Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ qkv_proj: 'remat'
out_proj: 'remat'
mla_q: 'remat'
mla_kv: 'remat'
attention_out: 'remat'

optimizer_memory_host_offload: False
parameter_memory_host_offload: False
Expand Down
6 changes: 6 additions & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,11 @@ class RematAndOffload(BaseModel):
RematLocation.REMAT,
description="Remat policy for the mla's key and value projection.",
)
attention_out: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the attention output.",
)

optimizer_memory_host_offload: bool = Field(False, description="Offload optimizer state to host memory.")
parameter_memory_host_offload: bool = Field(False, description="Offload parameters to host memory.")

Expand Down Expand Up @@ -2060,6 +2065,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
"mla_kv",
"mla_q",
"qkv_proj",
"attention_out",
"out_proj",
]
self.tensors_on_device = [t for t in tensors if getattr(self, t) == "device"]
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,7 @@ def __call__(
# Pass the index_mask to the Attention Op
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values, index_mask=index_mask)

out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
out = self._maybe_shard_with_logical(out, self.ep_out_axis_names)
else:
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,7 @@ def __call__(
bidirectional_mask,
self.sinks,
)
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
if model_mode == MODEL_MODE_PREFILL:
out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
Expand Down
9 changes: 9 additions & 0 deletions src/MaxText/layers/deepseek_batchsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def mla(
qk_nope_head_dim=qk_nope_head_dim,
mscale=mscale,
)
query = jax.ad_checkpoint.checkpoint_name(query, "query_proj")
key, value = kv_projection(
inputs,
positions,
Expand All @@ -355,6 +356,8 @@ def mla(
qk_nope_head_dim=qk_nope_head_dim,
num_query_heads=num_query_heads,
)
key = jax.ad_checkpoint.checkpoint_name(key, "key_proj")
value = jax.ad_checkpoint.checkpoint_name(value, "value_proj")
out = attention_op_fn(
query,
key,
Expand All @@ -363,7 +366,9 @@ def mla(
model_mode,
cached_values=[None, None],
)
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
out = dot(out, out_weights, axes=2)
out = jax.ad_checkpoint.checkpoint_name(out, "out_proj")
return out


Expand Down Expand Up @@ -402,6 +407,7 @@ def query_projection(
epsilon=epsilon,
dtype=dtype,
)
low_rank_q = jax.ad_checkpoint.checkpoint_name(low_rank_q, "mla_q")
q = dot(low_rank_q, wq_b_weights)

# Split into non-positional and rotary parts.
Expand Down Expand Up @@ -451,6 +457,7 @@ def kv_projection(
epsilon=kv_norm_epsilon,
dtype=dtype,
)
low_rank_main = jax.ad_checkpoint.checkpoint_name(low_rank_main, "mla_kv")
key_rope = jnp.expand_dims(low_rank_rope, axis=2)
key_rope = yarn(
key_rope,
Expand Down Expand Up @@ -690,6 +697,8 @@ def compute(x, w0, w1, wo, group_sizes, weights, *, wi_tile_size, wo_tile_size,
)
layer_w0 = gmm_fn(x, w0, tiling=wi_tile_size)
layer_w1 = gmm_fn(x, w1, tiling=wi_tile_size)
layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0")
layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1")
intermediate_layer = jax.nn.silu(layer_w0) * layer_w1
intermediate_layer *= weights[:, None]
return gmm_fn(intermediate_layer, wo, tiling=wo_tile_size)
Expand Down
Loading