Skip to content

Commit e6cd0cb

Browse files
authored
add template for dtype transfer (#47)
1 parent 8afb4ac commit e6cd0cb

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

excuter/cpp-common/src/deepx/dtype.hpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,5 +475,43 @@ namespace deepx
475475
}
476476
}
477477

478+
template <Precision P>
479+
struct PrecisionWrapper {};
480+
481+
template <typename PrecisionWrapper>
482+
struct to_tensor_type;
483+
484+
template <>
485+
struct to_tensor_type<PrecisionWrapper<Precision::Float64>> {
486+
using type = double;
487+
};
488+
489+
template <>
490+
struct to_tensor_type<PrecisionWrapper<Precision::Float32>> {
491+
using type = float;
492+
};
493+
494+
template <>
495+
struct to_tensor_type<PrecisionWrapper<Precision::Int64>> {
496+
using type = int64_t;
497+
};
498+
499+
template <>
500+
struct to_tensor_type<PrecisionWrapper<Precision::Int32>> {
501+
using type = int32_t;
502+
};
503+
504+
template <>
505+
struct to_tensor_type<PrecisionWrapper<Precision::Int16>> {
506+
using type = int16_t;
507+
};
508+
509+
template <>
510+
struct to_tensor_type<PrecisionWrapper<Precision::Int8>> {
511+
using type = int8_t;
512+
};
513+
514+
template <Precision p>
515+
using tensor_t = typename to_tensor_type<PrecisionWrapper<p>>::type;
478516
} // namespace deepx
479517
#endif

excuter/cpp-common/test/0_dtypes.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ using namespace std;
66
using namespace deepx::tf;
77
using namespace deepx;
88

9-
int main(int argc, char **argv)
10-
{
11-
9+
void test_1() {
1210
unordered_map<string, TypeDef> dtype_map = {
1311
{"tensor<any>", make_dtype(DataCategory::Tensor, Precision::Any)},
1412
{"tensor<int>", make_dtype(DataCategory::Tensor, Precision::Int)},
@@ -54,6 +52,26 @@ int main(int argc, char **argv)
5452
}
5553

5654
cout << string(80, '=') << endl;
55+
}
56+
57+
// test to tensor type
58+
void test_2() {
59+
if (typeid(tensor_t<Precision::Float64>)== typeid(double)) {
60+
std::cout<<"it's ok"<<std::endl;
61+
} else {
62+
std::cout<<"it's wrong"<<std::endl;
63+
}
5764

65+
if (typeid(tensor_t<Precision::Float32>)== typeid(float)) {
66+
std::cout<<"it's ok"<<std::endl;
67+
} else {
68+
std::cout<<"it's wrong"<<std::endl;
69+
}
70+
}
71+
72+
int main(int argc, char **argv)
73+
{
74+
// test_1();
75+
test_2();
5876
return 0;
5977
}

0 commit comments

Comments
 (0)