-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
60 lines (45 loc) · 1.81 KB
/
main.py
File metadata and controls
60 lines (45 loc) · 1.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
An example for users connecting to the Sybl service and waiting for data.
"""
import pandas as pd # type: ignore
from sklearn.tree import DecisionTreeClassifier # type: ignore
from sklearn.linear_model import LinearRegression # type: ignore
from sybl.client import Sybl
from sybl.client import JobConfig
sybl = Sybl()
def ohe(dataset):
categorical = dataset.select_dtypes("object")
if not categorical.empty:
encoded = pd.get_dummies(categorical[categorical.columns])
return pd.concat([dataset, encoded], axis=1).drop(categorical, axis=1)
return dataset
def callback(train, predict, job_config):
"""
An example callback that runs with data provided by the service.
Args:
train: The data for the model to learn from
predict: The data to make predictions based on
job_config: Details about the job
Returns: Predictions generated by the model
"""
print(job_config)
prediction_col = job_config["prediction_column"]
prediction_type = job_config["prediction_type"]
X_train = train.drop(prediction_col, axis=1)
y_train = train[prediction_col]
X_test = predict.drop(prediction_col, axis=1)
X_train = ohe(X_train)
X_test = ohe(X_test)
for column in set(X_train.columns).difference(set(X_test.columns)):
X_test[column] = 0
for column in set(X_test.columns).difference(set(X_train.columns)):
X_train[column] = 0
print(X_train, X_test, y_train)
if prediction_type == "classification":
return_frame = DecisionTreeClassifier().fit(X_train, y_train).predict(X_test)
else:
return_frame = LinearRegression().fit(X_train, y_train).predict(X_test)
return pd.DataFrame({prediction_col: return_frame})
sybl.register_callback(callback)
sybl.load_model("<email>", "<model_name>")
sybl.connect()