diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/error/ErrorMessage.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/error/ErrorMessage.java index fbe6d3cd723..0717771d1ae 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/error/ErrorMessage.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/error/ErrorMessage.java @@ -8,6 +8,7 @@ import lombok.Getter; import org.json.JSONObject; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.tasks.TaskCancelledException; /** Error Message. */ public class ErrorMessage { @@ -37,6 +38,9 @@ private String fetchType() { } protected String fetchReason() { + if (exception instanceof TaskCancelledException) { + return "Query cancelled"; + } return status == RestStatus.BAD_REQUEST.getStatus() ? "Invalid Query" : "There was internal problem at backend"; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/error/ErrorMessageFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/error/ErrorMessageFactory.java index 8617f264f06..036122fc9f7 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/error/ErrorMessageFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/error/ErrorMessageFactory.java @@ -7,6 +7,7 @@ import lombok.experimental.UtilityClass; import org.opensearch.OpenSearchException; +import org.opensearch.core.tasks.TaskCancelledException; @UtilityClass public class ErrorMessageFactory { @@ -21,6 +22,9 @@ public class ErrorMessageFactory { */ public static ErrorMessage createErrorMessage(Throwable e, int status) { Throwable cause = unwrapCause(e); + if (cause instanceof TaskCancelledException) { + return new ErrorMessage(cause, status); + } if (cause instanceof OpenSearchException) { OpenSearchException exception = (OpenSearchException) cause; return new OpenSearchErrorMessage(exception, exception.status().getStatus()); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/error/ErrorMessageFactoryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/error/ErrorMessageFactoryTest.java index 6ffe6b275ce..52f6856c831 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/error/ErrorMessageFactoryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/error/ErrorMessageFactoryTest.java @@ -8,9 +8,11 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.sql.SQLException; import org.junit.jupiter.api.Test; import org.opensearch.OpenSearchException; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.tasks.TaskCancelledException; public class ErrorMessageFactoryTest { private final Throwable nonOpenSearchThrowable = new Throwable(); @@ -48,4 +50,26 @@ public void nonOpenSearchExceptionWithWrappedEsExceptionCauseShouldCreateEsError ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); assertTrue(msg instanceof OpenSearchErrorMessage); } + + @Test + public void wrappedTaskCancelledExceptionShouldCreateGenericErrorMessageWithPassedStatus() { + TaskCancelledException cancelled = new TaskCancelledException("The task is cancelled."); + SQLException sqlEx = new SQLException("exception while executing query", cancelled); + RuntimeException wrapped = new RuntimeException(sqlEx); + ErrorMessage msg = + ErrorMessageFactory.createErrorMessage(wrapped, RestStatus.BAD_REQUEST.getStatus()); + assertFalse(msg instanceof OpenSearchErrorMessage); + assertTrue(msg.toString().contains("\"status\": 400")); + assertTrue(msg.toString().contains("\"reason\": \"Query cancelled\"")); + } + + @Test + public void directTaskCancelledExceptionShouldCreateGenericErrorMessageWithPassedStatus() { + TaskCancelledException cancelled = new TaskCancelledException("The task is cancelled."); + ErrorMessage msg = + ErrorMessageFactory.createErrorMessage(cancelled, RestStatus.BAD_REQUEST.getStatus()); + assertFalse(msg instanceof OpenSearchErrorMessage); + assertTrue(msg.toString().contains("\"status\": 400")); + assertTrue(msg.toString().contains("\"reason\": \"Query cancelled\"")); + } } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java index ffdd90504f7..49a6154ee20 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java @@ -19,6 +19,7 @@ import org.opensearch.OpenSearchException; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.tasks.TaskCancelledException; import org.opensearch.index.IndexNotFoundException; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; @@ -59,7 +60,17 @@ private static boolean isClientError(Exception e) { || e instanceof QueryEngineException || e instanceof SyntaxCheckException || e instanceof DataSourceClientException - || e instanceof IllegalAccessException; + || e instanceof IllegalAccessException + || hasCauseOf(e, TaskCancelledException.class); + } + + private static boolean hasCauseOf(Throwable e, Class target) { + for (Throwable cause = e; cause != null; cause = cause.getCause()) { + if (target.isInstance(cause)) { + return true; + } + } + return false; } @Override diff --git a/plugin/src/test/java/org/opensearch/sql/plugin/rest/RestPPLQueryActionTest.java b/plugin/src/test/java/org/opensearch/sql/plugin/rest/RestPPLQueryActionTest.java new file mode 100644 index 00000000000..136e7058ceb --- /dev/null +++ b/plugin/src/test/java/org/opensearch/sql/plugin/rest/RestPPLQueryActionTest.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.plugin.rest; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.lang.reflect.Method; +import java.sql.SQLException; +import org.junit.Test; +import org.opensearch.core.tasks.TaskCancelledException; + +public class RestPPLQueryActionTest { + + @Test + public void testTaskCancelledExceptionIsClientError() throws Exception { + TaskCancelledException cancelled = new TaskCancelledException("The task is cancelled."); + SQLException sqlEx = new SQLException("exception while executing query", cancelled); + RuntimeException wrapped = new RuntimeException(sqlEx); + + assertTrue(invokeIsClientError(wrapped)); + } + + @Test + public void testDirectTaskCancelledExceptionIsClientError() throws Exception { + TaskCancelledException cancelled = new TaskCancelledException("The task is cancelled."); + assertTrue(invokeIsClientError(cancelled)); + } + + @Test + public void testGenericRuntimeExceptionIsNotClientError() throws Exception { + RuntimeException e = new RuntimeException("something went wrong"); + assertFalse(invokeIsClientError(e)); + } + + private static boolean invokeIsClientError(Exception e) throws Exception { + Method method = RestPPLQueryAction.class.getDeclaredMethod("isClientError", Exception.class); + method.setAccessible(true); + return (boolean) method.invoke(null, e); + } +}