From f9ba832ea5b1e7b7c60047f839fb0fc52619333b Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 31 Mar 2026 06:54:38 +1000 Subject: [PATCH 1/3] Fix issue where single object returns object itself --- ultraplot/gridspec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 5c4ac4066..90c1082b0 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -1992,7 +1992,7 @@ def __getitem__(self, key): # De-duplicate while preserving order so method dispatch does not repeat. objs = list(dict.fromkeys(objs)) if len(objs) == 1: - return objs[0] + return SubplotGrid(objs[0]) return SubplotGrid(objs) def __setitem__(self, key, value): From a0316a03e139eec910f8030e0defe9bbfb5c1d09 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 31 Mar 2026 06:57:14 +1000 Subject: [PATCH 2/3] Add tests --- ultraplot/tests/test_gridspec.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ultraplot/tests/test_gridspec.py b/ultraplot/tests/test_gridspec.py index 3c8e8250b..2f7e89ac1 100644 --- a/ultraplot/tests/test_gridspec.py +++ b/ultraplot/tests/test_gridspec.py @@ -145,3 +145,15 @@ def test_gridspec_spanning_slice_deduplicates_axes(): legend = ax.get_legend() assert legend is not None assert [t.get_text() for t in legend.texts] == ["data"] + + +def test_return_type_after_indexing(): + """ + Inexing should always return a SubplotGrid even if we have 1 element + """ + fig, axs = uplt.subplots(ncols=2, nrows=2) + assert axs[1, 0:] is uplt.SubplotGrid + assert len(axs[1, 0:]) == 2 + + assert axs[1, 1:] is uplt.SubplotGrid + assert len(axs[1, 1:]) == 1 From 982d7c9dca376fc055155fc044b5c76a0802bd2c Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 31 Mar 2026 07:09:07 +1000 Subject: [PATCH 3/3] Fix regressed tests --- ultraplot/tests/test_gridspec.py | 16 ++++++++++++---- ultraplot/tests/test_legend.py | 24 ++++++++++++++---------- ultraplot/tests/test_subplots.py | 15 +++++++++++---- 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/ultraplot/tests/test_gridspec.py b/ultraplot/tests/test_gridspec.py index 2f7e89ac1..38802d2bb 100644 --- a/ultraplot/tests/test_gridspec.py +++ b/ultraplot/tests/test_gridspec.py @@ -4,6 +4,13 @@ from ultraplot.gridspec import SubplotGrid +def _singleton_axis(obj): + if isinstance(obj, SubplotGrid): + assert len(obj) == 1 + return obj[0] + return obj + + def test_grid_has_dynamic_methods(): """ Check that we can apply the methods to a SubplotGrid object. @@ -135,8 +142,9 @@ def test_gridspec_spanning_slice_deduplicates_axes(): # The first two slots in the top row refer to the same spanning subplot. ax = axs[0, :2] - assert isinstance(ax, uplt.axes.Axes) - assert ax is axs[0, 0] + assert isinstance(ax, uplt.SubplotGrid) + assert len(ax) == 1 + assert _singleton_axis(ax) is _singleton_axis(axs[0, 0]) data = np.array([[0.1, 0.2], [0.4, 0.5], [0.7, 0.8]]) ax.scatter(data[:, 0], data[:, 1], c="grey", label="data", legend=True) @@ -152,8 +160,8 @@ def test_return_type_after_indexing(): Inexing should always return a SubplotGrid even if we have 1 element """ fig, axs = uplt.subplots(ncols=2, nrows=2) - assert axs[1, 0:] is uplt.SubplotGrid + assert isinstance(axs[1, 0:], uplt.SubplotGrid) assert len(axs[1, 0:]) == 2 - assert axs[1, 1:] is uplt.SubplotGrid + assert isinstance(axs[1, 1:], uplt.SubplotGrid) assert len(axs[1, 1:]) == 1 diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index a8ebb4455..43fae2a81 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -1355,29 +1355,33 @@ def test_legend_span_inference_with_multi_panels( def test_legend_best_axis_selection_right_left(): fig, axs = uplt.subplots(nrows=1, ncols=3) axs.plot([0, 1], [0, 1], label="line") - ref = [axs[0, 0], axs[0, 2]] + left = _anchor_axis(axs[0, 0]) + right = _anchor_axis(axs[0, 2]) + ref = [left, right] fig.legend(ref=ref, loc="r", rows=1) - assert len(axs[0, 2]._panel_dict["right"]) == 1 - assert len(axs[0, 0]._panel_dict["right"]) == 0 + assert len(right._panel_dict["right"]) == 1 + assert len(left._panel_dict["right"]) == 0 fig.legend(ref=ref, loc="l", rows=1) - assert len(axs[0, 0]._panel_dict["left"]) == 1 - assert len(axs[0, 2]._panel_dict["left"]) == 0 + assert len(left._panel_dict["left"]) == 1 + assert len(right._panel_dict["left"]) == 0 def test_legend_best_axis_selection_top_bottom(): fig, axs = uplt.subplots(nrows=2, ncols=1) axs.plot([0, 1], [0, 1], label="line") - ref = [axs[0, 0], axs[1, 0]] + top = _anchor_axis(axs[0, 0]) + bottom = _anchor_axis(axs[1, 0]) + ref = [top, bottom] fig.legend(ref=ref, loc="t", cols=1) - assert len(axs[0, 0]._panel_dict["top"]) == 1 - assert len(axs[1, 0]._panel_dict["top"]) == 0 + assert len(top._panel_dict["top"]) == 1 + assert len(bottom._panel_dict["top"]) == 0 fig.legend(ref=ref, loc="b", cols=1) - assert len(axs[1, 0]._panel_dict["bottom"]) == 1 - assert len(axs[0, 0]._panel_dict["bottom"]) == 0 + assert len(bottom._panel_dict["bottom"]) == 1 + assert len(top._panel_dict["bottom"]) == 0 def test_legend_span_decode_fallback(monkeypatch): diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index c4d7c6d96..ee61616c4 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -9,6 +9,13 @@ import ultraplot as uplt +def _singleton_axis(obj): + if isinstance(obj, uplt.SubplotGrid): + assert len(obj) == 1 + return obj[0] + return obj + + @pytest.mark.mpl_image_compare def test_align_labels(): """ @@ -354,7 +361,7 @@ def test_subset_share_xlabels_implicit_column(): for axi, lab in fig._supxlabel_dict.items() if lab.get_text() == "Right-column X" ] - assert label_axes and label_axes[0] is ax[1, 1] + assert label_axes and label_axes[0] is _singleton_axis(ax[1, 1]) uplt.close(fig) @@ -406,7 +413,7 @@ def test_subset_share_ylabels_implicit_row(): label_axes = [ axi for axi, lab in fig._supylabel_dict.items() if lab.get_text() == "Top-row Y" ] - assert label_axes and label_axes[0] is ax[0, 0] + assert label_axes and label_axes[0] is _singleton_axis(ax[0, 0]) uplt.close(fig) @@ -493,7 +500,7 @@ def test_subset_share_xlabels_implicit_column_top(): for axi, lab in fig._supxlabel_dict.items() if lab.get_text() == "Right-column X (top)" ] - assert label_axes and label_axes[0] is ax[0, 1] + assert label_axes and label_axes[0] is _singleton_axis(ax[0, 1]) uplt.close(fig) @@ -512,7 +519,7 @@ def test_subset_share_ylabels_implicit_row_right(): for axi, lab in fig._supylabel_dict.items() if lab.get_text() == "Top-row Y (right)" ] - assert label_axes and label_axes[0] is ax[0, 1] + assert label_axes and label_axes[0] is _singleton_axis(ax[0, 1]) uplt.close(fig)