Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions runtime/executor/test/executor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ TEST_F(ExecutorTest, OpRegistration) {
auto s2 = register_kernel(Kernel("test_2", test_op));
ASSERT_EQ(Error::Ok, s1);
ASSERT_EQ(Error::Ok, s2);
ET_EXPECT_DEATH(
[]() { (void)register_kernel(Kernel("test", test_op)); }(), "");
// Duplicate registration should succeed and skip gracefully
auto s3 = register_kernel(Kernel("test", test_op));
ASSERT_EQ(Error::Ok, s3);

ASSERT_TRUE(registry_has_op_function("test"));
ASSERT_TRUE(registry_has_op_function("test_2"));
Expand Down
48 changes: 21 additions & 27 deletions runtime/kernel/operator_registry.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -54,44 +54,39 @@
// PAL init, so call it here. It is safe to call multiple times.
::et_pal_init();

if (kernels.size() + num_registered_kernels > kMaxRegisteredKernels) {
ET_LOG(
Error,
"The total number of kernels to be registered is larger than the limit "
"%" PRIu32 ". %" PRIu32
" kernels are already registered and we're trying to register another "
"%" PRIu32 " kernels.",
kMaxRegisteredKernels,
(uint32_t)num_registered_kernels,
(uint32_t)kernels.size());
ET_LOG(Error, "======== Kernels already in the registry: ========");
for (size_t i = 0; i < num_registered_kernels; i++) {
ET_LOG(Error, "%s", registered_kernels[i].name_);
ET_LOG_KERNEL_KEY(registered_kernels[i].kernel_key_);
}
ET_LOG(Error, "======== Kernels being registered: ========");
for (size_t i = 0; i < kernels.size(); i++) {
ET_LOG(Error, "%s", kernels[i].name_);
ET_LOG_KERNEL_KEY(kernels[i].kernel_key_);
}
return Error::RegistrationExceedingMaxKernels;
}
// for debugging purpose
ET_UNUSED const char* lib_name =
et_pal_get_shared_library_name(kernels.data());

for (const auto& kernel : kernels) {
// Linear search. This is fine if the number of kernels is small.
bool is_duplicate = false;
for (size_t i = 0; i < num_registered_kernels; i++) {
Kernel k = registered_kernels[i];
if (strcmp(kernel.name_, k.name_) == 0 &&
kernel.kernel_key_ == k.kernel_key_) {
ET_LOG(Error, "Re-registering %s, from %s", k.name_, lib_name);
ET_LOG(
Info,
"Skipping duplicate registration of %s, from %s",
k.name_,
lib_name);
ET_LOG_KERNEL_KEY(k.kernel_key_);
return Error::RegistrationAlreadyRegistered;
is_duplicate = true;
break;
}
}
if (!is_duplicate) {
if (num_registered_kernels >= kMaxRegisteredKernels) {
ET_LOG(
Error,
"Registry is full: %" PRIu32
" kernels registered, cannot add '%s'.",
(uint32_t)num_registered_kernels,
kernel.name_);
return Error::RegistrationExceedingMaxKernels;
}
registered_kernels[num_registered_kernels++] = kernel;
}
registered_kernels[num_registered_kernels++] = kernel;
}
ET_LOG(
Debug,
Expand All @@ -106,8 +101,7 @@
// Registers the kernels, but panics if an error occurs. Always returns Ok.
Error register_kernels(const Span<const Kernel> kernels) {
Error success = register_kernels_internal(kernels);
if (success == Error::RegistrationAlreadyRegistered ||
success == Error::RegistrationExceedingMaxKernels) {
if (success == Error::RegistrationExceedingMaxKernels) {
ET_CHECK_MSG(
false,
"Kernel registration failed with error %" PRIu32
Expand Down
11 changes: 7 additions & 4 deletions runtime/kernel/test/kernel_double_registration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ TEST_F(KernelDoubleRegistrationTest, Basic) {
"aten::add.out",
"v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3",
[](KernelRuntimeContext&, Span<EValue*>) {})};
Error err = Error::RegistrationAlreadyRegistered;

ET_EXPECT_DEATH(
{ (void)register_kernels({kernels}); },
std::to_string(static_cast<uint32_t>(err)));
// First registration should succeed
Error err = register_kernels({kernels});
EXPECT_EQ(err, Error::Ok);

// Second registration should succeed but skip the duplicate
err = register_kernels({kernels});
EXPECT_EQ(err, Error::Ok);
}
12 changes: 11 additions & 1 deletion runtime/kernel/test/operator_registry_max_kernel_num_test.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -49,5 +49,15 @@
Kernel("foo2", [](KernelRuntimeContext&, Span<EValue*>) {})};
ET_EXPECT_DEATH(
{ (void)register_kernels({kernels}); },
"The total number of kernels to be registered is larger than the limit 1");
"");
}

// Re-registering a duplicate when at capacity should succeed
TEST_F(OperatorRegistryMaxKernelNumTest, DuplicateAtCapacitySucceeds) {
// "foo" was already registered by RegisterOneOp, filling the registry (1/1).
// Re-registering the same kernel should succeed because it's a duplicate.
Kernel kernels[] = {
Kernel("foo", [](KernelRuntimeContext&, Span<EValue*>) {})};
auto s = register_kernels({kernels});
EXPECT_EQ(s, Error::Ok);
}
41 changes: 34 additions & 7 deletions runtime/kernel/test/operator_registry_test.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -192,12 +192,16 @@
EXPECT_TRUE(registry_has_op_function("foo"));
}

