File tree Expand file tree Collapse file tree 2 files changed +59
-3
lines changed
Expand file tree Collapse file tree 2 files changed +59
-3
lines changed Original file line number Diff line number Diff line change @@ -475,5 +475,43 @@ namespace deepx
475475 }
476476 }
477477
478+ template <Precision P>
479+ struct PrecisionWrapper {};
480+
481+ template <typename PrecisionWrapper>
482+ struct to_tensor_type ;
483+
484+ template <>
485+ struct to_tensor_type <PrecisionWrapper<Precision::Float64>> {
486+ using type = double ;
487+ };
488+
489+ template <>
490+ struct to_tensor_type <PrecisionWrapper<Precision::Float32>> {
491+ using type = float ;
492+ };
493+
494+ template <>
495+ struct to_tensor_type <PrecisionWrapper<Precision::Int64>> {
496+ using type = int64_t ;
497+ };
498+
499+ template <>
500+ struct to_tensor_type <PrecisionWrapper<Precision::Int32>> {
501+ using type = int32_t ;
502+ };
503+
504+ template <>
505+ struct to_tensor_type <PrecisionWrapper<Precision::Int16>> {
506+ using type = int16_t ;
507+ };
508+
509+ template <>
510+ struct to_tensor_type <PrecisionWrapper<Precision::Int8>> {
511+ using type = int8_t ;
512+ };
513+
514+ template <Precision p>
515+ using tensor_t = typename to_tensor_type<PrecisionWrapper<p>>::type;
478516} // namespace deepx
479517#endif
Original file line number Diff line number Diff line change @@ -6,9 +6,7 @@ using namespace std;
66using namespace deepx ::tf;
77using namespace deepx ;
88
9- int main (int argc, char **argv)
10- {
11-
9+ void test_1 () {
1210 unordered_map<string, TypeDef> dtype_map = {
1311 {" tensor<any>" , make_dtype (DataCategory::Tensor, Precision::Any)},
1412 {" tensor<int>" , make_dtype (DataCategory::Tensor, Precision::Int)},
@@ -54,6 +52,26 @@ int main(int argc, char **argv)
5452 }
5553
5654 cout << string (80 , ' =' ) << endl;
55+ }
56+
57+ // test to tensor type
58+ void test_2 () {
59+ if (typeid (tensor_t <Precision::Float64>)== typeid (double )) {
60+ std::cout<<" it's ok" <<std::endl;
61+ } else {
62+ std::cout<<" it's wrong" <<std::endl;
63+ }
5764
65+ if (typeid (tensor_t <Precision::Float32>)== typeid (float )) {
66+ std::cout<<" it's ok" <<std::endl;
67+ } else {
68+ std::cout<<" it's wrong" <<std::endl;
69+ }
70+ }
71+
72+ int main (int argc, char **argv)
73+ {
74+ // test_1();
75+ test_2 ();
5876 return 0 ;
5977}
You can’t perform that action at this time.
0 commit comments