diff --git a/core/src/indicators/cci.rs b/core/src/indicators/cci.rs index 35e6176..71cfca5 100644 --- a/core/src/indicators/cci.rs +++ b/core/src/indicators/cci.rs @@ -1,8 +1,23 @@ +use crate::indicators::ema::ema; + pub fn cci( highs: &[Option], lows: &[Option], closes: &[Option], period: usize, + signal_period: usize, +) -> (Vec>, Vec>) { + let line = cci_line(highs, lows, closes, period); + let signal = ema(&line, signal_period); + + (line, signal) +} + +pub fn cci_line( + highs: &[Option], + lows: &[Option], + closes: &[Option], + period: usize, ) -> Vec> { let len = highs.len(); let mut result = vec![None; len]; @@ -60,6 +75,17 @@ pub fn cci( result } +pub fn cci_signal( + highs: &[Option], + lows: &[Option], + closes: &[Option], + period: usize, + signal_period: usize, +) -> Vec> { + let line = cci_line(highs, lows, closes, period); + ema(&line, signal_period) +} + #[cfg(test)] mod tests { use super::*; @@ -73,7 +99,7 @@ mod tests { let high = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "h"); let low = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "l"); let close = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "c"); - let result = cci(&high, &low, &close, 20); + let (result, signal) = cci(&high, &low, &close, 20, 9); let expected = testutils::load_expected::>(&format!( "../data/expected/cci_{}.json", symbol @@ -85,6 +111,7 @@ mod tests { "CCI test failed for symbol {}.", symbol ); + assert_eq!(signal, cci_signal(&high, &low, &close, 20, 9)); } } @@ -94,7 +121,7 @@ mod tests { let lows = vec![Some(2.0), Some(2.0), None, Some(6.0), Some(8.0)]; let closes = vec![Some(3.0), Some(4.0), None, Some(9.0), Some(11.0)]; - let result = round_vec(cci(&highs, &lows, &closes, 2), 8); + let result = round_vec(cci_line(&highs, &lows, &closes, 2), 8); assert_eq!( result, @@ -110,4 +137,33 @@ mod tests { ) ); } + + #[test] + fn test_cci_signal_follows_base_ema_contract() { + let highs = vec![ + Some(4.0), + Some(6.0), + None, + Some(10.0), + Some(12.0), + Some(14.0), + ]; + let lows = vec![Some(2.0), Some(2.0), None, Some(6.0), Some(8.0), Some(10.0)]; + let closes = vec![ + Some(3.0), + Some(4.0), + None, + Some(9.0), + Some(11.0), + Some(12.0), + ]; + + let line = cci_line(&highs, &lows, &closes, 2); + let signal = cci_signal(&highs, &lows, &closes, 2, 2); + let (composite_line, composite_signal) = cci(&highs, &lows, &closes, 2, 2); + + assert_eq!(signal, ema(&line, 2)); + assert_eq!(composite_line, line); + assert_eq!(composite_signal, signal); + } } diff --git a/core/src/indicators/rsi.rs b/core/src/indicators/rsi.rs index d1ae07b..3956f91 100644 --- a/core/src/indicators/rsi.rs +++ b/core/src/indicators/rsi.rs @@ -1,4 +1,17 @@ -pub fn rsi(data: &[Option], period: usize) -> Vec> { +use crate::indicators::ema::ema; + +pub fn rsi( + data: &[Option], + period: usize, + signal_period: usize, +) -> (Vec>, Vec>) { + let line = rsi_line(data, period); + let signal = ema(&line, signal_period); + + (line, signal) +} + +pub fn rsi_line(data: &[Option], period: usize) -> Vec> { let mut rsi = vec![None; data.len()]; if data.len() < period || period <= 1 { @@ -59,6 +72,11 @@ pub fn rsi(data: &[Option], period: usize) -> Vec> { rsi } +pub fn rsi_signal(data: &[Option], period: usize, signal_period: usize) -> Vec> { + let line = rsi_line(data, period); + ema(&line, signal_period) +} + #[cfg(test)] mod tests { use super::*; @@ -73,7 +91,7 @@ mod tests { .into_iter() .map(Some) .collect::>(); - let result = rsi(&input, 14); + let (result, signal) = rsi(&input, 14, 9); let expected = testutils::load_expected::>(&format!( "../data/expected/rsi_{}.json", symbol @@ -85,6 +103,7 @@ mod tests { "RSI test failed for symbol {}.", symbol ); + assert_eq!(signal, rsi_signal(&input, 14, 9)); } } @@ -100,7 +119,7 @@ mod tests { Some(5.0), ]; - let result = rsi(&aligned, 3); + let result = rsi_line(&aligned, 3); assert_eq!( result, @@ -128,11 +147,33 @@ mod tests { Some(4.0), ]; - let result = rsi(&aligned, 3); + let result = rsi_line(&aligned, 3); assert_eq!( result, vec![None, None, None, None, None, None, Some(66.66666666666666)] ); } + + #[test] + fn test_rsi_signal_follows_base_ema_contract() { + let aligned = vec![ + Some(1.0), + Some(2.0), + Some(3.0), + Some(2.0), + None, + Some(4.0), + Some(5.0), + Some(6.0), + ]; + + let line = rsi_line(&aligned, 3); + let signal = rsi_signal(&aligned, 3, 2); + let (composite_line, composite_signal) = rsi(&aligned, 3, 2); + + assert_eq!(signal, ema(&line, 2)); + assert_eq!(composite_line, line); + assert_eq!(composite_signal, signal); + } } diff --git a/core/src/indicators/stochrsi.rs b/core/src/indicators/stochrsi.rs index bb03187..b8ab687 100644 --- a/core/src/indicators/stochrsi.rs +++ b/core/src/indicators/stochrsi.rs @@ -1,4 +1,4 @@ -use crate::indicators::rsi::rsi; +use crate::indicators::rsi::rsi_line; use crate::utils::{rolling_max_min, rolling_mean_strict}; pub fn stochrsi( @@ -14,7 +14,7 @@ pub fn stochrsi( return (percent_k, vec![None; len]); } - let rsi_values = rsi(closes, period_rsi); + let rsi_values = rsi_line(closes, period_rsi); let (rolling_max, rolling_min) = rolling_max_min(&rsi_values, &rsi_values, period_k); for i in (period_rsi + period_k - 1)..len { diff --git a/polars/Cargo.toml b/polars/Cargo.toml index 7727875..7c49789 100644 --- a/polars/Cargo.toml +++ b/polars/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "polars_techr" -version = "0.1.4" +version = "0.1.5" edition = "2021" description = "Polars expression plugins for techr indicators" license = "MIT" diff --git a/polars/README.md b/polars/README.md index 201d35a..56256dd 100644 --- a/polars/README.md +++ b/polars/README.md @@ -10,10 +10,10 @@ uv add techr ## Supported indicators -- Price/trend: `sma`, `wma`, `ema`, `disparity`, `mom`, `roc`, `rsi`, `psl` +- Price/trend: `sma`, `wma`, `ema`, `disparity`, `mom`, `roc`, `rsi`, `rsi_line`, `rsi_signal`, `psl` - Bands/channels: `bband_middle`, `bband_lower`, `bband_upper`, `env_upper`, `env_middle`, `env_lower`, `pchan_upper`, `pchan_middle`, `pchan_lower` - Momentum/oscillators: `macd`, `macd_line`, `macd_signal`, `macd_hist`, `macd_histogram`, `ppo_line`, `ppo_signal`, `ppo_histogram`, `pvo_line`, `pvo_signal`, `pvo_histogram`, `sonar_line`, `sonar_signal`, `trix_line`, `trix_signal`, `stochf_percent_k`, `stochf_percent_d`, `stoch_percent_k`, `stoch_percent_d`, `stochrsi_percent_k`, `stochrsi_percent_d` -- High/low/volume indicators: `ad`, `adx`, `adxr`, `aroon_up`, `aroon_down`, `aroonosc`, `atr`, `cci`, `cmf`, `co`, `cv`, `dmi_plus`, `dmi_minus`, `efi`, `eom_line`, `eom_signal`, `erbear`, `erbull`, `massi_line`, `massi_signal`, `mfi`, `nvi_line`, `nvi_signal`, `obv_line`, `obv_signal`, `psar`, `pvi_line`, `pvi_signal`, `ultosc`, `vr`, `willr` +- High/low/volume indicators: `ad`, `adx`, `adxr`, `aroon_up`, `aroon_down`, `aroonosc`, `atr`, `cci`, `cci_line`, `cci_signal`, `cmf`, `co`, `cv`, `dmi_plus`, `dmi_minus`, `efi`, `eom_line`, `eom_signal`, `erbear`, `erbull`, `massi_line`, `massi_signal`, `mfi`, `nvi_line`, `nvi_signal`, `obv_line`, `obv_signal`, `psar`, `pvi_line`, `pvi_signal`, `ultosc`, `vr`, `willr` - Ichimoku: `ichimoku_base_line`, `ichimoku_conversion_line`, `ichimoku_leading_span_a`, `ichimoku_leading_span_b`, `ichimoku_lagging_span` ## Usage diff --git a/polars/src/expressions.rs b/polars/src/expressions.rs index e20b465..d51d7e9 100644 --- a/polars/src/expressions.rs +++ b/polars/src/expressions.rs @@ -21,6 +21,12 @@ struct PeriodKwargs { period: u32, } +#[derive(Deserialize)] +struct PeriodSignalKwargs { + period: u32, + signal_period: u32, +} + #[derive(Deserialize)] struct BBandKwargs { period: u32, @@ -146,12 +152,6 @@ struct SonarSignalKwargs { signal_period: u32, } -#[derive(Deserialize)] -struct PeriodSignalKwargs { - period: u32, - signal_period: u32, -} - #[derive(Deserialize)] struct StochRsiKwargs { period_rsi: u32, @@ -495,11 +495,38 @@ fn cci(inputs: &[Series], kwargs: PeriodKwargs) -> PolarsResult { let highs = series_to_f64_vec(&inputs[0])?; let lows = series_to_f64_vec(&inputs[1])?; let closes = series_to_f64_vec(&inputs[2])?; - Ok(option_vec_to_series(core::cci( + Ok(option_vec_to_series(core::cci_line( + &highs, + &lows, + &closes, + kwargs.period as usize, + ))) +} + +#[polars_expr(output_type=Float64)] +fn cci_line(inputs: &[Series], kwargs: PeriodKwargs) -> PolarsResult { + let highs = series_to_f64_vec(&inputs[0])?; + let lows = series_to_f64_vec(&inputs[1])?; + let closes = series_to_f64_vec(&inputs[2])?; + Ok(option_vec_to_series(core::cci_line( + &highs, + &lows, + &closes, + kwargs.period as usize, + ))) +} + +#[polars_expr(output_type=Float64)] +fn cci_signal(inputs: &[Series], kwargs: PeriodSignalKwargs) -> PolarsResult { + let highs = series_to_f64_vec(&inputs[0])?; + let lows = series_to_f64_vec(&inputs[1])?; + let closes = series_to_f64_vec(&inputs[2])?; + Ok(option_vec_to_series(core::cci_signal( &highs, &lows, &closes, kwargs.period as usize, + kwargs.signal_period as usize, ))) } @@ -873,12 +900,31 @@ fn roc(inputs: &[Series], kwargs: PeriodKwargs) -> PolarsResult { #[polars_expr(output_type=Float64)] fn rsi(inputs: &[Series], kwargs: PeriodKwargs) -> PolarsResult { let input = series_to_f64_vec(&inputs[0])?; - Ok(option_vec_to_series(core::rsi( + Ok(option_vec_to_series(core::rsi_line( &input, kwargs.period as usize, ))) } +#[polars_expr(output_type=Float64)] +fn rsi_line(inputs: &[Series], kwargs: PeriodKwargs) -> PolarsResult { + let input = series_to_f64_vec(&inputs[0])?; + Ok(option_vec_to_series(core::rsi_line( + &input, + kwargs.period as usize, + ))) +} + +#[polars_expr(output_type=Float64)] +fn rsi_signal(inputs: &[Series], kwargs: PeriodSignalKwargs) -> PolarsResult { + let input = series_to_f64_vec(&inputs[0])?; + Ok(option_vec_to_series(core::rsi_signal( + &input, + kwargs.period as usize, + kwargs.signal_period as usize, + ))) +} + #[polars_expr(output_type=Float64)] fn trix_line(inputs: &[Series], kwargs: PeriodKwargs) -> PolarsResult { let input = series_to_f64_vec(&inputs[0])?; diff --git a/polars/techr/__init__.py b/polars/techr/__init__.py index e3ffb1d..2f5dbef 100644 --- a/polars/techr/__init__.py +++ b/polars/techr/__init__.py @@ -20,6 +20,8 @@ "bband_middle", "bband_upper", "cci", + "cci_line", + "cci_signal", "cmf", "co", "cv", @@ -68,6 +70,8 @@ "pvo_signal", "roc", "rsi", + "rsi_line", + "rsi_signal", "sma", "sonar_line", "sonar_signal", @@ -379,6 +383,25 @@ def cci(high: IntoExpr, low: IntoExpr, close: IntoExpr, *, period: int) -> pl.Ex return _register("cci", [high, low, close], {"period": period}) +def cci_line(high: IntoExpr, low: IntoExpr, close: IntoExpr, *, period: int) -> pl.Expr: + return _register("cci_line", [high, low, close], {"period": period}) + + +def cci_signal( + high: IntoExpr, + low: IntoExpr, + close: IntoExpr, + *, + period: int, + signal_period: int, +) -> pl.Expr: + return _register( + "cci_signal", + [high, low, close], + {"period": period, "signal_period": signal_period}, + ) + + def cmf( high: IntoExpr, low: IntoExpr, @@ -706,6 +729,18 @@ def rsi(expr: IntoExpr, *, period: int) -> pl.Expr: return _register("rsi", [expr], {"period": period}) +def rsi_line(expr: IntoExpr, *, period: int) -> pl.Expr: + return _register("rsi_line", [expr], {"period": period}) + + +def rsi_signal(expr: IntoExpr, *, period: int, signal_period: int) -> pl.Expr: + return _register( + "rsi_signal", + [expr], + {"period": period, "signal_period": signal_period}, + ) + + def trix_line(expr: IntoExpr, *, period: int) -> pl.Expr: return _register("trix_line", [expr], {"period": period}) diff --git a/polars/tests/test_indicators.py b/polars/tests/test_indicators.py index d07d253..1f26046 100644 --- a/polars/tests/test_indicators.py +++ b/polars/tests/test_indicators.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Callable +from typing import Callable, cast import polars as pl import pytest @@ -51,11 +51,10 @@ def assert_values_close( def select_expr(df: pl.DataFrame, expr: pl.Expr, alias: str, lazy: bool) -> pl.Series: - query = df.lazy() if lazy else df - result = query.select(expr.alias(alias)) if lazy: - result = result.collect() - return result.get_column(alias) + result = cast(pl.DataFrame, df.lazy().select(expr.alias(alias)).collect()) + return result.get_column(alias) + return df.select(expr.alias(alias)).get_column(alias) SeriesExprBuilder = Callable[[], pl.Expr] @@ -125,6 +124,11 @@ def select_expr(df: pl.DataFrame, expr: pl.Expr, alias: str, lazy: bool) -> pl.S lambda: ta.cci(pl.col("high"), pl.col("low"), pl.col("close"), period=20), "cci", ), + ( + "cci_line", + lambda: ta.cci_line(pl.col("high"), pl.col("low"), pl.col("close"), period=20), + "cci", + ), ( "cmf", lambda: ta.cmf( @@ -392,6 +396,7 @@ def select_expr(df: pl.DataFrame, expr: pl.Expr, alias: str, lazy: bool) -> pl.S ), ("roc", lambda: ta.roc(pl.col("close"), period=20), "roc"), ("rsi", lambda: ta.rsi(pl.col("close"), period=14), "rsi"), + ("rsi_line", lambda: ta.rsi_line(pl.col("close"), period=14), "rsi"), ("bband_middle", lambda: ta.bband_middle(pl.col("close"), period=20), "sma"), ( "bband_lower", @@ -657,6 +662,63 @@ def test_single_input_null_values_follow_core_gap_semantics(lazy: bool) -> None: assert_values_close(result.to_list(), [None, None, None, 3.5]) +@pytest.mark.parametrize("lazy", [False, True]) +def test_rsi_signal_matches_ema_of_rsi(lazy: bool) -> None: + """RSI signal is the EMA of the RSI line.""" + # given + df = load_ohlcv("TSLA") + + # when + signal = select_expr( + df, + ta.rsi_signal(pl.col("close"), period=14, signal_period=9), + "signal", + lazy, + ) + expected = select_expr( + df, + ta.ema(ta.rsi(pl.col("close"), period=14), period=9), + "expected", + lazy, + ) + + # then + assert_values_close(signal.to_list(), expected.to_list()) + + +@pytest.mark.parametrize("lazy", [False, True]) +def test_cci_signal_matches_ema_of_cci(lazy: bool) -> None: + """CCI signal is the EMA of the CCI line.""" + # given + df = load_ohlcv("TSLA") + + # when + signal = select_expr( + df, + ta.cci_signal( + pl.col("high"), + pl.col("low"), + pl.col("close"), + period=20, + signal_period=9, + ), + "signal", + lazy, + ) + expected = select_expr( + df, + ta.ema( + ta.cci(pl.col("high"), pl.col("low"), pl.col("close"), period=20), + period=9, + ), + "expected", + lazy, + ) + + # then + assert_values_close(signal.to_list(), expected.to_list()) + + @pytest.mark.parametrize("lazy", [False, True]) def test_multi_input_null_values_follow_core_gap_semantics(lazy: bool) -> None: """Accept null values for multi-input indicators."""