Skip to content

Commit 7d40d00

Browse files
committed
cuda dtype transfer
1 parent 73addbb commit 7d40d00

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

excuter/op-mem-cuda/src/deepx/dtype_cuda.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <cuda_fp16.h>
55
#include <cuda_bf16.h>
6+
#include <cuda_fp8.h>
67

78
#include "deepx/dtype.hpp"
89

@@ -34,6 +35,27 @@ namespace deepx
3435
else
3536
return Precision::Any;
3637
}
38+
39+
40+
template <>
41+
struct to_tensor_type<PrecisionWrapper<Precision::BFloat16>> {
42+
using type = nv_bfloat16;
43+
};
44+
45+
template <>
46+
struct to_tensor_type<PrecisionWrapper<Precision::Float16>> {
47+
using type = half;
48+
};
49+
50+
template <>
51+
struct to_tensor_type<PrecisionWrapper<Precision::Float8E5M2>> {
52+
using type = __nv_fp8_e5m2;
53+
};
54+
55+
template <>
56+
struct to_tensor_type<PrecisionWrapper<Precision::Float8e4m3>> {
57+
using type = __nv_fp8_e4m3;
58+
}
3759
}
3860

3961
#endif // DEEPX_DTYPE_CUDA_HPP

0 commit comments

Comments
 (0)