Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from .expr import Expr, WindowFrame
from .io import read_avro, read_csv, read_json, read_parquet
from .options import CsvReadOptions
from .plan import ExecutionPlan, LogicalPlan
from .plan import ExecutionPlan, LogicalPlan, Metric, MetricsSet
from .record_batch import RecordBatch, RecordBatchStream
from .user_defined import (
Accumulator,
Expand Down Expand Up @@ -85,6 +85,8 @@
"Expr",
"InsertOp",
"LogicalPlan",
"Metric",
"MetricsSet",
"ParquetColumnOptions",
"ParquetWriterOptions",
"RecordBatch",
Expand Down
106 changes: 106 additions & 0 deletions python/datafusion/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
__all__ = [
"ExecutionPlan",
"LogicalPlan",
"Metric",
"MetricsSet",
]


Expand Down Expand Up @@ -151,3 +153,107 @@ def to_proto(self) -> bytes:
Tables created in memory from record batches are currently not supported.
"""
return self._raw_plan.to_proto()

def metrics(self) -> MetricsSet | None:
"""Return metrics for this plan node after execution, or None if unavailable."""
raw = self._raw_plan.metrics()
if raw is None:
return None
return MetricsSet(raw)

def collect_metrics(self) -> list[tuple[str, MetricsSet]]:
"""Walk the plan tree and collect metrics from all operators.
Returns a list of (operator_name, MetricsSet) tuples.
"""
result: list[tuple[str, MetricsSet]] = []

def _walk(node: ExecutionPlan) -> None:
ms = node.metrics()
if ms is not None:
result.append((node.display(), ms))
for child in node.children():
_walk(child)

_walk(self)
return result


class MetricsSet:
"""A set of metrics for a single execution plan operator.
Provides both individual metric access and convenience aggregations
across partitions.
"""

def __init__(self, raw: df_internal.MetricsSet) -> None:
"""This constructor should not be called by the end user."""
self._raw = raw

def metrics(self) -> list[Metric]:
"""Return all individual metrics in this set."""
return [Metric(m) for m in self._raw.metrics()]

@property
def output_rows(self) -> int | None:
"""Sum of output_rows across all partitions."""
return self._raw.output_rows()

@property
def elapsed_compute(self) -> int | None:
"""Sum of elapsed_compute across all partitions, in nanoseconds."""
return self._raw.elapsed_compute()

@property
def spill_count(self) -> int | None:
"""Sum of spill_count across all partitions."""
return self._raw.spill_count()

@property
def spilled_bytes(self) -> int | None:
"""Sum of spilled_bytes across all partitions."""
return self._raw.spilled_bytes()

@property
def spilled_rows(self) -> int | None:
"""Sum of spilled_rows across all partitions."""
return self._raw.spilled_rows()

def sum_by_name(self, name: str) -> int | None:
"""Return the sum of metrics matching the given name."""
return self._raw.sum_by_name(name)

def __repr__(self) -> str:
"""Return a string representation of the metrics set."""
return repr(self._raw)


class Metric:
"""A single execution metric with name, value, partition, and labels."""

def __init__(self, raw: df_internal.Metric) -> None:
"""This constructor should not be called by the end user."""
self._raw = raw

@property
def name(self) -> str:
"""The name of this metric (e.g. ``output_rows``)."""
return self._raw.name

@property
def value(self) -> int | None:
"""The numeric value of this metric, or None for non-numeric types."""
return self._raw.value

@property
def partition(self) -> int | None:
"""The partition this metric applies to, or None if global."""
return self._raw.partition

def labels(self) -> dict[str, str]:
"""Return the labels associated with this metric."""
return self._raw.labels()

def __repr__(self) -> str:
"""Return a string representation of the metric."""
return repr(self._raw)
106 changes: 105 additions & 1 deletion python/tests/test_plans.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
# under the License.

import pytest
from datafusion import ExecutionPlan, LogicalPlan, SessionContext
from datafusion import (
ExecutionPlan,
LogicalPlan,
Metric,
MetricsSet,
SessionContext,
)


# Note: We must use CSV because memory tables are currently not supported for
Expand All @@ -40,3 +46,101 @@ def test_logical_plan_to_proto(ctx, df) -> None:
execution_plan = ExecutionPlan.from_proto(ctx, execution_plan_bytes)

assert str(original_execution_plan) == str(execution_plan)


def test_metrics_tree_walk() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
df.collect()
plan = df.execution_plan()

results = plan.collect_metrics()
assert len(results) >= 1
found_metrics = False
for name, ms in results:
assert isinstance(name, str)
assert isinstance(ms, MetricsSet)
if ms.output_rows is not None and ms.output_rows > 0:
found_metrics = True
assert found_metrics


def test_metric_properties() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
df.collect()
plan = df.execution_plan()

for _, ms in plan.collect_metrics():
r = repr(ms)
assert isinstance(r, str)
for metric in ms.metrics():
assert isinstance(metric, Metric)
assert isinstance(metric.name, str)
assert len(metric.name) > 0
assert metric.partition is None or isinstance(metric.partition, int)
assert isinstance(metric.labels(), dict)
mr = repr(metric)
assert isinstance(mr, str)
assert len(mr) > 0
return
pytest.skip("No metrics found")


def test_no_metrics_before_execution() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1), (2), (3)")
df = ctx.sql("SELECT * FROM t")
plan = df.execution_plan()
ms = plan.metrics()
assert ms is None or ms.output_rows is None or ms.output_rows == 0


def test_collect_partitioned_metrics() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")

df.collect_partitioned()
plan = df.execution_plan()

found_metrics = False
for _, ms in plan.collect_metrics():
if ms.output_rows is not None and ms.output_rows > 0:
found_metrics = True
assert found_metrics


def test_execute_stream_metrics() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")

for _ in df.execute_stream():
pass

plan = df.execution_plan()
found_metrics = False
for _, ms in plan.collect_metrics():
if ms.output_rows is not None and ms.output_rows > 0:
found_metrics = True
assert found_metrics


def test_execute_stream_partitioned_metrics() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")

for stream in df.execute_stream_partitioned():
for _ in stream:
pass

plan = df.execution_plan()
found_metrics = False
for _, ms in plan.collect_metrics():
if ms.output_rows is not None and ms.output_rows > 0:
found_metrics = True
assert found_metrics
45 changes: 41 additions & 4 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};

use datafusion::physical_plan::{
ExecutionPlan as DFExecutionPlan,
collect as df_collect,
collect_partitioned as df_collect_partitioned,
execute_stream as df_execute_stream,
execute_stream_partitioned as df_execute_stream_partitioned,
};

use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err};
use crate::expr::PyExpr;
use crate::expr::sort_expr::{PySortExpr, to_sort_expressions};
Expand Down Expand Up @@ -290,6 +298,9 @@ pub struct PyDataFrame {

// In IPython environment cache batches between __repr__ and _repr_html_ calls.
batches: SharedCachedBatches,

// Cache the last physical plan so that metrics are available after execution.
last_plan: Arc<Mutex<Option<Arc<dyn DFExecutionPlan>>>>,
}

impl PyDataFrame {
Expand All @@ -298,6 +309,7 @@ impl PyDataFrame {
Self {
df: Arc::new(df),
batches: Arc::new(Mutex::new(None)),
last_plan: Arc::new(Mutex::new(None)),
}
}

Expand Down Expand Up @@ -627,7 +639,12 @@ impl PyDataFrame {
/// Unless some order is specified in the plan, there is no
/// guarantee of the order of the result.
fn collect<'py>(&self, py: Python<'py>) -> PyResult<Vec<Bound<'py, PyAny>>> {
let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
let df = self.df.as_ref().clone();
let plan = wait_for_future(py, df.create_physical_plan())?
.map_err(PyDataFusionError::from)?;
*self.last_plan.lock() = Some(Arc::clone(&plan));
let task_ctx = Arc::new(self.df.as_ref().task_ctx());
let batches = wait_for_future(py, df_collect(plan, task_ctx))?
.map_err(PyDataFusionError::from)?;
// cannot use PyResult<Vec<RecordBatch>> return type due to
// https://github.com/PyO3/pyo3/issues/1813
Expand All @@ -643,7 +660,12 @@ impl PyDataFrame {
/// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
/// maintaining the input partitioning.
fn collect_partitioned<'py>(&self, py: Python<'py>) -> PyResult<Vec<Vec<Bound<'py, PyAny>>>> {
let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
let df = self.df.as_ref().clone();
let plan = wait_for_future(py, df.create_physical_plan())?
.map_err(PyDataFusionError::from)?;
*self.last_plan.lock() = Some(Arc::clone(&plan));
let task_ctx = Arc::new(self.df.as_ref().task_ctx());
let batches = wait_for_future(py, df_collect_partitioned(plan, task_ctx))?
.map_err(PyDataFusionError::from)?;

batches
Expand Down Expand Up @@ -803,7 +825,13 @@ impl PyDataFrame {
}

/// Get the execution plan for this `DataFrame`
///
/// If the DataFrame has already been executed (e.g. via `collect()`),
/// returns the cached plan which includes populated metrics.
fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
if let Some(plan) = self.last_plan.lock().as_ref() {
return Ok(PyExecutionPlan::new(Arc::clone(plan)));
}
let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??;
Ok(plan.into())
}
Expand Down Expand Up @@ -1128,13 +1156,22 @@ impl PyDataFrame {

fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
let df = self.df.as_ref().clone();
let stream = spawn_future(py, async move { df.execute_stream().await })?;
let plan = wait_for_future(py, df.create_physical_plan())??;
*self.last_plan.lock() = Some(Arc::clone(&plan));
let task_ctx = Arc::new(self.df.as_ref().task_ctx());
let stream = spawn_future(py, async move { df_execute_stream(plan, task_ctx) })?;
Ok(PyRecordBatchStream::new(stream))
}

fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
let df = self.df.as_ref().clone();
let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
let plan = wait_for_future(py, df.create_physical_plan())?
.map_err(PyDataFusionError::from)?;
*self.last_plan.lock() = Some(Arc::clone(&plan));
let task_ctx = Arc::new(self.df.as_ref().task_ctx());
let streams = spawn_future(py, async move {
df_execute_stream_partitioned(plan, task_ctx)
})?;
Ok(streams.into_iter().map(PyRecordBatchStream::new).collect())
}

Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub mod errors;
pub mod expr;
#[allow(clippy::borrow_deref_ref)]
mod functions;
pub mod metrics;
mod options;
pub mod physical_plan;
mod pyarrow_filter_expression;
Expand Down Expand Up @@ -96,6 +97,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<udtf::PyTableFunction>()?;
m.add_class::<config::PyConfig>()?;
m.add_class::<sql::logical::PyLogicalPlan>()?;
m.add_class::<metrics::PyMetricsSet>()?;
m.add_class::<metrics::PyMetric>()?;
m.add_class::<physical_plan::PyExecutionPlan>()?;
m.add_class::<record_batch::PyRecordBatch>()?;
m.add_class::<record_batch::PyRecordBatchStream>()?;
Expand Down
Loading