diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 67de7076fa..db02faa17e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4348,6 +4348,11 @@ def aten_gru( # Extract hidden_size from hx shape: [num_layers * num_directions, batch, hidden_size] hidden_size_attr = hx.shape[2] + # linear_before_reset=1 matches PyTorch's GRU formulation where the linear + # transformation is applied before multiplying by the reset gate: + # ht = g(Xt*(Wh^T) + rt (.) (Ht-1*(Rh^T) + Rbh) + Wbh) + # The ONNX default (linear_before_reset=0) uses a different equation and + # would produce numerically incorrect results. if B is not None: Y, Y_h = op.GRU( current_input, @@ -4357,6 +4362,7 @@ def aten_gru( initial_h=layer_h, direction=direction, hidden_size=hidden_size_attr, + linear_before_reset=1, ) else: Y, Y_h = op.GRU( @@ -4366,6 +4372,7 @@ def aten_gru( initial_h=layer_h, direction=direction, hidden_size=hidden_size_attr, + linear_before_reset=1, ) # Y shape: [seq_length, num_directions, batch_size, hidden_size]