From 509270564dadbc8ff8446bdcac308e429832461c Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Tue, 17 Mar 2026 16:02:55 -0700 Subject: [PATCH] =?UTF-8?q?All=20four=20Android=20JNI=20modules=20(Generic?= =?UTF-8?q?,=20LLM,=20ASR,=20Training)=20previously=20used=20inconsistent?= =?UTF-8?q?=20exception=20types=20=E2=80=94=20a=20mix=20of=20RuntimeExcept?= =?UTF-8?q?ion,=20IllegalStateException,=20IllegalArgumentException,=20and?= =?UTF-8?q?=20the=20structured=20ExecutorchRuntimeException.=20This=20unif?= =?UTF-8?q?ies=20them=20so=20every=20JNI=20error=20path=20throws=20Executo?= =?UTF-8?q?rchRuntimeException=20with=20a=20proper=20error=20code,=20givin?= =?UTF-8?q?g=20callers=20a=20single=20type=20to=20catch=20and=20programmat?= =?UTF-8?q?ic=20access=20to=20the=20underlying=20runtime=20error.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Key changes per module: LLM (jni_layer_llama.cpp) — generate() now captures the Error return from runner_->generate(), wraps the call in try/catch, reports failures via a new onError(errorCode, message) callback, and returns the actual error code instead of always returning 0. ASR (jni_layer_asr.cpp) — replaced six env->ThrowNew(...) calls with setExecutorchPendingException (for pure JNI path) Training (jni_layer_training.cpp) — added jni_helper.h include; replaced five throwNewJavaException("java/lang/Exception", ...) calls with throwExecutorchException, preserving the actual error codes from the failed operations. Additionally: added default onError callbacks to LlmCallback (Java) and AsrCallback (Kotlin); This PR was authored with the assistance of Claude. --- .../executorch/extension/asr/AsrCallback.kt | 8 ++ .../executorch/extension/llm/LlmCallback.java | 10 ++ extension/android/jni/jni_helper.cpp | 109 ++++++++++++++++++ extension/android/jni/jni_helper.h | 24 ++++ extension/android/jni/jni_layer_asr.cpp | 49 ++++---- extension/android/jni/jni_layer_llama.cpp | 80 ++++++++++--- extension/android/jni/jni_layer_training.cpp | 52 ++++----- 7 files changed, 268 insertions(+), 64 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrCallback.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrCallback.kt index 51a220167c0..0a2e02d63d6 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrCallback.kt +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrCallback.kt @@ -25,4 +25,12 @@ interface AsrCallback { * @param token The decoded text token */ fun onToken(token: String) + + /** + * Called when an error occurs during transcription. + * + * @param errorCode Error code from the ExecuTorch runtime + * @param message Human-readable error description + */ + fun onError(errorCode: Int, message: String) {} } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java index 42b44c7d4c5..4e834d06721 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java @@ -37,4 +37,14 @@ public interface LlmCallback { */ @DoNotStrip default void onStats(String stats) {} + + /** + * Called when an error occurs during generate(). + * + * @param errorCode Error code from the ExecuTorch runtime (see {@link + * org.pytorch.executorch.ExecutorchRuntimeException}) + * @param message Human-readable error description + */ + @DoNotStrip + default void onError(int errorCode, String message) {} } diff --git a/extension/android/jni/jni_helper.cpp b/extension/android/jni/jni_helper.cpp index 04c3cbeb58f..f7c9c4abb74 100644 --- a/extension/android/jni/jni_helper.cpp +++ b/extension/android/jni/jni_helper.cpp @@ -35,6 +35,115 @@ void throwExecutorchException(uint32_t errorCode, const std::string& details) { facebook::jni::throwNewJavaException(exception.get()); } +void setExecutorchPendingException( + JNIEnv* env, + uint32_t errorCode, + const std::string& details) { + if (!env) { + return; + } + if (env->ExceptionCheck()) { + // Preserve any preexisting pending exception; do not overwrite it here. + return; + } + + jclass exceptionClass = + env->FindClass("org/pytorch/executorch/ExecutorchRuntimeException"); + if (env->ExceptionCheck()) { + if (exceptionClass) { + env->DeleteLocalRef(exceptionClass); + } + // Preserve the original exception; do not clear or overwrite it. + return; + } + + if (!exceptionClass) { + // FindClass failed. It should have set a pending exception; leave it as is. + return; + } + + jmethodID factoryMethod = env->GetStaticMethodID( + exceptionClass, + "makeExecutorchException", + "(ILjava/lang/String;)Ljava/lang/RuntimeException;"); + if (env->ExceptionCheck()) { + // Preserve the original exception; do not overwrite it. + env->DeleteLocalRef(exceptionClass); + return; + } + + if (!factoryMethod) { + // Factory unavailable; try the (int, String) constructor directly so + // callers still receive a structured ExecutorchRuntimeException. + jmethodID ctor = + env->GetMethodID(exceptionClass, "", "(ILjava/lang/String;)V"); + if (!env->ExceptionCheck() && ctor) { + jstring jDetails = env->NewStringUTF(details.c_str()); + if (!env->ExceptionCheck() && jDetails) { + jobject exObj = env->NewObject( + exceptionClass, ctor, static_cast(errorCode), jDetails); + if (!env->ExceptionCheck() && exObj) { + env->Throw(reinterpret_cast(exObj)); + env->DeleteLocalRef(exObj); + } + env->DeleteLocalRef(jDetails); + } + } else { + // Clear any NoSuchMethodError from GetMethodID before falling back. + if (env->ExceptionCheck()) { + env->ExceptionClear(); + } + jclass runtimeExClass = env->FindClass("java/lang/RuntimeException"); + if (!env->ExceptionCheck() && runtimeExClass) { + env->ThrowNew(runtimeExClass, details.c_str()); + env->DeleteLocalRef(runtimeExClass); + } + } + env->DeleteLocalRef(exceptionClass); + return; + } + + jstring jDetails = env->NewStringUTF(details.c_str()); + if (env->ExceptionCheck()) { + // Preserve the original exception; do not overwrite it. + if (jDetails) { + env->DeleteLocalRef(jDetails); + } + env->DeleteLocalRef(exceptionClass); + return; + } + + if (!jDetails) { + // NewStringUTF returned null without setting an exception; fall back to + // a standard RuntimeException since ExecutorchRuntimeException lacks a + // (String) ctor. + jclass runtimeExClass = env->FindClass("java/lang/RuntimeException"); + if (!env->ExceptionCheck() && runtimeExClass) { + env->ThrowNew(runtimeExClass, details.c_str()); + env->DeleteLocalRef(runtimeExClass); + } + env->DeleteLocalRef(exceptionClass); + return; + } + + auto exception = static_cast(env->CallStaticObjectMethod( + exceptionClass, factoryMethod, static_cast(errorCode), jDetails)); + if (env->ExceptionCheck() || !exception) { + // If a Java exception was thrown, it is already pending; just clean up. + if (exception) { + env->DeleteLocalRef(exception); + } + env->DeleteLocalRef(jDetails); + env->DeleteLocalRef(exceptionClass); + return; + } + + env->Throw(exception); + env->DeleteLocalRef(exception); + env->DeleteLocalRef(jDetails); + env->DeleteLocalRef(exceptionClass); +} + bool utf8_check_validity(const char* str, size_t length) { for (size_t i = 0; i < length; ++i) { uint8_t byte = static_cast(str[i]); diff --git a/extension/android/jni/jni_helper.h b/extension/android/jni/jni_helper.h index 45b28c3b9ff..ec31085be30 100644 --- a/extension/android/jni/jni_helper.h +++ b/extension/android/jni/jni_helper.h @@ -20,11 +20,35 @@ namespace executorch::jni_helper { * code and details. Uses the Java factory method * ExecutorchRuntimeException.makeExecutorchException(int, String). * + * IMPORTANT: This attempts to throw a C++ exception (via fbjni). Only use in + * fbjni HybridClass methods where fbjni catches it at the JNI boundary. + * For plain extern "C" JNIEXPORT functions, use setExecutorchPendingException. + * + * Note: If there is no current JNI environment (for example, if + * facebook::jni::Environment::current() returns null), this function is a + * no-op and does not throw. Callers must not rely on this always aborting + * control flow. + * * @param errorCode The error code from the C++ Executorch runtime. * @param details Additional details to include in the exception message. */ void throwExecutorchException(uint32_t errorCode, const std::string& details); +/** + * Sets a pending Java ExecutorchRuntimeException without throwing a C++ + * exception. Safe to call from plain extern "C" JNIEXPORT functions. + * After calling this, the caller must return from the JNI function promptly; + * the Java exception will be delivered when control returns to the JVM. + * + * @param env The JNI environment pointer. + * @param errorCode The error code from the C++ Executorch runtime. + * @param details Additional details to include in the exception message. + */ +void setExecutorchPendingException( + JNIEnv* env, + uint32_t errorCode, + const std::string& details); + // Define the JavaClass wrapper struct JExecutorchRuntimeException : public facebook::jni::JavaClass { diff --git a/extension/android/jni/jni_layer_asr.cpp b/extension/android/jni/jni_layer_asr.cpp index dc053f69925..1daa23d6113 100644 --- a/extension/android/jni/jni_layer_asr.cpp +++ b/extension/android/jni/jni_layer_asr.cpp @@ -128,8 +128,9 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeCreate( auto load_error = handle->preprocessor->load(); if (load_error != Error::Ok) { ET_LOG(Error, "Failed to load preprocessor module"); - env->ThrowNew( - env->FindClass("java/lang/RuntimeException"), + executorch::jni_helper::setExecutorchPendingException( + env, + static_cast(load_error), "Failed to load preprocessor module"); return 0; } @@ -138,9 +139,10 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeCreate( return reinterpret_cast(handle.release()); } catch (const std::exception& e) { ET_LOG(Error, "Failed to create AsrModule: %s", e.what()); - env->ThrowNew( - env->FindClass("java/lang/RuntimeException"), - ("Failed to create AsrModule: " + std::string(e.what())).c_str()); + executorch::jni_helper::setExecutorchPendingException( + env, + static_cast(Error::Internal), + "Failed to create AsrModule: " + std::string(e.what())); return 0; } } @@ -172,8 +174,9 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeLoad( jobject /* this */, jlong nativeHandle) { if (nativeHandle == 0) { - env->ThrowNew( - env->FindClass("java/lang/IllegalStateException"), + executorch::jni_helper::setExecutorchPendingException( + env, + static_cast(Error::InvalidState), "Module has been destroyed"); return -1; } @@ -218,15 +221,17 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe( jlong decoderStartTokenId, jobject callback) { if (nativeHandle == 0) { - env->ThrowNew( - env->FindClass("java/lang/IllegalStateException"), + executorch::jni_helper::setExecutorchPendingException( + env, + static_cast(Error::InvalidState), "Module has been destroyed"); return -1; } if (wavPath == nullptr) { - env->ThrowNew( - env->FindClass("java/lang/IllegalArgumentException"), + executorch::jni_helper::setExecutorchPendingException( + env, + static_cast(Error::InvalidArgument), "WAV path cannot be null"); return -1; } @@ -239,15 +244,17 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe( try { audioData = ::executorch::extension::llm::load_wav_audio_data(wavPathStr); } catch (const std::exception& e) { - env->ThrowNew( - env->FindClass("java/lang/RuntimeException"), - ("Failed to load WAV file: " + std::string(e.what())).c_str()); + executorch::jni_helper::setExecutorchPendingException( + env, + static_cast(Error::AccessFailed), + "Failed to load WAV file: " + std::string(e.what())); return -1; } if (audioData.empty()) { - env->ThrowNew( - env->FindClass("java/lang/IllegalArgumentException"), + executorch::jni_helper::setExecutorchPendingException( + env, + static_cast(Error::InvalidArgument), "WAV file contains no audio data"); return -1; } @@ -267,16 +274,18 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe( auto processedResult = handle->preprocessor->execute("forward", audioTensor); if (processedResult.error() != Error::Ok) { - env->ThrowNew( - env->FindClass("java/lang/RuntimeException"), + executorch::jni_helper::setExecutorchPendingException( + env, + static_cast(processedResult.error()), "Audio preprocessing failed"); return -1; } auto outputs = std::move(processedResult.get()); if (outputs.empty() || !outputs[0].isTensor()) { - env->ThrowNew( - env->FindClass("java/lang/RuntimeException"), + executorch::jni_helper::setExecutorchPendingException( + env, + static_cast(Error::Internal), "Preprocessor returned unexpected output"); return -1; } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index ac0eb46c0eb..b1474288a2f 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -71,6 +71,14 @@ class ExecuTorchLlmCallbackJni facebook::jni::make_jstring( executorch::extension::llm::stats_to_json_string(result))); } + + void onError(int errorCode, const std::string& message) const { + static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic(); + static const auto on_error_method = + cls->getMethod)>( + "onError"); + on_error_method(self(), errorCode, facebook::jni::make_jstring(message)); + } }; class ExecuTorchLlmJni : public facebook::jni::HybridClass { @@ -198,6 +206,17 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jfloat temperature, jint num_bos, jint num_eos) { + Error err = Error::Ok; + if (!prompt) { + err = Error::InvalidArgument; + if (callback) { + callback->onError( + static_cast(err), + "generate() failed: prompt must not be null"); + } + return static_cast(err); + } + float effective_temperature = temperature >= 0 ? temperature : temperature_; std::string token_buffer; auto token_callback = [callback, &token_buffer](const std::string& token) { @@ -209,26 +228,53 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { } std::string result = token_buffer; token_buffer.clear(); - callback->onResult(result); + if (callback) { + callback->onResult(result); + } }; if (!runner_) { - return static_cast(Error::InvalidState); - } - executorch::extension::llm::GenerationConfig config{ - .echo = static_cast(echo), - .seq_len = seq_len, - .temperature = effective_temperature, - .num_bos = needs_bos_ ? num_bos_ : 0, - .num_eos = num_eos, - }; - auto err = runner_->generate( - prompt->toStdString(), - config, - token_callback, - [callback](const llm::Stats& result) { callback->onStats(result); }); - if (err == Error::Ok) { - needs_bos_ = false; + err = Error::InvalidState; + if (callback) { + callback->onError( + static_cast(err), "generate() failed: runner not initialized"); + } + return static_cast(err); + } + + try { + executorch::extension::llm::GenerationConfig config{ + .echo = static_cast(echo), + .seq_len = seq_len, + .temperature = effective_temperature, + .num_bos = needs_bos_ ? num_bos_ : 0, + .num_eos = num_eos, + }; + err = runner_->generate( + prompt->toStdString(), + config, + token_callback, + [callback](const llm::Stats& result) { + if (callback) { + callback->onStats(result); + } + }); + if (err == Error::Ok) { + needs_bos_ = false; + } + if (err != Error::Ok && callback) { + callback->onError( + static_cast(err), + "generate() failed with error code " + + std::to_string(static_cast(err))); + } + } catch (const std::exception& e) { + if (callback) { + callback->onError( + static_cast(Error::Internal), + std::string("generate() threw: ") + e.what()); + } + return static_cast(Error::Internal); } return static_cast(err); } diff --git a/extension/android/jni/jni_layer_training.cpp b/extension/android/jni/jni_layer_training.cpp index 90658006ee5..4b002de5dc7 100644 --- a/extension/android/jni/jni_layer_training.cpp +++ b/extension/android/jni/jni_layer_training.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -75,10 +76,10 @@ class ExecuTorchTrainingJni auto modelPathString = modelPath->toStdString(); auto modelLoaderRes = FileDataLoader::from(modelPathString.c_str()); if (modelLoaderRes.error() != Error::Ok) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "Failed to open model file: %s", - modelPathString.c_str()); + executorch::jni_helper::throwExecutorchException( + static_cast(modelLoaderRes.error()), + "Failed to open model file: " + modelPathString); + return; } auto modelLoader = std::make_unique(std::move(modelLoaderRes.get())); @@ -88,10 +89,10 @@ class ExecuTorchTrainingJni if (!dataPathString.empty()) { auto dataLoaderRes = FileDataLoader::from(dataPathString.c_str()); if (dataLoaderRes.error() != Error::Ok) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "Failed to open ptd file: %s", - dataPathString.c_str()); + executorch::jni_helper::throwExecutorchException( + static_cast(dataLoaderRes.error()), + "Failed to open ptd file: " + dataPathString); + return; } dataLoader = std::make_unique(std::move(dataLoaderRes.get())); @@ -148,11 +149,11 @@ class ExecuTorchTrainingJni auto result = module_->execute_forward_backward(methodName->toStdString(), evalues); if (!result.ok()) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "Execution of forward_backward for method %s failed with status 0x%" PRIx32, - methodName->toStdString().c_str(), - static_cast(result.error())); + executorch::jni_helper::throwExecutorchException( + static_cast(result.error()), + "Execution of forward_backward for method " + + methodName->toStdString() + " failed"); + return {}; } facebook::jni::local_ref> jresult = @@ -171,11 +172,10 @@ class ExecuTorchTrainingJni auto method = methodName->toStdString(); auto result = module_->named_parameters(method); if (!result.ok()) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "Getting named parameters for method %s failed with status 0x%" PRIx32, - method.c_str(), - static_cast(result.error())); + executorch::jni_helper::throwExecutorchException( + static_cast(result.error()), + "Getting named parameters for method " + method + " failed"); + return {}; } facebook::jni::local_ref< facebook::jni::JHashMap> @@ -195,11 +195,10 @@ class ExecuTorchTrainingJni auto method = methodName->toStdString(); auto result = module_->named_gradients(method); if (!result.ok()) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "Getting named gradients for method %s failed with status 0x%" PRIx32, - method.c_str(), - static_cast(result.error())); + executorch::jni_helper::throwExecutorchException( + static_cast(result.error()), + "Getting named gradients for method " + method + " failed"); + return {}; } facebook::jni::local_ref< facebook::jni::JHashMap> @@ -322,10 +321,9 @@ class SGDHybrid : public facebook::jni::HybridClass { auto result = sgdOptimizer_->step(cppNamedGradients); if (result != ::executorch::runtime::Error::Ok) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "SGD optimization step failed with status 0x%" PRIx32, - static_cast(result)); + executorch::jni_helper::throwExecutorchException( + static_cast(result), "SGD optimization step failed"); + return; } }