-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrandomwalks.py
More file actions
108 lines (85 loc) · 3.87 KB
/
randomwalks.py
File metadata and controls
108 lines (85 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
import torch as th
from torch import tensor
from torch.utils.data import TensorDataset
import torch.nn.functional as F
from utils import randexclude
import networkx as nx
# Toy dataset from Decision Transformer (Chen et. al 2021)
class RandomWalks:
def __init__(self, n_nodes=20, max_length=10, n_walks=1000, p_edge=0.1, seed=1002):
self.n_nodes = n_nodes
self.n_walks = n_walks
self.max_length = max_length
rng = np.random.RandomState(seed)
walks, rewards = [], []
while True:
self.adj = rng.rand(n_nodes, n_nodes) > (1 - p_edge)
np.fill_diagonal(self.adj, 0)
if np.all(self.adj.sum(1)): break
# terminal state
self.adj[0, :] = 0
self.adj[0, 0] = 1
self.goal = 0
for _ in range(n_walks):
node = randexclude(rng, n_nodes, self.goal)
walk = [node]
for istep in range(max_length-1):
node = rng.choice(np.nonzero(self.adj[node])[0])
walk.append(node)
if node == self.goal:
break
r = th.zeros(max_length-1)
r[:len(walk)-1] = -1 if walk[-1] == self.goal else -100
rewards.append(r)
walks.append(walk)
states = []
attention_masks = []
for r, walk in zip(rewards, map(th.tensor, walks)):
attention_mask = th.zeros(max_length, dtype=int)
attention_mask[:len(walk)] = 1
attention_masks.append(attention_mask)
states.append(F.pad(walk, (0, max_length-len(walk))))
self.worstlen = self.max_length
self.avglen = sum(map(len, walks)) / self.n_walks
self.bestlen = 0
g = nx.from_numpy_array(self.adj, create_using=nx.DiGraph)
for start in set(range(self.n_nodes)) - {self.goal}:
try:
shortest_path = nx.shortest_path(g, start, self.goal)[:self.max_length]
self.bestlen += len(shortest_path)
except:
self.bestlen += self.max_length
self.bestlen /= self.n_nodes - 1
print(f'{self.n_walks} walks of which {(np.array([r[0] for r in rewards])==-1).mean()*100:.0f}% arrived at destination')
# disallows selecting unaccessible nodes in a graph
self.logit_mask = tensor(~self.adj)
self.dataset = TensorDataset(th.stack(states), th.stack(attention_masks), th.stack(rewards))
self.eval_dataset = TensorDataset(th.arange(1, self.n_nodes).unsqueeze(1))
def render(self):
from matplotlib import pyplot
g = nx.from_numpy_array(self.adj, create_using=nx.DiGraph)
pos = nx.spring_layout(g, seed=7357)
pyplot.figure(figsize=(10, 8))
nx.draw_networkx_edges(g, pos=pos, alpha=0.5, width=1, edge_color='#d3d3d3')
nx.draw_networkx_nodes(g, nodelist=set(range(len(self.adj))) - {self.goal}, pos=pos, node_size=300, node_color='orange')
nx.draw_networkx_nodes(g, nodelist=[self.goal], pos=pos, node_size=300, node_color='darkblue')
pyplot.show()
def eval(self, samples, beta):
narrived = 0
actlen = 0
for node in range(self.n_nodes-1):
for istep in range(self.max_length):
if samples[node, istep] == self.goal:
narrived += 1
break
actlen += (istep + 1) / (self.n_nodes - 1)
current = (self.worstlen - actlen)/(self.worstlen - self.bestlen)
average = (self.worstlen - self.avglen)/(self.worstlen - self.bestlen)
stats = { 'actlen': actlen,
'avglen': self.avglen,
'bestlen': self.bestlen,
'worstlen': self.worstlen,
'arrived': f'{narrived / (self.n_nodes-1) * 100:.0f}%',
'optimal': f'{current*100:.0f}% > {average*100:.0f}%' }
return -actlen, stats