Skip to content
8 changes: 8 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -6541,6 +6541,14 @@
],
"sqlState" : "42601"
},
"STORAGE_PARTITION_JOIN_INCOMPATIBLE_REDUCED_TYPES" : {
"message" : [
"Storage-partition join partition transforms produced incompatible reduced types,",
"left reducers: <leftReducers> returned: <leftReducedDataTypes>,",
"right reducers: <rightReducers> returned: <rightReducedDataTypes>."
],
"sqlState" : "42K09"
},
"STREAMING_CHECKPOINT_MISSING_METADATA_FILE" : {
"message" : [
"Checkpoint location <checkpointLocation> is in an inconsistent state: the metadata file is missing but offset and/or commit logs contain data. Please restore the metadata file or create a new checkpoint directory."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql.connector.catalog.functions;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.types.DataType;

/**
* A 'reducer' for output of user-defined functions.
Expand All @@ -31,9 +32,10 @@
* <li> More generally, there exists reducer functions r1(x) and r2(x) such that
* r1(f_source(x)) = r2(f_target(x)) for all input x. </li>
* </ul>
* where = means both value and data type match.
*
* @param <I> reducer input type
* @param <O> reducer output type
* @param <I> the physical Java type of the input
* @param <O> the physical Java type of the output
* @since 4.0.0
*/
@Evolving
Expand All @@ -47,4 +49,11 @@ public interface Reducer<I, O> {
default String displayName() {
return getClass().getSimpleName();
}

/**
* Returns the {@link DataType data type} of values produced by this reducer.
*
* @return the data type of values produced by this reducer.
*/
DataType resultType();
}
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,11 @@ case class KeyedPartitioning(

/**
* Reduces this partitioning's partition keys by applying the given reducers.
* Returns the distinct reduced keys.
* Returns the reduced keys and their data types.
*/
def reduceKeys(reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRowComparableWrapper] =
KeyedPartitioning.reduceKeys(partitionKeys, expressionDataTypes, reducers).distinct
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This .distinct was moved to mergeAndDedupPartitions().

def reduceKeys(
reducers: Seq[Option[Reducer[_, _]]]): (Seq[DataType], Seq[InternalRowComparableWrapper]) =
KeyedPartitioning.reduceKeys(partitionKeys, expressionDataTypes, reducers)

override def satisfies0(required: Distribution): Boolean = {
nonGroupedSatisfies(required) || groupedSatisfies(required)
Expand Down Expand Up @@ -586,17 +587,23 @@ object KeyedPartitioning {
def reduceKeys(
keys: Seq[InternalRowComparableWrapper],
dataTypes: Seq[DataType],
reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRowComparableWrapper] = {
reducers: Seq[Option[Reducer[_, _]]]): (Seq[DataType], Seq[InternalRowComparableWrapper]) = {
val reducedDataTypes = dataTypes.zip(reducers).map {
case (_, Some(reducer: Reducer[Any, Any])) => reducer.resultType()
case (t, _) => t
}
val comparableKeyWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes)
keys.map { key =>
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(reducedDataTypes)
val reducedKeys = keys.map { key =>
val keyValues = key.row.toSeq(dataTypes)
val reducedKey = keyValues.zip(reducers).map {
case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v)
case (v, _) => v
}.toArray
comparableKeyWrapperFactory(new GenericInternalRow(reducedKey))
}

(reducedDataTypes, reducedKeys)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode}
import org.apache.spark.sql.catalyst.util.{sideBySide, CharsetProvider, DateTimeUtils, FailFastMode, IntervalUtils, MapData}
import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Table, TableProvider}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.functions.Reducer
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE
Expand Down Expand Up @@ -3128,6 +3129,30 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
)
}

def storagePartitionJoinIncompatibleReducedTypesError(
leftReducers: Option[Seq[Option[Reducer[_, _]]]],
leftReducedDataTypes: Seq[DataType],
rightReducers: Option[Seq[Option[Reducer[_, _]]]],
rightReducedDataTypes: Seq[DataType]): Throwable = {
def reducersNames(reducers: Option[Seq[Option[Reducer[_, _]]]]) = {
reducers.toSeq.flatMap(_.map(_.map(_.displayName()).getOrElse("identity")))
.mkString("[", ", ", "]")
}

def dataTypeNames(dataTypes: Seq[DataType]) = {
dataTypes.map(toSQLType).mkString("[", ", ", "]")
}

new SparkException(
errorClass = "STORAGE_PARTITION_JOIN_INCOMPATIBLE_REDUCED_TYPES",
messageParameters = Map(
"leftReducers" -> reducersNames(leftReducers),
"leftReducedDataTypes" -> dataTypeNames(leftReducedDataTypes),
"rightReducers" -> reducersNames(rightReducers),
"rightReducedDataTypes" -> dataTypeNames(rightReducedDataTypes)),
cause = null)
}

def notAbsolutePathError(path: Path): SparkException = {
SparkException.internalError(s"$path is not absolute path.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,10 @@ abstract class InMemoryBaseTable(
case YearsTransform(ref) =>
extractor(ref.fieldNames, cleanedSchema, row) match {
case (days: Int, DateType) =>
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days))
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)).toInt
case (micros: Long, TimestampType) =>
val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate).toInt
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
Expand All @@ -225,7 +225,7 @@ abstract class InMemoryBaseTable(
case (days, DateType) =>
days
case (micros: Long, TimestampType) =>
ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros))
ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)).toInt
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,15 @@ case class GroupPartitionsExec(
)(keyedPartitioning.projectKeys)

