-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathneuron.py
More file actions
87 lines (65 loc) · 2.27 KB
/
neuron.py
File metadata and controls
87 lines (65 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import math
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from surrogate import *
from copy import deepcopy
class BaseNeuron(nn.Module):
def __init__(self,
act_func,
threshold: float = 1.,
soft_mode: bool = False,
**kwargs
):
super(BaseNeuron, self).__init__()
self.threshold = Parameter(torch.tensor(threshold), requires_grad=False)
self.mem = 0.
self.spike = 0.
self.summem = 0.
self.sumspike = 0.
self.record_sum = True
self.v_reset = 0.
self.soft_mode = soft_mode
self.act_func = act_func
def cal_spike(self):
self.spike = self.act_func(self.mem - self.threshold)
def cal_mem(self, x):
raise NotImplementedError
def hard_reset(self):
self.mem = self.mem * (1 - self.spike)
def soft_reset(self):
self.mem = self.mem - self.threshold * self.spike.detach()
def forward(self, x):
self.cal_mem(x)
self.cal_spike()
if self.record_sum:
self.summem += x.detach()
self.sumspike += self.spike.detach()
self.soft_reset() if self.soft_mode else self.hard_reset()
return self.spike
def reset(self):
self.mem = self.v_reset
self.spike = 0.
self.summem = 0.
self.sumspike = 0.
def set_threshold(self, threshold):
self.threshold = Parameter(torch.tensor(threshold, dtype=torch.float), requires_grad=False)
def set_tau(self, tau):
if hasattr(self, 'tau'):
self.tau = Parameter(torch.tensor(tau, dtype=torch.float), requires_grad=False)
else:
raise NotImplementedError
class IF(BaseNeuron):
def __init__(self, act_func, threshold=1., **kwargs):
super().__init__(act_func, threshold, **kwargs)
def cal_mem(self, x):
self.mem = self.mem + x
class LIF(BaseNeuron):
def __init__(self, act_func, threshold=1., **kwargs):
super().__init__(act_func, threshold, **kwargs)
self.tau = kwargs['tau']
def cal_mem(self, x):
# self.mem = self.mem * (1 - 1. / self.tau) + x
self.mem = self.mem + (x - self.mem) / self.tau