From f91bafa3e134a7e2196fdce9679e343e23c1b4bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herman=20van=20H=C3=B6vell?= Date: Mon, 16 Mar 2026 19:29:16 +0000 Subject: [PATCH 1/3] [SPARK-56007][CONNECT] Fix ArrowDeserializer to use positional binding for RowEncoder and validate schema Switch RowEncoder deserialization from name-based lookup to positional binding to correctly handle duplicate column names. Add field-count and field-name mismatch error conditions with new tests. Co-authored-by: Isaac --- .../resources/error/error-conditions.json | 12 ++++ .../client/arrow/ArrowEncoderSuite.scala | 57 ++++++++++++++++++- .../client/arrow/ArrowDeserializer.scala | 21 ++++++- 3 files changed, 85 insertions(+), 5 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 63c54a71b904b..7b0e81d383cad 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -105,6 +105,18 @@ }, "sqlState" : "42604" }, + "ARROW_SCHEMA_FIELD_COUNT_MISMATCH" : { + "message" : [ + "Arrow schema mismatch: encoder has fields but Arrow data only has columns." + ], + "sqlState" : "42000" + }, + "ARROW_SCHEMA_FIELD_NAME_MISMATCH" : { + "message" : [ + "Arrow schema mismatch: encoder field '' does not match Arrow column '' at the same position." + ], + "sqlState" : "42000" + }, "AVRO_INCOMPATIBLE_READ_TYPE" : { "message" : [ "Cannot convert Avro to SQL because the original encoded data type is , however you're trying to read the field as , which would lead to an incorrect answer.", diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 10e4c11c406fe..20693c3559b9c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -739,6 +739,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } test("duplicate fields") { + // Arrow data with [foO, Foo] decoded into [foo]: positional matching binds foo → foO (pos 0), + // and the extra Foo column is ignored (over-complete schema is allowed). val duplicateSchemaEncoder = toRowEncoder( new StructType() .add("foO", "string") @@ -748,13 +750,64 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { .add("foo", "string")) withAllocator { allocator => val arrowBatches = serializeToArrow(Iterator.empty, duplicateSchemaEncoder, allocator) - intercept[AnalysisException] { + // Should not throw: RowEncoder uses positional binding, so foo binds to foO at position 0. + val result = ArrowDeserializers.deserializeFromArrow( + arrowBatches, + fooSchemaEncoder, + allocator, + timeZoneId = "UTC") + assert(!result.hasNext) + result.close() + } + } + + test("row with duplicate column names") { + // Spark DataFrames allow duplicate column names. collect() must round-trip such rows + // without throwing AMBIGUOUS_COLUMN_OR_FIELD. + val schema = new StructType() + .add("channel", "string") + .add("channel", "string") + val encoder = toRowEncoder(schema) + val rows = Seq(Row("a", "b"), Row("c", "d"), Row(null, "e")) + val iterator = roundTrip(encoder, rows.iterator) + try { + compareIterators(rows.iterator, iterator) + } finally { + iterator.close() + } + } + + test("row schema validation - column name mismatch") { + val serializeEncoder = toRowEncoder(new StructType().add("a", "string").add("b", "string")) + val deserializeEncoder = toRowEncoder(new StructType().add("a", "string").add("x", "string")) + withAllocator { allocator => + val arrowBatches = serializeToArrow(Iterator.empty, serializeEncoder, allocator) + val e = intercept[SparkRuntimeException] { + ArrowDeserializers.deserializeFromArrow( + arrowBatches, + deserializeEncoder, + allocator, + timeZoneId = "UTC") + } + assert(e.getCondition == "ARROW_SCHEMA_FIELD_NAME_MISMATCH") + arrowBatches.close() + } + } + + test("row schema validation - encoder has more fields than Arrow data") { + val serializeEncoder = toRowEncoder(new StructType().add("a", "string")) + val deserializeEncoder = + toRowEncoder(new StructType().add("a", "string").add("b", "string")) + withAllocator { allocator => + val arrowBatches = serializeToArrow(Iterator.empty, serializeEncoder, allocator) + val e = intercept[SparkRuntimeException] { ArrowDeserializers.deserializeFromArrow( arrowBatches, - fooSchemaEncoder, + deserializeEncoder, allocator, timeZoneId = "UTC") } + assert(e.getCondition == "ARROW_SCHEMA_FIELD_COUNT_MISMATCH") arrowBatches.close() } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index f3abaddb0110b..6a014747ce6b6 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -33,6 +33,7 @@ import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ @@ -325,9 +326,23 @@ object ArrowDeserializers { } case (r @ RowEncoder(fields), StructVectors(struct, vectors)) => - val lookup = createFieldLookup(vectors) - val deserializers = fields.toArray.map { field => - deserializerFor(field.enc, lookup(field.name), timeZoneId) + // Row allows duplicate column names, so bind by position rather than by name. + if (fields.length > vectors.length) { + throw new SparkRuntimeException( + errorClass = "ARROW_SCHEMA_FIELD_COUNT_MISMATCH", + messageParameters = Map( + "encoderFieldCount" -> fields.length.toString, + "arrowColumnCount" -> vectors.length.toString)) + } + val deserializers = fields.toArray.zip(vectors).map { case (field, vector) => + if (!field.name.equalsIgnoreCase(vector.getName)) { + throw new SparkRuntimeException( + errorClass = "ARROW_SCHEMA_FIELD_NAME_MISMATCH", + messageParameters = Map( + "encoderFieldName" -> field.name, + "arrowColumnName" -> vector.getName)) + } + deserializerFor(field.enc, vector, timeZoneId) } new StructFieldSerializer[Any](struct) { def value(i: Int): Any = { From be167862b8c47e5faabd2ff4b30298e1f9c89947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herman=20van=20H=C3=B6vell?= Date: Thu, 19 Mar 2026 22:15:25 +0000 Subject: [PATCH 2/3] [SPARK-56007][CONNECT] Fix ArrowEncoderSuite bind-to-schema test Fix the `bind to schema` test: - Correct `wideSchemaEncoder` (remove stray `a: int` field) - Fix narrow schema field order (C before d) and element struct fields (da, db not da, dc) - Supply complete wide-schema input rows (include dc boolean in d elements) - Correct expected output to match narrow schema projection - Add try/finally to ensure both iterators are always closed - Fix `unknown field` to expect `SparkRuntimeException` not `AnalysisException` Co-authored-by: Isaac --- .../client/arrow/ArrowEncoderSuite.scala | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index d7307c5dcb738..20458e61e6728 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -30,7 +30,7 @@ import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector import org.apache.spark.{SparkRuntimeException, SparkUnsupportedOperationException} -import org.apache.spark.sql.{AnalysisException, Encoders, Row} +import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec, OuterScopes} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} @@ -807,7 +807,6 @@ class ArrowEncoderSuite extends ConnectFunSuite { private val wideSchemaEncoder = toRowEncoder( new StructType() - .add("a", "int") .add("b", "string") .add( "c", @@ -826,31 +825,30 @@ class ArrowEncoderSuite extends ConnectFunSuite { private val narrowSchemaEncoder = toRowEncoder( new StructType() .add("b", "string") + .add( + "C", + new StructType() + .add("Ca", "array") + .add("Cb", "binary")) .add( "d", ArrayType( new StructType() .add("da", "decimal(20, 10)") - .add("dc", "boolean"))) - .add( - "C", - new StructType() - .add("Ca", "array") - .add("Cb", "binary"))) + .add("db", "string")))) test("bind to schema") { - // Binds to a wider schema. The narrow schema has fewer (nested) fields, has a slightly - // different field order, and uses different cased names in a couple of places. + // Binds to a wider schema. The narrow schema has fewer (nested) fields, and uses different + // cased names in a couple of places. withAllocator { allocator => val input = Row( - 887, "foo", Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte), 5f), Seq(Row(null, "a", false), Row(javaBigDecimal(57853, 10), "b", false))) val expected = Row( "foo", - Seq(Row(null, false), Row(javaBigDecimal(57853, 10), false)), - Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte))) + Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte)), + Seq(Row(null, "a"), Row(javaBigDecimal(57853, 10), "b"))) val arrowBatches = serializeToArrow(Iterator.single(input), wideSchemaEncoder, allocator) val result = ArrowDeserializers.deserializeFromArrow( @@ -858,18 +856,21 @@ class ArrowEncoderSuite extends ConnectFunSuite { narrowSchemaEncoder, allocator, timeZoneId = "UTC") - val actual = result.next() - assert(result.isEmpty) - assert(expected === actual) - result.close() - arrowBatches.close() + try { + val actual = result.next() + assert(result.isEmpty) + assert(expected === actual) + } finally { + result.close() + arrowBatches.close() + } } } test("unknown field") { withAllocator { allocator => val arrowBatches = serializeToArrow(Iterator.empty, narrowSchemaEncoder, allocator) - intercept[AnalysisException] { + intercept[SparkRuntimeException] { ArrowDeserializers.deserializeFromArrow( arrowBatches, wideSchemaEncoder, From ef095953da5484a93c97a22a59ad78c7db388a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herman=20van=20H=C3=B6vell?= Date: Fri, 20 Mar 2026 16:06:58 +0000 Subject: [PATCH 3/3] Fixes... --- .../resources/error/error-conditions.json | 24 +++++++++---------- .../client/arrow/ArrowEncoderSuite.scala | 1 + .../client/arrow/ArrowDeserializer.scala | 5 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 5d39bbbcc131d..f9b24924e5960 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -132,6 +132,18 @@ ], "sqlState" : "22003" }, + "ARROW_SCHEMA_FIELD_COUNT_MISMATCH" : { + "message" : [ + "Arrow schema mismatch: encoder has fields but Arrow data only has columns." + ], + "sqlState" : "42000" + }, + "ARROW_SCHEMA_FIELD_NAME_MISMATCH" : { + "message" : [ + "Arrow schema mismatch: encoder field '' does not match Arrow column '' at the same position." + ], + "sqlState" : "42000" + }, "ARROW_TYPE_MISMATCH" : { "message" : [ "Invalid schema from : expected , got ." @@ -173,18 +185,6 @@ }, "sqlState" : "42604" }, - "ARROW_SCHEMA_FIELD_COUNT_MISMATCH" : { - "message" : [ - "Arrow schema mismatch: encoder has fields but Arrow data only has columns." - ], - "sqlState" : "42000" - }, - "ARROW_SCHEMA_FIELD_NAME_MISMATCH" : { - "message" : [ - "Arrow schema mismatch: encoder field '' does not match Arrow column '' at the same position." - ], - "sqlState" : "42000" - }, "ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION" : { "message" : [ "Operations that trigger DataFrame analysis or execution are not allowed in pipeline query functions. Move code outside of the pipeline query function." diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 20458e61e6728..ac83300ac20ac 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -900,6 +900,7 @@ class ArrowEncoderSuite extends ConnectFunSuite { allocator, timeZoneId = "UTC") assert(!result.hasNext) + arrowBatches.close() result.close() } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 1044c023800d9..cc853f3c8a8c6 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -343,9 +343,8 @@ object ArrowDeserializers { if (!field.name.equalsIgnoreCase(vector.getName)) { throw new SparkRuntimeException( errorClass = "ARROW_SCHEMA_FIELD_NAME_MISMATCH", - messageParameters = Map( - "encoderFieldName" -> field.name, - "arrowColumnName" -> vector.getName)) + messageParameters = + Map("encoderFieldName" -> field.name, "arrowColumnName" -> vector.getName)) } deserializerFor(field.enc, vector, timeZoneId) }