@@ -835,35 +835,99 @@ namespace deepx::tf
835835 {
836836
837837 Precision C_type = mem->gettensor (this ->returns [0 ].textvalue ).get ()->shape .dtype ;
838-
838+ Precision cases_type = mem->gettensor (this ->args [1 ].textvalue ).get ()->shape .dtype ;
839+
839840 switch (C_type)
840841 {
841842 case Precision::Float64:
842- tensorfunc::Switch<Author, double >(mem->gettensors <double >(this ->getvector <string>(0 )), *mem->gettensor <int8_t >(this ->args [1 ].textvalue ), *mem->gettensor <double >(this ->returns [0 ].textvalue ));
843- break ;
844- case Precision::Float32:
845- tensorfunc::Switch<Author, float >(mem->gettensors <float >(this ->getvector <string>(0 )), *mem->gettensor <int8_t >(this ->args [1 ].textvalue ), *mem->gettensor <float >(this ->returns [0 ].textvalue ));
843+ if (cases_type == Precision::Bool)
844+ {
845+ tensorfunc::Switch<Author, double ,bool >(mem->gettensors <double >(this ->getvector <string>(0 )), *mem->gettensor <bool >(this ->args [1 ].textvalue ), *mem->gettensor <double >(this ->returns [0 ].textvalue ));
846+ }
847+ else
848+ {
849+ tensorfunc::Switch<Author, double ,int32_t >(mem->gettensors <double >(this ->getvector <string>(0 )), *mem->gettensor <int32_t >(this ->args [1 ].textvalue ), *mem->gettensor <double >(this ->returns [0 ].textvalue ));
850+ }
851+ break ;
852+ case Precision::Float32:
853+ if (cases_type == Precision::Bool)
854+ {
855+ tensorfunc::Switch<Author, float ,bool >(mem->gettensors <float >(this ->getvector <string>(0 )), *mem->gettensor <bool >(this ->args [1 ].textvalue ), *mem->gettensor <float >(this ->returns [0 ].textvalue ));
856+ }
857+ else
858+ {
859+ tensorfunc::Switch<Author, float ,int32_t >(mem->gettensors <float >(this ->getvector <string>(0 )), *mem->gettensor <int32_t >(this ->args [1 ].textvalue ), *mem->gettensor <float >(this ->returns [0 ].textvalue ));
860+ }
846861 break ;
847862 case Precision::Float16:
848- tensorfunc::Switch<Author, half>(mem->gettensors <half>(this ->getvector <string>(0 )), *mem->gettensor <int8_t >(this ->args [1 ].textvalue ), *mem->gettensor <half>(this ->returns [0 ].textvalue ));
863+ if (cases_type == Precision::Bool)
864+ {
865+ tensorfunc::Switch<Author, half,bool >(mem->gettensors <half>(this ->getvector <string>(0 )), *mem->gettensor <bool >(this ->args [1 ].textvalue ), *mem->gettensor <half>(this ->returns [0 ].textvalue ));
866+ }
867+ else
868+ {
869+ tensorfunc::Switch<Author, half,int32_t >(mem->gettensors <half>(this ->getvector <string>(0 )), *mem->gettensor <int32_t >(this ->args [1 ].textvalue ), *mem->gettensor <half>(this ->returns [0 ].textvalue ));
870+ }
849871 break ;
850872 case Precision::BFloat16:
851- tensorfunc::Switch<Author, nv_bfloat16>(mem->gettensors <nv_bfloat16>(this ->getvector <string>(0 )), *mem->gettensor <int8_t >(this ->args [1 ].textvalue ), *mem->gettensor <nv_bfloat16>(this ->returns [0 ].textvalue ));
873+ if (cases_type == Precision::Bool)
874+ {
875+ tensorfunc::Switch<Author, nv_bfloat16,bool >(mem->gettensors <nv_bfloat16>(this ->getvector <string>(0 )), *mem->gettensor <bool >(this ->args [1 ].textvalue ), *mem->gettensor <nv_bfloat16>(this ->returns [0 ].textvalue ));
876+ }
877+ else
878+ {
879+ tensorfunc::Switch<Author, nv_bfloat16,int32_t >(mem->gettensors <nv_bfloat16>(this ->getvector <string>(0 )), *mem->gettensor <int32_t >(this ->args [1 ].textvalue ), *mem->gettensor <nv_bfloat16>(this ->returns [0 ].textvalue ));
880+ }
852881 break ;
853882 case Precision::Int64:
854- tensorfunc::Switch<Author, int64_t >(mem->gettensors <int64_t >(this ->getvector <string>(0 )), *mem->gettensor <int8_t >(this ->args [1 ].textvalue ), *mem->gettensor <int64_t >(this ->returns [0 ].textvalue ));
883+ if (cases_type == Precision::Bool)
884+ {
885+ tensorfunc::Switch<Author, int64_t ,bool >(mem->gettensors <int64_t >(this ->getvector <string>(0 )), *mem->gettensor <bool >(this ->args [1 ].textvalue ), *mem->gettensor <int64_t >(this ->returns [0 ].textvalue ));
886+ }
887+ else
888+ {
889+ tensorfunc::Switch<Author, int64_t ,int32_t >(mem->gettensors <int64_t >(this ->getvector <string>(0 )), *mem->gettensor <int32_t >(this ->args [1 ].textvalue ), *mem->gettensor <int64_t >(this ->returns [0 ].textvalue ));
890+ }
855891 break ;
856892 case Precision::Int32:
857- tensorfunc::Switch<Author, int32_t >(mem->gettensors <int32_t >(this ->getvector <string>(0 )), *mem->gettensor <int8_t >(this ->args [1 ].textvalue ), *mem->gettensor <int32_t >(this ->returns [0 ].textvalue ));
893+ if (cases_type == Precision::Bool)
894+ {
895+ tensorfunc::Switch<Author, int32_t ,bool >(mem->gettensors <int32_t >(this ->getvector <string>(0 )), *mem->gettensor <bool >(this ->args [1 ].textvalue ), *mem->gettensor <int32_t >(this ->returns [0 ].textvalue ));
896+ }
897+ else
898+ {
899+ tensorfunc::Switch<Author, int32_t ,int32_t >(mem->gettensors <int32_t >(this ->getvector <string>(0 )), *mem->gettensor <int32_t >(this ->args [1 ].textvalue ), *mem->gettensor <int32_t >(this ->returns [0 ].textvalue ));
900+ }
858901 break ;
859902 case Precision::Int16:
860- tensorfunc::Switch<Author, int16_t >(mem->gettensors <int16_t >(this ->getvector <string>(0 )), *mem->gettensor <int8_t >(this ->args [1 ].textvalue ), *mem->gettensor <int16_t >(this ->returns [0 ].textvalue ));
903+ if (cases_type == Precision::Bool)
904+ {
905+ tensorfunc::Switch<Author, int16_t ,bool >(mem->gettensors <int16_t >(this ->getvector <string>(0 )), *mem->gettensor <bool >(this ->args [1 ].textvalue ), *mem->gettensor <int16_t >(this ->returns [0 ].textvalue ));
906+ }
907+ else
908+ {
909+ tensorfunc::Switch<Author, int16_t ,int32_t >(mem->gettensors <int16_t >(this ->getvector <string>(0 )), *mem->gettensor <int32_t >(this ->args [1 ].textvalue ), *mem->gettensor <int16_t >(this ->returns [0 ].textvalue ));
910+ }
861911 break ;
862912 case Precision::Int8:
863- tensorfunc::Switch<Author, int8_t >(mem->gettensors <int8_t >(this ->getvector <string>(0 )), *mem->gettensor <int8_t >(this ->args [1 ].textvalue ), *mem->gettensor <int8_t >(this ->returns [0 ].textvalue ));
913+ if (cases_type == Precision::Bool)
914+ {
915+ tensorfunc::Switch<Author, int8_t ,bool >(mem->gettensors <int8_t >(this ->getvector <string>(0 )), *mem->gettensor <bool >(this ->args [1 ].textvalue ), *mem->gettensor <int8_t >(this ->returns [0 ].textvalue ));
916+ }
917+ else
918+ {
919+ tensorfunc::Switch<Author, int8_t ,int32_t >(mem->gettensors <int8_t >(this ->getvector <string>(0 )), *mem->gettensor <int32_t >(this ->args [1 ].textvalue ), *mem->gettensor <int8_t >(this ->returns [0 ].textvalue ));
920+ }
864921 break ;
865922 case Precision::Bool:
866- tensorfunc::Switch<Author, bool >(mem->gettensors <bool >(this ->getvector <string>(0 )), *mem->gettensor <int8_t >(this ->args [1 ].textvalue ), *mem->gettensor <bool >(this ->returns [0 ].textvalue ));
923+ if (cases_type == Precision::Bool)
924+ {
925+ tensorfunc::Switch<Author, bool ,bool >(mem->gettensors <bool >(this ->getvector <string>(0 )),*mem->gettensor <bool >(this ->args [1 ].textvalue ), *mem->gettensor <bool >(this ->returns [0 ].textvalue ));
926+ }
927+ else
928+ {
929+ tensorfunc::Switch<Author, bool ,int32_t >(mem->gettensors <bool >(this ->getvector <string>(0 )),*mem->gettensor <int32_t >(this ->args [1 ].textvalue ), *mem->gettensor <bool >(this ->returns [0 ].textvalue ));
930+ }
867931 break ;
868932 default :
869933 error = " Unsupported type: " + precision_str (C_type);
0 commit comments