From 52d7add2a074f74ec2099f264d9d7c8d1fe9b2e2 Mon Sep 17 00:00:00 2001 From: Feng Date: Thu, 19 Mar 2026 17:08:46 -0400 Subject: [PATCH 1/2] Fix top_loci dropping variants shared between multiple credible sets --- R/susie_wrapper.R | 89 ++++++++++++----------------- tests/testthat/test_susie_wrapper.R | 21 +++++-- 2 files changed, 55 insertions(+), 55 deletions(-) diff --git a/R/susie_wrapper.R b/R/susie_wrapper.R index ff6d225c..bd95e5d3 100644 --- a/R/susie_wrapper.R +++ b/R/susie_wrapper.R @@ -296,40 +296,10 @@ susie_rss_pipeline <- function(sumstats, LD_mat, n = NULL, var_y = NULL, L = 5, #' @noRd get_cs_index <- function(snps_idx, susie_cs) { - # Use pmap to iterate over each vector in susie_cs - idx_lengths <- tryCatch( - { - pmap(list(x = susie_cs), function(x) { - # Check if snps_idx is in the CS and return the length of the CS if it is - if (snps_idx %in% x) { - return(length(x)) - } else { - return(NA_integer_) - } - }) %>% unlist() - }, - error = function(e) NA_integer_ - ) - idx <- which(!is.na(idx_lengths)) - # idx length should be either 1 or 0 - # But in some rare cases there will be a convergence issue resulting in a variant belong to multiple CS - # In which case we will keep one of them, with a warning - if (length(idx) > 0) { - if (length(idx) > 1) { - smallest_cs_idx <- which.min(idx_lengths[idx]) - selected_cs <- idx[smallest_cs_idx] - selected_length <- idx_lengths[selected_cs] - - warning(sprintf( - "Variable %d found in multiple CS: %s. Keeping smallest: CS %d (length %d).", - snps_idx, paste(idx, collapse = ", "), selected_cs, selected_length - )) - idx <- selected_cs # Keep index with smallest length - } - return(idx) - } else { - return(NA_integer_) - } + # Return ALL CS indices that contain this variant (not just one) + idx <- which(vapply(susie_cs, function(x) snps_idx %in% x, logical(1))) + if (length(idx) == 0) return(NA_integer_) + return(idx) } #' @noRd get_top_variants_idx <- function(susie_output, signal_cutoff) { @@ -338,9 +308,32 @@ get_top_variants_idx <- function(susie_output, signal_cutoff) { sort() } #' @noRd +#' Returns a data.frame(variant_idx, cs_idx) with one row per (variant, CS) pair. +#' Variants in multiple CSs get multiple rows. get_cs_info <- function(susie_output_sets_cs, top_variants_idx) { - cs_info_pri <- map_int(top_variants_idx, ~ get_cs_index(.x, susie_output_sets_cs)) - ifelse(is.na(cs_info_pri), 0, as.numeric(str_replace(names(susie_output_sets_cs)[cs_info_pri], "L", ""))) + cs_names <- names(susie_output_sets_cs) + rows <- lapply(top_variants_idx, function(vi) { + idx <- get_cs_index(vi, susie_output_sets_cs) + if (length(idx) == 1 && is.na(idx)) { + data.frame(variant_idx = vi, cs_idx = 0L, stringsAsFactors = FALSE) + } else { + cs_nums <- as.integer(str_replace(cs_names[idx], "L", "")) + data.frame(variant_idx = rep(vi, length(cs_nums)), cs_idx = cs_nums, stringsAsFactors = FALSE) + } + }) + do.call(rbind, rows) +} +#' @noRd +#' Lookup a variant's CS number at a secondary coverage level. +#' If in multiple CSs, prefer the one matching prefer_cs (the primary CS). +lookup_cs <- function(vi, susie_cs, prefer_cs = NULL) { + cs_names <- names(susie_cs) + in_cs <- vapply(susie_cs, function(x) vi %in% x, logical(1)) + idx <- which(in_cs) + if (length(idx) == 0) return(0L) + L_nums <- as.integer(str_replace(cs_names[idx], "L", "")) + if (!is.null(prefer_cs) && prefer_cs %in% L_nums) return(prefer_cs) + return(L_nums[1]) } #' @noRd @@ -434,30 +427,24 @@ susie_post_processor <- function(susie_output, data_x, data_y, X_scalar, y_scala if (length(eff_idx) > 0) { # Prepare for top loci table top_variants_idx_pri <- get_top_variants_idx(susie_output, signal_cutoff) - cs_pri <- get_cs_info(susie_output$sets$cs, top_variants_idx_pri) + # get_cs_info now returns data.frame(variant_idx, cs_idx) with one row per (variant, CS) pair + top_loci <- get_cs_info(susie_output$sets$cs, top_variants_idx_pri) + if (is.null(top_loci)) top_loci <- data.frame(variant_idx = integer(0), cs_idx = integer(0)) susie_output$cs_corr <- if (mode %in% c("susie", "mvsusie")) get_cs_correlation(susie_output, X = data_x) else get_cs_correlation(susie_output, Xcorr = data_x) - top_loci_list <- list("coverage_0.95" = data.frame(variant_idx = top_variants_idx_pri, cs_idx = cs_pri, stringsAsFactors = FALSE)) + names(top_loci)[2] <- "cs_coverage_0.95" - ## Loop over each secondary coverage value + ## Secondary coverages: lookup per row to avoid many-to-many join sets_secondary <- list() if (!is.null(secondary_coverage) && length(secondary_coverage)) { for (sec_cov in secondary_coverage) { sets_secondary[[paste0("coverage_", sec_cov)]] <- get_cs_and_corr(susie_output, sec_cov, data_x, mode, min_abs_corr) - top_variants_idx_sec <- get_top_variants_idx(sets_secondary[[paste0("coverage_", sec_cov)]], signal_cutoff) - cs_sec <- get_cs_info(sets_secondary[[paste0("coverage_", sec_cov)]]$sets$cs, top_variants_idx_sec) - top_loci_list[[paste0("coverage_", sec_cov)]] <- data.frame(variant_idx = top_variants_idx_sec, cs_idx = cs_sec, stringsAsFactors = FALSE) + sec_cs <- sets_secondary[[paste0("coverage_", sec_cov)]]$sets$cs + col_name <- paste0("cs_coverage_", sec_cov) + top_loci[[col_name]] <- mapply(function(vi, pcs) lookup_cs(vi, sec_cs, prefer_cs = pcs), + top_loci$variant_idx, top_loci$cs_coverage_0.95) } } - # Iterate over the remaining tables, rename and merge them - names(top_loci_list[[1]])[2] <- paste0("cs_", names(top_loci_list)[1]) - top_loci <- top_loci_list[[1]] - if (length(top_loci_list) > 1) { - for (i in 2:length(top_loci_list)) { - names(top_loci_list[[i]])[2] <- paste0("cs_", names(top_loci_list)[i]) - top_loci <- full_join(top_loci, top_loci_list[[i]], by = "variant_idx") - } - } if (nrow(top_loci) > 0) { top_loci[is.na(top_loci)] <- 0 variants <- res$variant_names[top_loci$variant_idx] diff --git a/tests/testthat/test_susie_wrapper.R b/tests/testthat/test_susie_wrapper.R index e82b5b1b..571d9b95 100644 --- a/tests/testthat/test_susie_wrapper.R +++ b/tests/testthat/test_susie_wrapper.R @@ -201,16 +201,29 @@ test_that("get_cs_info maps variants to CS numbers", { susie_cs <- list(L1 = c(1, 2), L3 = c(4, 5, 6)) top_idx <- c(1, 3, 5) result <- pecotmr:::get_cs_info(susie_cs, top_idx) - expect_equal(result[1], 1) - expect_equal(result[2], 0) - expect_equal(result[3], 3) + # Now returns data.frame(variant_idx, cs_idx) with one row per (variant, CS) pair + expect_true(is.data.frame(result)) + expect_equal(result$variant_idx, c(1, 3, 5)) + expect_equal(result$cs_idx, c(1L, 0L, 3L)) }) test_that("get_cs_info handles all variants outside CS", { susie_cs <- list(L1 = c(1, 2)) top_idx <- c(5, 6, 7) result <- pecotmr:::get_cs_info(susie_cs, top_idx) - expect_true(all(result == 0)) + expect_true(is.data.frame(result)) + expect_true(all(result$cs_idx == 0)) +}) + +test_that("get_cs_info reports variant in multiple CSs as multiple rows", { + susie_cs <- list(L1 = c(1, 2, 3), L3 = c(2, 3, 4)) + top_idx <- c(1, 2, 4) + result <- pecotmr:::get_cs_info(susie_cs, top_idx) + expect_true(is.data.frame(result)) + # variant 2 is in both L1 and L3, so it gets two rows + expect_equal(nrow(result), 4) + expect_equal(sum(result$variant_idx == 2), 2) + expect_equal(sort(result$cs_idx[result$variant_idx == 2]), c(1L, 3L)) }) # ============================================================================= From 90e4e174d04c64934a3eada0f1068c18b1f35f21 Mon Sep 17 00:00:00 2001 From: Feng Date: Thu, 19 Mar 2026 17:18:47 -0400 Subject: [PATCH 2/2] Fix top_loci dropping variants shared between multiple credible sets --- R/susie_wrapper.R | 41 +++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/R/susie_wrapper.R b/R/susie_wrapper.R index bd95e5d3..1aa0eac4 100644 --- a/R/susie_wrapper.R +++ b/R/susie_wrapper.R @@ -323,19 +323,6 @@ get_cs_info <- function(susie_output_sets_cs, top_variants_idx) { }) do.call(rbind, rows) } -#' @noRd -#' Lookup a variant's CS number at a secondary coverage level. -#' If in multiple CSs, prefer the one matching prefer_cs (the primary CS). -lookup_cs <- function(vi, susie_cs, prefer_cs = NULL) { - cs_names <- names(susie_cs) - in_cs <- vapply(susie_cs, function(x) vi %in% x, logical(1)) - idx <- which(in_cs) - if (length(idx) == 0) return(0L) - L_nums <- as.integer(str_replace(cs_names[idx], "L", "")) - if (!is.null(prefer_cs) && prefer_cs %in% L_nums) return(prefer_cs) - return(L_nums[1]) -} - #' @noRd get_cs_and_corr <- function(susie_output, coverage, data_x, mode = c("susie", "susie_rss", "mvsusie"), min_abs_corr = NULL) { if (mode %in% c("susie", "mvsusie")) { @@ -427,21 +414,31 @@ susie_post_processor <- function(susie_output, data_x, data_y, X_scalar, y_scala if (length(eff_idx) > 0) { # Prepare for top loci table top_variants_idx_pri <- get_top_variants_idx(susie_output, signal_cutoff) - # get_cs_info now returns data.frame(variant_idx, cs_idx) with one row per (variant, CS) pair - top_loci <- get_cs_info(susie_output$sets$cs, top_variants_idx_pri) - if (is.null(top_loci)) top_loci <- data.frame(variant_idx = integer(0), cs_idx = integer(0)) + # get_cs_info returns data.frame(variant_idx, cs_idx) with one row per (variant, CS) pair + top_loci_pri <- get_cs_info(susie_output$sets$cs, top_variants_idx_pri) + if (is.null(top_loci_pri)) top_loci_pri <- data.frame(variant_idx = integer(0), cs_idx = integer(0)) susie_output$cs_corr <- if (mode %in% c("susie", "mvsusie")) get_cs_correlation(susie_output, X = data_x) else get_cs_correlation(susie_output, Xcorr = data_x) - names(top_loci)[2] <- "cs_coverage_0.95" + top_loci_list <- list("coverage_0.95" = top_loci_pri) - ## Secondary coverages: lookup per row to avoid many-to-many join + ## Loop over each secondary coverage value independently sets_secondary <- list() if (!is.null(secondary_coverage) && length(secondary_coverage)) { for (sec_cov in secondary_coverage) { sets_secondary[[paste0("coverage_", sec_cov)]] <- get_cs_and_corr(susie_output, sec_cov, data_x, mode, min_abs_corr) - sec_cs <- sets_secondary[[paste0("coverage_", sec_cov)]]$sets$cs - col_name <- paste0("cs_coverage_", sec_cov) - top_loci[[col_name]] <- mapply(function(vi, pcs) lookup_cs(vi, sec_cs, prefer_cs = pcs), - top_loci$variant_idx, top_loci$cs_coverage_0.95) + top_variants_idx_sec <- get_top_variants_idx(sets_secondary[[paste0("coverage_", sec_cov)]], signal_cutoff) + top_loci_sec <- get_cs_info(sets_secondary[[paste0("coverage_", sec_cov)]]$sets$cs, top_variants_idx_sec) + if (is.null(top_loci_sec)) top_loci_sec <- data.frame(variant_idx = integer(0), cs_idx = integer(0)) + top_loci_list[[paste0("coverage_", sec_cov)]] <- top_loci_sec + } + } + + # Merge coverage tables via full_join + names(top_loci_list[[1]])[2] <- paste0("cs_", names(top_loci_list)[1]) + top_loci <- top_loci_list[[1]] + if (length(top_loci_list) > 1) { + for (i in 2:length(top_loci_list)) { + names(top_loci_list[[i]])[2] <- paste0("cs_", names(top_loci_list)[i]) + top_loci <- dplyr::full_join(top_loci, top_loci_list[[i]], by = "variant_idx") } }