// Reduce keys if reducers are specified
val reducedKeys = reducers.fold(projectedKeys)(
val (reducedDataTypes, reducedKeys) = reducers.fold((projectedDataTypes, projectedKeys))(
KeyedPartitioning.reduceKeys(projectedKeys, projectedDataTypes, _))

val keyToPartitionIndices = reducedKeys.zipWithIndex.groupMap(_._1)(_._2)

if (expectedPartitionKeys.isDefined) {
alignToExpectedKeys(keyToPartitionIndices)
} else {
(groupAndSortByKeys(keyToPartitionIndices, projectedDataTypes), true)
(groupAndSortByKeys(keyToPartitionIndices, reducedDataTypes), true)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
import org.apache.spark.sql.connector.catalog.functions.Reducer
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.v2.GroupPartitionsExec
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
Expand Down Expand Up @@ -509,16 +510,24 @@ case class EnsureRequirements(
// in case of compatible but not identical partition expressions, we apply 'reduce'
// transforms to group one side's partitions as well as the common partition values
val leftReducers = leftSpec.reducers(rightSpec)
val leftReducedKeys =
leftReducers.fold(leftPartitioning.partitionKeys)(leftPartitioning.reduceKeys)
val rightReducers = rightSpec.reducers(leftSpec)
val rightReducedKeys =
rightReducers.fold(rightPartitioning.partitionKeys)(rightPartitioning.reduceKeys)
val (leftReducedDataTypes, leftReducedKeys) = leftReducers.fold(
(leftPartitioning.expressionDataTypes, leftPartitioning.partitionKeys)
)(leftPartitioning.reduceKeys)
val (rightReducedDataTypes, rightReducedKeys) = rightReducers.fold(
(rightPartitioning.expressionDataTypes, rightPartitioning.partitionKeys)
)(rightPartitioning.reduceKeys)
if (leftReducedDataTypes != rightReducedDataTypes) {
throw QueryExecutionErrors.storagePartitionJoinIncompatibleReducedTypesError(
leftReducers = leftReducers,
leftReducedDataTypes = leftReducedDataTypes,
rightReducers = rightReducers,
rightReducedDataTypes = rightReducedDataTypes)
}

// merge values on both sides
var mergedPartitionKeys =
mergePartitions(leftReducedKeys, rightReducedKeys, joinType, leftPartitioning.keyOrdering)
.map((_, 1))
var mergedPartitionKeys = mergeAndDedupPartitions(leftReducedKeys, rightReducedKeys,
joinType, leftPartitioning.keyOrdering).map((_, 1))

logInfo(log"After merging, there are " +
log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartitionKeys.size)} partitions")
Expand Down Expand Up @@ -752,36 +761,37 @@ case class EnsureRequirements(
}

/**
* Merge and sort partitions keys for SPJ and optionally enable partition filtering.
* Merge, dedup and sort partitions keys for SPJ and optionally enable partition filtering.
* Both sides must have matching partition expressions.
* @param leftPartitionKeys left side partition keys
* @param rightPartitionKeys right side partition keys
* @param joinType join type for optional partition filtering
* @keyOrdering ordering to sort partition keys
* @param keyOrdering ordering to sort partition keys
* @return merged and sorted partition values
*/
def mergePartitions(
def mergeAndDedupPartitions(
leftPartitionKeys: Seq[InternalRowComparableWrapper],
rightPartitionKeys: Seq[InternalRowComparableWrapper],
joinType: JoinType,
keyOrdering: Ordering[InternalRowComparableWrapper]): Seq[InternalRowComparableWrapper] = {
val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) {
joinType match {
case Inner => mergePartitionKeys(leftPartitionKeys, rightPartitionKeys, intersect = true)
case LeftOuter => leftPartitionKeys
case RightOuter => rightPartitionKeys
case _ => mergePartitionKeys(leftPartitionKeys, rightPartitionKeys)
case Inner =>
mergeAndDedupPartitionKeys(leftPartitionKeys, rightPartitionKeys, intersect = true)
case LeftOuter => leftPartitionKeys.distinct
case RightOuter => rightPartitionKeys.distinct
case _ => mergeAndDedupPartitionKeys(leftPartitionKeys, rightPartitionKeys)
}
} else {
mergePartitionKeys(leftPartitionKeys, rightPartitionKeys)
mergeAndDedupPartitionKeys(leftPartitionKeys, rightPartitionKeys)
}

// SPARK-41471: We keep to order of partitions to make sure the order of
// partitions is deterministic in different case.
merged.sorted(keyOrdering)
}

private def mergePartitionKeys(
private def mergeAndDedupPartitionKeys(
leftPartitionKeys: Seq[InternalRowComparableWrapper],
rightPartitionKeys: Seq[InternalRowComparableWrapper],
intersect: Boolean = false) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.connector
import java.sql.Timestamp
import java.util.Collections

import org.apache.spark.SparkConf
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression}
Expand Down Expand Up @@ -75,6 +75,20 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
Column.create("dept_id", IntegerType),
Column.create("data", StringType))

def withFunction[T](fn: UnboundFunction)(f: => T): T = {
val id = Identifier.of(Array.empty, fn.name())
val oldFn = Option.when(catalog.listFunctions(Array.empty).contains(id)) {
val fn = catalog.loadFunction(id)
catalog.dropFunction(id)
fn
}
catalog.createFunction(id, fn)
try f finally {
catalog.dropFunction(id)
oldFn.foreach(catalog.createFunction(id, _))
}
}

test("clustered distribution: output partitioning should be KeyedPartitioning") {
val partitions: Array[Transform] = Array(Expressions.years("ts"))

Expand All @@ -88,7 +102,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
var df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY ts")
val catalystDistribution = physical.ClusteredDistribution(
Seq(TransformExpression(YearsFunction, Seq(attr("ts")))))
val partitionKeys = Seq(50L, 51L, 52L).map(v => InternalRow.fromSeq(Seq(v)))
val partitionKeys = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v)))

