Skip to content

Commit e993df9

Browse files
authored
compare kernel for fp8 (#98)
1 parent 8d3939c commit e993df9

File tree

3 files changed

+169
-26
lines changed

3 files changed

+169
-26
lines changed

executor/op-mem-cuda/src/deepx/dtype_cuda.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cuda_fp16.h>
55
#include <cuda_bf16.h>
66
#include <cuda_fp8.h>
7+
#include <type_traits>
78

89
#include "deepx/dtype.hpp"
910

@@ -56,6 +57,45 @@ namespace deepx
5657
struct to_tensor_type<PrecisionWrapper<Precision::Float8E4M3>> {
5758
using type = __nv_fp8_e4m3;
5859
};
60+
61+
62+
63+
template <typename T>
64+
struct fp8_format_map;
65+
66+
template <>
67+
struct fp8_format_map<__nv_fp8_e5m2> {
68+
static constexpr __nv_fp8_interpretation_t value = __NV_E5M2;
69+
};
70+
71+
template <>
72+
struct fp8_format_map<__nv_fp8_e4m3> {
73+
static constexpr __nv_fp8_interpretation_t value = __NV_E4M3;
74+
};
75+
76+
template<typename T>
77+
struct is_fp8 : std::false_type {}; // 默认 false
78+
79+
template<> struct is_fp8<__nv_fp8_e4m3> : std::true_type {};
80+
template<> struct is_fp8<__nv_fp8_e5m2> : std::true_type {};
81+
82+
83+
template <typename T>
84+
inline constexpr bool is_fp8_v = is_fp8<T>::value;
85+
86+
template <typename T>
87+
struct to_half {
88+
static __host__ __device__ __half convert(T a) {
89+
return __nv_cvt_fp8_to_halfraw(static_cast<__nv_fp8_storage_t>(a), fp8_format_map<T>::value);
90+
}
91+
};
92+
93+
template <typename T>
94+
struct to_fp8 {
95+
static __host__ __device__ T convert(half a) {
96+
return static_cast<T>(__nv_cvt_halfraw_to_fp8(a, __NV_SATFINITE, fp8_format_map<T>::value));
97+
}
98+
};
5999
}
60100

61101
#endif // DEEPX_DTYPE_CUDA_HPP

executor/op-mem-cuda/src/deepx/tensorfunc/cuda_math.cuh

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <cuda_bf16.h>
77
#include <cuda_fp8.h>
88
#include <cublas_v2.h>
9+
#include "deepx/dtype_cuda.hpp"
910

1011
namespace deepx::tensorfunc
1112
{
@@ -38,24 +39,12 @@ namespace deepx::tensorfunc
3839
*out = hsqrt(*a);
3940
}
4041

41-
template <>
42-
__device__ __forceinline__ void deepx_sqrt<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *a, __nv_fp8_e4m3 *out)
42+
template <typename T, std::enable_if_t<is_fp8_v<T>> = 0>
43+
__device__ __forceinline__ void deepx_sqrt(const T *a, T *out)
4344
{
44-
__half input_fp16 = __nv_cvt_fp8_to_halfraw(static_cast<__nv_fp8_storage_t>(*a), __NV_E4M3);
45-
__half result_fp16 = hsqrt(input_fp16); // CUDA 内置半精度平方根
46-
*out = static_cast<__nv_fp8_e4m3>(__nv_cvt_halfraw_to_fp8(result_fp16, __NV_SATFINITE, __NV_E4M3));
47-
}
48-
49-
template <>
50-
__device__ __forceinline__ void deepx_sqrt<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *a, __nv_fp8_e5m2 *out)
51-
{
52-
__half input_fp16 = __nv_cvt_fp8_to_halfraw(static_cast<__nv_fp8_storage_t>(*a), __NV_E5M2);
53-
54-
// 2. 执行平方根
55-
__half result_fp16 = hsqrt(input_fp16);
56-
57-
// 3. 转回 FP8 → E5M2 格式
58-
*out =static_cast<__nv_fp8_e5m2>(__nv_cvt_halfraw_to_fp8(result_fp16, __NV_SATFINITE, __NV_E5M2));
45+
__half input_half = to_half<T>::convert(*a);
46+
__half result_half = hsqrt(input_half); // CUDA 内置半精度平方根
47+
*out = to_fp8<T>::convert(result_half);
5948
}
6049

6150

executor/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cu

