From 4ce73155ee6a96c0160936e1aadd2d9d2f26e553 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Sun, 15 Feb 2026 19:34:19 -0500 Subject: [PATCH 1/4] Add parameter support to Python bindings - Unify make_left_matmul and make_left_param_matmul into single binding - Accept (param_or_none, child, data, indices, indptr, m, n) - Delete left_param_matmul.h (merged into left_matmul.h) - Update SparseDiffEngine submodule (simplified new_left_matmul API) Co-Authored-By: Claude Opus 4.6 --- SparseDiffEngine | 2 +- .../_bindings/atoms/const_scalar_mult.h | 17 ++++- .../_bindings/atoms/const_vector_mult.h | 21 ++++-- sparsediffpy/_bindings/atoms/constant.h | 5 +- sparsediffpy/_bindings/atoms/left_matmul.h | 73 ++++++++++++++++--- sparsediffpy/_bindings/atoms/parameter.h | 26 +++++++ sparsediffpy/_bindings/atoms/scalar_mult.h | 44 +++++++++++ sparsediffpy/_bindings/atoms/vector_mult.h | 44 +++++++++++ sparsediffpy/_bindings/bindings.c | 16 +++- .../_bindings/problem/register_params.h | 58 +++++++++++++++ .../_bindings/problem/update_params.h | 38 ++++++++++ 11 files changed, 323 insertions(+), 21 deletions(-) create mode 100644 sparsediffpy/_bindings/atoms/parameter.h create mode 100644 sparsediffpy/_bindings/atoms/scalar_mult.h create mode 100644 sparsediffpy/_bindings/atoms/vector_mult.h create mode 100644 sparsediffpy/_bindings/problem/register_params.h create mode 100644 sparsediffpy/_bindings/problem/update_params.h diff --git a/SparseDiffEngine b/SparseDiffEngine index 5316598..36864d4 160000 --- a/SparseDiffEngine +++ b/SparseDiffEngine @@ -1 +1 @@ -Subproject commit 5316598a490200483a65d26af14085d7c089eb20 +Subproject commit 36864d47e63a6f53726905e222d4b5c1d3845706 diff --git a/sparsediffpy/_bindings/atoms/const_scalar_mult.h b/sparsediffpy/_bindings/atoms/const_scalar_mult.h index 1a83e5a..189320d 100644 --- a/sparsediffpy/_bindings/atoms/const_scalar_mult.h +++ b/sparsediffpy/_bindings/atoms/const_scalar_mult.h @@ -3,8 +3,10 @@ #include "bivariate.h" #include "common.h" +#include "subexpr.h" -/* Constant scalar multiplication: a * f(x) where a is a constant double */ +/* Constant scalar multiplication: a * f(x) where a is a constant double. + * Creates a fixed parameter node for the scalar and calls new_scalar_mult. */ static PyObject *py_make_const_scalar_mult(PyObject *self, PyObject *args) { PyObject *child_capsule; @@ -22,11 +24,20 @@ static PyObject *py_make_const_scalar_mult(PyObject *self, PyObject *args) return NULL; } - expr *node = new_const_scalar_mult(a, child); + /* Create a 1x1 fixed parameter for the scalar value */ + expr *a_node = new_parameter(1, 1, PARAM_FIXED, child->n_vars, &a); + if (!a_node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create scalar parameter node"); + return NULL; + } + + expr *node = new_scalar_mult(a_node, child); + if (!node) { PyErr_SetString(PyExc_RuntimeError, - "failed to create const_scalar_mult node"); + "failed to create scalar_mult node"); return NULL; } expr_retain(node); /* Capsule owns a reference */ diff --git a/sparsediffpy/_bindings/atoms/const_vector_mult.h b/sparsediffpy/_bindings/atoms/const_vector_mult.h index aa1f83e..086caaf 100644 --- a/sparsediffpy/_bindings/atoms/const_vector_mult.h +++ b/sparsediffpy/_bindings/atoms/const_vector_mult.h @@ -3,9 +3,10 @@ #include "bivariate.h" #include "common.h" +#include "subexpr.h" -/* Constant vector elementwise multiplication: a ∘ f(x) where a is a constant vector - */ +/* Constant vector elementwise multiplication: a ∘ f(x) where a is a constant vector. + * Creates a fixed parameter node for the vector and calls new_vector_mult. */ static PyObject *py_make_const_vector_mult(PyObject *self, PyObject *args) { PyObject *child_capsule; @@ -42,14 +43,24 @@ static PyObject *py_make_const_vector_mult(PyObject *self, PyObject *args) double *a_data = (double *) PyArray_DATA(a_array); - expr *node = new_const_vector_mult(a_data, child); - + /* Create a fixed parameter node for the vector */ + expr *a_node = new_parameter(child->d1, child->d2, PARAM_FIXED, child->n_vars, + a_data); Py_DECREF(a_array); + if (!a_node) + { + PyErr_SetString(PyExc_RuntimeError, + "failed to create vector parameter node"); + return NULL; + } + + expr *node = new_vector_mult(a_node, child); + if (!node) { PyErr_SetString(PyExc_RuntimeError, - "failed to create const_vector_mult node"); + "failed to create vector_mult node"); return NULL; } expr_retain(node); /* Capsule owns a reference */ diff --git a/sparsediffpy/_bindings/atoms/constant.h b/sparsediffpy/_bindings/atoms/constant.h index d5fcba3..4e1b9d9 100644 --- a/sparsediffpy/_bindings/atoms/constant.h +++ b/sparsediffpy/_bindings/atoms/constant.h @@ -2,6 +2,7 @@ #define ATOM_CONSTANT_H #include "common.h" +#include "subexpr.h" static PyObject *py_make_constant(PyObject *self, PyObject *args) { @@ -19,8 +20,8 @@ static PyObject *py_make_constant(PyObject *self, PyObject *args) return NULL; } - expr *node = - new_constant(d1, d2, n_vars, (const double *) PyArray_DATA(values_array)); + expr *node = new_parameter(d1, d2, PARAM_FIXED, n_vars, + (const double *) PyArray_DATA(values_array)); Py_DECREF(values_array); if (!node) diff --git a/sparsediffpy/_bindings/atoms/left_matmul.h b/sparsediffpy/_bindings/atoms/left_matmul.h index 27fe3a4..9faebc5 100644 --- a/sparsediffpy/_bindings/atoms/left_matmul.h +++ b/sparsediffpy/_bindings/atoms/left_matmul.h @@ -3,15 +3,27 @@ #include "bivariate.h" #include "common.h" +#include "subexpr.h" -/* Left matrix multiplication: A @ f(x) where A is a constant matrix */ +/* Left matrix multiplication: A @ f(x). + * + * Unified binding for both fixed-constant and updatable-parameter cases. + * Python signature: + * make_left_matmul(param_or_none, child, data, indices, indptr, m, n) + * + * - param_or_none: None for fixed constants (a PARAM_FIXED parameter is created + * internally), or an existing parameter capsule for updatable parameters. + * - child: the child expression capsule f(x). + * - data, indices, indptr, m, n: CSR arrays defining the sparsity pattern and + * initial values of the matrix A. */ static PyObject *py_make_left_matmul(PyObject *self, PyObject *args) { + PyObject *param_obj; PyObject *child_capsule; PyObject *data_obj, *indices_obj, *indptr_obj; int m, n; - if (!PyArg_ParseTuple(args, "OOOOii", &child_capsule, &data_obj, &indices_obj, - &indptr_obj, &m, &n)) + if (!PyArg_ParseTuple(args, "OOOOOii", ¶m_obj, &child_capsule, &data_obj, + &indices_obj, &indptr_obj, &m, &n)) { return NULL; } @@ -38,18 +50,61 @@ static PyObject *py_make_left_matmul(PyObject *self, PyObject *args) return NULL; } - int nnz = (int) PyArray_SIZE(data_array); + double *csr_data = (double *) PyArray_DATA(data_array); + int *csr_indices = (int *) PyArray_DATA(indices_array); + int *csr_indptr = (int *) PyArray_DATA(indptr_array); + int nnz = csr_indptr[m]; + + /* Build CSR matrix from Python arrays */ CSR_Matrix *A = new_csr_matrix(m, n, nnz); - memcpy(A->x, PyArray_DATA(data_array), nnz * sizeof(double)); - memcpy(A->i, PyArray_DATA(indices_array), nnz * sizeof(int)); - memcpy(A->p, PyArray_DATA(indptr_array), (m + 1) * sizeof(int)); + memcpy(A->p, csr_indptr, (m + 1) * sizeof(int)); + memcpy(A->i, csr_indices, nnz * sizeof(int)); + memcpy(A->x, csr_data, nnz * sizeof(double)); + + /* Determine param_node: use passed capsule, or create PARAM_FIXED internally */ + expr *param_node; + if (param_obj == Py_None) + { + /* Fixed constant: create column-major values for the parameter node */ + double *col_major = (double *) calloc(m * n, sizeof(double)); + for (int row = 0; row < m; row++) + for (int k = csr_indptr[row]; k < csr_indptr[row + 1]; k++) + col_major[row + csr_indices[k] * m] = csr_data[k]; + + param_node = new_parameter(m, n, PARAM_FIXED, child->n_vars, col_major); + free(col_major); + + if (!param_node) + { + free_csr_matrix(A); + Py_DECREF(data_array); + Py_DECREF(indices_array); + Py_DECREF(indptr_array); + PyErr_SetString(PyExc_RuntimeError, + "failed to create matrix parameter node"); + return NULL; + } + } + else + { + param_node = (expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME); + if (!param_node) + { + free_csr_matrix(A); + Py_DECREF(data_array); + Py_DECREF(indices_array); + Py_DECREF(indptr_array); + PyErr_SetString(PyExc_ValueError, "invalid param capsule"); + return NULL; + } + } Py_DECREF(data_array); Py_DECREF(indices_array); Py_DECREF(indptr_array); - expr *node = new_left_matmul(child, A); - free_csr_matrix(A); + expr *node = new_left_matmul(param_node, child, A); + free_csr_matrix(A); /* constructor copies it */ if (!node) { diff --git a/sparsediffpy/_bindings/atoms/parameter.h b/sparsediffpy/_bindings/atoms/parameter.h new file mode 100644 index 0000000..fef3ace --- /dev/null +++ b/sparsediffpy/_bindings/atoms/parameter.h @@ -0,0 +1,26 @@ +#ifndef ATOM_PARAMETER_H +#define ATOM_PARAMETER_H + +#include "common.h" + +/* Updatable parameter: make_parameter(d1, d2, param_id, n_vars) + * Values are set later via problem_update_params. */ +static PyObject *py_make_parameter(PyObject *self, PyObject *args) +{ + int d1, d2, param_id, n_vars; + if (!PyArg_ParseTuple(args, "iiii", &d1, &d2, ¶m_id, &n_vars)) + { + return NULL; + } + + expr *node = new_parameter(d1, d2, param_id, n_vars, NULL); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create parameter node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_PARAMETER_H */ diff --git a/sparsediffpy/_bindings/atoms/scalar_mult.h b/sparsediffpy/_bindings/atoms/scalar_mult.h new file mode 100644 index 0000000..08eddf6 --- /dev/null +++ b/sparsediffpy/_bindings/atoms/scalar_mult.h @@ -0,0 +1,44 @@ +#ifndef ATOM_SCALAR_MULT_H +#define ATOM_SCALAR_MULT_H + +#include "bivariate.h" +#include "common.h" + +/* Parameter scalar multiplication: param * f(x) where param is a parameter capsule. + * Python name: make_param_scalar_mult(param_capsule, child_capsule) */ +static PyObject *py_make_param_scalar_mult(PyObject *self, PyObject *args) +{ + PyObject *param_capsule; + PyObject *child_capsule; + + if (!PyArg_ParseTuple(args, "OO", ¶m_capsule, &child_capsule)) + { + return NULL; + } + + expr *param_node = + (expr *) PyCapsule_GetPointer(param_capsule, EXPR_CAPSULE_NAME); + if (!param_node) + { + PyErr_SetString(PyExc_ValueError, "invalid param capsule"); + return NULL; + } + + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_scalar_mult(param_node, child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create scalar_mult node"); + return NULL; + } + expr_retain(node); + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_SCALAR_MULT_H */ diff --git a/sparsediffpy/_bindings/atoms/vector_mult.h b/sparsediffpy/_bindings/atoms/vector_mult.h new file mode 100644 index 0000000..9b2c468 --- /dev/null +++ b/sparsediffpy/_bindings/atoms/vector_mult.h @@ -0,0 +1,44 @@ +#ifndef ATOM_VECTOR_MULT_H +#define ATOM_VECTOR_MULT_H + +#include "bivariate.h" +#include "common.h" + +/* Parameter vector multiplication: param ∘ f(x) where param is a parameter capsule. + * Python name: make_param_vector_mult(param_capsule, child_capsule) */ +static PyObject *py_make_param_vector_mult(PyObject *self, PyObject *args) +{ + PyObject *param_capsule; + PyObject *child_capsule; + + if (!PyArg_ParseTuple(args, "OO", ¶m_capsule, &child_capsule)) + { + return NULL; + } + + expr *param_node = + (expr *) PyCapsule_GetPointer(param_capsule, EXPR_CAPSULE_NAME); + if (!param_node) + { + PyErr_SetString(PyExc_ValueError, "invalid param capsule"); + return NULL; + } + + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_vector_mult(param_node, child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create vector_mult node"); + return NULL; + } + expr_retain(node); + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_VECTOR_MULT_H */ diff --git a/sparsediffpy/_bindings/bindings.c b/sparsediffpy/_bindings/bindings.c index 8a0d794..b1adb33 100644 --- a/sparsediffpy/_bindings/bindings.c +++ b/sparsediffpy/_bindings/bindings.c @@ -24,6 +24,7 @@ #include "atoms/matmul.h" #include "atoms/multiply.h" #include "atoms/neg.h" +#include "atoms/parameter.h" #include "atoms/power.h" #include "atoms/prod.h" #include "atoms/prod_axis_one.h" @@ -36,6 +37,7 @@ #include "atoms/rel_entr_vector_scalar.h" #include "atoms/reshape.h" #include "atoms/right_matmul.h" +#include "atoms/scalar_mult.h" #include "atoms/sin.h" #include "atoms/sinh.h" #include "atoms/sum.h" @@ -44,6 +46,7 @@ #include "atoms/trace.h" #include "atoms/transpose.h" #include "atoms/variable.h" +#include "atoms/vector_mult.h" #include "atoms/xexp.h" /* Include problem bindings */ @@ -56,6 +59,8 @@ #include "problem/jacobian.h" #include "problem/make_problem.h" #include "problem/objective_forward.h" +#include "problem/register_params.h" +#include "problem/update_params.h" static int numpy_initialized = 0; @@ -70,6 +75,7 @@ static int ensure_numpy(void) static PyMethodDef DNLPMethods[] = { {"make_variable", py_make_variable, METH_VARARGS, "Create variable node"}, {"make_constant", py_make_constant, METH_VARARGS, "Create constant node"}, + {"make_parameter", py_make_parameter, METH_VARARGS, "Create parameter node"}, {"make_linear", py_make_linear, METH_VARARGS, "Create linear op node"}, {"make_log", py_make_log, METH_VARARGS, "Create log node"}, {"make_exp", py_make_exp, METH_VARARGS, "Create exp node"}, @@ -110,7 +116,11 @@ static PyMethodDef DNLPMethods[] = { {"make_logistic", py_make_logistic, METH_VARARGS, "Create logistic node"}, {"make_xexp", py_make_xexp, METH_VARARGS, "Create xexp node"}, {"make_left_matmul", py_make_left_matmul, METH_VARARGS, - "Create left matmul node (A @ f(x))"}, + "Create left matmul node (A @ f(x)): pass None or param capsule as first arg"}, + {"make_param_scalar_mult", py_make_param_scalar_mult, METH_VARARGS, + "Create scalar mult from parameter (p * f(x))"}, + {"make_param_vector_mult", py_make_param_vector_mult, METH_VARARGS, + "Create vector mult from parameter (p ∘ f(x))"}, {"make_right_matmul", py_make_right_matmul, METH_VARARGS, "Create right matmul node (f(x) @ A)"}, {"make_quad_form", py_make_quad_form, METH_VARARGS, @@ -150,6 +160,10 @@ static PyMethodDef DNLPMethods[] = { "Compute Lagrangian Hessian"}, {"get_hessian", py_get_hessian, METH_VARARGS, "Get Lagrangian Hessian without recomputing"}, + {"problem_register_params", py_problem_register_params, METH_VARARGS, + "Register parameter nodes with the problem"}, + {"problem_update_params", py_problem_update_params, METH_VARARGS, + "Update parameter values"}, {NULL, NULL, 0, NULL}}; static struct PyModuleDef sparsediffpy_module = { diff --git a/sparsediffpy/_bindings/problem/register_params.h b/sparsediffpy/_bindings/problem/register_params.h new file mode 100644 index 0000000..a80d110 --- /dev/null +++ b/sparsediffpy/_bindings/problem/register_params.h @@ -0,0 +1,58 @@ +#ifndef PROBLEM_REGISTER_PARAMS_H +#define PROBLEM_REGISTER_PARAMS_H + +#include "atoms/common.h" +#include "problem/common.h" + +/* Register parameter nodes with the problem. + * Python: problem_register_params(problem_capsule, [param_capsule, ...]) */ +static PyObject *py_problem_register_params(PyObject *self, PyObject *args) +{ + PyObject *prob_capsule; + PyObject *param_list; + if (!PyArg_ParseTuple(args, "OO", &prob_capsule, ¶m_list)) + { + return NULL; + } + + problem *prob = + (problem *) PyCapsule_GetPointer(prob_capsule, PROBLEM_CAPSULE_NAME); + if (!prob) + { + PyErr_SetString(PyExc_ValueError, "invalid problem capsule"); + return NULL; + } + + if (!PyList_Check(param_list)) + { + PyErr_SetString(PyExc_TypeError, "param_nodes must be a list"); + return NULL; + } + + Py_ssize_t n = PyList_Size(param_list); + expr **param_nodes = (expr **) malloc(n * sizeof(expr *)); + if (!param_nodes) + { + PyErr_NoMemory(); + return NULL; + } + + for (Py_ssize_t i = 0; i < n; i++) + { + PyObject *cap = PyList_GetItem(param_list, i); + param_nodes[i] = (expr *) PyCapsule_GetPointer(cap, EXPR_CAPSULE_NAME); + if (!param_nodes[i]) + { + free(param_nodes); + PyErr_SetString(PyExc_ValueError, "invalid parameter capsule"); + return NULL; + } + } + + problem_register_params(prob, param_nodes, (int) n); + free(param_nodes); + + Py_RETURN_NONE; +} + +#endif /* PROBLEM_REGISTER_PARAMS_H */ diff --git a/sparsediffpy/_bindings/problem/update_params.h b/sparsediffpy/_bindings/problem/update_params.h new file mode 100644 index 0000000..6766b7f --- /dev/null +++ b/sparsediffpy/_bindings/problem/update_params.h @@ -0,0 +1,38 @@ +#ifndef PROBLEM_UPDATE_PARAMS_H +#define PROBLEM_UPDATE_PARAMS_H + +#include "problem/common.h" + +/* Update parameter values. + * Python: problem_update_params(problem_capsule, theta_array) */ +static PyObject *py_problem_update_params(PyObject *self, PyObject *args) +{ + PyObject *prob_capsule; + PyObject *theta_obj; + if (!PyArg_ParseTuple(args, "OO", &prob_capsule, &theta_obj)) + { + return NULL; + } + + problem *prob = + (problem *) PyCapsule_GetPointer(prob_capsule, PROBLEM_CAPSULE_NAME); + if (!prob) + { + PyErr_SetString(PyExc_ValueError, "invalid problem capsule"); + return NULL; + } + + PyArrayObject *theta_array = (PyArrayObject *) PyArray_FROM_OTF( + theta_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + if (!theta_array) + { + return NULL; + } + + problem_update_params(prob, (const double *) PyArray_DATA(theta_array)); + Py_DECREF(theta_array); + + Py_RETURN_NONE; +} + +#endif /* PROBLEM_UPDATE_PARAMS_H */ From 20655a09c28bcfd18428c9e7426e9a9b2f6afc03 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Mon, 16 Feb 2026 00:19:38 -0500 Subject: [PATCH 2/4] Pass CSR data directly for fixed-const left_matmul, update submodule - left_matmul.h: replace col-major conversion with direct CSR data pass to new_parameter(nnz, 1, PARAM_FIXED, ...) - Update SparseDiffEngine submodule (CSR data order for params) Co-Authored-By: Claude Opus 4.6 --- SparseDiffEngine | 2 +- sparsediffpy/_bindings/atoms/left_matmul.h | 10 ++-------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/SparseDiffEngine b/SparseDiffEngine index 36864d4..978aa03 160000 --- a/SparseDiffEngine +++ b/SparseDiffEngine @@ -1 +1 @@ -Subproject commit 36864d47e63a6f53726905e222d4b5c1d3845706 +Subproject commit 978aa03023945143aa7656b384b4a5e9d02428a6 diff --git a/sparsediffpy/_bindings/atoms/left_matmul.h b/sparsediffpy/_bindings/atoms/left_matmul.h index 9faebc5..04e1537 100644 --- a/sparsediffpy/_bindings/atoms/left_matmul.h +++ b/sparsediffpy/_bindings/atoms/left_matmul.h @@ -65,14 +65,8 @@ static PyObject *py_make_left_matmul(PyObject *self, PyObject *args) expr *param_node; if (param_obj == Py_None) { - /* Fixed constant: create column-major values for the parameter node */ - double *col_major = (double *) calloc(m * n, sizeof(double)); - for (int row = 0; row < m; row++) - for (int k = csr_indptr[row]; k < csr_indptr[row + 1]; k++) - col_major[row + csr_indices[k] * m] = csr_data[k]; - - param_node = new_parameter(m, n, PARAM_FIXED, child->n_vars, col_major); - free(col_major); + /* Fixed constant: pass CSR data directly (values are already in CSR order) */ + param_node = new_parameter(nnz, 1, PARAM_FIXED, child->n_vars, csr_data); if (!param_node) { From bbb8c3f1d0205293f216e4192830650e721d140f Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 17 Feb 2026 20:45:05 -0500 Subject: [PATCH 3/4] Unify right_matmul binding to accept param_or_none like left_matmul Update py_make_right_matmul signature to (param_or_none, child, data, indices, indptr, m, n) matching left_matmul. When param_or_none is None, a PARAM_FIXED node is created internally; otherwise the passed parameter capsule is used for updatable sparse parameters. Co-Authored-By: Claude Opus 4.6 --- sparsediffpy/_bindings/atoms/right_matmul.h | 67 ++++++++++++++++++--- sparsediffpy/_bindings/bindings.c | 2 +- 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/sparsediffpy/_bindings/atoms/right_matmul.h b/sparsediffpy/_bindings/atoms/right_matmul.h index c1c3481..55864ab 100644 --- a/sparsediffpy/_bindings/atoms/right_matmul.h +++ b/sparsediffpy/_bindings/atoms/right_matmul.h @@ -3,15 +3,27 @@ #include "bivariate.h" #include "common.h" +#include "subexpr.h" -/* Right matrix multiplication: f(x) @ A where A is a constant matrix */ +/* Right matrix multiplication: f(x) @ A where A is a constant or parameter matrix. + * + * Unified binding for both fixed-constant and updatable-parameter cases. + * Python signature: + * make_right_matmul(param_or_none, child, data, indices, indptr, m, n) + * + * - param_or_none: None for fixed constants (a PARAM_FIXED parameter is created + * internally), or an existing parameter capsule for updatable parameters. + * - child: the child expression capsule f(x). + * - data, indices, indptr, m, n: CSR arrays defining the sparsity pattern and + * initial values of the matrix A. */ static PyObject *py_make_right_matmul(PyObject *self, PyObject *args) { + PyObject *param_obj; PyObject *child_capsule; PyObject *data_obj, *indices_obj, *indptr_obj; int m, n; - if (!PyArg_ParseTuple(args, "OOOOii", &child_capsule, &data_obj, &indices_obj, - &indptr_obj, &m, &n)) + if (!PyArg_ParseTuple(args, "OOOOOii", ¶m_obj, &child_capsule, &data_obj, + &indices_obj, &indptr_obj, &m, &n)) { return NULL; } @@ -38,18 +50,55 @@ static PyObject *py_make_right_matmul(PyObject *self, PyObject *args) return NULL; } - int nnz = (int) PyArray_SIZE(data_array); + double *csr_data = (double *) PyArray_DATA(data_array); + int *csr_indices = (int *) PyArray_DATA(indices_array); + int *csr_indptr = (int *) PyArray_DATA(indptr_array); + int nnz = csr_indptr[m]; + + /* Build CSR matrix from Python arrays */ CSR_Matrix *A = new_csr_matrix(m, n, nnz); - memcpy(A->x, PyArray_DATA(data_array), nnz * sizeof(double)); - memcpy(A->i, PyArray_DATA(indices_array), nnz * sizeof(int)); - memcpy(A->p, PyArray_DATA(indptr_array), (m + 1) * sizeof(int)); + memcpy(A->p, csr_indptr, (m + 1) * sizeof(int)); + memcpy(A->i, csr_indices, nnz * sizeof(int)); + memcpy(A->x, csr_data, nnz * sizeof(double)); + + /* Determine param_node: use passed capsule, or create PARAM_FIXED internally */ + expr *param_node; + if (param_obj == Py_None) + { + /* Fixed constant: pass CSR data directly (values are already in CSR order) */ + param_node = new_parameter(nnz, 1, PARAM_FIXED, child->n_vars, csr_data); + + if (!param_node) + { + free_csr_matrix(A); + Py_DECREF(data_array); + Py_DECREF(indices_array); + Py_DECREF(indptr_array); + PyErr_SetString(PyExc_RuntimeError, + "failed to create matrix parameter node"); + return NULL; + } + } + else + { + param_node = (expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME); + if (!param_node) + { + free_csr_matrix(A); + Py_DECREF(data_array); + Py_DECREF(indices_array); + Py_DECREF(indptr_array); + PyErr_SetString(PyExc_ValueError, "invalid param capsule"); + return NULL; + } + } Py_DECREF(data_array); Py_DECREF(indices_array); Py_DECREF(indptr_array); - expr *node = new_right_matmul(child, A); - free_csr_matrix(A); + expr *node = new_right_matmul(param_node, child, A); + free_csr_matrix(A); /* constructor copies it */ if (!node) { diff --git a/sparsediffpy/_bindings/bindings.c b/sparsediffpy/_bindings/bindings.c index b1adb33..da17925 100644 --- a/sparsediffpy/_bindings/bindings.c +++ b/sparsediffpy/_bindings/bindings.c @@ -122,7 +122,7 @@ static PyMethodDef DNLPMethods[] = { {"make_param_vector_mult", py_make_param_vector_mult, METH_VARARGS, "Create vector mult from parameter (p ∘ f(x))"}, {"make_right_matmul", py_make_right_matmul, METH_VARARGS, - "Create right matmul node (f(x) @ A)"}, + "Create right matmul node (f(x) @ A): pass None or param capsule as first arg"}, {"make_quad_form", py_make_quad_form, METH_VARARGS, "Create quadratic form node (x' * Q * x)"}, {"make_quad_over_lin", py_make_quad_over_lin, METH_VARARGS, From de375089465206c56e6dc890e8f3126152faf239 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Wed, 18 Feb 2026 01:11:17 -0500 Subject: [PATCH 4/4] Update SparseDiffEngine submodule for right_matmul parameter support Co-Authored-By: Claude Opus 4.6 --- SparseDiffEngine | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SparseDiffEngine b/SparseDiffEngine index 978aa03..1c3ebee 160000 --- a/SparseDiffEngine +++ b/SparseDiffEngine @@ -1 +1 @@ -Subproject commit 978aa03023945143aa7656b384b4a5e9d02428a6 +Subproject commit 1c3ebeecbe8843bc682c3ac9d24a81e722cd14d2