Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/Native/LibTorchSharp/THSTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,14 +576,14 @@ Tensor THSTensor_gather(
CATCH_TENSOR(torch::gather(*tensor, dim, *index));
}

Tensor THSTensor_gelu(const Tensor tensor)
Tensor THSTensor_gelu(const Tensor tensor, const char* approximate)
{
CATCH_TENSOR(torch::gelu(*tensor));
CATCH_TENSOR(torch::gelu(*tensor, approximate));
}

Tensor THSTensor_gelu_(const Tensor tensor)
Tensor THSTensor_gelu_(const Tensor tensor, const char* approximate)
{
CATCH_TENSOR(torch::gelu_(*tensor));
CATCH_TENSOR(torch::gelu_(*tensor, approximate));
}

Tensor THSTensor_get1(const Tensor tensor, int64_t index)
Expand Down
4 changes: 2 additions & 2 deletions src/Native/LibTorchSharp/THSTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,8 @@ EXPORT_API(Tensor) THSTensor_ge_scalar(const Tensor left, const Scalar right);

EXPORT_API(void) THSTensor_ge_scalar_(const Tensor left, const Scalar right);

EXPORT_API(Tensor) THSTensor_gelu(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_gelu_(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_gelu(const Tensor tensor, const char* approximate);
EXPORT_API(Tensor) THSTensor_gelu_(const Tensor tensor, const char* approximate);

EXPORT_API(Tensor) THSTensor_glu(const Tensor tensor, const int64_t dim);

Expand Down
43 changes: 41 additions & 2 deletions src/TorchSharp/NN/Activation/GELU.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,35 @@ namespace Modules
/// </summary>
public sealed class GELU : ParameterLessModule<Tensor, Tensor>
{
internal GELU(bool inplace) : base(nameof(GELU))
/// <summary>
/// Specifies the approximation method for GELU.
/// </summary>
public enum Approximate
{
/// <summary>
/// Exact GELU computation.
/// </summary>
none,
/// <summary>
/// Tanh-based approximation.
/// </summary>
tanh
}

internal GELU(bool inplace, Approximate approximate = Approximate.none) : base(nameof(GELU))
{
this.inplace = inplace;
this.approximate = approximate;
}

public override Tensor forward(Tensor tensor)
{
return torch.nn.functional.gelu(tensor, inplace);
return torch.nn.functional.gelu(tensor, approximate, inplace);
}

public bool inplace {get; set; }

public Approximate approximate { get; set; }
}
}

Expand All @@ -49,6 +67,16 @@ public static GELU GELU(bool inplace)
return new GELU(inplace);
}

/// <summary>
/// Gaussian Error Linear Units
/// </summary>
/// <param name="approximate">The approximation method to use. Default: none</param>
/// <param name="inplace">Do the operation in-place. Default: False</param>
public static GELU GELU(GELU.Approximate approximate, bool inplace = false)
{
return new GELU(inplace, approximate);
}

public static partial class functional
{
/// <summary>
Expand All @@ -61,6 +89,17 @@ public static Tensor gelu(Tensor x, bool inplace)
return inplace ? x.gelu_().alias() : x.gelu();
}

/// <summary>
/// Gaussian Error Linear Units
/// </summary>
/// <param name="x">The input tensor</param>
/// <param name="approximate">The approximation method to use.</param>
/// <param name="inplace">Do the operation in-place. Default: False</param>
public static Tensor gelu(Tensor x, GELU.Approximate approximate, bool inplace = false)
{
return inplace ? x.gelu_(approximate).alias() : x.gelu(approximate);
}

/// <summary>
/// Gaussian Error Linear Units
/// </summary>
Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -707,10 +707,10 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
internal static extern void THSTensor_elu_(IntPtr tensor, IntPtr alpha, IntPtr scale, IntPtr input_scale);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_gelu(IntPtr tensor);
internal static extern IntPtr THSTensor_gelu(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_gelu_(IntPtr tensor);
internal static extern IntPtr THSTensor_gelu_(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate);
Comment on lines 709 to +713
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new P/Invoke declarations for THSTensor_gelu/THSTensor_gelu_ introduce an LPStr string parameter but don’t specify CharSet/BestFitMapping/ThrowOnUnmappableChar like the other LPStr-based imports in this file (e.g., THSTensor_load/meshgrid/div). This can lead to inconsistent marshaling behavior across platforms and re-enables best-fit character mapping. Consider updating these DllImport attributes to match the existing pattern used for other string parameters in LibTorchSharp.THSTensor.cs.

Copilot uses AI. Check for mistakes.

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_glu(IntPtr tensor, long dim);
Expand Down
20 changes: 18 additions & 2 deletions src/TorchSharp/Tensor/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2977,15 +2977,31 @@ public Tensor elu_(Scalar alpha, Scalar scale, Scalar input_scale)

public Tensor gelu()
{
var res = NativeMethods.THSTensor_gelu(Handle);
var res = NativeMethods.THSTensor_gelu(Handle, "none");
if (res == IntPtr.Zero)
CheckForErrors();
return new Tensor(res);
}

public Tensor gelu(TorchSharp.Modules.GELU.Approximate approximate)
{
var res = NativeMethods.THSTensor_gelu(Handle, approximate == TorchSharp.Modules.GELU.Approximate.tanh ? "tanh" : "none");
if (res == IntPtr.Zero)
CheckForErrors();
return new Tensor(res);
}

public Tensor gelu_()
{
var res = NativeMethods.THSTensor_gelu_(Handle);
var res = NativeMethods.THSTensor_gelu_(Handle, "none");
if (res == IntPtr.Zero)
CheckForErrors();
return new Tensor(res);
}

public Tensor gelu_(TorchSharp.Modules.GELU.Approximate approximate)
{
var res = NativeMethods.THSTensor_gelu_(Handle, approximate == TorchSharp.Modules.GELU.Approximate.tanh ? "tanh" : "none");
Comment on lines +2986 to +3004
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Tensor.gelu overloads take TorchSharp.Modules.GELU.Approximate, which is a nested enum on an nn.Module type. That makes a core Tensor API depend on the Modules layer and forces callers of tensor.gelu(...) / functional.gelu(...) to reference Modules.GELU for what is essentially an ATen algorithm option. Consider moving the approximation enum to a more neutral location (e.g., torch.nn or torch) and having the Tensor/functional overloads use that type (keeping the current overload as a forwarding shim if you want to preserve source compatibility).

Copilot uses AI. Check for mistakes.
if (res == IntPtr.Zero)
CheckForErrors();
return new Tensor(res);
Expand Down
22 changes: 22 additions & 0 deletions test/TorchSharpTest/NN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,28 @@ public void EvaluateGELU()
}
}

[Fact]
public void EvaluateGELUWithTanhApproximate()
{
var rel = GELU(Modules.GELU.Approximate.tanh);

foreach (var device in TestUtils.AvailableDevices()) {
var input = torch.randn(new long[] { 64, 8 }, device: device) * 25.0;
var output = rel.call(input);
Assert.Equal(device.type, output.device_type);

var values = output.data<float>().ToArray();
Assert.Equal(input.shape, output.shape);
Assert.All(values, val => Assert.True(val >= -0.2));
}

// Verify that tanh approximate produces different results from exact
var x = torch.tensor(new float[] { -1.0f, 0.0f, 1.0f, 2.0f });
var exact = torch.nn.functional.gelu(x);
var approx = torch.nn.functional.gelu(x, Modules.GELU.Approximate.tanh);
Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5));
}

[Fact]
public void EvaluatePReLU()
{
Expand Down
Loading