TEST_F(OperatorRegistryTest, RegisterOpsMoreThanOnceDie) {
TEST_F(OperatorRegistryTest, RegisterOpsMoreThanOnceSkipsDuplicate) {
Kernel kernels[] = {
Kernel("foo", [](KernelRuntimeContext&, Span<EValue*>) {}),
Kernel("foo", [](KernelRuntimeContext&, Span<EValue*>) {})};
Span<const Kernel> kernels_span = Span<const Kernel>(kernels);
ET_EXPECT_DEATH((void)register_kernels(kernels_span), "registration failed");
// Should succeed and skip the duplicate
Error err = register_kernels(kernels_span);
EXPECT_EQ(err, Error::Ok);
// Verify the operator was registered
EXPECT_TRUE(registry_has_op_function("foo"));
}

TEST_F(OperatorRegistryTest, KernelKeyEquals) {
Expand Down Expand Up @@ -387,7 +391,7 @@
ASSERT_EQ(val_2, 50);
}

TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) {
TEST_F(OperatorRegistryTest, DoubleRegisterKernelsSkipsDuplicate) {
std::array<char, kKernelKeyBufSize> buf_long_contiguous;
Error err = make_kernel_key(
{{ScalarType::Long, {0, 1, 2, 3}}},
Expand All @@ -406,10 +410,33 @@
(void)context;
*(stack[0]) = Scalar(50);
});
Kernel kernels[] = {kernel_1, kernel_2};
// clang-tidy off
ET_EXPECT_DEATH((void)register_kernels(kernels), "registration failed");
// clang-tidy on

// Register first kernel
err = register_kernels({kernel_1});
ASSERT_EQ(err, Error::Ok);

// Attempt to register duplicate - should succeed but skip
err = register_kernels({kernel_2});
ASSERT_EQ(err, Error::Ok);

// Verify first registration was kept (returns 100, not 50)
Tensor::DimOrderType dims[] = {0, 1, 2, 3};
auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
Span<const TensorMeta> user_kernel_key(meta);

EXPECT_TRUE(registry_has_op_function("test::baz", user_kernel_key));
Result<OpFunction> op = get_op_function_from_registry("test::baz", user_kernel_key);
ASSERT_EQ(op.error(), Error::Ok);

EValue values[1];
values[0] = Scalar(0);
EValue* evalues[1];
evalues[0] = &values[0];
KernelRuntimeContext context{};

(*op)(context, Span<EValue*>(evalues));
ASSERT_EQ(values[0].toScalar().to<int64_t>(), 100);
}

TEST_F(OperatorRegistryTest, ExecutorChecksKernel) {
Expand Down
Loading