Skip to content

Commit 5409286

Browse files
authored
op:switch&where验证完成 (#55)
* op:switch&where 待验证 * op:switch&where验证完成
1 parent b2d66fd commit 5409286

12 files changed

Lines changed: 181 additions & 42 deletions

File tree

doc/excuter/op-mem-cuda/list.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
| Operation | Author | Math Formula | IR Instruction |
5151
|-----------|--------|--------------|----------------|
52-
| switch | miaobyte | C=switch(tensors,cases) | switch(listtensor<any> tensors, tensor<int8> cases)->(tensor<any> result) |
52+
| switch | miaobyte | C=switch(tensors,cases) | switch(listtensor<any> tensors, tensor<int32|bool> cases)->(tensor<any> result) |
5353
| greaterscalar | miaobyte | mask=compare(T1, scalar) | greaterscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
5454
| notequal | miaobyte | T1!=T2->mask | notequal(tensor<any> A, tensor<any> B, var<float32> epsilon)->(tensor<bool> mask) |
5555
| equalscalar | miaobyte | T1==scalar->mask | equalscalar(tensor<any> A, var<any> scalar, var<float32> epsilon)->(tensor<bool> mask) |

doc/excuter/op-mem-ompsimd/list.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
| Operation | Author | Math Formula | IR Instruction |
5252
|-----------|--------|--------------|----------------|
53-
| switch | miaobyte | C=switch([tensors],case) | switch(listtensor<any> tensors, tensor<int8> cases)->(tensor<any> C) |
53+
| switch | miaobyte | C=switch([tensors],case) | switch(listtensor<any> tensors, tensor<int32|bool> cases)->(tensor<any> C) |
5454
| greaterscalar | miaobyte | mask=greater(T1,scalar) | greaterscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
5555
| notequal | miaobyte | notequal(T1,T2)->mask | notequal(tensor<any> A, tensor<any> B, var<float32> epsilon)->(tensor<bool> mask) |
5656
| equalscalar | miaobyte | mask=equal(T1,scalar) | equalscalar(tensor<any> A, var<any> scalar, var<float32> eposilon)->(tensor<bool> mask) |
@@ -63,7 +63,7 @@
6363
| lessscalar | miaobyte | mask=less(T1,scalar) | lessscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
6464
| notequalscalar | miaobyte | mask=notequal(T1,scalar) | notequalscalar(tensor<any> A, var<any> scalar, var<float32> epsilon)->(tensor<bool> mask) |
6565
| minscalar | miaobyte | T3=min(T1,scalar) | minscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
66-
| rpowscalar | miaobyte | T3=scalar^T1 | rpowscalar(var<any> scalar, tensor<any> A)->(tensor<any> C) |
66+
| rpowscalar | miaobyte | T3=scalar^T1 | rpowscalar(var<float32> scalar, tensor<any> A)->(tensor<any> C) |
6767
| rdivscalar | miaobyte | T3=scalar/T1 | rdivscalar(var<any> scalar, tensor<any> A)->(tensor<any> C) |
6868
| less | miaobyte | mask=less(T1,T2) | less(tensor<any> A, tensor<any> B)->(tensor<bool> mask) |
6969
| powscalar | miaobyte | T3=T1^scalar | powscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |

excuter/op-mem-cuda/src/client/tfs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ namespace deepx::tf
486486
tffactory.add_tf(std::make_shared<Switch<miaobyte>>(vector<Param>(
487487
{
488488
Param("tensors", DataCategory::ListTensor, Precision::Any),
489-
Param("cases", DataCategory::Tensor, Precision::Int8),
489+
Param("cases", DataCategory::Tensor, Precision::Int32|Precision::Bool),
490490
}),
491491
vector<Param>(
492492
{

excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cu

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -524,15 +524,25 @@ namespace deepx::tensorfunc
524524
}
525525
}
526526

527-
template void launch_switch<double,int8_t>(const double **tensorsdata, const int numTensors, const int8_t *cases, double *C, const int size);
528-
template void launch_switch<float,int8_t>(const float **tensorsdata, const int numTensors, const int8_t *cases, float *C, const int size);
529-
template void launch_switch<nv_bfloat16,int8_t>(const nv_bfloat16 **tensorsdata, const int numTensors, const int8_t *cases, nv_bfloat16 *C, const int size);
530-
template void launch_switch<__half,int8_t>(const __half **tensorsdata, const int numTensors, const int8_t *cases, __half *C, const int size);
531-
template void launch_switch<int64_t,int8_t>(const int64_t **tensorsdata, const int numTensors, const int8_t *cases, int64_t *C, const int size);
532-
template void launch_switch<int32_t,int8_t>(const int32_t **tensorsdata, const int numTensors, const int8_t *cases, int32_t *C, const int size);
533-
template void launch_switch<int16_t,int8_t>(const int16_t **tensorsdata, const int numTensors, const int8_t *cases, int16_t *C, const int size);
534-
template void launch_switch<int8_t,int8_t>(const int8_t **tensorsdata, const int numTensors, const int8_t *cases, int8_t *C, const int size);
535-
template void launch_switch<bool,int8_t>(const bool **tensorsdata, const int numTensors, const int8_t *cases, bool *C, const int size);
527+
template void launch_switch<double,int32_t>(const double **tensorsdata, const int numTensors, const int32_t *cases, double *C, const int size);
528+
template void launch_switch<float,int32_t>(const float **tensorsdata, const int numTensors, const int32_t *cases, float *C, const int size);
529+
template void launch_switch<nv_bfloat16,int32_t>(const nv_bfloat16 **tensorsdata, const int numTensors, const int32_t *cases, nv_bfloat16 *C, const int size);
530+
template void launch_switch<__half,int32_t>(const __half **tensorsdata, const int numTensors, const int32_t *cases, __half *C, const int size);
531+
template void launch_switch<int64_t,int32_t>(const int64_t **tensorsdata, const int numTensors, const int32_t *cases, int64_t *C, const int size);
532+
template void launch_switch<int32_t,int32_t>(const int32_t **tensorsdata, const int numTensors, const int32_t *cases, int32_t *C, const int size);
533+
template void launch_switch<int16_t,int32_t>(const int16_t **tensorsdata, const int numTensors, const int32_t *cases, int16_t *C, const int size);
534+
template void launch_switch<int8_t,int32_t>(const int8_t **tensorsdata, const int numTensors, const int32_t *cases, int8_t *C, const int size);
535+
template void launch_switch<bool,int32_t>(const bool **tensorsdata, const int numTensors, const int32_t *cases, bool *C, const int size);
536+
537+
template void launch_switch<double,bool>(const double **tensorsdata, const int numTensors, const bool *cases, double *C, const int size);
538+
template void launch_switch<float,bool>(const float **tensorsdata, const int numTensors, const bool *cases, float *C, const int size);
539+
template void launch_switch<nv_bfloat16,bool>(const nv_bfloat16 **tensorsdata, const int numTensors, const bool *cases, nv_bfloat16 *C, const int size);
540+
template void launch_switch<__half,bool>(const __half **tensorsdata, const int numTensors, const bool *cases, __half *C, const int size);
541+
template void launch_switch<int64_t,bool>(const int64_t **tensorsdata, const int numTensors, const bool *cases, int64_t *C, const int size);
542+
template void launch_switch<int32_t,bool>(const int32_t **tensorsdata, const int numTensors, const bool *cases, int32_t *C, const int size);
543+
template void launch_switch<int16_t,bool>(const int16_t **tensorsdata, const int numTensors, const bool *cases, int16_t *C, const int size);
544+
template void launch_switch<int8_t,bool>(const int8_t **tensorsdata, const int numTensors, const bool *cases, int8_t *C, const int size);
545+
template void launch_switch<bool,bool>(const bool **tensorsdata, const int numTensors, const bool *cases, bool *C, const int size);
536546

537547
}
538548
#endif // DEEPX_TENSORFUNC_ELEMENTWISE_MIAO_BYTE_COMPARE_CU

excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -835,35 +835,99 @@ namespace deepx::tf
835835
{
836836

837837
Precision C_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype;
838-
838+
Precision cases_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype;
839+
839840
switch (C_type)
840841
{
841842
case Precision::Float64:
842-
tensorfunc::Switch<Author, double>(mem->gettensors<double>(this->getvector<string>(0)), *mem->gettensor<int8_t>(this->args[1].textvalue), *mem->gettensor<double>(this->returns[0].textvalue));
843-
break;
844-
case Precision::Float32:
845-
tensorfunc::Switch<Author, float>(mem->gettensors<float>(this->getvector<string>(0)), *mem->gettensor<int8_t>(this->args[1].textvalue), *mem->gettensor<float>(this->returns[0].textvalue));
843+
if (cases_type == Precision::Bool)
844+
{
845+
tensorfunc::Switch<Author, double,bool>(mem->gettensors<double>(this->getvector<string>(0)), *mem->gettensor<bool>(this->args[1].textvalue), *mem->gettensor<double>(this->returns[0].textvalue));
846+
}
847+
else
848+
{
849+
tensorfunc::Switch<Author, double,int32_t>(mem->gettensors<double>(this->getvector<string>(0)), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<double>(this->returns[0].textvalue));
850+
}
851+
break;
852+
case Precision::Float32:
853+
if (cases_type == Precision::Bool)
854+
{
855+
tensorfunc::Switch<Author, float,bool>(mem->gettensors<float>(this->getvector<string>(0)), *mem->gettensor<bool>(this->args[1].textvalue), *mem->gettensor<float>(this->returns[0].textvalue));
856+
}
857+
else
858+
{
859+
tensorfunc::Switch<Author, float,int32_t>(mem->gettensors<float>(this->getvector<string>(0)), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<float>(this->returns[0].textvalue));
860+
}
846861
break;
847862
case Precision::Float16:
848-
tensorfunc::Switch<Author, half>(mem->gettensors<half>(this->getvector<string>(0)), *mem->gettensor<int8_t>(this->args[1].textvalue), *mem->gettensor<half>(this->returns[0].textvalue));
863+
if (cases_type == Precision::Bool)
864+
{
865+
tensorfunc::Switch<Author, half,bool>(mem->gettensors<half>(this->getvector<string>(0)), *mem->gettensor<bool>(this->args[1].textvalue), *mem->gettensor<half>(this->returns[0].textvalue));
866+
}
867+
else
868+
{
869+
tensorfunc::Switch<Author, half,int32_t>(mem->gettensors<half>(this->getvector<string>(0)), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<half>(this->returns[0].textvalue));
870+
}
849871
break;
850872
case Precision::BFloat16:
851-
tensorfunc::Switch<Author, nv_bfloat16>(mem->gettensors<nv_bfloat16>(this->getvector<string>(0)), *mem->gettensor<int8_t>(this->args[1].textvalue), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
873+
if (cases_type == Precision::Bool)
874+
{
875+
tensorfunc::Switch<Author, nv_bfloat16,bool>(mem->gettensors<nv_bfloat16>(this->getvector<string>(0)), *mem->gettensor<bool>(this->args[1].textvalue), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
876+
}
877+
else
878+
{
879+
tensorfunc::Switch<Author, nv_bfloat16,int32_t>(mem->gettensors<nv_bfloat16>(this->getvector<string>(0)), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
880+
}
852881
break;
853882
case Precision::Int64:
854-
tensorfunc::Switch<Author, int64_t>(mem->gettensors<int64_t>(this->getvector<string>(0)), *mem->gettensor<int8_t>(this->args[1].textvalue), *mem->gettensor<int64_t>(this->returns[0].textvalue));
883+
if (cases_type == Precision::Bool)
884+
{
885+
tensorfunc::Switch<Author, int64_t,bool>(mem->gettensors<int64_t>(this->getvector<string>(0)), *mem->gettensor<bool>(this->args[1].textvalue), *mem->gettensor<int64_t>(this->returns[0].textvalue));
886+
}
887+
else
888+
{
889+
tensorfunc::Switch<Author, int64_t,int32_t>(mem->gettensors<int64_t>(this->getvector<string>(0)), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int64_t>(this->returns[0].textvalue));
890+
}
855891
break;
856892
case Precision::Int32:
857-
tensorfunc::Switch<Author, int32_t>(mem->gettensors<int32_t>(this->getvector<string>(0)), *mem->gettensor<int8_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
893+
if (cases_type == Precision::Bool)
894+
{
895+
tensorfunc::Switch<Author, int32_t,bool>(mem->gettensors<int32_t>(this->getvector<string>(0)), *mem->gettensor<bool>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
896+
}
897+
else
898+
{
899+
tensorfunc::Switch<Author, int32_t,int32_t>(mem->gettensors<int32_t>(this->getvector<string>(0)), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
900+
}
858901
break;
859902
case Precision::Int16:
860-
tensorfunc::Switch<Author, int16_t>(mem->gettensors<int16_t>(this->getvector<string>(0)), *mem->gettensor<int8_t>(this->args[1].textvalue), *mem->gettensor<int16_t>(this->returns[0].textvalue));
903+
if (cases_type == Precision::Bool)
904+
{
905+
tensorfunc::Switch<Author, int16_t,bool>(mem->gettensors<int16_t>(this->getvector<string>(0)), *mem->gettensor<bool>(this->args[1].textvalue), *mem->gettensor<int16_t>(this->returns[0].textvalue));
906+
}
907+
else
908+
{
909+
tensorfunc::Switch<Author, int16_t,int32_t>(mem->gettensors<int16_t>(this->getvector<string>(0)), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int16_t>(this->returns[0].textvalue));
910+
}
861911
break;
862912
case Precision::Int8:
863-
tensorfunc::Switch<Author, int8_t>(mem->gettensors<int8_t>(this->getvector<string>(0)), *mem->gettensor<int8_t>(this->args[1].textvalue), *mem->gettensor<int8_t>(this->returns[0].textvalue));
913+
if (cases_type == Precision::Bool)
914+
{
915+
tensorfunc::Switch<Author, int8_t,bool>(mem->gettensors<int8_t>(this->getvector<string>(0)), *mem->gettensor<bool>(this->args[1].textvalue), *mem->gettensor<int8_t>(this->returns[0].textvalue));
916+
}
917+
else
918+
{
919+
tensorfunc::Switch<Author, int8_t,int32_t>(mem->gettensors<int8_t>(this->getvector<string>(0)), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int8_t>(this->returns[0].textvalue));
920+
}
864921
break;
865922
case Precision::Bool:
866-
tensorfunc::Switch<Author, bool>(mem->gettensors<bool>(this->getvector<string>(0)), *mem->gettensor<int8_t>(this->args[1].textvalue), *mem->gettensor<bool>(this->returns[0].textvalue));
923+
if (cases_type == Precision::Bool)
924+
{
925+
tensorfunc::Switch<Author, bool,bool>(mem->gettensors<bool>(this->getvector<string>(0)),*mem->gettensor<bool>(this->args[1].textvalue), *mem->gettensor<bool>(this->returns[0].textvalue));
926+
}
927+
else
928+
{
929+
tensorfunc::Switch<Author, bool,int32_t>(mem->gettensors<bool>(this->getvector<string>(0)),*mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<bool>(this->returns[0].textvalue));
930+
}
867931
break;
868932
default:
869933
error = "Unsupported type: " + precision_str(C_type);

excuter/op-mem-ompsimd/src/client/tfs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ namespace deepx::tf
492492
tffactory.add_tf(std::make_shared<Switch<miaobyte>>(vector<Param>(
493493
{
494494
Param("tensors", DataCategory::ListTensor, Precision::Any),
495-
Param("cases", DataCategory::Tensor, Precision::Int8),
495+
Param("cases", DataCategory::Tensor, Precision::Bool|Precision::Int32),
496496
}),
497497
vector<Param>(
498498
{

excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -995,8 +995,8 @@ namespace deepx::tensorfunc
995995
{
996996
for (int j = 0; j < i_end; j++)
997997
{
998-
int which_tensor=cases.data[i];
999-
C.data[i+j]=tensors[which_tensor]->data[i];
998+
int which_tensor=cases.data[i+j];
999+
C.data[i+j]=tensors[which_tensor]->data[i+j];
10001000
} });
10011001
}
10021002
else

0 commit comments

Comments
 (0)