Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions leanframe/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import ibis
import ibis.expr.types as ibis_types
import pandas as pd
from leanframe.core.dtypes import convert_pandas_to_ibis


def col(name: str) -> Expression:
Expand Down Expand Up @@ -60,3 +62,107 @@ def __ne__(self, other) -> Expression: # type: ignore[override]

def __eq__(self, other) -> Expression: # type: ignore[override]
return Expression(self._data == getattr(other, "_data", other))

def lt(self, other) -> Expression:
"""Return a boolean Expression showing whether each element is less than the other."""
return self < other

def gt(self, other) -> Expression:
"""Return a boolean Expression showing whether each element is greater than the other."""
return self > other

def le(self, other) -> Expression:
"""Return a boolean Expression showing whether each element is less than or equal to the other."""
return self <= other

def ge(self, other) -> Expression:
"""Return a boolean Expression showing whether each element is greater than or equal to the other."""
return self >= other

def ne(self, other) -> Expression:
"""Return a boolean Expression showing whether each element is not equal to the other."""
return self != other

def eq(self, other) -> Expression:
"""Return a boolean Expression showing whether each element is equal to the other."""
return self == other

def __round__(self, n=0) -> Expression:
return Expression(self._data.round(n))

def abs(self) -> Expression:
"""Return an Expression with the absolute value of each element."""
return Expression(self._data.abs())

def all(self) -> Expression:
"""Return whether all elements are True."""
return Expression(self._data.all())

def any(self) -> Expression:
"""Return whether any element is True."""
return Expression(self._data.any())

def sum(self) -> Expression:
"""Return the sum of the Expression."""
return Expression(self._data.sum())

def mean(self) -> Expression:
"""Return the mean of the Expression."""
return Expression(self._data.mean())

def min(self) -> Expression:
"""Return the min of the Expression."""
return Expression(self._data.min())

def max(self) -> Expression:
"""Return the max of the Expression."""
return Expression(self._data.max())

def std(self) -> Expression:
"""Return the std of the Expression."""
return Expression(self._data.std())

def var(self) -> Expression:
"""Return the var of the Expression."""
return Expression(self._data.var())

def count(self) -> Expression:
"""Return the number of non-null observations in the Expression."""
return Expression(self._data.count())

def cummax(self) -> Expression:
"""Return an Expression with the cumulative maximum of each element."""
return Expression(self._data.cummax())

def cummin(self) -> Expression:
"""Return an Expression with the cumulative minimum of each element."""
return Expression(self._data.cummin())

def cumprod(self) -> Expression:
"""Return an Expression with the cumulative product of each element.

Note: This currently uses a `log().cumsum().exp()` workaround, which
may fail or return NaN if the data contains zeros or negative numbers.
"""
return Expression(self._data.log().cumsum().exp().cast(self._data.type()))

def cumsum(self) -> Expression:
"""Return an Expression with the cumulative sum of each element."""
return Expression(self._data.cumsum())

def diff(self) -> Expression:
"""Return an Expression with the difference between each element and the previous element."""
return Expression(self._data - self._data.lag())

def copy(self) -> Expression:
"""Return a copy of the Expression."""
return Expression(self._data)

def isin(self, values) -> Expression:
"""Return a boolean Expression showing whether each element is contained in values."""
return Expression(self._data.isin(values))

