Skip to content

Feature/new mv interface#93

Open
miaoqingyu2 wants to merge 4 commits intobartz-org:mainfrom
miaoqingyu2:feature/new-mv-interface
Open

Feature/new mv interface#93
miaoqingyu2 wants to merge 4 commits intobartz-org:mainfrom
miaoqingyu2:feature/new-mv-interface

Conversation

@miaoqingyu2
Copy link
Contributor

@miaoqingyu2 miaoqingyu2 commented Mar 21, 2026

Sorry for the long gap since my last commit! I went through all the updates on main since then, and decided it would be cleaner to start over with a fresh branch rather than trying to rebase. For now code updated and pre-commit hooks pass. Haven't been able to run the full test suite locally due to a container thread limit, so marking as draft for now.

@miaoqingyu2 miaoqingyu2 marked this pull request as ready for review March 22, 2026 16:39
Copy link
Collaborator

@Gattocrucco Gattocrucco left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the changes. Most of my comments are about merging the uv and mv code paths better, it greatly helps in keeping the code maintainable going forward.

An obvious missing thing is an interface to extract the error covariance matrix in the multivariate case, but that's quite self-contained and requires to design new stuff so I think I should do it in a separate PR.

Comment on lines +330 to +337
self._validate_compatibility(y_train, w, type)
if w is not None:
w = self._process_response_input(w)
self._check_same_length(x_train, w)

# check data types are correct for continuous/binary regression
self._check_type_settings(y_train, type, w)
if y_train.ndim == 1:
self._check_type_settings(y_train, type, w)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge _validate_compatibility() into _check_type_settings().


# configure priors (UV vs MV)
error_cov_df, error_cov_scale, leaf_prior_cov_inv, sigest = (
self._configure_priors(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename _configure_priors to _configure_variances

# set offset from the state because of buffer donation
self.offset = result.final_state.offset
self.sigest = sigest
self.sigest = sigest if y_train.ndim == 1 else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define sigest in the multivariate case.

Comment on lines +856 to +857
if y_train.ndim == 2:
return y_train.mean(axis=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle the case n < 1 in the multivariate case

def _process_offset_settings(
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'],
offset: float | Float32[Any, ''] | None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the type hint to offset for the multivariate case

raise TypeError(msg)

@classmethod
def _configure_priors(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge the uv and mv code paths better here. _process_error_variance_settings_uv/mv share much of the same logic and can probably be merged into a single function (see also my other comment on merging the part that does the linear regression). The final expressions to compute error_cov_df, error_cov_scale and leaf_prior_cov_inv are also very similar across the two cases, it should be possible to mostly merge them.

Comment on lines +499 to +510
n, p, k = mv_data_shape
sigma_noise = 0.1

key_x, key_eps = random.split(keys.pop(), 2)
X = random.uniform(key_x, (p, n), float, -2, 2)

s = jnp.ones((k, p))
norm_s = jnp.sqrt(jnp.sum(s * s, axis=1, keepdims=True))
F = (s @ jnp.cos(jnp.pi * X)) / norm_s

y = F + sigma_noise * random.normal(key_eps, (k, n))
return X, y
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use bartz.testing.gen_data.

Comment on lines +578 to +586
def test_mv_yhat_train_shape(self, mv_data: tuple) -> None:
"""yhat_train should have shape (ndpost, k, n)."""
X, Y = mv_data
k, n = Y.shape
model = Bart(
x_train=X, y_train=Y, num_trees=5, ndpost=10, nskip=5, num_chains=None
)
assert model.yhat_train.shape == (model.ndpost, k, n)
assert model.yhat_train_mean.shape == (k, n)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge this into test_initialization_and_shapes, you can run Bart once and check multiple shapes.

Comment on lines +612 to +633
# Check yhat convergence
yhat_train = model.yhat_train.reshape(
num_chains, nsamples_per_chain, k_dim, n_train
)
yhat_train_mean = yhat_train.mean(axis=-1)
max_rhats_yhat = [rhat(yhat_train_mean[:, :, j]) for j in range(k_dim)]
global_max_rhat = jnp.max(jnp.stack(max_rhats_yhat))
assert global_max_rhat < 1.1

# Check covariance matrix convergence
prec_trace = model._main_trace.error_cov_inv
if prec_trace.ndim == 3:
prec_trace = prec_trace.reshape(
num_chains, nsamples_per_chain, k_dim, k_dim
)

prec_flat = prec_trace.reshape(num_chains, nsamples_per_chain, -1)
assert jnp.all(jnp.std(prec_flat, axis=1) > 1e-8), 'Sigma is not updating!'

max_rhats_prec = [rhat(prec_flat[:, :, j]) for j in range(k_dim * k_dim)]
max_rhat_sigma = jnp.max(jnp.array(max_rhats_prec))
assert max_rhat_sigma < 1.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't you using multivariate rhats? Because they were too broken? (It makes sense as mv rhat is probably more sensitive to non-normality). Please leave a comment saying why you use max(uv rhat) instead of mv rhat.

yhat_train = model.yhat_train.reshape(
num_chains, nsamples_per_chain, k_dim, n_train
)
yhat_train_mean = yhat_train.mean(axis=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you checking the mean of the ys instead of all of them?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants