diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 2e6f81166..f0314d8dd 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -55,7 +55,7 @@ from .expr import Expr, WindowFrame from .io import read_avro, read_csv, read_json, read_parquet from .options import CsvReadOptions -from .plan import ExecutionPlan, LogicalPlan +from .plan import ExecutionPlan, LogicalPlan, Metric, MetricsSet from .record_batch import RecordBatch, RecordBatchStream from .user_defined import ( Accumulator, @@ -85,6 +85,8 @@ "Expr", "InsertOp", "LogicalPlan", + "Metric", + "MetricsSet", "ParquetColumnOptions", "ParquetWriterOptions", "RecordBatch", diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index fb54fd624..d46ff1a00 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -29,6 +29,8 @@ __all__ = [ "ExecutionPlan", "LogicalPlan", + "Metric", + "MetricsSet", ] @@ -151,3 +153,107 @@ def to_proto(self) -> bytes: Tables created in memory from record batches are currently not supported. """ return self._raw_plan.to_proto() + + def metrics(self) -> MetricsSet | None: + """Return metrics for this plan node after execution, or None if unavailable.""" + raw = self._raw_plan.metrics() + if raw is None: + return None + return MetricsSet(raw) + + def collect_metrics(self) -> list[tuple[str, MetricsSet]]: + """Walk the plan tree and collect metrics from all operators. + + Returns a list of (operator_name, MetricsSet) tuples. + """ + result: list[tuple[str, MetricsSet]] = [] + + def _walk(node: ExecutionPlan) -> None: + ms = node.metrics() + if ms is not None: + result.append((node.display(), ms)) + for child in node.children(): + _walk(child) + + _walk(self) + return result + + +class MetricsSet: + """A set of metrics for a single execution plan operator. + + Provides both individual metric access and convenience aggregations + across partitions. + """ + + def __init__(self, raw: df_internal.MetricsSet) -> None: + """This constructor should not be called by the end user.""" + self._raw = raw + + def metrics(self) -> list[Metric]: + """Return all individual metrics in this set.""" + return [Metric(m) for m in self._raw.metrics()] + + @property + def output_rows(self) -> int | None: + """Sum of output_rows across all partitions.""" + return self._raw.output_rows() + + @property + def elapsed_compute(self) -> int | None: + """Sum of elapsed_compute across all partitions, in nanoseconds.""" + return self._raw.elapsed_compute() + + @property + def spill_count(self) -> int | None: + """Sum of spill_count across all partitions.""" + return self._raw.spill_count() + + @property + def spilled_bytes(self) -> int | None: + """Sum of spilled_bytes across all partitions.""" + return self._raw.spilled_bytes() + + @property + def spilled_rows(self) -> int | None: + """Sum of spilled_rows across all partitions.""" + return self._raw.spilled_rows() + + def sum_by_name(self, name: str) -> int | None: + """Return the sum of metrics matching the given name.""" + return self._raw.sum_by_name(name) + + def __repr__(self) -> str: + """Return a string representation of the metrics set.""" + return repr(self._raw) + + +class Metric: + """A single execution metric with name, value, partition, and labels.""" + + def __init__(self, raw: df_internal.Metric) -> None: + """This constructor should not be called by the end user.""" + self._raw = raw + + @property + def name(self) -> str: + """The name of this metric (e.g. ``output_rows``).""" + return self._raw.name + + @property + def value(self) -> int | None: + """The numeric value of this metric, or None for non-numeric types.""" + return self._raw.value + + @property + def partition(self) -> int | None: + """The partition this metric applies to, or None if global.""" + return self._raw.partition + + def labels(self) -> dict[str, str]: + """Return the labels associated with this metric.""" + return self._raw.labels() + + def __repr__(self) -> str: + """Return a string representation of the metric.""" + return repr(self._raw) diff --git a/python/tests/test_plans.py b/python/tests/test_plans.py index 396acbe97..d3525b08c 100644 --- a/python/tests/test_plans.py +++ b/python/tests/test_plans.py @@ -16,7 +16,13 @@ # under the License. import pytest -from datafusion import ExecutionPlan, LogicalPlan, SessionContext +from datafusion import ( + ExecutionPlan, + LogicalPlan, + Metric, + MetricsSet, + SessionContext, +) # Note: We must use CSV because memory tables are currently not supported for @@ -40,3 +46,101 @@ def test_logical_plan_to_proto(ctx, df) -> None: execution_plan = ExecutionPlan.from_proto(ctx, execution_plan_bytes) assert str(original_execution_plan) == str(execution_plan) + + +def test_metrics_tree_walk() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + df.collect() + plan = df.execution_plan() + + results = plan.collect_metrics() + assert len(results) >= 1 + found_metrics = False + for name, ms in results: + assert isinstance(name, str) + assert isinstance(ms, MetricsSet) + if ms.output_rows is not None and ms.output_rows > 0: + found_metrics = True + assert found_metrics + + +def test_metric_properties() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + df.collect() + plan = df.execution_plan() + + for _, ms in plan.collect_metrics(): + r = repr(ms) + assert isinstance(r, str) + for metric in ms.metrics(): + assert isinstance(metric, Metric) + assert isinstance(metric.name, str) + assert len(metric.name) > 0 + assert metric.partition is None or isinstance(metric.partition, int) + assert isinstance(metric.labels(), dict) + mr = repr(metric) + assert isinstance(mr, str) + assert len(mr) > 0 + return + pytest.skip("No metrics found") + + +def test_no_metrics_before_execution() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1), (2), (3)") + df = ctx.sql("SELECT * FROM t") + plan = df.execution_plan() + ms = plan.metrics() + assert ms is None or ms.output_rows is None or ms.output_rows == 0 + + +def test_collect_partitioned_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + df.collect_partitioned() + plan = df.execution_plan() + + found_metrics = False + for _, ms in plan.collect_metrics(): + if ms.output_rows is not None and ms.output_rows > 0: + found_metrics = True + assert found_metrics + + +def test_execute_stream_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + for _ in df.execute_stream(): + pass + + plan = df.execution_plan() + found_metrics = False + for _, ms in plan.collect_metrics(): + if ms.output_rows is not None and ms.output_rows > 0: + found_metrics = True + assert found_metrics + + +def test_execute_stream_partitioned_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + for stream in df.execute_stream_partitioned(): + for _ in stream: + pass + + plan = df.execution_plan() + found_metrics = False + for _, ms in plan.collect_metrics(): + if ms.output_rows is not None and ms.output_rows > 0: + found_metrics = True + assert found_metrics diff --git a/src/dataframe.rs b/src/dataframe.rs index fe039593d..06e9ef2af 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -48,6 +48,14 @@ use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods}; +use datafusion::physical_plan::{ + ExecutionPlan as DFExecutionPlan, + collect as df_collect, + collect_partitioned as df_collect_partitioned, + execute_stream as df_execute_stream, + execute_stream_partitioned as df_execute_stream_partitioned, +}; + use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err}; use crate::expr::PyExpr; use crate::expr::sort_expr::{PySortExpr, to_sort_expressions}; @@ -290,6 +298,9 @@ pub struct PyDataFrame { // In IPython environment cache batches between __repr__ and _repr_html_ calls. batches: SharedCachedBatches, + + // Cache the last physical plan so that metrics are available after execution. + last_plan: Arc>>>, } impl PyDataFrame { @@ -298,6 +309,7 @@ impl PyDataFrame { Self { df: Arc::new(df), batches: Arc::new(Mutex::new(None)), + last_plan: Arc::new(Mutex::new(None)), } } @@ -627,7 +639,12 @@ impl PyDataFrame { /// Unless some order is specified in the plan, there is no /// guarantee of the order of the result. fn collect<'py>(&self, py: Python<'py>) -> PyResult>> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect())? + let df = self.df.as_ref().clone(); + let plan = wait_for_future(py, df.create_physical_plan())? + .map_err(PyDataFusionError::from)?; + *self.last_plan.lock() = Some(Arc::clone(&plan)); + let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + let batches = wait_for_future(py, df_collect(plan, task_ctx))? .map_err(PyDataFusionError::from)?; // cannot use PyResult> return type due to // https://github.com/PyO3/pyo3/issues/1813 @@ -643,7 +660,12 @@ impl PyDataFrame { /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch /// maintaining the input partitioning. fn collect_partitioned<'py>(&self, py: Python<'py>) -> PyResult>>> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())? + let df = self.df.as_ref().clone(); + let plan = wait_for_future(py, df.create_physical_plan())? + .map_err(PyDataFusionError::from)?; + *self.last_plan.lock() = Some(Arc::clone(&plan)); + let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + let batches = wait_for_future(py, df_collect_partitioned(plan, task_ctx))? .map_err(PyDataFusionError::from)?; batches @@ -803,7 +825,13 @@ impl PyDataFrame { } /// Get the execution plan for this `DataFrame` + /// + /// If the DataFrame has already been executed (e.g. via `collect()`), + /// returns the cached plan which includes populated metrics. fn execution_plan(&self, py: Python) -> PyDataFusionResult { + if let Some(plan) = self.last_plan.lock().as_ref() { + return Ok(PyExecutionPlan::new(Arc::clone(plan))); + } let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??; Ok(plan.into()) } @@ -1128,13 +1156,22 @@ impl PyDataFrame { fn execute_stream(&self, py: Python) -> PyDataFusionResult { let df = self.df.as_ref().clone(); - let stream = spawn_future(py, async move { df.execute_stream().await })?; + let plan = wait_for_future(py, df.create_physical_plan())??; + *self.last_plan.lock() = Some(Arc::clone(&plan)); + let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + let stream = spawn_future(py, async move { df_execute_stream(plan, task_ctx) })?; Ok(PyRecordBatchStream::new(stream)) } fn execute_stream_partitioned(&self, py: Python) -> PyResult> { let df = self.df.as_ref().clone(); - let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?; + let plan = wait_for_future(py, df.create_physical_plan())? + .map_err(PyDataFusionError::from)?; + *self.last_plan.lock() = Some(Arc::clone(&plan)); + let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + let streams = spawn_future(py, async move { + df_execute_stream_partitioned(plan, task_ctx) + })?; Ok(streams.into_iter().map(PyRecordBatchStream::new).collect()) } diff --git a/src/lib.rs b/src/lib.rs index 081366b20..7c21ae95c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,7 @@ pub mod errors; pub mod expr; #[allow(clippy::borrow_deref_ref)] mod functions; +pub mod metrics; mod options; pub mod physical_plan; mod pyarrow_filter_expression; @@ -96,6 +97,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 000000000..e333ea791 --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::physical_plan::metrics::{MetricValue, MetricsSet, Metric}; +use pyo3::prelude::*; + +#[pyclass(frozen, name = "MetricsSet", module = "datafusion")] +#[derive(Debug, Clone)] +pub struct PyMetricsSet { + metrics: MetricsSet, +} + +impl PyMetricsSet { + pub fn new(metrics: MetricsSet) -> Self { + Self { metrics } + } +} + +#[pymethods] +impl PyMetricsSet { + /// Returns all individual metrics in this set. + fn metrics(&self) -> Vec { + self.metrics + .iter() + .map(|m| PyMetric::new(Arc::clone(m))) + .collect() + } + + /// Returns the sum of all `output_rows` metrics, or None if not present. + fn output_rows(&self) -> Option { + self.metrics.output_rows() + } + + /// Returns the sum of all `elapsed_compute` metrics in nanoseconds, or None if not present. + fn elapsed_compute(&self) -> Option { + self.metrics.elapsed_compute() + } + + /// Returns the sum of all `spill_count` metrics, or None if not present. + fn spill_count(&self) -> Option { + self.metrics.spill_count() + } + + /// Returns the sum of all `spilled_bytes` metrics, or None if not present. + fn spilled_bytes(&self) -> Option { + self.metrics.spilled_bytes() + } + + /// Returns the sum of all `spilled_rows` metrics, or None if not present. + fn spilled_rows(&self) -> Option { + self.metrics.spilled_rows() + } + + /// Returns the sum of metrics matching the given name. + fn sum_by_name(&self, name: &str) -> Option { + self.metrics.sum_by_name(name).map(|v| v.as_usize()) + } + + fn __repr__(&self) -> String { + format!("{}", self.metrics) + } +} + +#[pyclass(frozen, name = "Metric", module = "datafusion")] +#[derive(Debug, Clone)] +pub struct PyMetric { + metric: Arc, +} + +impl PyMetric { + pub fn new(metric: Arc) -> Self { + Self { metric } + } +} + +#[pymethods] +impl PyMetric { + /// Returns the name of this metric. + #[getter] + fn name(&self) -> String { + self.metric.value().name().to_string() + } + + /// Returns the numeric value of this metric, or None for non-numeric types. + #[getter] + fn value(&self) -> Option { + match self.metric.value() { + MetricValue::OutputRows(c) => Some(c.value()), + MetricValue::OutputBytes(c) => Some(c.value()), + MetricValue::ElapsedCompute(t) => Some(t.value()), + MetricValue::SpillCount(c) => Some(c.value()), + MetricValue::SpilledBytes(c) => Some(c.value()), + MetricValue::SpilledRows(c) => Some(c.value()), + MetricValue::CurrentMemoryUsage(g) => Some(g.value()), + MetricValue::Count { count, .. } => Some(count.value()), + MetricValue::Gauge { gauge, .. } => Some(gauge.value()), + MetricValue::Time { time, .. } => Some(time.value()), + MetricValue::StartTimestamp(ts) => { + ts.value().and_then(|dt| dt.timestamp_nanos_opt().map(|n| n as usize)) + } + MetricValue::EndTimestamp(ts) => { + ts.value().and_then(|dt| dt.timestamp_nanos_opt().map(|n| n as usize)) + } + _ => None, + } + } + + /// Returns the partition this metric is for, or None if it applies to all partitions. + #[getter] + fn partition(&self) -> Option { + self.metric.partition() + } + + /// Returns the labels associated with this metric as a dict. + fn labels(&self) -> HashMap { + self.metric + .labels() + .iter() + .map(|l| (l.name().to_string(), l.value().to_string())) + .collect() + } + + fn __repr__(&self) -> String { + format!("{}", self.metric.value()) + } +} diff --git a/src/physical_plan.rs b/src/physical_plan.rs index 0069e5e6e..319d27efe 100644 --- a/src/physical_plan.rs +++ b/src/physical_plan.rs @@ -26,6 +26,7 @@ use pyo3::types::PyBytes; use crate::context::PySessionContext; use crate::errors::PyDataFusionResult; +use crate::metrics::PyMetricsSet; #[pyclass(frozen, name = "ExecutionPlan", module = "datafusion", subclass)] #[derive(Debug, Clone)] @@ -90,6 +91,11 @@ impl PyExecutionPlan { Ok(Self::new(plan)) } + /// Returns metrics for this plan node after execution, or None if unavailable. + pub fn metrics(&self) -> Option { + self.plan.metrics().map(PyMetricsSet::new) + } + fn __repr__(&self) -> String { self.display_indent() }