diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java index c2ce4a740ec..f0dcca93d3d 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java @@ -580,8 +580,10 @@ public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext conte ? Collections.emptyList() : arguments.subList(1, arguments.size()); List nodes = - PPLFuncImpTable.INSTANCE.validateAggFunctionSignature( - functionName, field, args, context.rexBuilder); + requiresWindowAggValidation(functionName) + ? PPLFuncImpTable.INSTANCE.validateAggFunctionSignature( + functionName, field, args, context.rexBuilder) + : null; return nodes != null ? PlanUtils.makeOver( context, @@ -606,6 +608,13 @@ public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext conte "Unexpected window function: " + windowFunction.getFuncName())); } + private boolean requiresWindowAggValidation(BuiltinFunctionName functionName) { + return switch (functionName) { + case ROW_NUMBER, RANK, DENSE_RANK, NTH_VALUE -> false; + default -> true; + }; + } + /** extract the expression of Alias from a node */ private RexNode extractRexNodeFromAlias(RexNode node) { requireNonNull(node); diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java index 39f3a6f2d05..003cb5acc21 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java @@ -211,6 +211,22 @@ static RexNode makeOver( true, lowerBound, upperBound); + case RANK: + return withOver( + context.relBuilder.aggregateCall(SqlStdOperatorTable.RANK), + partitions, + orderKeys, + true, + lowerBound, + upperBound); + case DENSE_RANK: + return withOver( + context.relBuilder.aggregateCall(SqlStdOperatorTable.DENSE_RANK), + partitions, + orderKeys, + true, + lowerBound, + upperBound); case NTH_VALUE: return withOver( context.relBuilder.aggregateCall(SqlStdOperatorTable.NTH_VALUE, field, argList.get(0)), diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 14f058a75d0..a08635e46c4 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -425,6 +425,11 @@ public enum BuiltinFunctionName { .put("dc", BuiltinFunctionName.DISTINCT_COUNT_APPROX) .put("distinct_count", BuiltinFunctionName.DISTINCT_COUNT_APPROX) .put("pattern", BuiltinFunctionName.INTERNAL_PATTERN) + .put("row_number", BuiltinFunctionName.ROW_NUMBER) + .put("rank", BuiltinFunctionName.RANK) + .put("dense_rank", BuiltinFunctionName.DENSE_RANK) + .put("nth", BuiltinFunctionName.NTH_VALUE) + .put("nth_value", BuiltinFunctionName.NTH_VALUE) .build(); public static Optional of(String str) { diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLEventstatsIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLEventstatsIT.java index f1ee8df35ea..262a9423abe 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLEventstatsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLEventstatsIT.java @@ -322,6 +322,82 @@ public void testUnsupportedWindowFunctions() { } } + @Test + public void testEventstatsRowNumberWindowFunction() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | sort country, age | eventstats row_number() as row_num by country |" + + " stats min(row_num) as min_row_num, max(row_num) as max_row_num by" + + " country | sort country", + TEST_INDEX_STATE_COUNTRY)); + + verifySchemaInOrder( + actual, + schema("min_row_num", "bigint"), + schema("max_row_num", "bigint"), + schema("country", "string")); + + verifyDataRowsInOrder(actual, rows(1, 1, "Canada"), rows(1, 1, "USA")); + } + + @Test + public void testEventstatsRankWindowFunction() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | sort country, age | eventstats rank() as rank_value by country |" + + " stats min(rank_value) as min_rank, max(rank_value) as max_rank by" + + " country | sort country", + TEST_INDEX_STATE_COUNTRY)); + + verifySchemaInOrder( + actual, + schema("min_rank", "bigint"), + schema("max_rank", "bigint"), + schema("country", "string")); + + verifyDataRowsInOrder(actual, rows(1, 1, "Canada"), rows(1, 1, "USA")); + } + + @Test + public void testEventstatsDenseRankWindowFunction() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | sort country, age | eventstats dense_rank() as dense_rank_value by" + + " country | stats min(dense_rank_value) as min_dense_rank," + + " max(dense_rank_value) as max_dense_rank by country | sort country", + TEST_INDEX_STATE_COUNTRY)); + + verifySchemaInOrder( + actual, + schema("min_dense_rank", "bigint"), + schema("max_dense_rank", "bigint"), + schema("country", "string")); + + verifyDataRowsInOrder(actual, rows(1, 1, "Canada"), rows(1, 1, "USA")); + } + + @Test + public void testEventstatsNthValueWindowFunction() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | sort country, age | eventstats nth(age, 2) as nth_age by" + + " country | stats min(nth_age) as min_nth_age, max(nth_age) as" + + " max_nth_age by country | sort country", + TEST_INDEX_STATE_COUNTRY)); + + verifySchemaInOrder( + actual, + schema("min_nth_age", "int"), + schema("max_nth_age", "int"), + schema("country", "string")); + + verifyDataRowsInOrder(actual, rows(25, 25, "Canada"), rows(70, 70, "USA")); + } + @Test public void testMultipleEventstats() throws IOException { JSONObject actual = diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteStreamstatsCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteStreamstatsCommandIT.java index dcf36f510bf..c92b178afb3 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteStreamstatsCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteStreamstatsCommandIT.java @@ -806,6 +806,82 @@ public void testUnsupportedWindowFunctions() { } } + @Test + public void testStreamstatsRowNumberWindowFunction() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | sort country, age | streamstats row_number() as row_num by country |" + + " stats min(row_num) as min_row_num, max(row_num) as max_row_num by" + + " country | sort country", + TEST_INDEX_STATE_COUNTRY)); + + verifySchemaInOrder( + actual, + schema("min_row_num", "bigint"), + schema("max_row_num", "bigint"), + schema("country", "string")); + + verifyDataRowsInOrder(actual, rows(1, 2, "Canada"), rows(1, 2, "USA")); + } + + @Test + public void testStreamstatsRankWindowFunction() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | sort country, age | streamstats rank() as rank_value by country |" + + " stats min(rank_value) as min_rank, max(rank_value) as max_rank by" + + " country | sort country", + TEST_INDEX_STATE_COUNTRY)); + + verifySchemaInOrder( + actual, + schema("min_rank", "bigint"), + schema("max_rank", "bigint"), + schema("country", "string")); + + verifyDataRowsInOrder(actual, rows(1, 1, "Canada"), rows(1, 1, "USA")); + } + + @Test + public void testStreamstatsDenseRankWindowFunction() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | sort country, age | streamstats dense_rank() as dense_rank_value by" + + " country | stats min(dense_rank_value) as min_dense_rank," + + " max(dense_rank_value) as max_dense_rank by country | sort country", + TEST_INDEX_STATE_COUNTRY)); + + verifySchemaInOrder( + actual, + schema("min_dense_rank", "bigint"), + schema("max_dense_rank", "bigint"), + schema("country", "string")); + + verifyDataRowsInOrder(actual, rows(1, 1, "Canada"), rows(1, 1, "USA")); + } + + @Test + public void testStreamstatsNthValueWindowFunction() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | sort country, age | streamstats nth(age, 2) as nth_age by" + + " country | stats min(nth_age) as min_nth_age, max(nth_age) as" + + " max_nth_age by country | sort country", + TEST_INDEX_STATE_COUNTRY)); + + verifySchemaInOrder( + actual, + schema("min_nth_age", "int"), + schema("max_nth_age", "int"), + schema("country", "string")); + + verifyDataRowsInOrder(actual, rows(25, 25, "Canada"), rows(70, 70, "USA")); + } + @Test public void testMultipleStreamstats() throws IOException { JSONObject actual = diff --git a/integ-test/src/yamlRestTest/resources/rest-api-spec/test/issues/5168.yml b/integ-test/src/yamlRestTest/resources/rest-api-spec/test/issues/5168.yml new file mode 100644 index 00000000000..16ed08d632f --- /dev/null +++ b/integ-test/src/yamlRestTest/resources/rest-api-spec/test/issues/5168.yml @@ -0,0 +1,88 @@ +setup: + - do: + query.settings: + body: + transient: + plugins.calcite.enabled: true + + - do: + indices.create: + index: bounty-types + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + int_field: + type: integer + str_field: + type: keyword + + - do: + bulk: + refresh: true + body: + - '{"index": {"_index": "bounty-types", "_id": "1"}}' + - '{"int_field": 42, "str_field": "alpha"}' + - '{"index": {"_index": "bounty-types", "_id": "2"}}' + - '{"int_field": -1, "str_field": "alpha"}' + - '{"index": {"_index": "bounty-types", "_id": "3"}}' + - '{"int_field": 0, "str_field": "beta"}' + +--- +teardown: + - do: + indices.delete: + index: bounty-types + ignore_unavailable: true + - do: + query.settings: + body: + transient: + plugins.calcite.enabled: false + +--- +"Issue 5168: eventstats row_number() should execute for grouped rows": + - skip: + features: + - headers + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=bounty-types | sort str_field | eventstats row_number() as rn by str_field | stats max(rn) as max_rn by str_field | sort str_field + + - match: { total: 2 } + - match: { datarows: [ [ 1, alpha ], [ 1, beta ] ] } + +--- +"Issue 5168: streamstats rank() should execute for grouped rows": + - skip: + features: + - headers + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=bounty-types | streamstats rank() as rk by int_field | stats max(rk) as max_rk by int_field | sort int_field + + - match: { total: 3 } + - match: { datarows: [ [ 1, -1 ], [ 1, 0 ], [ 1, 42 ] ] } + +--- +"Issue 5168: eventstats dense_rank() should execute for grouped rows": + - skip: + features: + - headers + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=bounty-types | sort str_field | eventstats dense_rank() as dr by str_field | stats max(dr) as max_dr by str_field | sort str_field + + - match: { total: 2 } + - match: { datarows: [ [ 1, alpha ], [ 1, beta ] ] }