Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion SparseDiffEngine
Submodule SparseDiffEngine updated 48 files
+1 −1 include/affine.h
+9 −8 include/bivariate.h
+13 −0 include/memory_wrappers.h
+9 −0 include/problem.h
+26 −8 include/subexpr.h
+5 −3 include/utils/Vec_macros.h
+2 −4 src/affine/hstack.c
+3 −6 src/affine/index.c
+3 −8 src/affine/linear_op.c
+27 −9 src/affine/parameter.c
+2 −1 src/affine/sum.c
+2 −1 src/affine/trace.c
+59 −36 src/bivariate/left_matmul.c
+1 −1 src/bivariate/multiply.c
+2 −1 src/bivariate/quad_over_lin.c
+32 −6 src/bivariate/right_matmul.c
+17 −9 src/bivariate/scalar_mult.c
+16 −16 src/bivariate/vector_mult.c
+5 −7 src/expr.c
+4 −3 src/other/prod_axis_one.c
+4 −3 src/other/prod_axis_zero.c
+2 −2 src/other/quad_form.c
+41 −6 src/problem.c
+12 −11 src/utils/CSC_Matrix.c
+6 −5 src/utils/CSR_Matrix.c
+2 −1 src/utils/int_double_pair.c
+3 −2 src/utils/linalg_sparse_matmuls.c
+11 −6 tests/all_tests.c
+2 −1 tests/forward_pass/affine/test_add.h
+4 −3 tests/forward_pass/affine/test_sum.h
+5 −3 tests/forward_pass/affine/test_variable_parameter.h
+2 −1 tests/forward_pass/composite/test_composite.h
+2 −1 tests/forward_pass/test_prod_axis_one.h
+2 −1 tests/forward_pass/test_prod_axis_zero.h
+2 −1 tests/jacobian_tests/test_broadcast.h
+11 −9 tests/jacobian_tests/test_left_matmul.h
+2 −2 tests/jacobian_tests/test_right_matmul.h
+6 −2 tests/jacobian_tests/test_scalar_mult.h
+12 −9 tests/jacobian_tests/test_transpose.h
+6 −2 tests/jacobian_tests/test_vector_mult.h
+351 −0 tests/problem/test_param_prob.h
+20 −14 tests/profiling/profile_left_matmul.h
+6 −5 tests/utils/test_csr_csc_conversion.h
+2 −1 tests/utils/test_csr_matrix.h
+10 −9 tests/wsum_hess/test_left_matmul.h
+2 −2 tests/wsum_hess/test_right_matmul.h
+6 −2 tests/wsum_hess/test_scalar_mult.h
+6 −2 tests/wsum_hess/test_vector_mult.h
17 changes: 14 additions & 3 deletions sparsediffpy/_bindings/atoms/const_scalar_mult.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

#include "bivariate.h"
#include "common.h"
#include "subexpr.h"

/* Constant scalar multiplication: a * f(x) where a is a constant double */
/* Constant scalar multiplication: a * f(x) where a is a constant double.
* Creates a fixed parameter node for the scalar and calls new_scalar_mult. */
static PyObject *py_make_const_scalar_mult(PyObject *self, PyObject *args)
{
PyObject *child_capsule;
Expand All @@ -22,11 +24,20 @@ static PyObject *py_make_const_scalar_mult(PyObject *self, PyObject *args)
return NULL;
}

expr *node = new_const_scalar_mult(a, child);
/* Create a 1x1 fixed parameter for the scalar value */
expr *a_node = new_parameter(1, 1, PARAM_FIXED, child->n_vars, &a);
if (!a_node)
{
PyErr_SetString(PyExc_RuntimeError, "failed to create scalar parameter node");
return NULL;
}

expr *node = new_scalar_mult(a_node, child);

if (!node)
{
PyErr_SetString(PyExc_RuntimeError,
"failed to create const_scalar_mult node");
"failed to create scalar_mult node");
return NULL;
}
expr_retain(node); /* Capsule owns a reference */
Expand Down
21 changes: 16 additions & 5 deletions sparsediffpy/_bindings/atoms/const_vector_mult.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

#include "bivariate.h"
#include "common.h"
#include "subexpr.h"

/* Constant vector elementwise multiplication: a ∘ f(x) where a is a constant vector
*/
/* Constant vector elementwise multiplication: a ∘ f(x) where a is a constant vector.
* Creates a fixed parameter node for the vector and calls new_vector_mult. */
static PyObject *py_make_const_vector_mult(PyObject *self, PyObject *args)
{
PyObject *child_capsule;
Expand Down Expand Up @@ -42,14 +43,24 @@ static PyObject *py_make_const_vector_mult(PyObject *self, PyObject *args)

double *a_data = (double *) PyArray_DATA(a_array);

expr *node = new_const_vector_mult(a_data, child);

/* Create a fixed parameter node for the vector */
expr *a_node = new_parameter(child->d1, child->d2, PARAM_FIXED, child->n_vars,
a_data);
Py_DECREF(a_array);

if (!a_node)
{
PyErr_SetString(PyExc_RuntimeError,
"failed to create vector parameter node");
return NULL;
}

expr *node = new_vector_mult(a_node, child);

if (!node)
{
PyErr_SetString(PyExc_RuntimeError,
"failed to create const_vector_mult node");
"failed to create vector_mult node");
return NULL;
}
expr_retain(node); /* Capsule owns a reference */
Expand Down
5 changes: 3 additions & 2 deletions sparsediffpy/_bindings/atoms/constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define ATOM_CONSTANT_H

