Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2913dee
Add parameter support to C diff engine
Transurgeon Feb 9, 2026
fea0708
Remove redundant n_params from problem struct
Transurgeon Feb 10, 2026
36ec239
Remove redundant A_m, A_n params from new_left_param_matmul
Transurgeon Feb 10, 2026
4f93f7c
Add total_parameter_size field to problem struct
Transurgeon Feb 10, 2026
cf695a2
Inline get_scalar helper in const_scalar_mult.c
Transurgeon Feb 10, 2026
bd4f3b0
Inline get_vector helper in const_vector_mult.c
Transurgeon Feb 10, 2026
27acb7d
Simplify refresh_param_values: fill one block, memcpy the rest
Transurgeon Feb 10, 2026
00f7732
Skip AT recomputation in param refresh; param matmul is always affine
Transurgeon Feb 10, 2026
dc56c90
Run clang-format on parameter support files
Transurgeon Feb 10, 2026
978f319
Add problem-level tests for parameter support
Transurgeon Feb 10, 2026
01c6f82
Clean up comments in bivariate.h and subexpr.h
Transurgeon Feb 10, 2026
b8cf436
Fix memory leak in new_param_scalar_mult
Transurgeon Feb 10, 2026
9c23b40
Merge origin/main into parameter-support
Transurgeon Feb 13, 2026
bf8a55c
Run clang-format on merge-resolved files
Transurgeon Feb 13, 2026
9700344
Unify Constant and Parameter into single parameter type
Transurgeon Feb 15, 2026
939a910
Remove redundant NULL-after-free, rename const_ files, fix stale comm…
Transurgeon Feb 15, 2026
71dddf1
Run clang-format on cleanup changes
Transurgeon Feb 15, 2026
36864d4
Simplify new_left_matmul: accept CSR directly, remove sparse/dense br…
Transurgeon Feb 16, 2026
978aa03
Store param values in CSR data order, simplify refresh to memcpy
Transurgeon Feb 16, 2026
b3e2304
small edits
dance858 Feb 17, 2026
0edf928
fix AT workspace in left_matmul and add has_been_refreshed
dance858 Feb 17, 2026
b985028
add parameter support for right matmul
dance858 Feb 17, 2026
4e25da7
clean up
dance858 Feb 17, 2026
4b58b1c
add back setting ptrs to null after freeingmake
dance858 Feb 17, 2026
1c3ebee
added free and null macro
dance858 Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/affine.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ expr *new_hstack(expr **args, int n_args, int n_vars);
expr *new_promote(expr *child, int d1, int d2);
expr *new_trace(expr *child);

expr *new_constant(int d1, int d2, int n_vars, const double *values);
expr *new_variable(int d1, int d2, int var_id, int n_vars);
expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values);

expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs);
expr *new_reshape(expr *child, int d1, int d2);
Expand Down
17 changes: 9 additions & 8 deletions include/bivariate.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,17 @@ expr *new_rel_entr_second_arg_scalar(expr *left, expr *right);
/* Matrix multiplication: Z = X @ Y */
expr *new_matmul(expr *x, expr *y);

/* Left matrix multiplication: A @ f(x) where A is a constant matrix */
expr *new_left_matmul(expr *u, const CSR_Matrix *A);
/* Left matrix multiplication: A @ f(x) where A comes from a parameter node.
Only the forward pass possibly updates the parameter. */
expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A);

/* Right matrix multiplication: f(x) @ A where A is a constant matrix */
expr *new_right_matmul(expr *u, const CSR_Matrix *A);
/* Right matrix multiplication: f(x) @ A where A comes from a parameter node. */
expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A);

/* Constant scalar multiplication: a * f(x) where a is a constant double */
expr *new_const_scalar_mult(double a, expr *child);
/* Scalar multiplication: a * f(x) where a comes from a parameter node */
expr *new_scalar_mult(expr *param_node, expr *child);

/* Constant vector elementwise multiplication: a ∘ f(x) where a is constant */
expr *new_const_vector_mult(const double *a, expr *child);
/* Vector elementwise multiplication: a ∘ f(x) where a comes from a parameter node */
expr *new_vector_mult(expr *param_node, expr *child);

#endif /* BIVARIATE_H */
13 changes: 13 additions & 0 deletions include/memory_wrappers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef MEMORY_WRAPPERS_H
#define MEMORY_WRAPPERS_H

#include <stdlib.h>

#define FREE_AND_NULL(p) \
do \
{ \
free(p); \
(p) = NULL; \
} while (0)

#endif /* MEMORY_WRAPPERS_H */
9 changes: 9 additions & 0 deletions include/problem.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ typedef struct problem
* hessian are called */
bool jacobian_called;

/* Parameter tracking for fast parameter updates. */
expr **param_nodes; /* weak references to parameter nodes in tree */
int n_param_nodes;
int total_parameter_size;

