From 4b8bcbc56a9ed12ca67c7d969f898f7c8480becb Mon Sep 17 00:00:00 2001 From: Reza Sajadiany Date: Mon, 16 Mar 2026 16:53:27 -0700 Subject: [PATCH] Op registry to warn instead of error against duplicated op s (#18179) Summary: Resolves the issue where duplicated primitive ops can be attempted to be registered from multiple ET runners. The error can be silent in pre-main static intialization. Differential Revision: D96544230 --- runtime/executor/test/executor_test.cpp | 5 +- runtime/kernel/operator_registry.cpp | 48 ++++++++----------- .../test/kernel_double_registration_test.cpp | 11 +++-- .../operator_registry_max_kernel_num_test.cpp | 12 ++++- .../kernel/test/operator_registry_test.cpp | 41 +++++++++++++--- 5 files changed, 76 insertions(+), 41 deletions(-) diff --git a/runtime/executor/test/executor_test.cpp b/runtime/executor/test/executor_test.cpp index de5597af0f9..33f24702dd3 100644 --- a/runtime/executor/test/executor_test.cpp +++ b/runtime/executor/test/executor_test.cpp @@ -173,8 +173,9 @@ TEST_F(ExecutorTest, OpRegistration) { auto s2 = register_kernel(Kernel("test_2", test_op)); ASSERT_EQ(Error::Ok, s1); ASSERT_EQ(Error::Ok, s2); - ET_EXPECT_DEATH( - []() { (void)register_kernel(Kernel("test", test_op)); }(), ""); + // Duplicate registration should succeed and skip gracefully + auto s3 = register_kernel(Kernel("test", test_op)); + ASSERT_EQ(Error::Ok, s3); ASSERT_TRUE(registry_has_op_function("test")); ASSERT_TRUE(registry_has_op_function("test_2")); diff --git a/runtime/kernel/operator_registry.cpp b/runtime/kernel/operator_registry.cpp index 3738f8285af..cf7d662ec4c 100644 --- a/runtime/kernel/operator_registry.cpp +++ b/runtime/kernel/operator_registry.cpp @@ -54,44 +54,39 @@ Error register_kernels_internal(const Span kernels) { // PAL init, so call it here. It is safe to call multiple times. ::et_pal_init(); - if (kernels.size() + num_registered_kernels > kMaxRegisteredKernels) { - ET_LOG( - Error, - "The total number of kernels to be registered is larger than the limit " - "%" PRIu32 ". %" PRIu32 - " kernels are already registered and we're trying to register another " - "%" PRIu32 " kernels.", - kMaxRegisteredKernels, - (uint32_t)num_registered_kernels, - (uint32_t)kernels.size()); - ET_LOG(Error, "======== Kernels already in the registry: ========"); - for (size_t i = 0; i < num_registered_kernels; i++) { - ET_LOG(Error, "%s", registered_kernels[i].name_); - ET_LOG_KERNEL_KEY(registered_kernels[i].kernel_key_); - } - ET_LOG(Error, "======== Kernels being registered: ========"); - for (size_t i = 0; i < kernels.size(); i++) { - ET_LOG(Error, "%s", kernels[i].name_); - ET_LOG_KERNEL_KEY(kernels[i].kernel_key_); - } - return Error::RegistrationExceedingMaxKernels; - } // for debugging purpose ET_UNUSED const char* lib_name = et_pal_get_shared_library_name(kernels.data()); for (const auto& kernel : kernels) { // Linear search. This is fine if the number of kernels is small. + bool is_duplicate = false; for (size_t i = 0; i < num_registered_kernels; i++) { Kernel k = registered_kernels[i]; if (strcmp(kernel.name_, k.name_) == 0 && kernel.kernel_key_ == k.kernel_key_) { - ET_LOG(Error, "Re-registering %s, from %s", k.name_, lib_name); + ET_LOG( + Info, + "Skipping duplicate registration of %s, from %s", + k.name_, + lib_name); ET_LOG_KERNEL_KEY(k.kernel_key_); - return Error::RegistrationAlreadyRegistered; + is_duplicate = true; + break; + } + } + if (!is_duplicate) { + if (num_registered_kernels >= kMaxRegisteredKernels) { + ET_LOG( + Error, + "Registry is full: %" PRIu32 + " kernels registered, cannot add '%s'.", + (uint32_t)num_registered_kernels, + kernel.name_); + return Error::RegistrationExceedingMaxKernels; } + registered_kernels[num_registered_kernels++] = kernel; } - registered_kernels[num_registered_kernels++] = kernel; } ET_LOG( Debug, @@ -106,8 +101,7 @@ Error register_kernels_internal(const Span kernels) { // Registers the kernels, but panics if an error occurs. Always returns Ok. Error register_kernels(const Span kernels) { Error success = register_kernels_internal(kernels); - if (success == Error::RegistrationAlreadyRegistered || - success == Error::RegistrationExceedingMaxKernels) { + if (success == Error::RegistrationExceedingMaxKernels) { ET_CHECK_MSG( false, "Kernel registration failed with error %" PRIu32 diff --git a/runtime/kernel/test/kernel_double_registration_test.cpp b/runtime/kernel/test/kernel_double_registration_test.cpp index 11026fd48fd..08286d8638b 100644 --- a/runtime/kernel/test/kernel_double_registration_test.cpp +++ b/runtime/kernel/test/kernel_double_registration_test.cpp @@ -35,9 +35,12 @@ TEST_F(KernelDoubleRegistrationTest, Basic) { "aten::add.out", "v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3", [](KernelRuntimeContext&, Span) {})}; - Error err = Error::RegistrationAlreadyRegistered; - ET_EXPECT_DEATH( - { (void)register_kernels({kernels}); }, - std::to_string(static_cast(err))); + // First registration should succeed + Error err = register_kernels({kernels}); + EXPECT_EQ(err, Error::Ok); + + // Second registration should succeed but skip the duplicate + err = register_kernels({kernels}); + EXPECT_EQ(err, Error::Ok); } diff --git a/runtime/kernel/test/operator_registry_max_kernel_num_test.cpp b/runtime/kernel/test/operator_registry_max_kernel_num_test.cpp index 627638d098b..89ade68fb0d 100644 --- a/runtime/kernel/test/operator_registry_max_kernel_num_test.cpp +++ b/runtime/kernel/test/operator_registry_max_kernel_num_test.cpp @@ -49,5 +49,15 @@ TEST_F(OperatorRegistryMaxKernelNumTest, RegisterTwoOpsFail) { Kernel("foo2", [](KernelRuntimeContext&, Span) {})}; ET_EXPECT_DEATH( { (void)register_kernels({kernels}); }, - "The total number of kernels to be registered is larger than the limit 1"); + ""); +} + +// Re-registering a duplicate when at capacity should succeed +TEST_F(OperatorRegistryMaxKernelNumTest, DuplicateAtCapacitySucceeds) { + // "foo" was already registered by RegisterOneOp, filling the registry (1/1). + // Re-registering the same kernel should succeed because it's a duplicate. + Kernel kernels[] = { + Kernel("foo", [](KernelRuntimeContext&, Span) {})}; + auto s = register_kernels({kernels}); + EXPECT_EQ(s, Error::Ok); } diff --git a/runtime/kernel/test/operator_registry_test.cpp b/runtime/kernel/test/operator_registry_test.cpp index 5bc411b43ee..e7bf2db0582 100644 --- a/runtime/kernel/test/operator_registry_test.cpp +++ b/runtime/kernel/test/operator_registry_test.cpp @@ -192,12 +192,16 @@ TEST_F(OperatorRegistryTest, Basic) { EXPECT_TRUE(registry_has_op_function("foo")); } -TEST_F(OperatorRegistryTest, RegisterOpsMoreThanOnceDie) { +TEST_F(OperatorRegistryTest, RegisterOpsMoreThanOnceSkipsDuplicate) { Kernel kernels[] = { Kernel("foo", [](KernelRuntimeContext&, Span) {}), Kernel("foo", [](KernelRuntimeContext&, Span) {})}; Span kernels_span = Span(kernels); - ET_EXPECT_DEATH((void)register_kernels(kernels_span), "registration failed"); + // Should succeed and skip the duplicate + Error err = register_kernels(kernels_span); + EXPECT_EQ(err, Error::Ok); + // Verify the operator was registered + EXPECT_TRUE(registry_has_op_function("foo")); } TEST_F(OperatorRegistryTest, KernelKeyEquals) { @@ -387,7 +391,7 @@ TEST_F(OperatorRegistryTest, RegisterTwoKernels) { ASSERT_EQ(val_2, 50); } -TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) { +TEST_F(OperatorRegistryTest, DoubleRegisterKernelsSkipsDuplicate) { std::array buf_long_contiguous; Error err = make_kernel_key( {{ScalarType::Long, {0, 1, 2, 3}}}, @@ -406,10 +410,33 @@ TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) { (void)context; *(stack[0]) = Scalar(50); }); - Kernel kernels[] = {kernel_1, kernel_2}; - // clang-tidy off - ET_EXPECT_DEATH((void)register_kernels(kernels), "registration failed"); - // clang-tidy on + + // Register first kernel + err = register_kernels({kernel_1}); + ASSERT_EQ(err, Error::Ok); + + // Attempt to register duplicate - should succeed but skip + err = register_kernels({kernel_2}); + ASSERT_EQ(err, Error::Ok); + + // Verify first registration was kept (returns 100, not 50) + Tensor::DimOrderType dims[] = {0, 1, 2, 3}; + auto dim_order_type = Span(dims, 4); + TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)}; + Span user_kernel_key(meta); + + EXPECT_TRUE(registry_has_op_function("test::baz", user_kernel_key)); + Result op = get_op_function_from_registry("test::baz", user_kernel_key); + ASSERT_EQ(op.error(), Error::Ok); + + EValue values[1]; + values[0] = Scalar(0); + EValue* evalues[1]; + evalues[0] = &values[0]; + KernelRuntimeContext context{}; + + (*op)(context, Span(evalues)); + ASSERT_EQ(values[0].toScalar().to(), 100); } TEST_F(OperatorRegistryTest, ExecutorChecksKernel) {