diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index ca1744c88..ada73ce56 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -201,7 +201,7 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) { CelNavigableMutableExpr operand = navigableExpr.children().collect(onlyElement()); return areChildrenArgConstant(operand); case COMPREHENSION: - return !isNestedComprehension(navigableExpr); + return !isNestedComprehension(navigableExpr) && containsFoldableFunctionOnly(navigableExpr); default: return false; } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index 1a3fac852..e259a7a35 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -47,19 +47,24 @@ @RunWith(TestParameterInjector.class) public class ConstantFoldingOptimizerTest { private static final CelOptions CEL_OPTIONS = - CelOptions.current() - .enableTimestampEpoch(true) - .build(); + CelOptions.current().populateMacroCalls(true).enableTimestampEpoch(true).build(); private static final Cel CEL = CelFactory.standardCelBuilder() .addVar("x", SimpleType.DYN) .addVar("y", SimpleType.DYN) .addVar("list_var", ListType.create(SimpleType.STRING)) .addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.STRING)) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "get_true", - CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)), + CelFunctionDecl.newFunctionDeclaration( + "get_list", + CelOverloadDecl.newGlobalOverload( + "get_list_overload", + ListType.create(SimpleType.INT), + ListType.create(SimpleType.INT)))) .addFunctionBindings( CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true)) .addMessageTypes(TestAllTypes.getDescriptor()) @@ -371,6 +376,8 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E @TestParameters("{source: 'x == 42'}") @TestParameters("{source: 'timestamp(100)'}") @TestParameters("{source: 'duration(\"1h\")'}") + @TestParameters("{source: '[true].exists(x, x == get_true())'}") + @TestParameters("{source: 'get_list([1, 2]).map(x, x * 2)'}") public void constantFold_noOp(String source) throws Exception { CelAbstractSyntaxTree ast = CEL.compile(source).getAst();