Skip to content

Commit 0903bbb

Browse files
🛠️ Add first above threshold aggregation
1 parent 5338f27 commit 0903bbb

1 file changed

Lines changed: 67 additions & 5 deletions

File tree

src/seismometer/data/pandas_helpers.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,9 @@ def _merge_with_strategy(
415415
return pd.merge(predictions, one_event_filtered, on=pks, how="left")
416416

417417

418-
def max_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str) -> pd.DataFrame:
418+
def max_aggregation(
419+
df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str, threshold: float = None
420+
) -> pd.DataFrame:
419421
"""
420422
Aggregates the DataFrame by selecting the maximum score value.
421423
@@ -431,6 +433,8 @@ def max_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str,
431433
The column name containing the time to consider, by default None.
432434
ref_event : Optional[str], optional
433435
The column name containing the event to consider, by default None.
436+
threshold : Optional[float], optional
437+
Score threshold to compare against, by default None.
434438
435439
Returns
436440
-------
@@ -446,7 +450,9 @@ def max_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str,
446450
return df.drop_duplicates(subset=pks)
447451

448452

449-
def min_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str) -> pd.DataFrame:
453+
def min_aggregation(
454+
df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str, threshold: float = None
455+
) -> pd.DataFrame:
450456
"""
451457
Aggregates the DataFrame by selecting the minimum score value.
452458
@@ -462,6 +468,8 @@ def min_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str,
462468
The column name containing the time to consider, by default None.
463469
ref_event : Optional[str], optional
464470
The column name containing the event to consider, by default None.
471+
threshold : Optional[float], optional
472+
Score threshold to compare against, by default None.
465473
466474
Returns
467475
-------
@@ -477,7 +485,9 @@ def min_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str,
477485
return df.drop_duplicates(subset=pks)
478486

479487

480-
def first_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str) -> pd.DataFrame:
488+
def first_aggregation(
489+
df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str, threshold: float = None
490+
) -> pd.DataFrame:
481491
"""
482492
Aggregates the DataFrame by selecting the first occurrence based on event time.
483493
@@ -493,6 +503,8 @@ def first_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: st
493503
The column name containing the time to consider, by default None.
494504
ref_event : Optional[str], optional
495505
The column name containing the event to consider, by default None.
506+
threshold : Optional[float], optional
507+
Score threshold to compare against, by default None.
496508
497509
Returns
498510
-------
@@ -508,7 +520,51 @@ def first_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: st
508520
return df.drop_duplicates(subset=pks)
509521

510522

511-
def last_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str) -> pd.DataFrame:
523+
def first_above_threshold_aggregation(
524+
df: pd.DataFrame,
525+
pks: list[str],
526+
score: str,
527+
ref_time: Optional[str],
528+
ref_event: Optional[str],
529+
threshold: float,
530+
) -> pd.DataFrame:
531+
"""
532+
Aggregates by selecting the first prediction with a score above the given threshold.
533+
534+
Parameters
535+
----------
536+
df : pd.DataFrame
537+
The dataframe to aggregate.
538+
pks : list[str]
539+
Keys to group by.
540+
score : str
541+
Score column name.
542+
ref_time : str
543+
Time reference column name.
544+
ref_event : str
545+
Not used here but retained for API consistency.
546+
threshold : float
547+
Score threshold to compare against.
548+
549+
Returns
550+
-------
551+
pd.DataFrame
552+
Aggregated dataframe with first above-threshold score per group.
553+
"""
554+
ref_score = _resolve_score_col(df, score)
555+
if ref_time is None:
556+
raise ValueError("ref_time is required for first_above_threshold aggregation")
557+
558+
reference_time = _resolve_time_col(df, ref_time)
559+
df = df[df[ref_score] > threshold]
560+
df = df[df[reference_time].notna()]
561+
df = df.sort_values(by=reference_time)
562+
return df.drop_duplicates(subset=pks)
563+
564+
565+
def last_aggregation(
566+
df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str, threshold: float = None
567+
) -> pd.DataFrame:
512568
"""
513569
Aggregates the DataFrame by selecting the last occurrence based on event time.
514570
@@ -524,6 +580,8 @@ def last_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str
524580
The column name containing the time to consider, by default None.
525581
ref_event : Optional[str], optional
526582
The column name containing the event to consider, by default None.
583+
threshold : Optional[float], optional
584+
Score threshold to compare against, by default None.
527585
528586
Returns
529587
-------
@@ -546,6 +604,7 @@ def event_score(
546604
ref_time: Optional[str] = None,
547605
ref_event: Optional[str] = None,
548606
aggregation_method: str = "max",
607+
threshold: Optional[float] = None,
549608
) -> pd.DataFrame:
550609
"""
551610
Reduces a dataframe of all predictions to a single row of significance; such as the max or most recent value for
@@ -573,6 +632,8 @@ def event_score(
573632
the aggregation_method.
574633
aggregation_method : str, optional
575634
A string describing the method to select a value, by default 'max'.
635+
threshold : Optional[float], optional
636+
Score threshold to compare against, by default None.
576637
577638
Returns
578639
-------
@@ -590,12 +651,13 @@ def event_score(
590651
"min": min_aggregation,
591652
"first": first_aggregation,
592653
"last": last_aggregation,
654+
"first_above_threshold": first_above_threshold_aggregation,
593655
}
594656

595657
if aggregation_method not in aggregation_methods:
596658
raise ValueError(f"Unknown aggregation method: {aggregation_method}")
597659

598-
df = aggregation_methods[aggregation_method](merged_frame, pks, score, ref_time, ref_event)
660+
df = aggregation_methods[aggregation_method](merged_frame, pks, score, ref_time, ref_event, threshold)
599661
return df.loc[~np.isnan(df.index)]
600662

601663

0 commit comments

Comments
 (0)