Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -2768,12 +2768,7 @@
"Matcher": "ChangePrefixMatcher"
},
"torch._assert": {
"Matcher": "AssertMatcher",
"min_input_args": 2,
"args_list": [
"condition",
"message"
]
"Matcher": "ChangePrefixMatcher"
},
"torch._foreach_abs": {
"Matcher": "ForeachMatcher",
Expand Down
82 changes: 82 additions & 0 deletions tests/test__assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,85 @@ def test_case_5():
"""
)
obj.run(pytorch_code)


def test_case_6():
"""Tensor condition"""
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([True])
torch._assert(x, "tensor should be true")
"""
)
obj.run(pytorch_code)


def test_case_7():
"""Boolean literal condition"""
pytorch_code = textwrap.dedent(
"""
import torch
torch._assert(True, "literal true")
"""
)
obj.run(pytorch_code)


def test_case_8():
"""Integer as condition (truthy)"""
pytorch_code = textwrap.dedent(
"""
import torch
torch._assert(1, "one is truthy")
"""
)
obj.run(pytorch_code)


def test_case_9():
"""Empty string message"""
pytorch_code = textwrap.dedent(
"""
import torch
torch._assert(True, "")
"""
)
obj.run(pytorch_code)


def test_case_10():
"""Variable as condition"""
pytorch_code = textwrap.dedent(
"""
import torch
cond = (2 > 1)
msg = "two is greater"
torch._assert(cond, msg)
"""
)
obj.run(pytorch_code)


def test_case_11():
"""Comparison expression with len()"""
pytorch_code = textwrap.dedent(
"""
import torch
data = [1, 2, 3]
torch._assert(len(data) == 3, "length mismatch")
"""
)
obj.run(pytorch_code)


def test_case_12():
"""Tensor comparison condition"""
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor(5.0)
torch._assert(x > 0, "x should be positive")
"""
)
obj.run(pytorch_code)