@@ -29,7 +29,7 @@ def _A_b_elementwiseop_C(
2929 if a .graph .eager :
3030 varir = DeepxIR ("argset" , a .dtype , [b ], [varnode .name ])
3131 send (str (varir ))
32- ir = DeepxIR (op + "_scalar" , a .dtype , [a .node .name ,varnode .name ], [out .node .name ])
32+ ir = DeepxIR (op , a .dtype , [a .node .name ,varnode .name ], [out .node .name ])
3333 send (str (ir ))
3434#add
3535OpNode .register ("add" )
@@ -42,7 +42,7 @@ def add(
4242 if isinstance (b ,Tensor ):
4343 _A_B_elementwiseop_C (a ,b ,"add" ,out )
4444 else :
45- _A_b_elementwiseop_C (a ,b ,"add " ,out )
45+ _A_b_elementwiseop_C (a ,b ,"add_scalar " ,out )
4646
4747
4848#sub
@@ -56,7 +56,7 @@ def sub(
5656 if isinstance (b ,Tensor ):
5757 _A_B_elementwiseop_C (a ,b ,"sub" ,out )
5858 else :
59- _A_b_elementwiseop_C (a ,b ,"sub " ,out )
59+ _A_b_elementwiseop_C (a ,b ,"sub_scalar " ,out )
6060
6161
6262#mul
@@ -70,7 +70,7 @@ def mul(
7070 if isinstance (b ,Tensor ):
7171 _A_B_elementwiseop_C (a ,b ,"mul" ,out )
7272 else :
73- _A_b_elementwiseop_C (a ,b ,"mul " ,out )
73+ _A_b_elementwiseop_C (a ,b ,"mul_scalar " ,out )
7474
7575
7676#div
@@ -84,10 +84,28 @@ def div(
8484 if isinstance (b ,Tensor ):
8585 _A_B_elementwiseop_C (a ,b ,"div" ,out )
8686 else :
87- _A_b_elementwiseop_C (a ,b ,"div " ,out )
87+ _A_b_elementwiseop_C (a ,b ,"div_scalar " ,out )
8888
89-
90-
89+
90+ #clamp
91+ OpNode .register ("clamp" )
92+ def clamp (
93+ a :Tensor ,
94+ min : Optional [Union [ float , int ]] = None ,
95+ max : Optional [Union [ float , int ]] = None ,
96+ out :Tensor = None ):
97+ opnode = a .graph .add_op ("clamp" )
98+ opnode .add_input (a .node )
99+ if min is not None :
100+ min_node = a .graph .add_var ("" , min )
101+ opnode .add_input (min_node )
102+ if max is not None :
103+ max_node = a .graph .add_var ("" , max )
104+ opnode .add_input (max_node )
105+ out .node .add_input (opnode )
106+ if a .graph .eager :
107+ varir = DeepxIR ("clamp" , a .dtype , [a .node .name ,min ,max ], [out .node .name ])
108+ send (str (varir ))
91109
92110# OpNode.register("ReLU", 101)
93111# OpNode.register("Placeholder", 102)
@@ -98,22 +116,7 @@ def div(
98116# NodeType.register("Tanh", 107)
99117# NodeType.register("Reshape", 108)
100118# NodeType.register("Transpose", 109)
101- # NodeType.register("Sum", 110)
102- # NodeType.register("Mean", 111)
103-
104- # # 操作节点创建函数
105- # def matmul(a, b, name=None):
106- # node = OpNode("MatMul", name)
107- # node.add_input("a", a)
108- # node.add_input("b", b)
109- # return node
110-
111- # def add(a, b, name=None):
112- # node = OpNode("Add", name)
113- # node.add_input("a", a)
114- # node.add_input("b", b)
115- # return node
116-
119+
117120# def relu(x, name=None):
118121# node = OpNode("ReLU", name)
119122# node.add_input("x", x)
@@ -129,25 +132,7 @@ def div(
129132# node = OpNode("Neg")
130133# node.add_input("x", x)
131134# return node
132-
133- # def mul(a, b):
134- # node = OpNode("Mul")
135- # node.add_input("a", a)
136- # node.add_input("b", b)
137- # return node
138-
139- # def div(a, b):
140- # node = OpNode("Div")
141- # node.add_input("a", a)
142- # node.add_input("b", b)
143- # return node
144-
145- # def sub(a, b):
146- # node = OpNode("Sub")
147- # node.add_input("a", a)
148- # node.add_input("b", b)
149- # return node
150-
135+
151136# def less(a, b):
152137# node = OpNode("Less")
153138# node.add_input("a", a)
@@ -182,17 +167,4 @@ def div(
182167# node.set_attr("dim0", dim0)
183168# node.set_attr("dim1", dim1)
184169# return node
185-
186- # def sum(x, dim=None, keepdim=False):
187- # node = OpNode("Sum")
188- # node.add_input("x", x)
189- # node.set_attr("dim", dim)
190- # node.set_attr("keepdim", keepdim)
191- # return node
192-
193- # def mean(x, dim=None, keepdim=False):
194- # node = OpNode("Mean")
195- # node.add_input("x", x)
196- # node.set_attr("dim", dim)
197- # node.set_attr("keepdim", keepdim)
198- # return node
170+
0 commit comments