From bcf9e2ecc4787c4e97115250338f493854090f19 Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Sun, 8 Feb 2026 13:56:33 +0100 Subject: [PATCH 01/10] Add GradientDispatch bridge for custom gradient adapter dispatch --- .../op/DispatchingGradientAdapter.java | 76 +++++++++++++++++++ .../org/tensorflow/op/GradientDispatch.java | 25 ++++++ 2 files changed, 101 insertions(+) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java new file mode 100644 index 00000000000..413dd7afb42 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java @@ -0,0 +1,76 @@ +package org.tensorflow.op; + +import java.lang.reflect.Constructor; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import org.tensorflow.AbstractGradientAdapter; +import org.tensorflow.Graph; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.internal.c_api.TFJ_Scope; + +final class DispatchingGradientAdapter extends AbstractGradientAdapter { + + private final ConcurrentMap raw = new ConcurrentHashMap<>(); + private final ConcurrentMap> typed = new ConcurrentHashMap<>(); + + static final class TypedEntry> { + final CustomGradient grad; + final Class inputClass; + final Constructor ctor; + + TypedEntry(CustomGradient grad, Class inputClass) { + this.grad = grad; + this.inputClass = inputClass; + try { + this.ctor = inputClass.getConstructor(org.tensorflow.GraphOperation.class); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException( + "Inputs class " + inputClass.getName() + " must have a public ctor(GraphOperation).", e); + } + } + } + + void putRaw(String opType, RawCustomGradient g) { + raw.put(opType, g); + } + + > void putTyped(String opType, CustomGradient g, Class inputClass) { + typed.put(opType, new TypedEntry<>(g, inputClass)); + } + + @Override + protected List> apply( + Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs) { + + final String opType = operation.type(); + + RawCustomGradient rg = raw.get(opType); + if (rg != null) { + // NativeScope & Ops constructors are package-private => must be in org.tensorflow.op + Scope nativeScope = new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); + return rg.call(new Ops(nativeScope), operation, gradInputs); + } + + @SuppressWarnings("rawtypes") + TypedEntry te = typed.get(opType); + if (te != null) { + return applyTyped(graph, scope, operation, gradInputs, te); + } + + throw new IllegalStateException("No Java custom gradient registered for op type: " + opType); + } + + private > List> applyTyped( + Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs, TypedEntry te) { + try { + T inputs = te.ctor.newInstance(operation); + Scope nativeScope = new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); + return te.grad.call(new Ops(nativeScope), inputs, gradInputs); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Failed to instantiate inputs for " + te.inputClass.getName(), e); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java new file mode 100644 index 00000000000..64504121677 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java @@ -0,0 +1,25 @@ +package org.tensorflow.op; + +import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; + +/** Public bridge to a single native gradient adapter. */ +public final class GradientDispatch { + + // package-private adapter that can access NativeScope/Ops constructors + static final DispatchingGradientAdapter ADAPTER = new DispatchingGradientAdapter(); + + private GradientDispatch() {} + + public static TFJ_GradFuncAdapter adapter() { + return ADAPTER; + } + + public static void putRaw(String opType, RawCustomGradient gradient) { + ADAPTER.putRaw(opType, gradient); + } + + public static > void putTyped( + String opType, CustomGradient gradient, Class inputClass) { + ADAPTER.putTyped(opType, gradient, inputClass); + } +} From 7c7dc54b57555bb5cc49526ce0bf1b8732eb74d9 Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Sun, 8 Feb 2026 13:56:40 +0100 Subject: [PATCH 02/10] Fix custom gradient registration scalability and support NoGradient --- .../tensorflow/AbstractGradientAdapter.java | 12 ++++++- .../main/java/org/tensorflow/TensorFlow.java | 10 ++++-- .../org/tensorflow/op/CustomGradient.java | 32 ++++++++++++++++++- .../op/DispatchingGradientAdapter.java | 18 ++++++++--- .../org/tensorflow/op/RawCustomGradient.java | 31 +++++++++++++++++- 5 files changed, 93 insertions(+), 10 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java index 2119cddaa67..18f95d2197d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java @@ -92,8 +92,18 @@ private static TF_Output toNativeOutputs(List> outputs) { new TF_Output(Pointer.malloc((long) outputs.size() * Pointer.sizeof(TF_Output.class))); for (int i = 0; i < outputs.size(); ++i) { - var output = outputs.get(i).asOutput(); var nativeOutput = nativeOutputs.getPointer(i); + + Operand operand = outputs.get(i); + if (operand == null) { + // "NoGradient" sentinel: null oper + index 0. + // Native side must tolerate TF_Output.oper == nullptr. + nativeOutput.oper((org.tensorflow.internal.c_api.TF_Operation) null); + nativeOutput.index(0); + continue; + } + + var output = operand.asOutput(); nativeOutput.oper(((GraphOperation) output.op()).getUnsafeNativeHandle()); nativeOutput.index(output.index()); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index 7eba6d7ce30..76c0f168eb6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -207,7 +207,10 @@ public static synchronized boolean registerCustomGradient( if (hasGradient(opType)) { return false; } - TFJ_GradFuncAdapter g = RawCustomGradient.adapter(gradient); + + org.tensorflow.op.GradientDispatch.putRaw(opType, gradient); + TFJ_GradFuncAdapter g = org.tensorflow.op.GradientDispatch.adapter(); + if (!TFJ_RegisterCustomGradient(opType, g)) { return false; } @@ -255,7 +258,10 @@ public static synchronized > boolean registerCustomGrad if (hasGradient(opType)) { return false; } - TFJ_GradFuncAdapter g = CustomGradient.adapter(gradient, inputClass); + + org.tensorflow.op.GradientDispatch.putTyped(opType, gradient, inputClass); + TFJ_GradFuncAdapter g = org.tensorflow.op.GradientDispatch.adapter(); + if (!TFJ_RegisterCustomGradient(opType, g)) { return false; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java index 02acce1cb37..4c3b80a6cad 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java @@ -17,10 +17,15 @@ package org.tensorflow.op; import java.util.List; +import org.bytedeco.javacpp.PointerPointer; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.TensorFlow; import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; +import org.tensorflow.internal.c_api.TFJ_GraphId; +import org.tensorflow.internal.c_api.TFJ_Scope; +import org.tensorflow.internal.c_api.TF_Operation; +import org.tensorflow.internal.c_api.TF_Output; /** * A custom gradient for ops of type {@link T}. Should be registered using {@link @@ -57,6 +62,31 @@ public interface CustomGradient { */ static > TFJ_GradFuncAdapter adapter( CustomGradient gradient, Class opClass) { - return new TypedGradientAdapter(gradient, opClass); + + final TypedGradientAdapter impl = new TypedGradientAdapter(gradient, opClass); + + // IMPORTANT: + // Return a *direct* TFJ_GradFuncAdapter subclass, so JavaCPP reliably materializes a function + // pointer thunk for the native side. Some call paths may pass NULL if we return a deeper + // subclass. + return new TFJ_GradFuncAdapter() { + @Override + public int call( + TFJ_GraphId nativeGraphId, + TFJ_Scope nativeScope, + TF_Operation nativeOperation, + TF_Output nativeGradInputs, + int nativeGradInputsLength, + PointerPointer nativeGradOutputsPtr) { + + return impl.call( + nativeGraphId, + nativeScope, + nativeOperation, + nativeGradInputs, + nativeGradInputsLength, + nativeGradOutputsPtr); + } + }; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java index 413dd7afb42..380dd6b555d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java @@ -28,7 +28,8 @@ static final class TypedEntry> { this.ctor = inputClass.getConstructor(org.tensorflow.GraphOperation.class); } catch (NoSuchMethodException e) { throw new IllegalArgumentException( - "Inputs class " + inputClass.getName() + " must have a public ctor(GraphOperation).", e); + "Inputs class " + inputClass.getName() + " must have a public ctor(GraphOperation).", + e); } } } @@ -37,7 +38,8 @@ void putRaw(String opType, RawCustomGradient g) { raw.put(opType, g); } - > void putTyped(String opType, CustomGradient g, Class inputClass) { + > void putTyped( + String opType, CustomGradient g, Class inputClass) { typed.put(opType, new TypedEntry<>(g, inputClass)); } @@ -50,7 +52,8 @@ protected List> apply( RawCustomGradient rg = raw.get(opType); if (rg != null) { // NativeScope & Ops constructors are package-private => must be in org.tensorflow.op - Scope nativeScope = new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); + Scope nativeScope = + new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); return rg.call(new Ops(nativeScope), operation, gradInputs); } @@ -64,10 +67,15 @@ protected List> apply( } private > List> applyTyped( - Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs, TypedEntry te) { + Graph graph, + TFJ_Scope scope, + GraphOperation operation, + List> gradInputs, + TypedEntry te) { try { T inputs = te.ctor.newInstance(operation); - Scope nativeScope = new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); + Scope nativeScope = + new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); return te.grad.call(new Ops(nativeScope), inputs, gradInputs); } catch (ReflectiveOperationException e) { throw new RuntimeException("Failed to instantiate inputs for " + te.inputClass.getName(), e); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java index c2d5496de2a..723d45d58ad 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java @@ -17,11 +17,16 @@ package org.tensorflow.op; import java.util.List; +import org.bytedeco.javacpp.PointerPointer; import org.tensorflow.GraphOperation; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.TensorFlow; import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; +import org.tensorflow.internal.c_api.TFJ_GraphId; +import org.tensorflow.internal.c_api.TFJ_Scope; +import org.tensorflow.internal.c_api.TF_Operation; +import org.tensorflow.internal.c_api.TF_Output; /** * A custom gradient for an op of unspecified type. Should be registered using {@link @@ -54,6 +59,30 @@ public interface RawCustomGradient { * TensorFlow#registerCustomGradient(String, RawCustomGradient)}. */ static TFJ_GradFuncAdapter adapter(RawCustomGradient gradient) { - return new RawGradientAdapter(gradient); + final RawGradientAdapter impl = new RawGradientAdapter(gradient); + + // IMPORTANT: + // Return a *direct* TFJ_GradFuncAdapter subclass, so JavaCPP reliably materializes a function + // pointer thunk for the native side. Some call paths may pass NULL if we return a deeper + // subclass. + return new TFJ_GradFuncAdapter() { + @Override + public int call( + TFJ_GraphId nativeGraphId, + TFJ_Scope nativeScope, + TF_Operation nativeOperation, + TF_Output nativeGradInputs, + int nativeGradInputsLength, + PointerPointer nativeGradOutputsPtr) { + + return impl.call( + nativeGraphId, + nativeScope, + nativeOperation, + nativeGradInputs, + nativeGradInputsLength, + nativeGradOutputsPtr); + } + }; } } From e2fa04b5729159308162bbebd8caa24d314f977b Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Sun, 8 Feb 2026 13:56:51 +0100 Subject: [PATCH 03/10] Handle NoGradient in native custom gradient bridge --- .../internal/c_api/presets/tensorflow.java | 13 ++++ .../internal/c_api/tfj_gradients_impl.cc | 78 ++++++++++++++++--- 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java index cd3be39fde2..01f8fe59b4b 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java @@ -213,6 +213,19 @@ public void map(InfoMap infoMap) { // Skip C++ classes infoMap.put(new Info("tsl::StatusGroup").skip()); + + // Force correct marshalling of TFJ_RegisterCustomGradient callback argument. + // Without an explicit cast, JavaCPP may pass a NULL function pointer for some FunctionPointer + // instances. + infoMap.put( + new Info("TFJ_RegisterCustomGradient") + .javaText( + "public static native @Cast(\"bool\") boolean TFJ_RegisterCustomGradient(" + + "@Cast(\"const char*\") BytePointer op_type, " + + "@Cast(\"TFJ_GradFuncAdapter\") TFJ_GradFuncAdapter custom_gradient_adapter);\n" + + "public static native @Cast(\"bool\") boolean TFJ_RegisterCustomGradient(" + + "@Cast(\"const char*\") String op_type, " + + "@Cast(\"TFJ_GradFuncAdapter\") TFJ_GradFuncAdapter custom_gradient_adapter);\n")); } @Override diff --git a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc index 6882cfa704c..0e30623dbc3 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc +++ b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc @@ -18,6 +18,12 @@ limitations under the License. #include #include +// IMPORTANT: explicit STL includes (this .cc is included via tfj_gradients.h in some builds, +// so we must not rely on transitive includes from other headers). +#include +#include +#include + #include "tfj_graph.h" #include "tsl/platform/errors.h" #include "tensorflow/c/c_api.h" @@ -32,7 +38,6 @@ namespace tensorflow { unordered_map g_grad_func_adapters; /// This method can be used to cast a pointer to/from a C struct that contains only that pointer. It is a bit - /// It has been "inspired" by the TensorFlow C API code, as found at this location when time of writing: /// https://github.com/tensorflow/tensorflow/blob/9d637f69f699c0c422716b56153a8b27b681891a/tensorflow/c/c_api.cc#L658 template T* struct_cast(U* ptr) { @@ -53,16 +58,34 @@ namespace tensorflow { if (found_adapter == g_grad_func_adapters.end()) { return errors::NotFound("No gradient adapter found for operation ", op_type); } - int num_inputs = grad_inputs.size(); - TF_Output* inputs = (TF_Output*)malloc(num_inputs * sizeof(TF_Output)); + + const int num_inputs = static_cast(grad_inputs.size()); + + TF_Output* inputs = nullptr; + if (num_inputs > 0) { + inputs = static_cast(malloc(num_inputs * sizeof(TF_Output))); + if (inputs == nullptr) { + return errors::ResourceExhausted( + "Out of memory allocating inputs for custom gradient of op ", op_type); + } + } + for (int i = 0; i < num_inputs; ++i) { - Output grad_input = grad_inputs[i]; + const Output& grad_input = grad_inputs[i]; inputs[i].oper = struct_cast(grad_input.node()); inputs[i].index = grad_input.index(); } - TF_Output* outputs = NULL; + + TF_Output* outputs = nullptr; LOG(INFO) << "Calling Java gradient function for operation of type " << op_type; - int num_outputs = found_adapter->second( + + TFJ_GradFuncAdapter adapter = found_adapter->second; + if (adapter == nullptr) { + if (inputs != nullptr) free(inputs); + return errors::Unknown("Null Java gradient adapter for op ", op_type); + } + LOG(INFO) << "Adapter ptr for " << op_type << " = " << reinterpret_cast(found_adapter->second); + const int num_outputs = adapter( static_cast(scope.graph()), struct_cast(const_cast(&scope)), struct_cast(op.node()), @@ -70,12 +93,39 @@ namespace tensorflow { num_inputs, &outputs ); + + // Always free inputs, even on error paths. + if (inputs != nullptr) free(inputs); + + // Adapter contract hardening: + // - On Java exception / failure, adapter should return negative or outputs==nullptr. + if (num_outputs < 0) { + if (outputs != nullptr) free(outputs); + return errors::Unknown("Java custom gradient adapter failed for op ", op_type, + " (num_outputs=", num_outputs, ")"); + } + if (num_outputs > 0 && outputs == nullptr) { + return errors::Unknown("Java custom gradient adapter returned null outputs for op ", + op_type, " with num_outputs=", num_outputs); + } + + grad_outputs->reserve(grad_outputs->size() + static_cast(num_outputs)); + for (int i = 0; i < num_outputs; ++i) { - TF_Output output = outputs[i]; - grad_outputs->push_back(Output(struct_cast(output.oper), output.index)); + const TF_Output out = outputs[i]; + + // "NoGradient" sentinel from Java: TF_Output.oper == nullptr + if (out.oper == nullptr) { + // Represent "no gradient" as an empty Output. + // TF's gradient builder should tolerate missing gradients for non-differentiable inputs. + grad_outputs->push_back(Output()); + continue; + } + + grad_outputs->push_back(Output(struct_cast(out.oper), out.index)); } - free(inputs); - free(outputs); // outputs are allocated from Java but must be freed here + + if (outputs != nullptr) free(outputs); return OkStatus(); } } @@ -91,6 +141,14 @@ bool TFJ_HasGradient(const char* op_type) { } bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { + LOG(INFO) << "TFJ_RegisterCustomGradient(" << op_type << ") adapter_ptr=" + << reinterpret_cast(grad_func_adapter); + + if (grad_func_adapter == nullptr) { + LOG(ERROR) << "Refusing to register NULL Java gradient adapter for op " << op_type; + return false; + } + if (TFJ_HasGradient(op_type)) { // Check if gradient already exists otherwise the JVM might abort/crash LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type << ", which has already a registered function"; From 1d43d4cd9542d632071b512cd5c690d4cd25f5f1 Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Tue, 10 Feb 2026 04:53:09 +0100 Subject: [PATCH 04/10] Add tests for NoGradient support in Java custom gradients --- .../org/tensorflow/CustomGradientsTest.java | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java new file mode 100644 index 00000000000..102623ffc5b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java @@ -0,0 +1,124 @@ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; +import org.tensorflow.op.CustomGradient; +import org.tensorflow.op.Ops; +import org.tensorflow.op.RawCustomGradient; +import org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +@DisabledOnOs(OS.WINDOWS) +public class CustomGradientsTest { + + @Test + public void noGradientNullIsSupported() { + // Register a custom gradient for an op that has NO native gradient in TF core. + CustomGradient grad = + (tf, op, gradInputs) -> { + @SuppressWarnings("unchecked") + Operand gLoss = (Operand) gradInputs.get(0); // [B] + + @SuppressWarnings("unchecked") + Operand logits = op.features; + + SparseSoftmaxCrossEntropyWithLogits xent = + SparseSoftmaxCrossEntropyWithLogits.create(tf.scope(), logits, op.labels); + + Operand backprop = xent.backprop(); // [B,C] + Operand gLossE = tf.expandDims(gLoss, tf.constant(1)); // [B,1] + Operand dLogits = tf.math.mul(gLossE, backprop); // [B,C] + + // labels: NoGradient + return java.util.Arrays.asList(dLogits, null); + }; + + assertTrue( + TensorFlow.registerCustomGradient(SparseSoftmaxCrossEntropyWithLogits.Inputs.class, grad)); + + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + + // Small fixed shapes to be able to create an explicit seed (avoid OnesLike in addGradients). + Operand logits = tf.constant(new float[][] {{1f, 2f, 3f}, {3f, 2f, 1f}}); + Operand labels = tf.constant(new int[] {2, 0}); + + SparseSoftmaxCrossEntropyWithLogits xent = + SparseSoftmaxCrossEntropyWithLogits.create(tf.scope(), logits, labels); + + Output loss = xent.loss(); // [2] + Operand seed = tf.constant(new float[] {1f, 1f}); // same shape as loss + + Output[] grads = + g.addGradients( + "seed", + new Output[] {loss}, + new Output[] {logits.asOutput(), labels.asOutput()}, + new Output[] {seed.asOutput()}); + + // logits grad exists, labels grad must be "NoGradient" (represented as a CLOSED Output) + assertNotNull(grads); + assertEquals(2, grads.length); + assertNotNull(grads[0], "Expected gradient for logits"); + assertNotNull(grads[1], "Expected an Output placeholder for labels gradient"); + assertTrue(grads[1].isClosed(), "Expected closed gradient (NoGradient) for labels"); + } + } + + @Test + public void sigmoidGradHasCustomGradientWithoutOnesLikeSeed() { + // Register custom gradient for SigmoidGrad (if already registered, it will return false, + // but the test can still pass because the gradient exists in the current process). + TensorFlow.registerCustomGradient( + "SigmoidGrad", + (RawCustomGradient) + (tf, op, gradInputs) -> { + @SuppressWarnings("unchecked") + Operand y = (Operand) op.input(0); // sigmoid(x) + @SuppressWarnings("unchecked") + Operand dy = (Operand) op.input(1); // upstream into SigmoidGrad + @SuppressWarnings("unchecked") + Operand upstream = (Operand) gradInputs.get(0); + + Operand one = tf.constant(1.0f); + Operand yTimesOneMinusY = tf.math.mul(y, tf.math.sub(one, y)); + + // dL/d(dy) = upstream * y*(1-y) + Operand dDy = tf.math.mul(upstream, yTimesOneMinusY); + + // dL/d(y) not needed for this test; return zeros to keep it non-null. + Operand dY = tf.zerosLike(y); + + return List.of(dY, dDy); + }); + + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + + Operand x = tf.placeholder(TFloat32.class); + Operand y = tf.math.sigmoid(x); + + // Provide an explicit seed dy to avoid Graph.addGradients defaulting to OnesLike(y) + Operand seed = tf.fill(tf.shape(y), tf.constant(1.0f)); + + Output[] grads = + g.addGradients( + "seed", + new Output[] {y.asOutput()}, + new Output[] {x.asOutput()}, + new Output[] {seed.asOutput()}); + + assertNotNull(grads); + assertEquals(1, grads.length); + assertNotNull(grads[0], "Expected a non-null gradient for sigmoid(x) wrt x."); + assertTrue(!grads[0].isClosed(), "Expected an active Output for d(sigmoid)/dx."); + } + } +} From 69a9a3672760188a25229eb57380d388000acafb Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Tue, 10 Feb 2026 09:58:12 +0100 Subject: [PATCH 05/10] Fix custom gradients: support NoGradient and stabilize adapter --- .../tensorflow/AbstractGradientAdapter.java | 17 +- .../org/tensorflow/CustomGradientsTest.java | 5 +- .../internal/c_api/tfj_gradients_impl.cc | 243 +++++++++--------- 3 files changed, 130 insertions(+), 135 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java index 18f95d2197d..ef581db7ae0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java @@ -80,11 +80,11 @@ private static List> fromNativeOutputs(Graph g, TF_Output nativeOutput } /** - * Put the Java outputs into the array of native outputs, resizing it to the necessary size. - * - * @param outputs the outputs to put - * @return pointer to the native array of outputs - */ + * Put the Java outputs into the array of native outputs, resizing it to the necessary size. + * + * @param outputs the outputs to put + * @return pointer to the native array of outputs + */ private static TF_Output toNativeOutputs(List> outputs) { // Use malloc to allocate native outputs, as they will be freed by the native layer and we do // not want JavaCPP to deallocate them @@ -92,13 +92,12 @@ private static TF_Output toNativeOutputs(List> outputs) { new TF_Output(Pointer.malloc((long) outputs.size() * Pointer.sizeof(TF_Output.class))); for (int i = 0; i < outputs.size(); ++i) { + Operand operand = outputs.get(i); var nativeOutput = nativeOutputs.getPointer(i); - Operand operand = outputs.get(i); + // Convention: null Operand => NoGradient if (operand == null) { - // "NoGradient" sentinel: null oper + index 0. - // Native side must tolerate TF_Output.oper == nullptr. - nativeOutput.oper((org.tensorflow.internal.c_api.TF_Operation) null); + nativeOutput.oper((TF_Operation) null); nativeOutput.index(0); continue; } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java index 102623ffc5b..baaa7bdb742 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java @@ -3,6 +3,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; import java.util.List; import org.junit.jupiter.api.Test; @@ -96,7 +97,7 @@ public void sigmoidGradHasCustomGradientWithoutOnesLikeSeed() { // dL/d(y) not needed for this test; return zeros to keep it non-null. Operand dY = tf.zerosLike(y); - return List.of(dY, dDy); + return java.util.Arrays.asList(dY, dDy); }); try (Graph g = new Graph()) { @@ -118,7 +119,7 @@ public void sigmoidGradHasCustomGradientWithoutOnesLikeSeed() { assertNotNull(grads); assertEquals(1, grads.length); assertNotNull(grads[0], "Expected a non-null gradient for sigmoid(x) wrt x."); - assertTrue(!grads[0].isClosed(), "Expected an active Output for d(sigmoid)/dx."); + assertFalse(grads[0].isClosed(), "Expected an active Output for d(sigmoid)/dx."); } } } diff --git a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc index 0e30623dbc3..32e6043b1c3 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc +++ b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc @@ -21,8 +21,8 @@ limitations under the License. // IMPORTANT: explicit STL includes (this .cc is included via tfj_gradients.h in some builds, // so we must not rely on transitive includes from other headers). #include -#include #include +#include #include "tfj_graph.h" #include "tsl/platform/errors.h" @@ -31,141 +31,136 @@ limitations under the License. #include "tensorflow/cc/framework/grad_op_registry.h" namespace tensorflow { - namespace java { - using namespace tsl; - using namespace std; - - unordered_map g_grad_func_adapters; - - /// This method can be used to cast a pointer to/from a C struct that contains only that pointer. It is a bit - /// It has been "inspired" by the TensorFlow C API code, as found at this location when time of writing: - /// https://github.com/tensorflow/tensorflow/blob/9d637f69f699c0c422716b56153a8b27b681891a/tensorflow/c/c_api.cc#L658 - template T* struct_cast(U* ptr) { - return static_cast(static_cast(ptr)); - } - - /// This function is called by the TensorFlow runtime when it is time to add gradient operations of `op` to the - /// graph using the given `scope`. - /// We use it as a bridge between the C++ signature in TensorFlow (tensorflow::op::GradFunc) and our custom - /// "C" version (TFJ_GradFuncAdapter). - Status CustomGradFunc(const Scope& scope, - const Operation& op, - const vector& grad_inputs, - vector* grad_outputs) - { - const string& op_type = op.node()->type_string(); - auto found_adapter = g_grad_func_adapters.find(op_type); - if (found_adapter == g_grad_func_adapters.end()) { - return errors::NotFound("No gradient adapter found for operation ", op_type); - } - - const int num_inputs = static_cast(grad_inputs.size()); - - TF_Output* inputs = nullptr; - if (num_inputs > 0) { - inputs = static_cast(malloc(num_inputs * sizeof(TF_Output))); - if (inputs == nullptr) { - return errors::ResourceExhausted( - "Out of memory allocating inputs for custom gradient of op ", op_type); - } - } - - for (int i = 0; i < num_inputs; ++i) { - const Output& grad_input = grad_inputs[i]; - inputs[i].oper = struct_cast(grad_input.node()); - inputs[i].index = grad_input.index(); - } - - TF_Output* outputs = nullptr; - LOG(INFO) << "Calling Java gradient function for operation of type " << op_type; - - TFJ_GradFuncAdapter adapter = found_adapter->second; - if (adapter == nullptr) { - if (inputs != nullptr) free(inputs); - return errors::Unknown("Null Java gradient adapter for op ", op_type); - } - LOG(INFO) << "Adapter ptr for " << op_type << " = " << reinterpret_cast(found_adapter->second); - const int num_outputs = adapter( - static_cast(scope.graph()), - struct_cast(const_cast(&scope)), - struct_cast(op.node()), - inputs, - num_inputs, - &outputs - ); - - // Always free inputs, even on error paths. - if (inputs != nullptr) free(inputs); - - // Adapter contract hardening: - // - On Java exception / failure, adapter should return negative or outputs==nullptr. - if (num_outputs < 0) { - if (outputs != nullptr) free(outputs); - return errors::Unknown("Java custom gradient adapter failed for op ", op_type, - " (num_outputs=", num_outputs, ")"); - } - if (num_outputs > 0 && outputs == nullptr) { - return errors::Unknown("Java custom gradient adapter returned null outputs for op ", - op_type, " with num_outputs=", num_outputs); - } - - grad_outputs->reserve(grad_outputs->size() + static_cast(num_outputs)); - - for (int i = 0; i < num_outputs; ++i) { - const TF_Output out = outputs[i]; - - // "NoGradient" sentinel from Java: TF_Output.oper == nullptr - if (out.oper == nullptr) { - // Represent "no gradient" as an empty Output. - // TF's gradient builder should tolerate missing gradients for non-differentiable inputs. - grad_outputs->push_back(Output()); - continue; - } - - grad_outputs->push_back(Output(struct_cast(out.oper), out.index)); - } - - if (outputs != nullptr) free(outputs); - return OkStatus(); - } +namespace java { + +using namespace tsl; +using namespace std; + +unordered_map g_grad_func_adapters; + +// Cast helper (inspired by TF C-API) +template +T* struct_cast(U* ptr) { + return static_cast(static_cast(ptr)); +} + +// Bridge called by TF runtime when building gradients for op +Status CustomGradFunc(const Scope& scope, + const Operation& op, + const vector& grad_inputs, + vector* grad_outputs) { + const string& op_type = op.node()->type_string(); + auto found_adapter = g_grad_func_adapters.find(op_type); + if (found_adapter == g_grad_func_adapters.end()) { + return errors::NotFound("No gradient adapter found for operation ", op_type); + } + + TFJ_GradFuncAdapter adapter = found_adapter->second; + if (adapter == nullptr) { + return errors::Unknown("Null Java gradient adapter for op ", op_type); + } + + const int num_inputs = static_cast(grad_inputs.size()); + + TF_Output* inputs = nullptr; + if (num_inputs > 0) { + inputs = static_cast(malloc(num_inputs * sizeof(TF_Output))); + if (inputs == nullptr) { + return errors::ResourceExhausted( + "Out of memory allocating inputs for custom gradient of op ", op_type); + } + } + + for (int i = 0; i < num_inputs; ++i) { + const Output& grad_input = grad_inputs[i]; + inputs[i].oper = struct_cast(grad_input.node()); + inputs[i].index = grad_input.index(); + } + + TF_Output* outputs = nullptr; + + LOG(INFO) << "Calling Java gradient function for operation of type " << op_type; + const int num_outputs = adapter( + static_cast(scope.graph()), + struct_cast(const_cast(&scope)), + struct_cast(op.node()), + inputs, + num_inputs, + &outputs); + + if (inputs != nullptr) free(inputs); + + // Adapter contract: + // - num_outputs < 0 indicates failure + // - num_outputs == 0: OK, outputs may be nullptr + // - num_outputs > 0: outputs must be non-null + if (num_outputs < 0) { + if (outputs != nullptr) free(outputs); + return errors::Unknown("Java custom gradient adapter failed for op ", op_type, + " (num_outputs=", num_outputs, ")"); + } + if (num_outputs > 0 && outputs == nullptr) { + return errors::Unknown("Java custom gradient adapter returned null outputs for op ", + op_type, " with num_outputs=", num_outputs); + } + + grad_outputs->reserve(grad_outputs->size() + static_cast(num_outputs)); + + for (int i = 0; i < num_outputs; ++i) { + const TF_Output out = outputs[i]; + + // Convention: out.oper == nullptr => NoGradient + if (out.oper == nullptr) { + grad_outputs->push_back(Output()); // TF interprets empty Output as "no grad" + continue; } + + grad_outputs->push_back(Output(struct_cast(out.oper), out.index)); + } + + if (outputs != nullptr) free(outputs); // allocated from Java via malloc + return OkStatus(); } +} // namespace java +} // namespace tensorflow + using namespace tensorflow::ops; using namespace tensorflow::java; bool TFJ_HasGradient(const char* op_type) { - GradFunc dummy; - tsl::Status status = GradOpRegistry::Global()->Lookup(op_type, &dummy); - return status.ok(); + GradFunc dummy; + tsl::Status status = GradOpRegistry::Global()->Lookup(op_type, &dummy); + return status.ok(); } bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { - LOG(INFO) << "TFJ_RegisterCustomGradient(" << op_type << ") adapter_ptr=" - << reinterpret_cast(grad_func_adapter); - - if (grad_func_adapter == nullptr) { - LOG(ERROR) << "Refusing to register NULL Java gradient adapter for op " << op_type; - return false; - } - - if (TFJ_HasGradient(op_type)) { // Check if gradient already exists otherwise the JVM might abort/crash - LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type - << ", which has already a registered function"; - return false; - } - bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc); - if (registered) { - g_grad_func_adapters.insert({op_type, grad_func_adapter}); - } - return registered; + LOG(INFO) << "TFJ_RegisterCustomGradient(" << op_type << ") adapter_ptr=" + << reinterpret_cast(grad_func_adapter); + + if (grad_func_adapter == nullptr) { + LOG(ERROR) << "Refusing to register NULL Java gradient adapter for op " << op_type; + return false; + } + + if (TFJ_HasGradient(op_type)) { + LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type + << ", which has already a registered function"; + return false; + } + + bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc); + if (registered) { + g_grad_func_adapters.insert({op_type, grad_func_adapter}); + } + return registered; } -#else // #ifndef _WIN32 - -/* This extension is not available on Windows */ +#else // _WIN32 bool TFJ_HasGradient(const char* op_type) { return true; } -bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { return false; } +bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { + return false; +} -#endif // #ifndef _WIN32 +#endif // _WIN32 From 8d80312b8c0b0efd53132fa5ac8f1d2cb5d5f50c Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Tue, 10 Feb 2026 17:55:07 +0100 Subject: [PATCH 06/10] apply mvn spotless --- .../tensorflow/AbstractGradientAdapter.java | 10 +- .../org/tensorflow/CustomGradientsTest.java | 3 +- .../internal/c_api/tfj_gradients_impl.cc | 230 +++++++++--------- 3 files changed, 116 insertions(+), 127 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java index ef581db7ae0..94f82786ed3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java @@ -80,11 +80,11 @@ private static List> fromNativeOutputs(Graph g, TF_Output nativeOutput } /** - * Put the Java outputs into the array of native outputs, resizing it to the necessary size. - * - * @param outputs the outputs to put - * @return pointer to the native array of outputs - */ + * Put the Java outputs into the array of native outputs, resizing it to the necessary size. + * + * @param outputs the outputs to put + * @return pointer to the native array of outputs + */ private static TF_Output toNativeOutputs(List> outputs) { // Use malloc to allocate native outputs, as they will be freed by the native layer and we do // not want JavaCPP to deallocate them diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java index baaa7bdb742..6d7e9a098cd 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java @@ -1,11 +1,10 @@ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.assertFalse; -import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledOnOs; import org.junit.jupiter.api.condition.OS; diff --git a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc index 32e6043b1c3..9c5ecd75e07 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc +++ b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc @@ -18,8 +18,6 @@ limitations under the License. #include #include -// IMPORTANT: explicit STL includes (this .cc is included via tfj_gradients.h in some builds, -// so we must not rely on transitive includes from other headers). #include #include #include @@ -31,136 +29,128 @@ limitations under the License. #include "tensorflow/cc/framework/grad_op_registry.h" namespace tensorflow { -namespace java { - -using namespace tsl; -using namespace std; - -unordered_map g_grad_func_adapters; - -// Cast helper (inspired by TF C-API) -template -T* struct_cast(U* ptr) { - return static_cast(static_cast(ptr)); -} - -// Bridge called by TF runtime when building gradients for op -Status CustomGradFunc(const Scope& scope, - const Operation& op, - const vector& grad_inputs, - vector* grad_outputs) { - const string& op_type = op.node()->type_string(); - auto found_adapter = g_grad_func_adapters.find(op_type); - if (found_adapter == g_grad_func_adapters.end()) { - return errors::NotFound("No gradient adapter found for operation ", op_type); - } - - TFJ_GradFuncAdapter adapter = found_adapter->second; - if (adapter == nullptr) { - return errors::Unknown("Null Java gradient adapter for op ", op_type); - } - - const int num_inputs = static_cast(grad_inputs.size()); - - TF_Output* inputs = nullptr; - if (num_inputs > 0) { - inputs = static_cast(malloc(num_inputs * sizeof(TF_Output))); - if (inputs == nullptr) { - return errors::ResourceExhausted( - "Out of memory allocating inputs for custom gradient of op ", op_type); - } - } - - for (int i = 0; i < num_inputs; ++i) { - const Output& grad_input = grad_inputs[i]; - inputs[i].oper = struct_cast(grad_input.node()); - inputs[i].index = grad_input.index(); - } - - TF_Output* outputs = nullptr; - - LOG(INFO) << "Calling Java gradient function for operation of type " << op_type; - const int num_outputs = adapter( - static_cast(scope.graph()), - struct_cast(const_cast(&scope)), - struct_cast(op.node()), - inputs, - num_inputs, - &outputs); - - if (inputs != nullptr) free(inputs); - - // Adapter contract: - // - num_outputs < 0 indicates failure - // - num_outputs == 0: OK, outputs may be nullptr - // - num_outputs > 0: outputs must be non-null - if (num_outputs < 0) { - if (outputs != nullptr) free(outputs); - return errors::Unknown("Java custom gradient adapter failed for op ", op_type, - " (num_outputs=", num_outputs, ")"); - } - if (num_outputs > 0 && outputs == nullptr) { - return errors::Unknown("Java custom gradient adapter returned null outputs for op ", - op_type, " with num_outputs=", num_outputs); - } - - grad_outputs->reserve(grad_outputs->size() + static_cast(num_outputs)); - - for (int i = 0; i < num_outputs; ++i) { - const TF_Output out = outputs[i]; - - // Convention: out.oper == nullptr => NoGradient - if (out.oper == nullptr) { - grad_outputs->push_back(Output()); // TF interprets empty Output as "no grad" - continue; + namespace java { + using namespace tsl; + using namespace std; + + unordered_map g_grad_func_adapters; + + /// This method can be used to cast a pointer to/from a C struct that contains only that pointer. It is a bit + /// + /// It has been "inspired" by the TensorFlow C API code, as found at this location when time of writing: + /// https://github.com/tensorflow/tensorflow/blob/9d637f69f699c0c422716b56153a8b27b681891a/tensorflow/c/c_api.cc#L658 + template T* struct_cast(U* ptr) { + return static_cast(static_cast(ptr)); + } + + /// This function is called by the TensorFlow runtime when it is time to add gradient operations of `op` to the + /// graph using the given `scope`. + /// We use it as a bridge between the C++ signature in TensorFlow (tensorflow::op::GradFunc) and our custom + /// "C" version (TFJ_GradFuncAdapter). + Status CustomGradFunc(const Scope& scope, + const Operation& op, + const vector& grad_inputs, + vector* grad_outputs) + { + const string& op_type = op.node()->type_string(); + auto found_adapter = g_grad_func_adapters.find(op_type); + if (found_adapter == g_grad_func_adapters.end()) { + return errors::NotFound("No gradient adapter found for operation ", op_type); + } + + TFJ_GradFuncAdapter adapter = found_adapter->second; + if (adapter == NULL) { + return errors::Unknown("Null Java gradient adapter for operation ", op_type); + } + + int num_inputs = grad_inputs.size(); + TF_Output* inputs = NULL; + if (num_inputs > 0) { + inputs = (TF_Output*)malloc(num_inputs * sizeof(TF_Output)); + if (inputs == NULL) { + return errors::ResourceExhausted( + "Out of memory allocating inputs for custom gradient of op ", op_type); + } + } + + for (int i = 0; i < num_inputs; ++i) { + Output grad_input = grad_inputs[i]; + inputs[i].oper = struct_cast(grad_input.node()); + inputs[i].index = grad_input.index(); + } + + TF_Output* outputs = NULL; + LOG(INFO) << "Calling Java gradient function for operation of type " << op_type; + int num_outputs = adapter( + static_cast(scope.graph()), + struct_cast(const_cast(&scope)), + struct_cast(op.node()), + inputs, + num_inputs, + &outputs + ); + + if (inputs != NULL) free(inputs); + + if (num_outputs < 0) { + if (outputs != NULL) free(outputs); + return errors::Unknown("Java custom gradient adapter failed for operation ", op_type, + " (num_outputs=", num_outputs, ")"); + } + if (num_outputs > 0 && outputs == NULL) { + return errors::Unknown("Java custom gradient adapter returned null outputs for operation ", + op_type, " with num_outputs=", num_outputs); + } + + for (int i = 0; i < num_outputs; ++i) { + TF_Output output = outputs[i]; + + // Convention: output.oper == NULL => NoGradient + if (output.oper == NULL) { + grad_outputs->push_back(Output()); + } else { + grad_outputs->push_back(Output(struct_cast(output.oper), output.index)); + } + } + + if (outputs != NULL) free(outputs); // outputs are allocated from Java but must be freed here + return OkStatus(); + } } - - grad_outputs->push_back(Output(struct_cast(out.oper), out.index)); - } - - if (outputs != nullptr) free(outputs); // allocated from Java via malloc - return OkStatus(); } -} // namespace java -} // namespace tensorflow - using namespace tensorflow::ops; using namespace tensorflow::java; bool TFJ_HasGradient(const char* op_type) { - GradFunc dummy; - tsl::Status status = GradOpRegistry::Global()->Lookup(op_type, &dummy); - return status.ok(); + GradFunc dummy; + tsl::Status status = GradOpRegistry::Global()->Lookup(op_type, &dummy); + return status.ok(); } bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { - LOG(INFO) << "TFJ_RegisterCustomGradient(" << op_type << ") adapter_ptr=" - << reinterpret_cast(grad_func_adapter); - - if (grad_func_adapter == nullptr) { - LOG(ERROR) << "Refusing to register NULL Java gradient adapter for op " << op_type; - return false; - } - - if (TFJ_HasGradient(op_type)) { - LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type - << ", which has already a registered function"; - return false; - } - - bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc); - if (registered) { - g_grad_func_adapters.insert({op_type, grad_func_adapter}); - } - return registered; + if (grad_func_adapter == NULL) { + LOG(ERROR) << "Refusing to register NULL Java gradient adapter for operation " << op_type; + return false; + } + + if (TFJ_HasGradient(op_type)) { // Check if gradient already exists otherwise the JVM might abort/crash + LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type + << ", which has already a registered function"; + return false; + } + bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc); + if (registered) { + g_grad_func_adapters.insert({op_type, grad_func_adapter}); + } + return registered; } -#else // _WIN32 +#else // #ifndef _WIN32 + +/* This extension is not available on Windows */ bool TFJ_HasGradient(const char* op_type) { return true; } -bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { - return false; -} +bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { return false; } -#endif // _WIN32 +#endif // #ifndef _WIN32 From 1240293e6799ca0caef6f439349f1aef304f7785 Mon Sep 17 00:00:00 2001 From: Nicolas Feybesse Date: Thu, 12 Feb 2026 17:29:28 +0100 Subject: [PATCH 07/10] Fix review comments: enforce mutual exclusion for raw/typed gradients, remove inline ifs, add license headers and imports - Prevent dual registration of raw and typed gradients for the same op type - Use putIfAbsent and explicit exceptions to avoid silent overwrites - Replace inline if statements in tfj_gradients_impl.cc with brace blocks - Add Apache 2.0 headers to new files - Replace fully-qualified GradientDispatch reference with import --- .../main/java/org/tensorflow/TensorFlow.java | 9 ++-- .../op/DispatchingGradientAdapter.java | 43 ++++++++++++++++++- .../org/tensorflow/op/GradientDispatch.java | 15 +++++++ .../internal/c_api/tfj_gradients_impl.cc | 14 ++++-- 4 files changed, 72 insertions(+), 9 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index 76c0f168eb6..a6b9e86dd9e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -39,6 +39,7 @@ import org.tensorflow.internal.c_api.TF_Library; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.op.CustomGradient; +import org.tensorflow.op.GradientDispatch; import org.tensorflow.op.RawCustomGradient; import org.tensorflow.op.RawOpInputs; import org.tensorflow.op.annotation.OpInputsMetadata; @@ -208,8 +209,8 @@ public static synchronized boolean registerCustomGradient( return false; } - org.tensorflow.op.GradientDispatch.putRaw(opType, gradient); - TFJ_GradFuncAdapter g = org.tensorflow.op.GradientDispatch.adapter(); + GradientDispatch.putRaw(opType, gradient); + TFJ_GradFuncAdapter g = GradientDispatch.adapter(); if (!TFJ_RegisterCustomGradient(opType, g)) { return false; @@ -259,8 +260,8 @@ public static synchronized > boolean registerCustomGrad return false; } - org.tensorflow.op.GradientDispatch.putTyped(opType, gradient, inputClass); - TFJ_GradFuncAdapter g = org.tensorflow.op.GradientDispatch.adapter(); + GradientDispatch.putTyped(opType, gradient, inputClass); + TFJ_GradFuncAdapter g = GradientDispatch.adapter(); if (!TFJ_RegisterCustomGradient(opType, g)) { return false; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java index 380dd6b555d..80b934460dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java @@ -1,3 +1,18 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op; import java.lang.reflect.Constructor; @@ -16,6 +31,16 @@ final class DispatchingGradientAdapter extends AbstractGradientAdapter { private final ConcurrentMap raw = new ConcurrentHashMap<>(); private final ConcurrentMap> typed = new ConcurrentHashMap<>(); + private static String dupMsg(String opType, String existingKind, String newKind) { + return "A " + + existingKind + + " gradient is already registered for op type '" + + opType + + "'. Raw and typed registrations are mutually exclusive; cannot register " + + newKind + + "."; + } + static final class TypedEntry> { final CustomGradient grad; final Class inputClass; @@ -35,12 +60,26 @@ static final class TypedEntry> { } void putRaw(String opType, RawCustomGradient g) { - raw.put(opType, g); + if (typed.containsKey(opType)) { + throw new IllegalStateException(dupMsg(opType, "typed", "raw")); + } + RawCustomGradient prev = raw.putIfAbsent(opType, g); + if (prev != null) { + throw new IllegalStateException( + "A raw gradient is already registered for op type '" + opType + "'."); + } } > void putTyped( String opType, CustomGradient g, Class inputClass) { - typed.put(opType, new TypedEntry<>(g, inputClass)); + if (raw.containsKey(opType)) { + throw new IllegalStateException(dupMsg(opType, "raw", "typed")); + } + TypedEntry prev = typed.putIfAbsent(opType, new TypedEntry<>(g, inputClass)); + if (prev != null) { + throw new IllegalStateException( + "A typed gradient is already registered for op type '" + opType + "'."); + } } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java index 64504121677..441cff5a2fc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientDispatch.java @@ -1,3 +1,18 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op; import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; diff --git a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc index 9c5ecd75e07..ad68e1e5c05 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc +++ b/tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc @@ -90,13 +90,18 @@ namespace tensorflow { &outputs ); - if (inputs != NULL) free(inputs); + if (inputs != NULL) { + free(inputs); + } if (num_outputs < 0) { - if (outputs != NULL) free(outputs); + if (outputs != NULL) { + free(outputs); + } return errors::Unknown("Java custom gradient adapter failed for operation ", op_type, " (num_outputs=", num_outputs, ")"); } + if (num_outputs > 0 && outputs == NULL) { return errors::Unknown("Java custom gradient adapter returned null outputs for operation ", op_type, " with num_outputs=", num_outputs); @@ -113,7 +118,10 @@ namespace tensorflow { } } - if (outputs != NULL) free(outputs); // outputs are allocated from Java but must be freed here + if (outputs != NULL) { + free(outputs); // outputs are allocated from Java but must be freed here + } + return OkStatus(); } } From fd1cc192ef41ca9bd87ff3a65d3b05618022b984 Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Fri, 13 Feb 2026 04:54:16 +0100 Subject: [PATCH 08/10] Add Javadoc to DispatchingGradientAdapter Document purpose as Java-side gradient dispatcher mirroring TF-Python, clarify raw vs typed gradient registration contract, and note duplicate registration rejection. --- .../op/DispatchingGradientAdapter.java | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java index 80b934460dc..d82aa5df8a2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/DispatchingGradientAdapter.java @@ -26,6 +26,28 @@ import org.tensorflow.Output; import org.tensorflow.internal.c_api.TFJ_Scope; +/** + * Dispatching adapter for Java-side custom gradient registration. + * + *

This class mirrors the behavior of TensorFlow Python's {@code tf.RegisterGradient} mechanism + * by providing a centralized dispatch layer for custom gradients in the Java API. + * + *

Gradients may be registered in one of two forms for a given op type: + * + *

    + *
  • A raw gradient ({@link RawCustomGradient}) operating directly on {@link GraphOperation} and + * {@link Output} objects. + *
  • A typed gradient ({@link CustomGradient}) operating on generated {@link RawOpInputs} + * subclasses. + *
+ * + *

For any given op type, exactly one gradient definition is permitted: either raw or typed. + * Duplicate registrations, or attempts to mix raw and typed gradients for the same op type, are + * rejected to prevent ambiguous dispatch behavior. + * + *

At runtime, {@link #apply(Graph, TFJ_Scope, GraphOperation, List)} determines the operation + * type and dispatches to the corresponding registered gradient implementation. + */ final class DispatchingGradientAdapter extends AbstractGradientAdapter { private final ConcurrentMap raw = new ConcurrentHashMap<>(); From cdd943c23b3fdf4765545b3a4444a7178681c5b9 Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Sat, 14 Feb 2026 10:25:03 +0100 Subject: [PATCH 09/10] Add Apache 2.0 license header to GradientTests --- .../java/org/tensorflow/CustomGradientsTest.java | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java index 6d7e9a098cd..c1d528c8a20 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java @@ -1,3 +1,19 @@ +/* + Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; From b0dbea2df4c3cd9de796c9e43928881cb0d74c4b Mon Sep 17 00:00:00 2001 From: nfeybesse Date: Tue, 24 Feb 2026 17:08:34 +0100 Subject: [PATCH 10/10] Add test registering >10 custom gradients and verifying registry presence This adds a regression test for PR #632. The test dynamically discovers op types with no registered gradient (using TF_GetAllOpList + TensorFlow.hasGradient), registers 11 custom gradients, and verifies that all are present in the native gradient registry. This directly validates that registering more than 10 gradients works and that all entries are correctly stored in the native registry, without relying on Graph.addGradients() execution. Addresses reviewer comment about missing test for >10 gradients. --- .../CustomGradientRegistryCapacityTest.java | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientRegistryCapacityTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientRegistryCapacityTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientRegistryCapacityTest.java new file mode 100644 index 00000000000..67c61195a2f --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientRegistryCapacityTest.java @@ -0,0 +1,102 @@ +/* + Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.tensorflow.internal.c_api.TF_Buffer; +import org.tensorflow.internal.c_api.global.tensorflow; +import org.tensorflow.proto.OpDef; +import org.tensorflow.proto.OpList; + +/** + * Regression test for PR #632: Registering >10 custom gradients must work; all registered opTypes + * must be present in the native gradient registry (TensorFlow.hasGradient(opType) becomes true). + */ +public class CustomGradientRegistryCapacityTest { + + private static List listAllOpTypes() { + TF_Buffer buf = tensorflow.TF_GetAllOpList(); + try { + OpList opList = OpList.parseFrom(buf.dataAsByteBuffer()); + List names = new ArrayList<>(opList.getOpCount()); + for (OpDef op : opList.getOpList()) { + names.add(op.getName()); + } + Collections.sort(names); + return names; + } catch (Exception e) { + throw new RuntimeException("Failed to parse TF_GetAllOpList()", e); + } + } + + @Test + public void registerMoreThanTenGradients_thenHasGradientIsTrue() { + + // 1) Discover op types that currently have NO gradient registered. + List ops = listAllOpTypes(); + + List noGrad = new ArrayList<>(); + for (String opType : ops) { + if (!TensorFlow.hasGradient(opType)) { + // Avoid internal/private ops to reduce risk of weirdness (optional) + if (!opType.startsWith("_")) { + noGrad.add(opType); + } + } + } + + // 2) Pick 11 opTypes (stable order: alphabetical). + // We intentionally pick from "noGrad" so the test is meaningful: + // before: hasGradient=false, after register: true. + assertTrue(noGrad.size() >= 11, "Need at least 11 ops with no gradient in this runtime."); + + List selected = noGrad.subList(0, 11); + + // 3) Before: ensure hasGradient is false + for (String opType : selected) { + assertFalse( + TensorFlow.hasGradient(opType), + "Precondition failed: expected no gradient for " + opType); + } + + // 4) Register 11 custom gradients (simple zerosLike per input) + for (String opType : selected) { + TensorFlow.registerCustomGradient( + opType, + (tf, op, gradInputs) -> { + int n = op.numInputs(); + java.util.ArrayList> grads = new java.util.ArrayList<>(n); + for (int i = 0; i < n; i++) { + grads.add(tf.zerosLike(op.input(i))); + } + return grads; + }); + } + + // 5) After: ensure hasGradient is true for all + for (String opType : selected) { + assertTrue( + TensorFlow.hasGradient(opType), "Expected gradient to be registered for " + opType); + } + } +}