-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_and_eval.py
More file actions
46 lines (39 loc) · 1.64 KB
/
train_and_eval.py
File metadata and controls
46 lines (39 loc) · 1.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from shufflenetv2 import ShuffleNetV2
from typing import List
output_classes_names: List[str] = ["Height", "Width", "Length"]
# Base Model train + eval
print("Train and evaluating base model...")
model = ShuffleNetV2(CBAM_status=False, train_enhanced=False)
model.train_model()
model.trained_model.save("trained_basemodel.keras")
metrics = model.evaluate_model_metrics()
for i in range(output_classes_names):
print(f"Metrics for {output_classes_names[i]}:")
print(f"MAE: {metrics[i]['mae']}")
print(f"MSE: {metrics[i]['mse']}")
print(f"MAPE: {metrics[i]['mape']}%")
print(f"R-squared: {metrics[i]['r2']}")
# CBAM Model unenhanced train + eval
print("Train and evaluating CBAM model...")
model = ShuffleNetV2(CBAM_status=True, train_enhanced=False)
model.train_model()
model.trained_model.save("trained_CBAMmodel.keras")
metrics = model.evaluate_model_metrics()
for i in range(output_classes_names):
print(f"Metrics for {output_classes_names[i]}:")
print(f"MAE: {metrics[i]['mae']}")
print(f"MSE: {metrics[i]['mse']}")
print(f"MAPE: {metrics[i]['mape']}%")
print(f"R-squared: {metrics[i]['r2']}")
# CBAM Model enhanced train + eval
print("Train and evaluating CBAM + image enhancements model...")
model = ShuffleNetV2(CBAM_status=True, train_enhanced=True)
model.train_model()
model.trained_model.save("trained_CBAMEnhancedmodel.keras")
metrics = model.evaluate_model_metrics()
for i in range(output_classes_names):
print(f"Metrics for {output_classes_names[i]}:")
print(f"MAE: {metrics[i]['mae']}")
print(f"MSE: {metrics[i]['mse']}")
print(f"MAPE: {metrics[i]['mape']}%")
print(f"R-squared: {metrics[i]['r2']}")