From f1df9a7072fea31a367e773246d4bee68689d4a7 Mon Sep 17 00:00:00 2001 From: wenfei qi Date: Sat, 14 Mar 2026 00:55:13 -0400 Subject: [PATCH] add torch.special.logit --- paconvert/api_mapping.json | 13 +--- tests/test_Tensor_logit.py | 61 ++++++++++++++++++ tests/test_logit.py | 124 ++++++++++++++++++++++++++++++++++++ tests/test_special_logit.py | 112 ++++++++++++++++++++++++++++++++ 4 files changed, 298 insertions(+), 12 deletions(-) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 1d10fe946..e4756b0d2 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -11791,18 +11791,7 @@ } }, "torch.special.logit": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.logit", - "min_input_args": 1, - "args_list": [ - "input", - "eps", - "*", - "out" - ], - "kwargs_change": { - "input": "x" - } + "Matcher": "ChangePrefixMatcher" }, "torch.special.logsumexp": { "Matcher": "ChangePrefixMatcher" diff --git a/tests/test_Tensor_logit.py b/tests/test_Tensor_logit.py index afd91f6aa..b495b8658 100644 --- a/tests/test_Tensor_logit.py +++ b/tests/test_Tensor_logit.py @@ -61,3 +61,64 @@ def test_case_4(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + """2D input with eps""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[0.2, 0.5, 0.8], [0.1, 0.9, 0.3]]) + result = input.logit(eps=1e-6) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + """3D input, no eps""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[0.2, 0.8], [0.4, 0.6]], [[0.1, 0.9], [0.3, 0.7]]]) + result = input.logit() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + """float64 dtype""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486], dtype=torch.float64) + result = input.logit(eps=1e-6) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + """Gradient computation""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0.2796, 0.9331, 0.6486], requires_grad=True) + y = x.logit(eps=1e-6) + y.sum().backward() + x_grad = x.grad + """ + ) + obj.run(pytorch_code, ["y", "x_grad"], check_stop_gradient=False) + + +def test_case_9(): + """Chained call""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([0.2, 0.5, 0.8]).logit(1e-6) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_logit.py b/tests/test_logit.py index 5a8cb84ad..b572873d0 100644 --- a/tests/test_logit.py +++ b/tests/test_logit.py @@ -102,3 +102,127 @@ def test_case_7(): """ ) obj.run(pytorch_code, ["result", "out"]) + + +def test_case_8(): + """No eps, keyword input=""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + result = torch.logit(input=input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + """No eps, positional only""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + result = torch.logit(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """2D input with eps""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[0.2, 0.5, 0.8], [0.1, 0.9, 0.3]]) + result = torch.logit(input, eps=1e-6) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_11(): + """3D input, no eps""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[0.2, 0.8], [0.4, 0.6]], [[0.1, 0.9], [0.3, 0.7]]]) + result = torch.logit(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_12(): + """float64 dtype""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486], dtype=torch.float64) + result = torch.logit(input, eps=1e-6) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_13(): + """out parameter without eps""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + out = torch.zeros(5) + result = torch.logit(input, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_14(): + """Reordered kwargs: out, input, eps""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + out = torch.zeros(5) + result = torch.logit(out=out, input=input, eps=1e-6) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_15(): + """Variable unpacking""" + pytorch_code = textwrap.dedent( + """ + import torch + args = (torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]), 1e-6) + result = torch.logit(*args) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_16(): + """Expression as eps argument""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + result = torch.logit(input, 1e-3 * 1e-3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_17(): + """Gradient computation""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0.2796, 0.9331, 0.6486], requires_grad=True) + y = torch.logit(x, eps=1e-6) + y.sum().backward() + x_grad = x.grad + """ + ) + obj.run(pytorch_code, ["y", "x_grad"], check_stop_gradient=False) diff --git a/tests/test_special_logit.py b/tests/test_special_logit.py index 98aaf394f..49dbf7a51 100644 --- a/tests/test_special_logit.py +++ b/tests/test_special_logit.py @@ -87,3 +87,115 @@ def test_case_6(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + """No eps, keyword input=""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + result = torch.special.logit(input=input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + """2D input with eps""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[0.2, 0.5, 0.8], [0.1, 0.9, 0.3]]) + result = torch.special.logit(input, eps=1e-6) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + """3D input, no eps""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[0.2, 0.8], [0.4, 0.6]], [[0.1, 0.9], [0.3, 0.7]]]) + result = torch.special.logit(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """float64 dtype""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486], dtype=torch.float64) + result = torch.special.logit(input, eps=1e-6) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_11(): + """out parameter without eps""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + out = torch.zeros(5) + result = torch.special.logit(input, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_12(): + """Reordered kwargs: out, input, eps""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + out = torch.zeros(5) + result = torch.special.logit(out=out, input=input, eps=1e-6) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_13(): + """Variable unpacking""" + pytorch_code = textwrap.dedent( + """ + import torch + args = (torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]), 1e-6) + result = torch.special.logit(*args) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_14(): + """Expression as eps argument""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + result = torch.special.logit(input, 1e-3 * 1e-3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_15(): + """Gradient computation""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0.2796, 0.9331, 0.6486], requires_grad=True) + y = torch.special.logit(x, eps=1e-6) + y.sum().backward() + x_grad = x.grad + """ + ) + obj.run(pytorch_code, ["y", "x_grad"], check_stop_gradient=False)