diff --git a/TrkDiag/data/TrkQual_ANN1_v2.onnx b/TrkDiag/data/TrkQual_ANN1_v2.onnx new file mode 100644 index 0000000..7b89bcb Binary files /dev/null and b/TrkDiag/data/TrkQual_ANN1_v2.onnx differ diff --git a/TrkDiag/src/SConscript b/TrkDiag/src/SConscript index 93181d8..efb348f 100644 --- a/TrkDiag/src/SConscript +++ b/TrkDiag/src/SConscript @@ -49,7 +49,8 @@ mainlib = helper.make_mainlib ( [ 'boost_filesystem', 'hep_concurrency', 'TMVA', - 'ROOTTMVASofie' + 'ROOTTMVASofie', + 'onnxruntime' ] ) # Fixme: split into link lists for each module. @@ -101,7 +102,8 @@ helper.make_plugins( [ 'hep_concurrency', 'TMVA', 'ROOTTMVASofie', -'pthread' + 'pthread', + 'onnxruntime' ] ) helper.make_dict_and_map( [ diff --git a/TrkDiag/src/TrackQuality_module.cc b/TrkDiag/src/TrackQuality_module.cc index 2cc6468..4ba84c8 100644 --- a/TrkDiag/src/TrackQuality_module.cc +++ b/TrkDiag/src/TrackQuality_module.cc @@ -21,7 +21,8 @@ // data #include "Offline/RecoDataProducts/inc/KalSeed.hh" #include "Offline/RecoDataProducts/inc/MVAResult.hh" -#include "ArtAnalysis/TrkDiag/inc/TrkQual_ANN1.hxx" +// ONNXRuntime +#include "onnxruntime/core/session/onnxruntime_cxx_api.h" // C++ #include #include @@ -33,10 +34,6 @@ using namespace std; using CLHEP::Hep3Vector; using CLHEP::HepVector; -namespace TMVA_SOFIE_TrkQual_ANN1 { - class Session; -} - namespace mu2e { @@ -49,7 +46,7 @@ namespace mu2e fhicl::Atom kalSeedPtrTag{Name("KalSeedPtrCollection"), Comment("Input tag for KalSeedPtrCollection")}; fhicl::Atom printMVA{Name("PrintMVA"), Comment("Print the MVA used"), false}; - fhicl::Atom datFilename{Name("datFilename"), Comment("Filename for the .dat file to use")}; + fhicl::Atom onnxFilename{Name("onnxFilename"), Comment("Filename for the .onnx file to use")}; fhicl::Atom debug{Name("debugLevel"), Comment("Debug printout level"), 0}; }; @@ -64,20 +61,55 @@ namespace mu2e bool _printMVA; int _debug; - std::shared_ptr mva_; - + ConfigFileLookupPolicy _configFileLookup; + + Ort::Env _env; + Ort::SessionOptions _session_options; + Ort::Session _session; + Ort::AllocatorWithDefaultOptions _allocator; + Ort::AllocatedStringPtr _input_name; + Ort::TypeInfo _type_info; + Ort::ConstTensorTypeAndShapeInfo _tensor_info; + std::vector _input_shape; + size_t _total_size; + Ort::MemoryInfo _memory_info; + Ort::AllocatedStringPtr _output_name; + + std::string print_shape(const std::vector& v) { + std::stringstream ss(""); + for (std::size_t i = 0; i < v.size() - 1; i++) ss << v[i] << "x"; + ss << v[v.size() - 1]; + return ss.str(); + } }; TrackQuality::TrackQuality(const Parameters& conf) : art::EDProducer{conf}, _kalSeedPtrTag(conf().kalSeedPtrTag()), _printMVA(conf().printMVA()), - _debug(conf().debug()) + _debug(conf().debug()), + + _env(ORT_LOGGING_LEVEL_WARNING, "ONNXInference"), + _session(_env, _configFileLookup(conf().onnxFilename()).c_str(), _session_options), + _input_name(_session.GetInputNameAllocated(0, _allocator)), + _type_info(_session.GetInputTypeInfo(0)), + _tensor_info(_type_info.GetTensorTypeAndShapeInfo()), + _input_shape(_tensor_info.GetShape()), // Get input shape from model + _memory_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault)), + _output_name(_session.GetOutputNameAllocated(0, _allocator)) { produces(); - ConfigFileLookupPolicy configFile; - mva_ = std::make_shared(configFile(conf().datFilename())); + // Handle dynamic dimensions if needed + for (auto& dim : _input_shape) { + if (dim == -1) { dim = 1; } // Set dynamic dims to 1 (or your desired value) + } + + // Calculate total size + _total_size = 1; + for (auto dim : _input_shape) { + _total_size *= dim; + } } void TrackQuality::produce(art::Event& event ) { @@ -89,10 +121,13 @@ namespace mu2e event.getByLabel(_kalSeedPtrTag, kalSeedPtrHandle); const auto& kalSeedPtrs = *kalSeedPtrHandle; + // Prepare input tensor + std::vector features(_total_size, 0.0f); // Initialize with zeros + + // Go through the tracks and calculate their track qualities for (const auto& kalSeedPtr : kalSeedPtrs) { const auto& kalSeed = *kalSeedPtr; - std::array features; // the features we trained on // fill the hit count variables int nhits = 0; int nactive = 0; int ndouble = 0; int ndactive = 0; int nnullambig = 0; @@ -152,7 +187,25 @@ namespace mu2e features[5] = -9999; } - std::vector mvaout = mva_->infer(features.data()); + Ort::Value input_tensor = Ort::Value::CreateTensor(_memory_info, + features.data(), + features.size(), + _input_shape.data(), + _input_shape.size() + ); + // Run inference + const char* input_names[] = {_input_name.get()}; + const char* output_names[] = {_output_name.get()}; + auto output_tensors = _session.Run( + Ort::RunOptions{nullptr}, + input_names, + &input_tensor, + 1, + output_names, + 1 + ); + // Get output + float* mvaout = output_tensors[0].GetTensorMutableData(); if (!entrance_found) { mvaout[0] = 0; // this is not a good track