forked from Outlier01/STAR
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreplay_buffer.py
More file actions
108 lines (91 loc) · 3.03 KB
/
replay_buffer.py
File metadata and controls
108 lines (91 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from dataclasses import dataclass, fields
from typing import Optional, Self
import torch
import torch.nn.functional as F
def zero_pad_sequences(
sequences: list[torch.Tensor], side: str = "left"
) -> torch.Tensor:
assert side in ("left", "right")
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - seq.size(0)
padding = (pad_len, 0) if side == "left" else (0, pad_len)
padded_sequences.append(F.pad(seq, padding))
return torch.stack(padded_sequences, dim=0)
@dataclass
class Experience:
sequences: torch.Tensor
action_log_probs: torch.Tensor
log_probs_ref: torch.Tensor
returns: Optional[torch.Tensor]
advantages: Optional[torch.Tensor]
attention_mask: Optional[torch.Tensor]
action_mask: torch.Tensor
kl: Optional[torch.Tensor] = None
def to(self, device: torch.device) -> Self:
members = {}
for field in fields(self):
v = getattr(self, field.name)
if isinstance(v, torch.Tensor):
v = v.to(device=device)
members[field.name] = v
return Experience(**members)
def split_experience_batch(experience: Experience) -> list[Experience]:
batch_size = experience.sequences.size(0)
batch_data = [{} for _ in range(batch_size)]
keys = (
"sequences",
"action_log_probs",
"log_probs_ref",
"returns",
"advantages",
"attention_mask",
"action_mask",
)
for key in keys:
value = getattr(experience, key)
if value is None:
vals = [None] * batch_size
else:
vals = torch.unbind(value)
assert batch_size == len(vals)
for i, v in enumerate(vals):
batch_data[i][key] = v
return [Experience(**data) for data in batch_data]
def join_experience_batch(items: list[Experience]) -> Experience:
batch_data = {}
keys = (
"sequences",
"action_log_probs",
"log_probs_ref",
"returns",
"advantages",
"attention_mask",
"action_mask",
)
for key in keys:
vals = [getattr(item, key) for item in items]
if all(v is not None for v in vals):
data = zero_pad_sequences(vals, "left")
else:
data = None
batch_data[key] = data
return Experience(**batch_data)
class ReplayBuffer:
def __init__(self, limit: int = 0) -> None:
self.limit = limit
self.items: list[Experience] = []
def append(self, experience: Experience) -> None:
items = split_experience_batch(experience)
self.items.extend(items)
if self.limit > 0:
samples_to_remove = len(self.items) - self.limit
if samples_to_remove > 0:
self.items = self.items[samples_to_remove:]
def clear(self) -> None:
self.items.clear()
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> Experience:
return self.items[idx]