diff --git a/.bazelversion b/.bazelversion new file mode 100644 index 000000000..6d2890793 --- /dev/null +++ b/.bazelversion @@ -0,0 +1 @@ +8.5.0 diff --git a/optimizer/optimizers/BUILD.bazel b/optimizer/optimizers/BUILD.bazel index 21241241f..26d98c574 100644 --- a/optimizer/optimizers/BUILD.bazel +++ b/optimizer/optimizers/BUILD.bazel @@ -14,3 +14,8 @@ java_library( name = "common_subexpression_elimination", exports = ["//optimizer/src/main/java/dev/cel/optimizer/optimizers:common_subexpression_elimination"], ) + +java_library( + name = "inlining", + exports = ["//optimizer/src/main/java/dev/cel/optimizer/optimizers:inlining"], +) diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel index d68842315..629918a3d 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -71,6 +71,40 @@ java_library( ], ) +java_library( + name = "inlining", + srcs = [ + "InliningOptimizer.java", + ], + tags = [ + ], + deps = [ + ":default_optimizer_constants", + "//:auto_value", + "//bundle:cel", + "//common:cel_ast", + "//common:cel_source", + "//common:compiler_common", + "//common:mutable_ast", + "//common:mutable_source", + "//common/ast", + "//common/ast:mutable_expr", + "//common/navigation", + "//common/navigation:common", + "//common/navigation:mutable_navigation", + "//common:operator", + "//common/values:values", + "//common/types", + "//common/types:type_providers", + "//optimizer:ast_optimizer", + "//optimizer:mutable_ast", + "//optimizer:optimization_exception", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + "@maven//:org_jspecify_jspecify", + ], +) + java_library( name = "default_optimizer_constants", srcs = [ diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/InliningOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/InliningOptimizer.java new file mode 100644 index 000000000..ac4018816 --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/InliningOptimizer.java @@ -0,0 +1,189 @@ +package dev.cel.optimizer.optimizers; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import dev.cel.bundle.Cel; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelMutableAst; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.ExprKind.Kind; +import dev.cel.common.ast.CelMutableExpr.CelMutableComprehension; +import dev.cel.common.ast.CelMutableExprConverter; +import dev.cel.common.navigation.CelNavigableMutableAst; +import dev.cel.common.navigation.CelNavigableMutableExpr; +import dev.cel.common.Operator; +import dev.cel.common.ast.CelConstant; +import dev.cel.optimizer.AstMutator; +import dev.cel.optimizer.CelAstOptimizer; +import dev.cel.optimizer.CelOptimizationException; +import java.util.Optional; +import dev.cel.common.values.NullValue; +import java.util.stream.Stream; + +/** + * Performs optimization for inlining variables within function calls and select + * statements with + * their associated AST. + */ +public final class InliningOptimizer implements CelAstOptimizer { + + private final ImmutableList inlineVariables; + private final AstMutator astMutator; + + public static InliningOptimizer newInstance(InlineVariable... inlineVariables) { + return newInstance(InliningOptions.newBuilder().build(), ImmutableList.copyOf(inlineVariables)); + } + + public static InliningOptimizer newInstance(InliningOptions options, InlineVariable... inlineVariables) { + return newInstance(options, ImmutableList.copyOf(inlineVariables)); + } + + public static InliningOptimizer newInstance(InliningOptions options, Iterable inlineVariables) { + return new InliningOptimizer(options, ImmutableList.copyOf(inlineVariables)); + } + + @Override + public OptimizationResult optimize(CelAbstractSyntaxTree ast, Cel cel) + throws CelOptimizationException { + CelMutableAst mutableAst = CelMutableAst.fromCelAst(ast); + for (InlineVariable inlineVariable : inlineVariables) { + ImmutableList inlinableExprs = CelNavigableMutableAst.fromAst(mutableAst) + .getRoot() + .allNodes() + .filter(node -> canInline(node, inlineVariable.name())) + .collect(toImmutableList()); + + for (CelNavigableMutableExpr inlinableExpr : inlinableExprs) { + CelExpr replacementExpr = inlineVariable.ast().getExpr(); + + if (inlinableExpr.getKind().equals(Kind.SELECT) && inlinableExpr.expr().select().testOnly()) { + if (replacementExpr.getKind().equals(Kind.SELECT)) { + // Preserve testOnly property for Select replacements (has(A) -> has(B)) + replacementExpr = replacementExpr.toBuilder().setSelect( + replacementExpr.select().toBuilder().setTestOnly(true).build()).build(); + } else { + // Rewrite has(X) -> X != null for non-select replacements + replacementExpr = CelExpr.newBuilder() + .setId(0) + .setCall( + CelExpr.CelCall.newBuilder() + .setFunction(Operator.NOT_EQUALS.getFunction()) + .addArgs(replacementExpr) + .addArgs(CelExpr.newBuilder().setId(0).setConstant(CelConstant.ofValue(NullValue.NULL_VALUE)) + .build()) + .build()) + .build(); + } + } + + mutableAst = astMutator.replaceSubtree(mutableAst, + CelMutableExprConverter.fromCelExpr(replacementExpr), inlinableExpr.id()); + } + } + + return OptimizationResult.create(mutableAst.toParsedAst()); + } + + private static boolean canInline(CelNavigableMutableExpr node, String identifier) { + boolean matches = maybeToQualifiedName(node) + .map(name -> name.equals(identifier)) + .orElse(false); + + if (!matches) { + return false; + } + + for (CelNavigableMutableExpr p = node.parent().orElse(null); p != null; p = p.parent().orElse(null)) { + if (p.getKind() != Kind.COMPREHENSION) { + continue; + } + + CelMutableComprehension comp = p.expr().comprehension(); + boolean shadows = Stream.of(comp.iterVar(), comp.iterVar2(), comp.accuVar()) + .anyMatch(identifier::equals); + + if (shadows) { + return false; + } + } + + return true; + } + + private static Optional maybeToQualifiedName(CelNavigableMutableExpr node) { + if (node.getKind().equals(Kind.IDENT)) { + return Optional.of(node.expr().ident().name()); + } + + if (node.getKind().equals(Kind.SELECT)) { + return node.children().findFirst() + .flatMap(InliningOptimizer::maybeToQualifiedName) + .map(operandName -> operandName + "." + node.expr().select().field()); + } + + return Optional.empty(); + } + + /** + * Represents a variable to be inlined. + */ + @AutoValue + public abstract static class InlineVariable { + public abstract String name(); + + public abstract CelAbstractSyntaxTree ast(); + + /** + * Creates a new {@link InlineVariable} with the given name and AST. + * + *

+ * The name must be a qualified name (e.g. "a.b.c") and cannot be an internal + * variable (starting with @). + */ + public static InlineVariable of(String name, CelAbstractSyntaxTree ast) { + if (name.startsWith("@")) { + throw new IllegalArgumentException("Internal variables cannot be inlined: " + name); + } + return new AutoValue_InliningOptimizer_InlineVariable(name, ast); + } + } + + /** Options to configure how Inlining behaves. */ + @AutoValue + public abstract static class InliningOptions { + public abstract int maxIterationLimit(); + + /** Builder for configuring the {@link InliningOptimizer.InliningOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + + /** + * Limit the number of iteration while inlining variables. An exception is + * thrown if + * the iteration count exceeds the set value. + */ + public abstract InliningOptions.Builder maxIterationLimit(int value); + + public abstract InliningOptimizer.InliningOptions build(); + + Builder() { + } + } + + /** Returns a new options builder with recommended defaults pre-configured. */ + public static InliningOptimizer.InliningOptions.Builder newBuilder() { + return new AutoValue_InliningOptimizer_InliningOptions.Builder() + .maxIterationLimit(400); + } + + InliningOptions() { + } + } + + private InliningOptimizer(InliningOptions options, ImmutableList inlineVariables) { + this.inlineVariables = inlineVariables; + this.astMutator = AstMutator.newInstance(options.maxIterationLimit()); + } +} diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index 46a063a8d..ce9a5dc77 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -91,7 +91,7 @@ * } * */ -public class SubexpressionOptimizer implements CelAstOptimizer { +public final class SubexpressionOptimizer implements CelAstOptimizer { private static final SubexpressionOptimizer INSTANCE = new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build()); diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel index acc8a4c3f..d91e48f54 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -27,6 +27,7 @@ java_library( "//optimizer:optimizer_builder", "//optimizer/optimizers:common_subexpression_elimination", "//optimizer/optimizers:constant_folding", + "//optimizer/optimizers:inlining", "//parser:macro", "//parser:unparser", "//runtime", diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/InliningOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/InliningOptimizerTest.java new file mode 100644 index 000000000..7f117d084 --- /dev/null +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/InliningOptimizerTest.java @@ -0,0 +1,163 @@ +package dev.cel.optimizer.optimizers; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.testing.junit.testparameterinjector.TestParameter; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelOptions; +import dev.cel.common.CelSource; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.types.SimpleType; +import dev.cel.extensions.CelExtensions; +import dev.cel.common.CelContainer; +import dev.cel.common.types.StructTypeReference; +import dev.cel.expr.conformance.proto3.TestAllTypes; +import dev.cel.optimizer.CelOptimizationException; +import dev.cel.optimizer.CelOptimizer; +import dev.cel.optimizer.CelOptimizerFactory; +import dev.cel.optimizer.optimizers.InliningOptimizer.InlineVariable; +import dev.cel.optimizer.optimizers.InliningOptimizer.InliningOptions; +import dev.cel.parser.CelStandardMacro; +import dev.cel.parser.CelUnparserFactory; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class InliningOptimizerTest { + + private static final Cel CEL = CelFactory.standardCelBuilder() + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setContainer(CelContainer.ofName("google.expr.proto3.test")) + .addFileTypes(TestAllTypes.getDescriptor().getFile()) + .addCompilerLibraries(CelExtensions.bindings()) + .addVar("int_var_to_inline", SimpleType.INT) + .addVar("a", SimpleType.DYN) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .addVar("unpacked_wrapper", SimpleType.STRING) + .addVar("wrapper_var", StructTypeReference.create("google.protobuf.Int64Value")) + .addVar("child", StructTypeReference.create(TestAllTypes.NestedMessage.getDescriptor().getFullName())) + .addVar("shadowed_ident", SimpleType.INT) + .addVar("x", SimpleType.DYN) + .setOptions(CelOptions.current().populateMacroCalls(true).build()) + .build(); + + @Test + public void inlining_success(@TestParameter SuccessTestCase testCase) throws Exception { + CelAbstractSyntaxTree astToInline = CEL.compile(testCase.source).getAst(); + CelAbstractSyntaxTree replacementAst = CEL.compile(testCase.replacement).getAst(); + + CelOptimizer optimizer = CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers(InliningOptimizer.newInstance( + InlineVariable.of(testCase.varName, replacementAst))) + .build(); + + CelAbstractSyntaxTree optimized = optimizer.optimize(astToInline); + + String unparsed = CelUnparserFactory.newUnparser().unparse(optimized); + assertThat(unparsed).isEqualTo(testCase.expected); + } + + @Test + public void inlining_noop(@TestParameter NoOpTestCase testCase) throws Exception { + CelAbstractSyntaxTree astToInline = CEL.compile(testCase.source).getAst(); + CelAbstractSyntaxTree replacementAst = CEL.compile(testCase.replacement).getAst(); + + CelOptimizer optimizer = CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers(InliningOptimizer.newInstance( + InlineVariable.of(testCase.varName, replacementAst))) + .build(); + + CelAbstractSyntaxTree optimized = optimizer.optimize(astToInline); + + String unparsed = CelUnparserFactory.newUnparser().unparse(optimized); + assertThat(unparsed).isEqualTo(testCase.source); + } + + private enum SuccessTestCase { + CONSTANT( + "int_var_to_inline + 2 + int_var_to_inline", + "int_var_to_inline", + "1", + "1 + 2 + 1"), + REPEATED( + "a + [a]", + "a", + "dyn([1, 2])", + "dyn([1, 2]) + [dyn([1, 2])]"), + SELECT( + "has(msg.single_any)", + "msg.single_any", + "msg.single_int64_wrapper", + "has(msg.single_int64_wrapper)"), + PRESENCE( + "has(msg.single_int64_wrapper)", + "msg.single_int64_wrapper", + "wrapper_var", + "wrapper_var != null"), + NESTED( + "msg.standalone_message.bb", + "msg.standalone_message", + "child", + "child.bb"), + ; + + private final String source; + private final String varName; + private final String replacement; + private final String expected; + + SuccessTestCase(String source, String varName, String replacement, String expected) { + this.source = source; + this.varName = varName; + this.replacement = replacement; + this.expected = expected; + } + } + + private enum NoOpTestCase { + NO_INLINE_ITER_VAR( + "[0].exists(shadowed_ident, shadowed_ident == 0)", + "shadowed_ident", + "1"), + NO_INLINE_BIND_VAR( + "cel.bind(shadowed_ident, 2, shadowed_ident + 1)", + "shadowed_ident", + "1"), + ; + + private final String source; + private final String varName; + private final String replacement; + + NoOpTestCase(String source, String varName, String replacement) { + this.source = source; + this.varName = varName; + this.replacement = replacement; + } + } + + @Test + public void inline_exceededIterationLimit_throws() throws Exception { + String expression = "int_var_to_inline + int_var_to_inline + int_var_to_inline"; + CelAbstractSyntaxTree astToInline = CEL.compile(expression).getAst(); + CelOptimizer optimizer = CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers(InliningOptimizer.newInstance( + InliningOptions.newBuilder().maxIterationLimit(2).build(), + InlineVariable.of("int_var_to_inline", CEL.compile("1").getAst()))) + .build(); + + CelOptimizationException e = assertThrows(CelOptimizationException.class, () -> optimizer.optimize(astToInline)); + assertThat(e).hasMessageThat().contains("Max iteration count reached."); + } + + @Test + public void inlineVariableDecl_internalVar_throws() throws Exception { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> InlineVariable.of("@internal_var", + CelAbstractSyntaxTree.newParsedAst(CelExpr.ofNotSet(0L), CelSource.newBuilder().build()))); + assertThat(e).hasMessageThat().contains("Internal variables cannot be inlined: @internal_var"); + } +}