Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
9ac44ce
starting to add caching to inertia moment calculation during the simu…
m-rauen Nov 26, 2024
37ad0e3
correct caching problem of not handle mutable objects
m-rauen Dec 3, 2024
014b59d
move ignore_unhashable to _misc module + trying to resolve circular i…
m-rauen Dec 3, 2024
7216147
circular import solved + movem from pkg_resources to importlib
m-rauen Dec 4, 2024
184a193
updated gitignore
m-rauen Dec 4, 2024
ddb4846
change ignore to copy strategy for cache + add docstring
m-rauen Dec 4, 2024
edae933
change maxsize of copy_unhashable
m-rauen Dec 5, 2024
bcce554
changing hash strategy to eliminate errors
m-rauen Dec 10, 2024
da26124
starting to add caching to inertia moment calculation during the simu…
m-rauen Nov 26, 2024
88b20e5
correct caching problem of not handle mutable objects
m-rauen Dec 3, 2024
07da1be
move ignore_unhashable to _misc module + trying to resolve circular i…
m-rauen Dec 3, 2024
6660655
circular import solved + movem from pkg_resources to importlib
m-rauen Dec 4, 2024
687d7c5
updated gitignore
m-rauen Dec 4, 2024
98be111
change ignore to copy strategy for cache + add docstring
m-rauen Dec 4, 2024
ca3e412
change maxsize of copy_unhashable
m-rauen Dec 5, 2024
6c1f0ad
changing hash strategy to eliminate errors
m-rauen Dec 10, 2024
70ae12e
changed scipy.misc.derivative to findiff.Diff (scipy derivative remov…
m-rauen Jan 28, 2025
37d2915
changed scipy.misc.derivative to findiff.Diff (scipy derivative remov…
m-rauen Jan 28, 2025
e9ce98d
resolve funky CI + cache corrected implemented
m-rauen Jan 28, 2025
8a32b23
added 'findiff' to the poetrylock file, since it's the new derivative…
m-rauen Jan 28, 2025
557092c
updated poetrylock file + cleaned some code
m-rauen Jan 28, 2025
29df5d7
testing funky CI
m-rauen Jan 28, 2025
61bde8b
testing
m-rauen Jan 29, 2025
f2298e4
added derivative func to misc module instead of using findiff
m-rauen Jan 31, 2025
3d638ac
forgetted to call _misc._derivative in the API module
m-rauen Jan 31, 2025
d426bdf
solved bugs from tests after adding cache to 'inertia' func + deleted…
m-rauen Mar 27, 2025
4083f0a
changed some diffusion rate code
m-rauen Apr 1, 2025
f6b8b58
changed some code to adequate for ruff linter + started some tests an…
m-rauen May 6, 2025
5058db1
fix: correct 'Lint with Ruff' step in github actions
m-rauen May 6, 2025
b36e01c
fix: adjust poetry to use correct python version (3.8-3.10) before in…
m-rauen May 6, 2025
9e6ac3c
fix: added 3 rules for ruff-lint to ignore in order to pass the build
m-rauen May 6, 2025
f4df058
fix: Black v25.0.1 doesnt support python 3.8 - changed to v23.3 to v2…
m-rauen May 7, 2025
f6ae16a
fix: changed to correct Black versions in 'pyproject' + linting w/ Bl…
m-rauen May 7, 2025
fd427a6
build: changed black lint strategy and versions + code formatting
m-rauen May 7, 2025
81f8b04
test: added tests for the new code of _misc (derivative func + cache …
m-rauen May 20, 2025
8f911e2
test: added isolated test for undefined derivative order + tests for …
m-rauen May 20, 2025
cb4440b
style: updated code formatting
m-rauen May 21, 2025
7f479e8
style: added missing commas to the higher order derivative tests
m-rauen May 21, 2025
c184d46
style: adequated to Black linting
m-rauen May 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,23 @@ jobs:
with:
poetry-version: ${{ matrix.poetry-version }}

- name: Configure Poetry to use matrix Python version
run: |
poetry env use python${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
poetry install --extras "cli fast solvents"

- name: Lint with Ruff
- run: |
run: |
poetry run ruff check --output-format=github .

- name: Lint with Black
uses: psf/black@stable
run: |
poetry run black --check .

- name: Test with pytest
run: |
Expand Down
5 changes: 5 additions & 0 deletions .vscode/extensions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"recommendations": [
"kevinrose.vsc-python-indent"
]
}
4 changes: 2 additions & 2 deletions overreact/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__docformat__ = "restructuredtext"

import pkg_resources as _pkg_resources
from importlib.metadata import version

from overreact.api import (
get_enthalpies,
Expand Down Expand Up @@ -48,7 +48,7 @@
"unparse_reactions",
]

__version__ = _pkg_resources.get_distribution(__name__).version
__version__ = version(__name__)
__license__ = "MIT" # I'm too lazy to get it from setup.py...

__headline__ = "📈 Create and analyze chemical microkinetic models built from computational chemistry data."
Expand Down
257 changes: 256 additions & 1 deletion overreact/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,270 @@
from __future__ import annotations

import contextlib
from copy import deepcopy
from functools import lru_cache as cache
from functools import wraps

import numpy as np
from numpy import arange, array, hstack, newaxis, prod
from scipy.stats import cauchy, norm

import overreact as rx
from overreact import _constants as constants


def _central_diff_weights(Np, ndiv=1):
"""
Return weights for an Np-point central derivative.

Assumes equally-spaced function points.

If weights are in the vector w, then
derivative is w[0] * f(x-ho*dx) + ... + w[-1] * f(x+h0*dx)

Extracted directly from Scipy 'finite_differences' module.
(https://github.com/scipy/scipy/blob/d1073acbc804b721cfe356969d8461cdd25a7839/scipy/stats/_finite_differences.py)

Parameters
----------
Np : int
Number of points for the central derivative.
ndiv : int, optional
Number of divisions. Default is 1.

Returns
-------
w : ndarray
Weights for an Np-point central derivative. Its size is `Np`.

Notes
-----
Can be inaccurate for a large number of points.

Examples
--------
We can calculate a derivative value of a function.

>>> def f(x):
... return 2 * x**2 + 3
>>> x = 3.0 # derivative point
>>> h = 0.1 # differential step
>>> Np = 3 # point number for central derivative
>>> weights = _central_diff_weights(Np) # weights for first derivative
>>> vals = [f(x + (i - Np/2) * h) for i in range(Np)]
>>> sum(w * v for (w, v) in zip(weights, vals))/h
11.79999999999998

This value is close to the analytical solution:
f'(x) = 4x, so f'(3) = 12

References
----------
.. [1] https://en.wikipedia.org/wiki/Finite_difference

"""
if Np < ndiv + 1:
msg = "Number of points must be at least the derivative order + 1."
raise ValueError(msg)
if Np % 2 == 0:
msg = "The number of points must be odd."
raise ValueError(msg)
from scipy import linalg

ho = Np >> 1
x = arange(-ho, ho + 1.0)
x = x[:, newaxis]
X = x**0.0
for k in range(1, Np):
X = hstack([X, x**k])
return prod(arange(1, ndiv + 1), axis=0) * linalg.inv(X)[ndiv]


def _derivative(func, x0, dx=1.0, n=1, args=(), order=3):
"""
Find the nth derivative of a function at a point.

Given a function, use a central difference formula with spacing `dx` to
compute the nth derivative at `x0`.

Extracted directly from Scipy 'finite_differences' module.
(https://github.com/scipy/scipy/blob/d1073acbc804b721cfe356969d8461cdd25a7839/scipy/stats/_finite_differences.py)

Parameters
----------
func : function
Input function.
x0 : float
The point at which the nth derivative is found.
dx : float, optional
Spacing.
n : int, optional
Order of the derivative. Default is 1.
args : tuple, optional
Arguments
order : int, optional
Number of points to use, must be odd.

Notes
-----
Decreasing the step size too small can result in round-off error.

Examples
--------
>>> def f(x):
... return x**3 + x**2
>>> _derivative(f, 1.0, dx=1e-6)
4.9999999999217337
"""
first_deriv_weight_map = {
3: array([-1, 0, 1]) / 2.0,
5: array([1, -8, 0, 8, -1]) / 12.0,
7: array([-1, 9, -45, 0, 45, -9, 1]) / 60.0,
9: array([3, -32, 168, -672, 0, 672, -168, 32, -3]) / 840.0,
}

second_deriv_weight_map = {
3: array([1, -2.0, 1]),
5: array([-1, 16, -30, 16, -1]) / 12.0,
7: array([2, -27, 270, -490, 270, -27, 2]) / 180.0,
9: array([-9, 128, -1008, 8064, -14350, 8064, -1008, 128, -9]) / 5040.0,
}

if order < n + 1:
msg = "'order' (the number of points used to compute the derivative), must be at least the derivative order 'n' + 1."
raise ValueError(msg)
if order % 2 == 0:
msg = (
"'order' (the number of points used to compute the derivative) must be odd."
)
raise ValueError(msg)

# pre-computed for n=1 and 2 and low-order for speed.
if n == 1:
if order == 3:
weights = first_deriv_weight_map.get(3)
elif n == 1 and order == 5:
weights = first_deriv_weight_map.get(5)
elif n == 1 and order == 7:
weights = first_deriv_weight_map.get(7)
elif n == 1 and order == 9:
weights = first_deriv_weight_map.get(9)
else:
weights = _central_diff_weights(order, 1)
# TODO(mrauen): I couldn't find a case in overreact where we use the second (or higher) derivatives. Therefore, I think we can delete this piece of code...Or maybe just leave it here for the future implementations (who knows)
elif n == 2:
if order == 3:
weights = second_deriv_weight_map.get(3)
elif n == 2 and order == 5:
weights = second_deriv_weight_map.get(5)
elif n == 2 and order == 7:
weights = second_deriv_weight_map.get(7)
elif n == 2 and order == 9:
weights = second_deriv_weight_map.get(9)
else:
weights = _central_diff_weights(order, 2)
else:
weights = _central_diff_weights(order, n)

Check warning on line 172 in overreact/_misc.py

View check run for this annotation

Codecov / codecov/patch

overreact/_misc.py#L172

Added line #L172 was not covered by tests

val = 0.0
ho = order >> 1
for k in range(order):
val += weights[k] * func(x0 + (k - ho) * dx, *args)
return val / prod((dx,) * n, axis=0)


def make_hashable(obj):
"""
Given an array, list or set make it immutable by transforming it into a tuple.

Parameters
----------
obj : array

Returns
-------
tuple

Notes
-----
List comprehension it's key here for list and set, otherwise it will return a tuple with only the first item.
"""
if isinstance(obj, np.ndarray):
return (tuple(obj.shape), tuple(obj.ravel()))
elif isinstance(obj, (list, set)):
return tuple(make_hashable(item) for item in obj)
else:
return obj


def copy_unhashable(maxsize=128, typed=False):
"""
Cache resultant tuples while handling the received unhashable types (array, list, dictionaries).

Convert unhashable arguments into hashable before passing it to 'lru_cache'. Then, reconstruct the (now) hashable tuple back to return it for the function caller. A copy of the received argument is made in order to prevent errors and side-effects to the original array/list/etc.

Parameters
----------
maxsize : int
Cache size limit. Default from functools.lru_cache()
typed : bool
If set to True, arguments of different types will be cache separately. Default from functools.lru_cache()
func : function
The function to be wrapped and cached

Returns
-------
function
A wrapper version of the original function that is cacheable now
"""

def decorator(func):
@cache(maxsize=maxsize, typed=typed)
@wraps(func)
def cached_func(*hashable_args, **hashable_kwargs):
args = []
kwargs = {}

def convert_back(arg):
if isinstance(arg, tuple) and len(arg) == 2:
shape, flat_data = arg
if (
isinstance(shape, tuple)
and all(isinstance(dim, (int, np.integer)) for dim in shape)
and isinstance(flat_data, tuple)
):
if len(flat_data) == 0 or any(dim <= 0 for dim in shape):
return np.array([])
try:
return np.array(flat_data).reshape(shape)
except ValueError as e:
msg = f"Reshape error: {e} - shape: {shape}, data: {flat_data}"
raise ValueError(msg)
return arg

args = [convert_back(arg) for arg in hashable_args]
for k, v in hashable_kwargs.items():
kwargs[k] = convert_back(v)
args = tuple(args)
return func(*args, **kwargs)

def wrapper(*args, **kwargs):
wrapper_hashable_args = []
wrapper_hashable_kwargs = {}

wrapper_hashable_args = [make_hashable(arg) for arg in args]
for k, v in kwargs.items():
wrapper_hashable_kwargs[k] = make_hashable(v)
wrapper_hashable_args = tuple(wrapper_hashable_args)
return deepcopy(
cached_func(*wrapper_hashable_args, **wrapper_hashable_kwargs),
)

return wrapper

return decorator


def _find_package(package):
"""Check if a package exists without importing it.

Expand Down Expand Up @@ -739,7 +994,7 @@
return primes


@cache(maxsize=1000000)
@cache
def _vdc(n, b=2):
"""Help haltonspace."""
res, denom = 0, 1
Expand Down
13 changes: 5 additions & 8 deletions overreact/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from __future__ import annotations

__all__ = [
"get_k",
"get_kappa",
"get_freeenergies",
"get_entropies",
"get_enthalpies",
"get_entropies",
"get_freeenergies",
"get_internal_energies",
"get_k",
"get_kappa",
]


Expand All @@ -22,11 +22,11 @@
from typing import TYPE_CHECKING

import numpy as np
from scipy.misc import derivative

import overreact as rx
from overreact import _constants as constants
from overreact import coords, rates, tunnel
from overreact._misc import _derivative as derivative

if TYPE_CHECKING:
from overreact.core import Scheme
Expand Down Expand Up @@ -74,7 +74,6 @@ def get_internal_energies(
for name in compounds:
logger.info(f"calculate internal energy: {name}")

# TODO(schneiderfelipe): inertia might benefit from caching
moments, _, _ = coords.inertia(
compounds[name].atommasses,
compounds[name].atomcoords,
Expand Down Expand Up @@ -140,7 +139,6 @@ def get_enthalpies(
for name in compounds:
logger.info(f"calculate enthalpy: {name}")

# TODO(schneiderfelipe): inertia might benefit from caching
moments, _, _ = coords.inertia(
compounds[name].atommasses,
compounds[name].atomcoords,
Expand Down Expand Up @@ -233,7 +231,6 @@ def get_entropies(
)
symmetry_number = coords.symmetry_number(point_group)

# TODO(schneiderfelipe): inertia might benefit from caching
moments, _, _ = coords.inertia(
compounds[name].atommasses,
compounds[name].atomcoords,
Expand Down
Loading