From 5ad4a4dc44052f6f7c836564fc375f90a8113448 Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Tue, 9 Sep 2025 15:20:34 +0200 Subject: [PATCH] fix: mangled errors closes #1226 --- python/tests/test_catalog.py | 33 +++++++++++++++++++++++++++++++++ python/tests/test_sql.py | 5 ++++- src/catalog.rs | 3 ++- src/context.rs | 7 +++++-- src/errors.rs | 12 +++++++++++- 5 files changed, 55 insertions(+), 5 deletions(-) diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index dd4c82469..71c08da26 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -81,6 +81,12 @@ def table_exist(self, name: str) -> bool: return name in self.tables +class CustomErrorSchemaProvider(CustomSchemaProvider): + def table(self, name: str) -> Table | None: + message = f"{name} is not an acceptable name" + raise ValueError(message) + + class CustomCatalogProvider(dfn.catalog.CatalogProvider): def __init__(self): self.schemas = {"my_schema": CustomSchemaProvider()} @@ -219,6 +225,33 @@ def test_schema_register_table_with_pyarrow_dataset(ctx: SessionContext): schema.deregister_table(table_name) +def test_exception_not_mangled(ctx: SessionContext): + """Test registering all python providers and running a query against them.""" + + catalog_name = "custom_catalog" + schema_name = "custom_schema" + + ctx.register_catalog_provider(catalog_name, CustomCatalogProvider()) + + catalog = ctx.catalog(catalog_name) + + # Clean out previous schemas if they exist so we can start clean + for schema_name in catalog.schema_names(): + catalog.deregister_schema(schema_name, cascade=False) + + catalog.register_schema(schema_name, CustomErrorSchemaProvider()) + + schema = catalog.schema(schema_name) + + for table_name in schema.table_names(): + schema.deregister_table(table_name) + + schema.register_table("test_table", create_dataset()) + + with pytest.raises(ValueError, match="^test_table is not an acceptable name$"): + ctx.sql(f"select * from {catalog_name}.{schema_name}.test_table") + + def test_in_end_to_end_python_providers(ctx: SessionContext): """Test registering all python providers and running a query against them.""" diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 48c374660..12710cf08 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -29,7 +29,10 @@ def test_no_table(ctx): - with pytest.raises(Exception, match="DataFusion error"): + with pytest.raises( + ValueError, + match="^Error during planning: table 'datafusion.public.b' not found$", + ): ctx.sql("SELECT a FROM b").collect() diff --git a/src/catalog.rs b/src/catalog.rs index b5b983970..d10d5b8b3 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -364,7 +364,8 @@ impl SchemaProvider for RustWrappedPySchemaProvider { &self, name: &str, ) -> datafusion::common::Result>, DataFusionError> { - self.table_inner(name).map_err(to_datafusion_err) + self.table_inner(name) + .map_err(|e| DataFusionError::External(Box::new(e))) } fn register_table( diff --git a/src/context.rs b/src/context.rs index 89bbe9344..fc3d595c1 100644 --- a/src/context.rs +++ b/src/context.rs @@ -65,7 +65,9 @@ use crate::catalog::{ use crate::common::data_type::PyScalarValue; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; -use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult}; +use crate::errors::{ + from_datafusion_error, py_datafusion_err, PyDataFusionError, PyDataFusionResult, +}; use crate::expr::sort_expr::PySortExpr; use crate::options::PyCsvReadOptions; use crate::physical_plan::PyExecutionPlan; @@ -465,7 +467,8 @@ impl PySessionContext { let mut df = wait_for_future(py, async { self.ctx.sql_with_options(&query, options).await - })??; + })? + .map_err(from_datafusion_error)?; if !param_values.is_empty() { df = df.with_param_values(param_values)?; diff --git a/src/errors.rs b/src/errors.rs index d1b518042..108072101 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -22,7 +22,7 @@ use std::fmt::Debug; use datafusion::arrow::error::ArrowError; use datafusion::error::DataFusionError as InnerDataFusionError; use prost::EncodeError; -use pyo3::exceptions::PyException; +use pyo3::exceptions::{PyException, PyValueError}; use pyo3::PyErr; pub type PyDataFusionResult = std::result::Result; @@ -96,3 +96,13 @@ pub fn py_unsupported_variant_err(e: impl Debug) -> PyErr { pub fn to_datafusion_err(e: impl Debug) -> InnerDataFusionError { InnerDataFusionError::Execution(format!("{e:?}")) } + +pub fn from_datafusion_error(err: InnerDataFusionError) -> PyErr { + match err { + InnerDataFusionError::External(boxed) => match boxed.downcast::() { + Ok(py_err) => *py_err, + Err(original_boxed) => PyValueError::new_err(format!("{original_boxed}")), + }, + _ => PyValueError::new_err(format!("{err}")), + } +}