Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions pathwaysutils/experimental/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pathwaysutils import jax as pw_jax
from pathwaysutils import lru_cache
from pathwaysutils import plugin_executable
from pathwaysutils.experimental import split_by_mesh_axis


_logger = logging.getLogger(__name__)
Expand All @@ -34,6 +35,11 @@
INTERMEDIATE_REPLICA_SUFFIX = "_intermediate_replica"


def _identity(x: Any) -> Any:
"""A helper function that returns its input."""
return x


class ReshardingPlanWrapper:
"""Wrapper around PluginProgram(reshard_request)."""

Expand Down Expand Up @@ -556,3 +562,55 @@ def find_intermediate_sharding(
memory_kind=in_sharding.memory_kind,
)
return intermediate_sharding, replicated_axes


def reshard_with_intermediate_sharding(
x: Any,
in_sharding: jax.sharding.Sharding,
out_sharding: jax.sharding.Sharding,
*,
donate: bool = False,
may_alias: bool | None = None, # pylint: disable=unused-argument
) -> Any:
"""Reshards `x` to `out_sharding`, using an intermediate sharding if possible.

This function is an alternative to `reshard` that may be faster and sometime
essential for certain sharding combinations by using an intermediate sharding
to avoid expensive all-gathers. If no beneficial intermediate sharding is
found, it falls back to standard resharding. See `find_intermediate_sharding`
for more details on when an intermediate sharding is used.

Args:
x: An array, scalar, or (nested) standard Python container thereof.
in_sharding: The source sharding of `x`.
out_sharding: The target sharding for `x`.
donate: If `True`, donate all input arrays, which may reduce the amount of
memory needed for resharding. Buffers donated to resharding should not be
reused.
may_alias: If `True`, may alias the input array with the output array. May
reduce the amount of memory needed for resharding. Not used at the moment.

Returns:
A copy of `x` whose sharding is `out_sharding`.
"""

try:
intermediate_sharding, replicated_axes_names = find_intermediate_sharding(
in_sharding, out_sharding
)
except NoIntermediateShardingError as e:
_logger.debug("No intermediate sharding needed or found. %s", e)
x_to_reshard = x
else:
x_to_reshard = jax.jit(
_identity,
out_shardings=intermediate_sharding,
)(x)
for split_axis in replicated_axes_names:
x_to_reshard, *_ = split_by_mesh_axis.split_by_mesh_axis(
x_to_reshard,
split_axis,
donate=donate,
)

return reshard(x_to_reshard, out_sharding, donate=donate, may_alias=may_alias)
Loading