33
44#include " deepx/op/op.hpp"
55#include " deepx/tensorfunc/init.hpp"
6+ #include " stdutil/num.hpp"
67namespace deepx ::op{
78 template <typename T>
89 class Uniform : public OpT <T>{
@@ -18,9 +19,15 @@ namespace deepx::op{
1819 }
1920 void forward (mem::Mem &mem) override {
2021 auto output = mem.gettensor <T>(this ->returns [0 ]).get ();
21- T low = mem.getarg <T>(this ->args [0 ]);
22- T high = mem.getarg <T>(this ->args [1 ]);
23- tensorfunc::uniform (*output,low,high);
22+ if (is_float (this ->args [0 ])){
23+ T low = std::stof (this ->args [0 ]);
24+ T high = std::stof (this ->args [1 ]);
25+ tensorfunc::uniform (*output,low,high);
26+ }else {
27+ T low = mem.getarg <T>(this ->args [0 ]);
28+ T high = mem.getarg <T>(this ->args [1 ]);
29+ tensorfunc::uniform (*output,low,high);
30+ }
2431 }
2532 void backward (mem::Mem &mem) override {
2633 throw std::runtime_error (" Uniform op does not support backward" );
@@ -41,8 +48,13 @@ namespace deepx::op{
4148 }
4249 void forward (mem::Mem &mem) override {
4350 auto output = mem.gettensor <T>(this ->returns [0 ]).get ();
44- T value = mem.getarg <T>(this ->args [0 ]);
45- tensorfunc::constant (*output,value);
51+ if (is_float (this ->args [0 ])){
52+ T value = std::stof (this ->args [0 ]);
53+ tensorfunc::constant (*output,value);
54+ }else {
55+ T value = mem.getarg <T>(this ->args [0 ]);
56+ tensorfunc::constant (*output,value);
57+ }
4658 }
4759 void backward (mem::Mem &mem) override {
4860 throw std::runtime_error (" Constant op does not support backward" );
@@ -63,9 +75,15 @@ namespace deepx::op{
6375 }
6476 void forward (mem::Mem &mem) override {
6577 auto output = mem.gettensor <T>(this ->returns [0 ]).get ();
66- T start = mem.getarg <T>(this ->args [0 ]);
67- T step = mem.getarg <T>(this ->args [1 ]);
68- tensorfunc::arange (*output,start,step);
78+ if (is_float (this ->args [0 ])){
79+ T start = std::stof (this ->args [0 ]);
80+ T step = std::stof (this ->args [1 ]);
81+ tensorfunc::arange (*output,start,step);
82+ }else {
83+ T start = mem.getarg <T>(this ->args [0 ]);
84+ T step = mem.getarg <T>(this ->args [1 ]);
85+ tensorfunc::arange (*output,start,step);
86+ }
6987 }
7088 void backward (mem::Mem &mem) override {
7189 throw std::runtime_error (" Arange op does not support backward" );
0 commit comments