11#ifndef DEEPX_TF_MATMUL_HPP
22#define DEEPX_TF_MATMUL_HPP
3-
3+
44#include " deepx/tf/tf.hpp"
55#include " deepx/dtype.hpp"
66#include " deepx/dtype_ompsimd.hpp"
@@ -21,7 +21,7 @@ namespace deepx::tf
2121 this ->args = args;
2222 this ->returns = returns;
2323 }
24-
24+
2525 string math_formula () const override
2626 {
2727 return " T3=T1 @ T2" ;
@@ -30,7 +30,17 @@ namespace deepx::tf
3030 {
3131 return make_shared<MatMul<Author>>(*this );
3232 }
33- int compute (shared_ptr<MemBase> mem, Precision a_type,string &error){
33+
34+ int run (shared_ptr<MemBase> mem, string &error) override
35+ {
36+ Precision a_type = mem->gettensor (this ->args [0 ].textvalue ).get ()->shape .dtype ;
37+ Precision b_type = mem->gettensor (this ->args [1 ].textvalue ).get ()->shape .dtype ;
38+ Precision c_type = mem->gettensor (this ->returns [0 ].textvalue ).get ()->shape .dtype ;
39+ if (a_type != b_type || a_type != c_type)
40+ {
41+ error = " Type mismatch: " + precision_str (a_type) + " != " + precision_str (b_type) + " != " + precision_str (c_type);
42+ return 1 ;
43+ }
3444 switch (a_type)
3545 {
3646 case Precision::Float64:
@@ -57,30 +67,6 @@ namespace deepx::tf
5767 }
5868 return 0 ;
5969 }
60- int run (shared_ptr<MemBase> mem, string &error) override
61- {
62- Precision a_type = mem->gettensor (this ->args [0 ].textvalue ).get ()->shape .dtype ;
63- Precision b_type = mem->gettensor (this ->args [1 ].textvalue ).get ()->shape .dtype ;
64- Precision c_type = mem->gettensor (this ->returns [0 ].textvalue ).get ()->shape .dtype ;
65- if (a_type != b_type || a_type != c_type)
66- {
67- error = " Type mismatch: " + precision_str (a_type) + " != " + precision_str (b_type) + " != " + precision_str (c_type);
68- return 1 ;
69- }
70- if (metadata.benchmark .repeat > 0 )
71- {
72- for (int i = 0 ; i < metadata.benchmark .repeat ; i++)
73- {
74- if (compute (mem, a_type, error))
75- {
76- return 1 ;
77- }
78- }
79- }else {
80- return compute (mem, a_type, error);
81- }
82- return 0 ;
83- }
8470 };
8571}
8672
0 commit comments