-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsampler.h
More file actions
122 lines (85 loc) · 3.33 KB
/
sampler.h
File metadata and controls
122 lines (85 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#ifndef _SAMPLER_H_
#define _SAMPLER_H_
#include <memory>
#include <random>
#include <unordered_map>
#include "aligned_tree.h"
#include "alignment_constructor.h"
#include "dictionary.h"
#include "distributed_rule_counts.h"
#include "rule_extractor.h"
#include "rule_reorderer.h"
#include "util.h"
using namespace std;
class PCFGTable;
class TranslationTable;
typedef mt19937 RandomGenerator;
class Sampler {
public:
Sampler(const shared_ptr<vector<Instance>>& training, Dictionary& dictionary,
const shared_ptr<PCFGTable>& pcfg_table,
const shared_ptr<TranslationTable>& forward_table,
const shared_ptr<TranslationTable>& reverse_table,
RandomGenerator& generator, int num_threads, bool enable_all_stats,
bool smart_expand, int min_rule_count, bool reorder, double penalty,
int max_leaves, int max_tree_size, double alpha,
double pexpand, double pchild, double pterm,
const string& output_directory);
void Sample(int iterations, int log_frequency,
int start_index, int final_index);
void SerializeAlignments(const string& iteration = "");
void SerializeGrammar(bool scfg_format, const string& iteration = "");
void SerializeReorderings(const string& iteration = "");
void SerializeInternalState(const string& iteration = "");
private:
void InitializeRuleCounts();
void CacheSentence(const Instance& instance);
void DisplayStats();
double ComputeDataLikelihood();
double ComputeAverageNumInteriorNodes();
int GetGrammarSize();
map<int, int> GenerateRuleHistogram();
void SampleAlignments(const Instance& instance, int index);
void SampleSwaps(const Instance& instance);
vector<NodeIter> GetRandomSchedule(const AlignedTree& tree);
vector<pair<int, int>> GetLegalSpans(const AlignedTree& tree,
const NodeIter& node,
const NodeIter& ancestor);
double ComputeLogBaseProbability(const Rule& rule);
double ComputeLogProbability(const Rule& r);
double ComputeLogProbability(const Rule& r1, const Rule& r2);
double ComputeLogProbability(const Rule& r1, const Rule& r2, const Rule& r3);
void IncrementRuleCount(const Rule& rule);
void DecrementRuleCount(const Rule& rule);
void InferReorderings();
void ExtractReordering(const Instance& instance,
const NodeIter& node,
String& reordering);
shared_ptr<vector<Instance>> training;
DistributedRuleCounts counts;
RuleExtractor extractor;
AlignmentConstructor alignment_constructor;
Dictionary& dictionary;
shared_ptr<PCFGTable> pcfg_table;
shared_ptr<TranslationTable> forward_table;
shared_ptr<TranslationTable> reverse_table;
RandomGenerator& generator;
uniform_real_distribution<double> uniform_distribution;
int num_threads;
bool enable_all_stats;
// Parameters for filtering the final rules.
int min_rule_count;
bool reorder;
RuleReorderer rule_reorderer;
vector<map<String, int>> reorder_counts;
bool smart_expand;
unordered_map<int, double> expand_probs;
unordered_map<int, double> not_expand_probs;
double alpha;
double prob_expand, prob_not_expand;
double prob_stop_child, prob_cont_child;
double prob_stop_str, prob_cont_str;
double prob_nt, prob_st, prob_tt;
string output_directory;
};
#endif