Lines changed: 123 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#ifndef DEEPX_TENSORFUNC_ELEMENTWISE_MIAO_BYTE_COMPARE_CU
22
#define DEEPX_TENSORFUNC_ELEMENTWISE_MIAO_BYTE_COMPARE_CU
33

4+
#include <cuda_fp8.h>
45
#include "deepx/tensorfunc/cuda.hpp"
56
#include "deepx/tensorfunc/authors.hpp"
67
#include "deepx/tensorfunc/vector_cuda.cuh"
8+
#include "deepx/dtype_cuda.hpp"
79
namespace deepx::tensorfunc
810
{
9-
template <typename T>
11+
template <typename T, std::enable_if_t<!is_fp8_v<T>, int> = 0>
1012
__global__ void max_kernel(const T *A, const T *B, T *C, const int size)
1113
{
1214
int stride = blockDim.x * gridDim.x;
@@ -16,6 +18,20 @@ namespace deepx::tensorfunc
1618
}
1719
}
1820

21+
template <typename T, std::enable_if_t<is_fp8_v<T>, int> = 0>
22+
__global__ void max_kernel(const T *A, const T *B, T *C, const int size)
23+
{
24+
int stride = blockDim.x * gridDim.x;
25+
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride)
26+
{
27+
__half temp_a = to_half<T>::convert(A[idx]);
28+
__half temp_b = to_half<T>::convert(B[idx]);
29+
__half temp_c = temp_a > temp_b ? temp_a : temp_b;
30+
C[idx] = to_fp8<T>::convert(temp_c);
31+
}
32+
}
33+
34+
1935
template <typename T>
2036
void launch_max(const T *A, const T *B, T *C, const int size)
2137
{
@@ -32,8 +48,10 @@ namespace deepx::tensorfunc
3248
template void launch_max<int32_t>(const int32_t *A, const int32_t *B, int32_t *C, const int size);
3349
template void launch_max<int16_t>(const int16_t *A, const int16_t *B, int16_t *C, const int size);
3450
template void launch_max<int8_t>(const int8_t *A, const int8_t *B, int8_t *C, const int size);
51+
template void launch_max<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, __nv_fp8_e4m3 *C, const int size);
52+
template void launch_max<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 *B, __nv_fp8_e5m2 *C, const int size);
3553

36-
template <typename T>
54+
template <typename T, std::enable_if_t<!is_fp8_v<T>, int> = 0>
3755
__global__ void maxscalar_kernel(const T *A, const T scalar, T *C, const int size)
3856
{
3957
int stride = blockDim.x * gridDim.x;
@@ -43,6 +61,19 @@ namespace deepx::tensorfunc
4361
}
4462
}
4563

64+
template <typename T, std::enable_if_t<is_fp8_v<T>, int> = 0>
65+
__global__ void maxscalar_kernel(const T *A, const T scalar, T *C, const int size)
66+
{
67+
int stride = blockDim.x * gridDim.x;
68+
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride)
69+
{
70+
__half temp_a = to_half<T>::convert(A[idx]);
71+
__half temp_scalar = to_half<T>::convert(scalar);
72+
__half temp_c = temp_a > temp_scalar ? temp_a : temp_scalar;
73+
C[idx] = to_fp8<T>::convert(temp_c);
74+
}
75+
}
76+
4677
template <typename T>
4778
void launch_maxscalar(const T *A, const T scalar, T *C, const int size)
4879
{
@@ -59,8 +90,10 @@ namespace deepx::tensorfunc
5990
template void launch_maxscalar<int32_t>(const int32_t *A, const int32_t scalar, int32_t *C, const int size);
6091
template void launch_maxscalar<int16_t>(const int16_t *A, const int16_t scalar, int16_t *C, const int size);
6192
template void launch_maxscalar<int8_t>(const int8_t *A, const int8_t scalar, int8_t *C, const int size);
93+
template void launch_maxscalar<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 scalar, __nv_fp8_e4m3 *C, const int size);
94+
template void launch_maxscalar<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 scalar, __nv_fp8_e5m2 *C, const int size);
6295

63-
template <typename T>
96+
template <typename T, std::enable_if_t<!is_fp8_v<T>, int> = 0>
6497
__global__ void min_kernel(const T *A, const T *B, T *C, const int size)
6598
{
6699
int stride = blockDim.x * gridDim.x;
@@ -70,6 +103,20 @@ namespace deepx::tensorfunc
70103
}
71104
}
72105