#include "common.h"
#include "subexpr.h"

static PyObject *py_make_constant(PyObject *self, PyObject *args)
{
Expand All @@ -19,8 +20,8 @@ static PyObject *py_make_constant(PyObject *self, PyObject *args)
return NULL;
}

expr *node =
new_constant(d1, d2, n_vars, (const double *) PyArray_DATA(values_array));
expr *node = new_parameter(d1, d2, PARAM_FIXED, n_vars,
(const double *) PyArray_DATA(values_array));
Py_DECREF(values_array);

if (!node)
Expand Down
67 changes: 58 additions & 9 deletions sparsediffpy/_bindings/atoms/left_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,27 @@

#include "bivariate.h"
#include "common.h"
#include "subexpr.h"

/* Left matrix multiplication: A @ f(x) where A is a constant matrix */
/* Left matrix multiplication: A @ f(x).
*
* Unified binding for both fixed-constant and updatable-parameter cases.
* Python signature:
* make_left_matmul(param_or_none, child, data, indices, indptr, m, n)
*
* - param_or_none: None for fixed constants (a PARAM_FIXED parameter is created
* internally), or an existing parameter capsule for updatable parameters.
* - child: the child expression capsule f(x).
* - data, indices, indptr, m, n: CSR arrays defining the sparsity pattern and
* initial values of the matrix A. */
static PyObject *py_make_left_matmul(PyObject *self, PyObject *args)
{
PyObject *param_obj;
PyObject *child_capsule;
PyObject *data_obj, *indices_obj, *indptr_obj;
int m, n;
if (!PyArg_ParseTuple(args, "OOOOii", &child_capsule, &data_obj, &indices_obj,
&indptr_obj, &m, &n))
if (!PyArg_ParseTuple(args, "OOOOOii", &param_obj, &child_capsule, &data_obj,
&indices_obj, &indptr_obj, &m, &n))
{
return NULL;
}
Expand All @@ -38,18 +50,55 @@ static PyObject *py_make_left_matmul(PyObject *self, PyObject *args)
return NULL;
}

int nnz = (int) PyArray_SIZE(data_array);
double *csr_data = (double *) PyArray_DATA(data_array);
int *csr_indices = (int *) PyArray_DATA(indices_array);
int *csr_indptr = (int *) PyArray_DATA(indptr_array);
int nnz = csr_indptr[m];

/* Build CSR matrix from Python arrays */
CSR_Matrix *A = new_csr_matrix(m, n, nnz);
memcpy(A->x, PyArray_DATA(data_array), nnz * sizeof(double));
memcpy(A->i, PyArray_DATA(indices_array), nnz * sizeof(int));
memcpy(A->p, PyArray_DATA(indptr_array), (m + 1) * sizeof(int));
memcpy(A->p, csr_indptr, (m + 1) * sizeof(int));
memcpy(A->i, csr_indices, nnz * sizeof(int));
memcpy(A->x, csr_data, nnz * sizeof(double));

/* Determine param_node: use passed capsule, or create PARAM_FIXED internally */
expr *param_node;
if (param_obj == Py_None)
{
/* Fixed constant: pass CSR data directly (values are already in CSR order) */
param_node = new_parameter(nnz, 1, PARAM_FIXED, child->n_vars, csr_data);

if (!param_node)
{
free_csr_matrix(A);
Py_DECREF(data_array);
Py_DECREF(indices_array);
Py_DECREF(indptr_array);
PyErr_SetString(PyExc_RuntimeError,
"failed to create matrix parameter node");
return NULL;
}
}
else
{
param_node = (expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME);
if (!param_node)
{
free_csr_matrix(A);
Py_DECREF(data_array);
Py_DECREF(indices_array);
Py_DECREF(indptr_array);
PyErr_SetString(PyExc_ValueError, "invalid param capsule");
return NULL;
}
}

Py_DECREF(data_array);
Py_DECREF(indices_array);
Py_DECREF(indptr_array);

expr *node = new_left_matmul(child, A);
free_csr_matrix(A);
expr *node = new_left_matmul(param_node, child, A);
free_csr_matrix(A); /* constructor copies it */

