diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index f8daf2c64d99..f9b24924e596 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -132,6 +132,18 @@ ], "sqlState" : "22003" }, + "ARROW_SCHEMA_FIELD_COUNT_MISMATCH" : { + "message" : [ + "Arrow schema mismatch: encoder has fields but Arrow data only has columns." + ], + "sqlState" : "42000" + }, + "ARROW_SCHEMA_FIELD_NAME_MISMATCH" : { + "message" : [ + "Arrow schema mismatch: encoder field '' does not match Arrow column '' at the same position." + ], + "sqlState" : "42000" + }, "ARROW_TYPE_MISMATCH" : { "message" : [ "Invalid schema from : expected , got ." diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 5cd7a3a2acde..ac83300ac20a 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -30,7 +30,7 @@ import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector import org.apache.spark.{SparkRuntimeException, SparkUnsupportedOperationException} -import org.apache.spark.sql.{AnalysisException, Encoders, Row} +import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec, OuterScopes} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} @@ -807,7 +807,6 @@ class ArrowEncoderSuite extends ConnectFunSuite { private val wideSchemaEncoder = toRowEncoder( new StructType() - .add("a", "int") .add("b", "string") .add( "c", @@ -826,31 +825,30 @@ class ArrowEncoderSuite extends ConnectFunSuite { private val narrowSchemaEncoder = toRowEncoder( new StructType() .add("b", "string") + .add( + "C", + new StructType() + .add("Ca", "array") + .add("Cb", "binary")) .add( "d", ArrayType( new StructType() .add("da", "decimal(20, 10)") - .add("dc", "boolean"))) - .add( - "C", - new StructType() - .add("Ca", "array") - .add("Cb", "binary"))) + .add("db", "string")))) test("bind to schema") { - // Binds to a wider schema. The narrow schema has fewer (nested) fields, has a slightly - // different field order, and uses different cased names in a couple of places. + // Binds to a wider schema. The narrow schema has fewer (nested) fields, and uses different + // cased names in a couple of places. withAllocator { allocator => val input = Row( - 887, "foo", Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte), 5f), Seq(Row(null, "a", false), Row(javaBigDecimal(57853, 10), "b", false))) val expected = Row( "foo", - Seq(Row(null, false), Row(javaBigDecimal(57853, 10), false)), - Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte))) + Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte)), + Seq(Row(null, "a"), Row(javaBigDecimal(57853, 10), "b"))) val arrowBatches = serializeToArrow(Iterator.single(input), wideSchemaEncoder, allocator) val result = ArrowDeserializers.deserializeFromArrow( @@ -858,18 +856,21 @@ class ArrowEncoderSuite extends ConnectFunSuite { narrowSchemaEncoder, allocator, timeZoneId = "UTC") - val actual = result.next() - assert(result.isEmpty) - assert(expected === actual) - result.close() - arrowBatches.close() + try { + val actual = result.next() + assert(result.isEmpty) + assert(expected === actual) + } finally { + result.close() + arrowBatches.close() + } } } test("unknown field") { withAllocator { allocator => val arrowBatches = serializeToArrow(Iterator.empty, narrowSchemaEncoder, allocator) - intercept[AnalysisException] { + intercept[SparkRuntimeException] { ArrowDeserializers.deserializeFromArrow( arrowBatches, wideSchemaEncoder, @@ -881,6 +882,8 @@ class ArrowEncoderSuite extends ConnectFunSuite { } test("duplicate fields") { + // Arrow data with [foO, Foo] decoded into [foo]: positional matching binds foo → foO (pos 0), + // and the extra Foo column is ignored (over-complete schema is allowed). val duplicateSchemaEncoder = toRowEncoder( new StructType() .add("foO", "string") @@ -890,13 +893,65 @@ class ArrowEncoderSuite extends ConnectFunSuite { .add("foo", "string")) withAllocator { allocator => val arrowBatches = serializeToArrow(Iterator.empty, duplicateSchemaEncoder, allocator) - intercept[AnalysisException] { + // Should not throw: RowEncoder uses positional binding, so foo binds to foO at position 0. + val result = ArrowDeserializers.deserializeFromArrow( + arrowBatches, + fooSchemaEncoder, + allocator, + timeZoneId = "UTC") + assert(!result.hasNext) + arrowBatches.close() + result.close() + } + } + + test("row with duplicate column names") { + // Spark DataFrames allow duplicate column names. collect() must round-trip such rows + // without throwing AMBIGUOUS_COLUMN_OR_FIELD. + val schema = new StructType() + .add("channel", "string") + .add("channel", "string") + val encoder = toRowEncoder(schema) + val rows = Seq(Row("a", "b"), Row("c", "d"), Row(null, "e")) + val iterator = roundTrip(encoder, rows.iterator) + try { + compareIterators(rows.iterator, iterator) + } finally { + iterator.close() + } + } + + test("row schema validation - column name mismatch") { + val serializeEncoder = toRowEncoder(new StructType().add("a", "string").add("b", "string")) + val deserializeEncoder = toRowEncoder(new StructType().add("a", "string").add("x", "string")) + withAllocator { allocator => + val arrowBatches = serializeToArrow(Iterator.empty, serializeEncoder, allocator) + val e = intercept[SparkRuntimeException] { + ArrowDeserializers.deserializeFromArrow( + arrowBatches, + deserializeEncoder, + allocator, + timeZoneId = "UTC") + } + assert(e.getCondition == "ARROW_SCHEMA_FIELD_NAME_MISMATCH") + arrowBatches.close() + } + } + + test("row schema validation - encoder has more fields than Arrow data") { + val serializeEncoder = toRowEncoder(new StructType().add("a", "string")) + val deserializeEncoder = + toRowEncoder(new StructType().add("a", "string").add("b", "string")) + withAllocator { allocator => + val arrowBatches = serializeToArrow(Iterator.empty, serializeEncoder, allocator) + val e = intercept[SparkRuntimeException] { ArrowDeserializers.deserializeFromArrow( arrowBatches, - fooSchemaEncoder, + deserializeEncoder, allocator, timeZoneId = "UTC") } + assert(e.getCondition == "ARROW_SCHEMA_FIELD_COUNT_MISMATCH") arrowBatches.close() } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 82029025a7f0..cc853f3c8a8c 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -33,6 +33,7 @@ import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ @@ -330,9 +331,22 @@ object ArrowDeserializers { } case (r @ RowEncoder(fields), StructVectors(struct, vectors)) => - val lookup = createFieldLookup(vectors) - val deserializers = fields.toArray.map { field => - deserializerFor(field.enc, lookup(field.name), timeZoneId) + // Row allows duplicate column names, so bind by position rather than by name. + if (fields.length > vectors.length) { + throw new SparkRuntimeException( + errorClass = "ARROW_SCHEMA_FIELD_COUNT_MISMATCH", + messageParameters = Map( + "encoderFieldCount" -> fields.length.toString, + "arrowColumnCount" -> vectors.length.toString)) + } + val deserializers = fields.toArray.zip(vectors).map { case (field, vector) => + if (!field.name.equalsIgnoreCase(vector.getName)) { + throw new SparkRuntimeException( + errorClass = "ARROW_SCHEMA_FIELD_NAME_MISMATCH", + messageParameters = + Map("encoderFieldName" -> field.name, "arrowColumnName" -> vector.getName)) + } + deserializerFor(field.enc, vector, timeZoneId) } new StructFieldSerializer[Any](struct) { def value(i: Int): Any = {