-
Notifications
You must be signed in to change notification settings - Fork 25
Expand file tree
/
Copy pathmodel.py
More file actions
66 lines (45 loc) · 1.75 KB
/
model.py
File metadata and controls
66 lines (45 loc) · 1.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from einops import rearrange
import numpy as np
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self,
in_features,
ffd_hidden_size,
num_classes,
attn_layer_num,
):
super(Generator, self).__init__()
self.attn = nn.ModuleList(
[
nn.MultiheadAttention(
embed_dim=in_features,
num_heads=8,
dropout=0.2,
batch_first=True,
)
for _ in range(attn_layer_num)
]
)
self.ffd = nn.Sequential(
nn.Linear(in_features, ffd_hidden_size),
nn.ReLU(),
nn.Linear(ffd_hidden_size, in_features)
)
self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(in_features * 2, num_classes)
self.proj = nn.Tanh()
def forward(self, ssl_feature, judge_id=None):
'''
ssl_feature: [B, T, D]
output: [B, num_classes]
'''
B, T, D = ssl_feature.shape
ssl_feature = self.ffd(ssl_feature)
tmp_ssl_feature = ssl_feature
for attn in self.attn:
tmp_ssl_feature, _ = attn(tmp_ssl_feature, tmp_ssl_feature, tmp_ssl_feature)
ssl_feature = self.dropout(torch.concat([torch.mean(tmp_ssl_feature, dim=1), torch.max(ssl_feature, dim=1)[0]], dim=1)) # B, 2D
x = self.fc(ssl_feature) # B, num_classes
x = self.proj(x) * 2.0 + 3
return x