Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package org.jlab.rec.ahdc.AI;

/** Normalization and graph-construction constants for the GNN track finder.
* Mirrors track-finding/gnn/config.py — keep in sync with the training config.
*/
final class GNNConstants {
private GNNConstants() {}

static final int NODE_FEAT_DIM = 10;
static final int EDGE_FEAT_DIM = 9;

// Model architecture parameters (control the minimum graph size at inference).
// GravNet progressive-k reaches 2*k, topk uses k+1 → N_nodes >= 2*k + 2.
// The exported model clamps topk(k+1) to N internally (see
// track-finding/export_torchscript.py::_knn_indices), so any graph with
// >=3 nodes runs without crashing. Smaller graphs can't form any edge
// with the MAX_LAYER_GAP rule anyway, so we skip them here.
static final int MIN_NODES = 3;

// Graph construction
static final int MAX_LAYER_GAP = 2;
static final double MAX_EDGE_DISTANCE = 35.0; // mm
static final double MAX_EDGE_DIST_SQ = MAX_EDGE_DISTANCE * MAX_EDGE_DISTANCE;

// Feature normalization
static final double MAX_R = 100.0; // mm
static final double DOCA_STD = 10.0; // mm
static final double Z_HALF_LENGTH = 200.0; // mm
static final double STEREO_ANGLE_MAX = 0.03; // rad
static final double STEREO_SCALE = 1.0 / STEREO_ANGLE_MAX;

// ATOF abs_layer convention from Python's build_graph
static final int ATOF_BAR_ABS_LAYER = 10; // component == 10
static final int ATOF_WEDGE_ABS_LAYER = 11; // all other components

// Track extraction: connected components at a single score threshold, matching
// gnn/evaluate.py (extract_tracks(..., method="cc", threshold=0.1)). Drop tracks
// with fewer than MIN_TRACK_NODES total nodes — same filter evaluate.py applies
// after the method call.
static final double TRACK_SCORE_THRESHOLD = 0.1;
static final int MIN_TRACK_NODES = 3;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
package org.jlab.rec.ahdc.AI;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.jlab.geom.prim.Line3D;
import org.jlab.geom.prim.Point3D;
import org.jlab.geom.prim.Vector3D;
import org.jlab.io.base.DataBank;
import org.jlab.rec.ahdc.Hit.Hit;

/** Builds the graph tensors expected by the exported GNN edge scorer.
* Ports track-finding/gnn/dataset.py::build_graph — must stay byte-compatible
* with the training-time feature layout and normalization.
*/
final class GNNGraphBuilder {

/** Container for the tensors + node provenance that the caller needs. */
static final class GraphInput {
final float[][] nodeFeatures; // shape [N, 10]
final long[][] edgeIndex; // shape [2, E]
final float[][] edgeAttr; // shape [E, 9]
/** nodeToSource[i] is the backing Hit for AHDC nodes, or null for ATOF nodes. */
final Hit[] nodeToSource;

GraphInput(float[][] nodeFeatures, long[][] edgeIndex, float[][] edgeAttr, Hit[] nodeToSource) {
this.nodeFeatures = nodeFeatures;
this.edgeIndex = edgeIndex;
this.edgeAttr = edgeAttr;
this.nodeToSource = nodeToSource;
}
}

private GNNGraphBuilder() {}

/** Build a graph from AHDC hits (required) plus the ATOF::hits bank (optional). */
static GraphInput build(List<Hit> ahdcHits, DataBank atofHitsBank) {
int nAhdc = ahdcHits == null ? 0 : ahdcHits.size();

// Node state buffers (grow as we append AHDC then ATOF nodes).
List<double[]> nodeBuf = new ArrayList<>(); // per-node raw floats (see NodeField indexes)
List<Line3D> nodeLine = new ArrayList<>(); // wire line for AHDC; null for ATOF
List<Hit> nodeHit = new ArrayList<>(); // backing Hit for AHDC; null for ATOF

// --- AHDC nodes -------------------------------------------------------------
for (int i = 0; i < nAhdc; i++) {
Hit h = ahdcHits.get(i);
Line3D line = h.getLine();
if (line == null) continue; // missing geometry → skip (shouldn't happen after setWirePosition)

Point3D mid = line.midpoint();
Vector3D dir = line.toVector();
double len = Math.max(dir.mag(), 1e-12);
double ux = dir.x() / len, uy = dir.y() / len, uz = dir.z() / len;
double stereo = Math.atan2(Math.sqrt(ux*ux + uy*uy), uz);

int absLayer = (h.getSuperLayerId() - 1) * 2 + (h.getLayerId() - 1);
nodeBuf.add(new double[]{
absLayer, // 0: abs_layer
h.getPhi(), // 1: phi
h.getRadius(), // 2: r
stereo, // 3: stereo_angle
mid.x(), // 4: x_mid
mid.y(), // 5: y_mid
mid.z(), // 6: z_mid
ux, // 7: ux
uy, // 8: uy
uz, // 9: uz
h.getX(), // 10: x (raw, for edge distance mask)
h.getY(), // 11: y (raw, for edge distance mask)
0.0, // 12: det_type = 0 (AHDC)
});
nodeLine.add(line);
nodeHit.add(h);
}

// --- ATOF nodes -------------------------------------------------------------
// Deduplicate by (sector, layer, component) — inference-time variant of the
// Python dedup which also keys on track id (only needed at training time).
if (atofHitsBank != null) {
Set<Long> seen = new HashSet<>();
int rows = atofHitsBank.rows();
for (int r = 0; r < rows; r++) {
int sector = atofHitsBank.getInt("sector", r);
int layer = atofHitsBank.getInt("layer", r);
int component = atofHitsBank.getInt("component", r);
long key = (((long)sector * 1000L) + layer) * 1000L + component;
if (!seen.add(key)) continue;

double x = atofHitsBank.getFloat("x", r);
double y = atofHitsBank.getFloat("y", r);
double radius = Math.hypot(x, y);
double phi = Math.atan2(y, x);
int absLayer = (component == 10) ? GNNConstants.ATOF_BAR_ABS_LAYER
: GNNConstants.ATOF_WEDGE_ABS_LAYER;

nodeBuf.add(new double[]{
absLayer, phi, radius,
0.0, // stereo
x, y, 0.0, // mid
0.0, 0.0, 1.0, // (ux, uy, uz)
x, y, // raw x, y (for edge mask)
1.0, // det_type = 1 (ATOF)
});
nodeLine.add(null);
nodeHit.add(null);
}
}

int n = nodeBuf.size();
if (n < 2) {
return new GraphInput(new float[0][GNNConstants.NODE_FEAT_DIM],
new long[][]{new long[0], new long[0]},
new float[0][GNNConstants.EDGE_FEAT_DIM],
new Hit[0]);
}

// --- Node feature tensor [N, 10] --------------------------------------------
float[][] nodeFeatures = new float[n][GNNConstants.NODE_FEAT_DIM];
for (int i = 0; i < n; i++) {
double[] v = nodeBuf.get(i);
nodeFeatures[i][0] = (float)(v[0] / 11.0);
nodeFeatures[i][1] = (float)(v[1] / Math.PI);
nodeFeatures[i][2] = (float)(v[2] / GNNConstants.DOCA_STD);
nodeFeatures[i][3] = (float)(v[3] / GNNConstants.STEREO_ANGLE_MAX);
nodeFeatures[i][4] = (float)(v[4] / GNNConstants.MAX_R);
nodeFeatures[i][5] = (float)(v[5] / GNNConstants.MAX_R);
nodeFeatures[i][6] = (float)(v[6] / GNNConstants.Z_HALF_LENGTH);
nodeFeatures[i][7] = (float)(v[7] * GNNConstants.STEREO_SCALE);
nodeFeatures[i][8] = (float)(v[8] * GNNConstants.STEREO_SCALE);
nodeFeatures[i][9] = (float)(v[9]);
}

// --- Edge construction (directed, layer_gap in [1, MAX_LAYER_GAP]) -----------
// Mirrors Python's np.where(mask) on a non-symmetric mask.
int[] absLayer = new int[n];
double[] xRaw = new double[n];
double[] yRaw = new double[n];
double[] rRaw = new double[n];
double[] phiRaw = new double[n];
double[] stereoRaw = new double[n];
double[] detTypeRaw = new double[n];
for (int i = 0; i < n; i++) {
double[] v = nodeBuf.get(i);
absLayer[i] = (int) v[0];
phiRaw[i] = v[1];
rRaw[i] = v[2];
stereoRaw[i] = v[3];
xRaw[i] = v[10];
yRaw[i] = v[11];
detTypeRaw[i] = v[12];
}

List<long[]> edgePairs = new ArrayList<>();
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (i == j) continue;
int gap = absLayer[j] - absLayer[i];
if (gap < 1 || gap > GNNConstants.MAX_LAYER_GAP) continue;
double dx = xRaw[i] - xRaw[j];
double dy = yRaw[i] - yRaw[j];
if (dx*dx + dy*dy > GNNConstants.MAX_EDGE_DIST_SQ) continue;
edgePairs.add(new long[]{i, j});
}
}

int e = edgePairs.size();
long[][] edgeIndex = new long[2][e];
float[][] edgeAttr = new float[e][GNNConstants.EDGE_FEAT_DIM];

for (int k = 0; k < e; k++) {
long[] p = edgePairs.get(k);
int s = (int) p[0];
int d = (int) p[1];
edgeIndex[0][k] = s;
edgeIndex[1][k] = d;

// dphi wrapped into [-pi, pi]
double dphi = phiRaw[s] - phiRaw[d];
dphi = ((dphi + Math.PI) % (2.0 * Math.PI) + 2.0 * Math.PI) % (2.0 * Math.PI) - Math.PI;
double dlayer = (double)(absLayer[d] - absLayer[s]) / GNNConstants.MAX_LAYER_GAP;

double doca, z1, z2;
Line3D ls = nodeLine.get(s);
Line3D ld = nodeLine.get(d);
if (ls != null && ld != null) {
doca = ls.distance(ld).length();
// Python: z1 = cp_d.z, where cp_d is the point on line_s closest to line_d's midpoint
// z2 = cp_s.z, where cp_s is the point on line_d closest to line_s's midpoint
z1 = clampZ(ls.distance(ld.midpoint()).origin().z());
z2 = clampZ(ld.distance(ls.midpoint()).origin().z());
} else {
double ex = xRaw[s] - xRaw[d];
double ey = yRaw[s] - yRaw[d];
doca = Math.hypot(ex, ey);
z1 = 0.0;
z2 = 0.0;
}

double edgeDetType = 0.5 * (detTypeRaw[s] + detTypeRaw[d]);

edgeAttr[k][0] = (float)(dphi / Math.PI);
edgeAttr[k][1] = (float) dlayer;
edgeAttr[k][2] = (float)(doca / GNNConstants.MAX_R);
edgeAttr[k][3] = (float)(z1 / GNNConstants.Z_HALF_LENGTH);
edgeAttr[k][4] = (float)(z2 / GNNConstants.Z_HALF_LENGTH);
edgeAttr[k][5] = (float)(rRaw[s] / GNNConstants.DOCA_STD);
edgeAttr[k][6] = (float)(rRaw[d] / GNNConstants.DOCA_STD);
edgeAttr[k][7] = (float)((stereoRaw[s] - stereoRaw[d]) / (2.0 * GNNConstants.STEREO_ANGLE_MAX));
edgeAttr[k][8] = (float) edgeDetType;
}

Hit[] nodeToHit = nodeHit.toArray(new Hit[0]);
return new GraphInput(nodeFeatures, edgeIndex, edgeAttr, nodeToHit);
}