/* Statistics for performance measurement */
Diff_engine_stats stats;
bool verbose;
Expand All @@ -78,4 +83,8 @@ void problem_gradient(problem *prob);
void problem_jacobian(problem *prob);
void problem_hessian(problem *prob, double obj_w, const double *w);

/* Parameter support */
void problem_register_params(problem *prob, expr **param_nodes, int n_param_nodes);
void problem_update_params(problem *prob, const double *theta);

#endif
34 changes: 26 additions & 8 deletions include/subexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
/* Forward declaration */
struct int_double_pair;

/* param_id value for fixed (constant) parameters */
#define PARAM_FIXED -1

/* Parameter node: unified leaf for constants and updatable parameters.
* Constants use param_id == PARAM_FIXED and have values set at creation.
* Updatable parameters have param_id >= 0 and are updated via problem_update_params.
*/
typedef struct parameter_expr
{
expr base;
int param_id; /* offset into global theta vector, or PARAM_FIXED */
bool has_been_refreshed; /* tracks whether parameter has been refreshed */
} parameter_expr;

/* Type-specific expression structures that "inherit" from expr */

/* Linear operator: y = A * x + b */
Expand Down Expand Up @@ -113,6 +127,9 @@ typedef struct left_matmul_expr
CSC_Matrix *Jchild_CSC;
CSC_Matrix *J_CSC;
int *csc_to_csr_workspace;
int *AT_iwork; /* work for computing AT values from A */
expr *param_source; /* parameter node; A/AT values are refreshed from this */
void (*refresh_param_values)(struct left_matmul_expr *lin_node);
} left_matmul_expr;

/* Right matrix multiplication: y = f(x) * A where f(x) is an expression.
Expand All @@ -126,19 +143,20 @@ typedef struct right_matmul_expr
CSC_Matrix *CSC_work;
} right_matmul_expr;

/* Constant scalar multiplication: y = a * child where a is a constant double */
typedef struct const_scalar_mult_expr
/* Scalar multiplication: y = a * child where a comes from a parameter node */
typedef struct scalar_mult_expr
{
expr base;
double a;
} const_scalar_mult_expr;
expr *param_source; /* always set; read a from param_source->value[0] */
} scalar_mult_expr;

/* Constant vector elementwise multiplication: y = a \circ child for constant a */
typedef struct const_vector_mult_expr
/* Vector elementwise multiplication: y = a \circ child where a comes from a
* parameter node */
typedef struct vector_mult_expr
{
expr base;
double *a; /* length equals node->size */
} const_vector_mult_expr;
expr *param_source; /* always set; read a from param_source->value */
} vector_mult_expr;

/* Index/slicing: y = child[indices] where indices is a list of flat positions */
typedef struct index_expr
Expand Down
8 changes: 5 additions & 3 deletions include/utils/Vec_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#ifndef VEC_MACROS_H
#define VEC_MACROS_H

#include "memory_wrappers.h"
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
Expand Down Expand Up @@ -48,7 +49,7 @@
vec->data = (TYPE *) malloc(capacity * sizeof(TYPE)); \
if (vec->data == NULL) \
{ \
free(vec); \
FREE_AND_NULL(vec); \
return NULL; \
} \
\
Expand All @@ -59,8 +60,9 @@
\
static inline void TYPE_NAME##Vec_free(TYPE_NAME##Vec *vec) \
{ \
free(vec->data); \
free(vec); \
if (!vec) return; \
FREE_AND_NULL(vec->data); \
FREE_AND_NULL(vec); \
} \
\
static inline void TYPE_NAME##Vec_clear_no_resize(TYPE_NAME##Vec *vec) \
Expand Down
6 changes: 2 additions & 4 deletions src/affine/hstack.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* limitations under the License.
*/
#include "affine.h"
#include "memory_wrappers.h"
#include "utils/CSR_sum.h"
#include <assert.h>
#include <stdio.h>
Expand Down Expand Up @@ -165,13 +166,10 @@ static void free_type_data(expr *node)
for (int i = 0; i < hnode->n_args; i++)
{
free_expr(hnode->args[i]);
hnode->args[i] = NULL;
}

free_csr_matrix(hnode->CSR_work);
hnode->CSR_work = NULL;
free(hnode->args);
hnode->args = NULL;
FREE_AND_NULL(hnode->args);
}

expr *new_hstack(expr **args, int n_args, int n_vars)
Expand Down
9 changes: 3 additions & 6 deletions src/affine/index.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* limitations under the License.
*/
#include "affine.h"
#include "memory_wrappers.h"
#include "subexpr.h"
#include <assert.h>
#include <stdio.h>
Expand All @@ -38,7 +39,7 @@ static bool check_for_duplicates(const int *indices, int n_idxs, int max_idx)
}
seen[indices[i]] = true;
}
free(seen);
FREE_AND_NULL(seen);
return has_dup;
}

