Skip to content

Commit 18e0f02

Browse files
committed
nn.functional:增加
1 parent f08cc0f commit 18e0f02

10 files changed

Lines changed: 164 additions & 122 deletions

File tree

excuter/op-mem-ompsimd/src/client/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ int main()
1919

2020
client::udpserver server(8080);
2121
deepx::op::OpFactory opfactory;
22-
deepx::op::register_all(opfactory);
22+
register_all(opfactory);
2323

2424
server.func = [&mem, &opfactory, &memmutex](const char *buffer)
2525
{

front/py/deepx/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
from .tensor import Tensor,Shape,Device,DeviceType
22
from deepx.nn.functional import full,zeros,ones,arange,rand,randn,eye
3-
from deepx.nn.functional import add,sub,mul,div
3+
from deepx.nn.functional import add,sub,mul,div,clamp
44
from deepx.nn.functional import matmul
5+
from deepx.nn.functional import max,min,sum,prod,mean
56
__all__ = [
67
'Tensor',
78
'Shape',
89
'Device','DeviceType',
910
#init
1011
'full','zeros', 'ones', 'arange', 'rand', 'randn', 'eye',
1112
#elementwise
12-
"add","sub","mul","div",
13+
"add","sub","mul","div","clamp",
1314
#matmul
1415
"matmul",
16+
#reduce
17+
"max","min","sum","prod","mean",
1518
]
1619

1720
# 为了支持 import deepx as dx 的用法
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
from .elementwise import add,sub,mul,div
1+
from .elementwise import add,sub,mul,div,clamp
22
from .new import newtensor
33
from .print import printtensor
44
from .matmul import matmul
55
from .init import full,zeros,ones,arange,rand,randn,eye
6+
from .reduce import max,min,sum,prod,mean
7+
68
__all__ = [
79
"newtensor",
810
"printtensor",
911
"full","zeros","ones","arange","rand","randn","eye",
10-
"add","sub","mul","div",
12+
"add","sub","mul","div","clamp",
1113
"matmul",
14+
"max","min","sum","prod","mean",
1215
]

front/py/deepx/nn/functional/elementwise.py

Lines changed: 28 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _A_b_elementwiseop_C(
2929
if a.graph.eager:
3030
varir=DeepxIR("argset", a.dtype, [b], [varnode.name])
3131
send(str(varir))
32-
ir=DeepxIR(op+"_scalar", a.dtype, [a.node.name,varnode.name], [out.node.name])
32+
ir=DeepxIR(op, a.dtype, [a.node.name,varnode.name], [out.node.name])
3333
send(str(ir))
3434
#add
3535
OpNode.register("add")
@@ -42,7 +42,7 @@ def add(
4242
if isinstance(b,Tensor):
4343
_A_B_elementwiseop_C(a,b,"add",out)
4444
else:
45-
_A_b_elementwiseop_C(a,b,"add",out)
45+
_A_b_elementwiseop_C(a,b,"add_scalar",out)
4646

4747

4848
#sub
@@ -56,7 +56,7 @@ def sub(
5656
if isinstance(b,Tensor):
5757
_A_B_elementwiseop_C(a,b,"sub",out)
5858
else:
59-
_A_b_elementwiseop_C(a,b,"sub",out)
59+
_A_b_elementwiseop_C(a,b,"sub_scalar",out)
6060

6161

6262
#mul
@@ -70,7 +70,7 @@ def mul(
7070
if isinstance(b,Tensor):
7171
_A_B_elementwiseop_C(a,b,"mul",out)
7272
else:
73-
_A_b_elementwiseop_C(a,b,"mul",out)
73+
_A_b_elementwiseop_C(a,b,"mul_scalar",out)
7474

7575

7676
#div
@@ -84,10 +84,28 @@ def div(
8484
if isinstance(b,Tensor):
8585
_A_B_elementwiseop_C(a,b,"div",out)
8686
else:
87-
_A_b_elementwiseop_C(a,b,"div",out)
87+
_A_b_elementwiseop_C(a,b,"div_scalar",out)
8888

89-
90-
89+
90+
#clamp
91+
OpNode.register("clamp")
92+
def clamp(
93+
a:Tensor,
94+
min: Optional[Union[ float, int]] = None,
95+
max: Optional[Union[ float, int]] = None,
96+
out:Tensor=None):
97+
opnode = a.graph.add_op("clamp")
98+
opnode.add_input(a.node)
99+
if min is not None:
100+
min_node = a.graph.add_var("", min)
101+
opnode.add_input(min_node)
102+
if max is not None:
103+
max_node = a.graph.add_var("", max)
104+
opnode.add_input(max_node)
105+
out.node.add_input(opnode)
106+
if a.graph.eager:
107+
varir=DeepxIR("clamp", a.dtype, [a.node.name,min,max], [out.node.name])
108+
send(str(varir))
91109

92110
# OpNode.register("ReLU", 101)
93111
# OpNode.register("Placeholder", 102)
@@ -98,22 +116,7 @@ def div(
98116
# NodeType.register("Tanh", 107)
99117
# NodeType.register("Reshape", 108)
100118
# NodeType.register("Transpose", 109)
101-
# NodeType.register("Sum", 110)
102-
# NodeType.register("Mean", 111)
103-
104-
# # 操作节点创建函数
105-
# def matmul(a, b, name=None):
106-
# node = OpNode("MatMul", name)
107-
# node.add_input("a", a)
108-
# node.add_input("b", b)
109-
# return node
110-
111-
# def add(a, b, name=None):
112-
# node = OpNode("Add", name)
113-
# node.add_input("a", a)
114-
# node.add_input("b", b)
115-
# return node
116-
119+
117120
# def relu(x, name=None):
118121
# node = OpNode("ReLU", name)
119122
# node.add_input("x", x)
@@ -129,25 +132,7 @@ def div(
129132
# node = OpNode("Neg")
130133
# node.add_input("x", x)
131134
# return node
132-
133-
# def mul(a, b):
134-
# node = OpNode("Mul")
135-
# node.add_input("a", a)
136-
# node.add_input("b", b)
137-
# return node
138-
139-
# def div(a, b):
140-
# node = OpNode("Div")
141-
# node.add_input("a", a)
142-
# node.add_input("b", b)
143-
# return node
144-
145-
# def sub(a, b):
146-
# node = OpNode("Sub")
147-
# node.add_input("a", a)
148-
# node.add_input("b", b)
149-
# return node
150-
135+
151136
# def less(a, b):
152137
# node = OpNode("Less")
153138
# node.add_input("a", a)
@@ -182,17 +167,4 @@ def div(
182167
# node.set_attr("dim0", dim0)
183168
# node.set_attr("dim1", dim1)
184169
# return node
185-
186-
# def sum(x, dim=None, keepdim=False):
187-
# node = OpNode("Sum")
188-
# node.add_input("x", x)
189-
# node.set_attr("dim", dim)
190-
# node.set_attr("keepdim", keepdim)
191-
# return node
192-
193-
# def mean(x, dim=None, keepdim=False):
194-
# node = OpNode("Mean")
195-
# node.add_input("x", x)
196-
# node.set_attr("dim", dim)
197-
# node.set_attr("keepdim", keepdim)
198-
# return node
170+
Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from typing import Optional, Union
22

3-
from .tensor import Tensor,tensor_method
3+
from deepx.tensor import Tensor
44
from deepx.autograd.graph import OpNode
55
from deepx.nn.deepxir import DeepxIR
66
from deepx.scheduler import send
77
from .elementwise import _A_b_elementwiseop_C
88

99
def _A_v_reduceop_C(
1010
a:Tensor,
11-
v: Optional[Union[Tensor, float, int]] = None,
11+
v: Optional[Union[list[int],tuple[int]]] = None,
1212
op:str=None,
1313
out:Tensor=None):
1414
opnode = a.graph.add_op(op)
@@ -20,56 +20,84 @@ def _A_v_reduceop_C(
2020
if a.graph.eager:
2121
varir=DeepxIR("argset", a.dtype, v, [vector_node.name])
2222
send(str(varir))
23-
ir=DeepxIR(op+"_scalar", a.dtype, [a.node.name,vector_node.name], [out.node.name])
23+
ir=DeepxIR(op, a.dtype, [a.node.name,vector_node.name], [out.node.name])
2424
send(str(ir))
2525

2626

2727
#max
2828
OpNode.register("max")
29-
OpNode.register("max_scalar")
30-
3129
def max(
3230
a:Tensor,
33-
b:Optional[Union[float,int],Union[Tensor,float,int]]=None,
31+
b: Optional[Union[
32+
int,float,
33+
Tensor,
34+
]] = None,
35+
dims:Optional[Union[list[int],tuple[int]]]=None,
3436
out:Tensor=None):
35-
if isinstance(b,list):
36-
_A_v_reduceop_C(a,b,"max",out)
37-
else:
37+
if b is not None and isinstance(b,int,float):
3838
_A_b_elementwiseop_C(a,b,"max_scalar",out)
39-
40-
@tensor_method
41-
def max_(self, other):
42-
result = Tensor(dtype=self.dtype,shape=self.shape)
43-
max(self,other,result)
44-
return result
39+
elif b is not None and isinstance(b,Tensor):
40+
_A_b_elementwiseop_C(a,b,"max_tensor",out)
41+
else:
42+
if dims is None:
43+
dims=list(range(a.ndim))
44+
_A_v_reduceop_C(a,dims,"max",out)
4545

4646
#min
4747
OpNode.register("min")
48-
OpNode.register("min_scalar")
49-
50-
def min(a:Tensor,b:Tensor,out:Tensor):
51-
if isinstance(b,list):
52-
_A_v_reduceop_C(a,b,"min",out)
53-
else:
48+
def min(
49+
a:Tensor,
50+
b: Optional[Union[
51+
int,float,
52+
Tensor,
53+
]] = None,
54+
dims:Optional[Union[list[int],tuple[int]]]=None,
55+
out:Tensor=None):
56+
if b is not None and isinstance(b,int,float):
5457
_A_b_elementwiseop_C(a,b,"min_scalar",out)
55-
56-
@tensor_method
57-
def min_(self, other):
58-
result = Tensor(dtype=self.dtype,shape=self.shape)
59-
min(self,other,result)
60-
return result
61-
62-
58+
elif b is not None and isinstance(b,Tensor):
59+
_A_b_elementwiseop_C(a,b,"min_tensor",out)
60+
else:
61+
if dims is None:
62+
dims=list(range(a.ndim))
63+
_A_v_reduceop_C(a,dims,"min",out)
64+
6365
#sum
6466
OpNode.register("sum")
6567
def sum(
6668
a:Tensor,
67-
b:list[int],
68-
out:Tensor):
69-
_A_v_reduceop_C(a,b,"sum",out)
69+
dims:Optional[Union[
70+
list[int],
71+
tuple[int],
72+
]]=None,
73+
out:Tensor=None):
74+
if dims is None:
75+
dims=list(range(a.ndim))
76+
_A_v_reduceop_C(a,dims,"sum",out)
77+
78+
#prod
79+
OpNode.register("prod")
80+
def prod(
81+
a:Tensor,
82+
dims:Optional[Union[
83+
list[int],
84+
tuple[int],
85+
]]=None,
86+
out:Tensor=None):
87+
if dims is None:
88+
dims=list(range(a.ndim))
89+
_A_v_reduceop_C(a,dims,"prod",out)
90+
91+
#mean
92+
OpNode.register("mean")
93+
def mean(
94+
a:Tensor,
95+
dims:Optional[Union[list[int],tuple[int]]]=None,
96+
out:Tensor=None):
97+
if dims is None:
98+
dims=list(range(a.ndim))
99+
_A_v_reduceop_C(a,dims,"mean",out)
100+
70101

71-
@tensor_method
72-
def sum_(self, other):
73-
result = Tensor(dtype=self.dtype,shape=self.shape)
74-
sum(self,other,result)
75-
return result
102+
# #var
103+
# OpNode.register("var")

front/py/deepx/nn/functional/reduction.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

front/py/deepx/tensor/elementwise.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,23 @@ def div_(self, other):
2828
div(self,other,result)
2929
return result
3030

31-
31+
@tensor_method
32+
def min_scalar_(self, other):
33+
result = Tensor(dtype=self.dtype,shape=self.shape)
34+
from deepx.nn.functional import min_scalar
35+
min_scalar(self,other,result)
36+
return result
37+
38+
@tensor_method
39+
def max_scalar_(self, other):
40+
result = Tensor(dtype=self.dtype,shape=self.shape)
41+
from deepx.nn.functional import max_scalar
42+
max_scalar(self,other,result)
43+
return result
44+
45+
@tensor_method
46+
def clamp_(self, min, max):
47+
result = Tensor(dtype=self.dtype,shape=self.shape)
48+
from deepx.nn.functional import clamp
49+
clamp(self,min,max,result)
50+
return result

0 commit comments

Comments
 (0)