From 906fdc519212344c16601296a25488b926d750b7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 14 Mar 2026 21:32:06 +0000 Subject: [PATCH 1/3] Initial plan From 170a407aa07beb9f04e4b7aa588a27cc7bb2e302 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 14 Mar 2026 21:35:32 +0000 Subject: [PATCH 2/3] fix: add linear_before_reset=1 to GRU op calls in aten_gru Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 67de7076fa..4b3cfef9da 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4357,6 +4357,7 @@ def aten_gru( initial_h=layer_h, direction=direction, hidden_size=hidden_size_attr, + linear_before_reset=1, ) else: Y, Y_h = op.GRU( @@ -4366,6 +4367,7 @@ def aten_gru( initial_h=layer_h, direction=direction, hidden_size=hidden_size_attr, + linear_before_reset=1, ) # Y shape: [seq_length, num_directions, batch_size, hidden_size] From 0226bb36fc5749234114bde90e1d081ea8f569ad Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Mar 2026 17:08:47 +0000 Subject: [PATCH 3/3] fix: add explanatory comment for linear_before_reset=1 in aten_gru Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4b3cfef9da..db02faa17e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4348,6 +4348,11 @@ def aten_gru( # Extract hidden_size from hx shape: [num_layers * num_directions, batch, hidden_size] hidden_size_attr = hx.shape[2] + # linear_before_reset=1 matches PyTorch's GRU formulation where the linear + # transformation is applied before multiplying by the reset gate: + # ht = g(Xt*(Wh^T) + rt (.) (Ht-1*(Rh^T) + Rbh) + Wbh) + # The ONNX default (linear_before_reset=0) uses a different equation and + # would produce numerically incorrect results. if B is not None: Y, Y_h = op.GRU( current_input,