Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 28 additions & 44 deletions R/susie_wrapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -296,40 +296,10 @@ susie_rss_pipeline <- function(sumstats, LD_mat, n = NULL, var_y = NULL, L = 5,

#' @noRd
get_cs_index <- function(snps_idx, susie_cs) {
# Use pmap to iterate over each vector in susie_cs
idx_lengths <- tryCatch(
{
pmap(list(x = susie_cs), function(x) {
# Check if snps_idx is in the CS and return the length of the CS if it is
if (snps_idx %in% x) {
return(length(x))
} else {
return(NA_integer_)
}
}) %>% unlist()
},
error = function(e) NA_integer_
)
idx <- which(!is.na(idx_lengths))
# idx length should be either 1 or 0
# But in some rare cases there will be a convergence issue resulting in a variant belong to multiple CS
# In which case we will keep one of them, with a warning
if (length(idx) > 0) {
if (length(idx) > 1) {
smallest_cs_idx <- which.min(idx_lengths[idx])
selected_cs <- idx[smallest_cs_idx]
selected_length <- idx_lengths[selected_cs]

warning(sprintf(
"Variable %d found in multiple CS: %s. Keeping smallest: CS %d (length %d).",
snps_idx, paste(idx, collapse = ", "), selected_cs, selected_length
))
idx <- selected_cs # Keep index with smallest length
}
return(idx)
} else {
return(NA_integer_)
}
# Return ALL CS indices that contain this variant (not just one)
idx <- which(vapply(susie_cs, function(x) snps_idx %in% x, logical(1)))
if (length(idx) == 0) return(NA_integer_)
return(idx)
}
#' @noRd
get_top_variants_idx <- function(susie_output, signal_cutoff) {
Expand All @@ -338,11 +308,21 @@ get_top_variants_idx <- function(susie_output, signal_cutoff) {
sort()
}
#' @noRd
#' Returns a data.frame(variant_idx, cs_idx) with one row per (variant, CS) pair.
#' Variants in multiple CSs get multiple rows.
get_cs_info <- function(susie_output_sets_cs, top_variants_idx) {
cs_info_pri <- map_int(top_variants_idx, ~ get_cs_index(.x, susie_output_sets_cs))
ifelse(is.na(cs_info_pri), 0, as.numeric(str_replace(names(susie_output_sets_cs)[cs_info_pri], "L", "")))
cs_names <- names(susie_output_sets_cs)
rows <- lapply(top_variants_idx, function(vi) {
idx <- get_cs_index(vi, susie_output_sets_cs)
if (length(idx) == 1 && is.na(idx)) {
data.frame(variant_idx = vi, cs_idx = 0L, stringsAsFactors = FALSE)
} else {
cs_nums <- as.integer(str_replace(cs_names[idx], "L", ""))
data.frame(variant_idx = rep(vi, length(cs_nums)), cs_idx = cs_nums, stringsAsFactors = FALSE)
}
})
do.call(rbind, rows)
}

#' @noRd
get_cs_and_corr <- function(susie_output, coverage, data_x, mode = c("susie", "susie_rss", "mvsusie"), min_abs_corr = NULL) {
if (mode %in% c("susie", "mvsusie")) {
Expand Down Expand Up @@ -434,30 +414,34 @@ susie_post_processor <- function(susie_output, data_x, data_y, X_scalar, y_scala
if (length(eff_idx) > 0) {
# Prepare for top loci table
top_variants_idx_pri <- get_top_variants_idx(susie_output, signal_cutoff)
cs_pri <- get_cs_info(susie_output$sets$cs, top_variants_idx_pri)
# get_cs_info returns data.frame(variant_idx, cs_idx) with one row per (variant, CS) pair
top_loci_pri <- get_cs_info(susie_output$sets$cs, top_variants_idx_pri)
if (is.null(top_loci_pri)) top_loci_pri <- data.frame(variant_idx = integer(0), cs_idx = integer(0))
susie_output$cs_corr <- if (mode %in% c("susie", "mvsusie")) get_cs_correlation(susie_output, X = data_x) else get_cs_correlation(susie_output, Xcorr = data_x)
top_loci_list <- list("coverage_0.95" = data.frame(variant_idx = top_variants_idx_pri, cs_idx = cs_pri, stringsAsFactors = FALSE))
top_loci_list <- list("coverage_0.95" = top_loci_pri)

## Loop over each secondary coverage value
## Loop over each secondary coverage value independently
sets_secondary <- list()
if (!is.null(secondary_coverage) && length(secondary_coverage)) {
for (sec_cov in secondary_coverage) {
sets_secondary[[paste0("coverage_", sec_cov)]] <- get_cs_and_corr(susie_output, sec_cov, data_x, mode, min_abs_corr)
top_variants_idx_sec <- get_top_variants_idx(sets_secondary[[paste0("coverage_", sec_cov)]], signal_cutoff)
cs_sec <- get_cs_info(sets_secondary[[paste0("coverage_", sec_cov)]]$sets$cs, top_variants_idx_sec)
top_loci_list[[paste0("coverage_", sec_cov)]] <- data.frame(variant_idx = top_variants_idx_sec, cs_idx = cs_sec, stringsAsFactors = FALSE)
top_loci_sec <- get_cs_info(sets_secondary[[paste0("coverage_", sec_cov)]]$sets$cs, top_variants_idx_sec)
if (is.null(top_loci_sec)) top_loci_sec <- data.frame(variant_idx = integer(0), cs_idx = integer(0))
top_loci_list[[paste0("coverage_", sec_cov)]] <- top_loci_sec
}
}

# Iterate over the remaining tables, rename and merge them
# Merge coverage tables via full_join
names(top_loci_list[[1]])[2] <- paste0("cs_", names(top_loci_list)[1])
top_loci <- top_loci_list[[1]]
if (length(top_loci_list) > 1) {
for (i in 2:length(top_loci_list)) {
names(top_loci_list[[i]])[2] <- paste0("cs_", names(top_loci_list)[i])
top_loci <- full_join(top_loci, top_loci_list[[i]], by = "variant_idx")
top_loci <- dplyr::full_join(top_loci, top_loci_list[[i]], by = "variant_idx")
}
}

if (nrow(top_loci) > 0) {
top_loci[is.na(top_loci)] <- 0
variants <- res$variant_names[top_loci$variant_idx]
Expand Down
21 changes: 17 additions & 4 deletions tests/testthat/test_susie_wrapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,29 @@ test_that("get_cs_info maps variants to CS numbers", {
susie_cs <- list(L1 = c(1, 2), L3 = c(4, 5, 6))
top_idx <- c(1, 3, 5)
result <- pecotmr:::get_cs_info(susie_cs, top_idx)
expect_equal(result[1], 1)
expect_equal(result[2], 0)
expect_equal(result[3], 3)
# Now returns data.frame(variant_idx, cs_idx) with one row per (variant, CS) pair
expect_true(is.data.frame(result))
expect_equal(result$variant_idx, c(1, 3, 5))
expect_equal(result$cs_idx, c(1L, 0L, 3L))
})

test_that("get_cs_info handles all variants outside CS", {
susie_cs <- list(L1 = c(1, 2))
top_idx <- c(5, 6, 7)
result <- pecotmr:::get_cs_info(susie_cs, top_idx)
expect_true(all(result == 0))
expect_true(is.data.frame(result))
expect_true(all(result$cs_idx == 0))
})

test_that("get_cs_info reports variant in multiple CSs as multiple rows", {
susie_cs <- list(L1 = c(1, 2, 3), L3 = c(2, 3, 4))
top_idx <- c(1, 2, 4)
result <- pecotmr:::get_cs_info(susie_cs, top_idx)
expect_true(is.data.frame(result))
# variant 2 is in both L1 and L3, so it gets two rows
expect_equal(nrow(result), 4)
expect_equal(sum(result$variant_idx == 2), 2)
expect_equal(sort(result$cs_idx[result$variant_idx == 2]), c(1L, 3L))
})

# =============================================================================
Expand Down
Loading