We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 73addbb commit 7d40d00Copy full SHA for 7d40d00
1 file changed
excuter/op-mem-cuda/src/deepx/dtype_cuda.hpp
@@ -3,6 +3,7 @@
3
4
#include <cuda_fp16.h>
5
#include <cuda_bf16.h>
6
+#include <cuda_fp8.h>
7
8
#include "deepx/dtype.hpp"
9
@@ -34,6 +35,27 @@ namespace deepx
34
35
else
36
return Precision::Any;
37
}
38
+
39
40
+ template <>
41
+ struct to_tensor_type<PrecisionWrapper<Precision::BFloat16>> {
42
+ using type = nv_bfloat16;
43
+ };
44
45
46
+ struct to_tensor_type<PrecisionWrapper<Precision::Float16>> {
47
+ using type = half;
48
49
50
51
+ struct to_tensor_type<PrecisionWrapper<Precision::Float8E5M2>> {
52
+ using type = __nv_fp8_e5m2;
53
54
55
56
+ struct to_tensor_type<PrecisionWrapper<Precision::Float8e4m3>> {
57
+ using type = __nv_fp8_e4m3;
58
+ }
59
60
61
#endif // DEEPX_DTYPE_CUDA_HPP
0 commit comments