|
8 | 8 | from plotly import graph_objects as go |
9 | 9 | from plotly.subplots import make_subplots |
10 | 10 |
|
11 | | -from optimagic.optimization.optimize_result import OptimizeResult |
12 | 11 | from tranquilo.clustering import cluster |
13 | 12 | from tranquilo.geometry import log_d_quality_calculator |
14 | 13 | from tranquilo.volume import get_radius_after_volume_scaling |
15 | 14 |
|
| 15 | +from typing import Any, Protocol, runtime_checkable |
| 16 | + |
16 | 17 |
|
17 | 18 | def visualize_tranquilo(results, iterations): |
18 | 19 | """Plot diagnostic information of optimization result in given iteration(s). |
@@ -56,7 +57,7 @@ def visualize_tranquilo(results, iterations): |
56 | 57 | if isinstance(iterations, int): |
57 | 58 | iterations = {case: iterations for case in results} |
58 | 59 | results = {case: _process_results(results[case]) for case in results} |
59 | | - elif isinstance(results, OptimizeResult): |
| 60 | + elif isinstance(results, OptimizeResultLike): |
60 | 61 | results = _process_results(results) |
61 | 62 | results = {f"iteration {i}": results for i in iterations} |
62 | 63 | iterations = {f"iteration {iteration}": iteration for iteration in iterations} |
@@ -588,3 +589,13 @@ def _get_model_indices(xs, state): |
588 | 589 | for point in state.model_points: |
589 | 590 | model_indices = np.concatenate([model_indices, _find_index(xs, point)]) |
590 | 591 | return model_indices.astype(int) |
| 592 | + |
| 593 | + |
| 594 | +@runtime_checkable |
| 595 | +class OptimizeResultLike(Protocol): |
| 596 | + """Runtime-checkable stand-in for optimagic's OptimizeResult object.""" |
| 597 | + |
| 598 | + algorithm: str |
| 599 | + history: Any |
| 600 | + params: Any |
| 601 | + algorithm_output: dict |
0 commit comments