-
Notifications
You must be signed in to change notification settings - Fork 105
Open
Description
Summary
The aten_gru translation in onnxscript/function_libs/torch_lib/ops/core.py (added in #2674) does not set linear_before_reset=1 on the ONNX GRU op. This causes numerically incorrect results because PyTorch's nn.GRU uses the linear_before_reset=1 variant.
Details
PyTorch GRU computes the new gate as:
n_t = tanh(W_in @ x_t + b_in + r_t * (W_hn @ h_{t-1} + b_hn))
This matches the ONNX GRU spec with linear_before_reset=1. But the default linear_before_reset=0 applies the reset gate before the linear transformation — a different equation.
The two op.GRU calls (~lines 4352 and 4362) need linear_before_reset=1 added.
Reproduction
import torch, numpy as np
m = torch.nn.GRU(1, 32, batch_first=True)
m.eval()
inp = torch.randn(1, 10, 1)
with torch.no_grad():
pt_out, _ = m(inp)
torch.onnx.export(m, (inp,), f="gru.onnx")
import onnxruntime as ort
sess = ort.InferenceSession("gru.onnx")
onnx_out = sess.run(None, {sess.get_inputs()[0].name: inp.numpy()})[0]
print("Max abs diff:", np.abs(pt_out.numpy() - onnx_out).max())
# Expected: ~1e-7 (float32 precision)
# Actual: ~0.1 (incorrect GRU equation)Environment
- torch 2.10.0
- onnxscript 0.6.2
- onnxruntime 1.22.0
References
- ONNX GRU spec (
linear_before_reset) - PyTorch GRU docs
- PR feat: implement LSTM and GRU operators for torchlib #2674 (introduced the
aten_grutranslation)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels