99from experiments .datasets .factory import AUGMENT_CHOICES
1010
1111
12+ def format_dataset_name (dataset : str ) -> str :
13+ """Format dataset name for LaTeX tables."""
14+ name_map = {
15+ "cifar10" : "CIFAR-10" ,
16+ "cifar100" : "CIFAR-100" ,
17+ "tiny_imagenet" : "Tiny-IN" ,
18+ }
19+ return name_map .get (dataset , dataset )
20+
21+
1222def accuracy_table (df : pd .DataFrame , augment : str = "basic" ) -> str :
1323 """Generate LaTeX table comparing standard vs bit accuracy by model and dataset."""
1424 if "augment" in df .columns :
@@ -63,7 +73,7 @@ def accuracy_table(df: pd.DataFrame, augment: str = "basic") -> str:
6373 bit_mean_str = f"{ bit_mean :.2f} " if pd .notna (bit_mean ) else "-"
6474 bit_std_str = f"{ bit_std :.2f} " if pd .notna (bit_std ) else "-"
6575
66- cols = [model , dataset , std_mean_str , std_std_str , bit_mean_str , bit_std_str ]
76+ cols = [model , format_dataset_name ( dataset ) , std_mean_str , std_std_str , bit_mean_str , bit_std_str ]
6777
6878 if has_ttq :
6979 ttq_mean = row .get (("mean" , "ttq" ), float ("nan" ))
@@ -103,7 +113,7 @@ def augmentation_ablation_table(df: pd.DataFrame) -> str:
103113 df = df [df ["ablation" ] == "none" ]
104114
105115 for (model , dataset ), group in df .groupby (["model" , "dataset" ]):
106- row_vals = [str (model ), str (dataset )]
116+ row_vals = [str (model ), format_dataset_name ( str (dataset ) )]
107117 for augment in augments :
108118 aug_data = group [group ["augment" ] == augment ]
109119 if len (aug_data ) == 0 :
@@ -141,7 +151,7 @@ def statistical_table(comparisons: pd.DataFrame) -> str:
141151 for _ , row in valid .iterrows ():
142152 sig = "*" if row ["significant" ] else ""
143153 lines .append (
144- f"{ row ['model' ]} & { row ['dataset' ]} & { row ['diff' ]:.2f} & "
154+ f"{ row ['model' ]} & { format_dataset_name ( row ['dataset' ]) } & { row ['diff' ]:.2f} & "
145155 f"{ row ['t_stat' ]:.2f} & { row ['p_value' ]:.3f} { sig } & { row ['cohens_d' ]:.2f} \\ \\ "
146156 )
147157
@@ -208,10 +218,15 @@ def layer_ablation_table(df: pd.DataFrame, dataset: str = "cifar10", augment: st
208218 # Get FP32 baseline
209219 fp32_df = df [(df ["version" ] == "std" ) & (df ["ablation" ] == "none" )]
210220
221+ # Format caption
222+ caption = (
223+ rf"Layer-wise ablation on { format_dataset_name (dataset )} : " r"accuracy when keeping specific layers in FP32"
224+ )
225+
211226 lines = [
212227 r"\begin{table}[h]" ,
213228 r"\centering" ,
214- rf"\caption{{Layer-wise ablation on { dataset . upper () } : accuracy when keeping specific layers in FP32 }}" ,
229+ rf"\caption{{{ caption } }}" ,
215230 rf"\label{{tab:layer_ablation_{ dataset } }}" ,
216231 r"\begin{tabular}{llcc}" ,
217232 r"\toprule" ,
@@ -291,7 +306,7 @@ def kd_statistics_table(df: pd.DataFrame) -> str:
291306 for r in results :
292307 sig = "*" if r ["significant" ] else ""
293308 lines .append (
294- f"{ r ['model' ]} & { r ['dataset' ]} & { r ['baseline_mean' ]:.2f} & "
309+ f"{ r ['model' ]} & { format_dataset_name ( r ['dataset' ]) } & { r ['baseline_mean' ]:.2f} & "
295310 f"{ r ['kd_mean' ]:.2f} & { r ['mean_diff' ]:+.2f} & "
296311 f"{ r ['p_value' ]:.4f} { sig } & { r ['cohens_d' ]:.2f} \\ \\ "
297312 )
@@ -307,7 +322,7 @@ def kd_statistics_table(df: pd.DataFrame) -> str:
307322 return "\n " .join (lines )
308323
309324
310- def save_tables (df : pd .DataFrame , comparisons : pd .DataFrame , output_dir : str = "paper/tmlr/ tables" ) -> None :
325+ def save_tables (df : pd .DataFrame , comparisons : pd .DataFrame , output_dir : str = "paper/tables" ) -> None :
311326 """Save all tables to files."""
312327 output = Path (output_dir )
313328 output .mkdir (parents = True , exist_ok = True )
0 commit comments