-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_env.py
More file actions
83 lines (65 loc) · 1.99 KB
/
test_env.py
File metadata and controls
83 lines (65 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gym
import numpy as np
from scipy.misc import imresize
from keras.models import load_model
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from controller import Controller
from load_encoder import load_encoder
VIEW_DECODED = False
def process_obs(obs):
return np.expand_dims(imresize(obs, (64, 64)), axis=0) / 255.
encoder = load_encoder()
decoder = load_model('./models/decoder.h5')
rnn = load_model('./models/mdn-rnn-forward.h5')
controller_params = np.load('./models/controller-params.npy')
controller = Controller(controller_params)
env = gym.make('CarRacing-v0')
done = True
t = 0
a = None
obs = None
total_reward = 0
if VIEW_DECODED:
fig = plt.figure()
im = plt.imshow(np.zeros((64, 64, 3)), animated=True)
def update_fig(*args):
global t, a, env, obs, encoder, decoder, done, total_reward
if done == True:
print('total reward: %d' % total_reward)
t = 0
total_reward = 0
obs = env.reset()
env.render()
obs = process_obs(obs)
z = encoder.predict(obs)[0]
decoded = decoder.predict(np.array([z]))[0]
a = controller.get_action(z)
a[2] = 0
obs, reward, done, info = env.step(a)
total_reward += reward
env.render()
obs = process_obs(obs)
t += 1
im.set_array(decoded)
return im,
ani = animation.FuncAnimation(fig, update_fig, interval=50, blit=True)
plt.show()
else:
while True:
if done == True:
print('total reward: %d' % total_reward)
t = 0
total_reward = 0
obs = env.reset()
env.render()
obs = process_obs(obs)
z = encoder.predict(obs)[0]
decoded = decoder.predict(np.array([z]))[0]
a = controller.get_action(z)
a[2] = 0
obs, reward, done, info = env.step(a)
total_reward += reward
env.render()
obs = process_obs(obs)
t += 1