diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index f40e05ea0c..6ed5b9d0de 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -32,9 +32,10 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{ColumnarToRowExec, ShufflePartitionSpec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -119,6 +120,15 @@ case class CometBroadcastExchangeExec( val countsAndBytes = child match { case c: CometPlan => CometExec.getByteArrayRdd(c).collect() + // AQEShuffleReadExec with CometShuffleExchangeExec: use coalesced partition specs + case aqe @ AQEShuffleReadExec(s: ShuffleQueryStageExec, _) + if s.shuffle.isInstanceOf[CometShuffleExchangeExec] => + CometBroadcastExchangeExec + .getByteArrayRddFromCoalescedShuffle( + s.shuffle.asInstanceOf[CometShuffleExchangeExec], + aqe.partitionSpecs.toArray) + .collect() + // AQEShuffleReadExec with other CometPlan (fallback to original behavior) case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) if s.plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect() @@ -126,6 +136,16 @@ case class CometBroadcastExchangeExec( CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect() case ReusedExchangeExec(_, plan) if plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() + // AQEShuffleReadExec with ReusedExchange containing CometShuffleExchangeExec + case aqe @ AQEShuffleReadExec( + ShuffleQueryStageExec( + _, + ReusedExchangeExec(_, shuffle: CometShuffleExchangeExec), + _), + _) => + CometBroadcastExchangeExec + .getByteArrayRddFromCoalescedShuffle(shuffle, aqe.partitionSpecs.toArray) + .collect() case AQEShuffleReadExec(ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _), _) if plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() @@ -273,6 +293,27 @@ object CometBroadcastExchangeExec extends CometSink[BroadcastExchangeExec] { */ override def isFfiSafe: Boolean = true + /** + * Gets serialized batches from a CometShuffleExchangeExec using the coalesced partition specs + * from AQE. This ensures that when AQE coalesces shuffle partitions (e.g., from 200 to 1), the + * broadcast exchange respects this optimization instead of reading all original partitions. + * + * @param shuffle + * The CometShuffleExchangeExec to read from + * @param partitionSpecs + * The coalesced partition specs from AQEShuffleReadExec + * @return + * RDD of (rowCount, serializedBytes) tuples + */ + private[comet] def getByteArrayRddFromCoalescedShuffle( + shuffle: CometShuffleExchangeExec, + partitionSpecs: Array[ShufflePartitionSpec]): RDD[(Long, ChunkedByteBuffer)] = { + shuffle.getShuffleRDD(partitionSpecs).asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { + iter => + Utils.serializeBatches(iter) + } + } + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED)