diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
index 488434c56f2..65050790293 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
@@ -43,6 +43,7 @@
import java.util.Queue;
import java.util.Set;
import java.util.WeakHashMap;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
@@ -396,8 +397,21 @@ public GraphOperationBuilder opBuilder(String type, String name, Scope scope) {
return new GraphOperationBuilder(this, type, name, scope, dangerousGradientBuilder);
}
+ /**
+ * Attaches a {@link ConcreteFunction} to this graph.
+ *
+ *
If a function with the same defined name has already been attached, this method returns
+ * immediately without re-registering it.
+ *
+ *
The function is also stored in an internal cache to speed up subsequent lookups performed by
+ * {@link #getFunction(String)} and {@link #getFunctionCached(String)}.
+ */
@Override
public void attachFunction(ConcreteFunction function) {
+ String name = function.getDefinedName();
+ if (functionCache.putIfAbsent(name, function) != null) {
+ return;
+ }
try (Reference ref = ref();
PointerScope scope = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
@@ -455,6 +469,10 @@ List getNativeFunctions(PointerScope outerScope) {
* name
*/
public ConcreteFunction getFunction(String key) {
+ ConcreteFunction cached = functionCache.get(key);
+ if (cached != null) {
+ return cached;
+ }
try (Reference ref = ref();
PointerScope scope = new PointerScope()) {
List funcs = getNativeFunctions(scope);
@@ -881,6 +899,44 @@ Set initializers() {
private final Set initializers = Collections.synchronizedSet(new LinkedHashSet<>());
private int newInitializersMarker = -1;
+ /**
+ * Cache of {@link ConcreteFunction}s attached to this graph, indexed by their defined name.
+ *
+ * This cache avoids repeatedly scanning the native function library when resolving functions
+ * during gradient construction or control-flow expansion.
+ *
+ *
The cache is populated lazily when {@link #attachFunction(ConcreteFunction)} is called and
+ * consulted first by {@link #getFunction(String)}.
+ *
+ *
A {@link ConcurrentHashMap} is used to allow concurrent reads during graph building without
+ * additional synchronization.
+ */
+ private final ConcurrentHashMap functionCache =
+ new ConcurrentHashMap<>();
+
+ /**
+ * Returns a cached {@link ConcreteFunction} whose name starts with the provided prefix.
+ *
+ * This is a lightweight lookup helper used when the exact function name is not known but
+ * follows a deterministic prefix (for example functions generated for control-flow constructs or
+ * custom gradient expansions).
+ *
+ *
The search is performed only in the local cache and does not query the native TensorFlow
+ * function library.
+ *
+ * @param prefix function name prefix
+ * @return a cached {@link ConcreteFunction} whose name starts with {@code prefix}, or {@code
+ * null} if none is found
+ */
+ public ConcreteFunction getFunctionCached(String prefix) {
+ for (Map.Entry e : functionCache.entrySet()) {
+ if (e.getKey().startsWith(prefix)) {
+ return e.getValue();
+ }
+ }
+ return null;
+ }
+
/**
* Use builders without locking. This should only be used during custom gradient building.
*
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IfGradientTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IfGradientTest.java
new file mode 100644
index 00000000000..25dccc35c10
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IfGradientTest.java
@@ -0,0 +1,278 @@
+/*
+ Copyright 2026 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================
+ */
+package org.tensorflow;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.junit.jupiter.api.Test;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Gradients;
+import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.op.core.StatefulIf;
+import org.tensorflow.op.core.StatefulPartitionedCall;
+import org.tensorflow.op.core.StatelessIf;
+import org.tensorflow.types.TBool;
+import org.tensorflow.types.TFloat32;
+import org.tensorflow.types.TInt32;
+import org.tensorflow.types.family.TType;
+
+public class IfGradientTest {
+
+ private static ConcreteFunction thenFn() {
+ return ConcreteFunction.create(
+ (Ops tf) -> {
+ Placeholder x = tf.placeholder(TFloat32.class);
+ Operand y = tf.math.mul(x, tf.constant(3.0f));
+ return Signature.builder("thenBranch").input("x", x).output("y", y).build();
+ });
+ }
+
+ private static ConcreteFunction elseFn() {
+ return ConcreteFunction.create(
+ (Ops tf) -> {
+ Placeholder x = tf.placeholder(TFloat32.class);
+ Operand y = tf.math.mul(x, tf.constant(5.0f));
+ return Signature.builder("elseBranch").input("x", x).output("y", y).build();
+ });
+ }
+
+ private static void assertClose(float got, float expected, float eps, String msg) {
+ if (Math.abs(got - expected) > eps) {
+ throw new AssertionError(msg + " (got=" + got + ", expected=" + expected + ")");
+ }
+ }
+
+ private static void primeIfGradFunctions(Graph g) {
+
+ Iterator operations = g.operations();
+ while (operations.hasNext()) {
+ GraphOperation op = operations.next();
+ String type = op.type();
+ if (!StatefulIf.OP_NAME.equals(type) && !StatelessIf.OP_NAME.equals(type)) continue;
+
+ ConcreteFunction thenFwd = op.attributes().getAttrFunction("then_branch");
+ ConcreteFunction elseFwd = op.attributes().getAttrFunction("else_branch");
+
+ int nInputs = op.inputListLength("input");
+ int nOut = op.numOutputs();
+
+ List> tin = new ArrayList<>(nInputs);
+ for (int i = 0; i < nInputs; i++) {
+ Class extends TType> c = op.input(1 + i).asOutput().type();
+ tin.add(c);
+ }
+
+ List> tout = new ArrayList<>(nOut);
+ for (int i = 0; i < nOut; i++) {
+ Class extends TType> c = op.output(i).type();
+ tout.add(c);
+ }
+
+ ConcreteFunction thenGrad = buildBranchGradFn(op.name() + "/then_grad", thenFwd, tin, tout);
+ ConcreteFunction elseGrad = buildBranchGradFn(op.name() + "/else_grad", elseFwd, tin, tout);
+
+ g.attachFunction(thenGrad);
+ g.attachFunction(elseGrad);
+ }
+ }
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ private static ConcreteFunction buildBranchGradFn(
+ String prefix,
+ ConcreteFunction branchFn,
+ List> tin,
+ List> toutForward) {
+
+ return ConcreteFunction.create(
+ (Ops tf) -> {
+ Signature.Builder sig = Signature.builder(prefix);
+
+ List> x = new ArrayList<>(tin.size());
+ for (int i = 0; i < tin.size(); i++) {
+ Placeholder extends TType> ph = tf.placeholder((Class) tin.get(i));
+ x.add(ph);
+ sig.input("x" + i, ph);
+ }
+
+ List> dy = new ArrayList<>(toutForward.size());
+ for (int i = 0; i < toutForward.size(); i++) {
+ Placeholder extends TType> ph = tf.placeholder((Class) toutForward.get(i));
+ dy.add(ph);
+ sig.input("dy" + i, ph);
+ }
+
+ StatefulPartitionedCall yCall =
+ StatefulPartitionedCall.create(tf.scope(), x, toutForward, branchFn);
+
+ Operand> L = tf.constant(0.0f);
+ for (int i = 0; i < toutForward.size(); i++) {
+ Operand> prod = tf.math.mul((Operand) yCall.output().get(i), (Operand) dy.get(i));
+ L = tf.math.add((Operand) L, (Operand) sumAll(tf, prod));
+ }
+
+ Gradients g = tf.gradients((Iterable) List.of((Operand) L), x);
+
+ for (int i = 0; i < tin.size(); i++) {
+ Operand> dx = g.dy(i);
+ sig.output("dx" + i, dx);
+ }
+
+ return sig.build();
+ });
+ }
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ private static Operand> sumAll(Ops tf, Operand> v) {
+ Operand r = tf.rank(v);
+ Operand axes = tf.range(tf.constant(0), r, tf.constant(1));
+ return tf.reduceSum((Operand) v, axes);
+ }
+
+ @Test
+ public void testStatefullIfGradient() {
+ TensorFlow.registerCustomGradient(
+ StatefulIf.OP_NAME,
+ (tf, op, gradOutputs) -> {
+ OperationAttributeInspector attrs = op.attributes();
+ ConcreteFunction thenBranch = attrs.getAttrFunction("then_branch");
+ ConcreteFunction elseBranch = attrs.getAttrFunction("else_branch");
+
+ if (thenBranch == null || elseBranch == null) {
+ int n = 1 + op.inputListLength("input");
+ List> no = new ArrayList<>(n);
+ for (int i = 0; i < n; i++) {
+ no.add(null);
+ }
+ return no;
+ }
+
+ Operand extends TType> cond = op.input(0);
+ int nInputs = op.inputListLength("input");
+ List> inputs = new ArrayList<>(nInputs);
+ for (int i = 0; i < nInputs; i++) {
+ inputs.add(op.input(1 + i));
+ }
+
+ int nOut = op.numOutputs();
+ List> toutForward = new ArrayList<>(nOut);
+ for (int i = 0; i < nOut; i++) {
+ toutForward.add(op.output(i).type());
+ }
+
+ List> tin =
+ inputs.stream().map(input -> input.asOutput().type()).collect(Collectors.toList());
+ List> dys = new ArrayList<>(nOut);
+ for (int i = 0; i < nOut; i++) {
+ Operand> dy = null;
+ if (gradOutputs != null && i < gradOutputs.size()) {
+ dy = gradOutputs.get(i);
+ }
+ if (dy == null) {
+ dy =
+ gradOutputs == null || gradOutputs.isEmpty()
+ ? tf.onesLike((Operand) op.output(i))
+ : tf.zerosLike((Operand) op.output(i));
+ }
+ dys.add(dy);
+ }
+
+ List> input = new ArrayList<>(nInputs + nOut);
+ input.addAll(inputs);
+ input.addAll(dys);
+
+ final String thenPrefix = op.name() + "/then_grad"; // op has unique name
+ final String elsePrefix = op.name() + "/else_grad";
+
+ ConcreteFunction thenGrad = op.env().getFunctionCached(thenPrefix);
+ ConcreteFunction elseGrad = op.env().getFunctionCached(elsePrefix);
+
+ if (thenGrad == null || elseGrad == null) {
+ throw new IllegalStateException("If grad functions not primed for op=" + op.name());
+ }
+ StatefulIf dInputsIf =
+ StatefulIf.create(tf.scope(), cond, input, tin, thenGrad, elseGrad);
+ List> result = new ArrayList<>(1 + nInputs);
+ result.add(null); // no gradient for condition
+ result.addAll(dInputsIf.output());
+ return result;
+ });
+
+ Graph g = new Graph();
+ Ops tf = Ops.create(g);
+
+ var x = tf.placeholder(TFloat32.class); // scalar
+ var cond = tf.placeholder(TBool.class); // scalar
+
+ try (ConcreteFunction thenBranch = thenFn();
+ ConcreteFunction elseBranch = elseFn()) {
+
+ StatefulIf ifOp =
+ StatefulIf.create(
+ tf.scope(),
+ cond,
+ List.of((Operand) x),
+ List.of(TFloat32.class),
+ thenBranch,
+ elseBranch);
+
+ var y = ifOp.output().get(0);
+
+ primeIfGradFunctions(g);
+
+ var dy_dx = g.addGradients(y, new Output[] {x.asOutput()})[0];
+
+ try (Session session = new Session(g)) {
+
+ try (Result r =
+ session
+ .runner()
+ .feed(x, TFloat32.scalarOf(2.0f))
+ .feed(cond, TBool.scalarOf(true))
+ .fetch(y)
+ .fetch(dy_dx)
+ .run()) {
+
+ float yVal = ((TFloat32) r.get(0)).getFloat();
+ float gVal = ((TFloat32) r.get(1)).getFloat();
+
+ assertClose(yVal, 6.0f, 1e-6f, "y mismatch for cond=true");
+ assertClose(gVal, 3.0f, 1e-6f, "grad mismatch for cond=true");
+ }
+
+ // ---- cond=false
+ try (Result r =
+ session
+ .runner()
+ .feed(x, TFloat32.scalarOf(2.0f))
+ .feed(cond, TBool.scalarOf(false))
+ .fetch(y)
+ .fetch(dy_dx)
+ .run()) {
+
+ float yVal = ((TFloat32) r.get(0)).getFloat();
+ float gVal = ((TFloat32) r.get(1)).getFloat();
+ assertClose(yVal, 10.0f, 1e-6f, "y mismatch for cond=false");
+ assertClose(gVal, 5.0f, 1e-6f, "grad mismatch for cond=false");
+ }
+ }
+ ;
+ }
+ }
+ ;
+}