-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpcfg_table.cc
More file actions
49 lines (42 loc) · 1.3 KB
/
pcfg_table.cc
File metadata and controls
49 lines (42 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include "pcfg_table.h"
#include "aligned_tree.h"
#include "log_add.h"
#include "node.h"
PCFGTable::PCFGTable(const shared_ptr<vector<Instance>>& training) {
// Count CFG rules.
for (auto instance: *training) {
const AlignedTree& tree = instance.first;
for (auto node = tree.begin(); node != tree.end(); ++node) {
vector<int> rhs;
if (node.number_of_children() > 0) {
for (auto child = tree.begin(node); child != tree.end(node); ++child) {
rhs.push_back(child->GetTag());
}
} else {
rhs.push_back(node->GetWord());
}
rule_probs[node->GetTag()][rhs] += 1;
}
}
// Normalize counts to compute log probabilities.
for (auto &lhs_entry: rule_probs) {
double total = 0;
for (auto rhs_prob_pair: lhs_entry.second) {
total += rhs_prob_pair.second;
}
for (auto &rhs_prob_pair: lhs_entry.second) {
rhs_prob_pair.second = log(rhs_prob_pair.second / total);
}
}
}
double PCFGTable::GetLogProbability(int lhs, const vector<int>& rhs) const {
auto lhs_result = rule_probs.find(lhs);
if (lhs_result != rule_probs.end()) {
auto rhs_result = lhs_result->second.find(rhs);
if (rhs_result != lhs_result->second.end()) {
return rhs_result->second;
}
}
assert(false);
return Log<double>::zero();
}