106+
107+
template <typename T, std::enable_if_t<is_fp8_v<T>, int> = 0>
108+
__global__ void min_kernel(const T *A, const T *B, T *C, const int size)
109+
{
110+
int stride = blockDim.x * gridDim.x;
111+
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride)
112+
{
113+
__half temp_a = to_half<T>::convert(A[idx]);
114+
__half temp_b = to_half<T>::convert(B[idx]);
115+
__half temp_c = temp_a < temp_b ? temp_a : temp_b;
116+
C[idx] = to_fp8<T>::convert(temp_c);
117+
}
118+
}
119+
73120
template <typename T>
74121
void launch_min(const T *A, const T *B, T *C, const int size)
75122
{
@@ -86,8 +133,10 @@ namespace deepx::tensorfunc
86133
template void launch_min<int32_t>(const int32_t *A, const int32_t *B, int32_t *C, const int size);
87134
template void launch_min<int16_t>(const int16_t *A, const int16_t *B, int16_t *C, const int size);
88135
template void launch_min<int8_t>(const int8_t *A, const int8_t *B, int8_t *C, const int size);
136+
template void launch_min<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, __nv_fp8_e4m3 *C, const int size);
137+
template void launch_min<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 *B, __nv_fp8_e5m2 *C, const int size);
89138

90-
template <typename T>
139+
template <typename T, std::enable_if_t<!is_fp8_v<T>, int> = 0>
91140
__global__ void minscalar_kernel(const T *A, const T scalar, T *C, const int size)
92141
{
93142
int stride = blockDim.x * gridDim.x;
@@ -97,6 +146,19 @@ namespace deepx::tensorfunc
97146
}
98147
}
99148

149+
template <typename T, std::enable_if_t<is_fp8_v<T>, int> = 0>
150+
__global__ void minscalar_kernel(const T *A, const T scalar, T *C, const int size)
151+
{
152+
int stride = blockDim.x * gridDim.x;
153+
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride)
154+
{
155+
__half temp_a = to_half<T>::convert(A[idx]);
156+
__half temp_scalar = to_half<T>::convert(scalar);
157+
__half temp_c = temp_a < temp_scalar ? temp_a : temp_scalar;
158+
C[idx] = to_fp8<T>::convert(temp_c);
159+
}
160+
}
161+
100162
template <typename T>
101163
void launch_minscalar(const T *A, const T scalar, T *C, const int size)
102164
{
@@ -113,9 +175,11 @@ namespace deepx::tensorfunc
113175
template void launch_minscalar<int32_t>(const int32_t *A, const int32_t scalar, int32_t *C, const int size);
114176
template void launch_minscalar<int16_t>(const int16_t *A, const int16_t scalar, int16_t *C, const int size);
115177
template void launch_minscalar<int8_t>(const int8_t *A, const int8_t scalar, int8_t *C, const int size);
178+
template void launch_minscalar<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 scalar, __nv_fp8_e4m3 *C, const int size);
179+
template void launch_minscalar<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 scalar, __nv_fp8_e5m2 *C, const int size);
116180

117181
// equal
118-
template <typename T,typename MaskT>
182+
template <typename T,typename MaskT, std::enable_if_t<!is_fp8_v<T>, int> = 0>
119183
__global__ void equalwithepsilon_kernel(const T *A, const T *B, const float epsilon, MaskT *mask, const int size)
120184
{
121185
int stride = blockDim.x * gridDim.x;
@@ -133,7 +197,28 @@ namespace deepx::tensorfunc
133197
}
134198
}
135199

136-
template <typename T,typename MaskT>
200+
// equal
201+
template <typename T, typename MaskT, std::enable_if_t<is_fp8_v<T>, int> = 0>
202+
__global__ void equalwithepsilon_kernel(const T *A, const T *B, const float epsilon, MaskT *mask, const int size)
203+
{
204+
int stride = blockDim.x * gridDim.x;
205+
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride)
206+
{
207+
float diff = fabsf(static_cast<float>(to_half<T>::convert(A[idx])) - static_cast<float>(to_half<T>::convert(B[idx])));
208+
if (diff < epsilon)
209+
{
210+
mask[idx] = 1;
211+
}
212+
else
213+
{
214+
mask[idx] = 0;
215+
}
216+
}
217+
}
218+
219+
220+
221+
template <typename T,typename MaskT, std::enable_if_t<!is_fp8_v<T>, int> = 0>
137222
__global__ void equal_kernel(const T *A, const T *B, MaskT *mask, const int size)
138223
{
139224
int stride = blockDim.x * gridDim.x;
@@ -143,6 +228,16 @@ namespace deepx::tensorfunc
143228
}
144229
}
145230