checkQueryPlan(df, catalystDistribution,
physical.KeyedPartitioning(catalystDistribution.clustering, partitionKeys))
Expand Down Expand Up @@ -3385,4 +3399,83 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
checkKeywordsExistsInExplain(df, FormattedMode, formattedKeyword)
}
}

test("SPARK-56046: Reducers with same result types") {
val items_partitions = Array(days("arrive_time"))
createTable(items, itemsColumns, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " +
s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
s"(2, 'bb', 41.0, cast('2021-01-03' as timestamp)), " +
s"(3, 'bb', 42.0, cast('2021-01-04' as timestamp))")

val purchases_partitions = Array(years("time"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
s"(1, 42.0, cast('2020-01-01' as timestamp)), " +
s"(5, 44.0, cast('2020-01-15' as timestamp)), " +
s"(7, 46.5, cast('2021-02-08' as timestamp))")

withSQLConf(
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
Seq(
s"testcat.ns.$items i JOIN testcat.ns.$purchases p ON p.time = i.arrive_time",
s"testcat.ns.$purchases p JOIN testcat.ns.$items i ON i.arrive_time = p.time"
).foreach { joinString =>
val df = sql(
s"""
|${selectWithMergeJoinHint("i", "p")} id, item_id
|FROM $joinString
|ORDER BY id, item_id
|""".stripMargin)

val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.isEmpty, "should not add shuffle for both sides of the join")
val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan)
assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 2))

checkAnswer(df, Seq(Row(0, 1), Row(1, 1)))
}
}
}

test("SPARK-56046: Reducers with different result types") {
withFunction(UnboundDaysFunctionWithIncompatibleResultTypeReducer) {
val items_partitions = Array(days("arrive_time"))
createTable(items, itemsColumns, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " +
s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
s"(2, 'bb', 41.0, cast('2021-01-03' as timestamp)), " +
s"(3, 'bb', 42.0, cast('2021-01-04' as timestamp))")

val purchases_partitions = Array(years("time"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
s"(1, 42.0, cast('2020-01-01' as timestamp)), " +
s"(5, 44.0, cast('2020-01-15' as timestamp)), " +
s"(7, 46.5, cast('2021-02-08' as timestamp))")

withSQLConf(
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
Seq(
s"testcat.ns.$items i JOIN testcat.ns.$purchases p ON p.time = i.arrive_time",
s"testcat.ns.$purchases p JOIN testcat.ns.$items i ON i.arrive_time = p.time"
).foreach { joinString =>
val e = intercept[SparkException] {
sql(
s"""
|${selectWithMergeJoinHint("i", "p")} id, item_id
|FROM $joinString
|ORDER BY id, item_id
|""".stripMargin).collect()
}
assert(e.getMessage.contains(
"Storage-partition join partition transforms produced incompatible reduced types"))
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,7 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase
Invoke(
Literal.create(YearsFunction, ObjectType(YearsFunction.getClass)),
"invoke",
LongType,
IntegerType,
Seq(Cast(attr("day"), TimestampType, Some("America/Los_Angeles"))),
Seq(TimestampType),
propagateNull = false),
Expand Down
Loading