diff --git a/.github/scripts/check_tuned_op_regression.sh b/.github/scripts/check_tuned_op_regression.sh new file mode 100755 index 0000000000..a5f694ca15 --- /dev/null +++ b/.github/scripts/check_tuned_op_regression.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: MIT +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# Wrap .github/scripts/compare_benchmark.py for the tuned_op_bench CI job. +# Prints a comparison table to stdout (captured by the job step into +# $GITHUB_STEP_SUMMARY). Exits 0 unless --fail-on-regress is passed and +# at least one REGRESS row is found. +# +# Usage: check_tuned_op_regression.sh [extra args...] +set -euo pipefail + +BASE=${1:?baseline csv path required} +CURR=${2:?current csv path required} +shift 2 + +BASE_LABEL="baseline" +CURR_LABEL="current" +if [[ -n "${BASE_SHA:-}" ]]; then + BASE_LABEL="main(${BASE_SHA:0:7})" +fi +if [[ -n "${CURR_SHA:-}" ]]; then + CURR_LABEL="PR(${CURR_SHA:0:7})" +elif [[ -n "${GITHUB_SHA:-}" ]]; then + CURR_LABEL="${GITHUB_SHA:0:7}" +fi + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" + +python3 "${REPO_ROOT}/.github/scripts/compare_benchmark.py" \ + "$BASE" "$CURR" \ + --baseline-label "$BASE_LABEL" \ + --current-label "$CURR_LABEL" \ + --warn 1.10 \ + --fail 1.15 \ + "$@" diff --git a/.github/scripts/compare_benchmark.py b/.github/scripts/compare_benchmark.py new file mode 100644 index 0000000000..8cf4229549 --- /dev/null +++ b/.github/scripts/compare_benchmark.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +"""Compare two tuned operator benchmark CSVs (wide-table, metric=us, lower is better). + +Used by .github/workflows/aiter-test.yaml tuned_op_bench job to flag tuned operator +performance regressions between a PR and main. + +CSV schema: + - Columns derived from tuned_fmoe.csv shape/params (dtype, token, + model_dim, inter_dim, E, topk, actType, ...) form the JOIN KEY. + - `us` is the metric (microseconds, lower = faster). + - `kernelName1`, `kernelName2` are NOT part of the key (kernel choice + may differ between baseline and current); they are shown on a + follow-up line when they differ. + - Any other column is treated as a key column. + +Status legend (all rows are printed, prefixed with status tag): + [REGRESS] ratio = current_us / baseline_us > FAIL threshold + [WARN] ratio > WARN threshold (and <= FAIL) + [OK] ratio <= WARN (including faster-than-baseline) + [NEW] shape present in current only (no baseline) + [REMOVED] shape present in baseline only (current missing) + [SKIPPED] missing or invalid us value on either side + +Rows are sorted worst-first: REGRESS, WARN, OK, NEW, REMOVED, SKIPPED. + +Exit code: + 0 always, unless --fail-on-regress is set AND >= 1 REGRESS row exists. + 1 with --fail-on-regress when REGRESS detected. +""" + +from __future__ import annotations + +import argparse +import csv +import sys +from pathlib import Path +from typing import Dict, Tuple + +METRIC = "us" +KERNEL_COLS = ("kernelName1", "kernelName2") +NON_KEY = {METRIC, *KERNEL_COLS} + +# Cols kept in the join key (for correct shape matching) but hidden from +# the printed table (low signal — usually constant across runs). +HIDE_DISPLAY_COLS = ( + "preshuffle", + "strict_accuracy", + "check_aot_cache", + "swiglu_limit", + # Source cols folded into derived `hip` column below + "hidden_pad", + "intermediate_pad", +) + +# Derived display columns. Each entry: derived_name -> (source_col_a, source_col_b) +# Value rendered as "(a, b)" tuple. Source cols stay in the join key; they're +# just hidden from display (covered by HIDE_DISPLAY_COLS above) and a synthetic +# tuple-valued col is inserted in their place. +DERIVED_TUPLE_COLS = { + "hip": ("hidden_pad", "intermediate_pad"), # hidden_pad / intermediate_pad +} + +# Display-only abbreviations. Applied at print time; underlying join key +# still uses full strings, so matching across files is unaffected. +# NOTE: `torch.float8_e4m3fnuz` is AMD's default fp8, mapped to `fp8` so it +# stays consistent with `torch.fp8` alias. The OCP / e5m2 variants keep +# the suffix so they remain distinguishable. +_VALUE_ABBREV = { + "torch.bfloat16": "bf16", + "torch.float16": "fp16", + "torch.float32": "fp32", + "torch.float8_e4m3fnuz": "fp8", # AMD default + "torch.float8_e4m3fn": "fp8e4m3fn", # OCP + "torch.float8_e5m2": "fp8e5m2", + "torch.float4_e2m1fn_x2": "fp4", # x2 = packed (2 elems per byte) + "torch.fp8": "fp8", + "torch.fp4x2": "fp4", + "torch.int8": "i8", + "torch.int4": "i4", + "torch.i4x2": "i4", + # Booleans + "True": "T", + "False": "F", +} +# Enum class prefixes to strip ("ActivationType.Silu" -> "Silu") +_STRIP_PREFIXES = ("ActivationType.", "QuantType.", "GateMode.") + + +def _abbreviate(val: str) -> str: + """Shorten verbose enum/dtype/bool values for table display.""" + if val in _VALUE_ABBREV: + return _VALUE_ABBREV[val] + for prefix in _STRIP_PREFIXES: + if val.startswith(prefix): + return val[len(prefix) :] + return val + + +def _natural_key(val: str): + """Cast numeric strings to numbers for natural sort (token=2 < 16 < 128).""" + try: + return (0, int(val)) + except ValueError: + try: + return (0, float(val)) + except ValueError: + return (1, val) + + +Row = Dict[str, str] +Key = Tuple[Tuple[str, str], ...] + + +def _normalize_row(raw: Row) -> Row: + """Strip whitespace and coerce missing CSV fields to empty strings.""" + return { + k: ("" if v is None else v.strip() if isinstance(v, str) else str(v)) + for k, v in raw.items() + if k is not None + } + + +def _read_csv_rows(path: Path) -> Tuple[list[Row], Tuple[str, ...]]: + """Return (rows, fieldnames). Whitespace stripped from values.""" + if not path.exists(): + raise SystemExit(f"input csv not found: {path}") + with path.open(newline="") as f: + reader = csv.DictReader(f) + if reader.fieldnames is None or METRIC not in reader.fieldnames: + raise SystemExit( + f"{path} missing required column `{METRIC}`; " + f"got columns: {reader.fieldnames}" + ) + rows = [_normalize_row(raw) for raw in reader] + return rows, tuple(reader.fieldnames) + + +def _key_cols(base_cols: Tuple[str, ...], cur_cols: Tuple[str, ...]) -> Tuple[str, ...]: + """Stable key column order across baseline/current schema drift.""" + cols = [] + seen = set() + for col in (*base_cols, *cur_cols): + if col in NON_KEY or col in seen: + continue + cols.append(col) + seen.add(col) + return tuple(cols) + + +def _index_rows(rows: list[Row], key_cols: Tuple[str, ...]) -> Dict[Key, Row]: + indexed: Dict[Key, Row] = {} + for row in rows: + key = tuple(sorted((c, row.get(c, "")) for c in key_cols)) + indexed[key] = row + return indexed + + +def _parse_us(raw: Row) -> float | None: + val = raw.get(METRIC, "") + if val in ("", "-", "skip", "nan", "NaN"): + return None + try: + return float(val) + except ValueError: + return None + + +def _fmt_key_compact( + key: Key, key_cols_order: Tuple[str, ...], constants: Dict[str, str] +) -> str: + """Format key showing only cols whose value is NOT in `constants`.""" + d = dict(key) + parts = [] + for c in key_cols_order: + if c in d and c not in constants: + parts.append(f"{c}={d[c]}") + return " ".join(parts) if parts else "(common)" + + +def _find_constants(keys: list[Key], key_cols_order: Tuple[str, ...]) -> Dict[str, str]: + """Return cols whose value is identical across all `keys`.""" + if not keys: + return {} + first = dict(keys[0]) + constants = {} + for c in key_cols_order: + if c not in first: + continue + v = first[c] + if all(dict(k).get(c) == v for k in keys): + constants[c] = v + return constants + + +def _kernel_diff(base_row: Row, cur_row: Row) -> list[str]: + """Return list of `kernelNameX: -> ` for cols that differ.""" + diffs = [] + for c in KERNEL_COLS: + b, k = base_row.get(c, ""), cur_row.get(c, "") + if b != k: + diffs.append(f"{c}: {b} -> {k}") + return diffs + + +def main() -> int: + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("baseline_csv", type=Path) + parser.add_argument("current_csv", type=Path) + parser.add_argument("--baseline-label", default="baseline") + parser.add_argument("--current-label", default="current") + parser.add_argument( + "--warn", + type=float, + default=1.10, + help="warn threshold: ratio > this is `warn` (default 1.10 = 10%% slower)", + ) + parser.add_argument( + "--fail", + type=float, + default=1.15, + help="regress threshold: ratio > this is `REGRESS` (default 1.15 = 15%% slower)", + ) + parser.add_argument( + "--fail-on-regress", + action="store_true", + help="exit 1 if any REGRESS row found (default: report only, exit 0)", + ) + args = parser.parse_args() + + if args.warn >= args.fail: + raise SystemExit(f"--warn ({args.warn}) must be < --fail ({args.fail})") + + baseline_rows, baseline_cols = _read_csv_rows(args.baseline_csv) + current_rows, current_cols = _read_csv_rows(args.current_csv) + key_cols = _key_cols(baseline_cols, current_cols) + baseline = _index_rows(baseline_rows, key_cols) + current = _index_rows(current_rows, key_cols) + + print(f"=== Tuned op bench: {args.current_label} vs {args.baseline_label} ===") + print(f" baseline: {args.baseline_csv} ({len(baseline)} rows)") + print(f" current: {args.current_csv} ({len(current)} rows)") + print(f" thresholds: warn>{args.warn:.2f}, fail>{args.fail:.2f}") + print() + + common = sorted(baseline.keys() & current.keys()) + only_curr = sorted(current.keys() - baseline.keys()) + only_base = sorted(baseline.keys() - current.keys()) + + # Classify every row. Each entry: (sort_rank, status_tag, key, base, cur, ratio) + # sort_rank: 0=REGRESS, 1=WARN, 2=OK, 3=NEW, 4=REMOVED, 5=SKIPPED + entries: list[tuple[int, str, Key, float | None, float | None, float | None]] = [] + n_regress = n_warn = n_ok = n_skip = 0 + for key in common: + b_us = _parse_us(baseline[key]) + c_us = _parse_us(current[key]) + if b_us is None or c_us is None or b_us <= 0: + n_skip += 1 + entries.append((5, "SKIPPED", key, b_us, c_us, None)) + continue + ratio = c_us / b_us + if ratio > args.fail: + rank, tag = 0, "REGRESS" + n_regress += 1 + elif ratio > args.warn: + rank, tag = 1, "WARN" + n_warn += 1 + else: + rank, tag = 2, "OK" + n_ok += 1 + entries.append((rank, tag, key, b_us, c_us, ratio)) + for key in only_curr: + c_us = _parse_us(current[key]) + entries.append((3, "NEW", key, None, c_us, None)) + for key in only_base: + b_us = _parse_us(baseline[key]) + entries.append((4, "REMOVED", key, b_us, None, None)) + + # Sort worst-first, then by key (natural sort: token=2 < 16 < 128) + def _entry_sort_key(e): + rank, _tag, key, *_ = e + return (rank, [_natural_key(v) for _, v in key]) + + entries.sort(key=_entry_sort_key) + + # ── Build proper tabular output ── + # Columns: status, ratio, cur(us), base(us), *display_cols + # display_cols = key_cols minus HIDE_DISPLAY_COLS, with each derived + # tuple col inserted at the position of its first source col. + # Hidden source cols still contribute to the join key. + # Kernel diffs (not in table) go on indented ↳ sub-lines below each row. + METRIC_HDRS = ("ratio", "cur(us)", "base(us)") + + # Build display_cols: walk key_cols, drop hidden, splice derived in place + _derived_first_src = { + sources[0]: name for name, sources in DERIVED_TUPLE_COLS.items() + } + display_cols: list[str] = [] + for c in key_cols: + if c in _derived_first_src: + display_cols.append(_derived_first_src[c]) + if c not in HIDE_DISPLAY_COLS: + display_cols.append(c) + + def _cell_value(c: str, d: Dict[str, str]) -> str: + if c in DERIVED_TUPLE_COLS: + srcs = DERIVED_TUPLE_COLS[c] + return "(" + ", ".join(d.get(s, "") for s in srcs) + ")" + return _abbreviate(d.get(c, "")) + + def _row_cells(rank, tag, key, b_us, c_us, ratio): + d = dict(key) + cells = [f"[{tag}]"] + cells.append(f"{ratio:.3f}" if ratio is not None else "-") + cells.append(f"{c_us:.2f}" if c_us is not None else "-") + cells.append(f"{b_us:.2f}" if b_us is not None else "-") + for c in display_cols: + cells.append(_cell_value(c, d)) + return cells + + header = ["status", *METRIC_HDRS, *display_cols] + body = [_row_cells(*e) for e in entries] + + # Column widths = max(header, max value) + widths = [ + max(len(header[i]), *(len(r[i]) for r in body)) if body else len(header[i]) + for i in range(len(header)) + ] + + # Right-justify the 3 metric cols (numbers), left-justify the rest. + def _fmt_row(cells): + out = [] + for i, c in enumerate(cells): + justify = str.rjust if 1 <= i <= 3 else str.ljust + out.append(justify(c, widths[i])) + return " ".join(out) + + print(_fmt_row(header)) + print(" ".join("-" * w for w in widths)) + for cells, e in zip(body, entries): + print(_fmt_row(cells)) + # Kernel-diff sub-lines (indented to align under shape columns) + tag = e[1] + key = e[2] + if tag in ("REGRESS", "WARN", "OK"): + for d in _kernel_diff(baseline[key], current[key]): + # Indent past the status column for visual hierarchy + print(" " * (widths[0] + 2) + "↳ " + d) + + print() + print("Summary:") + print(f" compared: {len(common)}") + print(f" REGRESS: {n_regress}") + print(f" WARN: {n_warn}") + print(f" OK: {n_ok}") + print(f" NEW: {len(only_curr)}") + print(f" REMOVED: {len(only_base)}") + print(f" SKIPPED: {n_skip} (missing/invalid us value)") + + if args.fail_on_regress and n_regress > 0: + print(f"\nFAIL: {n_regress} regression(s) above threshold.", file=sys.stderr) + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/workflows/aiter-test.yaml b/.github/workflows/aiter-test.yaml index b8a2f40fa6..5629535cd3 100644 --- a/.github/workflows/aiter-test.yaml +++ b/.github/workflows/aiter-test.yaml @@ -504,10 +504,13 @@ jobs: - name: Upload test logs uses: actions/upload-artifact@v4 - if: success() + if: always() with: name: standard-test-log-${{ matrix.runner }}-shard-${{ matrix.shard_idx }} - path: latest_test.log + path: | + latest_test.log + tuned_op_bench.csv + if-no-files-found: warn retention-days: 7 - name: Cleanup container @@ -756,3 +759,175 @@ jobs: if: always() run: | ./.github/scripts/clean_up_rocm.sh + + tuned_op_bench: + # Tuned operator perf regression check. + # - PR: pull last main baseline CSV, compare current shard's CSV, warn-only + # - push to main / workflow_dispatch: publish current CSV as next baseline + # Only consumes csv from linux-aiter-mi35x-1 runner to avoid cross-arch noise. + name: Tuned Op Bench + if: >- + always() && + !cancelled() && + !github.event.pull_request.draft && + github.event.action != 'labeled' && + github.event_name != 'schedule' + runs-on: ubuntu-latest + needs: [standard] + timeout-minutes: 15 + steps: + - uses: actions/checkout@v4 + + - name: Download standard test logs (mi35x only) + uses: actions/download-artifact@v4 + continue-on-error: true + with: + pattern: standard-test-log-linux-aiter-mi35x-1-shard-* + path: /tmp/logs/ + + - name: Locate current tuned_op_bench.csv + id: current + run: | + shopt -s nullglob + csv_files=(/tmp/logs/*/tuned_op_bench.csv) + if [[ ${#csv_files[@]} -eq 0 ]]; then + echo "::warning::No tuned op benchmark CSV found in any mi35x shard; skipping tuned_op_bench" + echo "found=false" >> "$GITHUB_OUTPUT" + exit 0 + fi + echo "Merging tuned op benchmark CSVs:" + printf ' %s\n' "${csv_files[@]}" + python3 - "${csv_files[@]}" <<'PY' + import csv + import sys + + rows = [] + fieldnames = [] + seen = set() + for path in sys.argv[1:]: + with open(path, newline="") as f: + reader = csv.DictReader(f) + if not reader.fieldnames: + continue + for name in reader.fieldnames: + if name not in seen: + fieldnames.append(name) + seen.add(name) + rows.extend(reader) + + if "us" not in seen: + raise SystemExit("merged tuned op benchmark CSV is missing required `us` column") + + with open("/tmp/current.csv", "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + for row in rows: + writer.writerow({name: row.get(name, "") for name in fieldnames}) + print(f"Merged {len(rows)} row(s) into /tmp/current.csv") + PY + echo "Current merged CSV: /tmp/current.csv ($(wc -l < /tmp/current.csv) lines)" + echo "found=true" >> "$GITHUB_OUTPUT" + + # ── PR path: compare vs baseline ── + - name: Fetch baseline from PR base.sha + if: steps.current.outputs.found == 'true' && github.event_name == 'pull_request' + id: baseline_pinned + continue-on-error: true + uses: dawidd6/action-download-artifact@v3 + with: + workflow: aiter-test.yaml + commit: ${{ github.event.pull_request.base.sha }} + name: tuned-op-bench-${{ github.event.pull_request.base.sha }} + path: /tmp/baseline_pinned/ + if_no_artifact_found: warn + + - name: Fallback — fetch baseline from latest main + if: >- + steps.current.outputs.found == 'true' && + github.event_name == 'pull_request' + id: baseline_main + continue-on-error: true + uses: dawidd6/action-download-artifact@v3 + with: + workflow: aiter-test.yaml + branch: main + name_is_regexp: true + name: ^tuned-op-bench-[a-f0-9]+$ + path: /tmp/baseline_main/ + if_no_artifact_found: warn + + - name: Compare + if: steps.current.outputs.found == 'true' && github.event_name == 'pull_request' + env: + BASE_SHA: ${{ github.event.pull_request.base.sha }} + CURR_SHA: ${{ github.event.pull_request.head.sha }} + run: | + set -euo pipefail + baseline_csv="" + if [[ -f /tmp/baseline_pinned/tuned_op_bench.csv ]]; then + baseline_csv=/tmp/baseline_pinned/tuned_op_bench.csv + echo "Using baseline pinned to PR.base.sha=${BASE_SHA:0:7}" + else + # fallback: pick first match under /tmp/baseline_main/* + shopt -s nullglob + candidates=(/tmp/baseline_main/*/tuned_op_bench.csv /tmp/baseline_main/tuned_op_bench.csv) + for c in "${candidates[@]}"; do + if [[ -f "$c" ]]; then + baseline_csv="$c" + echo "Using fallback baseline from latest main: $c" + break + fi + done + fi + if [[ -z "$baseline_csv" ]]; then + echo "::warning::No tuned op benchmark baseline found (neither pinned PR.base.sha nor latest main); skipping compare." + { + echo "## Tuned Op Bench" + echo + echo "_No baseline available — first run on this branch or main hasn't published baseline yet._" + } >> "$GITHUB_STEP_SUMMARY" + exit 0 + fi + echo "## Tuned Op Bench (vs baseline)" >> "$GITHUB_STEP_SUMMARY" + echo '```' >> "$GITHUB_STEP_SUMMARY" + bash .github/scripts/check_tuned_op_regression.sh \ + "$baseline_csv" /tmp/current.csv \ + | tee -a "$GITHUB_STEP_SUMMARY" + echo '```' >> "$GITHUB_STEP_SUMMARY" + + # ── main push / workflow_dispatch path: publish baseline ── + - name: Stage baseline payload + if: >- + steps.current.outputs.found == 'true' && + (github.event_name == 'push' && github.ref == 'refs/heads/main' + || github.event_name == 'workflow_dispatch') + run: | + mkdir -p /tmp/publish + cp /tmp/current.csv /tmp/publish/tuned_op_bench.csv + python3 -c " + import json, os, datetime + meta = { + 'commit': os.environ['GITHUB_SHA'], + 'ref': os.environ['GITHUB_REF'], + 'event': os.environ['GITHUB_EVENT_NAME'], + 'runner_pool': 'linux-aiter-mi35x-1', + 'gpu_arch_list': os.environ.get('GPU_ARCH_LIST', ''), + 'ran_at': datetime.datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ'), + } + with open('/tmp/publish/metadata.json', 'w') as f: + json.dump(meta, f, indent=2) + print(json.dumps(meta, indent=2)) + " + ls -la /tmp/publish/ + wc -l /tmp/publish/tuned_op_bench.csv + + - name: Publish baseline artifact + if: >- + steps.current.outputs.found == 'true' && + (github.event_name == 'push' && github.ref == 'refs/heads/main' + || github.event_name == 'workflow_dispatch') + uses: actions/upload-artifact@v4 + with: + name: tuned-op-bench-${{ github.sha }} + path: /tmp/publish/ + retention-days: 90 diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index f89e2c6701..2fb13b8db9 100755 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -12,6 +12,13 @@ from aiter.test_common import checkAllclose, perftest, benchmark from aiter import hipb_mm, hipb_create_extension from aiter.jit.utils.chip_info import get_gfx_runtime as get_gfx, get_cu_num + +try: + from tuned_op_bench_utils import append_tuned_op_bench_rows +except ModuleNotFoundError as e: + if e.name != "tuned_op_bench_utils": + raise + from op_tests.tuned_op_bench_utils import append_tuned_op_bench_rows import pandas as pd import argparse from functools import lru_cache @@ -409,7 +416,7 @@ def test_skinny_gemm_a8w8_pertoken_quant(): def _iter_flydsl_csv_cases(): - """Yield test_gemm kwargs for every flydsl row in the merged bpreshuffle tuned CSV.""" + """Yield (test_gemm kwargs, bench metadata) for flydsl tuned CSV rows.""" gfx, cu = get_gfx(), get_cu_num() merged_csv = AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE df = pd.read_csv(merged_csv) @@ -423,14 +430,21 @@ def _iter_flydsl_csv_cases(): ) for _, row in rows.iterrows(): q_dtype = dtypes.fp8 if "float8" in str(row["q_dtype_w"]) else dtypes.i8 - yield dict( - dtype=dtypes.bf16, - m=int(row["M"]), - n=int(row["N"]), - k=int(row["K"]), - quantDtype=q_dtype, - pad_a=128, - skip_ck=True, + yield ( + dict( + dtype=dtypes.bf16, + m=int(row["M"]), + n=int(row["N"]), + k=int(row["K"]), + quantDtype=q_dtype, + pad_a=128, + skip_ck=True, + ), + { + "source": "flydsl_csv", + "libtype": str(row.get("libtype", "")), + "kernelName1": str(row.get("kernelName", "")), + }, ) @@ -558,8 +572,21 @@ def _iter_flydsl_csv_cases(): args = parser.parse_args() if not args.no_flydsl_csv: - for kwargs in _iter_flydsl_csv_cases(): - test_gemm(**kwargs) + bench_csv = os.environ.get("AITER_TUNED_OP_BENCH_CSV", "tuned_op_bench.csv") + for kwargs, extras in _iter_flydsl_csv_cases(): + ret = test_gemm(**kwargs) + ret.update(extras) + written = append_tuned_op_bench_rows( + bench_csv, + [ret], + op_name="gemm_a8w8", + ) + if written: + aiter.logger.info( + "gemm_a8w8: appended %d tuned op bench row(s) to %s", + written, + bench_csv, + ) if not args.no_legacy: if args.csv is not None: diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index 16c3f55ea7..b249391791 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -31,6 +31,13 @@ from aiter.ops.flydsl.moe_common import GateMode import aiter.ops.flydsl.moe_kernels as _aiter_mk +try: + from tuned_op_bench_utils import append_tuned_op_bench_rows +except ModuleNotFoundError as e: + if e.name != "tuned_op_bench_utils": + raise + from op_tests.tuned_op_bench_utils import append_tuned_op_bench_rows + from aiter.ops.shuffle import ( shuffle_weight, @@ -818,6 +825,28 @@ def _kw( _case_iters.append(_iter_legacy_cases()) case_iter = itertools.chain(*_case_iters) +_csv_out = os.environ.get("AITER_TUNED_OP_BENCH_CSV", "tuned_op_bench.csv") + + +def _write_bench_csv(rows): + if not _csv_out or len(rows) == 0: + return + row = rows[-1] + if row.get("model") == "legacy": + return + written = append_tuned_op_bench_rows( + _csv_out, + [row], + op_name="moe_2stage", + metric_cols=("us",), + default_impl="fused_moe", + ) + if written: + aiter.logger.info( + "moe_2stage: appended %d tuned op bench row(s) to %s", written, _csv_out + ) + + df = [] seen = 0 for kwargs, extras in case_iter: @@ -853,6 +882,7 @@ def _kw( continue ret.update(extras) df.append(ret) + _write_bench_csv(df) aiter.logger.info( "moe_2stage: scanned %d cases, recorded %d results (skipped %d)", diff --git a/op_tests/tuned_op_bench_utils.py b/op_tests/tuned_op_bench_utils.py new file mode 100644 index 0000000000..8ee5404941 --- /dev/null +++ b/op_tests/tuned_op_bench_utils.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations + +from pathlib import Path +from typing import Iterable, Mapping + +import pandas as pd + +_DERIVED_METRIC_SUFFIXES = (" us",) +_DROP_SUFFIXES = (" err", " TFLOPS", " TB/s") +_DROP_COLS = {"logits_diff"} + + +def _is_missing(value) -> bool: + if value is None: + return True + try: + return bool(pd.isna(value)) + except (TypeError, ValueError): + return False + + +def _cell(value): + if _is_missing(value): + return "" + return str(value) + + +def _metric_columns(row: Mapping[str, object], metric_cols: Iterable[str] | None): + if metric_cols is not None: + return [col for col in metric_cols if col in row] + if "us" in row: + return ["us"] + return [ + col + for col in row + if any(col.endswith(suffix) for suffix in _DERIVED_METRIC_SUFFIXES) + ] + + +def _impl_from_metric(metric_col: str, default_impl: str) -> str: + if metric_col == "us": + return default_impl + for suffix in _DERIVED_METRIC_SUFFIXES: + if metric_col.endswith(suffix): + return metric_col[: -len(suffix)] + return metric_col + + +def _base_columns(row: Mapping[str, object], metric_cols: set[str]): + base = {} + for col, value in row.items(): + if col in metric_cols or col in _DROP_COLS: + continue + if any(col.endswith(suffix) for suffix in _DROP_SUFFIXES): + continue + if any(col.endswith(suffix) for suffix in _DERIVED_METRIC_SUFFIXES): + continue + base[col] = _cell(value) + return base + + +def append_tuned_op_bench_rows( + csv_path: str | Path, + rows: Iterable[Mapping[str, object]], + *, + op_name: str, + metric_cols: Iterable[str] | None = None, + default_impl: str = "", +) -> int: + """Append benchmark rows to the shared tuned-op CI CSV. + + Input rows are usually wide benchmark dictionaries. This writes a stable + long-table schema with one `us` metric per row so different operator tests + can share the same artifact. + """ + output_rows = [] + for row in rows: + row_metric_cols = _metric_columns(row, metric_cols) + metric_col_set = set(row_metric_cols) + base = _base_columns(row, metric_col_set) + base["op"] = op_name + for metric_col in row_metric_cols: + value = row.get(metric_col) + if _is_missing(value): + continue + out = dict(base) + impl = _impl_from_metric(metric_col, default_impl) + if impl: + out["impl"] = impl + out["us"] = _cell(value) + output_rows.append(out) + + if not output_rows: + return 0 + + csv_path = Path(csv_path) + new_df = pd.DataFrame(output_rows) + if csv_path.exists() and csv_path.stat().st_size > 0: + old_df = pd.read_csv(csv_path, dtype=str).fillna("") + new_df = pd.concat([old_df, new_df.astype(str).fillna("")], ignore_index=True) + new_df.to_csv(csv_path, index=False) + return len(output_rows)