Skip to content

Commit 02472e8

Browse files
committed
[PWGDQ] simplified MLResponse features retrieval
Use a single macro for retrieving features from tracks. Also removed alternative feature getter functions.
1 parent cab539f commit 02472e8

2 files changed

Lines changed: 74 additions & 192 deletions

File tree

PWGDQ/Core/MuonMatchingMlResponse.h

Lines changed: 73 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -36,72 +36,12 @@
3636
// matches the entry in EnumInputFeatures associated to this FEATURE
3737
// if so, the inputFeatures vector is filled with the FEATURE's value
3838
// by calling the corresponding GETTER=FEATURE from track
39-
#define CHECK_AND_FILL_MUON_TRACK(FEATURE, GETTER) \
39+
#define CHECK_AND_FILL_FEATURE(FEATURE, GETTER) \
4040
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
41-
inputFeature = muon.GETTER(); \
41+
inputFeature = (GETTER); \
4242
break; \
4343
}
4444

45-
// Check if the index of mCachedIndices (index associated to a FEATURE)
46-
// matches the entry in EnumInputFeatures associated to this FEATURE
47-
// if so, the inputFeatures vector is filled with the FEATURE's value
48-
// by calling the corresponding GETTER=FEATURE from track
49-
#define CHECK_AND_FILL_MUONGLOB_TRACK(FEATURE, GETTER) \
50-
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
51-
inputFeature = muonglob.GETTER(); \
52-
break; \
53-
}
54-
55-
// Check if the index of mCachedIndices (index associated to a FEATURE)
56-
// matches the entry in EnumInputFeatures associated to this FEATURE
57-
// if so, the inputFeatures vector is filled with the FEATURE's value
58-
// by calling the corresponding GETTER=FEATURE from track
59-
#define CHECK_AND_FILL_MFT_TRACK(FEATURE, GETTER) \
60-
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
61-
inputFeature = mft.GETTER(); \
62-
break; \
63-
}
64-
65-
// Check if the index of mCachedIndices (index associated to a FEATURE)
66-
// matches the entry in EnumInputFeatures associated to this FEATURE
67-
// if so, the inputFeatures vector is filled with the FEATURE's value
68-
// by calling the corresponding GETTER=FEATURE from track
69-
#define CHECK_AND_FILL_MUON_COV(FEATURE, GETTER) \
70-
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
71-
inputFeature = muoncov.GETTER(); \
72-
break; \
73-
}
74-
75-
// Check if the index of mCachedIndices (index associated to a FEATURE)
76-
// matches the entry in EnumInputFeatures associated to this FEATURE
77-
// if so, the inputFeatures vector is filled with the FEATURE's value
78-
// by calling the corresponding GETTER=FEATURE from track
79-
#define CHECK_AND_FILL_MFT_COV(FEATURE, GETTER) \
80-
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
81-
inputFeature = mftcov.GETTER(); \
82-
break; \
83-
}
84-
85-
// Check if the index of mCachedIndices (index associated to a FEATURE)
86-
// matches the entry in EnumInputFeatures associated to this FEATURE
87-
// if so, the inputFeatures vector is filled with the FEATURE's value
88-
// by calling the corresponding GETTER1 and GETTER2 from track.
89-
#define CHECK_AND_FILL_MFTMUON_DIFF(FEATURE, GETTER1, GETTER2) \
90-
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
91-
inputFeature = (mft.GETTER2() - muon.GETTER1()); \
92-
break; \
93-
}
94-
95-
// Check if the index of mCachedIndices (index associated to a FEATURE)
96-
// matches the entry in EnumInputFeatures associated to this FEATURE
97-
// if so, the inputFeatures vector is filled with the FEATURE's value
98-
// by calling the corresponding GETTER=FEATURE from collision
99-
#define CHECK_AND_FILL_MFTMUON_COLLISION(GETTER) \
100-
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::GETTER): { \
101-
inputFeature = collision.GETTER(); \
102-
break; \
103-
}
104-
10545
namespace o2::analysis
10646
{
10747
// possible input features for ML
@@ -190,148 +130,90 @@ class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
190130
/// Default destructor
191131
virtual ~MlResponseMFTMuonMatch() = default;
192132

193-
template <typename T1, typename T2, typename C1, typename C2, typename U>
194-
float returnFeature(uint8_t idx, T1 const& muon, T2 const& mft, C1 const& muoncov, C2 const& mftcov, U const& collision)
195-
{
196-
float inputFeature = 0.;
197-
switch (idx) {
198-
CHECK_AND_FILL_MFT_TRACK(zMatching, z);
199-
CHECK_AND_FILL_MFT_TRACK(xMFT, x);
200-
CHECK_AND_FILL_MFT_TRACK(yMFT, y);
201-
CHECK_AND_FILL_MFT_TRACK(qOverptMFT, signed1Pt);
202-
CHECK_AND_FILL_MFT_TRACK(tglMFT, tgl);
203-
CHECK_AND_FILL_MFT_TRACK(phiMFT, phi);
204-
CHECK_AND_FILL_MFT_TRACK(chi2MFT, chi2);
205-
CHECK_AND_FILL_MFT_TRACK(nClustersMFT, nClusters);
206-
CHECK_AND_FILL_MUON_TRACK(dcaXY, fwddcaXY);
207-
CHECK_AND_FILL_MUON_TRACK(dcaZ, fwddcaz);
208-
CHECK_AND_FILL_MUON_TRACK(xMCH, x);
209-
CHECK_AND_FILL_MUON_TRACK(yMCH, y);
210-
CHECK_AND_FILL_MUON_TRACK(qOverptMCH, signed1Pt);
211-
CHECK_AND_FILL_MUON_TRACK(tglMCH, tgl);
212-
CHECK_AND_FILL_MUON_TRACK(phiMCH, phi);
213-
CHECK_AND_FILL_MUON_TRACK(nClustersMCH, nClusters);
214-
CHECK_AND_FILL_MUON_TRACK(chi2MCH, chi2);
215-
CHECK_AND_FILL_MUON_TRACK(pdca, pDca);
216-
CHECK_AND_FILL_MFT_COV(cXXMFT, cXX);
217-
CHECK_AND_FILL_MFT_COV(cXYMFT, cXY);
218-
CHECK_AND_FILL_MFT_COV(cYYMFT, cYY);
219-
CHECK_AND_FILL_MFT_COV(cPhiYMFT, cPhiY);
220-
CHECK_AND_FILL_MFT_COV(cPhiXMFT, cPhiX);
221-
CHECK_AND_FILL_MFT_COV(cPhiPhiMFT, cPhiPhi);
222-
CHECK_AND_FILL_MFT_COV(cTglYMFT, cTglY);
223-
CHECK_AND_FILL_MFT_COV(cTglXMFT, cTglX);
224-
CHECK_AND_FILL_MFT_COV(cTglPhiMFT, cTglPhi);
225-
CHECK_AND_FILL_MFT_COV(cTglTglMFT, cTglTgl);
226-
CHECK_AND_FILL_MFT_COV(c1PtYMFT, c1PtY);
227-
CHECK_AND_FILL_MFT_COV(c1PtXMFT, c1PtX);
228-
CHECK_AND_FILL_MFT_COV(c1PtPhiMFT, c1PtPhi);
229-
CHECK_AND_FILL_MFT_COV(c1PtTglMFT, c1PtTgl);
230-
CHECK_AND_FILL_MFT_COV(c1Pt21Pt2MFT, c1Pt21Pt2);
231-
CHECK_AND_FILL_MUON_COV(cXXMCH, cXX);
232-
CHECK_AND_FILL_MUON_COV(cXYMCH, cXY);
233-
CHECK_AND_FILL_MUON_COV(cYYMCH, cYY);
234-
CHECK_AND_FILL_MUON_COV(cPhiYMCH, cPhiY);
235-
CHECK_AND_FILL_MUON_COV(cPhiXMCH, cPhiX);
236-
CHECK_AND_FILL_MUON_COV(cPhiPhiMCH, cPhiPhi);
237-
CHECK_AND_FILL_MUON_COV(cTglYMCH, cTglY);
238-
CHECK_AND_FILL_MUON_COV(cTglXMCH, cTglX);
239-
CHECK_AND_FILL_MUON_COV(cTglPhiMCH, cTglPhi);
240-
CHECK_AND_FILL_MUON_COV(cTglTglMCH, cTglTgl);
241-
CHECK_AND_FILL_MUON_COV(c1PtYMCH, c1PtY);
242-
CHECK_AND_FILL_MUON_COV(c1PtXMCH, c1PtX);
243-
CHECK_AND_FILL_MUON_COV(c1PtPhiMCH, c1PtPhi);
244-
CHECK_AND_FILL_MUON_COV(c1PtTglMCH, c1PtTgl);
245-
CHECK_AND_FILL_MUON_COV(c1Pt21Pt2MCH, c1Pt21Pt2);
246-
CHECK_AND_FILL_MFTMUON_COLLISION(posX);
247-
CHECK_AND_FILL_MFTMUON_COLLISION(posY);
248-
CHECK_AND_FILL_MFTMUON_COLLISION(posZ);
249-
CHECK_AND_FILL_MFTMUON_COLLISION(numContrib);
250-
CHECK_AND_FILL_MFTMUON_COLLISION(trackOccupancyInTimeRange);
251-
CHECK_AND_FILL_MFTMUON_COLLISION(ft0cOccupancyInTimeRange);
252-
CHECK_AND_FILL_MFTMUON_COLLISION(multFT0A);
253-
CHECK_AND_FILL_MFTMUON_COLLISION(multFT0C);
254-
CHECK_AND_FILL_MFTMUON_COLLISION(multNTracksPV);
255-
CHECK_AND_FILL_MFTMUON_COLLISION(multNTracksPVeta1);
256-
CHECK_AND_FILL_MFTMUON_COLLISION(multNTracksPVetaHalf);
257-
CHECK_AND_FILL_MFTMUON_COLLISION(isInelGt0);
258-
CHECK_AND_FILL_MFTMUON_COLLISION(isInelGt1);
259-
CHECK_AND_FILL_MFTMUON_COLLISION(multFT0M);
260-
CHECK_AND_FILL_MFTMUON_COLLISION(centFT0M);
261-
CHECK_AND_FILL_MFTMUON_COLLISION(centFT0A);
262-
CHECK_AND_FILL_MFTMUON_COLLISION(centFT0C);
263-
CHECK_AND_FILL_MUON_TRACK(chi2MCHMFT, chi2MatchMCHMFT);
264-
}
265-
return inputFeature;
266-
}
267-
268-
template <typename T1, typename T2, typename T3, typename U>
269-
float returnFeatureGlob(uint8_t idx, T1 const& muonglob, T2 const& muon, T3 const& mft, U const& collision)
270-
{
271-
float inputFeature = 0.;
272-
switch (idx) {
273-
CHECK_AND_FILL_MFT_TRACK(xMFT, getX);
274-
CHECK_AND_FILL_MFT_TRACK(yMFT, getY);
275-
CHECK_AND_FILL_MFT_TRACK(qOverptMFT, getInvQPt);
276-
CHECK_AND_FILL_MFT_TRACK(tglMFT, getTanl);
277-
CHECK_AND_FILL_MFT_TRACK(phiMFT, getPhi);
278-
CHECK_AND_FILL_MFT_TRACK(chi2MFT, getTrackChi2);
279-
CHECK_AND_FILL_MUON_TRACK(xMCH, getX);
280-
CHECK_AND_FILL_MUON_TRACK(yMCH, getY);
281-
CHECK_AND_FILL_MUON_TRACK(qOverptMCH, getInvQPt);
282-
CHECK_AND_FILL_MUON_TRACK(tglMCH, getTanl);
283-
CHECK_AND_FILL_MUON_TRACK(phiMCH, getPhi);
284-
CHECK_AND_FILL_MUON_TRACK(chi2MCH, getTrackChi2);
285-
CHECK_AND_FILL_MUONGLOB_TRACK(chi2MCHMFT, chi2MatchMCHMFT);
286-
CHECK_AND_FILL_MUONGLOB_TRACK(chi2GlobMUON, chi2);
287-
CHECK_AND_FILL_MUONGLOB_TRACK(Rabs, rAtAbsorberEnd);
288-
// Below are dummy files to remove warning of unused parameters
289-
CHECK_AND_FILL_MFTMUON_COLLISION(posZ);
290-
}
291-
return inputFeature;
292-
}
293-
294-
template <typename T1>
295-
float returnFeatureTest(uint8_t idx, T1 const& muon)
133+
template <typename T1, typename T2, typename T3, typename T4, typename T5, typename U>
134+
float returnFeature(uint8_t idx, T1 const& muon, T2 const& mft, T3 const& mch, T4 const& mftprop, T5 const& mchprop, U const& collision)
296135
{
297136
float inputFeature = 0.;
298137
switch (idx) {
299-
CHECK_AND_FILL_MUON_TRACK(chi2MCHMFT, chi2MatchMCHMFT);
138+
CHECK_AND_FILL_FEATURE(zMatching, mftprop.getZ());
139+
CHECK_AND_FILL_FEATURE(xMFT, mftprop.getX());
140+
CHECK_AND_FILL_FEATURE(yMFT, mftprop.getY());
141+
CHECK_AND_FILL_FEATURE(qOverptMFT, mftprop.getInvQPt());
142+
CHECK_AND_FILL_FEATURE(tglMFT, mftprop.getTanl());
143+
CHECK_AND_FILL_FEATURE(phiMFT, mftprop.getPhi());
144+
CHECK_AND_FILL_FEATURE(chi2MFT, mft.chi2());
145+
CHECK_AND_FILL_FEATURE(nClustersMFT, mft.nClusters());
146+
/*dummy value*/ CHECK_AND_FILL_FEATURE(dcaXY, 0);
147+
/*dummy value*/ CHECK_AND_FILL_FEATURE(dcaZ, 0);
148+
CHECK_AND_FILL_FEATURE(xMCH, mchprop.getX());
149+
CHECK_AND_FILL_FEATURE(yMCH, mchprop.getY());
150+
CHECK_AND_FILL_FEATURE(qOverptMCH, mchprop.getInvQPt());
151+
CHECK_AND_FILL_FEATURE(tglMCH, mchprop.getTanl());
152+
CHECK_AND_FILL_FEATURE(phiMCH, mchprop.getPhi());
153+
CHECK_AND_FILL_FEATURE(nClustersMCH, mch.nClusters());
154+
CHECK_AND_FILL_FEATURE(chi2MCH, mch.chi2());
155+
CHECK_AND_FILL_FEATURE(pdca, muon.pDca());
156+
CHECK_AND_FILL_FEATURE(cXXMFT, mftprop.getCovariances()(0, 0));
157+
CHECK_AND_FILL_FEATURE(cXYMFT, mftprop.getCovariances()(0, 1));
158+
CHECK_AND_FILL_FEATURE(cYYMFT, mftprop.getCovariances()(1, 1));
159+
CHECK_AND_FILL_FEATURE(cPhiXMFT, mftprop.getCovariances()(0, 2));
160+
CHECK_AND_FILL_FEATURE(cPhiYMFT, mftprop.getCovariances()(1, 2));
161+
CHECK_AND_FILL_FEATURE(cPhiPhiMFT, mftprop.getCovariances()(2, 2));
162+
CHECK_AND_FILL_FEATURE(cTglXMFT, mftprop.getCovariances()(0, 3));
163+
CHECK_AND_FILL_FEATURE(cTglYMFT, mftprop.getCovariances()(1, 3));
164+
CHECK_AND_FILL_FEATURE(cTglPhiMFT, mftprop.getCovariances()(2, 3));
165+
CHECK_AND_FILL_FEATURE(cTglTglMFT, mftprop.getCovariances()(3, 3));
166+
CHECK_AND_FILL_FEATURE(c1PtXMFT, mftprop.getCovariances()(0, 4));
167+
CHECK_AND_FILL_FEATURE(c1PtYMFT, mftprop.getCovariances()(1, 4));
168+
CHECK_AND_FILL_FEATURE(c1PtPhiMFT, mftprop.getCovariances()(2, 4));
169+
CHECK_AND_FILL_FEATURE(c1PtTglMFT, mftprop.getCovariances()(3, 4));
170+
CHECK_AND_FILL_FEATURE(c1Pt21Pt2MFT, mftprop.getCovariances()(4, 4));
171+
CHECK_AND_FILL_FEATURE(cXXMCH, mchprop.getCovariances()(0, 0));
172+
CHECK_AND_FILL_FEATURE(cXYMCH, mchprop.getCovariances()(0, 1));
173+
CHECK_AND_FILL_FEATURE(cYYMCH, mchprop.getCovariances()(1, 1));
174+
CHECK_AND_FILL_FEATURE(cPhiXMCH, mchprop.getCovariances()(0, 2));
175+
CHECK_AND_FILL_FEATURE(cPhiYMCH, mchprop.getCovariances()(1, 2));
176+
CHECK_AND_FILL_FEATURE(cPhiPhiMCH, mchprop.getCovariances()(2, 2));
177+
CHECK_AND_FILL_FEATURE(cTglXMCH, mchprop.getCovariances()(0, 3));
178+
CHECK_AND_FILL_FEATURE(cTglYMCH, mchprop.getCovariances()(1, 3));
179+
CHECK_AND_FILL_FEATURE(cTglPhiMCH, mchprop.getCovariances()(2, 3));
180+
CHECK_AND_FILL_FEATURE(cTglTglMCH, mchprop.getCovariances()(3, 3));
181+
CHECK_AND_FILL_FEATURE(c1PtXMCH, mchprop.getCovariances()(0, 4));
182+
CHECK_AND_FILL_FEATURE(c1PtYMCH, mchprop.getCovariances()(1, 4));
183+
CHECK_AND_FILL_FEATURE(c1PtPhiMCH, mchprop.getCovariances()(2, 4));
184+
CHECK_AND_FILL_FEATURE(c1PtTglMCH, mchprop.getCovariances()(3, 4));
185+
CHECK_AND_FILL_FEATURE(c1Pt21Pt2MCH, mchprop.getCovariances()(4, 4));
186+
CHECK_AND_FILL_FEATURE(posX, collision.posX());
187+
CHECK_AND_FILL_FEATURE(posY, collision.posY());
188+
CHECK_AND_FILL_FEATURE(posZ, collision.posZ());
189+
CHECK_AND_FILL_FEATURE(numContrib, collision.numContrib());
190+
CHECK_AND_FILL_FEATURE(trackOccupancyInTimeRange, collision.trackOccupancyInTimeRange());
191+
CHECK_AND_FILL_FEATURE(ft0cOccupancyInTimeRange, collision.ft0cOccupancyInTimeRange());
192+
CHECK_AND_FILL_FEATURE(multFT0A, collision.multFT0A());
193+
CHECK_AND_FILL_FEATURE(multFT0C, collision.multFT0C());
194+
CHECK_AND_FILL_FEATURE(multNTracksPV, collision.multNTracksPV());
195+
CHECK_AND_FILL_FEATURE(multNTracksPVeta1, collision.multNTracksPVeta1());
196+
CHECK_AND_FILL_FEATURE(multNTracksPVetaHalf, collision.multNTracksPVetaHalf());
197+
CHECK_AND_FILL_FEATURE(isInelGt0, collision.isInelGt0());
198+
CHECK_AND_FILL_FEATURE(isInelGt1, collision.isInelGt1());
199+
CHECK_AND_FILL_FEATURE(multFT0M, collision.multFT0M());
200+
CHECK_AND_FILL_FEATURE(centFT0M, collision.centFT0M());
201+
CHECK_AND_FILL_FEATURE(centFT0A, collision.centFT0A());
202+
CHECK_AND_FILL_FEATURE(centFT0C, collision.centFT0C());
203+
CHECK_AND_FILL_FEATURE(chi2MCHMFT, muon.chi2MatchMCHMFT());
300204
}
301205
return inputFeature;
302206
}
303207

304208
/// Method to get the input features vector needed for ML inference
305209
/// \param track is the single track, \param collision is the collision
306210
/// \return inputFeatures vector
307-
template <typename T1, typename T2, typename C1, typename C2, typename U>
308-
std::vector<float> getInputFeatures(T1 const& muon, T2 const& mft, C1 const& muoncov, C2 const& mftcov, U const& collision)
309-
{
310-
std::vector<float> inputFeatures;
311-
for (const auto& idx : MlResponse<TypeOutputScore>::mCachedIndices) {
312-
float inputFeature = returnFeature(idx, muon, mft, muoncov, mftcov, collision);
313-
inputFeatures.emplace_back(inputFeature);
314-
}
315-
return inputFeatures;
316-
}
317-
318-
template <typename T1>
319-
std::vector<float> getInputFeaturesTest(T1 const& muon)
320-
{
321-
std::vector<float> inputFeatures;
322-
for (const auto& idx : MlResponse<TypeOutputScore>::mCachedIndices) {
323-
float inputFeature = returnFeatureTest(idx, muon);
324-
inputFeatures.emplace_back(inputFeature);
325-
}
326-
return inputFeatures;
327-
}
328-
329-
template <typename T1, typename T2, typename T3, typename U>
330-
std::vector<float> getInputFeaturesGlob(T1 const& muonglob, T2 const& muon, T3 const& mft, U const& collision)
211+
template <typename T1, typename T2, typename T3, typename T4, typename T5, typename U>
212+
std::vector<float> getInputFeatures(T1 const& muon, T2 const& mft, T3 const& mch, T4 const& mftprop, T5 const& mchprop, U const& collision)
331213
{
332214
std::vector<float> inputFeatures;
333215
for (const auto& idx : MlResponse<TypeOutputScore>::mCachedIndices) {
334-
float inputFeature = returnFeatureGlob(idx, muonglob, muon, mft, collision);
216+
float inputFeature = returnFeature(idx, muon, mft, mch, mftprop, mchprop, collision);
335217
inputFeatures.emplace_back(inputFeature);
336218
}
337219
return inputFeatures;

PWGDQ/Tasks/qaMatching.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2734,7 +2734,7 @@ struct QaMatching {
27342734

27352735
// run the ML model
27362736
std::vector<float> output;
2737-
std::vector<float> inputML = mlResponse.getInputFeaturesGlob(muonTrack, mchTrackProp, mftTrackProp, collision);
2737+
std::vector<float> inputML = mlResponse.getInputFeatures(muonTrack, mftTrack, mchTrack, mftTrackProp, mchTrackProp, collision);
27382738
mlResponse.isSelectedMl(inputML, 0, output);
27392739
float matchScore = output[0];
27402740
float matchChi2Prod = muonTrack.chi2MatchMCHMFT() / MatchingDegreesOfFreedom;

0 commit comments

Comments
 (0)