Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,19 @@ def repartitionById(
self.sparkSession,
)

def optimizePartitions(self, targetMB: Optional[int] = None) -> "DataFrame":
"""
Optimizes the partition count based on dataset size.
"""
target_size = targetMB if targetMB is not None else 128
if target_size <= 0:
raise PySparkValueError(
errorClass="VALUE_NOT_POSITIVE",
messageParameters={"arg_name": "targetMB", "arg_value": str(target_size)},
)
jdf = self._jdf.optimizePartitions(int(target_size))
return DataFrame(jdf, self.sparkSession)

def distinct(self) -> ParentDataFrame:
return DataFrame(self._jdf.distinct(), self.sparkSession)

Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,12 @@ def repartitionById(
res._cached_schema = self._cached_schema
return res

def optimizePartitions(self, targetMB: Optional[int] = None) -> "DataFrame":
raise PySparkNotImplementedError(
errorClass="NOT_IMPLEMENTED",
messageParameters={"feature": "optimizePartitions for Spark Connect"},
)

def dropDuplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame:
if subset is not None and not isinstance(subset, (list, tuple)):
raise PySparkTypeError(
Expand Down Expand Up @@ -2363,6 +2369,7 @@ def _test() -> None:
if not is_remote_only():
del pyspark.sql.dataframe.DataFrame.toJSON.__doc__
del pyspark.sql.dataframe.DataFrame.rdd.__doc__
del pyspark.sql.dataframe.DataFrame.optimizePartitions.__doc__

if not have_pandas or not have_pyarrow:
del pyspark.sql.dataframe.DataFrame.toPandas.__doc__
Expand Down
39 changes: 39 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1954,6 +1954,45 @@ def repartitionById(self, numPartitions: int, partitionIdCol: "ColumnOrName") ->
"""
...

@dispatch_df_method
def optimizePartitions(self, targetMB: Optional[int] = None) -> "DataFrame":
"""
Proactively optimizes the partition count of this DataFrame based on its estimated size.
Best Practice: Use on Ingest
This method is best used immediately after reading a dataset to ensure the initial
parallelism matches the data size. This prevents "Small File" issues (too many partitions)
or "Giant Partition" issues (too few partitions) before heavy transformations begin.
.. versionadded:: 4.2.0
Parameters
----------
targetMB : int, optional
The target partition size in Megabytes. Defaults to 128MB.
Returns
-------
:class:`DataFrame`
Repartitioned DataFrame.
Notes
-----
This method uses Round Robin partitioning (random shuffle) to balance sizes.
If used immediately before writing to a partitioned table, it may degrade performance
by breaking data locality.
Examples
--------
>>> df = spark.range(1000000).repartition(8)
>>> df.rdd.getNumPartitions()
8
>>> df_opt = df.optimizePartitions(64)
>>> df_opt.rdd.getNumPartitions() # e.g., 1 (depending on data size)
1
"""
...

@dispatch_df_method
def distinct(self) -> "DataFrame":
"""Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def test_toDF_with_schema_string(self):
def test_query_execution_unsupported_in_classic(self):
pass

@unittest.skip("optimizePartitions is not implemented in Spark Connect")
def test_optimize_partitions(self):
super().test_optimize_partitions()


if __name__ == "__main__":
import unittest
Expand Down
71 changes: 62 additions & 9 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,12 @@ def test_drop_column_name_with_dot(self):
self.assertEqual(df.drop("city.name").columns, ["id", "first.name", "state"])
self.assertEqual(df.drop("first.name", "city.name").columns, ["id", "state"])
self.assertEqual(
df.drop("first.name", "city.name", "unknown.unknown").columns, ["id", "state"]
df.drop("first.name", "city.name", "unknown.unknown").columns,
["id", "state"],
)
self.assertEqual(
df.drop("unknown.unknown").columns, ["id", "first.name", "city.name", "state"]
df.drop("unknown.unknown").columns,
["id", "first.name", "city.name", "state"],
)

def test_with_column_with_existing_name(self):
Expand Down Expand Up @@ -442,7 +444,8 @@ def test_with_columns(self):
.collect()
)
self.assertEqual(
[(r.key_alias, r.value_alias) for r in kvs], [(i, str(i)) for i in range(100)]
[(r.key_alias, r.value_alias) for r in kvs],
[(i, str(i)) for i in range(100)],
)

# Type check
Expand Down Expand Up @@ -498,6 +501,40 @@ def test_coalesce_hints_with_string_parameter(self):
self.assertGreaterEqual(output.count("REBALANCE_PARTITIONS_BY_NONE"), 1)
self.assertGreaterEqual(output.count("REBALANCE_PARTITIONS_BY_COL"), 3)

def test_optimize_partitions(self):
# Setup: Create a small DataFrame with an intentionally inefficient number of partitions
# range(10000) is very small data (~80KB), but we force 50 partitions.
initial_partitions = 50
df = self.spark.range(10000).repartition(initial_partitions)

self.assertEqual(df.rdd.getNumPartitions(), initial_partitions)

# 1. Test Default Execution (Downscaling)
# Since data is small and default target is 128MB, this should coalesce to 1 partition.
result_default = df.optimizePartitions()

# Assertions
self.assertEqual(
result_default.rdd.getNumPartitions(),
1,
"Expected tiny dataset to coalesce to 1 partition",
)
self.assertEqual(result_default.count(), 10000, "Data count mismatch after optimization")

result_custom = df.optimizePartitions(targetMB=1)
self.assertEqual(result_custom.rdd.getNumPartitions(), 1)

# We expect optimizePartitions to throw PySparkValueError when targetMB is <= 0
with self.assertRaisesRegex(
PySparkValueError, "Value for `targetMB` must be positive, got '-1'"
):
df.optimizePartitions(targetMB=-1)

with self.assertRaisesRegex(
PySparkValueError, "Value for `targetMB` must be positive, got '0'"
):
df.optimizePartitions(targetMB=0)

# add tests for SPARK-23647 (test more types for hint)
def test_extended_hint_types(self):
df = self.spark.range(10e10).toDF("id")
Expand Down Expand Up @@ -952,7 +989,10 @@ def test_pandas_api(self):

def test_to(self):
schema = StructType(
[StructField("i", StringType(), True), StructField("j", IntegerType(), True)]
[
StructField("i", StringType(), True),
StructField("j", IntegerType(), True),
]
)
df = self.spark.createDataFrame([("a", 1)], schema)

Expand All @@ -974,13 +1014,17 @@ def test_to(self):
# incompatible field nullability
schema4 = StructType([StructField("j", LongType(), False)])
self.assertRaisesRegex(
AnalysisException, "NULLABLE_COLUMN_OR_FIELD", lambda: df.to(schema4).count()
AnalysisException,
"NULLABLE_COLUMN_OR_FIELD",
lambda: df.to(schema4).count(),
)

# field cannot upcast
schema5 = StructType([StructField("i", LongType())])
self.assertRaisesRegex(
AnalysisException, "INVALID_COLUMN_OR_FIELD_DATA_TYPE", lambda: df.to(schema5).count()
AnalysisException,
"INVALID_COLUMN_OR_FIELD_DATA_TYPE",
lambda: df.to(schema5).count(),
)

def test_colregex(self):
Expand Down Expand Up @@ -1063,7 +1107,10 @@ def test_transpose(self):
# default index column
transposed_df = df.transpose()
expected_schema = StructType(
[StructField("key", StringType(), False), StructField("x", StringType(), True)]
[
StructField("key", StringType(), False),
StructField("x", StringType(), True),
]
)
expected_data = [Row(key="b", x="y"), Row(key="c", x="z")]
expected_df = self.spark.createDataFrame(expected_data, schema=expected_schema)
Expand All @@ -1072,7 +1119,10 @@ def test_transpose(self):
# specified index column
transposed_df = df.transpose("c")
expected_schema = StructType(
[StructField("key", StringType(), False), StructField("z", StringType(), True)]
[
StructField("key", StringType(), False),
StructField("z", StringType(), True),
]
)
expected_data = [Row(key="a", z="x"), Row(key="b", z="y")]
expected_df = self.spark.createDataFrame(expected_data, schema=expected_schema)
Expand All @@ -1085,7 +1135,10 @@ def test_transpose(self):
self.check_error(
exception=pe.exception,
errorClass="TRANSPOSE_EXCEED_ROW_LIMIT",
messageParameters={"maxValues": "0", "config": "spark.sql.transposeMaxValues"},
messageParameters={
"maxValues": "0",
"config": "spark.sql.transposeMaxValues",
},
)

# enforce ascending order based on index column values for transposed columns
Expand Down
31 changes: 31 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2981,6 +2981,37 @@ abstract class Dataset[T] extends Serializable {
*/
def repartitionById(numPartitions: Int, partitionIdExpr: Column): Dataset[T]

/**
* Proactively optimizes the partition count of this Dataset based on its estimated size.
*
* ==Best Practice: Use on Ingest==
* This method is best used immediately after reading a dataset to ensure the initial
* parallelism matches the data size. This prevents "Small File" issues (too many partitions) or
* "Giant Partition" issues (too few partitions) before heavy transformations begin.
*
* {{{
* val raw = spark.read.parquet("...")
* val optimized = raw.optimizePartitions() // Perfect start for transformations
* optimized.filter(...).groupBy(...)
* }}}
*
* ==Warning: Use on Write==
* This method uses Round Robin partitioning (random shuffle) to balance sizes. If used
* immediately before writing to a partitioned table (e.g., `write.partitionBy("city")`), it may
* degrade performance by breaking data locality, causing the writer to create many small files
* across directories.
*
* @param targetMB
* The target partition size in Megabytes. Defaults to 128MB.
* @group typedrel
* @since 4.2.0
*/
def optimizePartitions(targetMB: Int = 128): Dataset[T] = {
throw new UnsupportedOperationException(
"This method is implemented in " +
"the concrete Dataset classes")
}

/**
* Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions
* are requested. If a larger number of partitions is requested, it will stay at the current
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OptimizePartitionsCommand, Repartition}
import org.apache.spark.sql.catalyst.rules.Rule

/**
* Proactively optimizes the partition count of a Dataset based on its estimated size.
* This rule transforms the custom OptimizePartitionsCommand into standard Spark operations.
*/
object OptimizePartitionsRule extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
case OptimizePartitionsCommand(child, targetMB, currentPartitions) =>

require(targetMB > 0, s"targetMB must be positive. Got $targetMB")
val targetBytes = targetMB.toLong * 1024L * 1024L

// Get the estimated size from Catalyst Statistics
val sizeInBytes: BigInt = child.stats.sizeInBytes

// Calculate Optimal Partition Count (N)
val count = math.ceil(sizeInBytes.toDouble / targetBytes).toInt
val calculatedN: Int = if (count <= 1) 1 else count

// Smart Switch: Coalesce vs Repartition
if (calculatedN < currentPartitions) {
// DOWNSCALING: Use Coalesce (shuffle = false)
Repartition(calculatedN, shuffle = false, child)
} else if (calculatedN > currentPartitions) {
// UPSCALING: Use Repartition (shuffle = true)
Repartition(calculatedN, shuffle = true, child)
} else {
// OPTIMAL
child
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.Attribute

/**
* A logical command that hints to the optimizer that we want to
* automatically repartition the data based on statistics.
*/
case class OptimizePartitionsCommand(child: LogicalPlan,
targetMB: Int,
currentPartitions: Int) extends UnaryNode {

override def output: Seq[Attribute] = child.output

override protected def withNewChildInternal(newChild: LogicalPlan): OptimizePartitionsCommand =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,14 @@ class Dataset[T] private[sql](
}
}

override def optimizePartitions(targetMB: Int): Dataset[T] = {
val currentPartitions = rdd.getNumPartitions

withTypedPlan {
OptimizePartitionsCommand(logicalPlan, targetMB, currentPartitions)
}
}

/** @inheritdoc */
def coalesce(numPartitions: Int): Dataset[T] = withSameTypedPlan {
Repartition(numPartitions, shuffle = false, logicalPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class SparkOptimizer(
ConstantFolding,
EliminateLimits),
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*),
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition)))
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition),
Batch("Optimizer Partitions", Once, OptimizePartitionsRule)))

override def nonExcludableRules: Seq[String] = super.nonExcludableRules ++
Seq(
Expand Down
Loading