Expand Down Expand Up @@ -154,11 +155,7 @@ static bool is_affine(const expr *node)
static void free_type_data(expr *node)
{
index_expr *idx = (index_expr *) node;
if (idx->indices)
{
free(idx->indices);
idx->indices = NULL;
}
FREE_AND_NULL(idx->indices);
}

expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs)
Expand Down
11 changes: 3 additions & 8 deletions src/affine/linear_op.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* limitations under the License.
*/
#include "affine.h"
#include "memory_wrappers.h"
#include <assert.h>
#include <stdlib.h>
#include <string.h>
Expand Down Expand Up @@ -55,18 +56,12 @@ static void free_type_data(expr *node)
if (!node->jacobian)
{
free_csr_matrix(lin_node->A_csr);
lin_node->A_csr = NULL;
}

free_csc_matrix(lin_node->A_csc);

if (lin_node->b != NULL)
{
free(lin_node->b);
lin_node->b = NULL;
}

lin_node->A_csr = NULL;
lin_node->A_csc = NULL;
FREE_AND_NULL(lin_node->b);
}

static void jacobian_init(expr *node)
Expand Down
36 changes: 27 additions & 9 deletions src/affine/constant.c → src/affine/parameter.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,48 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/* Unified parameter/constant leaf node.
*
* When param_id == PARAM_FIXED, this is a constant whose values are set at
* creation and never change. When param_id >= 0, values are updated via
* problem_update_params.
*
* In both cases the derivative behavior is identical: zero Jacobian and
* Hessian with respect to variables (always affine). */

#include "affine.h"
#include "subexpr.h"
#include <stdlib.h>
#include <string.h>

static void forward(expr *node, const double *u)
{
/* Constants don't depend on u; values are already set */
/* Values are set at creation (constants) or by problem_update_params */
(void) node;
(void) u;
}

static void jacobian_init(expr *node)
{
/* Constant jacobian is all zeros: size x n_vars with 0 nonzeros.
* new_csr_matrix uses calloc for row pointers, so they're already 0. */
/* Parameter/constant jacobian is all zeros: size x n_vars with 0 nonzeros */
node->jacobian = new_csr_matrix(node->size, node->n_vars, 0);
}

static void eval_jacobian(expr *node)
{
/* Constant jacobian never changes - nothing to evaluate */
/* Jacobian never changes */
(void) node;
}

static void wsum_hess_init(expr *node)
{
/* Constant Hessian is all zeros: n_vars x n_vars with 0 nonzeros. */
/* Hessian is all zeros */
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0);
}

static void eval_wsum_hess(expr *node, const double *w)
{
/* Constant Hessian is always zero - nothing to compute */
(void) node;
(void) w;
}
Expand All @@ -58,12 +67,21 @@ static bool is_affine(const expr *node)
return true;
}

expr *new_constant(int d1, int d2, int n_vars, const double *values)
expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values)
{
expr *node = (expr *) calloc(1, sizeof(expr));
parameter_expr *pnode = (parameter_expr *) calloc(1, sizeof(parameter_expr));
expr *node = &pnode->base;
init_expr(node, d1, d2, n_vars, forward, jacobian_init, eval_jacobian, is_affine,
wsum_hess_init, eval_wsum_hess, NULL);
memcpy(node->value, values, node->size * sizeof(double));
pnode->param_id = param_id;
pnode->has_been_refreshed = false;

/* If values provided (fixed constant), copy them now.
Otherwise values will be populated by problem_update_params. */
if (values != NULL)
{
memcpy(node->value, values, node->size * sizeof(double));
}

return node;
}
3 changes: 2 additions & 1 deletion src/affine/sum.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* limitations under the License.
*/
#include "affine.h"
#include "memory_wrappers.h"
#include "utils/CSR_sum.h"
#include "utils/int_double_pair.h"
#include "utils/mini_numpy.h"
Expand Down Expand Up @@ -175,7 +176,7 @@ static bool is_affine(const expr *node)
static void free_type_data(expr *node)
{
sum_expr *snode = (sum_expr *) node;
free(snode->idx_map);
FREE_AND_NULL(snode->idx_map);
}

expr *new_sum(expr *child, int axis)
Expand Down
3 changes: 2 additions & 1 deletion src/affine/trace.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* limitations under the License.
*/
#include "affine.h"
#include "memory_wrappers.h"
#include "utils/CSR_sum.h"
#include "utils/int_double_pair.h"
#include "utils/utils.h"
Expand Down Expand Up @@ -139,7 +140,7 @@ static void free_type_data(expr *node)
if (node)
{
trace_expr *tnode = (trace_expr *) node;
free(tnode->idx_map);
FREE_AND_NULL(tnode->idx_map);
}
}

Expand Down
Loading