From 3d3927b93f7040567b7fb6129b5e35e1e6f61472 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Tue, 10 Feb 2026 10:34:40 -0800 Subject: [PATCH] This change introduces reshard_with_intermediate_sharding which will first look for intermediate shardings, perform all intermediate resharding, and then perform the final reshard into the out sharding. PiperOrigin-RevId: 868214512 --- pathwaysutils/experimental/reshard.py | 58 +++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/pathwaysutils/experimental/reshard.py b/pathwaysutils/experimental/reshard.py index ec170a9..7e01350 100644 --- a/pathwaysutils/experimental/reshard.py +++ b/pathwaysutils/experimental/reshard.py @@ -26,6 +26,7 @@ from pathwaysutils import jax as pw_jax from pathwaysutils import lru_cache from pathwaysutils import plugin_executable +from pathwaysutils.experimental import split_by_mesh_axis _logger = logging.getLogger(__name__) @@ -34,6 +35,11 @@ INTERMEDIATE_REPLICA_SUFFIX = "_intermediate_replica" +def _identity(x: Any) -> Any: + """A helper function that returns its input.""" + return x + + class ReshardingPlanWrapper: """Wrapper around PluginProgram(reshard_request).""" @@ -556,3 +562,55 @@ def find_intermediate_sharding( memory_kind=in_sharding.memory_kind, ) return intermediate_sharding, replicated_axes + + +def reshard_with_intermediate_sharding( + x: Any, + in_sharding: jax.sharding.Sharding, + out_sharding: jax.sharding.Sharding, + *, + donate: bool = False, + may_alias: bool | None = None, # pylint: disable=unused-argument +) -> Any: + """Reshards `x` to `out_sharding`, using an intermediate sharding if possible. + + This function is an alternative to `reshard` that may be faster and sometime + essential for certain sharding combinations by using an intermediate sharding + to avoid expensive all-gathers. If no beneficial intermediate sharding is + found, it falls back to standard resharding. See `find_intermediate_sharding` + for more details on when an intermediate sharding is used. + + Args: + x: An array, scalar, or (nested) standard Python container thereof. + in_sharding: The source sharding of `x`. + out_sharding: The target sharding for `x`. + donate: If `True`, donate all input arrays, which may reduce the amount of + memory needed for resharding. Buffers donated to resharding should not be + reused. + may_alias: If `True`, may alias the input array with the output array. May + reduce the amount of memory needed for resharding. Not used at the moment. + + Returns: + A copy of `x` whose sharding is `out_sharding`. + """ + + try: + intermediate_sharding, replicated_axes_names = find_intermediate_sharding( + in_sharding, out_sharding + ) + except NoIntermediateShardingError as e: + _logger.debug("No intermediate sharding needed or found. %s", e) + x_to_reshard = x + else: + x_to_reshard = jax.jit( + _identity, + out_shardings=intermediate_sharding, + )(x) + for split_axis in replicated_axes_names: + x_to_reshard, *_ = split_by_mesh_axis.split_by_mesh_axis( + x_to_reshard, + split_axis, + donate=donate, + ) + + return reshard(x_to_reshard, out_sharding, donate=donate, may_alias=may_alias)