Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions duckdb/experimental/spark/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from collections.abc import Callable, Iterable, Sized
from typing import Literal, TypeVar

import pyarrow
from numpy import float32, float64, int32, int64, ndarray
from pandas import DataFrame as PandasDataFrame
Comment on lines +22 to +24
from typing_extensions import Protocol, Self

F = TypeVar("F", bound=Callable)
Expand All @@ -30,6 +32,13 @@
NonUDFType = Literal[0]


DataFrameLike = PandasDataFrame

PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], Iterable[DataFrameLike]]

ArrowMapIterFunction = Callable[[Iterable[pyarrow.RecordBatch]], Iterable[pyarrow.RecordBatch]]
Comment on lines +35 to +39


class SupportsIAdd(Protocol):
def __iadd__(self, other: "SupportsIAdd") -> Self: ...

Expand Down
200 changes: 199 additions & 1 deletion duckdb/experimental/spark/sql/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from collections.abc import Callable
from collections.abc import Callable, Iterable
from functools import reduce
from keyword import iskeyword
from typing import (
Expand All @@ -12,6 +12,7 @@

import duckdb
from duckdb import ColumnExpression, Expression, StarExpression
from duckdb.experimental.spark.exception import ContributionsAcceptedError

from ..errors import PySparkIndexError, PySparkTypeError, PySparkValueError
from .column import Column
Expand All @@ -23,6 +24,7 @@
import pyarrow as pa
from pandas.core.frame import DataFrame as PandasDataFrame

from .._typing import ArrowMapIterFunction, PandasMapIterFunction
from ._typing import ColumnOrName
from .group import GroupedData
from .session import SparkSession
Expand Down Expand Up @@ -1430,5 +1432,201 @@ def cache(self) -> "DataFrame":
cached_relation = self.relation.execute()
return DataFrame(cached_relation, self.session)

def mapInArrow(
self,
func: "ArrowMapIterFunction",
schema: StructType | str,
barrier: bool = False,
profile: object | None = None,
) -> "DataFrame":
"""Maps an iterator of batches in the current :class:`DataFrame` using a Python native
function that is performed on `pyarrow.RecordBatch`\\s both as input and output,
and returns the result as a :class:`DataFrame`.

This method applies the specified Python function to an iterator of
`pyarrow.RecordBatch`\\s, each representing a batch of rows from the original DataFrame.
The returned iterator of `pyarrow.RecordBatch`\\s are combined as a :class:`DataFrame`.
The size of the function's input and output can be different. Each `pyarrow.RecordBatch`
size can be controlled by `spark.sql.execution.arrow.maxRecordsPerBatch`.

.. versionadded:: 3.3.0

Parameters
----------
func : function
a Python native function that takes an iterator of `pyarrow.RecordBatch`\\s, and
outputs an iterator of `pyarrow.RecordBatch`\\s.
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
barrier : bool, optional, default False
Use barrier mode execution, ensuring that all Python workers in the stage will be
launched concurrently.

.. versionadded: 3.5.0

profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile
to be used for mapInArrow.

.. versionadded: 4.0.0

Examples:
--------
>>> import pyarrow as pa
>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
>>> def filter_func(iterator):
... for batch in iterator:
... pdf = batch.to_pandas()
... yield pa.RecordBatch.from_pandas(pdf[pdf.id == 1])
>>> df.mapInArrow(filter_func, df.schema).show()
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+

Set ``barrier`` to ``True`` to force the ``mapInArrow`` stage running in the
barrier mode, it ensures all Python workers in the stage will be
launched concurrently.

>>> df.mapInArrow(filter_func, df.schema, barrier=True).collect()
[Row(id=1, age=21)]

See Also:
--------
pyspark.sql.functions.pandas_udf
DataFrame.mapInPandas
""" # noqa: D205, D301
if isinstance(schema, str):
msg = "DDL-formatted type string is not supported yet for the 'schema' parameter."
raise ContributionsAcceptedError(msg)

if profile is not None:
msg = "ResourceProfile is not supported yet for the 'profile' parameter."
raise ContributionsAcceptedError(msg)

del barrier # Ignored due duckdb works on single node and doesn't have barrier execution mode.

import pyarrow as pa
from pyarrow.dataset import dataset

arrow_schema = self.session.createDataFrame([], schema=schema).toArrow().schema
record_batches = self.relation.fetch_record_batch()
batch_generator = func(record_batches)
reader = pa.RecordBatchReader.from_batches(arrow_schema, batch_generator)
ds = dataset(reader) # noqa: F841
df = DataFrame(self.session.conn.sql("SELECT * FROM ds"), self.session)
Comment on lines +1517 to +1518
return df

def mapInPandas(
self,
func: "PandasMapIterFunction",
schema: StructType | str,
barrier: bool = False,
profile: object | None = None,
) -> "DataFrame":
"""Maps an iterator of batches in the current :class:`DataFrame` using a Python native
function that is performed on pandas DataFrames both as input and output,
and returns the result as a :class:`DataFrame`.

This method applies the specified Python function to an iterator of
`pandas.DataFrame`\\s, each representing a batch of rows from the original DataFrame.
The returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
The size of the function's input and output can be different. Each `pandas.DataFrame`
size can be controlled by `spark.sql.execution.arrow.maxRecordsPerBatch`.

.. versionadded:: 3.0.0

.. versionchanged:: 3.4.0
Supports Spark Connect.

Parameters
----------
func : function
a Python native function that takes an iterator of `pandas.DataFrame`\\s, and
outputs an iterator of `pandas.DataFrame`\\s.
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
Comment on lines +1548 to +1550
barrier : bool, optional, default False
Use barrier mode execution, ensuring that all Python workers in the stage will be
launched concurrently.

.. versionadded: 3.5.0

profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile
to be used for mapInPandas.

.. versionadded: 4.0.0


Examples:
--------
>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))

Filter rows with id equal to 1:

>>> def filter_func(iterator):
... for pdf in iterator:
... yield pdf[pdf.id == 1]
>>> df.mapInPandas(filter_func, df.schema).show()
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+

Compute the mean age for each id:

>>> def mean_age(iterator):
... for pdf in iterator:
... yield pdf.groupby("id").mean().reset_index()
>>> df.mapInPandas(mean_age, "id: bigint, age: double").show()
Comment on lines +1581 to +1584
+---+----+
| id| age|
+---+----+
| 1|21.0|
| 2|30.0|
+---+----+

Add a new column with the double of the age:

>>> def double_age(iterator):
... for pdf in iterator:
... pdf["double_age"] = pdf["age"] * 2
... yield pdf
>>> df.mapInPandas(double_age, "id: bigint, age: bigint, double_age: bigint").show()
+---+---+----------+
| id|age|double_age|
+---+---+----------+
| 1| 21| 42|
| 2| 30| 60|
+---+---+----------+

Set ``barrier`` to ``True`` to force the ``mapInPandas`` stage running in the
barrier mode, it ensures all Python workers in the stage will be
launched concurrently.

>>> df.mapInPandas(filter_func, df.schema, barrier=True).collect()
[Row(id=1, age=21)]

See Also:
--------
pyspark.sql.functions.pandas_udf
DataFrame.mapInArrow
""" # noqa: D205, D301
import pyarrow as pa

def _build_arrow_func(pandas_func: "PandasMapIterFunction") -> "ArrowMapIterFunction":
def _map_func(record_batches: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]:
pandas_iterator = (batch.to_pandas() for batch in record_batches)
pandas_result_gen = pandas_func(pandas_iterator)
batch_iterator = (pa.RecordBatch.from_pandas(pdf) for pdf in pandas_result_gen)
yield from batch_iterator

return _map_func

return self.mapInArrow(_build_arrow_func(func), schema, barrier, profile)


__all__ = ["DataFrame"]
2 changes: 1 addition & 1 deletion external/duckdb
Submodule duckdb updated 2835 files
88 changes: 88 additions & 0 deletions tests/fast/spark/test_spark_dataframe_map_in.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

_ = pytest.importorskip("duckdb.experimental.spark")

from spark_namespace.sql import functions as F
from spark_namespace.sql.types import Row


class TestDataFrameMapInMethods:
data = ((56, "Carol"), (20, "Alice"), (3, "Dave"), (3, "Anna"), (1, "Ben"))

def test_map_in_pandas(self, spark):
def filter_func(iterator):
for pdf in iterator:
yield pdf[pdf.age == 3]

df = spark.createDataFrame(self.data, ["age", "name"])
df = df.mapInPandas(filter_func, schema=df.schema)
df = df.sort(["age", "name"])

expected = [
Row(age=3, name="Anna"),
Row(age=3, name="Dave"),
]

assert df.collect() == expected

def test_map_in_pandas_empty_result(self, spark):
def filter_func(iterator):
for pdf in iterator:
yield pdf[pdf.age > 100]

df = spark.createDataFrame(self.data, ["age", "name"])
df = df.mapInPandas(filter_func, schema=df.schema)

expected = []

assert df.collect() == expected
assert df.schema == spark.createDataFrame([], schema=df.schema).schema

def test_map_in_pandas_large_dataset_ensure_no_data_loss(self, spark):
def identity_func(iterator):
for pdf in iterator:
pdf = pdf[pdf.id >= 0] # Apply a filter to ensure the DataFrame is evaluated
yield pdf

n = 10_000_000

pandas_df = pd.DataFrame(
{
"id": np.arange(n, dtype=np.int64),
"value_float": np.random.rand(n).astype(np.float32),
"value_int": np.random.randint(0, 1000, size=n, dtype=np.int32),
"category": np.random.randint(0, 10, size=n, dtype=np.int8),
}
)
Comment on lines +50 to +59

df = spark.createDataFrame(pandas_df)
df = df.mapInPandas(identity_func, schema=df.schema)
# Apply filters to evaluate all dataframe
df = df.filter(F.col("id") <= n).filter(F.col("id") >= 0).filter(F.col("category") >= 0)

generated_pandas_df = df.toPandas()
total_records = df.count()
Comment on lines +66 to +67

assert total_records == n
assert pandas_df["id"].equals(generated_pandas_df["id"])

def test_map_in_arrow(self, spark):
def filter_func(iterator):
for batch in iterator:
df = batch.to_pandas()
df = df[df.age == 3]
yield pa.RecordBatch.from_pandas(df)

df = spark.createDataFrame(self.data, ["age", "name"])
df = df.mapInArrow(filter_func, schema=df.schema)
df = df.sort(["age", "name"])

expected = [
Row(age=3, name="Anna"),
Row(age=3, name="Dave"),
]

assert df.collect() == expected
Loading