-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprod.py
More file actions
122 lines (97 loc) · 3.57 KB
/
prod.py
File metadata and controls
122 lines (97 loc) · 3.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy
import six
from chainer import cuda
from chainer import function
from chainer.utils import type_check
class Prod(function.Function):
"""Product of array elements over a given axis."""
keepdims = False
def __init__(self, axis=None, keepdims=False):
if axis is None:
self.axis = None
elif isinstance(axis, int):
self.axis = (axis,)
elif isinstance(axis, tuple) and all(isinstance(a, int) for a in axis):
if len(set(axis)) != len(axis):
raise ValueError('duplicate value in axis: ({})'.format(
', '.join(map(str, axis))))
self.axis = axis
else:
raise TypeError('None, int or tuple of int are required')
self.keepdims = keepdims
def check_type_forward(self, in_types):
type_check.expect(
in_types.size() == 1,
in_types[0].dtype.kind == 'f',
)
if self.axis is not None:
for axis in self.axis:
if axis >= 0:
type_check.expect(
axis < in_types[0].ndim,
)
else:
type_check.expect(
-axis - 1 < in_types[0].ndim,
)
def forward(self, x):
xp = cuda.get_array_module(*x)
return xp.asarray(x[0].prod(axis=self.axis, keepdims=self.keepdims)),
def backward(self, x, gy):
xp = cuda.get_array_module(*x)
x = x[0]
gy = gy[0]
if self.axis is None:
axes = list(six.moves.range(x.ndim))
else:
axes = []
for axis in self.axis:
if axis < 0:
axis += len(x.shape)
axes.append(axis)
if not self.keepdims:
for axis in sorted(axes):
gy = xp.expand_dims(gy, axis=axis)
axes = tuple(axes)
# indices of axes that are not reduced
axes_kept = tuple(a for a in six.moves.range(x.ndim) if a not in axes)
n_reduced_elements = 1
for axis in axes:
n_reduced_elements *= x.shape[axis]
n_output_elements = x.size // n_reduced_elements
transpose_axes = axes_kept + axes
x = x.transpose(transpose_axes)
transposed_shape = x.shape
x = x.reshape(-1, n_reduced_elements)
unrolled_x = xp.repeat(x, n_reduced_elements, 0)
mask = xp.tile(xp.arange(n_reduced_elements), n_output_elements)
unrolled_x[xp.arange(x.size), mask] = 1
dydx = unrolled_x.prod(1)
dydx = dydx.reshape(transposed_shape)
dydx = dydx.transpose(numpy.argsort(transpose_axes))
gx = dydx * gy
return gx,
def prod(x, axis=None, keepdims=False):
"""Product of array elements over a given axis.
Args:
x (~chainer.Variable): Elements to calculate the product.
axis (None, int, or tuple of int): Axis which a product is performed.
The default (axis = None) is perform a product over all the
dimensions of the input array.
keepdims (bool): If ``True``, the specified axes are remained as axes
of length one.
Returns:
~chainer.Variable: Output variable.
"""
return Prod(axis, keepdims)(x)
if __name__ == '__main__':
import chainer
xp = numpy
xdata = xp.arange(6, dtype=xp.float32).reshape(1, 2, 3)
x = chainer.Variable(xdata)
y = prod(x, (1, -1), keepdims=True)
z = prod(y, -1)
loss = prod(z)
loss.grad = xp.ones_like(loss.data)
loss.backward(True)
print(x.grad)