-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlogreg_predict.py
More file actions
23 lines (22 loc) · 979 Bytes
/
logreg_predict.py
File metadata and controls
23 lines (22 loc) · 979 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from DSLR.lr_core import MultiLogisticRegression
from DSLR.dslr_utils import load_model_params, StandardScaler
from DSLR.dslr_math import calculate_mean
import pandas as pd
if __name__ == "__main__":
df = pd.read_csv("datasets/dataset_train.csv", index_col="Index")
df.drop(columns=["Hogwarts House", "Birthday", "Best Hand", "First Name", "Last Name", "Care of Magical Creatures", "Potions", "Arithmancy"], inplace=True)
for col in df.columns:
mean = calculate_mean(df[col])
df[col] = df[col].fillna(mean)
df.reset_index(drop=True, inplace=True)
params = load_model_params()
means = params[0]
stds = params[1]
scaler = StandardScaler(means, stds)
df = scaler.transform(df)
params = dict(params[2:])
model = MultiLogisticRegression(pd.DataFrame([]), pd.Series([]))
model.load_model(params)
preds = model.predict(df)
output = pd.DataFrame(preds, columns=["Predictions"])
output.to_csv("predictions.csv")