Skip to content

Fix aten_gru: add linear_before_reset=1 to match PyTorch GRU semantics#2853

Open
Copilot wants to merge 3 commits intomainfrom
copilot/add-linear-before-reset-to-gru
Open

Fix aten_gru: add linear_before_reset=1 to match PyTorch GRU semantics#2853
Copilot wants to merge 3 commits intomainfrom
copilot/add-linear-before-reset-to-gru

Conversation

Copy link
Contributor

Copilot AI commented Mar 14, 2026

PyTorch nn.GRU applies the linear transformation before multiplying by the reset gate (linear_before_reset=1), but the aten_gru translation was emitting ONNX GRU ops with the default linear_before_reset=0, producing numerically wrong results (error ~0.1 vs expected ~1e-7).

Changes

  • onnxscript/function_libs/torch_lib/ops/core.py: Add linear_before_reset=1 to both op.GRU calls in aten_gru — the biased and unbiased variants.
# Before (incorrect — uses default linear_before_reset=0)
Y, Y_h = op.GRU(current_input, W, R, B, initial_h=layer_h,
                direction=direction, hidden_size=hidden_size_attr)

# After (correct — matches PyTorch GRU: ht = g(Xt*(Wh^T) + rt (.) (Ht-1*(Rh^T) + Rbh) + Wbh))
Y, Y_h = op.GRU(current_input, W, R, B, initial_h=layer_h,
                direction=direction, hidden_size=hidden_size_attr,
                linear_before_reset=1)
Original prompt

This section details on the original issue you should resolve

<issue_title>GRU translation missing linear_before_reset=1 (produces incorrect results for PyTorch GRU)</issue_title>
<issue_description>## Summary

The aten_gru translation in onnxscript/function_libs/torch_lib/ops/core.py (added in #2674) does not set linear_before_reset=1 on the ONNX GRU op. This causes numerically incorrect results because PyTorch's nn.GRU uses the linear_before_reset=1 variant.

Details

PyTorch GRU computes the new gate as:

n_t = tanh(W_in @ x_t + b_in + r_t * (W_hn @ h_{t-1} + b_hn))

This matches the ONNX GRU spec with linear_before_reset=1. But the default linear_before_reset=0 applies the reset gate before the linear transformation — a different equation.

The two op.GRU calls (~lines 4352 and 4362) need linear_before_reset=1 added.

Reproduction

import torch, numpy as np

m = torch.nn.GRU(1, 32, batch_first=True)
m.eval()
inp = torch.randn(1, 10, 1)

with torch.no_grad():
    pt_out, _ = m(inp)

torch.onnx.export(m, (inp,), f="gru.onnx")

import onnxruntime as ort
sess = ort.InferenceSession("gru.onnx")
onnx_out = sess.run(None, {sess.get_inputs()[0].name: inp.numpy()})[0]

print("Max abs diff:", np.abs(pt_out.numpy() - onnx_out).max())
# Expected: ~1e-7 (float32 precision)
# Actual:   ~0.1 (incorrect GRU equation)

Environment

  • torch 2.10.0
  • onnxscript 0.6.2
  • onnxruntime 1.22.0

References

Comments on the Issue (you are @copilot in this section)


📱 Kick off Copilot coding agent tasks wherever you are with GitHub Mobile, available on iOS and Android.

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Copilot AI changed the title [WIP] Fix missing linear_before_reset in GRU translation for ONNX Fix aten_gru: add linear_before_reset=1 to match PyTorch GRU semantics Mar 14, 2026
Copilot AI requested a review from justinchuby March 14, 2026 21:36
@justinchuby justinchuby marked this pull request as ready for review March 14, 2026 21:40
@justinchuby justinchuby enabled auto-merge (squash) March 14, 2026 21:40
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Mar 14, 2026
@codecov
Copy link

codecov bot commented Mar 14, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 71.86%. Comparing base (4c4f7a0) to head (170a407).
✅ All tests successful. No failed tests found.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2853   +/-   ##
=======================================
  Coverage   71.86%   71.86%           
=======================================
  Files         239      239           
  Lines       29139    29139           
  Branches     2875     2875           
=======================================
  Hits        20942    20942           
  Misses       7219     7219           
  Partials      978      978           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby
Copy link
Collaborator

@copilot add a comment inline for future readers

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Mar 15, 2026

@copilot add a comment inline for future readers

Added a comment (commit 0226bb3) explaining that linear_before_reset=1 is needed to match PyTorch's GRU equation and that the ONNX default would produce incorrect results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

GRU translation missing linear_before_reset=1 (produces incorrect results for PyTorch GRU)

2 participants