Skip to content

GRU translation missing linear_before_reset=1 (produces incorrect results for PyTorch GRU) #2852

@daniel-om-weber

Description

@daniel-om-weber

Summary

The aten_gru translation in onnxscript/function_libs/torch_lib/ops/core.py (added in #2674) does not set linear_before_reset=1 on the ONNX GRU op. This causes numerically incorrect results because PyTorch's nn.GRU uses the linear_before_reset=1 variant.

Details

PyTorch GRU computes the new gate as:

n_t = tanh(W_in @ x_t + b_in + r_t * (W_hn @ h_{t-1} + b_hn))

This matches the ONNX GRU spec with linear_before_reset=1. But the default linear_before_reset=0 applies the reset gate before the linear transformation — a different equation.

The two op.GRU calls (~lines 4352 and 4362) need linear_before_reset=1 added.

Reproduction

import torch, numpy as np

m = torch.nn.GRU(1, 32, batch_first=True)
m.eval()
inp = torch.randn(1, 10, 1)

with torch.no_grad():
    pt_out, _ = m(inp)

torch.onnx.export(m, (inp,), f="gru.onnx")

import onnxruntime as ort
sess = ort.InferenceSession("gru.onnx")
onnx_out = sess.run(None, {sess.get_inputs()[0].name: inp.numpy()})[0]

print("Max abs diff:", np.abs(pt_out.numpy() - onnx_out).max())
# Expected: ~1e-7 (float32 precision)
# Actual:   ~0.1 (incorrect GRU equation)

Environment

  • torch 2.10.0
  • onnxscript 0.6.2
  • onnxruntime 1.22.0

References

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions