-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcalculate_werr.py
More file actions
93 lines (83 loc) · 3.59 KB
/
calculate_werr.py
File metadata and controls
93 lines (83 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
from collections import defaultdict
from distance import levenshtein
import itertools
import math
import numpy as np
import operator
import random
from srilm import LM
def calculate_wer(args, alpha=1):
plain_nbest = read_nbest_list(args.plain)
nbest = read_nbest_list(args.nbest)
lm_logprobs = iter(read_logprobs(args.lm_prob))
am_logprobs = get_am_logprobs(args)
total_wer = 0
for index in nbest:
next(lm_logprobs)
next(am_logprobs)
reference = plain_nbest[index][0].strip()
hypotheses = nbest[index][1:]
lms = [next(lm_logprobs) for h in hypotheses]
ams = [next(am_logprobs) for h in hypotheses]
combined_probs = log_linear_interpolate(lms, ams, alpha)
scores = list(zip(hypotheses, combined_probs))
best_index, best_hypothesis = find_best_hypothesis(scores)
best_hypothesis_plain = plain_nbest[index][best_index+1].strip()
current_wer = wer(reference, best_hypothesis_plain)
total_wer += current_wer
log_hypothesis(args.debug, best_hypothesis, best_hypothesis_plain, scores, current_wer, reference)
return 100*(total_wer/len(nbest))
def get_am_logprobs(args)
if args.include_am:
return iter(read_logprobs(args.am_prob))
else:
return itertools.repeat(0) #God help me
def read_nbest_list(filename):
nbest = defaultdict(list)
with open(filename, 'r') as f:
for line in f:
index, text = line.split("\t")
nbest[int(index)].append(text)
return nbest
def read_logprobs(filename):
with open(filename, 'r') as f:
return [float(line) for line in f]
def wer(reference, hypothesis):
distance = levenshtein(reference.strip().split(), hypothesis.strip().split())
return distance/len(reference.split())
def log_hypothesis(debug, best_hypothesis, best_hypothesis_plain, scores, werr, reference):
(best_text, best_logprob) = best_hypothesis
if debug == 2:
for hypothesis, logprob in scores:
print("{0}: {1}".format(hypothesis.strip(), logprob))
if debug > 0:
print("TRUE: {0}\nBEST: {1}\nPROB: {2}\nWERR: {3}\n\n".format(reference, best_hypothesis_plain, best_logprob, werr))
def log(wer, linspace=None):
if not linspace is None:
werr = [wer[0]- w for w in wer]
for a, w in zip(linspace, werr):
print("{0:.2f} {1}".format(a, w))
else:
print("Total WERR: {}".format(wer))
def find_best_hypothesis(scores):
return max(list(enumerate(scores)), key=lambda x : x[1][1])
def log_linear_interpolate(lms, ams, alpha):
return [alpha*lm + (1-alpha)*am for lm, am in zip(lms, ams)]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('lm_prob', help='path to the language model probabilities', type=str)
parser.add_argument('nbest', help='path to the nbest list', type=str)
parser.add_argument('-a', '--am_prob', help='path to the acoustic probabilities', type=str, default='out')
parser.add_argument('-d', '--debug', help='debug level', type=int, default=0, choices=[0, 1, 2])
parser.add_argument('-p', '--plain', help='plain text n-best file', type=str, default='data/test/nbest_plain')
parser.add_argument('--am', dest='include_am', action='store_true')
parser.add_argument('--no-am', dest='include_am', action='store_false')
parser.set_defaults(am=False)
args = parser.parse_args()
if args.include_am:
linspace = np.arange(0, 1, 0.05)
wer = [calculate_wer(args, alpha) for alpha in linspace]
log(wer, linspace)
else:
log(calculate_wer(args))