-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_us_centric_bias.py
More file actions
135 lines (115 loc) · 5.03 KB
/
plot_us_centric_bias.py
File metadata and controls
135 lines (115 loc) · 5.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse
import matplotlib.pyplot as plt
import polars as pl
import seaborn as sns
import statsmodels.formula.api as smf
from loguru import logger
from statsmodels.regression.linear_model import RegressionResults
from multicultural_alignment.constants import CLEAN_MODEL_NAMES, LANGUAGE_MAP, OUTPUT_DIR, PLOT_DIR
from multicultural_alignment.models import add_families_df
from multicultural_alignment.plot import get_renamed_colours
from multicultural_alignment.regression import save_regression_results
language_pattern = r"language\[([^\]]+)\]"
model_pattern = r"model_name\[T?\.?([^\]]+)\]"
def get_us_data() -> pl.DataFrame:
experiment_data = pl.read_csv(OUTPUT_DIR / "spearman_n100_extreme_total_gt_alignment.csv").rename(
{"response_language": "language"}
)
wvs_country = pl.read_csv(OUTPUT_DIR / "ground_truth_every_country.csv")
country_langs = wvs_country.unique(["cntry_an", "lnge_iso"]).select(["cntry_an", "lnge_iso"])
country_corrs_us = (
wvs_country.join(wvs_country, on=["question_key"])
.group_by("cntry_an", "cntry_an_right")
.agg(pl.corr("pro_score", "pro_score_right").alias("us_corr"))
.filter(pl.col("cntry_an_right") == "US")
.drop("cntry_an_right")
)
return (
add_families_df(experiment_data)
.filter(pl.col("gt_type") == "country")
.join(country_langs, left_on="gt_group", right_on="cntry_an")
.with_columns((pl.col("language") == pl.col("lnge_iso")).alias("native"), (pl.col("gt_group") == "US").alias("us"))
.filter(pl.col("native") | pl.col("us"))
.filter(pl.col("model_name") != "Baseline_fifty_percent")
.filter(pl.col("family") != "mistral")
.sort("family")
.cast({"model_name": pl.String})
.join(country_corrs_us, left_on="gt_group", right_on="cntry_an")
.unique()
)
def extract_results_df(results: RegressionResults) -> pl.DataFrame:
params = results.params
conf_int = results.conf_int()
return pl.DataFrame(
{
"term": params.index.tolist(),
"coefficient": params.values,
"std_err": results.bse.values,
"p_value": results.pvalues.values,
"conf_int_lower": conf_int[0].values,
"conf_int_upper": conf_int[1].values,
}
)
def run_us_regression() -> RegressionResults:
us_data = get_us_data()
us_model = smf.ols(formula="metric_value ~ us*model_name:language", data=us_data.to_pandas())
return us_model.fit()
def get_coefficients(us_results) -> pl.DataFrame:
results_pl = extract_results_df(us_results)
significant = (
results_pl.filter(pl.col("term").str.contains(":model_name"))
.with_columns(
pl.col("term").str.extract(language_pattern).alias("language"),
pl.col("term").str.extract(model_pattern).alias("model_name"),
)
.drop("term")
)
return add_families_df(significant)
def get_confidence_data(significant: pl.DataFrame) -> pl.DataFrame:
significant_lower = significant.select(
pl.col("conf_int_lower").alias("coefficient"),
"model_name",
"language",
)
significant_higher = significant.select(
pl.col("conf_int_upper").alias("coefficient"),
"model_name",
"language",
)
significant_mid = significant.select(
pl.col("coefficient"),
"model_name",
"language",
)
return pl.concat([significant_lower, significant_higher, significant_mid])
def plot_us_centric_bias(plot_data: pl.DataFrame, font_size: int = 27) -> None:
plt.figure(figsize=(12, 6))
plot = sns.barplot(
data=plot_data.sort("model_name")
.cast({"model_name": str})
.with_columns(pl.col("language").replace(LANGUAGE_MAP), pl.col("model_name").replace(CLEAN_MODEL_NAMES)),
x="language",
y="coefficient",
hue="model_name",
palette=get_renamed_colours(),
errorbar=("pi", 100),
)
plot.set_ylabel("$\\beta_{BiasUS}$", fontsize=font_size)
plot.set_xlabel(None)
plot.set_xticklabels(plot.get_xticklabels(), fontsize=font_size - 2)
plt.legend(ncol=3, loc="lower center", bbox_to_anchor=(0.45, -0.43))
plt.savefig(PLOT_DIR / "us_bias_coefficients.png", bbox_inches="tight")
def main(font_scale: float = 1.7, font_size: int = 27):
sns.set_theme(style="whitegrid", font_scale=font_scale)
logger.info("Running US-centric bias regression")
us_results = run_us_regression()
save_regression_results(us_results, regression_type="normal", rq_method="us_centric_bias")
logger.info(f"Regression table:\n{us_results.summary()}")
plot_data = get_confidence_data(get_coefficients(us_results))
plot_us_centric_bias(plot_data, font_size=font_size)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--scale", type=float, default=1.7, help="Font scale for the plot")
parser.add_argument("--size", type=int, default=27, help="Font size for the plot")
args = parser.parse_args()
main(font_scale=args.scale, font_size=args.size)