-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexport.cpp
More file actions
59 lines (45 loc) · 1.55 KB
/
export.cpp
File metadata and controls
59 lines (45 loc) · 1.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include "DQN.h"
#include "export.h"
#include <vector>
extern "C" {
DQN* DLL_CALL CreateDQN(int stateSize, int actionSize, int* hiddenLayers, int hiddenLayersSize) {
std::vector<int> hiddenLayersVec(hiddenLayers, hiddenLayers + hiddenLayersSize);
return new DQN(stateSize, actionSize, hiddenLayersVec);
}
void DLL_CALL DestroyDQN(DQN* dqn) {
delete dqn;
}
int DLL_CALL SelectAction(DQN* dqn, double* state, int stateSize, double epsilon) {
std::vector<double> stateVec(state, state + stateSize);
return dqn->SelectAction(stateVec, epsilon);
}
void DLL_CALL Train(DQN* dqn, double* state, int stateSize, int action, double reward, double* nextState, int nextStateSize, double gamma, double epsilonDecay) {
std::vector<double> stateVec(state, state + stateSize);
std::vector<double> nextStateVec(nextState, nextState + nextStateSize);
dqn->Train(stateVec, action, reward, nextStateVec, gamma, epsilonDecay);
}
void DLL_CALL UpdateTargetNetwork(DQN* dqn) {
dqn->UpdateTargetNetwork();
}
double DLL_CALL GetGamma(DQN* dqn) {
return dqn->GetGamma();
}
double DLL_CALL GetEpsilon(DQN* dqn) {
return dqn->GetEpsilon();
}
double DLL_CALL GetEpsilonDecay(DQN* dqn) {
return dqn->GetEpsilonDecay();
}
double DLL_CALL GetQNetLoss(DQN* dqn){
return dqn->GetQNetLoss();
}
double DLL_CALL GetTargetNetLoss(DQN* dqn){
return dqn->GetTargetNetLoss();
}
int DLL_CALL SaveModel(DQN* dqn, const char* filepath) {
dqn->SaveModel(filepath);
}
int DLL_CALL LoadModel(DQN* dqn, const char* filepath) {
dqn->LoadModel(filepath);
}
}