-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcaption_classifier.py
More file actions
127 lines (99 loc) · 3.76 KB
/
caption_classifier.py
File metadata and controls
127 lines (99 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Caption Quality Classifier
A machine learning model to differentiate between good and bad captions
"""
import json
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
import pickle
class CaptionClassifier:
"""A classifier to distinguish between good and bad captions"""
def __init__(self):
self.vectorizer = TfidfVectorizer(max_features=1000, stop_words='english')
self.model = LogisticRegression(random_state=42)
self.is_trained = False
def load_data(self, filepath):
"""Load caption data from JSON file"""
with open(filepath, 'r') as f:
data = json.load(f)
captions = []
labels = []
for item in data['captions']:
captions.append(item['text'])
labels.append(1 if item['label'] == 'good' else 0)
return captions, labels
def preprocess(self, captions):
"""Preprocess captions for training"""
return self.vectorizer.fit_transform(captions)
def train(self, X, y):
"""Train the classifier"""
self.model.fit(X, y)
self.is_trained = True
print("Model trained successfully!")
def predict(self, caption):
"""Predict if a caption is good or bad"""
if not self.is_trained:
raise Exception("Model not trained yet!")
X = self.vectorizer.transform([caption])
prediction = self.model.predict(X)[0]
probability = self.model.predict_proba(X)[0]
return {
'label': 'good' if prediction == 1 else 'bad',
'confidence': max(probability)
}
def evaluate(self, X, y):
"""Evaluate model performance"""
predictions = self.model.predict(X)
accuracy = accuracy_score(y, predictions)
report = classification_report(y, predictions, target_names=['bad', 'good'])
return accuracy, report
def save_model(self, filepath):
"""Save the trained model"""
with open(filepath, 'wb') as f:
pickle.dump({
'model': self.model,
'vectorizer': self.vectorizer
}, f)
print(f"Model saved to {filepath}")
def load_model(self, filepath):
"""Load a trained model"""
with open(filepath, 'rb') as f:
data = pickle.load(f)
self.model = data['model']
self.vectorizer = data['vectorizer']
self.is_trained = True
print(f"Model loaded from {filepath}")
def main():
"""Main function to train and test the classifier"""
classifier = CaptionClassifier()
# Load data
captions, labels = classifier.load_data('caption_data.json')
print(f"Loaded {len(captions)} captions")
# Preprocess
X = classifier.preprocess(captions)
y = np.array(labels)
# Train
classifier.train(X, y)
# Evaluate
accuracy, report = classifier.evaluate(X, y)
print(f"\nTraining Accuracy: {accuracy:.2%}")
print("\nClassification Report:")
print(report)
# Test predictions
test_captions = [
"A stunning mountain landscape with snow-capped peaks",
"meh",
"Beautiful flowers blooming in spring garden"
]
print("\nTest Predictions:")
for caption in test_captions:
result = classifier.predict(caption)
print(f"Caption: '{caption}'")
print(f" -> {result['label']} (confidence: {result['confidence']:.2%})\n")
# Save model
classifier.save_model('caption_model.pkl')
if __name__ == "__main__":
main()