@@ -415,7 +415,9 @@ def _merge_with_strategy(
415415 return pd .merge (predictions , one_event_filtered , on = pks , how = "left" )
416416
417417
418- def max_aggregation (df : pd .DataFrame , pks : list [str ], score : str , ref_time : str , ref_event : str ) -> pd .DataFrame :
418+ def max_aggregation (
419+ df : pd .DataFrame , pks : list [str ], score : str , ref_time : str , ref_event : str , threshold : float = None
420+ ) -> pd .DataFrame :
419421 """
420422 Aggregates the DataFrame by selecting the maximum score value.
421423
@@ -431,6 +433,8 @@ def max_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str,
431433 The column name containing the time to consider, by default None.
432434 ref_event : Optional[str], optional
433435 The column name containing the event to consider, by default None.
436+ threshold : Optional[float], optional
437+ Score threshold to compare against, by default None.
434438
435439 Returns
436440 -------
@@ -446,7 +450,9 @@ def max_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str,
446450 return df .drop_duplicates (subset = pks )
447451
448452
449- def min_aggregation (df : pd .DataFrame , pks : list [str ], score : str , ref_time : str , ref_event : str ) -> pd .DataFrame :
453+ def min_aggregation (
454+ df : pd .DataFrame , pks : list [str ], score : str , ref_time : str , ref_event : str , threshold : float = None
455+ ) -> pd .DataFrame :
450456 """
451457 Aggregates the DataFrame by selecting the minimum score value.
452458
@@ -462,6 +468,8 @@ def min_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str,
462468 The column name containing the time to consider, by default None.
463469 ref_event : Optional[str], optional
464470 The column name containing the event to consider, by default None.
471+ threshold : Optional[float], optional
472+ Score threshold to compare against, by default None.
465473
466474 Returns
467475 -------
@@ -477,7 +485,9 @@ def min_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str,
477485 return df .drop_duplicates (subset = pks )
478486
479487
480- def first_aggregation (df : pd .DataFrame , pks : list [str ], score : str , ref_time : str , ref_event : str ) -> pd .DataFrame :
488+ def first_aggregation (
489+ df : pd .DataFrame , pks : list [str ], score : str , ref_time : str , ref_event : str , threshold : float = None
490+ ) -> pd .DataFrame :
481491 """
482492 Aggregates the DataFrame by selecting the first occurrence based on event time.
483493
@@ -493,6 +503,8 @@ def first_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: st
493503 The column name containing the time to consider, by default None.
494504 ref_event : Optional[str], optional
495505 The column name containing the event to consider, by default None.
506+ threshold : Optional[float], optional
507+ Score threshold to compare against, by default None.
496508
497509 Returns
498510 -------
@@ -508,7 +520,51 @@ def first_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: st
508520 return df .drop_duplicates (subset = pks )
509521
510522
511- def last_aggregation (df : pd .DataFrame , pks : list [str ], score : str , ref_time : str , ref_event : str ) -> pd .DataFrame :
523+ def first_above_threshold_aggregation (
524+ df : pd .DataFrame ,
525+ pks : list [str ],
526+ score : str ,
527+ ref_time : Optional [str ],
528+ ref_event : Optional [str ],
529+ threshold : float ,
530+ ) -> pd .DataFrame :
531+ """
532+ Aggregates by selecting the first prediction with a score above the given threshold.
533+
534+ Parameters
535+ ----------
536+ df : pd.DataFrame
537+ The dataframe to aggregate.
538+ pks : list[str]
539+ Keys to group by.
540+ score : str
541+ Score column name.
542+ ref_time : str
543+ Time reference column name.
544+ ref_event : str
545+ Not used here but retained for API consistency.
546+ threshold : float
547+ Score threshold to compare against.
548+
549+ Returns
550+ -------
551+ pd.DataFrame
552+ Aggregated dataframe with first above-threshold score per group.
553+ """
554+ ref_score = _resolve_score_col (df , score )
555+ if ref_time is None :
556+ raise ValueError ("ref_time is required for first_above_threshold aggregation" )
557+
558+ reference_time = _resolve_time_col (df , ref_time )
559+ df = df [df [ref_score ] > threshold ]
560+ df = df [df [reference_time ].notna ()]
561+ df = df .sort_values (by = reference_time )
562+ return df .drop_duplicates (subset = pks )
563+
564+
565+ def last_aggregation (
566+ df : pd .DataFrame , pks : list [str ], score : str , ref_time : str , ref_event : str , threshold : float = None
567+ ) -> pd .DataFrame :
512568 """
513569 Aggregates the DataFrame by selecting the last occurrence based on event time.
514570
@@ -524,6 +580,8 @@ def last_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str
524580 The column name containing the time to consider, by default None.
525581 ref_event : Optional[str], optional
526582 The column name containing the event to consider, by default None.
583+ threshold : Optional[float], optional
584+ Score threshold to compare against, by default None.
527585
528586 Returns
529587 -------
@@ -546,6 +604,7 @@ def event_score(
546604 ref_time : Optional [str ] = None ,
547605 ref_event : Optional [str ] = None ,
548606 aggregation_method : str = "max" ,
607+ threshold : Optional [float ] = None ,
549608) -> pd .DataFrame :
550609 """
551610 Reduces a dataframe of all predictions to a single row of significance; such as the max or most recent value for
@@ -573,6 +632,8 @@ def event_score(
573632 the aggregation_method.
574633 aggregation_method : str, optional
575634 A string describing the method to select a value, by default 'max'.
635+ threshold : Optional[float], optional
636+ Score threshold to compare against, by default None.
576637
577638 Returns
578639 -------
@@ -590,12 +651,13 @@ def event_score(
590651 "min" : min_aggregation ,
591652 "first" : first_aggregation ,
592653 "last" : last_aggregation ,
654+ "first_above_threshold" : first_above_threshold_aggregation ,
593655 }
594656
595657 if aggregation_method not in aggregation_methods :
596658 raise ValueError (f"Unknown aggregation method: { aggregation_method } " )
597659
598- df = aggregation_methods [aggregation_method ](merged_frame , pks , score , ref_time , ref_event )
660+ df = aggregation_methods [aggregation_method ](merged_frame , pks , score , ref_time , ref_event , threshold )
599661 return df .loc [~ np .isnan (df .index )]
600662
601663
0 commit comments