Conversation
Gattocrucco
left a comment
There was a problem hiding this comment.
I reviewed the changes. Most of my comments are about merging the uv and mv code paths better, it greatly helps in keeping the code maintainable going forward.
An obvious missing thing is an interface to extract the error covariance matrix in the multivariate case, but that's quite self-contained and requires to design new stuff so I think I should do it in a separate PR.
| self._validate_compatibility(y_train, w, type) | ||
| if w is not None: | ||
| w = self._process_response_input(w) | ||
| self._check_same_length(x_train, w) | ||
|
|
||
| # check data types are correct for continuous/binary regression | ||
| self._check_type_settings(y_train, type, w) | ||
| if y_train.ndim == 1: | ||
| self._check_type_settings(y_train, type, w) |
There was a problem hiding this comment.
Merge _validate_compatibility() into _check_type_settings().
|
|
||
| # configure priors (UV vs MV) | ||
| error_cov_df, error_cov_scale, leaf_prior_cov_inv, sigest = ( | ||
| self._configure_priors( |
There was a problem hiding this comment.
Rename _configure_priors to _configure_variances
| # set offset from the state because of buffer donation | ||
| self.offset = result.final_state.offset | ||
| self.sigest = sigest | ||
| self.sigest = sigest if y_train.ndim == 1 else None |
There was a problem hiding this comment.
Define sigest in the multivariate case.
| if y_train.ndim == 2: | ||
| return y_train.mean(axis=1) |
There was a problem hiding this comment.
Handle the case n < 1 in the multivariate case
| def _process_offset_settings( | ||
| y_train: Float32[Array, ' n'] | Bool[Array, ' n'], | ||
| y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'], | ||
| offset: float | Float32[Any, ''] | None, |
There was a problem hiding this comment.
Add the type hint to offset for the multivariate case
| raise TypeError(msg) | ||
|
|
||
| @classmethod | ||
| def _configure_priors( |
There was a problem hiding this comment.
Merge the uv and mv code paths better here. _process_error_variance_settings_uv/mv share much of the same logic and can probably be merged into a single function (see also my other comment on merging the part that does the linear regression). The final expressions to compute error_cov_df, error_cov_scale and leaf_prior_cov_inv are also very similar across the two cases, it should be possible to mostly merge them.
| n, p, k = mv_data_shape | ||
| sigma_noise = 0.1 | ||
|
|
||
| key_x, key_eps = random.split(keys.pop(), 2) | ||
| X = random.uniform(key_x, (p, n), float, -2, 2) | ||
|
|
||
| s = jnp.ones((k, p)) | ||
| norm_s = jnp.sqrt(jnp.sum(s * s, axis=1, keepdims=True)) | ||
| F = (s @ jnp.cos(jnp.pi * X)) / norm_s | ||
|
|
||
| y = F + sigma_noise * random.normal(key_eps, (k, n)) | ||
| return X, y |
There was a problem hiding this comment.
Use bartz.testing.gen_data.
| def test_mv_yhat_train_shape(self, mv_data: tuple) -> None: | ||
| """yhat_train should have shape (ndpost, k, n).""" | ||
| X, Y = mv_data | ||
| k, n = Y.shape | ||
| model = Bart( | ||
| x_train=X, y_train=Y, num_trees=5, ndpost=10, nskip=5, num_chains=None | ||
| ) | ||
| assert model.yhat_train.shape == (model.ndpost, k, n) | ||
| assert model.yhat_train_mean.shape == (k, n) |
There was a problem hiding this comment.
Merge this into test_initialization_and_shapes, you can run Bart once and check multiple shapes.
| # Check yhat convergence | ||
| yhat_train = model.yhat_train.reshape( | ||
| num_chains, nsamples_per_chain, k_dim, n_train | ||
| ) | ||
| yhat_train_mean = yhat_train.mean(axis=-1) | ||
| max_rhats_yhat = [rhat(yhat_train_mean[:, :, j]) for j in range(k_dim)] | ||
| global_max_rhat = jnp.max(jnp.stack(max_rhats_yhat)) | ||
| assert global_max_rhat < 1.1 | ||
|
|
||
| # Check covariance matrix convergence | ||
| prec_trace = model._main_trace.error_cov_inv | ||
| if prec_trace.ndim == 3: | ||
| prec_trace = prec_trace.reshape( | ||
| num_chains, nsamples_per_chain, k_dim, k_dim | ||
| ) | ||
|
|
||
| prec_flat = prec_trace.reshape(num_chains, nsamples_per_chain, -1) | ||
| assert jnp.all(jnp.std(prec_flat, axis=1) > 1e-8), 'Sigma is not updating!' | ||
|
|
||
| max_rhats_prec = [rhat(prec_flat[:, :, j]) for j in range(k_dim * k_dim)] | ||
| max_rhat_sigma = jnp.max(jnp.array(max_rhats_prec)) | ||
| assert max_rhat_sigma < 1.1 |
There was a problem hiding this comment.
Why aren't you using multivariate rhats? Because they were too broken? (It makes sense as mv rhat is probably more sensitive to non-normality). Please leave a comment saying why you use max(uv rhat) instead of mv rhat.
| yhat_train = model.yhat_train.reshape( | ||
| num_chains, nsamples_per_chain, k_dim, n_train | ||
| ) | ||
| yhat_train_mean = yhat_train.mean(axis=-1) |
There was a problem hiding this comment.
Why are you checking the mean of the ys instead of all of them?
Sorry for the long gap since my last commit! I went through all the updates on main since then, and decided it would be cleaner to start over with a fresh branch rather than trying to rebase.
For now code updated and pre-commit hooks pass. Haven't been able to run the full test suite locally due to a container thread limit, so marking as draft for now.