From b3f4e2b85388cfb1e9390e72053e0ff172333337 Mon Sep 17 00:00:00 2001 From: alinpahontu2912 Date: Fri, 27 Feb 2026 13:30:27 +0200 Subject: [PATCH] Add approximate parameter to GELU activation function Add support for the 'approximate' parameter in GELU, matching PyTorch's torch.nn.GELU(approximate='tanh') functionality. Changes: - Add GELU.Approximate enum with 'none' and 'tanh' values - Thread approximate parameter through all layers: native C++, PInvoke, Tensor methods, functional API, and module factory - Add new overloads (no breaking changes to existing API) - Add test for tanh approximation mode Fixes dotnet/TorchSharp#1368 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Native/LibTorchSharp/THSTensor.cpp | 8 ++-- src/Native/LibTorchSharp/THSTensor.h | 4 +- src/TorchSharp/NN/Activation/GELU.cs | 43 ++++++++++++++++++- .../PInvoke/LibTorchSharp.THSTensor.cs | 4 +- src/TorchSharp/Tensor/Tensor.cs | 20 ++++++++- test/TorchSharpTest/NN.cs | 22 ++++++++++ 6 files changed, 89 insertions(+), 12 deletions(-) diff --git a/src/Native/LibTorchSharp/THSTensor.cpp b/src/Native/LibTorchSharp/THSTensor.cpp index 7b4a0e55e..33599be06 100644 --- a/src/Native/LibTorchSharp/THSTensor.cpp +++ b/src/Native/LibTorchSharp/THSTensor.cpp @@ -576,14 +576,14 @@ Tensor THSTensor_gather( CATCH_TENSOR(torch::gather(*tensor, dim, *index)); } -Tensor THSTensor_gelu(const Tensor tensor) +Tensor THSTensor_gelu(const Tensor tensor, const char* approximate) { - CATCH_TENSOR(torch::gelu(*tensor)); + CATCH_TENSOR(torch::gelu(*tensor, approximate)); } -Tensor THSTensor_gelu_(const Tensor tensor) +Tensor THSTensor_gelu_(const Tensor tensor, const char* approximate) { - CATCH_TENSOR(torch::gelu_(*tensor)); + CATCH_TENSOR(torch::gelu_(*tensor, approximate)); } Tensor THSTensor_get1(const Tensor tensor, int64_t index) diff --git a/src/Native/LibTorchSharp/THSTensor.h b/src/Native/LibTorchSharp/THSTensor.h index 73bff0403..40edc85ba 100644 --- a/src/Native/LibTorchSharp/THSTensor.h +++ b/src/Native/LibTorchSharp/THSTensor.h @@ -600,8 +600,8 @@ EXPORT_API(Tensor) THSTensor_ge_scalar(const Tensor left, const Scalar right); EXPORT_API(void) THSTensor_ge_scalar_(const Tensor left, const Scalar right); -EXPORT_API(Tensor) THSTensor_gelu(const Tensor tensor); -EXPORT_API(Tensor) THSTensor_gelu_(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_gelu(const Tensor tensor, const char* approximate); +EXPORT_API(Tensor) THSTensor_gelu_(const Tensor tensor, const char* approximate); EXPORT_API(Tensor) THSTensor_glu(const Tensor tensor, const int64_t dim); diff --git a/src/TorchSharp/NN/Activation/GELU.cs b/src/TorchSharp/NN/Activation/GELU.cs index 90c314b99..0b57c17e0 100644 --- a/src/TorchSharp/NN/Activation/GELU.cs +++ b/src/TorchSharp/NN/Activation/GELU.cs @@ -14,17 +14,35 @@ namespace Modules /// public sealed class GELU : ParameterLessModule { - internal GELU(bool inplace) : base(nameof(GELU)) + /// + /// Specifies the approximation method for GELU. + /// + public enum Approximate + { + /// + /// Exact GELU computation. + /// + none, + /// + /// Tanh-based approximation. + /// + tanh + } + + internal GELU(bool inplace, Approximate approximate = Approximate.none) : base(nameof(GELU)) { this.inplace = inplace; + this.approximate = approximate; } public override Tensor forward(Tensor tensor) { - return torch.nn.functional.gelu(tensor, inplace); + return torch.nn.functional.gelu(tensor, approximate, inplace); } public bool inplace {get; set; } + + public Approximate approximate { get; set; } } } @@ -49,6 +67,16 @@ public static GELU GELU(bool inplace) return new GELU(inplace); } + /// + /// Gaussian Error Linear Units + /// + /// The approximation method to use. Default: none + /// Do the operation in-place. Default: False + public static GELU GELU(GELU.Approximate approximate, bool inplace = false) + { + return new GELU(inplace, approximate); + } + public static partial class functional { /// @@ -61,6 +89,17 @@ public static Tensor gelu(Tensor x, bool inplace) return inplace ? x.gelu_().alias() : x.gelu(); } + /// + /// Gaussian Error Linear Units + /// + /// The input tensor + /// The approximation method to use. + /// Do the operation in-place. Default: False + public static Tensor gelu(Tensor x, GELU.Approximate approximate, bool inplace = false) + { + return inplace ? x.gelu_(approximate).alias() : x.gelu(approximate); + } + /// /// Gaussian Error Linear Units /// diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs index e8db2c2cb..108e1b740 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs @@ -707,10 +707,10 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, internal static extern void THSTensor_elu_(IntPtr tensor, IntPtr alpha, IntPtr scale, IntPtr input_scale); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_gelu(IntPtr tensor); + internal static extern IntPtr THSTensor_gelu(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_gelu_(IntPtr tensor); + internal static extern IntPtr THSTensor_gelu_(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_glu(IntPtr tensor, long dim); diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index ea70b83e1..7093a5a90 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -2977,7 +2977,15 @@ public Tensor elu_(Scalar alpha, Scalar scale, Scalar input_scale) public Tensor gelu() { - var res = NativeMethods.THSTensor_gelu(Handle); + var res = NativeMethods.THSTensor_gelu(Handle, "none"); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + public Tensor gelu(TorchSharp.Modules.GELU.Approximate approximate) + { + var res = NativeMethods.THSTensor_gelu(Handle, approximate == TorchSharp.Modules.GELU.Approximate.tanh ? "tanh" : "none"); if (res == IntPtr.Zero) CheckForErrors(); return new Tensor(res); @@ -2985,7 +2993,15 @@ public Tensor gelu() public Tensor gelu_() { - var res = NativeMethods.THSTensor_gelu_(Handle); + var res = NativeMethods.THSTensor_gelu_(Handle, "none"); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + public Tensor gelu_(TorchSharp.Modules.GELU.Approximate approximate) + { + var res = NativeMethods.THSTensor_gelu_(Handle, approximate == TorchSharp.Modules.GELU.Approximate.tanh ? "tanh" : "none"); if (res == IntPtr.Zero) CheckForErrors(); return new Tensor(res); diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index f2ed50db3..b9b5a93bb 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -618,6 +618,28 @@ public void EvaluateGELU() } } + [Fact] + public void EvaluateGELUWithTanhApproximate() + { + var rel = GELU(Modules.GELU.Approximate.tanh); + + foreach (var device in TestUtils.AvailableDevices()) { + var input = torch.randn(new long[] { 64, 8 }, device: device) * 25.0; + var output = rel.call(input); + Assert.Equal(device.type, output.device_type); + + var values = output.data().ToArray(); + Assert.Equal(input.shape, output.shape); + Assert.All(values, val => Assert.True(val >= -0.2)); + } + + // Verify that tanh approximate produces different results from exact + var x = torch.tensor(new float[] { -1.0f, 0.0f, 1.0f, 2.0f }); + var exact = torch.nn.functional.gelu(x); + var approx = torch.nn.functional.gelu(x, Modules.GELU.Approximate.tanh); + Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5)); + } + [Fact] public void EvaluatePReLU() {