231+
template <typename T,typename MaskT, std::enable_if_t<is_fp8_v<T>, int> = 0>
232+
__global__ void equal_kernel(const T *A, const T *B, MaskT *mask, const int size)
233+
{
234+
int stride = blockDim.x * gridDim.x;
235+
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride)
236+
{
237+
mask[idx] = (to_half<T>::convert(A[idx]) == to_half<T>::convert(B[idx]));
238+
}
239+
}
240+
146241
template <typename T,typename MaskT>
147242
void launch_equal(const T *A, const T *B, const float epsilon, MaskT *mask, const int size)
148243
{
@@ -166,6 +261,8 @@ namespace deepx::tensorfunc
166261
template void launch_equal<int32_t,bool>(const int32_t *A, const int32_t *B, const float epsilon, bool *mask, const int size);
167262
template void launch_equal<int16_t,bool>(const int16_t *A, const int16_t *B, const float epsilon, bool *mask, const int size);
168263
template void launch_equal<int8_t,bool>(const int8_t *A, const int8_t *B, const float epsilon, bool *mask, const int size);
264+
template void launch_equal<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, const float epsilon, bool *mask, const int size);
265+
template void launch_equal<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 *B, const float epsilon, bool *mask, const int size);
169266

170267
// equalscalar
171268
template <typename T,typename MaskT>
@@ -219,6 +316,8 @@ namespace deepx::tensorfunc
219316
template void launch_equalscalar<int32_t,bool>(const int32_t *A, const int32_t scalar, const float epsilon, bool *mask, const int size);
220317
template void launch_equalscalar<int16_t,bool>(const int16_t *A, const int16_t scalar, const float epsilon, bool *mask, const int size);
221318
template void launch_equalscalar<int8_t,bool>(const int8_t *A, const int8_t scalar, const float epsilon, bool *mask, const int size);
319+
// template void launch_equalscalar<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 scalar, const float epsilon, bool *mask, const int size);
320+
// template void launch_equalscalar<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 scalar, const float epsilon, bool *mask, const int size);
222321

223322
// not equal
224323
template <typename T,typename MaskT>
@@ -272,6 +371,8 @@ namespace deepx::tensorfunc
272371
template void launch_notequal<int32_t,bool>(const int32_t *A, const int32_t *B, const float epsilon, bool *mask, const int size);
273372
template void launch_notequal<int16_t,bool>(const int16_t *A, const int16_t *B, const float epsilon, bool *mask, const int size);
274373
template void launch_notequal<int8_t,bool>(const int8_t *A, const int8_t *B, const float epsilon, bool *mask, const int size);
374+
// template void launch_notequal<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, const float epsilon, bool *mask, const int size);
375+
// template void launch_notequal<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 *B, const float epsilon, bool *mask, const int size);
275376

276377
// notequalscalar
277378
template <typename T,typename MaskT>
@@ -325,6 +426,8 @@ namespace deepx::tensorfunc
325426
template void launch_notequalscalar<int32_t,bool>(const int32_t *A, const int32_t scalar, const float epsilon, bool *mask, const int size);
326427
template void launch_notequalscalar<int16_t,bool>(const int16_t *A, const int16_t scalar, const float epsilon, bool *mask, const int size);
327428
template void launch_notequalscalar<int8_t,bool>(const int8_t *A, const int8_t scalar, const float epsilon, bool *mask, const int size);
429+
// template void launch_notequalscalar<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 scalar, const float epsilon, bool *mask, const int size);
430+
// template void launch_notequalscalar<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 scalar, const float epsilon, bool *mask, const int size);
328431

329432
// less
330433
template <typename T,typename MaskT>
@@ -353,6 +456,8 @@ namespace deepx::tensorfunc
353456
template void launch_less<int32_t,bool>(const int32_t *A, const int32_t *B, bool *mask, const int size);
354457
template void launch_less<int16_t,bool>(const int16_t *A, const int16_t *B, bool *mask, const int size);
355458
template void launch_less<int8_t,bool>(const int8_t *A, const int8_t *B, bool *mask, const int size);
459+
// template void launch_less<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, bool *mask, const int size);
460+
// template void launch_less<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 *B, bool *mask, const int size);
356461

357462
// lessscalar
358463

