Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions core/src/indicators/cci.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
use crate::indicators::ema::ema;

pub fn cci(
highs: &[Option<f64>],
lows: &[Option<f64>],
closes: &[Option<f64>],
period: usize,
signal_period: usize,
) -> (Vec<Option<f64>>, Vec<Option<f64>>) {
let line = cci_line(highs, lows, closes, period);
let signal = ema(&line, signal_period);

(line, signal)
}

pub fn cci_line(
highs: &[Option<f64>],
lows: &[Option<f64>],
closes: &[Option<f64>],
period: usize,
) -> Vec<Option<f64>> {
let len = highs.len();
let mut result = vec![None; len];
Expand Down Expand Up @@ -60,6 +75,17 @@ pub fn cci(
result
}

pub fn cci_signal(
highs: &[Option<f64>],
lows: &[Option<f64>],
closes: &[Option<f64>],
period: usize,
signal_period: usize,
) -> Vec<Option<f64>> {
let line = cci_line(highs, lows, closes, period);
ema(&line, signal_period)
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -73,7 +99,7 @@ mod tests {
let high = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "h");
let low = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "l");
let close = testutils::load_data_nullable(&format!("../data/{}.json", symbol), "c");
let result = cci(&high, &low, &close, 20);
let (result, signal) = cci(&high, &low, &close, 20, 9);
let expected = testutils::load_expected::<Option<f64>>(&format!(
"../data/expected/cci_{}.json",
symbol
Expand All @@ -85,6 +111,7 @@ mod tests {
"CCI test failed for symbol {}.",
symbol
);
assert_eq!(signal, cci_signal(&high, &low, &close, 20, 9));
}
}

Expand All @@ -94,7 +121,7 @@ mod tests {
let lows = vec![Some(2.0), Some(2.0), None, Some(6.0), Some(8.0)];
let closes = vec![Some(3.0), Some(4.0), None, Some(9.0), Some(11.0)];

let result = round_vec(cci(&highs, &lows, &closes, 2), 8);
let result = round_vec(cci_line(&highs, &lows, &closes, 2), 8);

assert_eq!(
result,
Expand All @@ -110,4 +137,33 @@ mod tests {
)
);
}

#[test]
fn test_cci_signal_follows_base_ema_contract() {
let highs = vec![
Some(4.0),
Some(6.0),
None,
Some(10.0),
Some(12.0),
Some(14.0),
];
let lows = vec![Some(2.0), Some(2.0), None, Some(6.0), Some(8.0), Some(10.0)];
let closes = vec![
Some(3.0),
Some(4.0),
None,
Some(9.0),
Some(11.0),
Some(12.0),
];

let line = cci_line(&highs, &lows, &closes, 2);
let signal = cci_signal(&highs, &lows, &closes, 2, 2);
let (composite_line, composite_signal) = cci(&highs, &lows, &closes, 2, 2);

assert_eq!(signal, ema(&line, 2));
assert_eq!(composite_line, line);
assert_eq!(composite_signal, signal);
}
}
49 changes: 45 additions & 4 deletions core/src/indicators/rsi.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
pub fn rsi(data: &[Option<f64>], period: usize) -> Vec<Option<f64>> {
use crate::indicators::ema::ema;

pub fn rsi(
data: &[Option<f64>],
period: usize,
signal_period: usize,
) -> (Vec<Option<f64>>, Vec<Option<f64>>) {
let line = rsi_line(data, period);
let signal = ema(&line, signal_period);

(line, signal)
}

pub fn rsi_line(data: &[Option<f64>], period: usize) -> Vec<Option<f64>> {
let mut rsi = vec![None; data.len()];

if data.len() < period || period <= 1 {
Expand Down Expand Up @@ -59,6 +72,11 @@ pub fn rsi(data: &[Option<f64>], period: usize) -> Vec<Option<f64>> {
rsi
}

pub fn rsi_signal(data: &[Option<f64>], period: usize, signal_period: usize) -> Vec<Option<f64>> {
let line = rsi_line(data, period);
ema(&line, signal_period)
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -73,7 +91,7 @@ mod tests {
.into_iter()
.map(Some)
.collect::<Vec<_>>();
let result = rsi(&input, 14);
let (result, signal) = rsi(&input, 14, 9);
let expected = testutils::load_expected::<Option<f64>>(&format!(
"../data/expected/rsi_{}.json",
symbol
Expand All @@ -85,6 +103,7 @@ mod tests {
"RSI test failed for symbol {}.",
symbol
);
assert_eq!(signal, rsi_signal(&input, 14, 9));
}
}

Expand All @@ -100,7 +119,7 @@ mod tests {
Some(5.0),
];

let result = rsi(&aligned, 3);
let result = rsi_line(&aligned, 3);

assert_eq!(
result,
Expand Down Expand Up @@ -128,11 +147,33 @@ mod tests {
Some(4.0),
];

let result = rsi(&aligned, 3);
let result = rsi_line(&aligned, 3);

assert_eq!(
result,
vec![None, None, None, None, None, None, Some(66.66666666666666)]
);
}

#[test]
fn test_rsi_signal_follows_base_ema_contract() {
let aligned = vec![
Some(1.0),
Some(2.0),
Some(3.0),
Some(2.0),
None,
Some(4.0),
Some(5.0),
Some(6.0),
];

let line = rsi_line(&aligned, 3);
let signal = rsi_signal(&aligned, 3, 2);
let (composite_line, composite_signal) = rsi(&aligned, 3, 2);

assert_eq!(signal, ema(&line, 2));
assert_eq!(composite_line, line);
assert_eq!(composite_signal, signal);
}
}
4 changes: 2 additions & 2 deletions core/src/indicators/stochrsi.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::indicators::rsi::rsi;
use crate::indicators::rsi::rsi_line;
use crate::utils::{rolling_max_min, rolling_mean_strict};

pub fn stochrsi(
Expand All @@ -14,7 +14,7 @@ pub fn stochrsi(
return (percent_k, vec![None; len]);
}

let rsi_values = rsi(closes, period_rsi);
let rsi_values = rsi_line(closes, period_rsi);
let (rolling_max, rolling_min) = rolling_max_min(&rsi_values, &rsi_values, period_k);

for i in (period_rsi + period_k - 1)..len {
Expand Down
2 changes: 1 addition & 1 deletion polars/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "polars_techr"
version = "0.1.4"
version = "0.1.5"
edition = "2021"
description = "Polars expression plugins for techr indicators"
license = "MIT"
Expand Down
4 changes: 2 additions & 2 deletions polars/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ uv add techr

## Supported indicators

- Price/trend: `sma`, `wma`, `ema`, `disparity`, `mom`, `roc`, `rsi`, `psl`
- Price/trend: `sma`, `wma`, `ema`, `disparity`, `mom`, `roc`, `rsi`, `rsi_line`, `rsi_signal`, `psl`
- Bands/channels: `bband_middle`, `bband_lower`, `bband_upper`, `env_upper`, `env_middle`, `env_lower`, `pchan_upper`, `pchan_middle`, `pchan_lower`
- Momentum/oscillators: `macd`, `macd_line`, `macd_signal`, `macd_hist`, `macd_histogram`, `ppo_line`, `ppo_signal`, `ppo_histogram`, `pvo_line`, `pvo_signal`, `pvo_histogram`, `sonar_line`, `sonar_signal`, `trix_line`, `trix_signal`, `stochf_percent_k`, `stochf_percent_d`, `stoch_percent_k`, `stoch_percent_d`, `stochrsi_percent_k`, `stochrsi_percent_d`
- High/low/volume indicators: `ad`, `adx`, `adxr`, `aroon_up`, `aroon_down`, `aroonosc`, `atr`, `cci`, `cmf`, `co`, `cv`, `dmi_plus`, `dmi_minus`, `efi`, `eom_line`, `eom_signal`, `erbear`, `erbull`, `massi_line`, `massi_signal`, `mfi`, `nvi_line`, `nvi_signal`, `obv_line`, `obv_signal`, `psar`, `pvi_line`, `pvi_signal`, `ultosc`, `vr`, `willr`
- High/low/volume indicators: `ad`, `adx`, `adxr`, `aroon_up`, `aroon_down`, `aroonosc`, `atr`, `cci`, `cci_line`, `cci_signal`, `cmf`, `co`, `cv`, `dmi_plus`, `dmi_minus`, `efi`, `eom_line`, `eom_signal`, `erbear`, `erbull`, `massi_line`, `massi_signal`, `mfi`, `nvi_line`, `nvi_signal`, `obv_line`, `obv_signal`, `psar`, `pvi_line`, `pvi_signal`, `ultosc`, `vr`, `willr`
- Ichimoku: `ichimoku_base_line`, `ichimoku_conversion_line`, `ichimoku_leading_span_a`, `ichimoku_leading_span_b`, `ichimoku_lagging_span`

## Usage
Expand Down
62 changes: 54 additions & 8 deletions polars/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ struct PeriodKwargs {
period: u32,
}

#[derive(Deserialize)]
struct PeriodSignalKwargs {
period: u32,
signal_period: u32,
}

#[derive(Deserialize)]
struct BBandKwargs {
period: u32,
Expand Down Expand Up @@ -146,12 +152,6 @@ struct SonarSignalKwargs {
signal_period: u32,
}

#[derive(Deserialize)]
struct PeriodSignalKwargs {
period: u32,
signal_period: u32,
}

#[derive(Deserialize)]
struct StochRsiKwargs {
period_rsi: u32,
Expand Down Expand Up @@ -495,11 +495,38 @@ fn cci(inputs: &[Series], kwargs: PeriodKwargs) -> PolarsResult<Series> {
let highs = series_to_f64_vec(&inputs[0])?;
let lows = series_to_f64_vec(&inputs[1])?;
let closes = series_to_f64_vec(&inputs[2])?;
Ok(option_vec_to_series(core::cci(
Ok(option_vec_to_series(core::cci_line(
&highs,
&lows,
&closes,
kwargs.period as usize,
)))
}

#[polars_expr(output_type=Float64)]
fn cci_line(inputs: &[Series], kwargs: PeriodKwargs) -> PolarsResult<Series> {
let highs = series_to_f64_vec(&inputs[0])?;
let lows = series_to_f64_vec(&inputs[1])?;
let closes = series_to_f64_vec(&inputs[2])?;
Ok(option_vec_to_series(core::cci_line(
&highs,
&lows,
&closes,
kwargs.period as usize,
)))
}

#[polars_expr(output_type=Float64)]
fn cci_signal(inputs: &[Series], kwargs: PeriodSignalKwargs) -> PolarsResult<Series> {
let highs = series_to_f64_vec(&inputs[0])?;
let lows = series_to_f64_vec(&inputs[1])?;
let closes = series_to_f64_vec(&inputs[2])?;
Ok(option_vec_to_series(core::cci_signal(
&highs,
&lows,
&closes,
kwargs.period as usize,
kwargs.signal_period as usize,
)))
}

Expand Down Expand Up @@ -873,12 +900,31 @@ fn roc(inputs: &[Series], kwargs: PeriodKwargs) -> PolarsResult<Series> {
#[polars_expr(output_type=Float64)]
fn rsi(inputs: &[Series], kwargs: PeriodKwargs) -> PolarsResult<Series> {
let input = series_to_f64_vec(&inputs[0])?;
Ok(option_vec_to_series(core::rsi(
Ok(option_vec_to_series(core::rsi_line(
&input,
kwargs.period as usize,
)))
}

#[polars_expr(output_type=Float64)]
fn rsi_line(inputs: &[Series], kwargs: PeriodKwargs) -> PolarsResult<Series> {
let input = series_to_f64_vec(&inputs[0])?;
Ok(option_vec_to_series(core::rsi_line(
&input,
kwargs.period as usize,
)))
}

#[polars_expr(output_type=Float64)]
fn rsi_signal(inputs: &[Series], kwargs: PeriodSignalKwargs) -> PolarsResult<Series> {
let input = series_to_f64_vec(&inputs[0])?;
Ok(option_vec_to_series(core::rsi_signal(
&input,
kwargs.period as usize,
kwargs.signal_period as usize,
)))
}

#[polars_expr(output_type=Float64)]
fn trix_line(inputs: &[Series], kwargs: PeriodKwargs) -> PolarsResult<Series> {
let input = series_to_f64_vec(&inputs[0])?;
Expand Down
35 changes: 35 additions & 0 deletions polars/techr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"bband_middle",
"bband_upper",
"cci",
"cci_line",
"cci_signal",
"cmf",
"co",
"cv",
Expand Down Expand Up @@ -68,6 +70,8 @@
"pvo_signal",
"roc",
"rsi",
"rsi_line",
"rsi_signal",
"sma",
"sonar_line",
"sonar_signal",
Expand Down Expand Up @@ -379,6 +383,25 @@ def cci(high: IntoExpr, low: IntoExpr, close: IntoExpr, *, period: int) -> pl.Ex
return _register("cci", [high, low, close], {"period": period})


def cci_line(high: IntoExpr, low: IntoExpr, close: IntoExpr, *, period: int) -> pl.Expr:
return _register("cci_line", [high, low, close], {"period": period})


def cci_signal(
high: IntoExpr,
low: IntoExpr,
close: IntoExpr,
*,
period: int,
signal_period: int,
) -> pl.Expr:
return _register(
"cci_signal",
[high, low, close],
{"period": period, "signal_period": signal_period},
)


def cmf(
high: IntoExpr,
low: IntoExpr,
Expand Down Expand Up @@ -706,6 +729,18 @@ def rsi(expr: IntoExpr, *, period: int) -> pl.Expr:
return _register("rsi", [expr], {"period": period})


def rsi_line(expr: IntoExpr, *, period: int) -> pl.Expr:
return _register("rsi_line", [expr], {"period": period})


def rsi_signal(expr: IntoExpr, *, period: int, signal_period: int) -> pl.Expr:
return _register(
"rsi_signal",
[expr],
{"period": period, "signal_period": signal_period},
)


def trix_line(expr: IntoExpr, *, period: int) -> pl.Expr:
return _register("trix_line", [expr], {"period": period})

Expand Down
Loading