Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{ColumnarToRowExec, ShufflePartitionSpec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand Down Expand Up @@ -119,13 +120,32 @@ case class CometBroadcastExchangeExec(

val countsAndBytes = child match {
case c: CometPlan => CometExec.getByteArrayRdd(c).collect()
// AQEShuffleReadExec with CometShuffleExchangeExec: use coalesced partition specs
case aqe @ AQEShuffleReadExec(s: ShuffleQueryStageExec, _)
if s.shuffle.isInstanceOf[CometShuffleExchangeExec] =>
CometBroadcastExchangeExec
.getByteArrayRddFromCoalescedShuffle(
s.shuffle.asInstanceOf[CometShuffleExchangeExec],
aqe.partitionSpecs.toArray)
.collect()
// AQEShuffleReadExec with other CometPlan (fallback to original behavior)
case AQEShuffleReadExec(s: ShuffleQueryStageExec, _)
if s.plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect()
case s: ShuffleQueryStageExec if s.plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect()
case ReusedExchangeExec(_, plan) if plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect()
// AQEShuffleReadExec with ReusedExchange containing CometShuffleExchangeExec
case aqe @ AQEShuffleReadExec(
ShuffleQueryStageExec(
_,
ReusedExchangeExec(_, shuffle: CometShuffleExchangeExec),
_),
_) =>
CometBroadcastExchangeExec
.getByteArrayRddFromCoalescedShuffle(shuffle, aqe.partitionSpecs.toArray)
.collect()
case AQEShuffleReadExec(ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _), _)
if plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect()
Expand Down Expand Up @@ -273,6 +293,27 @@ object CometBroadcastExchangeExec extends CometSink[BroadcastExchangeExec] {
*/
override def isFfiSafe: Boolean = true

/**
* Gets serialized batches from a CometShuffleExchangeExec using the coalesced partition specs
* from AQE. This ensures that when AQE coalesces shuffle partitions (e.g., from 200 to 1), the
* broadcast exchange respects this optimization instead of reading all original partitions.
*
* @param shuffle
* The CometShuffleExchangeExec to read from
* @param partitionSpecs
* The coalesced partition specs from AQEShuffleReadExec
* @return
* RDD of (rowCount, serializedBytes) tuples
*/
private[comet] def getByteArrayRddFromCoalescedShuffle(
shuffle: CometShuffleExchangeExec,
partitionSpecs: Array[ShufflePartitionSpec]): RDD[(Long, ChunkedByteBuffer)] = {
shuffle.getShuffleRDD(partitionSpecs).asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal {
iter =>
Utils.serializeBatches(iter)
}
}

override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED)

Expand Down
Loading