Skip to content

Commit d5fdb2b

Browse files
committed
Fix table generation: Format dataset names correctly (Tiny-IN)
- Add format_dataset_name() helper function - Maps tiny_imagenet → Tiny-IN to avoid LaTeX hyphenation issues - Maps cifar10 → CIFAR-10, cifar100 → CIFAR-100 - Apply formatting in all table generation functions - Fix output directory: paper/tmlr/tables → paper/tables - Fix line length lint warning in layer ablation caption
1 parent 8dc0dfe commit d5fdb2b

1 file changed

Lines changed: 21 additions & 6 deletions

File tree

analysis/generate_tables.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@
99
from experiments.datasets.factory import AUGMENT_CHOICES
1010

1111

12+
def format_dataset_name(dataset: str) -> str:
13+
"""Format dataset name for LaTeX tables."""
14+
name_map = {
15+
"cifar10": "CIFAR-10",
16+
"cifar100": "CIFAR-100",
17+
"tiny_imagenet": "Tiny-IN",
18+
}
19+
return name_map.get(dataset, dataset)
20+
21+
1222
def accuracy_table(df: pd.DataFrame, augment: str = "basic") -> str:
1323
"""Generate LaTeX table comparing standard vs bit accuracy by model and dataset."""
1424
if "augment" in df.columns:
@@ -63,7 +73,7 @@ def accuracy_table(df: pd.DataFrame, augment: str = "basic") -> str:
6373
bit_mean_str = f"{bit_mean:.2f}" if pd.notna(bit_mean) else "-"
6474
bit_std_str = f"{bit_std:.2f}" if pd.notna(bit_std) else "-"
6575

66-
cols = [model, dataset, std_mean_str, std_std_str, bit_mean_str, bit_std_str]
76+
cols = [model, format_dataset_name(dataset), std_mean_str, std_std_str, bit_mean_str, bit_std_str]
6777

6878
if has_ttq:
6979
ttq_mean = row.get(("mean", "ttq"), float("nan"))
@@ -103,7 +113,7 @@ def augmentation_ablation_table(df: pd.DataFrame) -> str:
103113
df = df[df["ablation"] == "none"]
104114

105115
for (model, dataset), group in df.groupby(["model", "dataset"]):
106-
row_vals = [str(model), str(dataset)]
116+
row_vals = [str(model), format_dataset_name(str(dataset))]
107117
for augment in augments:
108118
aug_data = group[group["augment"] == augment]
109119
if len(aug_data) == 0:
@@ -141,7 +151,7 @@ def statistical_table(comparisons: pd.DataFrame) -> str:
141151
for _, row in valid.iterrows():
142152
sig = "*" if row["significant"] else ""
143153
lines.append(
144-
f"{row['model']} & {row['dataset']} & {row['diff']:.2f} & "
154+
f"{row['model']} & {format_dataset_name(row['dataset'])} & {row['diff']:.2f} & "
145155
f"{row['t_stat']:.2f} & {row['p_value']:.3f}{sig} & {row['cohens_d']:.2f} \\\\"
146156
)
147157

@@ -208,10 +218,15 @@ def layer_ablation_table(df: pd.DataFrame, dataset: str = "cifar10", augment: st
208218
# Get FP32 baseline
209219
fp32_df = df[(df["version"] == "std") & (df["ablation"] == "none")]
210220

221+
# Format caption
222+
caption = (
223+
rf"Layer-wise ablation on {format_dataset_name(dataset)}: " r"accuracy when keeping specific layers in FP32"
224+
)
225+
211226
lines = [
212227
r"\begin{table}[h]",
213228
r"\centering",
214-
rf"\caption{{Layer-wise ablation on {dataset.upper()}: accuracy when keeping specific layers in FP32}}",
229+
rf"\caption{{{caption}}}",
215230
rf"\label{{tab:layer_ablation_{dataset}}}",
216231
r"\begin{tabular}{llcc}",
217232
r"\toprule",
@@ -291,7 +306,7 @@ def kd_statistics_table(df: pd.DataFrame) -> str:
291306
for r in results:
292307
sig = "*" if r["significant"] else ""
293308
lines.append(
294-
f"{r['model']} & {r['dataset']} & {r['baseline_mean']:.2f} & "
309+
f"{r['model']} & {format_dataset_name(r['dataset'])} & {r['baseline_mean']:.2f} & "
295310
f"{r['kd_mean']:.2f} & {r['mean_diff']:+.2f} & "
296311
f"{r['p_value']:.4f}{sig} & {r['cohens_d']:.2f} \\\\"
297312
)
@@ -307,7 +322,7 @@ def kd_statistics_table(df: pd.DataFrame) -> str:
307322
return "\n".join(lines)
308323

309324

310-
def save_tables(df: pd.DataFrame, comparisons: pd.DataFrame, output_dir: str = "paper/tmlr/tables") -> None:
325+
def save_tables(df: pd.DataFrame, comparisons: pd.DataFrame, output_dir: str = "paper/tables") -> None:
311326
"""Save all tables to files."""
312327
output = Path(output_dir)
313328
output.mkdir(parents=True, exist_ok=True)

0 commit comments

Comments
 (0)