-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathbert_for_seq_classification.py
More file actions
executable file
·186 lines (151 loc) · 6.19 KB
/
bert_for_seq_classification.py
File metadata and controls
executable file
·186 lines (151 loc) · 6.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Date : 2020-08-30 12:08:23
# @Author : Kaiyan Zhang (minekaiyan@gmail.com)
# @Link : https://github.com/iseesaw
# @Version : 1.0.0
import ast
from argparse import ArgumentParser
import os
import pprint
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import torch
from transformers import BertForSequenceClassification, BertTokenizerFast, TrainingArguments, Trainer
class SimDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
"""Dataset
Args:
encodings (Dict(str, List[List[int]])): after tokenizer
labels (List[int]): labels
"""
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
"""next
Args:
idx (int):
Returns:
dict-like object, Dict(str, tensor)
"""
item = {
key: torch.tensor(val[idx])
for key, val in self.encodings.items()
}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
def load_dataset(filename):
"""加载训练集
Args:
filename (str): 文件名
Returns:
"""
df = pd.read_csv(filename)
# array -> list
return [df['sentence1'].values.tolist(),
df['sentence2'].values.tolist()], df['label'].values.tolist()
def compute_metrics(pred):
"""计算指标
Args:
pred (EvalPrediction): pred.label_ids, List[int]; pred.predictions, List[int]
Returns:
Dict(str, float): 指标结果
"""
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(
labels, preds, average='binary')
acc = accuracy_score(labels, preds)
return {
'accuracy': acc,
'f1': f1,
'precision': precision,
'recall': recall
}
def main(args):
model_path = args.model_name_or_path if args.do_train else args.output_dir
# 初始化预训练模型和分词器
tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)
# 加载 csv 格式数据集
train_texts, train_labels = load_dataset(args.trainset_path)
dev_texts, dev_labels = load_dataset(args.devset_path)
test_texts, test_labels = load_dataset(args.testset_path)
# 预处理获得模型输入特征
train_encodings = tokenizer(text=train_texts[0],
text_pair=train_texts[1],
truncation=True,
padding=True,
max_length=args.max_length)
dev_encodings = tokenizer(text=dev_texts[0],
text_pair=dev_texts[1],
truncation=True,
padding=True,
max_length=args.max_length)
test_encodings = tokenizer(text=test_texts[0],
text_pair=test_texts[1],
truncation=True,
padding=True,
max_length=args.max_length)
# 构建 SimDataset 作为模型输入
train_dataset = SimDataset(train_encodings, train_labels)
dev_dataset = SimDataset(dev_encodings, dev_labels)
test_dataset = SimDataset(test_encodings, test_labels)
# 设置训练参数
training_args = TrainingArguments(
output_dir=args.output_dir,
do_train=args.do_train,
do_eval=args.do_eval,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
warmup_steps=args.warmup_steps,
weight_decay=args.weight_decay,
logging_dir=args.logging_dir,
logging_steps=args.logging_steps,
save_total_limit=args.save_total_limit)
# 初始化训练器并开始训练
trainer = Trainer(model=model,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=dev_dataset)
if args.do_train:
trainer.train()
# 保存模型和分词器
trainer.save_model()
tokenizer.save_pretrained(args.output_dir)
if args.do_predict:
eval_metrics = trainer.evaluate(dev_dataset)
pprint.pprint(eval_metrics)
test_metrics = trainer.evaluate(test_dataset)
pprint.pprint(test_metrics)
if __name__ == '__main__':
parser = ArgumentParser('Bert For Sequence Classification')
parser.add_argument('--do_train', type=ast.literal_eval, default=False)
parser.add_argument('--do_eval', type=ast.literal_eval, default=True)
parser.add_argument('--do_predict', type=ast.literal_eval, default=True)
parser.add_argument(
'--model_name_or_path',
default='/users6/kyzhang/embeddings/bert/bert-base-chinese')
parser.add_argument('--trainset_path', default='lcqmc/LCQMC_train.csv')
parser.add_argument('--devset_path', default='lcqmc/LCQMC_dev.csv')
parser.add_argument('--testset_path', default='lcqmc/LCQMC_test.csv')
parser.add_argument('--output_dir',
default='output/transformers-bert-for-classification')
parser.add_argument('--max_length',
type=int,
default=128,
help='max length of sentence1 & sentence2')
parser.add_argument('--num_train_epochs', type=int, default=10)
parser.add_argument('--per_device_train_batch_size', type=int, default=64)
parser.add_argument('--per_device_eval_batch_size', type=int, default=64)
parser.add_argument('--warmup_steps', type=int, default=500)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--logging_dir', type=str, default='./logs')
parser.add_argument('--logging_steps', type=int, default=10)
parser.add_argument('--save_total_limit', type=int, default=3)
args = parser.parse_args()
main(args)