-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_agent.py
More file actions
47 lines (40 loc) · 933 Bytes
/
train_agent.py
File metadata and controls
47 lines (40 loc) · 933 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
import gym
import numpy as np
import os, sys
import optuna
from envs.FspaceEnv import FspaceEnv
from stable_baselines3 import PPO
import cProfile
import pstats
steps = 100
seed = 2
arch = 32
ts = 3e4
cwd = '/home/kylesa/avast_clf/v0.2/'
# Make training env
env = gym.make("FspaceEnv-v0",
steps=steps,)
model = PPO(
policy="MlpPolicy",
env=env,
n_steps=1024,
batch_size=32,
learning_rate=0.0003,
gamma=0.99,
seed=seed,
policy_kwargs=dict(net_arch=dict(vf=[arch,arch], pi=[arch,arch]))
)
def main():
env.reset()
# for i in range(100):
# env.step(2)
# env.printrep()
model.learn(total_timesteps=int(ts), progress_bar=True)
# env.write_hashes()
# model.save(cwd + 'agents/fspace1.pt')
main()
# p = cProfile.run('main()', 'profile_stats')
# stats = pstats.Stats('profile_stats')
# stats.sort_stats('tottime')
# stats.print_stats(10)