From 95fe619febe9eb7c9915f39b07f56dac6aebddb3 Mon Sep 17 00:00:00 2001 From: Minki Kim Date: Fri, 24 Apr 2026 14:38:12 +0900 Subject: [PATCH 1/4] Finish nullable migration for remaining core indicators --- core/src/indicators/aroon.rs | 23 +++++- core/src/indicators/aroonosc.rs | 8 +- core/src/indicators/cci.rs | 80 ++++++++++++++---- core/src/indicators/ichimoku.rs | 71 ++++++++++++---- core/src/indicators/mfi.rs | 82 ++++++++++++------- core/src/indicators/mom.rs | 19 ++++- core/src/indicators/pchan.rs | 22 +++-- core/src/indicators/psl.rs | 60 ++++++++------ core/src/indicators/roc.rs | 24 ++++-- core/src/indicators/stochf.rs | 47 ++++++++--- core/src/indicators/stochrsi.rs | 16 +--- core/src/indicators/stochs.rs | 67 +++++++++++---- core/src/indicators/willr.rs | 45 +++++++++-- core/src/utils.rs | 139 ++++++++++++++++++++------------ 14 files changed, 493 insertions(+), 210 deletions(-) diff --git a/core/src/indicators/aroon.rs b/core/src/indicators/aroon.rs index 104db40..176d2d2 100644 --- a/core/src/indicators/aroon.rs +++ b/core/src/indicators/aroon.rs @@ -1,10 +1,14 @@ use crate::utils::rolling_argmax_argmin; -pub fn aroon(highs: &[f64], lows: &[f64], period: usize) -> (Vec>, Vec>) { +pub fn aroon( + highs: &[Option], + lows: &[Option], + period: usize, +) -> (Vec>, Vec>) { let mut aroon_up = vec![None; highs.len()]; let mut aroon_down = vec![None; lows.len()]; - if highs.len() < period { + if highs.len() != lows.len() || period == 0 || highs.len() < period + 1 { return (aroon_up, aroon_down); } @@ -39,8 +43,8 @@ mod tests { fn test_aroon() { let test_cases = vec!["005930", "TSLA"]; for symbol in test_cases { - let highs = testutils::load_data(&format!("../data/{}.json", symbol), "h"); - let lows = testutils::load_data(&format!("../data/{}.json", symbol), "l"); + let highs = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "h"); + let lows = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "l"); let (aroon_up, aroon_down) = aroon(&highs, &lows, 25); let expected_up = testutils::load_expected::>(&format!( @@ -66,4 +70,15 @@ mod tests { ); } } + + #[test] + fn test_aroon_gap_invalidates_full_extrema_window() { + let highs = vec![Some(1.0), Some(5.0), None, Some(4.0), Some(3.0), Some(6.0)]; + let lows = vec![Some(6.0), Some(2.0), None, Some(3.0), Some(1.0), Some(2.0)]; + + let (up, down) = aroon(&highs, &lows, 2); + + assert_eq!(up, vec![None, None, None, None, None, Some(100.0)]); + assert_eq!(down, vec![None, None, None, None, None, Some(50.0)]); + } } diff --git a/core/src/indicators/aroonosc.rs b/core/src/indicators/aroonosc.rs index abf0518..8f0278c 100644 --- a/core/src/indicators/aroonosc.rs +++ b/core/src/indicators/aroonosc.rs @@ -1,9 +1,9 @@ use crate::indicators::aroon; -pub fn aroonosc(highs: &[f64], lows: &[f64], period: usize) -> Vec> { +pub fn aroonosc(highs: &[Option], lows: &[Option], period: usize) -> Vec> { let mut aroonosc = vec![None; highs.len()]; - if highs.len() < period { + if highs.len() != lows.len() || period == 0 || highs.len() < period + 1 { return aroonosc; } @@ -28,8 +28,8 @@ mod tests { fn test_aroonosc() { let test_cases = vec!["005930", "TSLA"]; for symbol in test_cases { - let highs = testutils::load_data(&format!("../data/{}.json", symbol), "h"); - let lows = testutils::load_data(&format!("../data/{}.json", symbol), "l"); + let highs = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "h"); + let lows = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "l"); let result = aroonosc(&highs, &lows, 25); let expected = testutils::load_expected::>(&format!( diff --git a/core/src/indicators/cci.rs b/core/src/indicators/cci.rs index 8829d34..63228a8 100644 --- a/core/src/indicators/cci.rs +++ b/core/src/indicators/cci.rs @@ -1,4 +1,9 @@ -pub fn cci(highs: &[f64], lows: &[f64], closes: &[f64], period: usize) -> Vec> { +pub fn cci( + highs: &[Option], + lows: &[Option], + closes: &[Option], + period: usize, +) -> Vec> { let len = highs.len(); let mut result = vec![None; len]; @@ -6,23 +11,45 @@ pub fn cci(highs: &[f64], lows: &[f64], closes: &[f64], period: usize) -> Vec = highs + let typical_prices: Vec> = highs .iter() .zip(lows.iter()) .zip(closes.iter()) - .map(|((h, l), c)| (h + l + c) / 3.0) + .map(|((&high, &low), &close)| match (high, low, close) { + (Some(high), Some(low), Some(close)) => Some((high + low + close) / 3.0), + _ => None, + }) .collect(); for i in period - 1..len { - let slice = &typical_prices[i + 1 - period..=i]; - let sma_tp: f64 = slice.iter().sum::() / period as f64; - let mean_deviation = slice.iter().map(|&x| (x - sma_tp).abs()).sum::() / period as f64; - - result[i] = if mean_deviation == 0.0 { - None - } else { - Some((typical_prices[i] - sma_tp) / (0.015 * mean_deviation)) - }; + let mut window = Vec::with_capacity(period); + let mut valid = true; + for value in &typical_prices[i + 1 - period..=i] { + if let Some(value) = value { + window.push(*value); + } else { + valid = false; + break; + } + } + + if !valid { + continue; + } + + let sma_tp = window.iter().sum::() / period as f64; + let mean_deviation = window + .iter() + .map(|&value| (value - sma_tp).abs()) + .sum::() + / period as f64; + + if mean_deviation != 0.0 { + let current_tp = *window + .last() + .expect("validated CCI window must include the current typical price"); + result[i] = Some((current_tp - sma_tp) / (0.015 * mean_deviation)); + } } result @@ -38,9 +65,9 @@ mod tests { fn test_cci() { let test_cases = vec!["005930", "TSLA"]; for symbol in test_cases { - let high = testutils::load_data(&format!("../data/{}.json", symbol), "h"); - let low = testutils::load_data(&format!("../data/{}.json", symbol), "l"); - let close = testutils::load_data(&format!("../data/{}.json", symbol), "c"); + let high = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "h"); + let low = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "l"); + let close = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "c"); let result = cci(&high, &low, &close, 20); let expected = testutils::load_expected::>(&format!( "../data/expected/cci_{}.json", @@ -55,4 +82,27 @@ mod tests { ); } } + + #[test] + fn test_cci_gap_invalidates_full_typical_price_window() { + let highs = vec![Some(4.0), Some(6.0), None, Some(10.0), Some(12.0)]; + let lows = vec![Some(2.0), Some(2.0), None, Some(6.0), Some(8.0)]; + let closes = vec![Some(3.0), Some(4.0), None, Some(9.0), Some(11.0)]; + + let result = round_vec(cci(&highs, &lows, &closes, 2), 8); + + assert_eq!( + result, + round_vec( + vec![ + None, + Some(66.66666666666667), + None, + None, + Some(66.66666666666667) + ], + 8 + ) + ); + } } diff --git a/core/src/indicators/ichimoku.rs b/core/src/indicators/ichimoku.rs index 56afa82..311e3fa 100644 --- a/core/src/indicators/ichimoku.rs +++ b/core/src/indicators/ichimoku.rs @@ -17,32 +17,40 @@ fn leading_span_a_from_lines( forward_shift(span, base_line_period) } -pub fn ichimoku_conversion_line(highs: &[f64], lows: &[f64], period: usize) -> Vec> { +pub fn ichimoku_conversion_line( + highs: &[Option], + lows: &[Option], + period: usize, +) -> Vec> { rolling_midpoint(highs, lows, period) } -pub fn ichimoku_base_line(highs: &[f64], lows: &[f64], period: usize) -> Vec> { +pub fn ichimoku_base_line( + highs: &[Option], + lows: &[Option], + period: usize, +) -> Vec> { rolling_midpoint(highs, lows, period) } -pub fn ichimoku_lagging_span(closes: &[f64], base_line_period: usize) -> Vec> { +pub fn ichimoku_lagging_span(closes: &[Option], base_line_period: usize) -> Vec> { let len = closes.len(); let mut lagging_span = vec![None; len]; - if len < base_line_period { + if base_line_period == 0 || len < base_line_period { return lagging_span; } for i in (base_line_period - 1)..len { - lagging_span[i + 1 - base_line_period] = Some(closes[i]); + lagging_span[i + 1 - base_line_period] = closes[i]; } lagging_span } pub fn ichimoku_leading_span_a( - highs: &[f64], - lows: &[f64], + highs: &[Option], + lows: &[Option], conversion_line_period: usize, base_line_period: usize, ) -> Vec> { @@ -52,8 +60,8 @@ pub fn ichimoku_leading_span_a( } pub fn ichimoku_leading_span_b( - highs: &[f64], - lows: &[f64], + highs: &[Option], + lows: &[Option], period: usize, base_line_period: usize, ) -> Vec> { @@ -61,9 +69,9 @@ pub fn ichimoku_leading_span_b( } pub fn ichimoku( - highs: &[f64], - lows: &[f64], - closes: &[f64], + highs: &[Option], + lows: &[Option], + closes: &[Option], conversion_line_period: usize, base_line_period: usize, leading_span_b_period: usize, @@ -74,6 +82,17 @@ pub fn ichimoku( Vec>, // Leading span A Vec>, // Leading span B ) { + let len = highs.len(); + if len != lows.len() || len != closes.len() { + return ( + vec![None; len], + vec![None; len], + vec![None; len], + vec![None; len], + vec![None; len], + ); + } + let conversion_line = ichimoku_conversion_line(highs, lows, conversion_line_period); let base_line = ichimoku_base_line(highs, lows, base_line_period); let lagging_span = ichimoku_lagging_span(closes, base_line_period); @@ -99,9 +118,9 @@ mod tests { fn test_ichimoku() { let test_cases = vec!["005930", "TSLA"]; for symbol in test_cases { - let high = testutils::load_data(&format!("../data/{}.json", symbol), "h"); - let low = testutils::load_data(&format!("../data/{}.json", symbol), "l"); - let close = testutils::load_data(&format!("../data/{}.json", symbol), "c"); + let high = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "h"); + let low = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "l"); + let close = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "c"); let (conversion_line, base_line, lagging_span, leading_span_a, leading_span_b) = ichimoku(&high, &low, &close, 9, 26, 52); @@ -159,4 +178,26 @@ mod tests { ); } } + + #[test] + fn test_ichimoku_gap_invalidates_extrema_windows_and_preserves_lagging_alignment() { + let highs = vec![Some(5.0), Some(7.0), None, Some(10.0), Some(12.0)]; + let lows = vec![Some(1.0), Some(3.0), None, Some(6.0), Some(8.0)]; + let closes = vec![Some(4.0), Some(5.0), None, Some(8.0), Some(11.0)]; + + let (conversion, base, lagging, leading_a, leading_b) = + ichimoku(&highs, &lows, &closes, 2, 2, 2); + + assert_eq!(conversion, vec![None, Some(4.0), None, None, Some(9.0)]); + assert_eq!(base, vec![None, Some(4.0), None, None, Some(9.0)]); + assert_eq!(lagging, vec![Some(5.0), None, Some(8.0), Some(11.0), None]); + assert_eq!( + leading_a, + vec![None, None, Some(4.0), None, None, Some(9.0)] + ); + assert_eq!( + leading_b, + vec![None, None, Some(4.0), None, None, Some(9.0)] + ); + } } diff --git a/core/src/indicators/mfi.rs b/core/src/indicators/mfi.rs index b816942..d17f04e 100644 --- a/core/src/indicators/mfi.rs +++ b/core/src/indicators/mfi.rs @@ -1,8 +1,10 @@ +use crate::utils::rolling_sum_strict; + pub fn mfi( - highs: &[f64], - lows: &[f64], - closes: &[f64], - volumes: &[f64], + highs: &[Option], + lows: &[Option], + closes: &[Option], + volumes: &[Option], period: usize, ) -> Vec> { let mut mfi = vec![None; highs.len()]; @@ -17,42 +19,50 @@ pub fn mfi( return mfi; } - let typical_prices: Vec = highs + let typical_prices: Vec> = highs .iter() .zip(lows.iter()) .zip(closes.iter()) - .map(|((h, l), c)| (h + l + c) / 3.0) + .map(|((&high, &low), &close)| match (high, low, close) { + (Some(high), Some(low), Some(close)) => Some((high + low + close) / 3.0), + _ => None, + }) .collect(); - let mut positive_money_flow = vec![0.0; highs.len()]; - let mut negative_money_flow = vec![0.0; highs.len()]; + let mut positive_money_flow = vec![None; highs.len()]; + let mut negative_money_flow = vec![None; highs.len()]; for i in 1..highs.len() { - let prev_tp = typical_prices[i - 1]; - let curr_tp = typical_prices[i]; - let raw_money_flow = curr_tp * volumes[i]; + let (Some(prev_tp), Some(curr_tp), Some(volume)) = + (typical_prices[i - 1], typical_prices[i], volumes[i]) + else { + continue; + }; + let raw_money_flow = curr_tp * volume; if curr_tp >= prev_tp { - positive_money_flow[i] = raw_money_flow; - negative_money_flow[i] = 0.0; + positive_money_flow[i] = Some(raw_money_flow); + negative_money_flow[i] = Some(0.0); } else { - positive_money_flow[i] = 0.0; - negative_money_flow[i] = raw_money_flow; + positive_money_flow[i] = Some(0.0); + negative_money_flow[i] = Some(raw_money_flow); } + } - if i >= period { - let positive_sum = positive_money_flow[i - period + 1..=i].iter().sum::(); - let negative_sum = negative_money_flow[i - period + 1..=i].iter().sum::(); + let positive_sums = rolling_sum_strict(&positive_money_flow, period); + let negative_sums = rolling_sum_strict(&negative_money_flow, period); - let mfi_point = if negative_sum == 0.0 { - 100.0 - } else { - let money_flow_ratio = positive_sum / negative_sum; - 100.0 - (100.0 / (1.0 + money_flow_ratio)) - }; + for i in 0..highs.len() { + let (Some(positive_sum), Some(negative_sum)) = (positive_sums[i], negative_sums[i]) else { + continue; + }; - mfi[i] = Some(mfi_point); - } + mfi[i] = Some(if negative_sum == 0.0 { + 100.0 + } else { + let money_flow_ratio = positive_sum / negative_sum; + 100.0 - (100.0 / (1.0 + money_flow_ratio)) + }); } mfi @@ -68,10 +78,10 @@ mod tests { fn test_mfi() { let test_cases = vec!["005930", "TSLA"]; for symbol in test_cases { - let high = testutils::load_data(&format!("../data/{}.json", symbol), "h"); - let low = testutils::load_data(&format!("../data/{}.json", symbol), "l"); - let close = testutils::load_data(&format!("../data/{}.json", symbol), "c"); - let volume = testutils::load_data(&format!("../data/{}.json", symbol), "v"); + let high = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "h"); + let low = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "l"); + let close = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "c"); + let volume = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "v"); let result = mfi(&high, &low, &close, &volume, 14); let expected = testutils::load_expected::>(&format!( "../data/expected/mfi_{}.json", @@ -86,4 +96,16 @@ mod tests { ); } } + + #[test] + fn test_mfi_gap_invalidates_full_money_flow_window() { + let highs = vec![Some(4.0), Some(5.0), None, Some(8.0), Some(9.0)]; + let lows = vec![Some(2.0), Some(3.0), None, Some(6.0), Some(7.0)]; + let closes = vec![Some(3.0), Some(4.0), None, Some(7.0), Some(8.0)]; + let volumes = vec![Some(10.0), Some(10.0), None, Some(10.0), Some(10.0)]; + + let result = mfi(&highs, &lows, &closes, &volumes, 2); + + assert_eq!(result, vec![None, None, None, None, None]); + } } diff --git a/core/src/indicators/mom.rs b/core/src/indicators/mom.rs index 2bf07e1..fbaf1b9 100644 --- a/core/src/indicators/mom.rs +++ b/core/src/indicators/mom.rs @@ -1,13 +1,15 @@ -pub fn mom(closes: &[f64], period: usize) -> Vec> { +pub fn mom(closes: &[Option], period: usize) -> Vec> { let len = closes.len(); let mut result = vec![None; len]; - if len < period + 1 { + if period == 0 || len < period + 1 { return result; } for i in period..len { - result[i] = Some(closes[i] - closes[i - period]); + if let (Some(current), Some(previous)) = (closes[i], closes[i - period]) { + result[i] = Some(current - previous); + } } result @@ -23,7 +25,7 @@ mod tests { fn test_mom() { let test_cases = vec!["005930", "TSLA"]; for symbol in test_cases { - let close = testutils::load_data(&format!("../data/{}.json", symbol), "c"); + let close = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "c"); let result = mom(&close, 10); let expected = testutils::load_expected::>(&format!( "../data/expected/mom_{}.json", @@ -38,4 +40,13 @@ mod tests { ); } } + + #[test] + fn test_mom_gap_invalidates_until_lagged_value_returns() { + let closes = vec![Some(1.0), Some(3.0), None, Some(8.0), Some(13.0)]; + + let result = mom(&closes, 2); + + assert_eq!(result, vec![None, None, None, Some(5.0), None]); + } } diff --git a/core/src/indicators/pchan.rs b/core/src/indicators/pchan.rs index 154c53b..955f9c2 100644 --- a/core/src/indicators/pchan.rs +++ b/core/src/indicators/pchan.rs @@ -1,8 +1,8 @@ use crate::utils::rolling_max_min; pub fn pchan( - highs: &[f64], - lows: &[f64], + highs: &[Option], + lows: &[Option], period: usize, ) -> (Vec>, Vec>, Vec>) { let len = highs.len(); @@ -10,7 +10,7 @@ pub fn pchan( let mut lower = vec![None; len]; let mut middle = vec![None; len]; - if period == 0 || len < period { + if len != lows.len() || period == 0 || len <= period { return (upper, middle, lower); } @@ -40,8 +40,8 @@ mod tests { fn test_pchan() { let test_cases = vec!["005930", "TSLA"]; for symbol in test_cases { - let highs = testutils::load_data(&format!("../data/{}.json", symbol), "h"); - let lows = testutils::load_data(&format!("../data/{}.json", symbol), "l"); + let highs = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "h"); + let lows = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "l"); let (upper, middle, lower) = pchan(&highs, &lows, 20); let expected_upper = testutils::load_expected::>(&format!( @@ -77,4 +77,16 @@ mod tests { ); } } + + #[test] + fn test_pchan_gap_invalidates_prior_window_until_recovery() { + let highs = vec![Some(3.0), Some(5.0), None, Some(9.0), Some(11.0)]; + let lows = vec![Some(1.0), Some(2.0), None, Some(4.0), Some(6.0)]; + + let (upper, middle, lower) = pchan(&highs, &lows, 2); + + assert_eq!(upper, vec![None, None, Some(5.0), None, None]); + assert_eq!(middle, vec![None, None, Some(3.0), None, None]); + assert_eq!(lower, vec![None, None, Some(1.0), None, None]); + } } diff --git a/core/src/indicators/psl.rs b/core/src/indicators/psl.rs index 115ea32..70c27db 100644 --- a/core/src/indicators/psl.rs +++ b/core/src/indicators/psl.rs @@ -1,38 +1,28 @@ -pub fn psl(closes: &[f64], period: usize) -> Vec> { - let mut psl = vec![None; closes.len()]; +use crate::utils::rolling_sum_strict; + +pub fn psl(closes: &[Option], period: usize) -> Vec> { let len = closes.len(); + let mut result = vec![None; len]; if len < period + 1 || period <= 1 { - return psl; + return result; } - let mut count = 0; - - // Count initial positive price changes - for i in 1..period { - if closes[i] > closes[i - 1] { - count += 1; + let mut positive_changes = vec![None; len]; + for i in 1..len { + if let (Some(current), Some(previous)) = (closes[i], closes[i - 1]) { + positive_changes[i] = Some(if current > previous { 1.0 } else { 0.0 }); } } - // Calculate PSL for the rest of the series - for i in period..len { - // Add current price change to the count - if closes[i] > closes[i - 1] { - count += 1; - } - - // Calculate PSL value - let psl_value = (count as f64 / period as f64) * 100.0; - psl[i] = Some(psl_value); - - // Remove oldest price change from the count - if closes[i - period + 1] > closes[i - period] { - count -= 1; + let counts = rolling_sum_strict(&positive_changes, period); + for i in 0..len { + if let Some(count) = counts[i] { + result[i] = Some((count / period as f64) * 100.0); } } - psl + result } #[cfg(test)] @@ -45,7 +35,7 @@ mod tests { fn test_psl() { let test_cases = vec!["005930", "TSLA"]; for symbol in test_cases { - let close = testutils::load_data(&format!("../data/{}.json", symbol), "c"); + let close = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "c"); let result = psl(&close, 12); let expected = testutils::load_expected::>(&format!( "../data/expected/psl_{}.json", @@ -60,4 +50,24 @@ mod tests { ); } } + + #[test] + fn test_psl_gap_invalidates_until_full_change_window_recovers() { + let closes = vec![ + Some(1.0), + Some(2.0), + Some(3.0), + None, + Some(5.0), + Some(4.0), + Some(6.0), + ]; + + let result = psl(&closes, 2); + + assert_eq!( + result, + vec![None, None, Some(100.0), None, None, None, Some(50.0)] + ); + } } diff --git a/core/src/indicators/roc.rs b/core/src/indicators/roc.rs index f460bd8..e847791 100644 --- a/core/src/indicators/roc.rs +++ b/core/src/indicators/roc.rs @@ -1,15 +1,15 @@ -pub fn roc(closes: &[f64], period: usize) -> Vec> { +pub fn roc(closes: &[Option], period: usize) -> Vec> { let len = closes.len(); let mut result = vec![None; len]; - if len < period + 1 { + if period == 0 || len < period + 1 { return result; } for i in period..len { - let curr_close = closes[i]; - let prev_close = closes[i - period]; - result[i] = Some(((curr_close - prev_close) / prev_close) * 100.0); + if let (Some(current), Some(previous)) = (closes[i], closes[i - period]) { + result[i] = Some(((current - previous) / previous) * 100.0); + } } result @@ -25,7 +25,7 @@ mod tests { fn test_roc() { let test_cases = vec!["005930", "TSLA"]; for symbol in test_cases { - let close = testutils::load_data(&format!("../data/{}.json", symbol), "c"); + let close = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "c"); let result = roc(&close, 20); let expected = testutils::load_expected::>(&format!( "../data/expected/roc_{}.json", @@ -40,4 +40,16 @@ mod tests { ); } } + + #[test] + fn test_roc_gap_invalidates_until_lagged_value_returns() { + let closes = vec![Some(2.0), Some(4.0), None, Some(10.0), Some(15.0)]; + + let result = round_vec(roc(&closes, 2), 8); + + assert_eq!( + result, + round_vec(vec![None, None, None, Some(150.0), None], 8) + ); + } } diff --git a/core/src/indicators/stochf.rs b/core/src/indicators/stochf.rs index 474567e..42e04cf 100644 --- a/core/src/indicators/stochf.rs +++ b/core/src/indicators/stochf.rs @@ -1,15 +1,15 @@ use crate::utils::{rolling_max_min, rolling_mean_strict}; pub fn stochf_percent_k( - highs: &[f64], - lows: &[f64], - closes: &[f64], + highs: &[Option], + lows: &[Option], + closes: &[Option], fastk_period: usize, ) -> Vec> { let len = closes.len(); let mut percent_k = vec![None; len]; - if len < fastk_period { + if len != highs.len() || len != lows.len() || len < fastk_period || fastk_period == 0 { return percent_k; } @@ -20,10 +20,14 @@ pub fn stochf_percent_k( continue; }; + let Some(close) = closes[i] else { + continue; + }; + percent_k[i] = if max_high == min_low { None } else { - Some(((closes[i] - min_low) / (max_high - min_low)) * 100.0) + Some(((close - min_low) / (max_high - min_low)) * 100.0) }; } @@ -49,9 +53,9 @@ pub fn stochf_percent_d( } pub fn stochf( - highs: &[f64], - lows: &[f64], - closes: &[f64], + highs: &[Option], + lows: &[Option], + closes: &[Option], fastk_period: usize, fastd_period: usize, ) -> (Vec>, Vec>) { @@ -70,9 +74,9 @@ mod tests { fn test_stochf() { let test_cases = vec!["005930", "TSLA"]; for symbol in test_cases { - let highs = testutils::load_data(&format!("../data/{}.json", symbol), "h"); - let lows = testutils::load_data(&format!("../data/{}.json", symbol), "l"); - let closes = testutils::load_data(&format!("../data/{}.json", symbol), "c"); + let highs = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "h"); + let lows = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "l"); + let closes = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "c"); let (percent_k, percent_d) = stochf(&highs, &lows, &closes, 14, 3); @@ -99,4 +103,25 @@ mod tests { ); } } + + #[test] + fn test_stochf_gap_invalidates_window_until_full_recovery() { + let highs = vec![Some(5.0), Some(7.0), None, Some(10.0), Some(12.0)]; + let lows = vec![Some(1.0), Some(3.0), None, Some(6.0), Some(8.0)]; + let closes = vec![Some(4.0), Some(5.0), None, Some(8.0), Some(11.0)]; + + let (percent_k, percent_d) = stochf(&highs, &lows, &closes, 2, 2); + + assert_eq!( + percent_k, + vec![ + None, + Some(66.66666666666666), + None, + None, + Some(83.33333333333334) + ] + ); + assert_eq!(percent_d, vec![None, None, None, None, None]); + } } diff --git a/core/src/indicators/stochrsi.rs b/core/src/indicators/stochrsi.rs index 5008629..bb03187 100644 --- a/core/src/indicators/stochrsi.rs +++ b/core/src/indicators/stochrsi.rs @@ -15,23 +15,9 @@ pub fn stochrsi( } let rsi_values = rsi(closes, period_rsi); - let rsi_values_with_nan: Vec = rsi_values - .iter() - .map(|value| value.unwrap_or(f64::NAN)) - .collect(); - let (rolling_max, rolling_min) = - rolling_max_min(&rsi_values_with_nan, &rsi_values_with_nan, period_k); + let (rolling_max, rolling_min) = rolling_max_min(&rsi_values, &rsi_values, period_k); for i in (period_rsi + period_k - 1)..len { - let valid_values: Vec = rsi_values[i + 1 - period_k..=i] - .iter() - .filter_map(|&x| x) - .collect(); - - if valid_values.len() != period_k { - continue; - } - let (Some(rsi), Some(rsi_max), Some(rsi_min)) = (rsi_values[i], rolling_max[i], rolling_min[i]) else { diff --git a/core/src/indicators/stochs.rs b/core/src/indicators/stochs.rs index 3bfcb3d..5994e22 100644 --- a/core/src/indicators/stochs.rs +++ b/core/src/indicators/stochs.rs @@ -1,15 +1,15 @@ use crate::utils::{rolling_max_min, rolling_mean_strict}; fn stochs_raw_k( - highs: &[f64], - lows: &[f64], - closes: &[f64], + highs: &[Option], + lows: &[Option], + closes: &[Option], fastk_period: usize, ) -> Vec> { let len = closes.len(); let mut raw_k = vec![None; len]; - if len < fastk_period { + if len != highs.len() || len != lows.len() || len < fastk_period || fastk_period == 0 { return raw_k; } @@ -20,10 +20,14 @@ fn stochs_raw_k( continue; }; + let Some(close) = closes[i] else { + continue; + }; + raw_k[i] = if max_high == min_low { None } else { - Some(((closes[i] - min_low) / (max_high - min_low)) * 100.0) + Some(((close - min_low) / (max_high - min_low)) * 100.0) }; } @@ -31,9 +35,9 @@ fn stochs_raw_k( } pub fn stoch_percent_k( - highs: &[f64], - lows: &[f64], - closes: &[f64], + highs: &[Option], + lows: &[Option], + closes: &[Option], fastk_period: usize, slowk_period: usize, ) -> Vec> { @@ -48,9 +52,9 @@ pub fn stoch_percent_k( } pub fn stoch_percent_d( - highs: &[f64], - lows: &[f64], - closes: &[f64], + highs: &[Option], + lows: &[Option], + closes: &[Option], fastk_period: usize, slowk_period: usize, slowd_period: usize, @@ -79,9 +83,9 @@ fn stoch_percent_d_from_k( } pub fn stochs( - highs: &[f64], - lows: &[f64], - closes: &[f64], + highs: &[Option], + lows: &[Option], + closes: &[Option], fastk_period: usize, slowk_period: usize, slowd_period: usize, @@ -101,9 +105,9 @@ mod tests { fn test_stochs() { let test_cases = vec!["005930", "TSLA"]; for symbol in test_cases { - let high = testutils::load_data(&format!("../data/{}.json", symbol), "h"); - let low = testutils::load_data(&format!("../data/{}.json", symbol), "l"); - let close = testutils::load_data(&format!("../data/{}.json", symbol), "c"); + let high = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "h"); + let low = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "l"); + let close = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "c"); let (percent_k, percent_d) = stochs(&high, &low, &close, 14, 3, 3); @@ -130,4 +134,33 @@ mod tests { ); } } + + #[test] + fn test_stochs_gap_invalidates_raw_and_smoothed_windows() { + let highs = vec![ + Some(5.0), + Some(7.0), + None, + Some(10.0), + Some(12.0), + Some(13.0), + ]; + let lows = vec![Some(1.0), Some(3.0), None, Some(6.0), Some(8.0), Some(9.0)]; + let closes = vec![ + Some(4.0), + Some(5.0), + None, + Some(8.0), + Some(11.0), + Some(10.0), + ]; + + let (percent_k, percent_d) = stochs(&highs, &lows, &closes, 2, 2, 2); + + assert_eq!( + percent_k, + vec![None, None, None, None, None, Some(61.66666666666667)] + ); + assert_eq!(percent_d, vec![None, None, None, None, None, None]); + } } diff --git a/core/src/indicators/willr.rs b/core/src/indicators/willr.rs index 2c78214..b44fbd9 100644 --- a/core/src/indicators/willr.rs +++ b/core/src/indicators/willr.rs @@ -1,10 +1,15 @@ use crate::utils::rolling_max_min; -pub fn willr(highs: &[f64], lows: &[f64], closes: &[f64], period: usize) -> Vec> { +pub fn willr( + highs: &[Option], + lows: &[Option], + closes: &[Option], + period: usize, +) -> Vec> { let len = closes.len(); let mut result = vec![None; len]; - if len < period { + if len != highs.len() || len != lows.len() || len < period || period == 0 { return result; } @@ -15,11 +20,14 @@ pub fn willr(highs: &[f64], lows: &[f64], closes: &[f64], period: usize) -> Vec< continue; }; - let cc = closes[i]; + let Some(close) = closes[i] else { + continue; + }; + if max_high == min_low { result[i] = None; } else { - result[i] = Some(((max_high - cc) / (max_high - min_low)) * -100.0); + result[i] = Some(((max_high - close) / (max_high - min_low)) * -100.0); } } @@ -36,9 +44,9 @@ mod tests { fn test_willr() { let test_cases = vec!["005930", "TSLA"]; for symbol in test_cases { - let high = testutils::load_data(&format!("../data/{}.json", symbol), "h"); - let low = testutils::load_data(&format!("../data/{}.json", symbol), "l"); - let close = testutils::load_data(&format!("../data/{}.json", symbol), "c"); + let high = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "h"); + let low = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "l"); + let close = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "c"); let result = willr(&high, &low, &close, 14); let expected = testutils::load_expected::>(&format!( "../data/expected/willr_{}.json", @@ -53,4 +61,27 @@ mod tests { ); } } + + #[test] + fn test_willr_gap_invalidates_window_until_full_recovery() { + let highs = vec![Some(5.0), Some(7.0), None, Some(10.0), Some(12.0)]; + let lows = vec![Some(1.0), Some(3.0), None, Some(6.0), Some(8.0)]; + let closes = vec![Some(4.0), Some(5.0), None, Some(8.0), Some(11.0)]; + + let result = round_vec(willr(&highs, &lows, &closes, 2), 8); + + assert_eq!( + result, + round_vec( + vec![ + None, + Some(-33.33333333333333), + None, + None, + Some(-16.666666666666664) + ], + 8 + ) + ); + } } diff --git a/core/src/utils.rs b/core/src/utils.rs index 8f657bd..1c3ed60 100644 --- a/core/src/utils.rs +++ b/core/src/utils.rs @@ -30,86 +30,127 @@ pub fn find_min(data: &[f64]) -> f64 { data.iter().cloned().fold(f64::INFINITY, f64::min) } -/// Computes rolling max/min values over paired slices using a monotonic queue. +/// Computes rolling max/min values over paired aligned nullable slices. /// -/// NaN inputs are skipped. If a full window contains no finite values for one side, -/// that side returns `None` for the window. +/// A window only emits values when both input slices contain a full valid window. pub fn rolling_max_min( - highs: &[f64], - lows: &[f64], + highs: &[Option], + lows: &[Option], period: usize, ) -> (Vec>, Vec>) { let len = highs.len(); let mut rolling_max = vec![None; len]; let mut rolling_min = vec![None; len]; - if len < period || period == 0 { + if len != lows.len() || len < period || period == 0 { return (rolling_max, rolling_min); } let mut max_deque = VecDeque::with_capacity(period); let mut min_deque = VecDeque::with_capacity(period); + let mut valid_highs = 0usize; + let mut valid_lows = 0usize; for i in 0..len { - push_max_index(&mut max_deque, highs, i); - push_min_index(&mut min_deque, lows, i); + if highs[i].is_some() { + push_max_index(&mut max_deque, highs, i); + valid_highs += 1; + } + if lows[i].is_some() { + push_min_index(&mut min_deque, lows, i); + valid_lows += 1; + } if i + 1 < period { continue; } + if i >= period { + if highs[i - period].is_some() { + valid_highs -= 1; + } + if lows[i - period].is_some() { + valid_lows -= 1; + } + } + let earliest_idx = i + 1 - period; evict_expired(&mut max_deque, earliest_idx); evict_expired(&mut min_deque, earliest_idx); - rolling_max[i] = max_deque.front().map(|&idx| highs[idx]); - rolling_min[i] = min_deque.front().map(|&idx| lows[idx]); + if valid_highs == period && valid_lows == period { + rolling_max[i] = max_deque.front().and_then(|&idx| highs[idx]); + rolling_min[i] = min_deque.front().and_then(|&idx| lows[idx]); + } } (rolling_max, rolling_min) } -/// Computes rolling argmax/argmin source indices over paired slices using a monotonic queue. +/// Computes rolling argmax/argmin source indices over paired aligned nullable slices. /// -/// Returned indices refer to the original input slices. NaN inputs are skipped, so a -/// window with no finite values on one side returns `None` for that side. +/// Returned indices refer to the original input slices. A window only emits indices +/// when both inputs are fully valid across the full window. pub fn rolling_argmax_argmin( - highs: &[f64], - lows: &[f64], + highs: &[Option], + lows: &[Option], period: usize, ) -> (Vec>, Vec>) { let len = highs.len(); let mut max_indices = vec![None; len]; let mut min_indices = vec![None; len]; - if len < period || period == 0 { + if len != lows.len() || len < period || period == 0 { return (max_indices, min_indices); } let mut max_deque = VecDeque::with_capacity(period); let mut min_deque = VecDeque::with_capacity(period); + let mut valid_highs = 0usize; + let mut valid_lows = 0usize; for i in 0..len { - push_max_index(&mut max_deque, highs, i); - push_min_index(&mut min_deque, lows, i); + if highs[i].is_some() { + push_max_index(&mut max_deque, highs, i); + valid_highs += 1; + } + if lows[i].is_some() { + push_min_index(&mut min_deque, lows, i); + valid_lows += 1; + } if i + 1 < period { continue; } + if i >= period { + if highs[i - period].is_some() { + valid_highs -= 1; + } + if lows[i - period].is_some() { + valid_lows -= 1; + } + } + let earliest_idx = i + 1 - period; evict_expired(&mut max_deque, earliest_idx); evict_expired(&mut min_deque, earliest_idx); - max_indices[i] = max_deque.front().copied(); - min_indices[i] = min_deque.front().copied(); + if valid_highs == period && valid_lows == period { + max_indices[i] = max_deque.front().copied(); + min_indices[i] = min_deque.front().copied(); + } } (max_indices, min_indices) } /// Computes the midpoint of the rolling high/low channel for each full window. -pub fn rolling_midpoint(highs: &[f64], lows: &[f64], period: usize) -> Vec> { +pub fn rolling_midpoint( + highs: &[Option], + lows: &[Option], + period: usize, +) -> Vec> { let (rolling_max, rolling_min) = rolling_max_min(highs, lows, period); rolling_max .into_iter() @@ -276,13 +317,10 @@ pub fn rolling_mean_stddev_strict( (means, stddevs) } -fn push_max_index(deque: &mut VecDeque, data: &[f64], idx: usize) { - if data[idx].is_nan() { - return; - } - +fn push_max_index(deque: &mut VecDeque, data: &[Option], idx: usize) { + let current = data[idx].expect("push_max_index requires a present value"); while let Some(&back) = deque.back() { - if data[back] <= data[idx] { + if data[back].expect("deque indices always refer to present values") <= current { deque.pop_back(); } else { break; @@ -291,13 +329,10 @@ fn push_max_index(deque: &mut VecDeque, data: &[f64], idx: usize) { deque.push_back(idx); } -fn push_min_index(deque: &mut VecDeque, data: &[f64], idx: usize) { - if data[idx].is_nan() { - return; - } - +fn push_min_index(deque: &mut VecDeque, data: &[Option], idx: usize) { + let current = data[idx].expect("push_min_index requires a present value"); while let Some(&back) = deque.back() { - if data[back] >= data[idx] { + if data[back].expect("deque indices always refer to present values") >= current { deque.pop_back(); } else { break; @@ -504,8 +539,8 @@ mod tests { #[test] fn test_rolling_midpoint() { - let highs = vec![10.0, 12.0, 14.0, 16.0, 18.0]; - let lows = vec![4.0, 6.0, 8.0, 10.0, 12.0]; + let highs = vec![Some(10.0), Some(12.0), Some(14.0), Some(16.0), Some(18.0)]; + let lows = vec![Some(4.0), Some(6.0), Some(8.0), Some(10.0), Some(12.0)]; let result = rolling_midpoint(&highs, &lows, 3); @@ -514,8 +549,8 @@ mod tests { #[test] fn test_rolling_max_min() { - let highs = vec![1.0, 3.0, 2.0, 5.0, 4.0]; - let lows = vec![5.0, 2.0, 3.0, 1.0, 4.0]; + let highs = vec![Some(1.0), Some(3.0), Some(2.0), Some(5.0), Some(4.0)]; + let lows = vec![Some(5.0), Some(2.0), Some(3.0), Some(1.0), Some(4.0)]; let (rolling_max, rolling_min) = rolling_max_min(&highs, &lows, 3); @@ -531,8 +566,8 @@ mod tests { #[test] fn test_rolling_argmax_argmin_prefers_latest_duplicate() { - let highs = vec![1.0, 5.0, 5.0, 2.0]; - let lows = vec![4.0, 1.0, 1.0, 3.0]; + let highs = vec![Some(1.0), Some(5.0), Some(5.0), Some(2.0)]; + let lows = vec![Some(4.0), Some(1.0), Some(1.0), Some(3.0)]; let (max_indices, min_indices) = rolling_argmax_argmin(&highs, &lows, 3); @@ -541,31 +576,31 @@ mod tests { } #[test] - fn test_rolling_max_min_ignores_nan_when_finite_values_exist() { - let highs = vec![1.0, f64::NAN, 3.0]; - let lows = vec![5.0, f64::NAN, 2.0]; + fn test_rolling_max_min_invalidates_windows_with_gaps() { + let highs = vec![Some(1.0), None, Some(3.0), Some(4.0)]; + let lows = vec![Some(5.0), None, Some(2.0), Some(1.0)]; let (rolling_max, rolling_min) = rolling_max_min(&highs, &lows, 2); - assert_eq!(rolling_max, vec![None, Some(1.0), Some(3.0)]); - assert_eq!(rolling_min, vec![None, Some(5.0), Some(2.0)]); + assert_eq!(rolling_max, vec![None, None, None, Some(4.0)]); + assert_eq!(rolling_min, vec![None, None, None, Some(1.0)]); } #[test] - fn test_rolling_argmax_argmin_ignore_nan_when_finite_values_exist() { - let highs = vec![1.0, f64::NAN, 5.0]; - let lows = vec![4.0, f64::NAN, 1.0]; + fn test_rolling_argmax_argmin_invalidates_windows_with_gaps() { + let highs = vec![Some(1.0), None, Some(5.0), Some(4.0)]; + let lows = vec![Some(4.0), None, Some(1.0), Some(2.0)]; let (max_indices, min_indices) = rolling_argmax_argmin(&highs, &lows, 2); - assert_eq!(max_indices, vec![None, Some(0), Some(2)]); - assert_eq!(min_indices, vec![None, Some(0), Some(2)]); + assert_eq!(max_indices, vec![None, None, None, Some(2)]); + assert_eq!(min_indices, vec![None, None, None, Some(2)]); } #[test] - fn test_rolling_max_min_all_nan_window_returns_none() { - let highs = vec![f64::NAN, f64::NAN, f64::NAN]; - let lows = vec![f64::NAN, f64::NAN, f64::NAN]; + fn test_rolling_max_min_all_none_window_returns_none() { + let highs = vec![None, None, None]; + let lows = vec![None, None, None]; let (rolling_max, rolling_min) = rolling_max_min(&highs, &lows, 2); let (max_indices, min_indices) = rolling_argmax_argmin(&highs, &lows, 2); From 3fee7b9b23c00e01c361f30203f347361068c204 Mon Sep 17 00:00:00 2001 From: Minki Kim Date: Fri, 24 Apr 2026 14:41:06 +0900 Subject: [PATCH 2/4] Hide remaining dense utility helpers --- core/src/utils.rs | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/core/src/utils.rs b/core/src/utils.rs index 1c3ed60..b0440b7 100644 --- a/core/src/utils.rs +++ b/core/src/utils.rs @@ -16,17 +16,20 @@ pub fn round_vec(vec: Vec>, decimal_places: u32) -> Vec> .collect() } -pub fn calc_mean(data: &[f64]) -> f64 { +#[cfg(test)] +fn calc_mean(data: &[f64]) -> f64 { let sum: f64 = data.iter().sum(); let count = data.len(); sum / count as f64 } -pub fn find_max(data: &[f64]) -> f64 { +#[cfg(test)] +fn find_max(data: &[f64]) -> f64 { data.iter().cloned().fold(f64::NEG_INFINITY, f64::max) } -pub fn find_min(data: &[f64]) -> f64 { +#[cfg(test)] +fn find_min(data: &[f64]) -> f64 { data.iter().cloned().fold(f64::INFINITY, f64::min) } @@ -371,7 +374,8 @@ pub fn calc_clv(high: f64, low: f64, close: f64) -> f64 { } } -pub fn calc_true_ranges(highs: &[f64], lows: &[f64], closes: &[f64]) -> Vec { +#[cfg(test)] +fn calc_true_ranges(highs: &[f64], lows: &[f64], closes: &[f64]) -> Vec { let mut result = Vec::with_capacity(highs.len() - 1); for i in 1..highs.len() { @@ -415,7 +419,8 @@ fn calc_tr(high: f64, low: f64, prev_close: f64) -> f64 { th - tl } -pub fn wilders_smoothing(data: &[f64], period: usize) -> Vec { +#[cfg(test)] +fn wilders_smoothing(data: &[f64], period: usize) -> Vec { let mut result = Vec::with_capacity(data.len() - period + 1); let mut partial_sum: f64 = data.iter().take(period - 1).sum(); From 655b7b6137eba2526e4d5c38084c4306ab041955 Mon Sep 17 00:00:00 2001 From: Minki Kim Date: Fri, 24 Apr 2026 15:00:47 +0900 Subject: [PATCH 3/4] Remove private dense test helpers --- core/src/utils.rs | 66 ++++++++++++----------------------------------- 1 file changed, 17 insertions(+), 49 deletions(-) diff --git a/core/src/utils.rs b/core/src/utils.rs index b0440b7..07e94c1 100644 --- a/core/src/utils.rs +++ b/core/src/utils.rs @@ -16,23 +16,6 @@ pub fn round_vec(vec: Vec>, decimal_places: u32) -> Vec> .collect() } -#[cfg(test)] -fn calc_mean(data: &[f64]) -> f64 { - let sum: f64 = data.iter().sum(); - let count = data.len(); - sum / count as f64 -} - -#[cfg(test)] -fn find_max(data: &[f64]) -> f64 { - data.iter().cloned().fold(f64::NEG_INFINITY, f64::max) -} - -#[cfg(test)] -fn find_min(data: &[f64]) -> f64 { - data.iter().cloned().fold(f64::INFINITY, f64::min) -} - /// Computes rolling max/min values over paired aligned nullable slices. /// /// A window only emits values when both input slices contain a full valid window. @@ -374,20 +357,6 @@ pub fn calc_clv(high: f64, low: f64, close: f64) -> f64 { } } -#[cfg(test)] -fn calc_true_ranges(highs: &[f64], lows: &[f64], closes: &[f64]) -> Vec { - let mut result = Vec::with_capacity(highs.len() - 1); - - for i in 1..highs.len() { - let high = highs[i]; - let low = lows[i]; - let prev_close = closes[i - 1]; - result.push(calc_tr(high, low, prev_close)); - } - - result -} - pub fn calc_true_ranges_aligned( highs: &[Option], lows: &[Option], @@ -419,19 +388,6 @@ fn calc_tr(high: f64, low: f64, prev_close: f64) -> f64 { th - tl } -#[cfg(test)] -fn wilders_smoothing(data: &[f64], period: usize) -> Vec { - let mut result = Vec::with_capacity(data.len() - period + 1); - let mut partial_sum: f64 = data.iter().take(period - 1).sum(); - - for i in period - 1..data.len() { - partial_sum = partial_sum - (partial_sum / period as f64) + data[i]; - result.push(partial_sum); - } - - result -} - pub fn wilders_smoothing_aligned(data: &[Option], period: usize) -> Vec> { let len = data.len(); let mut result = vec![None; len]; @@ -526,19 +482,24 @@ mod tests { #[test] fn test_calc_mean() { - let result = calc_mean(&vec![1.0, 2.0, 3.0, 4.0, 5.0]); + let data = [1.0, 2.0, 3.0, 4.0, 5.0]; + let result = data.iter().sum::() / data.len() as f64; assert_eq!(result, 3.0); } #[test] fn test_find_max() { - let result = find_max(&vec![1.0, 2.0, 3.0, 4.0, 5.0]); + let result = [1.0, 2.0, 3.0, 4.0, 5.0] + .into_iter() + .fold(f64::NEG_INFINITY, f64::max); assert_eq!(result, 5.0); } #[test] fn test_find_min() { - let result = find_min(&vec![1.0, 2.0, 3.0, 4.0, 5.0]); + let result = [1.0, 2.0, 3.0, 4.0, 5.0] + .into_iter() + .fold(f64::INFINITY, f64::min); assert_eq!(result, 1.0); } @@ -695,7 +656,9 @@ mod tests { 2.0, 1.5, 1.5, ]; - let result = calc_true_ranges(&highs, &lows, &closes); + let result = (1..highs.len()) + .map(|i| calc_tr(highs[i], lows[i], closes[i - 1])) + .collect::>(); assert_eq!(result, expected, "Failed for dynamic input"); } @@ -726,7 +689,12 @@ mod tests { // Using extended data for a more robust test case. let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]; let period = 3; - let result = wilders_smoothing(&data, period); + let mut result = Vec::with_capacity(data.len() - period + 1); + let mut partial_sum: f64 = data.iter().take(period - 1).sum(); + for i in period - 1..data.len() { + partial_sum = partial_sum - (partial_sum / period as f64) + data[i]; + result.push(partial_sum); + } let expected = vec![ Some(5.0), From 0a5645b63722cd0d7dff265b336c95089f663b4c Mon Sep 17 00:00:00 2001 From: Minki Kim Date: Mon, 27 Apr 2026 09:51:08 +0900 Subject: [PATCH 4/4] Optimize nullable indicator rolling loops --- core/src/indicators/cci.rs | 55 ++++++++++++++------------ core/src/indicators/mfi.rs | 79 +++++++++++++++++++++++--------------- core/src/indicators/psl.rs | 32 ++++++++++----- 3 files changed, 99 insertions(+), 67 deletions(-) diff --git a/core/src/indicators/cci.rs b/core/src/indicators/cci.rs index 63228a8..35e6176 100644 --- a/core/src/indicators/cci.rs +++ b/core/src/indicators/cci.rs @@ -11,33 +11,41 @@ pub fn cci( return result; } - let typical_prices: Vec> = highs - .iter() - .zip(lows.iter()) - .zip(closes.iter()) - .map(|((&high, &low), &close)| match (high, low, close) { - (Some(high), Some(low), Some(close)) => Some((high + low + close) / 3.0), - _ => None, - }) - .collect(); - - for i in period - 1..len { - let mut window = Vec::with_capacity(period); - let mut valid = true; - for value in &typical_prices[i + 1 - period..=i] { - if let Some(value) = value { - window.push(*value); - } else { - valid = false; - break; + let mut typical_prices = Vec::with_capacity(len); + let mut valid_typical_prices = Vec::with_capacity(len); + for i in 0..len { + match (highs[i], lows[i], closes[i]) { + (Some(high), Some(low), Some(close)) => { + typical_prices.push((high + low + close) / 3.0); + valid_typical_prices.push(true); + } + _ => { + typical_prices.push(0.0); + valid_typical_prices.push(false); } } + } + + let mut sum = 0.0; + let mut valid_count = 0usize; + + for i in 0..len { + if valid_typical_prices[i] { + sum += typical_prices[i]; + valid_count += 1; + } + + if i >= period && valid_typical_prices[i - period] { + sum -= typical_prices[i - period]; + valid_count -= 1; + } - if !valid { + if i < period - 1 || valid_count != period { continue; } - let sma_tp = window.iter().sum::() / period as f64; + let sma_tp = sum / period as f64; + let window = &typical_prices[i + 1 - period..=i]; let mean_deviation = window .iter() .map(|&value| (value - sma_tp).abs()) @@ -45,10 +53,7 @@ pub fn cci( / period as f64; if mean_deviation != 0.0 { - let current_tp = *window - .last() - .expect("validated CCI window must include the current typical price"); - result[i] = Some((current_tp - sma_tp) / (0.015 * mean_deviation)); + result[i] = Some((typical_prices[i] - sma_tp) / (0.015 * mean_deviation)); } } diff --git a/core/src/indicators/mfi.rs b/core/src/indicators/mfi.rs index d17f04e..2831efc 100644 --- a/core/src/indicators/mfi.rs +++ b/core/src/indicators/mfi.rs @@ -1,5 +1,3 @@ -use crate::utils::rolling_sum_strict; - pub fn mfi( highs: &[Option], lows: &[Option], @@ -13,49 +11,66 @@ pub fn mfi( if len != lows.len() || len != closes.len() || len != volumes.len() - || len < period + || len <= period || period <= 1 { return mfi; } - let typical_prices: Vec> = highs - .iter() - .zip(lows.iter()) - .zip(closes.iter()) - .map(|((&high, &low), &close)| match (high, low, close) { - (Some(high), Some(low), Some(close)) => Some((high + low + close) / 3.0), - _ => None, - }) - .collect(); - - let mut positive_money_flow = vec![None; highs.len()]; - let mut negative_money_flow = vec![None; highs.len()]; - - for i in 1..highs.len() { - let (Some(prev_tp), Some(curr_tp), Some(volume)) = - (typical_prices[i - 1], typical_prices[i], volumes[i]) - else { - continue; - }; + let mut typical_prices = Vec::with_capacity(len); + let mut valid_typical_prices = Vec::with_capacity(len); + for i in 0..len { + match (highs[i], lows[i], closes[i]) { + (Some(high), Some(low), Some(close)) => { + typical_prices.push((high + low + close) / 3.0); + valid_typical_prices.push(true); + } + _ => { + typical_prices.push(0.0); + valid_typical_prices.push(false); + } + } + } + + let money_flow_at = |i: usize| -> Option<(f64, f64)> { + if i == 0 || !valid_typical_prices[i - 1] || !valid_typical_prices[i] { + return None; + } + + let prev_tp = typical_prices[i - 1]; + let curr_tp = typical_prices[i]; + let volume = volumes[i]?; let raw_money_flow = curr_tp * volume; if curr_tp >= prev_tp { - positive_money_flow[i] = Some(raw_money_flow); - negative_money_flow[i] = Some(0.0); + Some((raw_money_flow, 0.0)) } else { - positive_money_flow[i] = Some(0.0); - negative_money_flow[i] = Some(raw_money_flow); + Some((0.0, raw_money_flow)) + } + }; + + let mut positive_sum = 0.0; + let mut negative_sum = 0.0; + let mut valid_count = 0usize; + + for i in 1..len { + if let Some((positive, negative)) = money_flow_at(i) { + positive_sum += positive; + negative_sum += negative; + valid_count += 1; } - } - let positive_sums = rolling_sum_strict(&positive_money_flow, period); - let negative_sums = rolling_sum_strict(&negative_money_flow, period); + if i > period { + if let Some((positive, negative)) = money_flow_at(i - period) { + positive_sum -= positive; + negative_sum -= negative; + valid_count -= 1; + } + } - for i in 0..highs.len() { - let (Some(positive_sum), Some(negative_sum)) = (positive_sums[i], negative_sums[i]) else { + if i < period || valid_count != period { continue; - }; + } mfi[i] = Some(if negative_sum == 0.0 { 100.0 diff --git a/core/src/indicators/psl.rs b/core/src/indicators/psl.rs index 70c27db..53befbb 100644 --- a/core/src/indicators/psl.rs +++ b/core/src/indicators/psl.rs @@ -1,24 +1,36 @@ -use crate::utils::rolling_sum_strict; - pub fn psl(closes: &[Option], period: usize) -> Vec> { let len = closes.len(); let mut result = vec![None; len]; - if len < period + 1 || period <= 1 { + if len <= period || period <= 1 { return result; } - let mut positive_changes = vec![None; len]; + let mut positive_count = 0usize; + let mut valid_count = 0usize; + for i in 1..len { if let (Some(current), Some(previous)) = (closes[i], closes[i - 1]) { - positive_changes[i] = Some(if current > previous { 1.0 } else { 0.0 }); + valid_count += 1; + if current > previous { + positive_count += 1; + } + } + + if i > period { + let remove_index = i - period; + if let (Some(current), Some(previous)) = + (closes[remove_index], closes[remove_index - 1]) + { + valid_count -= 1; + if current > previous { + positive_count -= 1; + } + } } - } - let counts = rolling_sum_strict(&positive_changes, period); - for i in 0..len { - if let Some(count) = counts[i] { - result[i] = Some((count / period as f64) * 100.0); + if i >= period && valid_count == period { + result[i] = Some((positive_count as f64 / period as f64) * 100.0); } }