diff --git a/R/susie_wrapper.R b/R/susie_wrapper.R index ff6d225c..1aa0eac4 100644 --- a/R/susie_wrapper.R +++ b/R/susie_wrapper.R @@ -296,40 +296,10 @@ susie_rss_pipeline <- function(sumstats, LD_mat, n = NULL, var_y = NULL, L = 5, #' @noRd get_cs_index <- function(snps_idx, susie_cs) { - # Use pmap to iterate over each vector in susie_cs - idx_lengths <- tryCatch( - { - pmap(list(x = susie_cs), function(x) { - # Check if snps_idx is in the CS and return the length of the CS if it is - if (snps_idx %in% x) { - return(length(x)) - } else { - return(NA_integer_) - } - }) %>% unlist() - }, - error = function(e) NA_integer_ - ) - idx <- which(!is.na(idx_lengths)) - # idx length should be either 1 or 0 - # But in some rare cases there will be a convergence issue resulting in a variant belong to multiple CS - # In which case we will keep one of them, with a warning - if (length(idx) > 0) { - if (length(idx) > 1) { - smallest_cs_idx <- which.min(idx_lengths[idx]) - selected_cs <- idx[smallest_cs_idx] - selected_length <- idx_lengths[selected_cs] - - warning(sprintf( - "Variable %d found in multiple CS: %s. Keeping smallest: CS %d (length %d).", - snps_idx, paste(idx, collapse = ", "), selected_cs, selected_length - )) - idx <- selected_cs # Keep index with smallest length - } - return(idx) - } else { - return(NA_integer_) - } + # Return ALL CS indices that contain this variant (not just one) + idx <- which(vapply(susie_cs, function(x) snps_idx %in% x, logical(1))) + if (length(idx) == 0) return(NA_integer_) + return(idx) } #' @noRd get_top_variants_idx <- function(susie_output, signal_cutoff) { @@ -338,11 +308,21 @@ get_top_variants_idx <- function(susie_output, signal_cutoff) { sort() } #' @noRd +#' Returns a data.frame(variant_idx, cs_idx) with one row per (variant, CS) pair. +#' Variants in multiple CSs get multiple rows. get_cs_info <- function(susie_output_sets_cs, top_variants_idx) { - cs_info_pri <- map_int(top_variants_idx, ~ get_cs_index(.x, susie_output_sets_cs)) - ifelse(is.na(cs_info_pri), 0, as.numeric(str_replace(names(susie_output_sets_cs)[cs_info_pri], "L", ""))) + cs_names <- names(susie_output_sets_cs) + rows <- lapply(top_variants_idx, function(vi) { + idx <- get_cs_index(vi, susie_output_sets_cs) + if (length(idx) == 1 && is.na(idx)) { + data.frame(variant_idx = vi, cs_idx = 0L, stringsAsFactors = FALSE) + } else { + cs_nums <- as.integer(str_replace(cs_names[idx], "L", "")) + data.frame(variant_idx = rep(vi, length(cs_nums)), cs_idx = cs_nums, stringsAsFactors = FALSE) + } + }) + do.call(rbind, rows) } - #' @noRd get_cs_and_corr <- function(susie_output, coverage, data_x, mode = c("susie", "susie_rss", "mvsusie"), min_abs_corr = NULL) { if (mode %in% c("susie", "mvsusie")) { @@ -434,30 +414,34 @@ susie_post_processor <- function(susie_output, data_x, data_y, X_scalar, y_scala if (length(eff_idx) > 0) { # Prepare for top loci table top_variants_idx_pri <- get_top_variants_idx(susie_output, signal_cutoff) - cs_pri <- get_cs_info(susie_output$sets$cs, top_variants_idx_pri) + # get_cs_info returns data.frame(variant_idx, cs_idx) with one row per (variant, CS) pair + top_loci_pri <- get_cs_info(susie_output$sets$cs, top_variants_idx_pri) + if (is.null(top_loci_pri)) top_loci_pri <- data.frame(variant_idx = integer(0), cs_idx = integer(0)) susie_output$cs_corr <- if (mode %in% c("susie", "mvsusie")) get_cs_correlation(susie_output, X = data_x) else get_cs_correlation(susie_output, Xcorr = data_x) - top_loci_list <- list("coverage_0.95" = data.frame(variant_idx = top_variants_idx_pri, cs_idx = cs_pri, stringsAsFactors = FALSE)) + top_loci_list <- list("coverage_0.95" = top_loci_pri) - ## Loop over each secondary coverage value + ## Loop over each secondary coverage value independently sets_secondary <- list() if (!is.null(secondary_coverage) && length(secondary_coverage)) { for (sec_cov in secondary_coverage) { sets_secondary[[paste0("coverage_", sec_cov)]] <- get_cs_and_corr(susie_output, sec_cov, data_x, mode, min_abs_corr) top_variants_idx_sec <- get_top_variants_idx(sets_secondary[[paste0("coverage_", sec_cov)]], signal_cutoff) - cs_sec <- get_cs_info(sets_secondary[[paste0("coverage_", sec_cov)]]$sets$cs, top_variants_idx_sec) - top_loci_list[[paste0("coverage_", sec_cov)]] <- data.frame(variant_idx = top_variants_idx_sec, cs_idx = cs_sec, stringsAsFactors = FALSE) + top_loci_sec <- get_cs_info(sets_secondary[[paste0("coverage_", sec_cov)]]$sets$cs, top_variants_idx_sec) + if (is.null(top_loci_sec)) top_loci_sec <- data.frame(variant_idx = integer(0), cs_idx = integer(0)) + top_loci_list[[paste0("coverage_", sec_cov)]] <- top_loci_sec } } - # Iterate over the remaining tables, rename and merge them + # Merge coverage tables via full_join names(top_loci_list[[1]])[2] <- paste0("cs_", names(top_loci_list)[1]) top_loci <- top_loci_list[[1]] if (length(top_loci_list) > 1) { for (i in 2:length(top_loci_list)) { names(top_loci_list[[i]])[2] <- paste0("cs_", names(top_loci_list)[i]) - top_loci <- full_join(top_loci, top_loci_list[[i]], by = "variant_idx") + top_loci <- dplyr::full_join(top_loci, top_loci_list[[i]], by = "variant_idx") } } + if (nrow(top_loci) > 0) { top_loci[is.na(top_loci)] <- 0 variants <- res$variant_names[top_loci$variant_idx] diff --git a/tests/testthat/test_susie_wrapper.R b/tests/testthat/test_susie_wrapper.R index e82b5b1b..571d9b95 100644 --- a/tests/testthat/test_susie_wrapper.R +++ b/tests/testthat/test_susie_wrapper.R @@ -201,16 +201,29 @@ test_that("get_cs_info maps variants to CS numbers", { susie_cs <- list(L1 = c(1, 2), L3 = c(4, 5, 6)) top_idx <- c(1, 3, 5) result <- pecotmr:::get_cs_info(susie_cs, top_idx) - expect_equal(result[1], 1) - expect_equal(result[2], 0) - expect_equal(result[3], 3) + # Now returns data.frame(variant_idx, cs_idx) with one row per (variant, CS) pair + expect_true(is.data.frame(result)) + expect_equal(result$variant_idx, c(1, 3, 5)) + expect_equal(result$cs_idx, c(1L, 0L, 3L)) }) test_that("get_cs_info handles all variants outside CS", { susie_cs <- list(L1 = c(1, 2)) top_idx <- c(5, 6, 7) result <- pecotmr:::get_cs_info(susie_cs, top_idx) - expect_true(all(result == 0)) + expect_true(is.data.frame(result)) + expect_true(all(result$cs_idx == 0)) +}) + +test_that("get_cs_info reports variant in multiple CSs as multiple rows", { + susie_cs <- list(L1 = c(1, 2, 3), L3 = c(2, 3, 4)) + top_idx <- c(1, 2, 4) + result <- pecotmr:::get_cs_info(susie_cs, top_idx) + expect_true(is.data.frame(result)) + # variant 2 is in both L1 and L3, so it gets two rows + expect_equal(nrow(result), 4) + expect_equal(sum(result$variant_idx == 2), 2) + expect_equal(sort(result$cs_idx[result$variant_idx == 2]), c(1L, 3L)) }) # =============================================================================