if (!node)
{
Expand Down
26 changes: 26 additions & 0 deletions sparsediffpy/_bindings/atoms/parameter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef ATOM_PARAMETER_H
#define ATOM_PARAMETER_H

#include "common.h"

/* Updatable parameter: make_parameter(d1, d2, param_id, n_vars)
* Values are set later via problem_update_params. */
static PyObject *py_make_parameter(PyObject *self, PyObject *args)
{
int d1, d2, param_id, n_vars;
if (!PyArg_ParseTuple(args, "iiii", &d1, &d2, &param_id, &n_vars))
{
return NULL;
}

expr *node = new_parameter(d1, d2, param_id, n_vars, NULL);
if (!node)
{
PyErr_SetString(PyExc_RuntimeError, "failed to create parameter node");
return NULL;
}
expr_retain(node); /* Capsule owns a reference */
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}

#endif /* ATOM_PARAMETER_H */
67 changes: 58 additions & 9 deletions sparsediffpy/_bindings/atoms/right_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,27 @@

#include "bivariate.h"
#include "common.h"
#include "subexpr.h"

/* Right matrix multiplication: f(x) @ A where A is a constant matrix */
/* Right matrix multiplication: f(x) @ A where A is a constant or parameter matrix.
*
* Unified binding for both fixed-constant and updatable-parameter cases.
* Python signature:
* make_right_matmul(param_or_none, child, data, indices, indptr, m, n)
*
* - param_or_none: None for fixed constants (a PARAM_FIXED parameter is created
* internally), or an existing parameter capsule for updatable parameters.
* - child: the child expression capsule f(x).
* - data, indices, indptr, m, n: CSR arrays defining the sparsity pattern and
* initial values of the matrix A. */
static PyObject *py_make_right_matmul(PyObject *self, PyObject *args)
{
PyObject *param_obj;
PyObject *child_capsule;
PyObject *data_obj, *indices_obj, *indptr_obj;
int m, n;
if (!PyArg_ParseTuple(args, "OOOOii", &child_capsule, &data_obj, &indices_obj,
&indptr_obj, &m, &n))
if (!PyArg_ParseTuple(args, "OOOOOii", &param_obj, &child_capsule, &data_obj,
&indices_obj, &indptr_obj, &m, &n))
{
return NULL;
}
Expand All @@ -38,18 +50,55 @@ static PyObject *py_make_right_matmul(PyObject *self, PyObject *args)
return NULL;
}

int nnz = (int) PyArray_SIZE(data_array);
double *csr_data = (double *) PyArray_DATA(data_array);
int *csr_indices = (int *) PyArray_DATA(indices_array);
int *csr_indptr = (int *) PyArray_DATA(indptr_array);
int nnz = csr_indptr[m];

/* Build CSR matrix from Python arrays */
CSR_Matrix *A = new_csr_matrix(m, n, nnz);
memcpy(A->x, PyArray_DATA(data_array), nnz * sizeof(double));
memcpy(A->i, PyArray_DATA(indices_array), nnz * sizeof(int));
memcpy(A->p, PyArray_DATA(indptr_array), (m + 1) * sizeof(int));
memcpy(A->p, csr_indptr, (m + 1) * sizeof(int));
memcpy(A->i, csr_indices, nnz * sizeof(int));
memcpy(A->x, csr_data, nnz * sizeof(double));

/* Determine param_node: use passed capsule, or create PARAM_FIXED internally */
expr *param_node;
if (param_obj == Py_None)
{
/* Fixed constant: pass CSR data directly (values are already in CSR order) */
param_node = new_parameter(nnz, 1, PARAM_FIXED, child->n_vars, csr_data);

if (!param_node)
{
free_csr_matrix(A);
Py_DECREF(data_array);
Py_DECREF(indices_array);
Py_DECREF(indptr_array);
PyErr_SetString(PyExc_RuntimeError,
"failed to create matrix parameter node");
return NULL;
}
}
else
{
param_node = (expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME);
if (!param_node)
{
free_csr_matrix(A);
Py_DECREF(data_array);
Py_DECREF(indices_array);
Py_DECREF(indptr_array);
PyErr_SetString(PyExc_ValueError, "invalid param capsule");
return NULL;
}
}

Py_DECREF(data_array);
Py_DECREF(indices_array);
Py_DECREF(indptr_array);

expr *node = new_right_matmul(child, A);
free_csr_matrix(A);
expr *node = new_right_matmul(param_node, child, A);
free_csr_matrix(A); /* constructor copies it */

if (!node)
{
Expand Down
44 changes: 44 additions & 0 deletions sparsediffpy/_bindings/atoms/scalar_mult.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef ATOM_SCALAR_MULT_H
#define ATOM_SCALAR_MULT_H

#include "bivariate.h"
#include "common.h"

/* Parameter scalar multiplication: param * f(x) where param is a parameter capsule.
* Python name: make_param_scalar_mult(param_capsule, child_capsule) */
static PyObject *py_make_param_scalar_mult(PyObject *self, PyObject *args)
{
PyObject *param_capsule;
PyObject *child_capsule;

if (!PyArg_ParseTuple(args, "OO", &param_capsule, &child_capsule))
{
return NULL;
}

expr *param_node =
(expr *) PyCapsule_GetPointer(param_capsule, EXPR_CAPSULE_NAME);
if (!param_node)
{
PyErr_SetString(PyExc_ValueError, "invalid param capsule");
return NULL;
}

expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
if (!child)
{
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
return NULL;
}

expr *node = new_scalar_mult(param_node, child);
if (!node)
{
PyErr_SetString(PyExc_RuntimeError, "failed to create scalar_mult node");
return NULL;
}
expr_retain(node);
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}

#endif /* ATOM_SCALAR_MULT_H */
Loading