Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,18 @@
],
"sqlState" : "22003"
},
"ARROW_SCHEMA_FIELD_COUNT_MISMATCH" : {
"message" : [
"Arrow schema mismatch: encoder has <encoderFieldCount> fields but Arrow data only has <arrowColumnCount> columns."
],
"sqlState" : "42000"
},
"ARROW_SCHEMA_FIELD_NAME_MISMATCH" : {
"message" : [
"Arrow schema mismatch: encoder field '<encoderFieldName>' does not match Arrow column '<arrowColumnName>' at the same position."
],
"sqlState" : "42000"
},
"ARROW_TYPE_MISMATCH" : {
"message" : [
"Invalid schema from <operation>: expected <outputTypes>, got <actualDataTypes>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
import org.apache.arrow.vector.VarBinaryVector

import org.apache.spark.{SparkRuntimeException, SparkUnsupportedOperationException}
import org.apache.spark.sql.{AnalysisException, Encoders, Row}
import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec, OuterScopes}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder}
Expand Down Expand Up @@ -807,7 +807,6 @@ class ArrowEncoderSuite extends ConnectFunSuite {

private val wideSchemaEncoder = toRowEncoder(
new StructType()
.add("a", "int")
.add("b", "string")
.add(
"c",
Expand All @@ -826,50 +825,52 @@ class ArrowEncoderSuite extends ConnectFunSuite {
private val narrowSchemaEncoder = toRowEncoder(
new StructType()
.add("b", "string")
.add(
"C",
new StructType()
.add("Ca", "array<int>")
.add("Cb", "binary"))
.add(
"d",
ArrayType(
new StructType()
.add("da", "decimal(20, 10)")
.add("dc", "boolean")))
.add(
"C",
new StructType()
.add("Ca", "array<int>")
.add("Cb", "binary")))
.add("db", "string"))))

test("bind to schema") {
// Binds to a wider schema. The narrow schema has fewer (nested) fields, has a slightly
// different field order, and uses different cased names in a couple of places.
// Binds to a wider schema. The narrow schema has fewer (nested) fields, and uses different
// cased names in a couple of places.
withAllocator { allocator =>
val input = Row(
887,
"foo",
Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte), 5f),
Seq(Row(null, "a", false), Row(javaBigDecimal(57853, 10), "b", false)))
val expected = Row(
"foo",
Seq(Row(null, false), Row(javaBigDecimal(57853, 10), false)),
Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte)))
Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte)),
Seq(Row(null, "a"), Row(javaBigDecimal(57853, 10), "b")))
val arrowBatches = serializeToArrow(Iterator.single(input), wideSchemaEncoder, allocator)
val result =
ArrowDeserializers.deserializeFromArrow(
arrowBatches,
narrowSchemaEncoder,
allocator,
timeZoneId = "UTC")
val actual = result.next()
assert(result.isEmpty)
assert(expected === actual)
result.close()
arrowBatches.close()
try {
val actual = result.next()
assert(result.isEmpty)
assert(expected === actual)
} finally {
result.close()
arrowBatches.close()
}
}
}

test("unknown field") {
withAllocator { allocator =>
val arrowBatches = serializeToArrow(Iterator.empty, narrowSchemaEncoder, allocator)
intercept[AnalysisException] {
intercept[SparkRuntimeException] {
ArrowDeserializers.deserializeFromArrow(
arrowBatches,
wideSchemaEncoder,
Expand All @@ -881,6 +882,8 @@ class ArrowEncoderSuite extends ConnectFunSuite {
}

test("duplicate fields") {
// Arrow data with [foO, Foo] decoded into [foo]: positional matching binds foo → foO (pos 0),
// and the extra Foo column is ignored (over-complete schema is allowed).
val duplicateSchemaEncoder = toRowEncoder(
new StructType()
.add("foO", "string")
Expand All @@ -890,13 +893,65 @@ class ArrowEncoderSuite extends ConnectFunSuite {
.add("foo", "string"))
withAllocator { allocator =>
val arrowBatches = serializeToArrow(Iterator.empty, duplicateSchemaEncoder, allocator)
intercept[AnalysisException] {
// Should not throw: RowEncoder uses positional binding, so foo binds to foO at position 0.
val result = ArrowDeserializers.deserializeFromArrow(
arrowBatches,
fooSchemaEncoder,
allocator,
timeZoneId = "UTC")
assert(!result.hasNext)
arrowBatches.close()
result.close()
}
}

test("row with duplicate column names") {
// Spark DataFrames allow duplicate column names. collect() must round-trip such rows
// without throwing AMBIGUOUS_COLUMN_OR_FIELD.
val schema = new StructType()
.add("channel", "string")
.add("channel", "string")
val encoder = toRowEncoder(schema)
val rows = Seq(Row("a", "b"), Row("c", "d"), Row(null, "e"))
val iterator = roundTrip(encoder, rows.iterator)
try {
compareIterators(rows.iterator, iterator)
} finally {
iterator.close()
}
}

test("row schema validation - column name mismatch") {
val serializeEncoder = toRowEncoder(new StructType().add("a", "string").add("b", "string"))
val deserializeEncoder = toRowEncoder(new StructType().add("a", "string").add("x", "string"))
withAllocator { allocator =>
val arrowBatches = serializeToArrow(Iterator.empty, serializeEncoder, allocator)
val e = intercept[SparkRuntimeException] {
ArrowDeserializers.deserializeFromArrow(
arrowBatches,
deserializeEncoder,
allocator,
timeZoneId = "UTC")
}
assert(e.getCondition == "ARROW_SCHEMA_FIELD_NAME_MISMATCH")
arrowBatches.close()
}
}

test("row schema validation - encoder has more fields than Arrow data") {
val serializeEncoder = toRowEncoder(new StructType().add("a", "string"))
val deserializeEncoder =
toRowEncoder(new StructType().add("a", "string").add("b", "string"))
withAllocator { allocator =>
val arrowBatches = serializeToArrow(Iterator.empty, serializeEncoder, allocator)
val e = intercept[SparkRuntimeException] {
ArrowDeserializers.deserializeFromArrow(
arrowBatches,
fooSchemaEncoder,
deserializeEncoder,
allocator,
timeZoneId = "UTC")
}
assert(e.getCondition == "ARROW_SCHEMA_FIELD_COUNT_MISMATCH")
arrowBatches.close()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
import org.apache.arrow.vector.ipc.ArrowReader

import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
Expand Down Expand Up @@ -330,9 +331,22 @@ object ArrowDeserializers {
}

case (r @ RowEncoder(fields), StructVectors(struct, vectors)) =>
val lookup = createFieldLookup(vectors)
val deserializers = fields.toArray.map { field =>
deserializerFor(field.enc, lookup(field.name), timeZoneId)
// Row allows duplicate column names, so bind by position rather than by name.
if (fields.length > vectors.length) {
throw new SparkRuntimeException(
errorClass = "ARROW_SCHEMA_FIELD_COUNT_MISMATCH",
messageParameters = Map(
"encoderFieldCount" -> fields.length.toString,
"arrowColumnCount" -> vectors.length.toString))
}
val deserializers = fields.toArray.zip(vectors).map { case (field, vector) =>
if (!field.name.equalsIgnoreCase(vector.getName)) {
throw new SparkRuntimeException(
errorClass = "ARROW_SCHEMA_FIELD_NAME_MISMATCH",
messageParameters =
Map("encoderFieldName" -> field.name, "arrowColumnName" -> vector.getName))
}
deserializerFor(field.enc, vector, timeZoneId)
}
new StructFieldSerializer[Any](struct) {
def value(i: Int): Any = {
Expand Down