-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_pagent.py
More file actions
46 lines (39 loc) · 929 Bytes
/
train_pagent.py
File metadata and controls
46 lines (39 loc) · 929 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
import gym
import numpy as np
import os, sys
import optuna
from envs.PspaceEnv import PspaceEnv
from stable_baselines3 import PPO
import cProfile
import pstats
steps = 100
seed = 2
arch = 64
ts = 3e4
cwd = '/home/kylesa/avast_clf/v0.2/'
# Make training env
env = gym.make("PspaceEnv-v0",
steps=steps,)
model = PPO(
policy="MlpPolicy",
env=env,
n_steps=1024,
batch_size=32,
learning_rate=0.0003,
gamma=0.99,
seed=seed,
policy_kwargs=dict(net_arch=[arch,arch]),
# policy_kwargs=dict(net_arch=dict(vf=[arch,arch], pi=[arch,arch]))
)
def main():
env.reset()
model.learn(total_timesteps=int(ts), progress_bar=True)
# env.write_hashes()
env.summary()
# model.save(cwd + 'agents/pspace1.pt')
main()
# p = cProfile.run('main()', 'profile_stats')
# stats = pstats.Stats('profile_stats')
# stats.sort_stats('tottime')
# stats.print_stats(10)