Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added TrkDiag/data/TrkQual_ANN1_v2.onnx
Binary file not shown.
6 changes: 4 additions & 2 deletions TrkDiag/src/SConscript
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ mainlib = helper.make_mainlib ( [
'boost_filesystem',
'hep_concurrency',
'TMVA',
'ROOTTMVASofie'
'ROOTTMVASofie',
'onnxruntime'
] )

# Fixme: split into link lists for each module.
Expand Down Expand Up @@ -101,7 +102,8 @@ helper.make_plugins( [
'hep_concurrency',
'TMVA',
'ROOTTMVASofie',
'pthread'
'pthread',
'onnxruntime'
] )

helper.make_dict_and_map( [
Expand Down
79 changes: 66 additions & 13 deletions TrkDiag/src/TrackQuality_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
// data
#include "Offline/RecoDataProducts/inc/KalSeed.hh"
#include "Offline/RecoDataProducts/inc/MVAResult.hh"
#include "ArtAnalysis/TrkDiag/inc/TrkQual_ANN1.hxx"
// ONNXRuntime
#include "onnxruntime/core/session/onnxruntime_cxx_api.h"
// C++
#include <iostream>
#include <fstream>
Expand All @@ -33,10 +34,6 @@ using namespace std;
using CLHEP::Hep3Vector;
using CLHEP::HepVector;

namespace TMVA_SOFIE_TrkQual_ANN1 {
class Session;
}

namespace mu2e
{

Expand All @@ -49,7 +46,7 @@ namespace mu2e

fhicl::Atom<art::InputTag> kalSeedPtrTag{Name("KalSeedPtrCollection"), Comment("Input tag for KalSeedPtrCollection")};
fhicl::Atom<bool> printMVA{Name("PrintMVA"), Comment("Print the MVA used"), false};
fhicl::Atom<std::string> datFilename{Name("datFilename"), Comment("Filename for the .dat file to use")};
fhicl::Atom<std::string> onnxFilename{Name("onnxFilename"), Comment("Filename for the .onnx file to use")};
fhicl::Atom<int> debug{Name("debugLevel"), Comment("Debug printout level"), 0};
};

Expand All @@ -64,20 +61,55 @@ namespace mu2e
bool _printMVA;
int _debug;

std::shared_ptr<TMVA_SOFIE_TrkQual_ANN1::Session> mva_;

ConfigFileLookupPolicy _configFileLookup;

Ort::Env _env;
Ort::SessionOptions _session_options;
Ort::Session _session;
Ort::AllocatorWithDefaultOptions _allocator;
Ort::AllocatedStringPtr _input_name;
Ort::TypeInfo _type_info;
Ort::ConstTensorTypeAndShapeInfo _tensor_info;
std::vector<int64_t> _input_shape;
size_t _total_size;
Ort::MemoryInfo _memory_info;
Ort::AllocatedStringPtr _output_name;

std::string print_shape(const std::vector<std::int64_t>& v) {
std::stringstream ss("");
for (std::size_t i = 0; i < v.size() - 1; i++) ss << v[i] << "x";
ss << v[v.size() - 1];
return ss.str();
}
};

TrackQuality::TrackQuality(const Parameters& conf) :
art::EDProducer{conf},
_kalSeedPtrTag(conf().kalSeedPtrTag()),
_printMVA(conf().printMVA()),
_debug(conf().debug())
_debug(conf().debug()),

_env(ORT_LOGGING_LEVEL_WARNING, "ONNXInference"),
_session(_env, _configFileLookup(conf().onnxFilename()).c_str(), _session_options),
_input_name(_session.GetInputNameAllocated(0, _allocator)),
_type_info(_session.GetInputTypeInfo(0)),
_tensor_info(_type_info.GetTensorTypeAndShapeInfo()),
_input_shape(_tensor_info.GetShape()), // Get input shape from model
_memory_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault)),
_output_name(_session.GetOutputNameAllocated(0, _allocator))
{
produces<MVAResultCollection>();

ConfigFileLookupPolicy configFile;
mva_ = std::make_shared<TMVA_SOFIE_TrkQual_ANN1::Session>(configFile(conf().datFilename()));
// Handle dynamic dimensions if needed
for (auto& dim : _input_shape) {
if (dim == -1) { dim = 1; } // Set dynamic dims to 1 (or your desired value)
}

// Calculate total size
_total_size = 1;
for (auto dim : _input_shape) {
_total_size *= dim;
}
}

void TrackQuality::produce(art::Event& event ) {
Expand All @@ -89,10 +121,13 @@ namespace mu2e
event.getByLabel(_kalSeedPtrTag, kalSeedPtrHandle);
const auto& kalSeedPtrs = *kalSeedPtrHandle;

// Prepare input tensor
std::vector<float> features(_total_size, 0.0f); // Initialize with zeros


// Go through the tracks and calculate their track qualities
for (const auto& kalSeedPtr : kalSeedPtrs) {
const auto& kalSeed = *kalSeedPtr;
std::array<float,7> features; // the features we trained on

// fill the hit count variables
int nhits = 0; int nactive = 0; int ndouble = 0; int ndactive = 0; int nnullambig = 0;
Expand Down Expand Up @@ -152,7 +187,25 @@ namespace mu2e
features[5] = -9999;
}

std::vector<float> mvaout = mva_->infer(features.data());
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(_memory_info,
features.data(),
features.size(),
_input_shape.data(),
_input_shape.size()
);
// Run inference
const char* input_names[] = {_input_name.get()};
const char* output_names[] = {_output_name.get()};
auto output_tensors = _session.Run(
Ort::RunOptions{nullptr},
input_names,
&input_tensor,
1,
output_names,
1
);
// Get output
float* mvaout = output_tensors[0].GetTensorMutableData<float>();

if (!entrance_found) {
mvaout[0] = 0; // this is not a good track
Expand Down