@@ -382,7 +487,9 @@ namespace deepx::tensorfunc
382487
template void launch_lessscalar<int32_t,bool>(const int32_t *A, const int32_t scalar, bool *mask, const int size);
383488
template void launch_lessscalar<int16_t,bool>(const int16_t *A, const int16_t scalar, bool *mask, const int size);
384489
template void launch_lessscalar<int8_t,bool>(const int8_t *A, const int8_t scalar, bool *mask, const int size);
385-
490+
// template void launch_lessscalar<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 scalar, bool *mask, const int size);
491+
// template void launch_lessscalar<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 scalar, bool *mask, const int size);
492+
386493
// greater
387494
template <typename T,typename MaskT>
388495
__global__ void greater_kernel(const T *A, const T *B, MaskT *mask, const int size)
@@ -410,6 +517,8 @@ namespace deepx::tensorfunc
410517
template void launch_greater<int32_t,bool>(const int32_t *A, const int32_t *B, bool *mask, const int size);
411518
template void launch_greater<int16_t,bool>(const int16_t *A, const int16_t *B, bool *mask, const int size);
412519
template void launch_greater<int8_t,bool>(const int8_t *A, const int8_t *B, bool *mask, const int size);
520+
// template void launch_greater<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, bool *mask, const int size);
521+
// template void launch_greater<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 *B, bool *mask, const int size);
413522

414523
// greaterscalar
415524
template <typename T,typename MaskT>
@@ -438,6 +547,8 @@ namespace deepx::tensorfunc
438547
template void launch_greaterscalar<int32_t,bool>(const int32_t *A, const int32_t scalar, bool *mask, const int size);
439548
template void launch_greaterscalar<int16_t,bool>(const int16_t *A, const int16_t scalar, bool *mask, const int size);
440549
template void launch_greaterscalar<int8_t,bool>(const int8_t *A, const int8_t scalar, bool *mask, const int size);
550+
// template void launch_greaterscalar<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 scalar, bool *mask, const int size);
551+
// template void launch_greaterscalar<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 scalar, bool *mask, const int size);
441552

442553
// switch
443554
template <typename T,typename casesT>
@@ -476,7 +587,9 @@ namespace deepx::tensorfunc
476587
template void launch_switch<int16_t,int32_t>(const int16_t **tensorsdata, const int numTensors, const int32_t *cases, int16_t *C, const int size);
477588
template void launch_switch<int8_t,int32_t>(const int8_t **tensorsdata, const int numTensors, const int32_t *cases, int8_t *C, const int size);
478589
template void launch_switch<bool,int32_t>(const bool **tensorsdata, const int numTensors, const int32_t *cases, bool *C, const int size);
479-
590+
// template void launch_switch<__nv_fp8_e4m3,int32_t>(const __nv_fp8_e4m3 **tensorsdata, const int numTensors, const int32_t *cases, __nv_fp8_e4m3 *C, const int size);
591+
// template void launch_switch<__nv_fp8_e5m2,int32_t>(const __nv_fp8_e5m2 **tensorsdata, const int numTensors, const int32_t *cases, __nv_fp8_e5m2 *C, const int size);
592+
480593
template void launch_switch<double,bool>(const double **tensorsdata, const int numTensors, const bool *cases, double *C, const int size);
481594
template void launch_switch<float,bool>(const float **tensorsdata, const int numTensors, const bool *cases, float *C, const int size);
482595
template void launch_switch<nv_bfloat16,bool>(const nv_bfloat16 **tensorsdata, const int numTensors, const bool *cases, nv_bfloat16 *C, const int size);
@@ -486,6 +599,7 @@ namespace deepx::tensorfunc
486599
template void launch_switch<int16_t,bool>(const int16_t **tensorsdata, const int numTensors, const bool *cases, int16_t *C, const int size);
487600
template void launch_switch<int8_t,bool>(const int8_t **tensorsdata, const int numTensors, const bool *cases, int8_t *C, const int size);
488601
template void launch_switch<bool,bool>(const bool **tensorsdata, const int numTensors, const bool *cases, bool *C, const int size);
489-
602+
// template void launch_switch<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 **tensorsdata, const int numTensors, const bool *cases, __nv_fp8_e4m3 *C, const int size);
603+
// template void launch_switch<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 **tensorsdata, const int numTensors, const bool *cases, __nv_fp8_e5m2 *C, const int size);
490604
}
491605
#endif // DEEPX_TENSORFUNC_ELEMENTWISE_MIAO_BYTE_COMPARE_CU

0 commit comments

Comments
 (0)