feat(bioemu): Add FKC steering denoiser & refactor steering code#203
feat(bioemu): Add FKC steering denoiser & refactor steering code#203
Conversation
Split steering.py into a steering/ package with modular components: - steering/potentials.py: Potential base class, UmbrellaPotential, LinearPotential - steering/collective_variables.py: CV framework (RMSD, FNC, CaCaDistance, PairwiseClash) - steering/utils.py: Resampling, reward computation, sequence alignment helpers - steering/dpm_fkc.py: FKC (Feynman-Kac Control) steered denoiser - steering/dpm_smc.py: SMC (Sequential Monte Carlo) steered denoiser Key changes: - Unified DPM-Solver primitives in denoiser.py (shared by FKC, SMC, unsteered) - Steering configs (cv_steer.yaml, physical_steering.yaml) are self-contained Hydra denoiser configs with target, potentials, and steering params - Simplified sample.py: removed steering_config param, denoiser handles everything - Added start/end time window for steering resampling - Simplified loop returns to (batch, log_weights) 2-tuple Tests: - 60+ steering tests: unit tests for CVs, potentials, utils; integration tests for FKC/SMC loops, ODE consistency, generate_batch pipeline - Chignolin e2e tests (require model weights) Closes: https://github.com/msr-ai4science/feynman/issues/20268 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Split steering.py into a steering/ package with modular components: - steering/potentials.py: Potential base class, UmbrellaPotential, LinearPotential - steering/collective_variables.py: CV framework (RMSD, FNC, CaCaDistance, PairwiseClash) - steering/utils.py: Resampling, reward computation, sequence alignment helpers - steering/dpm_fkc.py: FKC (Feynman-Kac Control) steered denoiser - steering/dpm_smc.py: SMC (Sequential Monte Carlo) steered denoiser Key changes: - Unified DPM-Solver primitives in denoiser.py (shared by FKC, SMC, unsteered) - Steering configs (cv_steer.yaml, physical_steering.yaml) are self-contained Hydra denoiser configs with target, potentials, and steering params - Simplified sample.py: removed steering_config param, denoiser handles everything - Added start/end time window for steering resampling - Simplified loop returns to (batch, log_weights) 2-tuple Tests: - 60+ steering tests: unit tests for CVs, potentials, utils; integration tests for FKC/SMC loops, ODE consistency, generate_batch pipeline - Chignolin e2e tests (require model weights) Closes: https://github.com/msr-ai4science/feynman/issues/20268 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
| return self.compute_batch(all_positions * 10.0, sequence) | ||
|
|
||
|
|
||
| class FractionNativeContacts(CollectiveVariable): |
There was a problem hiding this comment.
This one has duplication with the train/foldedness.py, we might want to decide which one to keep
There was a problem hiding this comment.
Do we do something explicitly with bioemu/training/foldedness.py:foldedness? Otherwise can simply use that one and wrap it here with a CollectiveVariable.
- Change physical_steering.yaml target from dpm_solver_fkc to dpm_solver_smc - Fix SMC loop bug: log_weights was overwritten with None outside steering window - Rewrite README steering section: document both SMC/FKC algorithms, update CLI examples to use denoiser_config (removed old steering_config param), fix parameter descriptions to match current interface Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…ioemu into yuxie1/fkc-steering # Conflicts: # src/bioemu/config/steering/physical_steering.yaml # src/bioemu/steering/dpm_smc.py
SummarySummary
Coveragesrc.bioemu - 89.1%
src.bioemu.colabfold_setup -
src.bioemu.hpacker_setup - 58.8%
src.bioemu.openfold.np - 44%
src.bioemu.openfold.utils - 50.1%
src.bioemu.steering - 79.3%
src.bioemu.training - 100%
|
There was a problem hiding this comment.
Pull request overview
Refactors BioEmu’s steering functionality into a dedicated bioemu.steering package and adds modular, Hydra-configured steered denoisers (FKC + SMC) built on shared DPM-Solver primitives.
Changes:
- Introduces
bioemu.steeringsubpackage (CVs, potentials, utilities) and new steered denoisers (dpm_fkc,dpm_smc). - Extracts/shared DPM-Solver helper primitives in
denoiser.pyand simplifies sampling API to steer via a singledenoiser_config. - Replaces the legacy steering test with a broader steering test suite (unit + lightweight integration + optional e2e).
Reviewed changes
Copilot reviewed 19 out of 21 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_steering.py | Removes legacy steering e2e test module. |
| tests/steering/init.py | Adds steering test package marker. |
| tests/steering/test_utils.py | Adds unit tests for resampling + helper utilities. |
| tests/steering/test_potentials.py | Adds unit tests for UmbrellaPotential / LinearPotential behavior. |
| tests/steering/test_integration.py | Adds lightweight integration tests for configs, solvers, resampling, and pipeline wiring. |
| tests/steering/test_denoisers.py | Adds tests for ESS computation and SO(3) gradient mapping helper. |
| tests/steering/test_collective_variables.py | Adds unit tests for CV implementations. |
| tests/steering/test_chignolin_e2e.py | Adds e2e tests that invoke sample() (requires model weights). |
| src/bioemu/steering/utils.py | New steering utilities: config validation, x0/R0 helpers, resampling, reward/grad computation. |
| src/bioemu/steering/potentials.py | New potential framework + UmbrellaPotential / LinearPotential. |
| src/bioemu/steering/dpm_smc.py | Adds SMC steered denoiser built on DPM-Solver++ utilities. |
| src/bioemu/steering/dpm_fkc.py | Adds FKC steered denoiser + analytical weight updates + ESS resampling. |
| src/bioemu/steering/collective_variables.py | Adds CV framework + implementations (FNC, RMSD, CaCaDistance, PairwiseClash). |
| src/bioemu/steering/init.py | Exposes steering public API (CVs, potentials, utilities). |
| src/bioemu/steering.py | Removes legacy monolithic steering module. |
| src/bioemu/sample.py | Simplifies sampling: steering now handled by the instantiated denoiser config. |
| src/bioemu/denoiser.py | Refactors DPM-Solver primitives into reusable helper dataclasses/functions. |
| src/bioemu/config/steering/physical_steering.yaml | Converts physical steering into a self-contained Hydra denoiser config. |
| src/bioemu/config/steering/cv_steer.yaml | Adds example self-contained FKC steering config. |
| README.md | Updates steering documentation and usage to the new single-config model. |
| .gitignore | Ignores tests/cross_repo/. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Two steering algorithms are available: | ||
|
|
||
| - **SMC (Sequential Monte Carlo)**: Simulates multiple *candidate samples* (particles) per desired output sample and resamples between them according to the favorability of the provided potentials. This is the default for physical steering. | ||
| - **FKC (Feynman–Kac Control)**: Applies importance weighting without resampling; useful when targeting a specific collective variable value (e.g., RMSD to a reference). |
There was a problem hiding this comment.
The README states FKC "applies importance weighting without resampling", but the current dpm_solver_fkc implementation does perform ESS-based resampling via resample_based_on_log_weights when num_particles > 1 and within the steering time window. Either update the README description to match the actual behavior, or adjust the implementation/config to make resampling optional/absent for FKC as documented.
| - **FKC (Feynman–Kac Control)**: Applies importance weighting without resampling; useful when targeting a specific collective variable value (e.g., RMSD to a reference). | |
| - **FKC (Feynman–Kac Control)**: Uses importance weighting and may perform ESS-based resampling between particles; useful when targeting a specific collective variable value (e.g., RMSD to a reference). |
| # 1. Compute cumulative sums (CDF) for each batch | ||
| cdf = torch.cumsum(weights, dim=-1) # (B, N) | ||
|
|
||
| # 2. Stratified positions: one per interval | ||
| # shape (B, N): each row gets N stratified uniforms | ||
| u = (torch.rand(B, N, device=weights.device) + torch.arange(N, device=weights.device)) / N | ||
|
|
||
| # 3. Inverse-CDF search: for each u, find smallest j s.t. cdf[b, j] >= u[b, i] | ||
| idx = torch.searchsorted(cdf, u, right=True) | ||
|
|
||
| return idx # shape (B, N) |
There was a problem hiding this comment.
stratified_resample() can return an out-of-range index when u exceeds the final CDF entry due to floating-point normalization error (i.e., searchsorted may return N). This will later crash when indexing data_list[i] during resampling. Consider explicitly forcing cdf[..., -1] = 1 (or cdf = cdf / cdf[..., -1:]) and/or clamping idx to [0, N-1] before returning.
There was a problem hiding this comment.
Add this edge case cdf[..., -1] = 1
| def resample_based_on_log_weights( | ||
| batch: ChemGraph, | ||
| log_weight: torch.Tensor, | ||
| n_particles: int, | ||
| is_last_step: bool, | ||
| ess_threshold: float, | ||
| step: int, | ||
| t: float, | ||
| ) -> tuple[ChemGraph, torch.Tensor, float, torch.Tensor]: | ||
| """Resample particles based on importance weights. |
There was a problem hiding this comment.
The return type annotation for resample_based_on_log_weights doesn't match what the function actually returns. It currently annotates -> tuple[ChemGraph, torch.Tensor, float, torch.Tensor], but the implementation returns (batch, log_weight, indices, ess) where both indices and ess are tensors. Updating the annotation (and, if desired, the docstring wording) will prevent type-checking confusion and downstream misuse.
There was a problem hiding this comment.
Okay, fix it. seems the better type annotation to me.
| potential_(None, 10 * coords, None, None, | ||
| t=step_index, N=num_steps, sequence=batch.sequence[0]). |
There was a problem hiding this comment.
The compute_reward_and_grad docstring describes calling potentials with the old signature (potential_(None, 10 * coords, None, None, t=step_index, N=num_steps, ...)), but the implementation now calls potential(10.0 * coords, t=..., sequence=...). Please update the docstring to reflect the current API so new potentials/CVs are implemented against the right contract.
| potential_(None, 10 * coords, None, None, | |
| t=step_index, N=num_steps, sequence=batch.sequence[0]). | |
| potential(10.0 * coords, t=step_index, sequence=batch.sequence[0]). |
| denoiser_type: Denoiser to use for sampling, if `denoiser_config` not specified. Comes in with default parameter configuration. Must be one of ['dpm', 'heun'] | ||
| denoiser_config: Path to a denoiser config YAML, or a dict. For steered sampling (FKC/SMC), | ||
| pass a steering config (e.g., config/steering/physical_steering.yaml) which includes | ||
| the denoiser target, potentials, and steering parameters in one file. |
There was a problem hiding this comment.
The docstring states denoiser_config can be a Path, but the current parsing logic only treats an exact str as a path input. If a caller passes a Path, it will hit the final type-assert and fail. Either update the implementation to accept Path/os.PathLike, or adjust the docstring/type to match the actual supported inputs.
| clip_max: 0.7 | ||
| cv: | ||
| _target_: bioemu.steering.RMSD | ||
| reference_pdb: null |
There was a problem hiding this comment.
cv_steer.yaml is described as a self-contained denoiser config, but it sets RMSD.reference_pdb: null. The RMSD CV constructor asserts reference_pdb is not None, so Hydra instantiation will fail if someone tries to use this config as-is. Consider using a concrete default path, or mark it as a required Hydra value (e.g., ???) and add a comment explaining it must be provided.
| reference_pdb: null | |
| # Path to the reference PDB file; must be provided by the user when using this config. | |
| reference_pdb: ??? |
SummarySummary
Coveragesrc.bioemu - 89.1%
src.bioemu.colabfold_setup -
src.bioemu.hpacker_setup - 58.8%
src.bioemu.openfold.np - 44%
src.bioemu.openfold.utils - 50.1%
src.bioemu.steering - 79.4%
src.bioemu.training - 100%
|
|
Good afternoon! I see that this pull request has been open for quite a while. Is there anything I can help with to move it forward and get it closed faster? Especially with preparing the notebooks, tests, and related tasks. |
|
Hi @vkuzniak, we appreciate the keen eye and your interest in the feature. |
|
Hi @ludwigwinkler, I am glad to hear that and will be waiting. Interesting to see this list. |
ludwigwinkler
left a comment
There was a problem hiding this comment.
Thanks for the refactoring and the work you put in!
| return self.compute_batch(all_positions * 10.0, sequence) | ||
|
|
||
|
|
||
| class FractionNativeContacts(CollectiveVariable): |
There was a problem hiding this comment.
Do we do something explicitly with bioemu/training/foldedness.py:foldedness? Otherwise can simply use that one and wrap it here with a CollectiveVariable.
| H = torch.einsum("bni,nj->bij", samples_centered, ref_centered) | ||
|
|
||
| # SVD decomposition | ||
| U, S, Vh = torch.linalg.svd(H) | ||
|
|
||
| # Optimal rotation (handle reflection) | ||
| d = torch.det(torch.bmm(Vh.transpose(-2, -1), U.transpose(-2, -1))) | ||
| sign_matrix = torch.ones(batch_size, 3, device=device) | ||
| sign_matrix[:, -1] = d.sign() | ||
| R = torch.bmm(Vh.transpose(-2, -1) * sign_matrix.unsqueeze(-1), U.transpose(-2, -1)) | ||
|
|
||
| # Detach R so gradients don't flow through SVD (numerically unstable) | ||
| R = R.detach() | ||
|
|
||
| # Apply rotation | ||
| samples_rotated = torch.einsum("bij,bnj->bni", R, samples_centered) |
There was a problem hiding this comment.
nit: could be its own function in utils
| steer_start = steering_config.get("start", 1.0) | ||
| steer_end = steering_config.get("end", 0.0) |
There was a problem hiding this comment.
Do we want to let it fail explicitly? I like explicit failures more than implicit default values.
| max_t, | ||
| potentials, | ||
| step_idx, | ||
| use_x0_for_reward: bool = False, |
There was a problem hiding this comment.
we support use_x0_for_reward=False even though SMC itself is not defined for intermediate values? All our potentials should be defined at t=0, I believe. Is this for solver_step argument consistency (same as for fkc solver)?
| self.guidance_steering = guidance_steering | ||
| self.cv = cv | ||
|
|
||
| @staticmethod |
| BS_offset = torch.arange(n_groups).unsqueeze(-1) * effective_n_particles # [n_groups, 1] | ||
| indices = (indices + BS_offset.to(indices.device)).flatten() # [n_groups, n_particles] | ||
|
|
||
| # Resample samples | ||
| data_list = batch.to_data_list() | ||
| resampled_data_list = [data_list[i] for i in indices] | ||
| batch = Batch.from_data_list( | ||
| resampled_data_list | ||
| ) # TODO: there should be a more efficient way |
There was a problem hiding this comment.
We can use apply a chunk algorithm that gracefully respect non-integer-division batches sizes. For example if we have a batch size of 27 due to GPU memory constraints, and num_particles=10, the chunks would be [[10], [10], [8]] and by equally chunking the normalized weights, we can apply stratified_resample on those weights and resample the list and finally restack/concatenate.
|
|
||
| if use_x0_for_reward or eval_score: | ||
| # Lazy import to avoid circular dependency (denoiser.py imports from steering) | ||
| from ..denoiser import get_score |
There was a problem hiding this comment.
can we make that an explicit import instead of relative?
| RtG = R.transpose(-2, -1) @ dJ_dR # (...,3,3) | ||
| A = 0.5 * (RtG - RtG.transpose(-2, -1)) # skew(...) in so(3) | ||
| return 2.0 * skew_matrix_to_vector(A) # (...,3) vee-map | ||
|
|
There was a problem hiding this comment.
where did you find that equation?
There was a problem hiding this comment.
I'll trust you on this one.
Add a bimodal Gaussian mixture model with an analytical score function. Implements a toy steering example in `notebooks/fkc_steering.py`. Adds two numerical tests for for `dpm_fkc` and `dpm_smc` that simulates a large ensemble of samples to numerically match the analytically tractable (in 1D) biased target distribution. --------- Co-authored-by: ludwigwinkler <luwinkler@microsoft.com>
…conversion layer - CVs and potentials now receive Cα positions in nm directly - Remove 10x multiplication at call site and /10 divisions in CVs - Rename Ca_pos/ca_pos to ca_pos_nm for explicit units - Rescale config constants to nm (target, flatbottom, slope, min_dist) - Remove log_physicality helper (unused) - Update all tests to use nm-scale values Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Put steering functionalities into a steering/ package with modular components:
Other key changes:
steeringdir.dpm.yamlfor example.resampling_frequencytoess_thresholdTests:
test_steering.pyTODOs:
FractionNativeConcactCV class versustrain/foldedness.pyCloses: https://github.com/msr-ai4science/feynman/issues/20268