def astype(self, dtype: pd.ArrowDtype) -> Expression:
"""Cast an Expression to a specified dtype."""
ibis_type = convert_pandas_to_ibis(dtype)
return Expression(self._data.cast(ibis_type))
153 changes: 153 additions & 0 deletions tests/unit/test_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2025 Google LLC, LeanFrame Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for Expression methods evaluation."""

import ibis
import pandas as pd
import pyarrow as pa
import pytest
from leanframe import Session, col

@pytest.fixture
def session():
"""Return a Session connected to an in-memory duckdb."""
return Session(ibis.duckdb.connect())

@pytest.fixture
def df(session):
"""Return a test DataFrame."""
data = {
'a': [1, -2, 3, -4, 5],
'b': [5, 4, 3, 2, 1],
}
return session.read_ibis(ibis.memtable(data))

def test_expression_comparison_methods(df):
result = df.assign(
lt=col('a').lt(col('b')),
gt=col('a').gt(col('b')),
le=col('a').le(col('b')),
ge=col('a').ge(col('b')),
ne=col('a').ne(col('b')),
eq=col('a').eq(col('b')),
).to_pandas()

pd.testing.assert_series_equal(
result['lt'], pd.Series([True, True, False, True, False], name='lt'), check_dtype=False
)
pd.testing.assert_series_equal(
result['gt'], pd.Series([False, False, False, False, True], name='gt'), check_dtype=False
)
pd.testing.assert_series_equal(
result['le'], pd.Series([True, True, True, True, False], name='le'), check_dtype=False
)
pd.testing.assert_series_equal(
result['ge'], pd.Series([False, False, True, False, True], name='ge'), check_dtype=False
)
pd.testing.assert_series_equal(
result['ne'], pd.Series([True, True, False, True, True], name='ne'), check_dtype=False
)
pd.testing.assert_series_equal(
result['eq'], pd.Series([False, False, True, False, False], name='eq'), check_dtype=False
)

def test_expression_math_methods(session):
df = session.read_ibis(ibis.memtable({'a': [1.5, -2.1, 3.8]}))
result = df.assign(
r=round(col('a')),
r1=round(col('a'), 1),
abs=col('a').abs()
).to_pandas()

pd.testing.assert_series_equal(
result['r'], pd.Series([2.0, -2.0, 4.0], name='r'), check_dtype=False
)
pd.testing.assert_series_equal(
result['r1'], pd.Series([1.5, -2.1, 3.8], name='r1'), check_dtype=False
)
pd.testing.assert_series_equal(
result['abs'], pd.Series([1.5, 2.1, 3.8], name='abs'), check_dtype=False
)

def test_expression_aggregation_methods(session):
df = session.read_ibis(ibis.memtable({'a': [1, 2, 3], 'b': [True, False, True]}))
result = df.assign(
all_b=col('b').all(),
any_b=col('b').any(),
sum_a=col('a').sum(),
mean_a=col('a').mean(),
min_a=col('a').min(),
max_a=col('a').max(),
count_a=col('a').count()
).to_pandas()

pd.testing.assert_series_equal(
result['all_b'], pd.Series([False, False, False], name='all_b'), check_dtype=False
)
pd.testing.assert_series_equal(
result['any_b'], pd.Series([True, True, True], name='any_b'), check_dtype=False
)
pd.testing.assert_series_equal(
result['sum_a'], pd.Series([6, 6, 6], name='sum_a'), check_dtype=False
)
pd.testing.assert_series_equal(
result['mean_a'], pd.Series([2.0, 2.0, 2.0], name='mean_a'), check_dtype=False
)
pd.testing.assert_series_equal(
result['min_a'], pd.Series([1, 1, 1], name='min_a'), check_dtype=False
)
pd.testing.assert_series_equal(
result['max_a'], pd.Series([3, 3, 3], name='max_a'), check_dtype=False
)

def test_expression_cumulative_methods(session):
df = session.read_ibis(ibis.memtable({'a': [1, 2, 3]}))
result = df.assign(
cummax=col('a').cummax(),
cummin=col('a').cummin(),
cumsum=col('a').cumsum(),
cumprod=col('a').cumprod(),
diff=col('a').diff()
).to_pandas()

pd.testing.assert_series_equal(
result['cummax'], pd.Series([1, 2, 3], name='cummax'), check_dtype=False
)
pd.testing.assert_series_equal(
result['cummin'], pd.Series([1, 1, 1], name='cummin'), check_dtype=False
)
pd.testing.assert_series_equal(
result['cumsum'], pd.Series([1, 3, 6], name='cumsum'), check_dtype=False
)
# cumprod uses log trick, check approximate match. 1*1, 1*2, 2*3 = 1, 2, 6.
pd.testing.assert_series_equal(
result['cumprod'].round(), pd.Series([1.0, 2.0, 6.0], name='cumprod'), check_dtype=False
)
pd.testing.assert_series_equal(
result['diff'], pd.Series([float('nan'), 1.0, 1.0], name='diff'), check_dtype=False
)

def test_expression_utility_methods(df):
result = df.assign(
isin=col('a').isin([1, 3]),
cast=col('a').astype(pd.ArrowDtype(pa.float64()))
).to_pandas()

pd.testing.assert_series_equal(
result['isin'], pd.Series([True, False, True, False, False], name='isin'), check_dtype=False
)
pd.testing.assert_series_equal(
result['cast'], pd.Series([1.0, -2.0, 3.0, -4.0, 5.0], name='cast'), check_dtype=False
)
Loading