From 2913deebc60a1425934ca7d30fd39814e24acc3d Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Mon, 9 Feb 2026 01:29:09 -0500 Subject: [PATCH 01/24] Add parameter support to C diff engine Add parameter node type and parameter-aware variants of scalar mult, vector mult, and left matmul. Parameters store an offset into a global theta vector and can be updated via problem_update_params without rebuilding the expression tree. Co-Authored-By: Claude Opus 4.6 --- include/affine.h | 1 + include/bivariate.h | 9 ++ include/problem.h | 10 ++ include/subexpr.h | 13 ++- src/affine/parameter.c | 74 +++++++++++++++ src/bivariate/const_scalar_mult.c | 29 +++++- src/bivariate/const_vector_mult.c | 44 ++++++++- src/bivariate/left_matmul.c | 146 +++++++++++++++++++++++++++++- src/problem.c | 29 ++++++ 9 files changed, 346 insertions(+), 9 deletions(-) create mode 100644 src/affine/parameter.c diff --git a/include/affine.h b/include/affine.h index ada58f1..cc7120a 100644 --- a/include/affine.h +++ b/include/affine.h @@ -34,6 +34,7 @@ expr *new_trace(expr *child); expr *new_constant(int d1, int d2, int n_vars, const double *values); expr *new_variable(int d1, int d2, int var_id, int n_vars); +expr *new_parameter(int d1, int d2, int param_id, int n_vars); expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs); expr *new_reshape(expr *child, int d1, int d2); diff --git a/include/bivariate.h b/include/bivariate.h index aa005ed..bf15224 100644 --- a/include/bivariate.h +++ b/include/bivariate.h @@ -42,4 +42,13 @@ expr *new_const_scalar_mult(double a, expr *child); /* Constant vector elementwise multiplication: a ∘ f(x) where a is constant */ expr *new_const_vector_mult(const double *a, expr *child); +/* Left matrix multiplication with parameter source: P @ f(x) where P is a parameter */ +expr *new_left_param_matmul(expr *param_node, expr *u, int A_m, int A_n); + +/* Parameter scalar multiplication: p * f(x) where p is a parameter */ +expr *new_param_scalar_mult(expr *param_node, expr *child); + +/* Parameter vector elementwise multiplication: p ∘ f(x) where p is a parameter */ +expr *new_param_vector_mult(expr *param_node, expr *child); + #endif /* BIVARIATE_H */ diff --git a/include/problem.h b/include/problem.h index 2462ffd..83e516e 100644 --- a/include/problem.h +++ b/include/problem.h @@ -59,6 +59,11 @@ typedef struct problem * hessian are called */ bool jacobian_called; + /* Parameter tracking for fast parameter updates */ + expr **param_nodes; /* weak references to parameter nodes in tree */ + int n_param_nodes; + int n_params; /* total scalar parameters */ + /* Statistics for performance measurement */ Diff_engine_stats stats; bool verbose; @@ -78,4 +83,9 @@ void problem_gradient(problem *prob); void problem_jacobian(problem *prob); void problem_hessian(problem *prob, double obj_w, const double *w); +/* Parameter support */ +void problem_register_params(problem *prob, expr **param_nodes, + int n_param_nodes, int n_params); +void problem_update_params(problem *prob, const double *theta); + #endif diff --git a/include/subexpr.h b/include/subexpr.h index a663282..224fcb1 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -25,6 +25,13 @@ /* Forward declaration */ struct int_double_pair; +/* Parameter node: like constant but with updatable values via problem_update_params */ +typedef struct parameter_expr +{ + expr base; + int param_id; /* offset into global theta vector */ +} parameter_expr; + /* Type-specific expression structures that "inherit" from expr */ /* Linear operator: y = A * x + b */ @@ -110,6 +117,8 @@ typedef struct left_matmul_expr CSR_Matrix *A; CSR_Matrix *AT; CSC_Matrix *CSC_work; + expr *param_source; /* if non-NULL, refresh A/AT values from param_source->value */ + int src_m, src_n; /* original (non-block-diag) matrix dimensions */ } left_matmul_expr; /* Right matrix multiplication: y = f(x) * A where f(x) is an expression. @@ -128,13 +137,15 @@ typedef struct const_scalar_mult_expr { expr base; double a; + expr *param_source; /* if non-NULL, read a from param_source->value[0] */ } const_scalar_mult_expr; /* Constant vector elementwise multiplication: y = a \circ child for constant a */ typedef struct const_vector_mult_expr { expr base; - double *a; /* length equals node->size */ + double *a; /* length equals node->size */ + expr *param_source; /* if non-NULL, use param_source->value instead of a */ } const_vector_mult_expr; /* Index/slicing: y = child[indices] where indices is a list of flat positions */ diff --git a/src/affine/parameter.c b/src/affine/parameter.c new file mode 100644 index 0000000..06d0724 --- /dev/null +++ b/src/affine/parameter.c @@ -0,0 +1,74 @@ +/* + * Copyright 2026 Daniel Cederberg and William Zhang + * + * This file is part of the DNLP-differentiation-engine project. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Parameter leaf node: behaviorally identical to constant (zero derivatives + w.r.t. variables), but its values are updatable via problem_update_params. + This allows re-solving with different parameter values without rebuilding + the expression tree. */ + +#include "affine.h" +#include "subexpr.h" +#include +#include + +static void forward(expr *node, const double *u) +{ + /* Values are set by problem_update_params, not by forward pass */ + (void)node; + (void)u; +} + +static void jacobian_init(expr *node) +{ + /* Parameter jacobian is all zeros: size x n_vars with 0 nonzeros */ + node->jacobian = new_csr_matrix(node->size, node->n_vars, 0); +} + +static void eval_jacobian(expr *node) +{ + /* Parameter jacobian never changes */ + (void)node; +} + +static void wsum_hess_init(expr *node) +{ + /* Parameter Hessian is all zeros */ + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0); +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + (void)node; + (void)w; +} + +static bool is_affine(const expr *node) +{ + (void)node; + return true; +} + +expr *new_parameter(int d1, int d2, int param_id, int n_vars) +{ + parameter_expr *pnode = (parameter_expr *)calloc(1, sizeof(parameter_expr)); + init_expr(&pnode->base, d1, d2, n_vars, forward, jacobian_init, eval_jacobian, + is_affine, wsum_hess_init, eval_wsum_hess, NULL); + pnode->param_id = param_id; + /* values will be populated by problem_update_params */ + return &pnode->base; +} diff --git a/src/bivariate/const_scalar_mult.c b/src/bivariate/const_scalar_mult.c index 4898389..33ac131 100644 --- a/src/bivariate/const_scalar_mult.c +++ b/src/bivariate/const_scalar_mult.c @@ -24,6 +24,11 @@ /* Constant scalar multiplication: y = a * child where a is a constant double */ +static inline double get_scalar(const const_scalar_mult_expr *sn) +{ + return sn->param_source ? sn->param_source->value[0] : sn->a; +} + static void forward(expr *node, const double *u) { expr *child = node->left; @@ -32,7 +37,7 @@ static void forward(expr *node, const double *u) child->forward(child, u); /* local forward pass: multiply each element by scalar a */ - double a = ((const_scalar_mult_expr *) node)->a; + double a = get_scalar((const_scalar_mult_expr *) node); for (int i = 0; i < node->size; i++) { node->value[i] = a * child->value[i]; @@ -55,7 +60,7 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *child = node->left; - double a = ((const_scalar_mult_expr *) node)->a; + double a = get_scalar((const_scalar_mult_expr *) node); /* evaluate child */ child->eval_jacobian(child); @@ -85,7 +90,7 @@ static void eval_wsum_hess(expr *node, const double *w) expr *x = node->left; x->eval_wsum_hess(x, w); - double a = ((const_scalar_mult_expr *) node)->a; + double a = get_scalar((const_scalar_mult_expr *) node); for (int j = 0; j < x->wsum_hess->nnz; j++) { node->wsum_hess->x[j] = a * x->wsum_hess->x[j]; @@ -108,7 +113,25 @@ expr *new_const_scalar_mult(double a, expr *child) eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL); node->left = child; mult_node->a = a; + mult_node->param_source = NULL; + expr_retain(child); + + return node; +} + +expr *new_param_scalar_mult(expr *param_node, expr *child) +{ + const_scalar_mult_expr *mult_node = + (const_scalar_mult_expr *) calloc(1, sizeof(const_scalar_mult_expr)); + expr *node = &mult_node->base; + + init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init, + eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL); + node->left = child; + mult_node->a = param_node->value[0]; /* initial value */ + mult_node->param_source = param_node; expr_retain(child); + expr_retain(param_node); return node; } diff --git a/src/bivariate/const_vector_mult.c b/src/bivariate/const_vector_mult.c index 65823a7..2f11269 100644 --- a/src/bivariate/const_vector_mult.c +++ b/src/bivariate/const_vector_mult.c @@ -23,10 +23,15 @@ /* Constant vector elementwise multiplication: y = a \circ child */ +static inline const double *get_vector(const const_vector_mult_expr *vn) +{ + return vn->param_source ? vn->param_source->value : vn->a; +} + static void forward(expr *node, const double *u) { expr *child = node->left; - const double *a = ((const_vector_mult_expr *) node)->a; + const double *a = get_vector((const_vector_mult_expr *) node); /* child's forward pass */ child->forward(child, u); @@ -54,7 +59,7 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *x = node->left; - const double *a = ((const_vector_mult_expr *) node)->a; + const double *a = get_vector((const_vector_mult_expr *) node); /* evaluate x */ x->eval_jacobian(x); @@ -87,7 +92,7 @@ static void wsum_hess_init(expr *node) static void eval_wsum_hess(expr *node, const double *w) { expr *x = node->left; - const double *a = ((const_vector_mult_expr *) node)->a; + const double *a = get_vector((const_vector_mult_expr *) node); /* scale weights w by a */ for (int i = 0; i < node->size; i++) @@ -128,6 +133,39 @@ expr *new_const_vector_mult(const double *a, expr *child) /* copy a vector */ vnode->a = (double *) malloc(child->size * sizeof(double)); memcpy(vnode->a, a, child->size * sizeof(double)); + vnode->param_source = NULL; + + return node; +} + +static void free_param_type_data(expr *node) +{ + const_vector_mult_expr *vnode = (const_vector_mult_expr *) node; + /* a is not owned when param_source is set */ + free(vnode->a); + if (vnode->param_source) + { + free_expr(vnode->param_source); + } +} + +expr *new_param_vector_mult(expr *param_node, expr *child) +{ + const_vector_mult_expr *vnode = + (const_vector_mult_expr *) calloc(1, sizeof(const_vector_mult_expr)); + expr *node = &vnode->base; + + init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init, + eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, + free_param_type_data); + node->left = child; + expr_retain(child); + + /* Still allocate a copy for initial values (used before first update_params) */ + vnode->a = (double *) malloc(child->size * sizeof(double)); + memcpy(vnode->a, param_node->value, child->size * sizeof(double)); + vnode->param_source = param_node; + expr_retain(param_node); return node; } diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index 92dd762..a1e730d 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -46,16 +46,60 @@ */ #include "utils/utils.h" +#include + +/* Refresh block-diagonal A values from param_source and recompute AT values. + The block-diagonal has n_blocks copies of the src_m x src_n source matrix. + block_diag_repeat_csr lays out values as: for each block, copy the entire + source nnz array in order. So A->x is [src_nnz | src_nnz | ... | src_nnz]. */ +static void refresh_param_values(left_matmul_expr *lin_node) +{ + const double *src = lin_node->param_source->value; + CSR_Matrix *A = lin_node->A; + int src_m = lin_node->src_m; + int src_n = lin_node->src_n; + int total_rows = A->m; + int n_blocks = total_rows / src_m; + + /* Rebuild A values from column-major source matrix. + For each block, iterate rows of the source matrix and fill CSR values. */ + int nnz_cursor = 0; + for (int block = 0; block < n_blocks; block++) + { + for (int row = 0; row < src_m; row++) + { + int dest_row = block * src_m + row; + for (int j = A->p[dest_row]; j < A->p[dest_row + 1]; j++) + { + /* column index in local block coordinates */ + int col = A->i[j] - block * src_n; + /* source is column-major: src[row + col * src_m] */ + A->x[nnz_cursor] = src[row + col * src_m]; + nnz_cursor++; + } + } + } + + /* Recompute AT values from updated A */ + AT_fill_values(A, lin_node->AT, lin_node->base.iwork); +} static void forward(expr *node, const double *u) { expr *x = node->left; + left_matmul_expr *lin_node = (left_matmul_expr *) node; + + /* refresh A/AT if parameter-sourced */ + if (lin_node->param_source) + { + refresh_param_values(lin_node); + } /* child's forward pass */ node->left->forward(node->left, u); /* y = A_kron @ vec(f(x)) */ - csr_matvec_wo_offset(((left_matmul_expr *) node)->A, x->value, node->value); + csr_matvec_wo_offset(lin_node->A, x->value, node->value); } static bool is_affine(const expr *node) @@ -95,6 +139,12 @@ static void eval_jacobian(expr *node) expr *x = node->left; left_matmul_expr *lin_node = (left_matmul_expr *) node; + /* refresh A if parameter-sourced */ + if (lin_node->param_source) + { + refresh_param_values(lin_node); + } + /* evaluate child's jacobian and convert to CSC */ x->eval_jacobian(x); csr_to_csc_fill_values(x->jacobian, lin_node->CSC_work, node->iwork); @@ -121,8 +171,15 @@ static void wsum_hess_init(expr *node) static void eval_wsum_hess(expr *node, const double *w) { - /* compute A^T w*/ left_matmul_expr *lin_node = (left_matmul_expr *) node; + + /* refresh AT if parameter-sourced */ + if (lin_node->param_source) + { + refresh_param_values(lin_node); + } + + /* compute A^T w*/ csr_matvec_wo_offset(lin_node->AT, w, node->dwork); node->left->eval_wsum_hess(node->left, node->dwork); @@ -170,6 +227,91 @@ expr *new_left_matmul(expr *u, const CSR_Matrix *A) node->iwork = (int *) malloc(alloc * sizeof(int)); lin_node->AT = transpose(lin_node->A, node->iwork); lin_node->CSC_work = NULL; + lin_node->param_source = NULL; + lin_node->src_m = 0; + lin_node->src_n = 0; + + return node; +} + +static void free_param_matmul_type_data(expr *node) +{ + left_matmul_expr *lin_node = (left_matmul_expr *) node; + free_csr_matrix(lin_node->A); + free_csr_matrix(lin_node->AT); + if (lin_node->CSC_work) + { + free_csc_matrix(lin_node->CSC_work); + } + if (lin_node->param_source) + { + free_expr(lin_node->param_source); + } + lin_node->A = NULL; + lin_node->AT = NULL; + lin_node->CSC_work = NULL; + lin_node->param_source = NULL; +} + +expr *new_left_param_matmul(expr *param_node, expr *u, int A_m, int A_n) +{ + /* Same dimension logic as new_left_matmul */ + int d1, d2, n_blocks; + if (u->d1 == A_n) + { + d1 = A_m; + d2 = u->d2; + n_blocks = u->d2; + } + else if (u->d2 == A_n && u->d1 == 1) + { + d1 = 1; + d2 = A_m; + n_blocks = 1; + } + else + { + fprintf(stderr, "Error in new_left_param_matmul: dimension mismatch \n"); + exit(1); + } + + /* Build a temporary CSR from param_node's current values (column-major) */ + int nnz = A_m * A_n; /* dense for now — could optimize for sparse params later */ + CSR_Matrix *A_tmp = new_csr_matrix(A_m, A_n, nnz); + int idx = 0; + for (int row = 0; row < A_m; row++) + { + A_tmp->p[row] = idx; + for (int col = 0; col < A_n; col++) + { + A_tmp->i[idx] = col; + /* param_node->value is column-major: value[row + col * A_m] */ + A_tmp->x[idx] = param_node->value[row + col * A_m]; + idx++; + } + } + A_tmp->p[A_m] = idx; + + /* Allocate the type-specific struct */ + left_matmul_expr *lin_node = + (left_matmul_expr *) calloc(1, sizeof(left_matmul_expr)); + expr *node = &lin_node->base; + init_expr(node, d1, d2, u->n_vars, forward, jacobian_init, eval_jacobian, + is_affine, wsum_hess_init, eval_wsum_hess, free_param_matmul_type_data); + node->left = u; + expr_retain(u); + + /* Initialize type-specific fields */ + lin_node->A = block_diag_repeat_csr(A_tmp, n_blocks); + int alloc = MAX(lin_node->A->n, node->n_vars); + node->iwork = (int *) malloc(alloc * sizeof(int)); + lin_node->AT = transpose(lin_node->A, node->iwork); + lin_node->CSC_work = NULL; + lin_node->param_source = param_node; + lin_node->src_m = A_m; + lin_node->src_n = A_n; + expr_retain(param_node); + free_csr_matrix(A_tmp); return node; } diff --git a/src/problem.c b/src/problem.c index d9c14a9..d2ad379 100644 --- a/src/problem.c +++ b/src/problem.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "problem.h" +#include "subexpr.h" #include "utils/utils.h" #include #include @@ -65,6 +66,10 @@ problem *new_problem(expr *objective, expr **constraints, int n_constraints, prob->stats.nnz_affine = 0; prob->stats.nnz_nonlinear = 0; + prob->param_nodes = NULL; + prob->n_param_nodes = 0; + prob->n_params = 0; + prob->verbose = verbose; return prob; @@ -287,6 +292,9 @@ void free_problem(problem *prob) free_csr_matrix(prob->lagrange_hessian); free(prob->hess_idx_map); + /* Free parameter node array (weak references, not owned) */ + free(prob->param_nodes); + /* Release expression references (decrements refcount) */ free_expr(prob->objective); for (int i = 0; i < prob->n_constraints; i++) @@ -439,3 +447,24 @@ void problem_hessian(problem *prob, double obj_w, const double *w) clock_gettime(CLOCK_MONOTONIC, &timer.end); prob->stats.time_eval_hessian += GET_ELAPSED_SECONDS(timer); } + +void problem_register_params(problem *prob, expr **param_nodes, + int n_param_nodes, int n_params) +{ + prob->n_param_nodes = n_param_nodes; + prob->n_params = n_params; + prob->param_nodes = (expr **)malloc(n_param_nodes * sizeof(expr *)); + memcpy(prob->param_nodes, param_nodes, n_param_nodes * sizeof(expr *)); +} + +void problem_update_params(problem *prob, const double *theta) +{ + for (int i = 0; i < prob->n_param_nodes; i++) + { + parameter_expr *p = (parameter_expr *)prob->param_nodes[i]; + memcpy(p->base.value, theta + p->param_id, + p->base.size * sizeof(double)); + } + /* Force re-evaluation of affine Jacobians on next call */ + prob->jacobian_called = false; +} From fea0708727061b8017ddbc41014bb0e9c826b13e Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 10 Feb 2026 10:24:26 -0500 Subject: [PATCH 02/24] Remove redundant n_params from problem struct n_params (total scalar parameter count) can be computed from n_param_nodes and each node's size, making the field redundant. Co-Authored-By: Claude Opus 4.6 --- include/problem.h | 3 +-- src/problem.c | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/include/problem.h b/include/problem.h index 83e516e..4d8fc49 100644 --- a/include/problem.h +++ b/include/problem.h @@ -62,7 +62,6 @@ typedef struct problem /* Parameter tracking for fast parameter updates */ expr **param_nodes; /* weak references to parameter nodes in tree */ int n_param_nodes; - int n_params; /* total scalar parameters */ /* Statistics for performance measurement */ Diff_engine_stats stats; @@ -85,7 +84,7 @@ void problem_hessian(problem *prob, double obj_w, const double *w); /* Parameter support */ void problem_register_params(problem *prob, expr **param_nodes, - int n_param_nodes, int n_params); + int n_param_nodes); void problem_update_params(problem *prob, const double *theta); #endif diff --git a/src/problem.c b/src/problem.c index d2ad379..53cad18 100644 --- a/src/problem.c +++ b/src/problem.c @@ -68,7 +68,6 @@ problem *new_problem(expr *objective, expr **constraints, int n_constraints, prob->param_nodes = NULL; prob->n_param_nodes = 0; - prob->n_params = 0; prob->verbose = verbose; @@ -449,10 +448,9 @@ void problem_hessian(problem *prob, double obj_w, const double *w) } void problem_register_params(problem *prob, expr **param_nodes, - int n_param_nodes, int n_params) + int n_param_nodes) { prob->n_param_nodes = n_param_nodes; - prob->n_params = n_params; prob->param_nodes = (expr **)malloc(n_param_nodes * sizeof(expr *)); memcpy(prob->param_nodes, param_nodes, n_param_nodes * sizeof(expr *)); } From 36ec23916184068a4fc015d5842554d13ba0f315 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 10 Feb 2026 10:55:04 -0500 Subject: [PATCH 03/24] Remove redundant A_m, A_n params from new_left_param_matmul These dimensions are always equal to param_node->d1 and param_node->d2, which are set during make_parameter. Read them from the node directly. Co-Authored-By: Claude Opus 4.6 --- include/bivariate.h | 2 +- src/bivariate/left_matmul.c | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/include/bivariate.h b/include/bivariate.h index bf15224..5efd6e0 100644 --- a/include/bivariate.h +++ b/include/bivariate.h @@ -43,7 +43,7 @@ expr *new_const_scalar_mult(double a, expr *child); expr *new_const_vector_mult(const double *a, expr *child); /* Left matrix multiplication with parameter source: P @ f(x) where P is a parameter */ -expr *new_left_param_matmul(expr *param_node, expr *u, int A_m, int A_n); +expr *new_left_param_matmul(expr *param_node, expr *child); /* Parameter scalar multiplication: p * f(x) where p is a parameter */ expr *new_param_scalar_mult(expr *param_node, expr *child); diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index a1e730d..bf373be 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -253,17 +253,20 @@ static void free_param_matmul_type_data(expr *node) lin_node->param_source = NULL; } -expr *new_left_param_matmul(expr *param_node, expr *u, int A_m, int A_n) +expr *new_left_param_matmul(expr *param_node, expr *child) { + int A_m = param_node->d1; + int A_n = param_node->d2; + /* Same dimension logic as new_left_matmul */ int d1, d2, n_blocks; - if (u->d1 == A_n) + if (child->d1 == A_n) { d1 = A_m; - d2 = u->d2; - n_blocks = u->d2; + d2 = child->d2; + n_blocks = child->d2; } - else if (u->d2 == A_n && u->d1 == 1) + else if (child->d2 == A_n && child->d1 == 1) { d1 = 1; d2 = A_m; @@ -296,10 +299,10 @@ expr *new_left_param_matmul(expr *param_node, expr *u, int A_m, int A_n) left_matmul_expr *lin_node = (left_matmul_expr *) calloc(1, sizeof(left_matmul_expr)); expr *node = &lin_node->base; - init_expr(node, d1, d2, u->n_vars, forward, jacobian_init, eval_jacobian, + init_expr(node, d1, d2, child->n_vars, forward, jacobian_init, eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, free_param_matmul_type_data); - node->left = u; - expr_retain(u); + node->left = child; + expr_retain(child); /* Initialize type-specific fields */ lin_node->A = block_diag_repeat_csr(A_tmp, n_blocks); From 4f93f7ca0be8bea6a9a3029be8248aa7af570687 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 10 Feb 2026 11:11:33 -0500 Subject: [PATCH 04/24] Add total_parameter_size field to problem struct Precompute total parameter size in problem_register_params, mirroring how total_constraint_size is computed in new_problem. Co-Authored-By: Claude Opus 4.6 --- include/problem.h | 1 + src/problem.c | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/include/problem.h b/include/problem.h index 4d8fc49..268f784 100644 --- a/include/problem.h +++ b/include/problem.h @@ -62,6 +62,7 @@ typedef struct problem /* Parameter tracking for fast parameter updates */ expr **param_nodes; /* weak references to parameter nodes in tree */ int n_param_nodes; + int total_parameter_size; /* Statistics for performance measurement */ Diff_engine_stats stats; diff --git a/src/problem.c b/src/problem.c index 53cad18..4eb2eac 100644 --- a/src/problem.c +++ b/src/problem.c @@ -453,6 +453,10 @@ void problem_register_params(problem *prob, expr **param_nodes, prob->n_param_nodes = n_param_nodes; prob->param_nodes = (expr **)malloc(n_param_nodes * sizeof(expr *)); memcpy(prob->param_nodes, param_nodes, n_param_nodes * sizeof(expr *)); + + prob->total_parameter_size = 0; + for (int i = 0; i < n_param_nodes; i++) + prob->total_parameter_size += param_nodes[i]->size; } void problem_update_params(problem *prob, const double *theta) From cf695a2e4b4966ae9d6f841b3b0cc8ab0e74036b Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 10 Feb 2026 11:19:43 -0500 Subject: [PATCH 05/24] Inline get_scalar helper in const_scalar_mult.c Co-Authored-By: Claude Opus 4.6 --- src/bivariate/const_scalar_mult.c | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/bivariate/const_scalar_mult.c b/src/bivariate/const_scalar_mult.c index 33ac131..c1d5db1 100644 --- a/src/bivariate/const_scalar_mult.c +++ b/src/bivariate/const_scalar_mult.c @@ -24,11 +24,6 @@ /* Constant scalar multiplication: y = a * child where a is a constant double */ -static inline double get_scalar(const const_scalar_mult_expr *sn) -{ - return sn->param_source ? sn->param_source->value[0] : sn->a; -} - static void forward(expr *node, const double *u) { expr *child = node->left; @@ -37,7 +32,8 @@ static void forward(expr *node, const double *u) child->forward(child, u); /* local forward pass: multiply each element by scalar a */ - double a = get_scalar((const_scalar_mult_expr *) node); + const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node; + double a = sn->param_source ? sn->param_source->value[0] : sn->a; for (int i = 0; i < node->size; i++) { node->value[i] = a * child->value[i]; @@ -60,7 +56,8 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *child = node->left; - double a = get_scalar((const_scalar_mult_expr *) node); + const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node; + double a = sn->param_source ? sn->param_source->value[0] : sn->a; /* evaluate child */ child->eval_jacobian(child); @@ -90,7 +87,8 @@ static void eval_wsum_hess(expr *node, const double *w) expr *x = node->left; x->eval_wsum_hess(x, w); - double a = get_scalar((const_scalar_mult_expr *) node); + const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node; + double a = sn->param_source ? sn->param_source->value[0] : sn->a; for (int j = 0; j < x->wsum_hess->nnz; j++) { node->wsum_hess->x[j] = a * x->wsum_hess->x[j]; From bd4f3b039b6c075f3472658f580a180a4766a647 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 10 Feb 2026 11:26:04 -0500 Subject: [PATCH 06/24] Inline get_vector helper in const_vector_mult.c Co-Authored-By: Claude Opus 4.6 --- src/bivariate/const_vector_mult.c | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/bivariate/const_vector_mult.c b/src/bivariate/const_vector_mult.c index 2f11269..31f7236 100644 --- a/src/bivariate/const_vector_mult.c +++ b/src/bivariate/const_vector_mult.c @@ -23,15 +23,11 @@ /* Constant vector elementwise multiplication: y = a \circ child */ -static inline const double *get_vector(const const_vector_mult_expr *vn) -{ - return vn->param_source ? vn->param_source->value : vn->a; -} - static void forward(expr *node, const double *u) { expr *child = node->left; - const double *a = get_vector((const_vector_mult_expr *) node); + const_vector_mult_expr *vn = (const_vector_mult_expr *) node; + const double *a = vn->param_source ? vn->param_source->value : vn->a; /* child's forward pass */ child->forward(child, u); @@ -59,7 +55,8 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *x = node->left; - const double *a = get_vector((const_vector_mult_expr *) node); + const_vector_mult_expr *vn = (const_vector_mult_expr *) node; + const double *a = vn->param_source ? vn->param_source->value : vn->a; /* evaluate x */ x->eval_jacobian(x); @@ -92,7 +89,8 @@ static void wsum_hess_init(expr *node) static void eval_wsum_hess(expr *node, const double *w) { expr *x = node->left; - const double *a = get_vector((const_vector_mult_expr *) node); + const_vector_mult_expr *vn = (const_vector_mult_expr *) node; + const double *a = vn->param_source ? vn->param_source->value : vn->a; /* scale weights w by a */ for (int i = 0; i < node->size; i++) From 27acb7d869bda68c52d6e3a8c84c91c1653cc8ea Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 10 Feb 2026 11:44:13 -0500 Subject: [PATCH 07/24] Simplify refresh_param_values: fill one block, memcpy the rest Co-Authored-By: Claude Opus 4.6 --- src/bivariate/left_matmul.c | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index bf373be..f5a6f00 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -49,37 +49,32 @@ #include /* Refresh block-diagonal A values from param_source and recompute AT values. - The block-diagonal has n_blocks copies of the src_m x src_n source matrix. - block_diag_repeat_csr lays out values as: for each block, copy the entire - source nnz array in order. So A->x is [src_nnz | src_nnz | ... | src_nnz]. */ + The block-diagonal has n_blocks copies of the dense src_m x src_n source matrix. + A->x is laid out as [src_nnz | src_nnz | ... | src_nnz]. */ static void refresh_param_values(left_matmul_expr *lin_node) { const double *src = lin_node->param_source->value; CSR_Matrix *A = lin_node->A; int src_m = lin_node->src_m; int src_n = lin_node->src_n; - int total_rows = A->m; - int n_blocks = total_rows / src_m; + int src_nnz = src_m * src_n; + int n_blocks = A->m / src_m; - /* Rebuild A values from column-major source matrix. - For each block, iterate rows of the source matrix and fill CSR values. */ - int nnz_cursor = 0; - for (int block = 0; block < n_blocks; block++) + /* Build first block: column-major source -> row-major CSR values */ + for (int row = 0; row < src_m; row++) { - for (int row = 0; row < src_m; row++) + for (int col = 0; col < src_n; col++) { - int dest_row = block * src_m + row; - for (int j = A->p[dest_row]; j < A->p[dest_row + 1]; j++) - { - /* column index in local block coordinates */ - int col = A->i[j] - block * src_n; - /* source is column-major: src[row + col * src_m] */ - A->x[nnz_cursor] = src[row + col * src_m]; - nnz_cursor++; - } + A->x[row * src_n + col] = src[row + col * src_m]; } } + /* Copy first block to remaining blocks */ + for (int block = 1; block < n_blocks; block++) + { + memcpy(A->x + block * src_nnz, A->x, src_nnz * sizeof(double)); + } + /* Recompute AT values from updated A */ AT_fill_values(A, lin_node->AT, lin_node->base.iwork); } From 00f7732fabc81688e5ffe53484d48e3deee60c53 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 10 Feb 2026 11:46:19 -0500 Subject: [PATCH 08/24] Skip AT recomputation in param refresh; param matmul is always affine Co-Authored-By: Claude Opus 4.6 --- src/bivariate/left_matmul.c | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index f5a6f00..ad04dad 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -48,9 +48,11 @@ #include "utils/utils.h" #include -/* Refresh block-diagonal A values from param_source and recompute AT values. +/* Refresh block-diagonal A values from param_source. The block-diagonal has n_blocks copies of the dense src_m x src_n source matrix. - A->x is laid out as [src_nnz | src_nnz | ... | src_nnz]. */ + A->x is laid out as [src_nnz | src_nnz | ... | src_nnz]. + Note: AT is not refreshed because param matmul is always affine, so the + weighted Hessian is always zero regardless of AT values. */ static void refresh_param_values(left_matmul_expr *lin_node) { const double *src = lin_node->param_source->value; @@ -75,8 +77,6 @@ static void refresh_param_values(left_matmul_expr *lin_node) memcpy(A->x + block * src_nnz, A->x, src_nnz * sizeof(double)); } - /* Recompute AT values from updated A */ - AT_fill_values(A, lin_node->AT, lin_node->base.iwork); } static void forward(expr *node, const double *u) @@ -168,11 +168,8 @@ static void eval_wsum_hess(expr *node, const double *w) { left_matmul_expr *lin_node = (left_matmul_expr *) node; - /* refresh AT if parameter-sourced */ - if (lin_node->param_source) - { - refresh_param_values(lin_node); - } + /* No need to refresh AT for param-sourced nodes: param matmul is always + affine, so the child's weighted Hessian is zero regardless of AT values. */ /* compute A^T w*/ csr_matvec_wo_offset(lin_node->AT, w, node->dwork); From dc56c90266442e3449f525b9759113b45936e6ba Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 10 Feb 2026 16:52:44 -0500 Subject: [PATCH 09/24] Run clang-format on parameter support files Co-Authored-By: Claude Opus 4.6 --- include/bivariate.h | 3 ++- include/problem.h | 5 ++--- include/subexpr.h | 8 +++++--- src/affine/parameter.c | 14 +++++++------- src/bivariate/left_matmul.c | 4 ++-- src/problem.c | 10 ++++------ 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/include/bivariate.h b/include/bivariate.h index 5efd6e0..fb6bf9b 100644 --- a/include/bivariate.h +++ b/include/bivariate.h @@ -42,7 +42,8 @@ expr *new_const_scalar_mult(double a, expr *child); /* Constant vector elementwise multiplication: a ∘ f(x) where a is constant */ expr *new_const_vector_mult(const double *a, expr *child); -/* Left matrix multiplication with parameter source: P @ f(x) where P is a parameter */ +/* Left matrix multiplication with parameter source: P @ f(x) where P is a parameter + */ expr *new_left_param_matmul(expr *param_node, expr *child); /* Parameter scalar multiplication: p * f(x) where p is a parameter */ diff --git a/include/problem.h b/include/problem.h index 268f784..24d354b 100644 --- a/include/problem.h +++ b/include/problem.h @@ -60,7 +60,7 @@ typedef struct problem bool jacobian_called; /* Parameter tracking for fast parameter updates */ - expr **param_nodes; /* weak references to parameter nodes in tree */ + expr **param_nodes; /* weak references to parameter nodes in tree */ int n_param_nodes; int total_parameter_size; @@ -84,8 +84,7 @@ void problem_jacobian(problem *prob); void problem_hessian(problem *prob, double obj_w, const double *w); /* Parameter support */ -void problem_register_params(problem *prob, expr **param_nodes, - int n_param_nodes); +void problem_register_params(problem *prob, expr **param_nodes, int n_param_nodes); void problem_update_params(problem *prob, const double *theta); #endif diff --git a/include/subexpr.h b/include/subexpr.h index 224fcb1..00e61c4 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -25,7 +25,8 @@ /* Forward declaration */ struct int_double_pair; -/* Parameter node: like constant but with updatable values via problem_update_params */ +/* Parameter node: like constant but with updatable values via problem_update_params + */ typedef struct parameter_expr { expr base; @@ -117,8 +118,9 @@ typedef struct left_matmul_expr CSR_Matrix *A; CSR_Matrix *AT; CSC_Matrix *CSC_work; - expr *param_source; /* if non-NULL, refresh A/AT values from param_source->value */ - int src_m, src_n; /* original (non-block-diag) matrix dimensions */ + expr * + param_source; /* if non-NULL, refresh A/AT values from param_source->value */ + int src_m, src_n; /* original (non-block-diag) matrix dimensions */ } left_matmul_expr; /* Right matrix multiplication: y = f(x) * A where f(x) is an expression. diff --git a/src/affine/parameter.c b/src/affine/parameter.c index 06d0724..cd8ea95 100644 --- a/src/affine/parameter.c +++ b/src/affine/parameter.c @@ -29,8 +29,8 @@ static void forward(expr *node, const double *u) { /* Values are set by problem_update_params, not by forward pass */ - (void)node; - (void)u; + (void) node; + (void) u; } static void jacobian_init(expr *node) @@ -42,7 +42,7 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { /* Parameter jacobian never changes */ - (void)node; + (void) node; } static void wsum_hess_init(expr *node) @@ -53,19 +53,19 @@ static void wsum_hess_init(expr *node) static void eval_wsum_hess(expr *node, const double *w) { - (void)node; - (void)w; + (void) node; + (void) w; } static bool is_affine(const expr *node) { - (void)node; + (void) node; return true; } expr *new_parameter(int d1, int d2, int param_id, int n_vars) { - parameter_expr *pnode = (parameter_expr *)calloc(1, sizeof(parameter_expr)); + parameter_expr *pnode = (parameter_expr *) calloc(1, sizeof(parameter_expr)); init_expr(&pnode->base, d1, d2, n_vars, forward, jacobian_init, eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL); pnode->param_id = param_id; diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index ad04dad..d5d430a 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -76,7 +76,6 @@ static void refresh_param_values(left_matmul_expr *lin_node) { memcpy(A->x + block * src_nnz, A->x, src_nnz * sizeof(double)); } - } static void forward(expr *node, const double *u) @@ -292,7 +291,8 @@ expr *new_left_param_matmul(expr *param_node, expr *child) (left_matmul_expr *) calloc(1, sizeof(left_matmul_expr)); expr *node = &lin_node->base; init_expr(node, d1, d2, child->n_vars, forward, jacobian_init, eval_jacobian, - is_affine, wsum_hess_init, eval_wsum_hess, free_param_matmul_type_data); + is_affine, wsum_hess_init, eval_wsum_hess, + free_param_matmul_type_data); node->left = child; expr_retain(child); diff --git a/src/problem.c b/src/problem.c index 4eb2eac..bacd29a 100644 --- a/src/problem.c +++ b/src/problem.c @@ -447,11 +447,10 @@ void problem_hessian(problem *prob, double obj_w, const double *w) prob->stats.time_eval_hessian += GET_ELAPSED_SECONDS(timer); } -void problem_register_params(problem *prob, expr **param_nodes, - int n_param_nodes) +void problem_register_params(problem *prob, expr **param_nodes, int n_param_nodes) { prob->n_param_nodes = n_param_nodes; - prob->param_nodes = (expr **)malloc(n_param_nodes * sizeof(expr *)); + prob->param_nodes = (expr **) malloc(n_param_nodes * sizeof(expr *)); memcpy(prob->param_nodes, param_nodes, n_param_nodes * sizeof(expr *)); prob->total_parameter_size = 0; @@ -463,9 +462,8 @@ void problem_update_params(problem *prob, const double *theta) { for (int i = 0; i < prob->n_param_nodes; i++) { - parameter_expr *p = (parameter_expr *)prob->param_nodes[i]; - memcpy(p->base.value, theta + p->param_id, - p->base.size * sizeof(double)); + parameter_expr *p = (parameter_expr *) prob->param_nodes[i]; + memcpy(p->base.value, theta + p->param_id, p->base.size * sizeof(double)); } /* Force re-evaluation of affine Jacobians on next call */ prob->jacobian_called = false; From 978f3194cd0a6af2570859b7b17cbac2114dd718 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 10 Feb 2026 17:04:52 -0500 Subject: [PATCH 10/24] Add problem-level tests for parameter support Exercises param_scalar_mult, param_vector_mult, and left_param_matmul with problem_register_params/problem_update_params to verify objective, gradient, constraint, and Jacobian values update correctly. Co-Authored-By: Claude Opus 4.6 --- tests/all_tests.c | 4 + tests/problem/test_param_prob.h | 245 ++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 tests/problem/test_param_prob.h diff --git a/tests/all_tests.c b/tests/all_tests.c index 4e7571a..0d8278d 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -41,6 +41,7 @@ #include "jacobian_tests/test_sum.h" #include "jacobian_tests/test_trace.h" #include "jacobian_tests/test_transpose.h" +#include "problem/test_param_prob.h" #include "problem/test_problem.h" #include "utils/test_csc_matrix.h" #include "utils/test_csr_matrix.h" @@ -257,6 +258,9 @@ int main(void) mu_run_test(test_problem_jacobian_multi, tests_run); mu_run_test(test_problem_constraint_forward, tests_run); mu_run_test(test_problem_hessian, tests_run); + mu_run_test(test_param_scalar_mult_problem, tests_run); + mu_run_test(test_param_vector_mult_problem, tests_run); + mu_run_test(test_param_left_matmul_problem, tests_run); printf("\n=== All %d tests passed ===\n", tests_run); diff --git a/tests/problem/test_param_prob.h b/tests/problem/test_param_prob.h new file mode 100644 index 0000000..b7716bf --- /dev/null +++ b/tests/problem/test_param_prob.h @@ -0,0 +1,245 @@ +#ifndef TEST_PARAM_PROB_H +#define TEST_PARAM_PROB_H + +#include +#include + +#include "affine.h" +#include "bivariate.h" +#include "elementwise_univariate.h" +#include "expr.h" +#include "minunit.h" +#include "problem.h" +#include "test_helpers.h" + +/* + * Test 1: param_scalar_mult in objective + * + * Problem: minimize a * sum(log(x)), no constraints, x size 2 + * a is a scalar parameter (param_id=0) + * + * At x=[1,2], a=3: + * obj = 3*(log(1)+log(2)) = 3*log(2) + * gradient = [3/1, 3/2] = [3.0, 1.5] + * + * After update a=5: + * obj = 5*log(2) + * gradient = [5.0, 2.5] + */ +const char *test_param_scalar_mult_problem(void) +{ + int n_vars = 2; + + /* Build tree: sum(a * log(x)) */ + expr *x = new_variable(2, 1, 0, n_vars); + expr *log_x = new_log(x); + expr *a_param = new_parameter(1, 1, 0, n_vars); + expr *scaled = new_param_scalar_mult(a_param, log_x); + expr *objective = new_sum(scaled, -1); + + /* Create problem (no constraints) */ + problem *prob = new_problem(objective, NULL, 0, true); + + /* Register parameter */ + expr *param_nodes[1] = {a_param}; + problem_register_params(prob, param_nodes, 1); + problem_init_derivatives(prob); + + /* Set a=3 and evaluate at x=[1,2] */ + double theta[1] = {3.0}; + problem_update_params(prob, theta); + + double u[2] = {1.0, 2.0}; + double obj_val = problem_objective_forward(prob, u); + problem_gradient(prob); + + double expected_obj = 3.0 * log(2.0); + mu_assert("obj wrong (a=3)", fabs(obj_val - expected_obj) < 1e-10); + + double expected_grad[2] = {3.0, 1.5}; + mu_assert("gradient wrong (a=3)", + cmp_double_array(prob->gradient_values, expected_grad, 2)); + + /* Update a=5 and re-evaluate */ + theta[0] = 5.0; + problem_update_params(prob, theta); + + obj_val = problem_objective_forward(prob, u); + problem_gradient(prob); + + expected_obj = 5.0 * log(2.0); + mu_assert("obj wrong (a=5)", fabs(obj_val - expected_obj) < 1e-10); + + double expected_grad2[2] = {5.0, 2.5}; + mu_assert("gradient wrong (a=5)", + cmp_double_array(prob->gradient_values, expected_grad2, 2)); + + free_problem(prob); + + return 0; +} + +/* + * Test 2: param_vector_mult in constraint + * + * Problem: minimize sum(x), subject to p ∘ x, x size 2 + * p is a vector parameter of size 2 (param_id=0) + * + * At x=[1,2], p=[3,4]: + * constraint_values = [3, 8] + * jacobian = diag([3, 4]) + * + * After update p=[5,6]: + * constraint_values = [5, 12] + * jacobian = diag([5, 6]) + */ +const char *test_param_vector_mult_problem(void) +{ + int n_vars = 2; + + /* Objective: sum(x) */ + expr *x_obj = new_variable(2, 1, 0, n_vars); + expr *objective = new_sum(x_obj, -1); + + /* Constraint: p ∘ x */ + expr *x_con = new_variable(2, 1, 0, n_vars); + expr *p_param = new_parameter(2, 1, 0, n_vars); + expr *constraint = new_param_vector_mult(p_param, x_con); + + expr *constraints[1] = {constraint}; + + /* Create problem */ + problem *prob = new_problem(objective, constraints, 1, true); + + expr *param_nodes[1] = {p_param}; + problem_register_params(prob, param_nodes, 1); + problem_init_derivatives(prob); + + /* Set p=[3,4] and evaluate at x=[1,2] */ + double theta[2] = {3.0, 4.0}; + problem_update_params(prob, theta); + + double u[2] = {1.0, 2.0}; + problem_constraint_forward(prob, u); + problem_jacobian(prob); + + double expected_cv[2] = {3.0, 8.0}; + mu_assert("constraint values wrong (p=[3,4])", + cmp_double_array(prob->constraint_values, expected_cv, 2)); + + CSR_Matrix *jac = prob->jacobian; + mu_assert("jac rows wrong", jac->m == 2); + mu_assert("jac cols wrong", jac->n == 2); + + int expected_p[3] = {0, 1, 2}; + mu_assert("jac->p wrong (p=[3,4])", cmp_int_array(jac->p, expected_p, 3)); + + int expected_i[2] = {0, 1}; + mu_assert("jac->i wrong (p=[3,4])", cmp_int_array(jac->i, expected_i, 2)); + + double expected_x[2] = {3.0, 4.0}; + mu_assert("jac->x wrong (p=[3,4])", cmp_double_array(jac->x, expected_x, 2)); + + /* Update p=[5,6] and re-evaluate */ + double theta2[2] = {5.0, 6.0}; + problem_update_params(prob, theta2); + + problem_constraint_forward(prob, u); + problem_jacobian(prob); + + double expected_cv2[2] = {5.0, 12.0}; + mu_assert("constraint values wrong (p=[5,6])", + cmp_double_array(prob->constraint_values, expected_cv2, 2)); + + double expected_x2[2] = {5.0, 6.0}; + mu_assert("jac->x wrong (p=[5,6])", cmp_double_array(jac->x, expected_x2, 2)); + + free_problem(prob); + + return 0; +} + +/* + * Test 3: left_param_matmul in constraint + * + * Problem: minimize sum(x), subject to A @ x, x size 2, A is 2x2 + * A is a 2x2 matrix parameter (param_id=0, size=4, column-major) + * A = [[1,2],[3,4]] → column-major theta = [1,3,2,4] + * + * At x=[1,2]: + * constraint_values = [1*1+2*2, 3*1+4*2] = [5, 11] + * jacobian = [[1,2],[3,4]] + * + * After update A = [[5,6],[7,8]] → theta = [5,7,6,8]: + * constraint_values = [5*1+6*2, 7*1+8*2] = [17, 23] + * jacobian = [[5,6],[7,8]] + */ +const char *test_param_left_matmul_problem(void) +{ + int n_vars = 2; + + /* Objective: sum(x) */ + expr *x_obj = new_variable(2, 1, 0, n_vars); + expr *objective = new_sum(x_obj, -1); + + /* Constraint: A @ x */ + expr *x_con = new_variable(2, 1, 0, n_vars); + expr *A_param = new_parameter(2, 2, 0, n_vars); + expr *constraint = new_left_param_matmul(A_param, x_con); + + expr *constraints[1] = {constraint}; + + /* Create problem */ + problem *prob = new_problem(objective, constraints, 1, true); + + expr *param_nodes[1] = {A_param}; + problem_register_params(prob, param_nodes, 1); + problem_init_derivatives(prob); + + /* Set A = [[1,2],[3,4]], column-major: [1,3,2,4] */ + double theta[4] = {1.0, 3.0, 2.0, 4.0}; + problem_update_params(prob, theta); + + double u[2] = {1.0, 2.0}; + problem_constraint_forward(prob, u); + problem_jacobian(prob); + + double expected_cv[2] = {5.0, 11.0}; + mu_assert("constraint values wrong (A1)", + cmp_double_array(prob->constraint_values, expected_cv, 2)); + + CSR_Matrix *jac = prob->jacobian; + mu_assert("jac rows wrong", jac->m == 2); + mu_assert("jac cols wrong", jac->n == 2); + + /* Dense jacobian = [[1,2],[3,4]], CSR: row 0 → cols 0,1 vals 1,2; + * row 1 → cols 0,1 vals 3,4 */ + int expected_p[3] = {0, 2, 4}; + mu_assert("jac->p wrong (A1)", cmp_int_array(jac->p, expected_p, 3)); + + int expected_i[4] = {0, 1, 0, 1}; + mu_assert("jac->i wrong (A1)", cmp_int_array(jac->i, expected_i, 4)); + + double expected_x[4] = {1.0, 2.0, 3.0, 4.0}; + mu_assert("jac->x wrong (A1)", cmp_double_array(jac->x, expected_x, 4)); + + /* Update A = [[5,6],[7,8]], column-major: [5,7,6,8] */ + double theta2[4] = {5.0, 7.0, 6.0, 8.0}; + problem_update_params(prob, theta2); + + problem_constraint_forward(prob, u); + problem_jacobian(prob); + + double expected_cv2[2] = {17.0, 23.0}; + mu_assert("constraint values wrong (A2)", + cmp_double_array(prob->constraint_values, expected_cv2, 2)); + + double expected_x2[4] = {5.0, 6.0, 7.0, 8.0}; + mu_assert("jac->x wrong (A2)", cmp_double_array(jac->x, expected_x2, 4)); + + free_problem(prob); + + return 0; +} + +#endif /* TEST_PARAM_PROB_H */ From 01c6f8258c391cc44ee07102f12100e797a30c75 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 10 Feb 2026 17:09:16 -0500 Subject: [PATCH 11/24] Clean up comments in bivariate.h and subexpr.h Co-Authored-By: Claude Opus 4.6 --- include/bivariate.h | 3 +-- include/subexpr.h | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/include/bivariate.h b/include/bivariate.h index fb6bf9b..4d901eb 100644 --- a/include/bivariate.h +++ b/include/bivariate.h @@ -42,8 +42,7 @@ expr *new_const_scalar_mult(double a, expr *child); /* Constant vector elementwise multiplication: a ∘ f(x) where a is constant */ expr *new_const_vector_mult(const double *a, expr *child); -/* Left matrix multiplication with parameter source: P @ f(x) where P is a parameter - */ +/* Left matrix multiplication: P @ f(x) where P is a parameter */ expr *new_left_param_matmul(expr *param_node, expr *child); /* Parameter scalar multiplication: p * f(x) where p is a parameter */ diff --git a/include/subexpr.h b/include/subexpr.h index 00e61c4..8f11942 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -118,8 +118,7 @@ typedef struct left_matmul_expr CSR_Matrix *A; CSR_Matrix *AT; CSC_Matrix *CSC_work; - expr * - param_source; /* if non-NULL, refresh A/AT values from param_source->value */ + expr *param_source; int src_m, src_n; /* original (non-block-diag) matrix dimensions */ } left_matmul_expr; From b8cf43666a25145125b36e424535785edb53dd30 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 10 Feb 2026 17:16:29 -0500 Subject: [PATCH 12/24] Fix memory leak in new_param_scalar_mult The retained param_node was never released because free_type_data was NULL. Add free_param_type_data to match const_vector_mult and left_matmul. Co-Authored-By: Claude Opus 4.6 --- src/bivariate/const_scalar_mult.c | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/bivariate/const_scalar_mult.c b/src/bivariate/const_scalar_mult.c index c1d5db1..ef4c949 100644 --- a/src/bivariate/const_scalar_mult.c +++ b/src/bivariate/const_scalar_mult.c @@ -117,6 +117,15 @@ expr *new_const_scalar_mult(double a, expr *child) return node; } +static void free_param_type_data(expr *node) +{ + const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node; + if (sn->param_source) + { + free_expr(sn->param_source); + } +} + expr *new_param_scalar_mult(expr *param_node, expr *child) { const_scalar_mult_expr *mult_node = @@ -124,7 +133,8 @@ expr *new_param_scalar_mult(expr *param_node, expr *child) expr *node = &mult_node->base; init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init, - eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL); + eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, + free_param_type_data); node->left = child; mult_node->a = param_node->value[0]; /* initial value */ mult_node->param_source = param_node; From bf8a55ccf91631be34a0b08d1b88a59cd974617c Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Fri, 13 Feb 2026 16:25:06 -0500 Subject: [PATCH 13/24] Run clang-format on merge-resolved files Co-Authored-By: Claude Opus 4.6 --- include/subexpr.h | 4 ++-- src/bivariate/left_matmul.c | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/include/subexpr.h b/include/subexpr.h index f59fc24..078b61a 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -121,8 +121,8 @@ typedef struct left_matmul_expr CSC_Matrix *Jchild_CSC; CSC_Matrix *J_CSC; int *csc_to_csr_workspace; - expr *param_source; /* if non-NULL, A/AT values come from this parameter */ - int src_m, src_n; /* original matrix dimensions */ + expr *param_source; /* if non-NULL, A/AT values come from this parameter */ + int src_m, src_n; /* original matrix dimensions */ } left_matmul_expr; /* Right matrix multiplication: y = f(x) * A where f(x) is an expression. diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index ad29ab3..f2044a7 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -101,8 +101,7 @@ static void free_type_data(expr *node) free_csc_matrix(lin_node->Jchild_CSC); free_csc_matrix(lin_node->J_CSC); free(lin_node->csc_to_csr_workspace); - if (lin_node->param_source) - free_expr(lin_node->param_source); + if (lin_node->param_source) free_expr(lin_node->param_source); lin_node->A = NULL; lin_node->AT = NULL; lin_node->Jchild_CSC = NULL; From 9700344e675cb4d4b416a19e7f26be4eb1e3ba2e Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Sun, 15 Feb 2026 16:06:09 -0500 Subject: [PATCH 14/24] Unify Constant and Parameter into single parameter type Merge the separate Constant and Parameter leaf nodes into a unified parameter_expr with PARAM_FIXED sentinel (-1) for constants. This eliminates duplicate code paths and consolidates 7 bivariate constructors into 3 unified ones: - new_const_scalar_mult / new_param_scalar_mult -> new_scalar_mult - new_const_vector_mult / new_param_vector_mult -> new_vector_mult - new_left_matmul (CSR) / new_left_param_matmul -> new_left_matmul (param node) Key changes: - Add PARAM_FIXED define and extend new_parameter() to accept initial values - Delete constant.c (absorbed by parameter.c) - Remove direct value storage (double a, double *a) from scalar/vector mult structs; always read from param_source - left_matmul builds sparse CSR for fixed params (preserving sparsity) and dense CSR for updatable params - right_matmul internally creates a fixed parameter node from transposed A - problem_update_params skips PARAM_FIXED nodes - Update all test callers to use new_parameter with PARAM_FIXED Co-Authored-By: Claude Opus 4.6 --- include/affine.h | 3 +- include/bivariate.h | 21 +-- include/subexpr.h | 20 +-- src/affine/constant.c | 69 ---------- src/affine/parameter.c | 31 +++-- src/bivariate/const_scalar_mult.c | 31 +---- src/bivariate/const_vector_mult.c | 46 ++----- src/bivariate/left_matmul.c | 120 ++++++------------ src/bivariate/right_matmul.c | 14 +- src/problem.c | 1 + tests/forward_pass/affine/test_add.h | 3 +- tests/forward_pass/affine/test_sum.h | 7 +- .../affine/test_variable_constant.h | 3 +- tests/forward_pass/composite/test_composite.h | 3 +- tests/forward_pass/test_prod_axis_one.h | 3 +- tests/forward_pass/test_prod_axis_zero.h | 3 +- tests/jacobian_tests/test_broadcast.h | 3 +- tests/jacobian_tests/test_const_scalar_mult.h | 8 +- tests/jacobian_tests/test_const_vector_mult.h | 8 +- tests/jacobian_tests/test_left_matmul.h | 43 ++----- tests/jacobian_tests/test_transpose.h | 15 +-- tests/problem/test_param_prob.h | 12 +- tests/profiling/profile_left_matmul.h | 24 ++-- tests/wsum_hess/test_const_scalar_mult.h | 8 +- tests/wsum_hess/test_const_vector_mult.h | 8 +- tests/wsum_hess/test_left_matmul.h | 43 ++----- 26 files changed, 193 insertions(+), 357 deletions(-) delete mode 100644 src/affine/constant.c diff --git a/include/affine.h b/include/affine.h index cc7120a..2210dd2 100644 --- a/include/affine.h +++ b/include/affine.h @@ -32,9 +32,8 @@ expr *new_hstack(expr **args, int n_args, int n_vars); expr *new_promote(expr *child, int d1, int d2); expr *new_trace(expr *child); -expr *new_constant(int d1, int d2, int n_vars, const double *values); expr *new_variable(int d1, int d2, int var_id, int n_vars); -expr *new_parameter(int d1, int d2, int param_id, int n_vars); +expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values); expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs); expr *new_reshape(expr *child, int d1, int d2); diff --git a/include/bivariate.h b/include/bivariate.h index 4d901eb..15947a1 100644 --- a/include/bivariate.h +++ b/include/bivariate.h @@ -30,25 +30,16 @@ expr *new_rel_entr_second_arg_scalar(expr *left, expr *right); /* Matrix multiplication: Z = X @ Y */ expr *new_matmul(expr *x, expr *y); -/* Left matrix multiplication: A @ f(x) where A is a constant matrix */ -expr *new_left_matmul(expr *u, const CSR_Matrix *A); +/* Left matrix multiplication: A @ f(x) where A comes from a parameter node */ +expr *new_left_matmul(expr *param_node, expr *child); /* Right matrix multiplication: f(x) @ A where A is a constant matrix */ expr *new_right_matmul(expr *u, const CSR_Matrix *A); -/* Constant scalar multiplication: a * f(x) where a is a constant double */ -expr *new_const_scalar_mult(double a, expr *child); +/* Scalar multiplication: a * f(x) where a comes from a parameter node */ +expr *new_scalar_mult(expr *param_node, expr *child); -/* Constant vector elementwise multiplication: a ∘ f(x) where a is constant */ -expr *new_const_vector_mult(const double *a, expr *child); - -/* Left matrix multiplication: P @ f(x) where P is a parameter */ -expr *new_left_param_matmul(expr *param_node, expr *child); - -/* Parameter scalar multiplication: p * f(x) where p is a parameter */ -expr *new_param_scalar_mult(expr *param_node, expr *child); - -/* Parameter vector elementwise multiplication: p ∘ f(x) where p is a parameter */ -expr *new_param_vector_mult(expr *param_node, expr *child); +/* Vector elementwise multiplication: a ∘ f(x) where a comes from a parameter node */ +expr *new_vector_mult(expr *param_node, expr *child); #endif /* BIVARIATE_H */ diff --git a/include/subexpr.h b/include/subexpr.h index 078b61a..41a7833 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -25,12 +25,17 @@ /* Forward declaration */ struct int_double_pair; -/* Parameter node: like constant but with updatable values via problem_update_params +/* param_id value for fixed (constant) parameters */ +#define PARAM_FIXED -1 + +/* Parameter node: unified leaf for constants and updatable parameters. + * Constants use param_id == PARAM_FIXED and have values set at creation. + * Updatable parameters have param_id >= 0 and are updated via problem_update_params. */ typedef struct parameter_expr { expr base; - int param_id; /* offset into global theta vector */ + int param_id; /* offset into global theta vector, or PARAM_FIXED */ } parameter_expr; /* Type-specific expression structures that "inherit" from expr */ @@ -136,20 +141,19 @@ typedef struct right_matmul_expr CSC_Matrix *CSC_work; } right_matmul_expr; -/* Constant scalar multiplication: y = a * child where a is a constant double */ +/* Scalar multiplication: y = a * child where a comes from a parameter node */ typedef struct const_scalar_mult_expr { expr base; - double a; - expr *param_source; /* if non-NULL, read a from param_source->value[0] */ + expr *param_source; /* always set; read a from param_source->value[0] */ } const_scalar_mult_expr; -/* Constant vector elementwise multiplication: y = a \circ child for constant a */ +/* Vector elementwise multiplication: y = a \circ child where a comes from a + * parameter node */ typedef struct const_vector_mult_expr { expr base; - double *a; /* length equals node->size */ - expr *param_source; /* if non-NULL, use param_source->value instead of a */ + expr *param_source; /* always set; read a from param_source->value */ } const_vector_mult_expr; /* Index/slicing: y = child[indices] where indices is a list of flat positions */ diff --git a/src/affine/constant.c b/src/affine/constant.c deleted file mode 100644 index 59da3e2..0000000 --- a/src/affine/constant.c +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright 2026 Daniel Cederberg and William Zhang - * - * This file is part of the DNLP-differentiation-engine project. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "affine.h" -#include -#include - -static void forward(expr *node, const double *u) -{ - /* Constants don't depend on u; values are already set */ - (void) node; - (void) u; -} - -static void jacobian_init(expr *node) -{ - /* Constant jacobian is all zeros: size x n_vars with 0 nonzeros. - * new_csr_matrix uses calloc for row pointers, so they're already 0. */ - node->jacobian = new_csr_matrix(node->size, node->n_vars, 0); -} - -static void eval_jacobian(expr *node) -{ - /* Constant jacobian never changes - nothing to evaluate */ - (void) node; -} - -static void wsum_hess_init(expr *node) -{ - /* Constant Hessian is all zeros: n_vars x n_vars with 0 nonzeros. */ - node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0); -} - -static void eval_wsum_hess(expr *node, const double *w) -{ - /* Constant Hessian is always zero - nothing to compute */ - (void) node; - (void) w; -} - -static bool is_affine(const expr *node) -{ - (void) node; - return true; -} - -expr *new_constant(int d1, int d2, int n_vars, const double *values) -{ - expr *node = (expr *) calloc(1, sizeof(expr)); - init_expr(node, d1, d2, n_vars, forward, jacobian_init, eval_jacobian, is_affine, - wsum_hess_init, eval_wsum_hess, NULL); - memcpy(node->value, values, node->size * sizeof(double)); - - return node; -} diff --git a/src/affine/parameter.c b/src/affine/parameter.c index cd8ea95..c50a5cb 100644 --- a/src/affine/parameter.c +++ b/src/affine/parameter.c @@ -16,10 +16,14 @@ * limitations under the License. */ -/* Parameter leaf node: behaviorally identical to constant (zero derivatives - w.r.t. variables), but its values are updatable via problem_update_params. - This allows re-solving with different parameter values without rebuilding - the expression tree. */ +/* Unified parameter/constant leaf node. + * + * When param_id == PARAM_FIXED, this is a constant whose values are set at + * creation and never change. When param_id >= 0, values are updated via + * problem_update_params. + * + * In both cases the derivative behavior is identical: zero Jacobian and + * Hessian with respect to variables (always affine). */ #include "affine.h" #include "subexpr.h" @@ -28,26 +32,26 @@ static void forward(expr *node, const double *u) { - /* Values are set by problem_update_params, not by forward pass */ + /* Values are set at creation (constants) or by problem_update_params */ (void) node; (void) u; } static void jacobian_init(expr *node) { - /* Parameter jacobian is all zeros: size x n_vars with 0 nonzeros */ + /* Parameter/constant jacobian is all zeros: size x n_vars with 0 nonzeros */ node->jacobian = new_csr_matrix(node->size, node->n_vars, 0); } static void eval_jacobian(expr *node) { - /* Parameter jacobian never changes */ + /* Jacobian never changes */ (void) node; } static void wsum_hess_init(expr *node) { - /* Parameter Hessian is all zeros */ + /* Hessian is all zeros */ node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0); } @@ -63,12 +67,19 @@ static bool is_affine(const expr *node) return true; } -expr *new_parameter(int d1, int d2, int param_id, int n_vars) +expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values) { parameter_expr *pnode = (parameter_expr *) calloc(1, sizeof(parameter_expr)); init_expr(&pnode->base, d1, d2, n_vars, forward, jacobian_init, eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL); pnode->param_id = param_id; - /* values will be populated by problem_update_params */ + + /* If values provided (fixed constant), copy them now */ + if (values != NULL) + { + memcpy(pnode->base.value, values, pnode->base.size * sizeof(double)); + } + /* Otherwise values will be populated by problem_update_params */ + return &pnode->base; } diff --git a/src/bivariate/const_scalar_mult.c b/src/bivariate/const_scalar_mult.c index ef4c949..3863730 100644 --- a/src/bivariate/const_scalar_mult.c +++ b/src/bivariate/const_scalar_mult.c @@ -22,7 +22,7 @@ #include #include -/* Constant scalar multiplication: y = a * child where a is a constant double */ +/* Scalar multiplication: y = a * child where a comes from a parameter node */ static void forward(expr *node, const double *u) { @@ -33,7 +33,7 @@ static void forward(expr *node, const double *u) /* local forward pass: multiply each element by scalar a */ const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node; - double a = sn->param_source ? sn->param_source->value[0] : sn->a; + double a = sn->param_source->value[0]; for (int i = 0; i < node->size; i++) { node->value[i] = a * child->value[i]; @@ -57,7 +57,7 @@ static void eval_jacobian(expr *node) { expr *child = node->left; const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node; - double a = sn->param_source ? sn->param_source->value[0] : sn->a; + double a = sn->param_source->value[0]; /* evaluate child */ child->eval_jacobian(child); @@ -88,7 +88,7 @@ static void eval_wsum_hess(expr *node, const double *w) x->eval_wsum_hess(x, w); const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node; - double a = sn->param_source ? sn->param_source->value[0] : sn->a; + double a = sn->param_source->value[0]; for (int j = 0; j < x->wsum_hess->nnz; j++) { node->wsum_hess->x[j] = a * x->wsum_hess->x[j]; @@ -101,23 +101,7 @@ static bool is_affine(const expr *node) return node->left->is_affine(node->left); } -expr *new_const_scalar_mult(double a, expr *child) -{ - const_scalar_mult_expr *mult_node = - (const_scalar_mult_expr *) calloc(1, sizeof(const_scalar_mult_expr)); - expr *node = &mult_node->base; - - init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init, - eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL); - node->left = child; - mult_node->a = a; - mult_node->param_source = NULL; - expr_retain(child); - - return node; -} - -static void free_param_type_data(expr *node) +static void free_type_data(expr *node) { const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node; if (sn->param_source) @@ -126,7 +110,7 @@ static void free_param_type_data(expr *node) } } -expr *new_param_scalar_mult(expr *param_node, expr *child) +expr *new_scalar_mult(expr *param_node, expr *child) { const_scalar_mult_expr *mult_node = (const_scalar_mult_expr *) calloc(1, sizeof(const_scalar_mult_expr)); @@ -134,9 +118,8 @@ expr *new_param_scalar_mult(expr *param_node, expr *child) init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init, eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, - free_param_type_data); + free_type_data); node->left = child; - mult_node->a = param_node->value[0]; /* initial value */ mult_node->param_source = param_node; expr_retain(child); expr_retain(param_node); diff --git a/src/bivariate/const_vector_mult.c b/src/bivariate/const_vector_mult.c index 31f7236..8eba16a 100644 --- a/src/bivariate/const_vector_mult.c +++ b/src/bivariate/const_vector_mult.c @@ -21,13 +21,14 @@ #include #include -/* Constant vector elementwise multiplication: y = a \circ child */ +/* Vector elementwise multiplication: y = a \circ child + * where a comes from a parameter node */ static void forward(expr *node, const double *u) { expr *child = node->left; const_vector_mult_expr *vn = (const_vector_mult_expr *) node; - const double *a = vn->param_source ? vn->param_source->value : vn->a; + const double *a = vn->param_source->value; /* child's forward pass */ child->forward(child, u); @@ -56,7 +57,7 @@ static void eval_jacobian(expr *node) { expr *x = node->left; const_vector_mult_expr *vn = (const_vector_mult_expr *) node; - const double *a = vn->param_source ? vn->param_source->value : vn->a; + const double *a = vn->param_source->value; /* evaluate x */ x->eval_jacobian(x); @@ -90,7 +91,7 @@ static void eval_wsum_hess(expr *node, const double *w) { expr *x = node->left; const_vector_mult_expr *vn = (const_vector_mult_expr *) node; - const double *a = vn->param_source ? vn->param_source->value : vn->a; + const double *a = vn->param_source->value; /* scale weights w by a */ for (int i = 0; i < node->size; i++) @@ -104,50 +105,22 @@ static void eval_wsum_hess(expr *node, const double *w) memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double)); } -static void free_type_data(expr *node) -{ - const_vector_mult_expr *vnode = (const_vector_mult_expr *) node; - free(vnode->a); -} - static bool is_affine(const expr *node) { /* Affine iff the child is affine */ return node->left->is_affine(node->left); } -expr *new_const_vector_mult(const double *a, expr *child) -{ - const_vector_mult_expr *vnode = - (const_vector_mult_expr *) calloc(1, sizeof(const_vector_mult_expr)); - expr *node = &vnode->base; - - init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init, - eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, - free_type_data); - node->left = child; - expr_retain(child); - - /* copy a vector */ - vnode->a = (double *) malloc(child->size * sizeof(double)); - memcpy(vnode->a, a, child->size * sizeof(double)); - vnode->param_source = NULL; - - return node; -} - -static void free_param_type_data(expr *node) +static void free_type_data(expr *node) { const_vector_mult_expr *vnode = (const_vector_mult_expr *) node; - /* a is not owned when param_source is set */ - free(vnode->a); if (vnode->param_source) { free_expr(vnode->param_source); } } -expr *new_param_vector_mult(expr *param_node, expr *child) +expr *new_vector_mult(expr *param_node, expr *child) { const_vector_mult_expr *vnode = (const_vector_mult_expr *) calloc(1, sizeof(const_vector_mult_expr)); @@ -155,13 +128,10 @@ expr *new_param_vector_mult(expr *param_node, expr *child) init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init, eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, - free_param_type_data); + free_type_data); node->left = child; expr_retain(child); - /* Still allocate a copy for initial values (used before first update_params) */ - vnode->a = (double *) malloc(child->size * sizeof(double)); - memcpy(vnode->a, param_node->value, child->size * sizeof(double)); vnode->param_source = param_node; expr_retain(param_node); diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index f2044a7..96ff7d8 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -23,10 +23,10 @@ #include #include -/* This file implement the atom 'left_matmul' corresponding to the operation y = - A @ f(x), where A is a given matrix and f(x) is an arbitrary expression. - Here, f(x) can be a vector-valued expression and a matrix-valued - expression. The dimensions are A - m x n, f(x) - n x p, y - m x p. +/* This file implements the atom 'left_matmul' corresponding to the operation y = + A @ f(x), where A is a given matrix (from a parameter node) and f(x) is an + arbitrary expression. Here, f(x) can be a vector-valued expression and a + matrix-valued expression. The dimensions are A - m x n, f(x) - n x p, y - m x p. Note that here A does not have global column indices but it is a local matrix. This is an important distinction compared to linear_op_expr. @@ -45,9 +45,6 @@ Working in terms of A_kron unifies the implementation of f(x) being vector-valued or matrix-valued. - - I (dance858) think we can get additional big speedups when A is dense by - introducing a dense matrix class. */ #include "utils/utils.h" @@ -59,15 +56,15 @@ static void refresh_param_values(left_matmul_expr *lin_node) { const double *src = lin_node->param_source->value; int m = lin_node->src_m; - int n = lin_node->src_n; + CSR_Matrix *A = lin_node->A; - /* Fill A: column-major source -> row-major CSR values */ + /* Fill A values from column-major source, following existing sparsity pattern */ for (int row = 0; row < m; row++) - for (int col = 0; col < n; col++) - lin_node->A->x[row * n + col] = src[row + col * m]; + for (int k = A->p[row]; k < A->p[row + 1]; k++) + A->x[k] = src[row + A->i[k] * m]; /* Recompute AT values from updated A */ - AT_fill_values(lin_node->A, lin_node->AT, lin_node->base.iwork); + AT_fill_values(A, lin_node->AT, lin_node->base.iwork); } static void forward(expr *node, const double *u) @@ -75,11 +72,8 @@ static void forward(expr *node, const double *u) expr *x = node->left; left_matmul_expr *lin_node = (left_matmul_expr *) node; - /* refresh A/AT if parameter-sourced */ - if (lin_node->param_source) - { - refresh_param_values(lin_node); - } + /* refresh A/AT from parameter source */ + refresh_param_values(lin_node); /* child's forward pass */ node->left->forward(node->left, u); @@ -134,11 +128,8 @@ static void eval_jacobian(expr *node) CSC_Matrix *Jchild_CSC = lnode->Jchild_CSC; CSC_Matrix *J_CSC = lnode->J_CSC; - /* refresh A if parameter-sourced */ - if (lnode->param_source) - { - refresh_param_values(lnode); - } + /* refresh A from parameter source */ + refresh_param_values(lnode); /* evaluate child's jacobian and convert to CSC */ x->eval_jacobian(x); @@ -178,64 +169,12 @@ static void eval_wsum_hess(expr *node, const double *w) node->wsum_hess->nnz * sizeof(double)); } -expr *new_left_matmul(expr *u, const CSR_Matrix *A) -{ - /* We expect u->d1 == A->n. However, numpy's broadcasting rules allow users - to do A @ u where u is (n, ) which in C is actually (1, n). In that case - the result of A @ u is (m, ), which is (1, m) according to broadcasting - rules. We therefore check if this is the case. */ - int d1, d2, n_blocks; - if (u->d1 == A->n) - { - d1 = A->m; - d2 = u->d2; - n_blocks = u->d2; - } - else if (u->d2 == A->n && u->d1 == 1) - { - d1 = 1; - d2 = A->m; - n_blocks = 1; - } - else - { - fprintf(stderr, "Error in new_left_matmul: dimension mismatch \n"); - exit(1); - } - - /* Allocate the type-specific struct */ - left_matmul_expr *lin_node = - (left_matmul_expr *) calloc(1, sizeof(left_matmul_expr)); - expr *node = &lin_node->base; - init_expr(node, d1, d2, u->n_vars, forward, jacobian_init, eval_jacobian, - is_affine, wsum_hess_init, eval_wsum_hess, free_type_data); - node->left = u; - expr_retain(u); - - /* allocate workspace. iwork is used for transposing A (requiring size A->n) - and for converting J_child csr to csc (requring size node->n_vars). - csc_to_csr_workspace is used for converting J_CSC to CSR (requring node->size) - */ - node->iwork = (int *) malloc(MAX(A->n, node->n_vars) * sizeof(int)); - lin_node->csc_to_csr_workspace = (int *) malloc(node->size * sizeof(int)); - lin_node->n_blocks = n_blocks; - - /* store A and AT */ - lin_node->A = new_csr(A); - lin_node->AT = transpose(lin_node->A, node->iwork); - lin_node->param_source = NULL; - lin_node->src_m = 0; - lin_node->src_n = 0; - - return node; -} - -expr *new_left_param_matmul(expr *param_node, expr *child) +expr *new_left_matmul(expr *param_node, expr *child) { int A_m = param_node->d1; int A_n = param_node->d2; - /* Same dimension logic as new_left_matmul */ + /* Dimension logic: handle numpy broadcasting (1, n) as (n, ) */ int d1, d2, n_blocks; if (child->d1 == A_n) { @@ -251,12 +190,28 @@ expr *new_left_param_matmul(expr *param_node, expr *child) } else { - fprintf(stderr, "Error in new_left_param_matmul: dimension mismatch\n"); + fprintf(stderr, "Error in new_left_matmul: dimension mismatch\n"); exit(1); } - /* Build dense CSR from param_node's column-major values */ - int nnz = A_m * A_n; + /* Build CSR from param_node's column-major values. + * For fixed parameters (PARAM_FIXED), skip zeros to preserve sparsity. + * For updatable parameters, build dense CSR since sparsity may change. */ + parameter_expr *pnode = (parameter_expr *) param_node; + int sparse = (pnode->param_id == PARAM_FIXED); + + int nnz = 0; + if (sparse) + { + for (int row = 0; row < A_m; row++) + for (int col = 0; col < A_n; col++) + if (param_node->value[row + col * A_m] != 0.0) nnz++; + } + else + { + nnz = A_m * A_n; + } + CSR_Matrix *A = new_csr_matrix(A_m, A_n, nnz); int idx = 0; for (int row = 0; row < A_m; row++) @@ -264,8 +219,10 @@ expr *new_left_param_matmul(expr *param_node, expr *child) A->p[row] = idx; for (int col = 0; col < A_n; col++) { + double val = param_node->value[row + col * A_m]; + if (sparse && val == 0.0) continue; A->i[idx] = col; - A->x[idx] = param_node->value[row + col * A_m]; + A->x[idx] = val; idx++; } } @@ -276,8 +233,7 @@ expr *new_left_param_matmul(expr *param_node, expr *child) (left_matmul_expr *) calloc(1, sizeof(left_matmul_expr)); expr *node = &lin_node->base; init_expr(node, d1, d2, child->n_vars, forward, jacobian_init, eval_jacobian, - is_affine, wsum_hess_init, eval_wsum_hess, - free_type_data); /* same free_type_data as constant version */ + is_affine, wsum_hess_init, eval_wsum_hess, free_type_data); node->left = child; expr_retain(child); diff --git a/src/bivariate/right_matmul.c b/src/bivariate/right_matmul.c index 6ec593b..e64bf3c 100644 --- a/src/bivariate/right_matmul.c +++ b/src/bivariate/right_matmul.c @@ -33,10 +33,20 @@ expr *new_right_matmul(expr *u, const CSR_Matrix *A) int *work_transpose = (int *) malloc(A->n * sizeof(int)); CSR_Matrix *AT = transpose(A, work_transpose); + /* Convert AT (CSR) to dense column-major array for parameter node */ + int m = AT->m; /* rows of AT = cols of A */ + int n = AT->n; /* cols of AT = rows of A */ + double *col_major = (double *) calloc(m * n, sizeof(double)); + for (int row = 0; row < m; row++) + for (int k = AT->p[row]; k < AT->p[row + 1]; k++) + col_major[row + AT->i[k] * m] = AT->x[k]; + expr *u_transpose = new_transpose(u); - expr *left_matmul = new_left_matmul(u_transpose, AT); - expr *node = new_transpose(left_matmul); + expr *param_node = new_parameter(m, n, PARAM_FIXED, u->n_vars, col_major); + expr *left_matmul_node = new_left_matmul(param_node, u_transpose); + expr *node = new_transpose(left_matmul_node); + free(col_major); free_csr_matrix(AT); free(work_transpose); return node; diff --git a/src/problem.c b/src/problem.c index 88af4ec..c9e9a1e 100644 --- a/src/problem.c +++ b/src/problem.c @@ -464,6 +464,7 @@ void problem_update_params(problem *prob, const double *theta) for (int i = 0; i < prob->n_param_nodes; i++) { parameter_expr *p = (parameter_expr *) prob->param_nodes[i]; + if (p->param_id == PARAM_FIXED) continue; memcpy(p->base.value, theta + p->param_id, p->base.size * sizeof(double)); } /* Force re-evaluation of affine Jacobians on next call */ diff --git a/tests/forward_pass/affine/test_add.h b/tests/forward_pass/affine/test_add.h index 7ce859c..a087c9a 100644 --- a/tests/forward_pass/affine/test_add.h +++ b/tests/forward_pass/affine/test_add.h @@ -5,6 +5,7 @@ #include "affine.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" const char *test_addition() @@ -12,7 +13,7 @@ const char *test_addition() double u[2] = {3.0, 4.0}; double c[2] = {1.0, 2.0}; expr *var = new_variable(2, 1, 0, 2); - expr *const_node = new_constant(2, 1, 0, c); + expr *const_node = new_parameter(2, 1, PARAM_FIXED, 0, c); expr *sum = new_add(var, const_node); sum->forward(sum, u); double expected[2] = {4.0, 6.0}; diff --git a/tests/forward_pass/affine/test_sum.h b/tests/forward_pass/affine/test_sum.h index 2e3d6ab..e700ebd 100644 --- a/tests/forward_pass/affine/test_sum.h +++ b/tests/forward_pass/affine/test_sum.h @@ -6,6 +6,7 @@ #include "elementwise_univariate.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" const char *test_sum_axis_neg1() @@ -17,7 +18,7 @@ const char *test_sum_axis_neg1() Stored as: [1, 2, 3, 4, 5, 6] */ double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - expr *const_node = new_constant(3, 2, 0, values); + expr *const_node = new_parameter(3, 2, PARAM_FIXED, 0, values); expr *log_node = new_log(const_node); expr *sum_node = new_sum(log_node, -1); sum_node->forward(sum_node, NULL); @@ -42,7 +43,7 @@ const char *test_sum_axis_0() Stored as: [1, 2, 3, 4, 5, 6] */ double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - expr *const_node = new_constant(3, 2, 0, values); + expr *const_node = new_parameter(3, 2, PARAM_FIXED, 0, values); expr *log_node = new_log(const_node); expr *sum_node = new_sum(log_node, 0); sum_node->forward(sum_node, NULL); @@ -69,7 +70,7 @@ const char *test_sum_axis_1() Stored as: [1, 2, 3, 4, 5, 6] */ double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - expr *const_node = new_constant(3, 2, 0, values); + expr *const_node = new_parameter(3, 2, PARAM_FIXED, 0, values); expr *log_node = new_log(const_node); expr *sum_node = new_sum(log_node, 1); sum_node->forward(sum_node, NULL); diff --git a/tests/forward_pass/affine/test_variable_constant.h b/tests/forward_pass/affine/test_variable_constant.h index c964654..ea9b609 100644 --- a/tests/forward_pass/affine/test_variable_constant.h +++ b/tests/forward_pass/affine/test_variable_constant.h @@ -5,6 +5,7 @@ #include "affine.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" const char *test_variable() @@ -21,7 +22,7 @@ const char *test_constant() { double c[2] = {5.0, 10.0}; double u[2] = {0.0, 0.0}; - expr *const_node = new_constant(2, 1, 0, c); + expr *const_node = new_parameter(2, 1, PARAM_FIXED, 0, c); const_node->forward(const_node, u); mu_assert("Constant test failed", cmp_double_array(const_node->value, c, 2)); free_expr(const_node); diff --git a/tests/forward_pass/composite/test_composite.h b/tests/forward_pass/composite/test_composite.h index 92074a0..253aa6a 100644 --- a/tests/forward_pass/composite/test_composite.h +++ b/tests/forward_pass/composite/test_composite.h @@ -6,6 +6,7 @@ #include "elementwise_univariate.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" const char *test_composite() @@ -16,7 +17,7 @@ const char *test_composite() /* Build tree: log(exp(x) + c) */ expr *var = new_variable(2, 1, 0, 2); expr *exp_node = new_exp(var); - expr *const_node = new_constant(2, 1, 0, c); + expr *const_node = new_parameter(2, 1, PARAM_FIXED, 0, c); expr *sum = new_add(exp_node, const_node); expr *log_node = new_log(sum); diff --git a/tests/forward_pass/test_prod_axis_one.h b/tests/forward_pass/test_prod_axis_one.h index 6f4b3bb..499214e 100644 --- a/tests/forward_pass/test_prod_axis_one.h +++ b/tests/forward_pass/test_prod_axis_one.h @@ -6,6 +6,7 @@ #include "expr.h" #include "minunit.h" #include "other.h" +#include "subexpr.h" #include "test_helpers.h" const char *test_forward_prod_axis_one() @@ -16,7 +17,7 @@ const char *test_forward_prod_axis_one() Stored as: [1, 2, 3, 4, 5, 6] */ double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - expr *const_node = new_constant(2, 3, 0, values); + expr *const_node = new_parameter(2, 3, PARAM_FIXED, 0, values); expr *prod_node = new_prod_axis_one(const_node); prod_node->forward(prod_node, NULL); diff --git a/tests/forward_pass/test_prod_axis_zero.h b/tests/forward_pass/test_prod_axis_zero.h index 6504502..30b8cdc 100644 --- a/tests/forward_pass/test_prod_axis_zero.h +++ b/tests/forward_pass/test_prod_axis_zero.h @@ -6,6 +6,7 @@ #include "expr.h" #include "minunit.h" #include "other.h" +#include "subexpr.h" #include "test_helpers.h" const char *test_forward_prod_axis_zero() @@ -16,7 +17,7 @@ const char *test_forward_prod_axis_zero() Stored as: [1, 2, 3, 4, 5, 6] */ double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - expr *const_node = new_constant(2, 3, 0, values); + expr *const_node = new_parameter(2, 3, PARAM_FIXED, 0, values); expr *prod_node = new_prod_axis_zero(const_node); prod_node->forward(prod_node, NULL); diff --git a/tests/jacobian_tests/test_broadcast.h b/tests/jacobian_tests/test_broadcast.h index 612e5cf..da99597 100644 --- a/tests/jacobian_tests/test_broadcast.h +++ b/tests/jacobian_tests/test_broadcast.h @@ -5,6 +5,7 @@ #include "affine.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" const char *test_broadcast_row_jacobian() @@ -141,7 +142,7 @@ const char *test_double_broadcast(void) /* form the expression x + b */ expr *x = new_variable(5, 1, 0, 5); - expr *b = new_constant(1, 5, 5, b_vals); + expr *b = new_parameter(1, 5, PARAM_FIXED, 5, b_vals); expr *bcast_x = new_broadcast(x, 5, 5); expr *bcast_b = new_broadcast(b, 5, 5); expr *sum = new_add(bcast_x, bcast_b); diff --git a/tests/jacobian_tests/test_const_scalar_mult.h b/tests/jacobian_tests/test_const_scalar_mult.h index 3143929..6e1cfd0 100644 --- a/tests/jacobian_tests/test_const_scalar_mult.h +++ b/tests/jacobian_tests/test_const_scalar_mult.h @@ -1,9 +1,11 @@ #include +#include "affine.h" #include "bivariate.h" #include "elementwise_univariate.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" /* Test: y = a * log(x) where a is a scalar constant */ @@ -19,7 +21,8 @@ const char *test_jacobian_const_scalar_mult_log_vector() /* Create scalar mult node: y = 2.5 * log(x) */ double a = 2.5; - expr *y = new_const_scalar_mult(a, log_node); + expr *a_node = new_parameter(1, 1, PARAM_FIXED, 3, &a); + expr *y = new_scalar_mult(a_node, log_node); /* Forward pass */ y->forward(y, u_vals); @@ -55,7 +58,8 @@ const char *test_jacobian_const_scalar_mult_log_matrix() /* Create scalar mult node: y = 3.0 * log(x) */ double a = 3.0; - expr *y = new_const_scalar_mult(a, log_node); + expr *a_node = new_parameter(1, 1, PARAM_FIXED, 4, &a); + expr *y = new_scalar_mult(a_node, log_node); /* Forward pass */ y->forward(y, u_vals); diff --git a/tests/jacobian_tests/test_const_vector_mult.h b/tests/jacobian_tests/test_const_vector_mult.h index 4658dd4..3122e2e 100644 --- a/tests/jacobian_tests/test_const_vector_mult.h +++ b/tests/jacobian_tests/test_const_vector_mult.h @@ -1,9 +1,11 @@ #include +#include "affine.h" #include "bivariate.h" #include "elementwise_univariate.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" /* Test: y = a ∘ log(x) where a is a constant vector */ @@ -19,7 +21,8 @@ const char *test_jacobian_const_vector_mult_log_vector() /* Create vector mult node: y = [2.0, 3.0, 4.0] ∘ log(x) */ double a[3] = {2.0, 3.0, 4.0}; - expr *y = new_const_vector_mult(a, log_node); + expr *a_node = new_parameter(3, 1, PARAM_FIXED, 3, a); + expr *y = new_vector_mult(a_node, log_node); /* Forward pass */ y->forward(y, u_vals); @@ -59,7 +62,8 @@ const char *test_jacobian_const_vector_mult_log_matrix() /* Create vector mult node: y = [1.5, 2.5, 3.5, 4.5] ∘ log(x) */ double a[4] = {1.5, 2.5, 3.5, 4.5}; - expr *y = new_const_vector_mult(a, log_node); + expr *a_node = new_parameter(2, 2, PARAM_FIXED, 4, a); + expr *y = new_vector_mult(a_node, log_node); /* Forward pass */ y->forward(y, u_vals); diff --git a/tests/jacobian_tests/test_left_matmul.h b/tests/jacobian_tests/test_left_matmul.h index f79eee7..854687b 100644 --- a/tests/jacobian_tests/test_left_matmul.h +++ b/tests/jacobian_tests/test_left_matmul.h @@ -5,6 +5,7 @@ #include "elementwise_univariate.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" const char *test_jacobian_left_matmul_log() @@ -30,17 +31,12 @@ const char *test_jacobian_left_matmul_log() double x_vals[3] = {1.0, 2.0, 3.0}; expr *x = new_variable(3, 1, 0, 3); - /* Create sparse matrix A in CSR format */ - CSR_Matrix *A = new_csr_matrix(4, 3, 7); - int A_p[5] = {0, 2, 4, 6, 7}; - int A_i[7] = {0, 2, 0, 2, 0, 2, 0}; - double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; - memcpy(A->p, A_p, 5 * sizeof(int)); - memcpy(A->i, A_i, 7 * sizeof(int)); - memcpy(A->x, A_x, 7 * sizeof(double)); + /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */ + double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0}; + expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals); expr *log_x = new_log(x); - expr *A_log_x = new_left_matmul(log_x, A); + expr *A_log_x = new_left_matmul(A_param, log_x); A_log_x->forward(A_log_x, x_vals); A_log_x->jacobian_init(A_log_x); @@ -63,7 +59,6 @@ const char *test_jacobian_left_matmul_log() mu_assert("cols fail", cmp_int_array(A_log_x->jacobian->i, expected_Ai, 7)); mu_assert("rows fail", cmp_int_array(A_log_x->jacobian->p, expected_Ap, 5)); - free_csr_matrix(A); free_expr(A_log_x); return 0; } @@ -74,17 +69,12 @@ const char *test_jacobian_left_matmul_log_matrix() double x_vals[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; expr *x = new_variable(3, 2, 0, 6); - /* Create sparse matrix A in CSR format (4x3) */ - CSR_Matrix *A = new_csr_matrix(4, 3, 7); - int A_p[5] = {0, 2, 4, 6, 7}; - int A_i[7] = {0, 2, 0, 2, 0, 2, 0}; - double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; - memcpy(A->p, A_p, 5 * sizeof(int)); - memcpy(A->i, A_i, 7 * sizeof(int)); - memcpy(A->x, A_x, 7 * sizeof(double)); + /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */ + double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0}; + expr *A_param = new_parameter(4, 3, PARAM_FIXED, 6, A_vals); expr *log_x = new_log(x); - expr *A_log_x = new_left_matmul(log_x, A); + expr *A_log_x = new_left_matmul(A_param, log_x); A_log_x->forward(A_log_x, x_vals); A_log_x->jacobian_init(A_log_x); @@ -102,7 +92,6 @@ const char *test_jacobian_left_matmul_log_matrix() mu_assert("cols fail", cmp_int_array(A_log_x->jacobian->i, expected_Ai, 14)); mu_assert("rows fail", cmp_int_array(A_log_x->jacobian->p, expected_Ap, 9)); - free_csr_matrix(A); free_expr(A_log_x); return 0; } @@ -145,18 +134,13 @@ const char *test_jacobian_left_matmul_log_composite() memcpy(B->i, B_i, 9 * sizeof(int)); memcpy(B->x, B_x, 9 * sizeof(double)); - /* Create A matrix */ - CSR_Matrix *A = new_csr_matrix(4, 3, 7); - int A_p[5] = {0, 2, 4, 6, 7}; - int A_i[7] = {0, 2, 0, 2, 0, 2, 0}; - double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; - memcpy(A->p, A_p, 5 * sizeof(int)); - memcpy(A->i, A_i, 7 * sizeof(int)); - memcpy(A->x, A_x, 7 * sizeof(double)); + /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */ + double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0}; + expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals); expr *Bx = new_linear(x, B, NULL); expr *log_Bx = new_log(Bx); - expr *A_log_Bx = new_left_matmul(log_Bx, A); + expr *A_log_Bx = new_left_matmul(A_param, log_Bx); A_log_Bx->forward(A_log_Bx, x_vals); A_log_Bx->jacobian_init(A_log_Bx); @@ -176,7 +160,6 @@ const char *test_jacobian_left_matmul_log_composite() mu_assert("cols fail", cmp_int_array(A_log_Bx->jacobian->i, expected_Ai, 12)); mu_assert("rows fail", cmp_int_array(A_log_Bx->jacobian->p, expected_Ap, 5)); - free_csr_matrix(A); free_csr_matrix(B); free_expr(A_log_Bx); return 0; diff --git a/tests/jacobian_tests/test_transpose.h b/tests/jacobian_tests/test_transpose.h index 4fd22d5..581b97c 100644 --- a/tests/jacobian_tests/test_transpose.h +++ b/tests/jacobian_tests/test_transpose.h @@ -4,24 +4,20 @@ #include "affine.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" #include #include const char *test_jacobian_transpose() { - // A = [1 2; 3 4] - CSR_Matrix *A = new_csr_matrix(2, 2, 4); - int A_p[3] = {0, 2, 4}; - int A_i[4] = {0, 1, 0, 1}; - double A_x[4] = {1, 2, 3, 4}; - memcpy(A->p, A_p, 3 * sizeof(int)); - memcpy(A->i, A_i, 4 * sizeof(int)); - memcpy(A->x, A_x, 4 * sizeof(double)); + /* A = [1 2; 3 4] in column-major order: [1, 3, 2, 4] */ + double A_vals[4] = {1.0, 3.0, 2.0, 4.0}; + expr *A_param = new_parameter(2, 2, PARAM_FIXED, 2, A_vals); // X = [1 2; 3 4] (columnwise: x = [1 3 2 4]) expr *X = new_variable(2, 2, 0, 4); - expr *AX = new_left_matmul(X, A); + expr *AX = new_left_matmul(A_param, X); expr *transpose_AX = new_transpose(AX); double u[4] = {1, 3, 2, 4}; transpose_AX->forward(transpose_AX, u); @@ -40,7 +36,6 @@ const char *test_jacobian_transpose() mu_assert("jacobian col idx fail", cmp_int_array(transpose_AX->jacobian->i, expected_i, 8)); free_expr(transpose_AX); - free_csr_matrix(A); return 0; } diff --git a/tests/problem/test_param_prob.h b/tests/problem/test_param_prob.h index b7716bf..250651c 100644 --- a/tests/problem/test_param_prob.h +++ b/tests/problem/test_param_prob.h @@ -33,8 +33,8 @@ const char *test_param_scalar_mult_problem(void) /* Build tree: sum(a * log(x)) */ expr *x = new_variable(2, 1, 0, n_vars); expr *log_x = new_log(x); - expr *a_param = new_parameter(1, 1, 0, n_vars); - expr *scaled = new_param_scalar_mult(a_param, log_x); + expr *a_param = new_parameter(1, 1, 0, n_vars, NULL); + expr *scaled = new_scalar_mult(a_param, log_x); expr *objective = new_sum(scaled, -1); /* Create problem (no constraints) */ @@ -103,8 +103,8 @@ const char *test_param_vector_mult_problem(void) /* Constraint: p ∘ x */ expr *x_con = new_variable(2, 1, 0, n_vars); - expr *p_param = new_parameter(2, 1, 0, n_vars); - expr *constraint = new_param_vector_mult(p_param, x_con); + expr *p_param = new_parameter(2, 1, 0, n_vars, NULL); + expr *constraint = new_vector_mult(p_param, x_con); expr *constraints[1] = {constraint}; @@ -184,8 +184,8 @@ const char *test_param_left_matmul_problem(void) /* Constraint: A @ x */ expr *x_con = new_variable(2, 1, 0, n_vars); - expr *A_param = new_parameter(2, 2, 0, n_vars); - expr *constraint = new_left_param_matmul(A_param, x_con); + expr *A_param = new_parameter(2, 2, 0, n_vars, NULL); + expr *constraint = new_left_matmul(A_param, x_con); expr *constraints[1] = {constraint}; diff --git a/tests/profiling/profile_left_matmul.h b/tests/profiling/profile_left_matmul.h index a8e819b..58940d9 100644 --- a/tests/profiling/profile_left_matmul.h +++ b/tests/profiling/profile_left_matmul.h @@ -8,30 +8,27 @@ #include "elementwise_univariate.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" #include "utils/Timer.h" const char *profile_left_matmul() { - /* A @ X where A is 50 x 50 dense stored in CSR and X is 50 x 50 variable */ + /* A @ X where A is 100 x 100 dense (all ones) and X is 100 x 100 variable */ int n = 100; expr *X = new_variable(n, n, 0, n * n); - CSR_Matrix *A = new_csr_matrix(n, n, n * n); + + /* Create n x n parameter of all ones (column-major, but all ones so order + * doesn't matter) */ + double *A_vals = (double *) malloc(n * n * sizeof(double)); for (int i = 0; i < n * n; i++) { - A->x[i] = 1.0; /* dense matrix of all ones */ - } - for (int row = 0; row < n; row++) - { - A->p[row] = row * n; - for (int col = 0; col < n; col++) - { - A->i[row * n + col] = col; - } + A_vals[i] = 1.0; } - A->p[n] = n * n; + expr *A_param = new_parameter(n, n, PARAM_FIXED, n, A_vals); + free(A_vals); - expr *AX = new_left_matmul(X, A); + expr *AX = new_left_matmul(A_param, X); double *x_vals = (double *) malloc(n * n * sizeof(double)); for (int i = 0; i < n * n; i++) @@ -56,7 +53,6 @@ const char *profile_left_matmul() GET_ELAPSED_SECONDS(timer)); free(x_vals); - free_csr_matrix(A); free_expr(AX); return 0; } diff --git a/tests/wsum_hess/test_const_scalar_mult.h b/tests/wsum_hess/test_const_scalar_mult.h index 7172be4..b4654d5 100644 --- a/tests/wsum_hess/test_const_scalar_mult.h +++ b/tests/wsum_hess/test_const_scalar_mult.h @@ -1,10 +1,12 @@ #include #include +#include "affine.h" #include "bivariate.h" #include "elementwise_univariate.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" /* Test: y = a * log(x) where a is a scalar constant */ @@ -20,7 +22,8 @@ const char *test_wsum_hess_const_scalar_mult_log_vector() /* Create scalar mult node: y = 2.5 * log(x) */ double a = 2.5; - expr *y = new_const_scalar_mult(a, log_node); + expr *a_node = new_parameter(1, 1, PARAM_FIXED, 3, &a); + expr *y = new_scalar_mult(a_node, log_node); /* Forward pass */ y->forward(y, u_vals); @@ -65,7 +68,8 @@ const char *test_wsum_hess_const_scalar_mult_log_matrix() /* Create scalar mult node: y = 3.0 * log(x) */ double a = 3.0; - expr *y = new_const_scalar_mult(a, log_node); + expr *a_node = new_parameter(1, 1, PARAM_FIXED, 4, &a); + expr *y = new_scalar_mult(a_node, log_node); /* Forward pass */ y->forward(y, u_vals); diff --git a/tests/wsum_hess/test_const_vector_mult.h b/tests/wsum_hess/test_const_vector_mult.h index 88bc127..6bee2b8 100644 --- a/tests/wsum_hess/test_const_vector_mult.h +++ b/tests/wsum_hess/test_const_vector_mult.h @@ -1,10 +1,12 @@ #include #include +#include "affine.h" #include "bivariate.h" #include "elementwise_univariate.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" /* Test: y = a ∘ log(x) where a is a constant vector */ @@ -20,7 +22,8 @@ const char *test_wsum_hess_const_vector_mult_log_vector() /* Create vector mult node: y = [2.0, 3.0, 4.0] ∘ log(x) */ double a[3] = {2.0, 3.0, 4.0}; - expr *y = new_const_vector_mult(a, log_node); + expr *a_node = new_parameter(3, 1, PARAM_FIXED, 3, a); + expr *y = new_vector_mult(a_node, log_node); /* Forward pass */ y->forward(y, u_vals); @@ -63,7 +66,8 @@ const char *test_wsum_hess_const_vector_mult_log_matrix() /* Create vector mult node: y = [1.5, 2.5, 3.5, 4.5] ∘ log(x) */ double a[4] = {1.5, 2.5, 3.5, 4.5}; - expr *y = new_const_vector_mult(a, log_node); + expr *a_node = new_parameter(2, 2, PARAM_FIXED, 4, a); + expr *y = new_vector_mult(a_node, log_node); /* Forward pass */ y->forward(y, u_vals); diff --git a/tests/wsum_hess/test_left_matmul.h b/tests/wsum_hess/test_left_matmul.h index 28d1cec..22853b3 100644 --- a/tests/wsum_hess/test_left_matmul.h +++ b/tests/wsum_hess/test_left_matmul.h @@ -8,6 +8,7 @@ #include "elementwise_univariate.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" const char *test_wsum_hess_left_matmul() @@ -52,17 +53,12 @@ const char *test_wsum_hess_left_matmul() expr *x = new_variable(3, 1, 0, 3); - /* Create sparse matrix A in CSR format */ - CSR_Matrix *A = new_csr_matrix(4, 3, 7); - int A_p[5] = {0, 2, 4, 6, 7}; - int A_i[7] = {0, 2, 0, 2, 0, 2, 0}; - double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; - memcpy(A->p, A_p, 5 * sizeof(int)); - memcpy(A->i, A_i, 7 * sizeof(int)); - memcpy(A->x, A_x, 7 * sizeof(double)); + /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */ + double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0}; + expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals); expr *log_x = new_log(x); - expr *A_log_x = new_left_matmul(log_x, A); + expr *A_log_x = new_left_matmul(A_param, log_x); A_log_x->forward(A_log_x, x_vals); A_log_x->jacobian_init(A_log_x); @@ -84,7 +80,6 @@ const char *test_wsum_hess_left_matmul() mu_assert("cols incorrect", cmp_int_array(A_log_x->wsum_hess->i, expected_i, 3)); mu_assert("rows incorrect", cmp_int_array(A_log_x->wsum_hess->p, expected_p, 4)); - free_csr_matrix(A); free_expr(A_log_x); return 0; } @@ -150,18 +145,13 @@ const char *test_wsum_hess_left_matmul_composite() memcpy(B->i, B_i, 9 * sizeof(int)); memcpy(B->x, B_x, 9 * sizeof(double)); - /* Create A matrix */ - CSR_Matrix *A = new_csr_matrix(4, 3, 7); - int A_p[5] = {0, 2, 4, 6, 7}; - int A_i[7] = {0, 2, 0, 2, 0, 2, 0}; - double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; - memcpy(A->p, A_p, 5 * sizeof(int)); - memcpy(A->i, A_i, 7 * sizeof(int)); - memcpy(A->x, A_x, 7 * sizeof(double)); + /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */ + double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0}; + expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals); expr *Bx = new_linear(x, B, NULL); expr *log_Bx = new_log(Bx); - expr *A_log_Bx = new_left_matmul(log_Bx, A); + expr *A_log_Bx = new_left_matmul(A_param, log_Bx); A_log_Bx->forward(A_log_Bx, x_vals); A_log_Bx->jacobian_init(A_log_Bx); @@ -186,7 +176,6 @@ const char *test_wsum_hess_left_matmul_composite() mu_assert("rows incorrect", cmp_int_array(A_log_Bx->wsum_hess->p, expected_p, 4)); - free_csr_matrix(A); free_csr_matrix(B); free_expr(A_log_Bx); return 0; @@ -223,17 +212,12 @@ const char *test_wsum_hess_left_matmul_matrix() expr *x = new_variable(3, 2, 0, 6); - /* Create sparse matrix A in CSR format */ - CSR_Matrix *A = new_csr_matrix(4, 3, 7); - int A_p[5] = {0, 2, 4, 6, 7}; - int A_i[7] = {0, 2, 0, 2, 0, 2, 0}; - double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; - memcpy(A->p, A_p, 5 * sizeof(int)); - memcpy(A->i, A_i, 7 * sizeof(int)); - memcpy(A->x, A_x, 7 * sizeof(double)); + /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */ + double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0}; + expr *A_param = new_parameter(4, 3, PARAM_FIXED, 6, A_vals); expr *log_x = new_log(x); - expr *A_log_x = new_left_matmul(log_x, A); + expr *A_log_x = new_left_matmul(A_param, log_x); A_log_x->forward(A_log_x, x_vals); A_log_x->jacobian_init(A_log_x); @@ -257,7 +241,6 @@ const char *test_wsum_hess_left_matmul_matrix() mu_assert("cols incorrect", cmp_int_array(A_log_x->wsum_hess->i, expected_i, 6)); mu_assert("rows incorrect", cmp_int_array(A_log_x->wsum_hess->p, expected_p, 7)); - free_csr_matrix(A); free_expr(A_log_x); return 0; } From 939a9105b9f217cc1a487f83ef9e1771c085e9ea Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Sun, 15 Feb 2026 17:55:43 -0500 Subject: [PATCH 15/24] Remove redundant NULL-after-free, rename const_ files, fix stale comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Post-refactor cleanup after the Constant→Parameter merge: - Remove dead NULL assignments in free_type_data (quad_form, hstack, index, linear_op, left_matmul, scalar_mult, vector_mult) - Rename const_scalar_mult/const_vector_mult source and test files to drop const_ prefix - Rename test_variable_constant.h → test_variable_parameter.h and test_constant → test_fixed_parameter - Update stale comments in multiply.c and bivariate.h - Rename const_scalar_mult_expr/const_vector_mult_expr structs to scalar_mult_expr/vector_mult_expr - Update left_matmul param_source comment Co-Authored-By: Claude Opus 4.6 --- include/bivariate.h | 2 +- include/subexpr.h | 10 +++++----- src/affine/hstack.c | 3 --- src/affine/index.c | 6 +----- src/affine/linear_op.c | 4 ---- src/bivariate/left_matmul.c | 8 +------- src/bivariate/multiply.c | 2 +- .../{const_scalar_mult.c => scalar_mult.c} | 17 +++++++---------- .../{const_vector_mult.c => vector_mult.c} | 17 +++++++---------- src/other/quad_form.c | 1 - tests/all_tests.c | 12 ++++++------ ...ble_constant.h => test_variable_parameter.h} | 4 ++-- ...t_const_scalar_mult.h => test_scalar_mult.h} | 0 ...t_const_vector_mult.h => test_vector_mult.h} | 0 ...t_const_scalar_mult.h => test_scalar_mult.h} | 0 ...t_const_vector_mult.h => test_vector_mult.h} | 0 16 files changed, 31 insertions(+), 55 deletions(-) rename src/bivariate/{const_scalar_mult.c => scalar_mult.c} (87%) rename src/bivariate/{const_vector_mult.c => vector_mult.c} (88%) rename tests/forward_pass/affine/{test_variable_constant.h => test_variable_parameter.h} (83%) rename tests/jacobian_tests/{test_const_scalar_mult.h => test_scalar_mult.h} (100%) rename tests/jacobian_tests/{test_const_vector_mult.h => test_vector_mult.h} (100%) rename tests/wsum_hess/{test_const_scalar_mult.h => test_scalar_mult.h} (100%) rename tests/wsum_hess/{test_const_vector_mult.h => test_vector_mult.h} (100%) diff --git a/include/bivariate.h b/include/bivariate.h index 15947a1..83c31db 100644 --- a/include/bivariate.h +++ b/include/bivariate.h @@ -33,7 +33,7 @@ expr *new_matmul(expr *x, expr *y); /* Left matrix multiplication: A @ f(x) where A comes from a parameter node */ expr *new_left_matmul(expr *param_node, expr *child); -/* Right matrix multiplication: f(x) @ A where A is a constant matrix */ +/* Right matrix multiplication: f(x) @ A where A is a fixed parameter matrix */ expr *new_right_matmul(expr *u, const CSR_Matrix *A); /* Scalar multiplication: a * f(x) where a comes from a parameter node */ diff --git a/include/subexpr.h b/include/subexpr.h index 41a7833..74ae4b1 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -126,7 +126,7 @@ typedef struct left_matmul_expr CSC_Matrix *Jchild_CSC; CSC_Matrix *J_CSC; int *csc_to_csr_workspace; - expr *param_source; /* if non-NULL, A/AT values come from this parameter */ + expr *param_source; /* parameter node; A/AT values are refreshed from this */ int src_m, src_n; /* original matrix dimensions */ } left_matmul_expr; @@ -142,19 +142,19 @@ typedef struct right_matmul_expr } right_matmul_expr; /* Scalar multiplication: y = a * child where a comes from a parameter node */ -typedef struct const_scalar_mult_expr +typedef struct scalar_mult_expr { expr base; expr *param_source; /* always set; read a from param_source->value[0] */ -} const_scalar_mult_expr; +} scalar_mult_expr; /* Vector elementwise multiplication: y = a \circ child where a comes from a * parameter node */ -typedef struct const_vector_mult_expr +typedef struct vector_mult_expr { expr base; expr *param_source; /* always set; read a from param_source->value */ -} const_vector_mult_expr; +} vector_mult_expr; /* Index/slicing: y = child[indices] where indices is a list of flat positions */ typedef struct index_expr diff --git a/src/affine/hstack.c b/src/affine/hstack.c index e2235d6..b5d8fb8 100644 --- a/src/affine/hstack.c +++ b/src/affine/hstack.c @@ -165,13 +165,10 @@ static void free_type_data(expr *node) for (int i = 0; i < hnode->n_args; i++) { free_expr(hnode->args[i]); - hnode->args[i] = NULL; } free_csr_matrix(hnode->CSR_work); - hnode->CSR_work = NULL; free(hnode->args); - hnode->args = NULL; } expr *new_hstack(expr **args, int n_args, int n_vars) diff --git a/src/affine/index.c b/src/affine/index.c index 9577a05..5702100 100644 --- a/src/affine/index.c +++ b/src/affine/index.c @@ -154,11 +154,7 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { index_expr *idx = (index_expr *) node; - if (idx->indices) - { - free(idx->indices); - idx->indices = NULL; - } + free(idx->indices); } expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs) diff --git a/src/affine/linear_op.c b/src/affine/linear_op.c index a8d2863..04e5b9b 100644 --- a/src/affine/linear_op.c +++ b/src/affine/linear_op.c @@ -62,11 +62,7 @@ static void free_type_data(expr *node) if (lin_node->b != NULL) { free(lin_node->b); - lin_node->b = NULL; } - - lin_node->A_csr = NULL; - lin_node->A_csc = NULL; } static void jacobian_init(expr *node) diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index 96ff7d8..238dbb4 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -95,13 +95,7 @@ static void free_type_data(expr *node) free_csc_matrix(lin_node->Jchild_CSC); free_csc_matrix(lin_node->J_CSC); free(lin_node->csc_to_csr_workspace); - if (lin_node->param_source) free_expr(lin_node->param_source); - lin_node->A = NULL; - lin_node->AT = NULL; - lin_node->Jchild_CSC = NULL; - lin_node->J_CSC = NULL; - lin_node->csc_to_csr_workspace = NULL; - lin_node->param_source = NULL; + free_expr(lin_node->param_source); } static void jacobian_init(expr *node) diff --git a/src/bivariate/multiply.c b/src/bivariate/multiply.c index d8836ae..dceea8d 100644 --- a/src/bivariate/multiply.c +++ b/src/bivariate/multiply.c @@ -27,7 +27,7 @@ // ------------------------------------------------------------------------------ // Implementation of elementwise multiplication when both arguments are vectors. // If one argument is a scalar variable, the broadcasting should be represented -// as a linear operator child node? How to treat if one variable is a constant? +// as a linear operator child node. // ------------------------------------------------------------------------------ static void forward(expr *node, const double *u) { diff --git a/src/bivariate/const_scalar_mult.c b/src/bivariate/scalar_mult.c similarity index 87% rename from src/bivariate/const_scalar_mult.c rename to src/bivariate/scalar_mult.c index 3863730..19aae60 100644 --- a/src/bivariate/const_scalar_mult.c +++ b/src/bivariate/scalar_mult.c @@ -32,7 +32,7 @@ static void forward(expr *node, const double *u) child->forward(child, u); /* local forward pass: multiply each element by scalar a */ - const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node; + scalar_mult_expr *sn = (scalar_mult_expr *) node; double a = sn->param_source->value[0]; for (int i = 0; i < node->size; i++) { @@ -56,7 +56,7 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *child = node->left; - const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node; + scalar_mult_expr *sn = (scalar_mult_expr *) node; double a = sn->param_source->value[0]; /* evaluate child */ @@ -87,7 +87,7 @@ static void eval_wsum_hess(expr *node, const double *w) expr *x = node->left; x->eval_wsum_hess(x, w); - const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node; + scalar_mult_expr *sn = (scalar_mult_expr *) node; double a = sn->param_source->value[0]; for (int j = 0; j < x->wsum_hess->nnz; j++) { @@ -103,17 +103,14 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { - const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node; - if (sn->param_source) - { - free_expr(sn->param_source); - } + scalar_mult_expr *sn = (scalar_mult_expr *) node; + free_expr(sn->param_source); } expr *new_scalar_mult(expr *param_node, expr *child) { - const_scalar_mult_expr *mult_node = - (const_scalar_mult_expr *) calloc(1, sizeof(const_scalar_mult_expr)); + scalar_mult_expr *mult_node = + (scalar_mult_expr *) calloc(1, sizeof(scalar_mult_expr)); expr *node = &mult_node->base; init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init, diff --git a/src/bivariate/const_vector_mult.c b/src/bivariate/vector_mult.c similarity index 88% rename from src/bivariate/const_vector_mult.c rename to src/bivariate/vector_mult.c index 8eba16a..83e6cbf 100644 --- a/src/bivariate/const_vector_mult.c +++ b/src/bivariate/vector_mult.c @@ -27,7 +27,7 @@ static void forward(expr *node, const double *u) { expr *child = node->left; - const_vector_mult_expr *vn = (const_vector_mult_expr *) node; + vector_mult_expr *vn = (vector_mult_expr *) node; const double *a = vn->param_source->value; /* child's forward pass */ @@ -56,7 +56,7 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *x = node->left; - const_vector_mult_expr *vn = (const_vector_mult_expr *) node; + vector_mult_expr *vn = (vector_mult_expr *) node; const double *a = vn->param_source->value; /* evaluate x */ @@ -90,7 +90,7 @@ static void wsum_hess_init(expr *node) static void eval_wsum_hess(expr *node, const double *w) { expr *x = node->left; - const_vector_mult_expr *vn = (const_vector_mult_expr *) node; + vector_mult_expr *vn = (vector_mult_expr *) node; const double *a = vn->param_source->value; /* scale weights w by a */ @@ -113,17 +113,14 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { - const_vector_mult_expr *vnode = (const_vector_mult_expr *) node; - if (vnode->param_source) - { - free_expr(vnode->param_source); - } + vector_mult_expr *vnode = (vector_mult_expr *) node; + free_expr(vnode->param_source); } expr *new_vector_mult(expr *param_node, expr *child) { - const_vector_mult_expr *vnode = - (const_vector_mult_expr *) calloc(1, sizeof(const_vector_mult_expr)); + vector_mult_expr *vnode = + (vector_mult_expr *) calloc(1, sizeof(vector_mult_expr)); expr *node = &vnode->base; init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init, diff --git a/src/other/quad_form.c b/src/other/quad_form.c index d03c70e..25fedce 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -174,7 +174,6 @@ static void free_type_data(expr *node) { quad_form_expr *qnode = (quad_form_expr *) node; free_csr_matrix(qnode->Q); - qnode->Q = NULL; } static bool is_affine(const expr *node) diff --git a/tests/all_tests.c b/tests/all_tests.c index c2d05c8..219014c 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -11,7 +11,7 @@ #include "forward_pass/affine/test_neg.h" #include "forward_pass/affine/test_promote.h" #include "forward_pass/affine/test_sum.h" -#include "forward_pass/affine/test_variable_constant.h" +#include "forward_pass/affine/test_variable_parameter.h" #include "forward_pass/composite/test_composite.h" #include "forward_pass/elementwise/test_exp.h" #include "forward_pass/elementwise/test_log.h" @@ -20,8 +20,8 @@ #include "forward_pass/test_prod_axis_zero.h" #include "jacobian_tests/test_broadcast.h" #include "jacobian_tests/test_composite.h" -#include "jacobian_tests/test_const_scalar_mult.h" -#include "jacobian_tests/test_const_vector_mult.h" +#include "jacobian_tests/test_scalar_mult.h" +#include "jacobian_tests/test_vector_mult.h" #include "jacobian_tests/test_elementwise_mult.h" #include "jacobian_tests/test_hstack.h" #include "jacobian_tests/test_index.h" @@ -57,8 +57,8 @@ #include "wsum_hess/elementwise/test_trig.h" #include "wsum_hess/elementwise/test_xexp.h" #include "wsum_hess/test_broadcast.h" -#include "wsum_hess/test_const_scalar_mult.h" -#include "wsum_hess/test_const_vector_mult.h" +#include "wsum_hess/test_scalar_mult.h" +#include "wsum_hess/test_vector_mult.h" #include "wsum_hess/test_hstack.h" #include "wsum_hess/test_index.h" #include "wsum_hess/test_left_matmul.h" @@ -91,7 +91,7 @@ int main(void) #ifndef PROFILE_ONLY printf("--- Forward Pass Tests ---\n"); mu_run_test(test_variable, tests_run); - mu_run_test(test_constant, tests_run); + mu_run_test(test_fixed_parameter, tests_run); mu_run_test(test_addition, tests_run); mu_run_test(test_linear_op, tests_run); mu_run_test(test_neg_forward, tests_run); diff --git a/tests/forward_pass/affine/test_variable_constant.h b/tests/forward_pass/affine/test_variable_parameter.h similarity index 83% rename from tests/forward_pass/affine/test_variable_constant.h rename to tests/forward_pass/affine/test_variable_parameter.h index ea9b609..f9c29fb 100644 --- a/tests/forward_pass/affine/test_variable_constant.h +++ b/tests/forward_pass/affine/test_variable_parameter.h @@ -18,13 +18,13 @@ const char *test_variable() return 0; } -const char *test_constant() +const char *test_fixed_parameter() { double c[2] = {5.0, 10.0}; double u[2] = {0.0, 0.0}; expr *const_node = new_parameter(2, 1, PARAM_FIXED, 0, c); const_node->forward(const_node, u); - mu_assert("Constant test failed", cmp_double_array(const_node->value, c, 2)); + mu_assert("Fixed parameter test failed", cmp_double_array(const_node->value, c, 2)); free_expr(const_node); return 0; } diff --git a/tests/jacobian_tests/test_const_scalar_mult.h b/tests/jacobian_tests/test_scalar_mult.h similarity index 100% rename from tests/jacobian_tests/test_const_scalar_mult.h rename to tests/jacobian_tests/test_scalar_mult.h diff --git a/tests/jacobian_tests/test_const_vector_mult.h b/tests/jacobian_tests/test_vector_mult.h similarity index 100% rename from tests/jacobian_tests/test_const_vector_mult.h rename to tests/jacobian_tests/test_vector_mult.h diff --git a/tests/wsum_hess/test_const_scalar_mult.h b/tests/wsum_hess/test_scalar_mult.h similarity index 100% rename from tests/wsum_hess/test_const_scalar_mult.h rename to tests/wsum_hess/test_scalar_mult.h diff --git a/tests/wsum_hess/test_const_vector_mult.h b/tests/wsum_hess/test_vector_mult.h similarity index 100% rename from tests/wsum_hess/test_const_vector_mult.h rename to tests/wsum_hess/test_vector_mult.h From 71dddf1a4cc03c887e7c7c9ae305952683462335 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Sun, 15 Feb 2026 17:56:37 -0500 Subject: [PATCH 16/24] Run clang-format on cleanup changes Co-Authored-By: Claude Opus 4.6 --- tests/all_tests.c | 8 ++++---- tests/forward_pass/affine/test_variable_parameter.h | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/all_tests.c b/tests/all_tests.c index 219014c..8cede31 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -20,8 +20,6 @@ #include "forward_pass/test_prod_axis_zero.h" #include "jacobian_tests/test_broadcast.h" #include "jacobian_tests/test_composite.h" -#include "jacobian_tests/test_scalar_mult.h" -#include "jacobian_tests/test_vector_mult.h" #include "jacobian_tests/test_elementwise_mult.h" #include "jacobian_tests/test_hstack.h" #include "jacobian_tests/test_index.h" @@ -39,9 +37,11 @@ #include "jacobian_tests/test_rel_entr_scalar_vector.h" #include "jacobian_tests/test_rel_entr_vector_scalar.h" #include "jacobian_tests/test_right_matmul.h" +#include "jacobian_tests/test_scalar_mult.h" #include "jacobian_tests/test_sum.h" #include "jacobian_tests/test_trace.h" #include "jacobian_tests/test_transpose.h" +#include "jacobian_tests/test_vector_mult.h" #include "problem/test_param_prob.h" #include "problem/test_problem.h" #include "utils/test_csc_matrix.h" @@ -57,8 +57,6 @@ #include "wsum_hess/elementwise/test_trig.h" #include "wsum_hess/elementwise/test_xexp.h" #include "wsum_hess/test_broadcast.h" -#include "wsum_hess/test_scalar_mult.h" -#include "wsum_hess/test_vector_mult.h" #include "wsum_hess/test_hstack.h" #include "wsum_hess/test_index.h" #include "wsum_hess/test_left_matmul.h" @@ -73,9 +71,11 @@ #include "wsum_hess/test_rel_entr_scalar_vector.h" #include "wsum_hess/test_rel_entr_vector_scalar.h" #include "wsum_hess/test_right_matmul.h" +#include "wsum_hess/test_scalar_mult.h" #include "wsum_hess/test_sum.h" #include "wsum_hess/test_trace.h" #include "wsum_hess/test_transpose.h" +#include "wsum_hess/test_vector_mult.h" #endif /* PROFILE_ONLY */ #ifdef PROFILE_ONLY diff --git a/tests/forward_pass/affine/test_variable_parameter.h b/tests/forward_pass/affine/test_variable_parameter.h index f9c29fb..2cfea7b 100644 --- a/tests/forward_pass/affine/test_variable_parameter.h +++ b/tests/forward_pass/affine/test_variable_parameter.h @@ -24,7 +24,8 @@ const char *test_fixed_parameter() double u[2] = {0.0, 0.0}; expr *const_node = new_parameter(2, 1, PARAM_FIXED, 0, c); const_node->forward(const_node, u); - mu_assert("Fixed parameter test failed", cmp_double_array(const_node->value, c, 2)); + mu_assert("Fixed parameter test failed", + cmp_double_array(const_node->value, c, 2)); free_expr(const_node); return 0; } From 36864d47e63a6f53726905e222d4b5c1d3845706 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Sun, 15 Feb 2026 19:33:13 -0500 Subject: [PATCH 17/24] Simplify new_left_matmul: accept CSR directly, remove sparse/dense branching - Change signature to (expr *param_node, expr *child, const CSR_Matrix *A) - Constructor copies A with new_csr() instead of rebuilding from dense values - Remove src_m/src_n fields from left_matmul_expr (use A->m directly) - Allow param_node=NULL for fixed constants (no-op in refresh_param_values) - Update all tests to pass CSR directly; fixed-constant tests use NULL param Co-Authored-By: Claude Opus 4.6 --- include/bivariate.h | 2 +- include/subexpr.h | 1 - src/bivariate/left_matmul.c | 55 +++++-------------------- src/bivariate/right_matmul.c | 2 +- tests/jacobian_tests/test_left_matmul.h | 43 +++++++++++++------ tests/jacobian_tests/test_transpose.h | 16 +++++-- tests/problem/test_param_prob.h | 14 ++++++- tests/profiling/profile_left_matmul.h | 25 +++++++---- tests/wsum_hess/test_left_matmul.h | 42 +++++++++++++------ 9 files changed, 116 insertions(+), 84 deletions(-) diff --git a/include/bivariate.h b/include/bivariate.h index 83c31db..19ecd9d 100644 --- a/include/bivariate.h +++ b/include/bivariate.h @@ -31,7 +31,7 @@ expr *new_rel_entr_second_arg_scalar(expr *left, expr *right); expr *new_matmul(expr *x, expr *y); /* Left matrix multiplication: A @ f(x) where A comes from a parameter node */ -expr *new_left_matmul(expr *param_node, expr *child); +expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A); /* Right matrix multiplication: f(x) @ A where A is a fixed parameter matrix */ expr *new_right_matmul(expr *u, const CSR_Matrix *A); diff --git a/include/subexpr.h b/include/subexpr.h index 74ae4b1..3a9cc54 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -127,7 +127,6 @@ typedef struct left_matmul_expr CSC_Matrix *J_CSC; int *csc_to_csr_workspace; expr *param_source; /* parameter node; A/AT values are refreshed from this */ - int src_m, src_n; /* original matrix dimensions */ } left_matmul_expr; /* Right matrix multiplication: y = f(x) * A where f(x) is an expression. diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index 238dbb4..5269c58 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -51,11 +51,14 @@ #include /* Refresh A and AT values from param_source. - A is the small m x n matrix (NOT block-diagonal). */ + A is the small m x n matrix (NOT block-diagonal). + No-op when param_source is NULL (fixed constant — values already in A). */ static void refresh_param_values(left_matmul_expr *lin_node) { + if (!lin_node->param_source) return; + const double *src = lin_node->param_source->value; - int m = lin_node->src_m; + int m = lin_node->A->m; CSR_Matrix *A = lin_node->A; /* Fill A values from column-major source, following existing sparsity pattern */ @@ -163,10 +166,10 @@ static void eval_wsum_hess(expr *node, const double *w) node->wsum_hess->nnz * sizeof(double)); } -expr *new_left_matmul(expr *param_node, expr *child) +expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A) { - int A_m = param_node->d1; - int A_n = param_node->d2; + int A_m = A->m; + int A_n = A->n; /* Dimension logic: handle numpy broadcasting (1, n) as (n, ) */ int d1, d2, n_blocks; @@ -188,40 +191,6 @@ expr *new_left_matmul(expr *param_node, expr *child) exit(1); } - /* Build CSR from param_node's column-major values. - * For fixed parameters (PARAM_FIXED), skip zeros to preserve sparsity. - * For updatable parameters, build dense CSR since sparsity may change. */ - parameter_expr *pnode = (parameter_expr *) param_node; - int sparse = (pnode->param_id == PARAM_FIXED); - - int nnz = 0; - if (sparse) - { - for (int row = 0; row < A_m; row++) - for (int col = 0; col < A_n; col++) - if (param_node->value[row + col * A_m] != 0.0) nnz++; - } - else - { - nnz = A_m * A_n; - } - - CSR_Matrix *A = new_csr_matrix(A_m, A_n, nnz); - int idx = 0; - for (int row = 0; row < A_m; row++) - { - A->p[row] = idx; - for (int col = 0; col < A_n; col++) - { - double val = param_node->value[row + col * A_m]; - if (sparse && val == 0.0) continue; - A->i[idx] = col; - A->x[idx] = val; - idx++; - } - } - A->p[A_m] = idx; - /* Allocate the type-specific struct */ left_matmul_expr *lin_node = (left_matmul_expr *) calloc(1, sizeof(left_matmul_expr)); @@ -235,13 +204,11 @@ expr *new_left_matmul(expr *param_node, expr *child) node->iwork = (int *) malloc(MAX(A_n, node->n_vars) * sizeof(int)); lin_node->csc_to_csr_workspace = (int *) malloc(node->size * sizeof(int)); lin_node->n_blocks = n_blocks; - lin_node->A = A; /* transfer ownership */ - lin_node->AT = transpose(A, node->iwork); + lin_node->A = new_csr(A); + lin_node->AT = transpose(lin_node->A, node->iwork); lin_node->param_source = param_node; - lin_node->src_m = A_m; - lin_node->src_n = A_n; - expr_retain(param_node); + if (param_node) expr_retain(param_node); return node; } diff --git a/src/bivariate/right_matmul.c b/src/bivariate/right_matmul.c index e64bf3c..2f7d8c8 100644 --- a/src/bivariate/right_matmul.c +++ b/src/bivariate/right_matmul.c @@ -43,7 +43,7 @@ expr *new_right_matmul(expr *u, const CSR_Matrix *A) expr *u_transpose = new_transpose(u); expr *param_node = new_parameter(m, n, PARAM_FIXED, u->n_vars, col_major); - expr *left_matmul_node = new_left_matmul(param_node, u_transpose); + expr *left_matmul_node = new_left_matmul(param_node, u_transpose, AT); expr *node = new_transpose(left_matmul_node); free(col_major); diff --git a/tests/jacobian_tests/test_left_matmul.h b/tests/jacobian_tests/test_left_matmul.h index 854687b..feda189 100644 --- a/tests/jacobian_tests/test_left_matmul.h +++ b/tests/jacobian_tests/test_left_matmul.h @@ -1,5 +1,6 @@ #include #include +#include #include "bivariate.h" #include "elementwise_univariate.h" @@ -31,12 +32,18 @@ const char *test_jacobian_left_matmul_log() double x_vals[3] = {1.0, 2.0, 3.0}; expr *x = new_variable(3, 1, 0, 3); - /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */ - double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0}; - expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals); + /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] */ + CSR_Matrix *A = new_csr_matrix(4, 3, 7); + int A_p[5] = {0, 2, 4, 6, 7}; + int A_i[7] = {0, 2, 0, 2, 0, 2, 0}; + double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; + memcpy(A->p, A_p, 5 * sizeof(int)); + memcpy(A->i, A_i, 7 * sizeof(int)); + memcpy(A->x, A_x, 7 * sizeof(double)); expr *log_x = new_log(x); - expr *A_log_x = new_left_matmul(A_param, log_x); + expr *A_log_x = new_left_matmul(NULL, log_x, A); + free_csr_matrix(A); A_log_x->forward(A_log_x, x_vals); A_log_x->jacobian_init(A_log_x); @@ -69,12 +76,18 @@ const char *test_jacobian_left_matmul_log_matrix() double x_vals[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; expr *x = new_variable(3, 2, 0, 6); - /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */ - double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0}; - expr *A_param = new_parameter(4, 3, PARAM_FIXED, 6, A_vals); + /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] */ + CSR_Matrix *A = new_csr_matrix(4, 3, 7); + int A_p[5] = {0, 2, 4, 6, 7}; + int A_i[7] = {0, 2, 0, 2, 0, 2, 0}; + double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; + memcpy(A->p, A_p, 5 * sizeof(int)); + memcpy(A->i, A_i, 7 * sizeof(int)); + memcpy(A->x, A_x, 7 * sizeof(double)); expr *log_x = new_log(x); - expr *A_log_x = new_left_matmul(A_param, log_x); + expr *A_log_x = new_left_matmul(NULL, log_x, A); + free_csr_matrix(A); A_log_x->forward(A_log_x, x_vals); A_log_x->jacobian_init(A_log_x); @@ -134,13 +147,19 @@ const char *test_jacobian_left_matmul_log_composite() memcpy(B->i, B_i, 9 * sizeof(int)); memcpy(B->x, B_x, 9 * sizeof(double)); - /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */ - double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0}; - expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals); + /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] */ + CSR_Matrix *A = new_csr_matrix(4, 3, 7); + int A_p[5] = {0, 2, 4, 6, 7}; + int A_i[7] = {0, 2, 0, 2, 0, 2, 0}; + double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; + memcpy(A->p, A_p, 5 * sizeof(int)); + memcpy(A->i, A_i, 7 * sizeof(int)); + memcpy(A->x, A_x, 7 * sizeof(double)); expr *Bx = new_linear(x, B, NULL); expr *log_Bx = new_log(Bx); - expr *A_log_Bx = new_left_matmul(A_param, log_Bx); + expr *A_log_Bx = new_left_matmul(NULL, log_Bx, A); + free_csr_matrix(A); A_log_Bx->forward(A_log_Bx, x_vals); A_log_Bx->jacobian_init(A_log_Bx); diff --git a/tests/jacobian_tests/test_transpose.h b/tests/jacobian_tests/test_transpose.h index 581b97c..0870d20 100644 --- a/tests/jacobian_tests/test_transpose.h +++ b/tests/jacobian_tests/test_transpose.h @@ -3,21 +3,29 @@ #define TEST_TRANSPOSE_H #include "affine.h" +#include "bivariate.h" #include "minunit.h" #include "subexpr.h" #include "test_helpers.h" #include #include +#include const char *test_jacobian_transpose() { - /* A = [1 2; 3 4] in column-major order: [1, 3, 2, 4] */ - double A_vals[4] = {1.0, 3.0, 2.0, 4.0}; - expr *A_param = new_parameter(2, 2, PARAM_FIXED, 2, A_vals); + /* A = [1 2; 3 4] as dense 2x2 CSR */ + CSR_Matrix *A = new_csr_matrix(2, 2, 4); + int Ap[3] = {0, 2, 4}; + int Ai[4] = {0, 1, 0, 1}; + double Ax[4] = {1.0, 2.0, 3.0, 4.0}; + memcpy(A->p, Ap, 3 * sizeof(int)); + memcpy(A->i, Ai, 4 * sizeof(int)); + memcpy(A->x, Ax, 4 * sizeof(double)); // X = [1 2; 3 4] (columnwise: x = [1 3 2 4]) expr *X = new_variable(2, 2, 0, 4); - expr *AX = new_left_matmul(A_param, X); + expr *AX = new_left_matmul(NULL, X, A); + free_csr_matrix(A); expr *transpose_AX = new_transpose(AX); double u[4] = {1, 3, 2, 4}; transpose_AX->forward(transpose_AX, u); diff --git a/tests/problem/test_param_prob.h b/tests/problem/test_param_prob.h index 250651c..fcab20a 100644 --- a/tests/problem/test_param_prob.h +++ b/tests/problem/test_param_prob.h @@ -3,6 +3,7 @@ #include #include +#include #include "affine.h" #include "bivariate.h" @@ -185,7 +186,18 @@ const char *test_param_left_matmul_problem(void) /* Constraint: A @ x */ expr *x_con = new_variable(2, 1, 0, n_vars); expr *A_param = new_parameter(2, 2, 0, n_vars, NULL); - expr *constraint = new_left_matmul(A_param, x_con); + + /* Dense 2x2 CSR with placeholder zeros (values refreshed from A_param) */ + CSR_Matrix *A = new_csr_matrix(2, 2, 4); + int Ap[3] = {0, 2, 4}; + int Ai[4] = {0, 1, 0, 1}; + double Ax[4] = {0.0, 0.0, 0.0, 0.0}; + memcpy(A->p, Ap, 3 * sizeof(int)); + memcpy(A->i, Ai, 4 * sizeof(int)); + memcpy(A->x, Ax, 4 * sizeof(double)); + + expr *constraint = new_left_matmul(A_param, x_con, A); + free_csr_matrix(A); expr *constraints[1] = {constraint}; diff --git a/tests/profiling/profile_left_matmul.h b/tests/profiling/profile_left_matmul.h index 58940d9..bb2f12c 100644 --- a/tests/profiling/profile_left_matmul.h +++ b/tests/profiling/profile_left_matmul.h @@ -18,17 +18,26 @@ const char *profile_left_matmul() int n = 100; expr *X = new_variable(n, n, 0, n * n); - /* Create n x n parameter of all ones (column-major, but all ones so order - * doesn't matter) */ - double *A_vals = (double *) malloc(n * n * sizeof(double)); - for (int i = 0; i < n * n; i++) + /* Build dense n x n CSR (all ones) */ + int nnz = n * n; + CSR_Matrix *A = new_csr_matrix(n, n, nnz); { - A_vals[i] = 1.0; + int idx = 0; + for (int row = 0; row < n; row++) + { + A->p[row] = idx; + for (int col = 0; col < n; col++) + { + A->i[idx] = col; + A->x[idx] = 1.0; + idx++; + } + } + A->p[n] = idx; } - expr *A_param = new_parameter(n, n, PARAM_FIXED, n, A_vals); - free(A_vals); - expr *AX = new_left_matmul(A_param, X); + expr *AX = new_left_matmul(NULL, X, A); + free_csr_matrix(A); double *x_vals = (double *) malloc(n * n * sizeof(double)); for (int i = 0; i < n * n; i++) diff --git a/tests/wsum_hess/test_left_matmul.h b/tests/wsum_hess/test_left_matmul.h index 22853b3..5722f25 100644 --- a/tests/wsum_hess/test_left_matmul.h +++ b/tests/wsum_hess/test_left_matmul.h @@ -53,12 +53,18 @@ const char *test_wsum_hess_left_matmul() expr *x = new_variable(3, 1, 0, 3); - /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */ - double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0}; - expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals); + /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] */ + CSR_Matrix *A = new_csr_matrix(4, 3, 7); + int A_p[5] = {0, 2, 4, 6, 7}; + int A_i[7] = {0, 2, 0, 2, 0, 2, 0}; + double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; + memcpy(A->p, A_p, 5 * sizeof(int)); + memcpy(A->i, A_i, 7 * sizeof(int)); + memcpy(A->x, A_x, 7 * sizeof(double)); expr *log_x = new_log(x); - expr *A_log_x = new_left_matmul(A_param, log_x); + expr *A_log_x = new_left_matmul(NULL, log_x, A); + free_csr_matrix(A); A_log_x->forward(A_log_x, x_vals); A_log_x->jacobian_init(A_log_x); @@ -145,13 +151,19 @@ const char *test_wsum_hess_left_matmul_composite() memcpy(B->i, B_i, 9 * sizeof(int)); memcpy(B->x, B_x, 9 * sizeof(double)); - /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */ - double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0}; - expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals); + /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] */ + CSR_Matrix *A = new_csr_matrix(4, 3, 7); + int A_p[5] = {0, 2, 4, 6, 7}; + int A_i[7] = {0, 2, 0, 2, 0, 2, 0}; + double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; + memcpy(A->p, A_p, 5 * sizeof(int)); + memcpy(A->i, A_i, 7 * sizeof(int)); + memcpy(A->x, A_x, 7 * sizeof(double)); expr *Bx = new_linear(x, B, NULL); expr *log_Bx = new_log(Bx); - expr *A_log_Bx = new_left_matmul(A_param, log_Bx); + expr *A_log_Bx = new_left_matmul(NULL, log_Bx, A); + free_csr_matrix(A); A_log_Bx->forward(A_log_Bx, x_vals); A_log_Bx->jacobian_init(A_log_Bx); @@ -212,12 +224,18 @@ const char *test_wsum_hess_left_matmul_matrix() expr *x = new_variable(3, 2, 0, 6); - /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */ - double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0}; - expr *A_param = new_parameter(4, 3, PARAM_FIXED, 6, A_vals); + /* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] */ + CSR_Matrix *A = new_csr_matrix(4, 3, 7); + int A_p[5] = {0, 2, 4, 6, 7}; + int A_i[7] = {0, 2, 0, 2, 0, 2, 0}; + double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; + memcpy(A->p, A_p, 5 * sizeof(int)); + memcpy(A->i, A_i, 7 * sizeof(int)); + memcpy(A->x, A_x, 7 * sizeof(double)); expr *log_x = new_log(x); - expr *A_log_x = new_left_matmul(A_param, log_x); + expr *A_log_x = new_left_matmul(NULL, log_x, A); + free_csr_matrix(A); A_log_x->forward(A_log_x, x_vals); A_log_x->jacobian_init(A_log_x); From 978aa03023945143aa7656b384b4a5e9d02428a6 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Mon, 16 Feb 2026 00:19:33 -0500 Subject: [PATCH 18/24] Store param values in CSR data order, simplify refresh to memcpy - left_matmul: replace col-major loop in refresh_param_values with memcpy of nnz doubles (values now arrive in CSR data order) - right_matmul: pass AT->x directly to new_parameter(nnz, 1, ...), remove col-major round-trip allocation - test_param_prob: update theta arrays to CSR data order Co-Authored-By: Claude Opus 4.6 --- src/bivariate/left_matmul.c | 12 +++--------- src/bivariate/right_matmul.c | 12 ++---------- tests/problem/test_param_prob.h | 14 +++++++------- 3 files changed, 12 insertions(+), 26 deletions(-) diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index 5269c58..bd73b82 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -57,17 +57,11 @@ static void refresh_param_values(left_matmul_expr *lin_node) { if (!lin_node->param_source) return; - const double *src = lin_node->param_source->value; - int m = lin_node->A->m; - CSR_Matrix *A = lin_node->A; - - /* Fill A values from column-major source, following existing sparsity pattern */ - for (int row = 0; row < m; row++) - for (int k = A->p[row]; k < A->p[row + 1]; k++) - A->x[k] = src[row + A->i[k] * m]; + memcpy(lin_node->A->x, lin_node->param_source->value, + lin_node->A->nnz * sizeof(double)); /* Recompute AT values from updated A */ - AT_fill_values(A, lin_node->AT, lin_node->base.iwork); + AT_fill_values(lin_node->A, lin_node->AT, lin_node->base.iwork); } static void forward(expr *node, const double *u) diff --git a/src/bivariate/right_matmul.c b/src/bivariate/right_matmul.c index 2f7d8c8..58b91e4 100644 --- a/src/bivariate/right_matmul.c +++ b/src/bivariate/right_matmul.c @@ -33,20 +33,12 @@ expr *new_right_matmul(expr *u, const CSR_Matrix *A) int *work_transpose = (int *) malloc(A->n * sizeof(int)); CSR_Matrix *AT = transpose(A, work_transpose); - /* Convert AT (CSR) to dense column-major array for parameter node */ - int m = AT->m; /* rows of AT = cols of A */ - int n = AT->n; /* cols of AT = rows of A */ - double *col_major = (double *) calloc(m * n, sizeof(double)); - for (int row = 0; row < m; row++) - for (int k = AT->p[row]; k < AT->p[row + 1]; k++) - col_major[row + AT->i[k] * m] = AT->x[k]; - + /* Parameter stores CSR data order (same as AT->x) */ expr *u_transpose = new_transpose(u); - expr *param_node = new_parameter(m, n, PARAM_FIXED, u->n_vars, col_major); + expr *param_node = new_parameter(AT->nnz, 1, PARAM_FIXED, u->n_vars, AT->x); expr *left_matmul_node = new_left_matmul(param_node, u_transpose, AT); expr *node = new_transpose(left_matmul_node); - free(col_major); free_csr_matrix(AT); free(work_transpose); return node; diff --git a/tests/problem/test_param_prob.h b/tests/problem/test_param_prob.h index fcab20a..79761b4 100644 --- a/tests/problem/test_param_prob.h +++ b/tests/problem/test_param_prob.h @@ -164,14 +164,14 @@ const char *test_param_vector_mult_problem(void) * Test 3: left_param_matmul in constraint * * Problem: minimize sum(x), subject to A @ x, x size 2, A is 2x2 - * A is a 2x2 matrix parameter (param_id=0, size=4, column-major) - * A = [[1,2],[3,4]] → column-major theta = [1,3,2,4] + * A is a 2x2 matrix parameter (param_id=0, size=4, CSR data order) + * A = [[1,2],[3,4]] → CSR data order theta = [1,2,3,4] * * At x=[1,2]: * constraint_values = [1*1+2*2, 3*1+4*2] = [5, 11] * jacobian = [[1,2],[3,4]] * - * After update A = [[5,6],[7,8]] → theta = [5,7,6,8]: + * After update A = [[5,6],[7,8]] → theta = [5,6,7,8]: * constraint_values = [5*1+6*2, 7*1+8*2] = [17, 23] * jacobian = [[5,6],[7,8]] */ @@ -208,8 +208,8 @@ const char *test_param_left_matmul_problem(void) problem_register_params(prob, param_nodes, 1); problem_init_derivatives(prob); - /* Set A = [[1,2],[3,4]], column-major: [1,3,2,4] */ - double theta[4] = {1.0, 3.0, 2.0, 4.0}; + /* Set A = [[1,2],[3,4]], CSR data order: [1,2,3,4] */ + double theta[4] = {1.0, 2.0, 3.0, 4.0}; problem_update_params(prob, theta); double u[2] = {1.0, 2.0}; @@ -235,8 +235,8 @@ const char *test_param_left_matmul_problem(void) double expected_x[4] = {1.0, 2.0, 3.0, 4.0}; mu_assert("jac->x wrong (A1)", cmp_double_array(jac->x, expected_x, 4)); - /* Update A = [[5,6],[7,8]], column-major: [5,7,6,8] */ - double theta2[4] = {5.0, 7.0, 6.0, 8.0}; + /* Update A = [[5,6],[7,8]], CSR data order: [5,6,7,8] */ + double theta2[4] = {5.0, 6.0, 7.0, 8.0}; problem_update_params(prob, theta2); problem_constraint_forward(prob, u); From b3e23044eb859b67cb884f5f0ba20785841bc9ad Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 17 Feb 2026 08:37:54 -0800 Subject: [PATCH 19/24] small edits --- src/affine/parameter.c | 13 +++++++------ src/bivariate/left_matmul.c | 18 ++++++++++-------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/affine/parameter.c b/src/affine/parameter.c index c50a5cb..c7698db 100644 --- a/src/affine/parameter.c +++ b/src/affine/parameter.c @@ -70,16 +70,17 @@ static bool is_affine(const expr *node) expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values) { parameter_expr *pnode = (parameter_expr *) calloc(1, sizeof(parameter_expr)); - init_expr(&pnode->base, d1, d2, n_vars, forward, jacobian_init, eval_jacobian, - is_affine, wsum_hess_init, eval_wsum_hess, NULL); + expr *node = &pnode->base; + init_expr(node, d1, d2, n_vars, forward, jacobian_init, eval_jacobian, is_affine, + wsum_hess_init, eval_wsum_hess, NULL); pnode->param_id = param_id; - /* If values provided (fixed constant), copy them now */ + /* If values provided (fixed constant), copy them now. + Otherwise values will be populated by problem_update_params. */ if (values != NULL) { - memcpy(pnode->base.value, values, pnode->base.size * sizeof(double)); + memcpy(node->value, values, node->size * sizeof(double)); } - /* Otherwise values will be populated by problem_update_params */ - return &pnode->base; + return node; } diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index bd73b82..679da89 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -162,21 +162,23 @@ static void eval_wsum_hess(expr *node, const double *w) expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A) { - int A_m = A->m; - int A_n = A->n; + /* Dimension logic: handle numpy broadcasting (1, n) as (n, )/ + We expect u->d1 == A->n. However, numpy's broadcasting rules allow users + to do A @ u where u is (n, ) which in C is actually (1, n). In that case + the result of A @ u is (m, ), which is (1, m) according to broadcasting + rules. We therefore check if this is the case. */ - /* Dimension logic: handle numpy broadcasting (1, n) as (n, ) */ int d1, d2, n_blocks; - if (child->d1 == A_n) + if (child->d1 == A->n) { - d1 = A_m; + d1 = A->m; d2 = child->d2; n_blocks = child->d2; } - else if (child->d2 == A_n && child->d1 == 1) + else if (child->d2 == A->n && child->d1 == 1) { d1 = 1; - d2 = A_m; + d2 = A->m; n_blocks = 1; } else @@ -195,7 +197,7 @@ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A) expr_retain(child); /* Store small A (NOT block-diagonal) — block functions handle the rest */ - node->iwork = (int *) malloc(MAX(A_n, node->n_vars) * sizeof(int)); + node->iwork = (int *) malloc(MAX(A->n, node->n_vars) * sizeof(int)); lin_node->csc_to_csr_workspace = (int *) malloc(node->size * sizeof(int)); lin_node->n_blocks = n_blocks; lin_node->A = new_csr(A); From 0edf92829b0725474fde97888957dec40a3112aa Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 17 Feb 2026 09:13:32 -0800 Subject: [PATCH 20/24] fix AT workspace in left_matmul and add has_been_refreshed --- include/bivariate.h | 3 ++- include/problem.h | 2 +- include/subexpr.h | 4 +++- src/affine/parameter.c | 1 + src/bivariate/left_matmul.c | 24 +++++++++++++----------- src/problem.c | 6 +++++- 6 files changed, 25 insertions(+), 15 deletions(-) diff --git a/include/bivariate.h b/include/bivariate.h index 19ecd9d..3795786 100644 --- a/include/bivariate.h +++ b/include/bivariate.h @@ -30,7 +30,8 @@ expr *new_rel_entr_second_arg_scalar(expr *left, expr *right); /* Matrix multiplication: Z = X @ Y */ expr *new_matmul(expr *x, expr *y); -/* Left matrix multiplication: A @ f(x) where A comes from a parameter node */ +/* Left matrix multiplication: A @ f(x) where A comes from a parameter node. + Only the forward pass possibly updates the parameter. */ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A); /* Right matrix multiplication: f(x) @ A where A is a fixed parameter matrix */ diff --git a/include/problem.h b/include/problem.h index 24d354b..c84be7d 100644 --- a/include/problem.h +++ b/include/problem.h @@ -59,7 +59,7 @@ typedef struct problem * hessian are called */ bool jacobian_called; - /* Parameter tracking for fast parameter updates */ + /* Parameter tracking for fast parameter updates. */ expr **param_nodes; /* weak references to parameter nodes in tree */ int n_param_nodes; int total_parameter_size; diff --git a/include/subexpr.h b/include/subexpr.h index 3a9cc54..6588621 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -35,7 +35,8 @@ struct int_double_pair; typedef struct parameter_expr { expr base; - int param_id; /* offset into global theta vector, or PARAM_FIXED */ + int param_id; /* offset into global theta vector, or PARAM_FIXED */ + bool has_been_refreshed; /* tracks whether parameter has been refreshed */ } parameter_expr; /* Type-specific expression structures that "inherit" from expr */ @@ -126,6 +127,7 @@ typedef struct left_matmul_expr CSC_Matrix *Jchild_CSC; CSC_Matrix *J_CSC; int *csc_to_csr_workspace; + int *AT_iwork; /* work for computing AT values from A */ expr *param_source; /* parameter node; A/AT values are refreshed from this */ } left_matmul_expr; diff --git a/src/affine/parameter.c b/src/affine/parameter.c index c7698db..7564827 100644 --- a/src/affine/parameter.c +++ b/src/affine/parameter.c @@ -74,6 +74,7 @@ expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *valu init_expr(node, d1, d2, n_vars, forward, jacobian_init, eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL); pnode->param_id = param_id; + pnode->has_been_refreshed = false; /* If values provided (fixed constant), copy them now. Otherwise values will be populated by problem_update_params. */ diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index 679da89..27a6db9 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -55,13 +55,17 @@ No-op when param_source is NULL (fixed constant — values already in A). */ static void refresh_param_values(left_matmul_expr *lin_node) { - if (!lin_node->param_source) return; + parameter_expr *param = (parameter_expr *) lin_node->param_source; + if (!param || param->has_been_refreshed) return; + param->has_been_refreshed = true; + + /* update values of A */ memcpy(lin_node->A->x, lin_node->param_source->value, lin_node->A->nnz * sizeof(double)); - /* Recompute AT values from updated A */ - AT_fill_values(lin_node->A, lin_node->AT, lin_node->base.iwork); + /* update values of AT */ + AT_fill_values(lin_node->A, lin_node->AT, lin_node->AT_iwork); } static void forward(expr *node, const double *u) @@ -69,7 +73,7 @@ static void forward(expr *node, const double *u) expr *x = node->left; left_matmul_expr *lin_node = (left_matmul_expr *) node; - /* refresh A/AT from parameter source */ + /* possibly refresh A and AT */ refresh_param_values(lin_node); /* child's forward pass */ @@ -92,6 +96,7 @@ static void free_type_data(expr *node) free_csc_matrix(lin_node->Jchild_CSC); free_csc_matrix(lin_node->J_CSC); free(lin_node->csc_to_csr_workspace); + free(lin_node->AT_iwork); free_expr(lin_node->param_source); } @@ -119,9 +124,6 @@ static void eval_jacobian(expr *node) CSC_Matrix *Jchild_CSC = lnode->Jchild_CSC; CSC_Matrix *J_CSC = lnode->J_CSC; - /* refresh A from parameter source */ - refresh_param_values(lnode); - /* evaluate child's jacobian and convert to CSC */ x->eval_jacobian(x); csr_to_csc_fill_values(x->jacobian, Jchild_CSC, node->iwork); @@ -167,7 +169,6 @@ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A) to do A @ u where u is (n, ) which in C is actually (1, n). In that case the result of A @ u is (m, ), which is (1, m) according to broadcasting rules. We therefore check if this is the case. */ - int d1, d2, n_blocks; if (child->d1 == A->n) { @@ -197,13 +198,14 @@ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A) expr_retain(child); /* Store small A (NOT block-diagonal) — block functions handle the rest */ - node->iwork = (int *) malloc(MAX(A->n, node->n_vars) * sizeof(int)); + node->iwork = (int *) malloc(node->n_vars * sizeof(int)); + lin_node->AT_iwork = (int *) malloc(A->n * sizeof(int)); lin_node->csc_to_csr_workspace = (int *) malloc(node->size * sizeof(int)); lin_node->n_blocks = n_blocks; lin_node->A = new_csr(A); - lin_node->AT = transpose(lin_node->A, node->iwork); - + lin_node->AT = transpose(lin_node->A, lin_node->AT_iwork); lin_node->param_source = param_node; + if (param_node) expr_retain(param_node); return node; diff --git a/src/problem.c b/src/problem.c index c9e9a1e..4ee0078 100644 --- a/src/problem.c +++ b/src/problem.c @@ -456,7 +456,9 @@ void problem_register_params(problem *prob, expr **param_nodes, int n_param_node prob->total_parameter_size = 0; for (int i = 0; i < n_param_nodes; i++) + { prob->total_parameter_size += param_nodes[i]->size; + } } void problem_update_params(problem *prob, const double *theta) @@ -466,7 +468,9 @@ void problem_update_params(problem *prob, const double *theta) parameter_expr *p = (parameter_expr *) prob->param_nodes[i]; if (p->param_id == PARAM_FIXED) continue; memcpy(p->base.value, theta + p->param_id, p->base.size * sizeof(double)); + p->has_been_refreshed = false; } - /* Force re-evaluation of affine Jacobians on next call */ + + /* force re-evaluation of affine Jacobians on next call */ prob->jacobian_called = false; } From b985028c09e2c387fdca56b2144ffbff0284d34a Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 17 Feb 2026 11:57:26 -0800 Subject: [PATCH 21/24] add parameter support for right matmul --- include/bivariate.h | 4 +- include/subexpr.h | 1 + src/bivariate/left_matmul.c | 6 +- src/bivariate/right_matmul.c | 35 +++++++-- src/bivariate/scalar_mult.c | 9 +-- src/bivariate/vector_mult.c | 18 +++-- tests/all_tests.c | 1 + tests/jacobian_tests/test_right_matmul.h | 4 +- tests/problem/test_param_prob.h | 94 ++++++++++++++++++++++++ tests/wsum_hess/test_right_matmul.h | 4 +- 10 files changed, 151 insertions(+), 25 deletions(-) diff --git a/include/bivariate.h b/include/bivariate.h index 3795786..a832d14 100644 --- a/include/bivariate.h +++ b/include/bivariate.h @@ -34,8 +34,8 @@ expr *new_matmul(expr *x, expr *y); Only the forward pass possibly updates the parameter. */ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A); -/* Right matrix multiplication: f(x) @ A where A is a fixed parameter matrix */ -expr *new_right_matmul(expr *u, const CSR_Matrix *A); +/* Right matrix multiplication: f(x) @ A where A comes from a parameter node. */ +expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A); /* Scalar multiplication: a * f(x) where a comes from a parameter node */ expr *new_scalar_mult(expr *param_node, expr *child); diff --git a/include/subexpr.h b/include/subexpr.h index 6588621..fdb57e4 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -129,6 +129,7 @@ typedef struct left_matmul_expr int *csc_to_csr_workspace; int *AT_iwork; /* work for computing AT values from A */ expr *param_source; /* parameter node; A/AT values are refreshed from this */ + void (*refresh_param_values)(struct left_matmul_expr *lin_node); } left_matmul_expr; /* Right matrix multiplication: y = f(x) * A where f(x) is an expression. diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index 27a6db9..7298d83 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -48,6 +48,7 @@ */ #include "utils/utils.h" +#include #include /* Refresh A and AT values from param_source. @@ -60,6 +61,8 @@ static void refresh_param_values(left_matmul_expr *lin_node) if (!param || param->has_been_refreshed) return; param->has_been_refreshed = true; + assert(param->param_id != PARAM_FIXED); + /* update values of A */ memcpy(lin_node->A->x, lin_node->param_source->value, lin_node->A->nnz * sizeof(double)); @@ -74,7 +77,7 @@ static void forward(expr *node, const double *u) left_matmul_expr *lin_node = (left_matmul_expr *) node; /* possibly refresh A and AT */ - refresh_param_values(lin_node); + if (lin_node->refresh_param_values) lin_node->refresh_param_values(lin_node); /* child's forward pass */ node->left->forward(node->left, u); @@ -205,6 +208,7 @@ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A) lin_node->A = new_csr(A); lin_node->AT = transpose(lin_node->A, lin_node->AT_iwork); lin_node->param_source = param_node; + lin_node->refresh_param_values = refresh_param_values; if (param_node) expr_retain(param_node); diff --git a/src/bivariate/right_matmul.c b/src/bivariate/right_matmul.c index 58b91e4..dd0f38b 100644 --- a/src/bivariate/right_matmul.c +++ b/src/bivariate/right_matmul.c @@ -20,26 +20,49 @@ #include "subexpr.h" #include "utils/CSR_Matrix.h" #include "utils/linalg_sparse_matmuls.h" +#include #include +#include + +/* Refresh AT and A values from param_source for right matmul. + param_source stores values in CSR order for the original A. */ +static void refresh_param_values(left_matmul_expr *lin_node) +{ + parameter_expr *param = (parameter_expr *) lin_node->param_source; + + if (!param || param->has_been_refreshed) return; + param->has_been_refreshed = true; + + assert(param->param_id != PARAM_FIXED); + + /* update values of original A (stored in lin_node->AT) */ + memcpy(lin_node->AT->x, lin_node->param_source->value, + lin_node->AT->nnz * sizeof(double)); + + /* update values of A^T (stored in lin_node->A) */ + AT_fill_values(lin_node->AT, lin_node->A, lin_node->AT_iwork); +} /* This file implements the atom 'right_matmul' corresponding to the operation y = - f(x) @ A, where A is a given matrix and f(x) is an arbitrary expression. + f(x) @ A, where A is a given matrix and f(x) is an arbitrary expression. We implement this by expressing right matmul in terms of left matmul and transpose: f(x) @ A = (A^T @ f(x)^T)^T. */ -expr *new_right_matmul(expr *u, const CSR_Matrix *A) +expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A) { /* We can express right matmul using left matmul and transpose: u @ A = (A^T @ u^T)^T. */ int *work_transpose = (int *) malloc(A->n * sizeof(int)); CSR_Matrix *AT = transpose(A, work_transpose); - - /* Parameter stores CSR data order (same as AT->x) */ expr *u_transpose = new_transpose(u); - expr *param_node = new_parameter(AT->nnz, 1, PARAM_FIXED, u->n_vars, AT->x); expr *left_matmul_node = new_left_matmul(param_node, u_transpose, AT); expr *node = new_transpose(left_matmul_node); + /* functionality for parameter */ + left_matmul_expr *left_matmul_data = (left_matmul_expr *) left_matmul_node; + free(left_matmul_data->AT_iwork); + left_matmul_data->AT_iwork = work_transpose; + left_matmul_data->refresh_param_values = refresh_param_values; + free_csr_matrix(AT); - free(work_transpose); return node; } diff --git a/src/bivariate/scalar_mult.c b/src/bivariate/scalar_mult.c index 19aae60..415687e 100644 --- a/src/bivariate/scalar_mult.c +++ b/src/bivariate/scalar_mult.c @@ -32,8 +32,7 @@ static void forward(expr *node, const double *u) child->forward(child, u); /* local forward pass: multiply each element by scalar a */ - scalar_mult_expr *sn = (scalar_mult_expr *) node; - double a = sn->param_source->value[0]; + double a = ((scalar_mult_expr *) node)->param_source->value[0]; for (int i = 0; i < node->size; i++) { node->value[i] = a * child->value[i]; @@ -56,8 +55,7 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *child = node->left; - scalar_mult_expr *sn = (scalar_mult_expr *) node; - double a = sn->param_source->value[0]; + double a = ((scalar_mult_expr *) node)->param_source->value[0]; /* evaluate child */ child->eval_jacobian(child); @@ -87,8 +85,7 @@ static void eval_wsum_hess(expr *node, const double *w) expr *x = node->left; x->eval_wsum_hess(x, w); - scalar_mult_expr *sn = (scalar_mult_expr *) node; - double a = sn->param_source->value[0]; + double a = ((scalar_mult_expr *) node)->param_source->value[0]; for (int j = 0; j < x->wsum_hess->nnz; j++) { node->wsum_hess->x[j] = a * x->wsum_hess->x[j]; diff --git a/src/bivariate/vector_mult.c b/src/bivariate/vector_mult.c index 83e6cbf..21f8231 100644 --- a/src/bivariate/vector_mult.c +++ b/src/bivariate/vector_mult.c @@ -27,8 +27,10 @@ static void forward(expr *node, const double *u) { expr *child = node->left; - vector_mult_expr *vn = (vector_mult_expr *) node; - const double *a = vn->param_source->value; + // vector_mult_expr *vn = (vector_mult_expr *) node; + // const double *a = vn->param_source->value; + + const double *a = ((vector_mult_expr *) node)->param_source->value; /* child's forward pass */ child->forward(child, u); @@ -56,8 +58,10 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *x = node->left; - vector_mult_expr *vn = (vector_mult_expr *) node; - const double *a = vn->param_source->value; + // vector_mult_expr *vn = (vector_mult_expr *) node; + // const double *a = vn->param_source->value; + + const double *a = ((vector_mult_expr *) node)->param_source->value; /* evaluate x */ x->eval_jacobian(x); @@ -90,8 +94,10 @@ static void wsum_hess_init(expr *node) static void eval_wsum_hess(expr *node, const double *w) { expr *x = node->left; - vector_mult_expr *vn = (vector_mult_expr *) node; - const double *a = vn->param_source->value; + // vector_mult_expr *vn = (vector_mult_expr *) node; + // const double *a = vn->param_source->value; + + const double *a = ((vector_mult_expr *) node)->param_source->value; /* scale weights w by a */ for (int i = 0; i < node->size; i++) diff --git a/tests/all_tests.c b/tests/all_tests.c index 8cede31..e652965 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -284,6 +284,7 @@ int main(void) mu_run_test(test_param_scalar_mult_problem, tests_run); mu_run_test(test_param_vector_mult_problem, tests_run); mu_run_test(test_param_left_matmul_problem, tests_run); + mu_run_test(test_param_right_matmul_problem, tests_run); #endif /* PROFILE_ONLY */ #ifdef PROFILE_ONLY diff --git a/tests/jacobian_tests/test_right_matmul.h b/tests/jacobian_tests/test_right_matmul.h index 931df44..a181d60 100644 --- a/tests/jacobian_tests/test_right_matmul.h +++ b/tests/jacobian_tests/test_right_matmul.h @@ -27,7 +27,7 @@ const char *test_jacobian_right_matmul_log() memcpy(A->x, A_x, 4 * sizeof(double)); expr *log_x = new_log(x); - expr *log_x_A = new_right_matmul(log_x, A); + expr *log_x_A = new_right_matmul(NULL, log_x, A); log_x_A->forward(log_x_A, x_vals); log_x_A->jacobian_init(log_x_A); @@ -76,7 +76,7 @@ const char *test_jacobian_right_matmul_log_vector() memcpy(A->x, A_x, 4 * sizeof(double)); expr *log_x = new_log(x); - expr *log_x_A = new_right_matmul(log_x, A); + expr *log_x_A = new_right_matmul(NULL, log_x, A); log_x_A->forward(log_x_A, x_vals); log_x_A->jacobian_init(log_x_A); diff --git a/tests/problem/test_param_prob.h b/tests/problem/test_param_prob.h index 79761b4..29b9c1b 100644 --- a/tests/problem/test_param_prob.h +++ b/tests/problem/test_param_prob.h @@ -254,4 +254,98 @@ const char *test_param_left_matmul_problem(void) return 0; } +/* + * Test 4: right_param_matmul in constraint + * + * Problem: minimize sum(x), subject to x @ A, x size 1x2, A is 2x2 + * A is a 2x2 matrix parameter (param_id=0, size=4, CSR data order) + * A = [[1,2],[3,4]] → CSR data order theta = [1,2,3,4] + * + * At x=[1,2]: + * constraint_values = [1*1+2*3, 1*2+2*4] = [7, 10] + * jacobian = [[1,3],[2,4]] = A^T + * + * After update A = [[5,6],[7,8]] → theta = [5,6,7,8]: + * constraint_values = [1*5+2*7, 1*6+2*8] = [19, 22] + * jacobian = [[5,7],[6,8]] = A^T + */ +const char *test_param_right_matmul_problem(void) +{ + int n_vars = 2; + + /* Objective: sum(x) */ + expr *x_obj = new_variable(1, 2, 0, n_vars); + expr *objective = new_sum(x_obj, -1); + + /* Constraint: x @ A */ + expr *x_con = new_variable(1, 2, 0, n_vars); + expr *A_param = new_parameter(2, 2, 0, n_vars, NULL); + + /* Dense 2x2 CSR with placeholder zeros (values refreshed from A_param) */ + CSR_Matrix *A = new_csr_matrix(2, 2, 4); + int Ap[3] = {0, 2, 4}; + int Ai[4] = {0, 1, 0, 1}; + double Ax[4] = {0.0, 0.0, 0.0, 0.0}; + memcpy(A->p, Ap, 3 * sizeof(int)); + memcpy(A->i, Ai, 4 * sizeof(int)); + memcpy(A->x, Ax, 4 * sizeof(double)); + + expr *constraint = new_right_matmul(A_param, x_con, A); + free_csr_matrix(A); + + expr *constraints[1] = {constraint}; + + /* Create problem */ + problem *prob = new_problem(objective, constraints, 1, true); + + expr *param_nodes[1] = {A_param}; + problem_register_params(prob, param_nodes, 1); + problem_init_derivatives(prob); + + /* Set A = [[1,2],[3,4]], CSR data order: [1,2,3,4] */ + double theta[4] = {1.0, 2.0, 3.0, 4.0}; + problem_update_params(prob, theta); + + double u[2] = {1.0, 2.0}; + problem_constraint_forward(prob, u); + problem_jacobian(prob); + + double expected_cv[2] = {7.0, 10.0}; + mu_assert("constraint values wrong (A1)", + cmp_double_array(prob->constraint_values, expected_cv, 2)); + + CSR_Matrix *jac = prob->jacobian; + mu_assert("jac rows wrong", jac->m == 2); + mu_assert("jac cols wrong", jac->n == 2); + + /* Dense jacobian = [[1,3],[2,4]] = A^T, CSR: row 0 → cols 0,1 vals 1,3; + * row 1 → cols 0,1 vals 2,4 */ + int expected_p[3] = {0, 2, 4}; + mu_assert("jac->p wrong (A1)", cmp_int_array(jac->p, expected_p, 3)); + + int expected_i[4] = {0, 1, 0, 1}; + mu_assert("jac->i wrong (A1)", cmp_int_array(jac->i, expected_i, 4)); + + double expected_x[4] = {1.0, 3.0, 2.0, 4.0}; + mu_assert("jac->x wrong (A1)", cmp_double_array(jac->x, expected_x, 4)); + + /* Update A = [[5,6],[7,8]], CSR data order: [5,6,7,8] */ + double theta2[4] = {5.0, 6.0, 7.0, 8.0}; + problem_update_params(prob, theta2); + + problem_constraint_forward(prob, u); + problem_jacobian(prob); + + double expected_cv2[2] = {19.0, 22.0}; + mu_assert("constraint values wrong (A2)", + cmp_double_array(prob->constraint_values, expected_cv2, 2)); + + double expected_x2[4] = {5.0, 7.0, 6.0, 8.0}; + mu_assert("jac->x wrong (A2)", cmp_double_array(jac->x, expected_x2, 4)); + + free_problem(prob); + + return 0; +} + #endif /* TEST_PARAM_PROB_H */ diff --git a/tests/wsum_hess/test_right_matmul.h b/tests/wsum_hess/test_right_matmul.h index ca109f8..575f38c 100644 --- a/tests/wsum_hess/test_right_matmul.h +++ b/tests/wsum_hess/test_right_matmul.h @@ -33,7 +33,7 @@ const char *test_wsum_hess_right_matmul() memcpy(A->x, A_x, 4 * sizeof(double)); expr *log_x = new_log(x); - expr *log_x_A = new_right_matmul(log_x, A); + expr *log_x_A = new_right_matmul(NULL, log_x, A); log_x_A->forward(log_x_A, x_vals); log_x_A->jacobian_init(log_x_A); @@ -83,7 +83,7 @@ const char *test_wsum_hess_right_matmul_vector() memcpy(A->x, A_x, 4 * sizeof(double)); expr *log_x = new_log(x); - expr *log_x_A = new_right_matmul(log_x, A); + expr *log_x_A = new_right_matmul(NULL, log_x, A); log_x_A->forward(log_x_A, x_vals); log_x_A->jacobian_init(log_x_A); From 4e25da77feab865303ebe37ebe2e9c737043ca11 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 17 Feb 2026 12:01:06 -0800 Subject: [PATCH 22/24] clean up --- src/bivariate/left_matmul.c | 5 ++++- src/bivariate/vector_mult.c | 9 --------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index 7298d83..5a8e2d2 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -200,7 +200,10 @@ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A) node->left = child; expr_retain(child); - /* Store small A (NOT block-diagonal) — block functions handle the rest */ + /* Store small A (NOT block-diagonal) — block functions handle the rest + Allocate workspace. iwork is used for converting J_child csr to csc + (requring size node->n_vars). csc_to_csr_workspace is used for + converting J_CSC to CSR (requring node->size) */ node->iwork = (int *) malloc(node->n_vars * sizeof(int)); lin_node->AT_iwork = (int *) malloc(A->n * sizeof(int)); lin_node->csc_to_csr_workspace = (int *) malloc(node->size * sizeof(int)); diff --git a/src/bivariate/vector_mult.c b/src/bivariate/vector_mult.c index 21f8231..6956aec 100644 --- a/src/bivariate/vector_mult.c +++ b/src/bivariate/vector_mult.c @@ -27,9 +27,6 @@ static void forward(expr *node, const double *u) { expr *child = node->left; - // vector_mult_expr *vn = (vector_mult_expr *) node; - // const double *a = vn->param_source->value; - const double *a = ((vector_mult_expr *) node)->param_source->value; /* child's forward pass */ @@ -58,9 +55,6 @@ static void jacobian_init(expr *node) static void eval_jacobian(expr *node) { expr *x = node->left; - // vector_mult_expr *vn = (vector_mult_expr *) node; - // const double *a = vn->param_source->value; - const double *a = ((vector_mult_expr *) node)->param_source->value; /* evaluate x */ @@ -94,9 +88,6 @@ static void wsum_hess_init(expr *node) static void eval_wsum_hess(expr *node, const double *w) { expr *x = node->left; - // vector_mult_expr *vn = (vector_mult_expr *) node; - // const double *a = vn->param_source->value; - const double *a = ((vector_mult_expr *) node)->param_source->value; /* scale weights w by a */ From 4b58b1cfb5c642cfb6cf3e474108f3e6c627023d Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 17 Feb 2026 12:25:11 -0800 Subject: [PATCH 23/24] add back setting ptrs to null after freeingmake --- src/affine/index.c | 1 + src/affine/linear_op.c | 9 ++++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/affine/index.c b/src/affine/index.c index 5702100..80fe77d 100644 --- a/src/affine/index.c +++ b/src/affine/index.c @@ -155,6 +155,7 @@ static void free_type_data(expr *node) { index_expr *idx = (index_expr *) node; free(idx->indices); + idx->indices = NULL; } expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs) diff --git a/src/affine/linear_op.c b/src/affine/linear_op.c index 04e5b9b..9bc262e 100644 --- a/src/affine/linear_op.c +++ b/src/affine/linear_op.c @@ -55,14 +55,13 @@ static void free_type_data(expr *node) if (!node->jacobian) { free_csr_matrix(lin_node->A_csr); + lin_node->A_csr = NULL; } free_csc_matrix(lin_node->A_csc); - - if (lin_node->b != NULL) - { - free(lin_node->b); - } + lin_node->A_csc = NULL; + free(lin_node->b); + lin_node->b = NULL; } static void jacobian_init(expr *node) From 1c3ebeecbe8843bc682c3ac9d24a81e722cd14d2 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 17 Feb 2026 12:25:29 -0800 Subject: [PATCH 24/24] added free and null macro --- include/memory_wrappers.h | 13 +++++++++++++ include/utils/Vec_macros.h | 8 +++++--- src/affine/hstack.c | 3 ++- src/affine/index.c | 6 +++--- src/affine/linear_op.c | 4 ++-- src/affine/sum.c | 3 ++- src/affine/trace.c | 3 ++- src/bivariate/left_matmul.c | 5 +++-- src/bivariate/quad_over_lin.c | 3 ++- src/bivariate/right_matmul.c | 3 ++- src/expr.c | 12 +++++------- src/other/prod_axis_one.c | 7 ++++--- src/other/prod_axis_zero.c | 7 ++++--- src/other/quad_form.c | 3 ++- src/problem.c | 15 ++++++++------- src/utils/CSC_Matrix.c | 23 ++++++++++++----------- src/utils/CSR_Matrix.c | 11 ++++++----- src/utils/int_double_pair.c | 3 ++- src/utils/linalg_sparse_matmuls.c | 5 +++-- tests/profiling/profile_left_matmul.h | 3 ++- tests/utils/test_csr_csc_conversion.h | 11 ++++++----- tests/utils/test_csr_matrix.h | 3 ++- 22 files changed, 92 insertions(+), 62 deletions(-) create mode 100644 include/memory_wrappers.h diff --git a/include/memory_wrappers.h b/include/memory_wrappers.h new file mode 100644 index 0000000..00ed6bd --- /dev/null +++ b/include/memory_wrappers.h @@ -0,0 +1,13 @@ +#ifndef MEMORY_WRAPPERS_H +#define MEMORY_WRAPPERS_H + +#include + +#define FREE_AND_NULL(p) \ + do \ + { \ + free(p); \ + (p) = NULL; \ + } while (0) + +#endif /* MEMORY_WRAPPERS_H */ diff --git a/include/utils/Vec_macros.h b/include/utils/Vec_macros.h index def45e4..4f0c8ba 100644 --- a/include/utils/Vec_macros.h +++ b/include/utils/Vec_macros.h @@ -19,6 +19,7 @@ #ifndef VEC_MACROS_H #define VEC_MACROS_H +#include "memory_wrappers.h" #include #include #include @@ -48,7 +49,7 @@ vec->data = (TYPE *) malloc(capacity * sizeof(TYPE)); \ if (vec->data == NULL) \ { \ - free(vec); \ + FREE_AND_NULL(vec); \ return NULL; \ } \ \ @@ -59,8 +60,9 @@ \ static inline void TYPE_NAME##Vec_free(TYPE_NAME##Vec *vec) \ { \ - free(vec->data); \ - free(vec); \ + if (!vec) return; \ + FREE_AND_NULL(vec->data); \ + FREE_AND_NULL(vec); \ } \ \ static inline void TYPE_NAME##Vec_clear_no_resize(TYPE_NAME##Vec *vec) \ diff --git a/src/affine/hstack.c b/src/affine/hstack.c index b5d8fb8..39b6057 100644 --- a/src/affine/hstack.c +++ b/src/affine/hstack.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "affine.h" +#include "memory_wrappers.h" #include "utils/CSR_sum.h" #include #include @@ -168,7 +169,7 @@ static void free_type_data(expr *node) } free_csr_matrix(hnode->CSR_work); - free(hnode->args); + FREE_AND_NULL(hnode->args); } expr *new_hstack(expr **args, int n_args, int n_vars) diff --git a/src/affine/index.c b/src/affine/index.c index 80fe77d..c646d9c 100644 --- a/src/affine/index.c +++ b/src/affine/index.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "affine.h" +#include "memory_wrappers.h" #include "subexpr.h" #include #include @@ -38,7 +39,7 @@ static bool check_for_duplicates(const int *indices, int n_idxs, int max_idx) } seen[indices[i]] = true; } - free(seen); + FREE_AND_NULL(seen); return has_dup; } @@ -154,8 +155,7 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { index_expr *idx = (index_expr *) node; - free(idx->indices); - idx->indices = NULL; + FREE_AND_NULL(idx->indices); } expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs) diff --git a/src/affine/linear_op.c b/src/affine/linear_op.c index 9bc262e..1eb252e 100644 --- a/src/affine/linear_op.c +++ b/src/affine/linear_op.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "affine.h" +#include "memory_wrappers.h" #include #include #include @@ -60,8 +61,7 @@ static void free_type_data(expr *node) free_csc_matrix(lin_node->A_csc); lin_node->A_csc = NULL; - free(lin_node->b); - lin_node->b = NULL; + FREE_AND_NULL(lin_node->b); } static void jacobian_init(expr *node) diff --git a/src/affine/sum.c b/src/affine/sum.c index d1fec14..ce30d70 100644 --- a/src/affine/sum.c +++ b/src/affine/sum.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "affine.h" +#include "memory_wrappers.h" #include "utils/CSR_sum.h" #include "utils/int_double_pair.h" #include "utils/mini_numpy.h" @@ -175,7 +176,7 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { sum_expr *snode = (sum_expr *) node; - free(snode->idx_map); + FREE_AND_NULL(snode->idx_map); } expr *new_sum(expr *child, int axis) diff --git a/src/affine/trace.c b/src/affine/trace.c index 1732c40..5a25bfb 100644 --- a/src/affine/trace.c +++ b/src/affine/trace.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "affine.h" +#include "memory_wrappers.h" #include "utils/CSR_sum.h" #include "utils/int_double_pair.h" #include "utils/utils.h" @@ -139,7 +140,7 @@ static void free_type_data(expr *node) if (node) { trace_expr *tnode = (trace_expr *) node; - free(tnode->idx_map); + FREE_AND_NULL(tnode->idx_map); } } diff --git a/src/bivariate/left_matmul.c b/src/bivariate/left_matmul.c index 5a8e2d2..1cecfc1 100644 --- a/src/bivariate/left_matmul.c +++ b/src/bivariate/left_matmul.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "bivariate.h" +#include "memory_wrappers.h" #include "subexpr.h" #include "utils/Timer.h" #include "utils/linalg_sparse_matmuls.h" @@ -98,8 +99,8 @@ static void free_type_data(expr *node) free_csr_matrix(lin_node->AT); free_csc_matrix(lin_node->Jchild_CSC); free_csc_matrix(lin_node->J_CSC); - free(lin_node->csc_to_csr_workspace); - free(lin_node->AT_iwork); + FREE_AND_NULL(lin_node->csc_to_csr_workspace); + FREE_AND_NULL(lin_node->AT_iwork); free_expr(lin_node->param_source); } diff --git a/src/bivariate/quad_over_lin.c b/src/bivariate/quad_over_lin.c index 6b781d9..4da4135 100644 --- a/src/bivariate/quad_over_lin.c +++ b/src/bivariate/quad_over_lin.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "bivariate.h" +#include "memory_wrappers.h" #include "subexpr.h" #include "utils/CSC_Matrix.h" #include @@ -102,7 +103,7 @@ static void jacobian_init(expr *node) } assert(nonzero_cols == node->jacobian->nnz); - free(col_nz); + FREE_AND_NULL(col_nz); /* insert y variable index at correct position */ insert_idx(y->var_id, node->jacobian->i, node->jacobian->nnz); diff --git a/src/bivariate/right_matmul.c b/src/bivariate/right_matmul.c index dd0f38b..166ed72 100644 --- a/src/bivariate/right_matmul.c +++ b/src/bivariate/right_matmul.c @@ -17,6 +17,7 @@ */ #include "affine.h" #include "bivariate.h" +#include "memory_wrappers.h" #include "subexpr.h" #include "utils/CSR_Matrix.h" #include "utils/linalg_sparse_matmuls.h" @@ -59,7 +60,7 @@ expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A) /* functionality for parameter */ left_matmul_expr *left_matmul_data = (left_matmul_expr *) left_matmul_node; - free(left_matmul_data->AT_iwork); + FREE_AND_NULL(left_matmul_data->AT_iwork); left_matmul_data->AT_iwork = work_transpose; left_matmul_data->refresh_param_values = refresh_param_values; diff --git a/src/expr.c b/src/expr.c index 99a7175..d093064 100644 --- a/src/expr.c +++ b/src/expr.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "expr.h" +#include "memory_wrappers.h" #include "utils/int_double_pair.h" #include #include @@ -61,19 +62,16 @@ void free_expr(expr *node) } /* free value array and jacobian */ - free(node->value); + FREE_AND_NULL(node->value); free_csr_matrix(node->jacobian); free_csr_matrix(node->wsum_hess); - free(node->dwork); - free(node->iwork); - node->value = NULL; + FREE_AND_NULL(node->dwork); + FREE_AND_NULL(node->iwork); node->jacobian = NULL; node->wsum_hess = NULL; - node->dwork = NULL; - node->iwork = NULL; /* free the node itself */ - free(node); + FREE_AND_NULL(node); } void expr_retain(expr *node) diff --git a/src/other/prod_axis_one.c b/src/other/prod_axis_one.c index cd43229..8a3e02b 100644 --- a/src/other/prod_axis_one.c +++ b/src/other/prod_axis_one.c @@ -1,3 +1,4 @@ +#include "memory_wrappers.h" #include "other.h" #include #include @@ -371,9 +372,9 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { prod_axis *pnode = (prod_axis *) node; - free(pnode->num_of_zeros); - free(pnode->zero_index); - free(pnode->prod_nonzero); + FREE_AND_NULL(pnode->num_of_zeros); + FREE_AND_NULL(pnode->zero_index); + FREE_AND_NULL(pnode->prod_nonzero); } expr *new_prod_axis_one(expr *child) diff --git a/src/other/prod_axis_zero.c b/src/other/prod_axis_zero.c index ad85b6e..128e1a0 100644 --- a/src/other/prod_axis_zero.c +++ b/src/other/prod_axis_zero.c @@ -1,3 +1,4 @@ +#include "memory_wrappers.h" #include "other.h" #include #include @@ -330,9 +331,9 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { prod_axis *pnode = (prod_axis *) node; - free(pnode->num_of_zeros); - free(pnode->zero_index); - free(pnode->prod_nonzero); + FREE_AND_NULL(pnode->num_of_zeros); + FREE_AND_NULL(pnode->zero_index); + FREE_AND_NULL(pnode->prod_nonzero); } /* TODO: refactor to remove diagonal entry as nonzero since it's always zero */ diff --git a/src/other/quad_form.c b/src/other/quad_form.c index 25fedce..fb7a97f 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -1,3 +1,4 @@ +#include "memory_wrappers.h" #include "other.h" #include "subexpr.h" #include "utils/CSC_Matrix.h" @@ -129,7 +130,7 @@ static void jacobian_init(expr *node) } } assert(nonzero_cols == node->jacobian->nnz); - free(col_nz); + FREE_AND_NULL(col_nz); node->jacobian->p[0] = 0; node->jacobian->p[1] = node->jacobian->nnz; diff --git a/src/problem.c b/src/problem.c index 4ee0078..294db92 100644 --- a/src/problem.c +++ b/src/problem.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "problem.h" +#include "memory_wrappers.h" #include "subexpr.h" #include "utils/CSR_sum.h" #include "utils/utils.h" @@ -235,7 +236,7 @@ void problem_init_hessian(problem *prob) prob->hess_idx_map = (int *) malloc(nnz * sizeof(int)); int *iwork = (int *) malloc(MAX(nnz, prob->n_vars) * sizeof(int)); problem_lagrange_hess_fill_sparsity(prob, iwork); - free(iwork); + FREE_AND_NULL(iwork); clock_gettime(CLOCK_MONOTONIC, &timer.end); prob->stats.time_init_derivatives += GET_ELAPSED_SECONDS(timer); @@ -286,14 +287,14 @@ void free_problem(problem *prob) if (prob == NULL) return; /* Free allocated arrays */ - free(prob->constraint_values); - free(prob->gradient_values); + FREE_AND_NULL(prob->constraint_values); + FREE_AND_NULL(prob->gradient_values); free_csr_matrix(prob->jacobian); free_csr_matrix(prob->lagrange_hessian); - free(prob->hess_idx_map); + FREE_AND_NULL(prob->hess_idx_map); /* Free parameter node array (weak references, not owned) */ - free(prob->param_nodes); + FREE_AND_NULL(prob->param_nodes); /* Release expression references (decrements refcount) */ free_expr(prob->objective); @@ -301,7 +302,7 @@ void free_problem(problem *prob) { free_expr(prob->constraints[i]); } - free(prob->constraints); + FREE_AND_NULL(prob->constraints); if (prob->verbose) { @@ -309,7 +310,7 @@ void free_problem(problem *prob) } /* Free problem struct */ - free(prob); + FREE_AND_NULL(prob); } double problem_objective_forward(problem *prob, const double *u) diff --git a/src/utils/CSC_Matrix.c b/src/utils/CSC_Matrix.c index d017a62..c445917 100644 --- a/src/utils/CSC_Matrix.c +++ b/src/utils/CSC_Matrix.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "utils/CSC_Matrix.h" +#include "memory_wrappers.h" #include "utils/iVec.h" #include #include @@ -32,10 +33,10 @@ CSC_Matrix *new_csc_matrix(int m, int n, int nnz) if (!matrix->p || !matrix->i || !matrix->x) { - free(matrix->p); - free(matrix->i); - free(matrix->x); - free(matrix); + FREE_AND_NULL(matrix->p); + FREE_AND_NULL(matrix->i); + FREE_AND_NULL(matrix->x); + FREE_AND_NULL(matrix); return NULL; } @@ -50,10 +51,10 @@ void free_csc_matrix(CSC_Matrix *matrix) { if (matrix) { - free(matrix->p); - free(matrix->i); - free(matrix->x); - free(matrix); + FREE_AND_NULL(matrix->p); + FREE_AND_NULL(matrix->i); + FREE_AND_NULL(matrix->x); + FREE_AND_NULL(matrix); } } @@ -105,7 +106,7 @@ CSR_Matrix *ATA_alloc(const CSC_Matrix *A) symmetrize_csr(Cp, Ci->data, n, C); /* free workspace */ - free(Cp); + FREE_AND_NULL(Cp); iVec_free(Ci); return C; @@ -211,7 +212,7 @@ CSC_Matrix *csr_to_csc(const CSR_Matrix *A) } } - free(count); + FREE_AND_NULL(count); return C; } @@ -395,7 +396,7 @@ CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B) memcpy(C->i, Ci->data, nnz * sizeof(int)); /* free workspace */ - free(Cp); + FREE_AND_NULL(Cp); iVec_free(Ci); return C; diff --git a/src/utils/CSR_Matrix.c b/src/utils/CSR_Matrix.c index 513a457..0778804 100644 --- a/src/utils/CSR_Matrix.c +++ b/src/utils/CSR_Matrix.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "utils/CSR_Matrix.h" +#include "memory_wrappers.h" #include "utils/int_double_pair.h" #include "utils/utils.h" #include @@ -49,10 +50,10 @@ void free_csr_matrix(CSR_Matrix *matrix) { if (matrix) { - free(matrix->p); - free(matrix->i); - free(matrix->x); - free(matrix); + FREE_AND_NULL(matrix->p); + FREE_AND_NULL(matrix->i); + FREE_AND_NULL(matrix->x); + FREE_AND_NULL(matrix); } } @@ -448,5 +449,5 @@ void symmetrize_csr(const int *Ap, const int *Ai, int m, CSR_Matrix *C) } } - free(counts); + FREE_AND_NULL(counts); } diff --git a/src/utils/int_double_pair.c b/src/utils/int_double_pair.c index 6b49021..1c0d01b 100644 --- a/src/utils/int_double_pair.c +++ b/src/utils/int_double_pair.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "utils/int_double_pair.h" +#include "memory_wrappers.h" #include static int compare_int_double_pair(const void *a, const void *b) @@ -45,7 +46,7 @@ void set_int_double_pair_array(int_double_pair *pair, int *ints, double *doubles void free_int_double_pair_array(int_double_pair *array) { - free(array); + FREE_AND_NULL(array); } void sort_int_double_pair_array(int_double_pair *array, int size) diff --git a/src/utils/linalg_sparse_matmuls.c b/src/utils/linalg_sparse_matmuls.c index 82c03e0..0bbae7e 100644 --- a/src/utils/linalg_sparse_matmuls.c +++ b/src/utils/linalg_sparse_matmuls.c @@ -15,6 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "memory_wrappers.h" #include "utils/CSC_Matrix.h" #include "utils/CSR_Matrix.h" #include "utils/iVec.h" @@ -183,7 +184,7 @@ CSC_Matrix *block_left_multiply_fill_sparsity(const CSR_Matrix *A, memcpy(C->i, Ci->data, Ci->len * sizeof(int)); /* Clean up workspace */ - free(Cp); + FREE_AND_NULL(Cp); iVec_free(Ci); return C; @@ -312,7 +313,7 @@ CSR_Matrix *csr_csc_matmul_alloc(const CSR_Matrix *A, const CSC_Matrix *B) CSR_Matrix *C = new_csr_matrix(m, p, nnz); memcpy(C->p, Cp, (m + 1) * sizeof(int)); memcpy(C->i, Ci->data, nnz * sizeof(int)); - free(Cp); + FREE_AND_NULL(Cp); iVec_free(Ci); return C; diff --git a/tests/profiling/profile_left_matmul.h b/tests/profiling/profile_left_matmul.h index bb2f12c..38a4bb3 100644 --- a/tests/profiling/profile_left_matmul.h +++ b/tests/profiling/profile_left_matmul.h @@ -7,6 +7,7 @@ #include "bivariate.h" #include "elementwise_univariate.h" #include "expr.h" +#include "memory_wrappers.h" #include "minunit.h" #include "subexpr.h" #include "test_helpers.h" @@ -61,7 +62,7 @@ const char *profile_left_matmul() printf("left_matmul jacobian eval time: %8.3f seconds\n", GET_ELAPSED_SECONDS(timer)); - free(x_vals); + FREE_AND_NULL(x_vals); free_expr(AX); return 0; } diff --git a/tests/utils/test_csr_csc_conversion.h b/tests/utils/test_csr_csc_conversion.h index c7daeb6..ba6f64a 100644 --- a/tests/utils/test_csr_csc_conversion.h +++ b/tests/utils/test_csr_csc_conversion.h @@ -3,6 +3,7 @@ #include #include +#include "memory_wrappers.h" #include "minunit.h" #include "test_helpers.h" #include "utils/CSC_Matrix.h" @@ -46,7 +47,7 @@ const char *test_csr_to_csc_split() mu_assert("C vals incorrect", cmp_double_array(C->x, Cx_correct, 5)); - free(iwork); + FREE_AND_NULL(iwork); free_csr_matrix(A); free_csc_matrix(C); @@ -90,7 +91,7 @@ const char *test_csc_to_csr_sparsity() mu_assert("C dimensions incorrect", C->m == 4 && C->n == 5); mu_assert("C nnz incorrect", C->nnz == 5); - free(iwork); + FREE_AND_NULL(iwork); free_csc_matrix(A); free_csr_matrix(C); @@ -123,7 +124,7 @@ const char *test_csc_to_csr_values() mu_assert("C vals incorrect", cmp_double_array(C->x, Cx_correct, 5)); - free(iwork); + FREE_AND_NULL(iwork); free_csc_matrix(A); free_csr_matrix(C); @@ -161,8 +162,8 @@ const char *test_csr_csc_csr_roundtrip() mu_assert("Round-trip: col indices incorrect", cmp_int_array(C->i, Ai, 8)); mu_assert("Round-trip: row pointers incorrect", cmp_int_array(C->p, Ap, 4)); - free(iwork_csc); - free(iwork_csr); + FREE_AND_NULL(iwork_csc); + FREE_AND_NULL(iwork_csr); free_csr_matrix(A); free_csc_matrix(B); free_csr_matrix(C); diff --git a/tests/utils/test_csr_matrix.h b/tests/utils/test_csr_matrix.h index 09b2e7d..c2d4f9c 100644 --- a/tests/utils/test_csr_matrix.h +++ b/tests/utils/test_csr_matrix.h @@ -2,6 +2,7 @@ #include #include +#include "memory_wrappers.h" #include "minunit.h" #include "test_helpers.h" #include "utils/CSR_Matrix.h" @@ -432,7 +433,7 @@ const char *test_AT_alloc_and_fill() free_csr_matrix(A); free_csr_matrix(AT); - free(iwork); + FREE_AND_NULL(iwork); return 0; }