-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBernpro.py
More file actions
55 lines (44 loc) · 1.85 KB
/
Bernpro.py
File metadata and controls
55 lines (44 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import Optional
from torch_geometric.typing import OptTensor
import math
import torch
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
from torch_geometric.utils import get_laplacian
from scipy.special import comb
import torch.nn.functional as F
from torch_geometric.nn.conv.gcn_conv import gcn_norm
import numpy as np
class Bern_prop(MessagePassing):
def __init__(self, K, bias=True, **kwargs):
super(Bern_prop, self).__init__(aggr='add', **kwargs)
self.K = K
self.temp = Parameter(torch.Tensor(self.K+1))
self.reset_parameters()
def reset_parameters(self):
self.temp.data.fill_(1)
def forward(self, x, edge_index,edge_weight=None):
TEMP=F.relu(self.temp)
#L=I-D^(-0.5)AD^(-0.5)
edge_index1, norm1 = get_laplacian(edge_index, edge_weight,normalization='sym', dtype=x.dtype, num_nodes=x.size(self.node_dim))
#2I-L
edge_index2, norm2=add_self_loops(edge_index1,-norm1,fill_value=2.,num_nodes=x.size(self.node_dim))
tmp=[]
tmp.append(x)
for i in range(self.K):
x=self.propagate(edge_index2,x=x,norm=norm2,size=None)
tmp.append(x)
out=(comb(self.K,0)/(2**self.K))*TEMP[0]*tmp[self.K]
for i in range(self.K):
x=tmp[self.K-i-1]
x=self.propagate(edge_index1,x=x,norm=norm1,size=None)
for j in range(i):
x=self.propagate(edge_index1,x=x,norm=norm1,size=None)
out=out+(comb(self.K,i+1)/(2**self.K))*TEMP[i+1]*x
return out
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def __repr__(self):
return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K,
self.temp)