From 5cf44851967f8673a9e0894734ec83ae14dab487 Mon Sep 17 00:00:00 2001 From: FangRui Date: Mon, 13 Apr 2026 19:49:35 +0800 Subject: [PATCH] fix: compare mrgsort_format2 by stable output prefix --- test/npu_validation/scripts/generate_testcase.py | 10 ++++++++++ test/samples/Mrgsort/mrgsort_format2.py | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test/npu_validation/scripts/generate_testcase.py b/test/npu_validation/scripts/generate_testcase.py index ca802b567..cbc23f668 100644 --- a/test/npu_validation/scripts/generate_testcase.py +++ b/test/npu_validation/scripts/generate_testcase.py @@ -1623,6 +1623,16 @@ def generate_testcase( file_cnt = ptr_elem_counts.get(name, logical_elem_count) if file_cnt and req < int(file_cnt): compare_prefix_counts[name] = req + # TMRGSORT format2 testcase writes three contiguous regions: + # 2-way (256) + 3-way (384) + 4-way (up to 512). + # With 4-way exhausted mode, the stable worst-case valid prefix for 4-way + # is 128 elements, so compare 256 + 384 + 128 = 768 elements. + testcase_lc = testcase.lower() + if testcase_lc == "mrgsort_format2": + for p in output_ptrs: + name = p["name"] + file_cnt = int(ptr_elem_counts.get(name, logical_elem_count)) + compare_prefix_counts[name] = min(file_cnt, 768) for p in output_ptrs: np_dtype = _np_dtype_for_cpp(p["cpp_type"]) name = p["name"] diff --git a/test/samples/Mrgsort/mrgsort_format2.py b/test/samples/Mrgsort/mrgsort_format2.py index b25607f8f..af3421ae7 100644 --- a/test/samples/Mrgsort/mrgsort_format2.py +++ b/test/samples/Mrgsort/mrgsort_format2.py @@ -125,7 +125,7 @@ def build(): srcs=[tb_s0, tb_s1, tb_s2], dsts=[tb_dst3, tb_tmp3], excuted=excuted, - exhausted=True, + exhausted=False, ) # 4-way: src0 + src1 + src2 + src3 -> 1x512 @@ -133,7 +133,7 @@ def build(): srcs=[tb_s0, tb_s1, tb_s2, tb_s3], dsts=[tb_dst4, tb_tmp4], excuted=excuted, - exhausted=False, + exhausted=True, ) sv_out2 = pto.PartitionViewOp(part_view_1x256, tv_out, offsets=[c0, c0], sizes=[c1, c256]).result