private static double clampZ(double z) {
if (z < -GNNConstants.Z_HALF_LENGTH) return -GNNConstants.Z_HALF_LENGTH;
if (z > GNNConstants.Z_HALF_LENGTH) return GNNConstants.Z_HALF_LENGTH;
return z;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package org.jlab.rec.ahdc.AI;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

import org.jlab.io.base.DataBank;
import org.jlab.rec.ahdc.Cluster.Cluster;
import org.jlab.rec.ahdc.Hit.Hit;
import org.jlab.rec.ahdc.PreCluster.PreCluster;
import org.jlab.rec.ahdc.PreCluster.PreClusterFinder;
import org.jlab.rec.ahdc.Track.Track;

/** Orchestrates GNN-based track finding: builds the graph, runs the exported
* edge scorer, extracts tracks via connected components on edge scores
* thresholded at 0.1, and converts each node-set back into a {@link Track}
* carrying per-superlayer Clusters so the downstream helix fit / Kalman
* stages can consume it.
*/
public final class GNNPrediction {

private static final Logger LOGGER = Logger.getLogger(GNNPrediction.class.getName());

public ArrayList<Track> prediction(List<Hit> ahdcHits,
DataBank atofHitsBank,
ModelTrackFindingGNN model) {
ArrayList<Track> out = new ArrayList<>();
if (ahdcHits == null || ahdcHits.isEmpty() || model == null) return out;

GNNGraphBuilder.GraphInput g = GNNGraphBuilder.build(ahdcHits, atofHitsBank);
int nNodes = g.nodeToSource.length;
int nEdges = g.edgeIndex[0].length;
if (nNodes < GNNConstants.MIN_NODES || nEdges == 0) {
return out; // model cannot run on graphs this small
}

float[] edgeScores;
try {
edgeScores = model.predictEdgeScores(g.nodeFeatures, g.edgeIndex, g.edgeAttr);
} catch (Exception ex) {
LOGGER.warning(() -> "GNN inference failed: " + ex);
return out;
}

// Connected components at TRACK_SCORE_THRESHOLD, filtered to
// components of size >= MIN_TRACK_NODES — mirrors gnn/evaluate.py.
List<int[]> trackNodeSets = SeedExtendTrackExtractor.extract(edgeScores, g.edgeIndex, nNodes);

for (int[] nodes : trackNodeSets) {
// Collect just the AHDC Hits in this track — ATOF nodes were graph
// context only, they don't belong in AHDC::track or AHDC::hits.
ArrayList<Hit> trackHits = new ArrayList<>(nodes.length);
for (int n : nodes) {
Hit h = g.nodeToSource[n];
if (h != null) trackHits.add(h);
}
if (trackHits.isEmpty()) continue;

ArrayList<Cluster> clusters = buildSuperlayerClusters(trackHits);
if (clusters.size() < 3) continue; // matches the downstream >=3 filter

out.add(new Track(clusters));
}

return out;
}

/** One {@link Cluster} per superlayer built from two {@link PreCluster}s (one
* per layer within the superlayer). Using real PreClusters — instead of the
* 3-arg {@code Cluster(x,y,z)} constructor — keeps
* {@code Track.generateHitList()} and {@code DocaClusterRefiner}'s stereo
* pairing working for GNN-discovered tracks just like they do for MLP tracks.
*/
private static ArrayList<Cluster> buildSuperlayerClusters(List<Hit> hits) {
// Feed the track's hits through the same preclustering the MLP path uses.
// findPreclusters mutates its input (it calls setUse(true) on consumed
// hits), so pass a copy and ensure each hit starts unmarked.
ArrayList<Hit> hitsForPre = new ArrayList<>(hits.size());
for (Hit h : hits) { h.setUse(false); hitsForPre.add(h); }
PreClusterFinder pcf = new PreClusterFinder();
pcf.findPreclusters(hitsForPre);
ArrayList<PreCluster> preclusters = pcf.get_AHDCPreClusters();

// Index by (superlayer, layer). If the GNN assigns two PreClusters of the
// same superlayer+layer to one track (rare — it would mean two disjoint
// wire runs on the same layer), keep the largest and drop the rest.
Map<Integer, PreCluster[]> bySuperlayer = new HashMap<>();
for (PreCluster pc : preclusters) {
int sl = pc.get_Super_layer();
int layerIdx = pc.get_Layer() - 1; // layer is 1-based, slots are [0,1]
if (layerIdx < 0 || layerIdx > 1) continue;
PreCluster[] slot = bySuperlayer.computeIfAbsent(sl, k -> new PreCluster[2]);
PreCluster prev = slot[layerIdx];
if (prev == null || pc.get_Num_wire() > prev.get_Num_wire()) slot[layerIdx] = pc;
}

ArrayList<Cluster> clusters = new ArrayList<>();
// Iterate superlayers in ascending order to keep downstream output stable.
// If both stereo layers have a PreCluster, pair them (full stereo cluster).
// If only one has hits, use the single-layer Cluster(PreCluster) ctor —
// DocaClusterRefiner handles PreClusters_list.size() != 2 with a
// degenerate DocaCluster fallback, so the helix fit still runs.
for (int sl = 1; sl <= 5; sl++) {
PreCluster[] slot = bySuperlayer.get(sl);
if (slot == null) continue;
if (slot[0] != null && slot[1] != null) {
clusters.add(new Cluster(slot[0], slot[1]));
} else {
PreCluster single = (slot[0] != null) ? slot[0] : slot[1];
if (single != null) clusters.add(new Cluster(single));
}
}
return clusters;
}
}
Loading