@@ -636,6 +636,75 @@ namespace deepx::tf
636636 }
637637 };
638638
639+ // rsubscalar
640+ template <typename Author>
641+ class RSubScalar : public TF
642+ {
643+ public:
644+ RSubScalar (const vector<Param> &args, const vector<Param> &returns)
645+ {
646+ this ->name = " rsubscalar" ;
647+ this ->metadata .author = Author::name ();
648+ this ->tftype = " elementwise" ;
649+ this ->args = args;
650+ this ->returns = returns;
651+ }
652+
653+ string math_formula () const override
654+ {
655+ return " T3=scalar-T1" ;
656+ }
657+ shared_ptr<TF> clone () const override
658+ {
659+ return make_shared<RSubScalar<Author>>(*this );
660+ }
661+ int run (shared_ptr<MemBase> mem, string &error) override
662+ {
663+ if (!checktensors ({this ->args [0 ].textvalue , this ->returns [0 ].textvalue }, mem, error))
664+ {
665+ return 1 ;
666+ }
667+ Precision a_type = mem->gettensor (this ->args [0 ].textvalue ).get ()->shape .dtype ;
668+ Precision c_type = mem->gettensor (this ->returns [0 ].textvalue ).get ()->shape .dtype ;
669+ if (a_type != c_type)
670+ {
671+ error = " Type mismatch: " + precision_str (a_type) + " != " + precision_str (c_type);
672+ return 1 ;
673+ }
674+ switch (a_type)
675+ {
676+ case Precision::Float64:
677+ tensorfunc::rsubscalar<Author, double >(this ->getvar <double >(1 , mem), *mem->gettensor <double >(this ->args [0 ].textvalue ), *mem->gettensor <double >(this ->returns [0 ].textvalue ));
678+ break ;
679+ case Precision::Float32:
680+ tensorfunc::rsubscalar<Author, float >(this ->getvar <float >(1 , mem), *mem->gettensor <float >(this ->args [0 ].textvalue ), *mem->gettensor <float >(this ->returns [0 ].textvalue ));
681+ break ;
682+ case Precision::Float16:
683+ tensorfunc::rsubscalar<Author, half>(this ->getvar <half>(1 , mem), *mem->gettensor <half>(this ->args [0 ].textvalue ), *mem->gettensor <half>(this ->returns [0 ].textvalue ));
684+ break ;
685+ case Precision::BFloat16:
686+ tensorfunc::rsubscalar<Author, nv_bfloat16>(this ->getvar <nv_bfloat16>(1 , mem), *mem->gettensor <nv_bfloat16>(this ->args [0 ].textvalue ), *mem->gettensor <nv_bfloat16>(this ->returns [0 ].textvalue ));
687+ break ;
688+ case Precision::Int64:
689+ tensorfunc::rsubscalar<Author, int32_t >(this ->getvar <int32_t >(1 , mem), *mem->gettensor <int32_t >(this ->args [0 ].textvalue ), *mem->gettensor <int32_t >(this ->returns [0 ].textvalue ));
690+ break ;
691+ case Precision::Int32:
692+ tensorfunc::rsubscalar<Author, int32_t >(this ->getvar <int32_t >(1 , mem), *mem->gettensor <int32_t >(this ->args [0 ].textvalue ), *mem->gettensor <int32_t >(this ->returns [0 ].textvalue ));
693+ break ;
694+ case Precision::Int16:
695+ tensorfunc::rsubscalar<Author, int16_t >(this ->getvar <int16_t >(1 , mem), *mem->gettensor <int16_t >(this ->args [0 ].textvalue ), *mem->gettensor <int16_t >(this ->returns [0 ].textvalue ));
696+ break ;
697+ case Precision::Int8:
698+ tensorfunc::rsubscalar<Author, int8_t >(this ->getvar <int8_t >(1 , mem), *mem->gettensor <int8_t >(this ->args [0 ].textvalue ), *mem->gettensor <int8_t >(this ->returns [0 ].textvalue ));
699+ break ;
700+ default :
701+ error = " Unsupported dtype: " + precision_str (a_type);
702+ return 1 ;
703+ }
704+ return 0 ;
705+ }
706+ };
707+
639708 template <typename Author>
640709 class Mul : public TF
641710 {
0 commit comments