diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 5c4ac4066..90c1082b0 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -1992,7 +1992,7 @@ def __getitem__(self, key): # De-duplicate while preserving order so method dispatch does not repeat. objs = list(dict.fromkeys(objs)) if len(objs) == 1: - return objs[0] + return SubplotGrid(objs[0]) return SubplotGrid(objs) def __setitem__(self, key, value): diff --git a/ultraplot/tests/test_gridspec.py b/ultraplot/tests/test_gridspec.py index 3c8e8250b..38802d2bb 100644 --- a/ultraplot/tests/test_gridspec.py +++ b/ultraplot/tests/test_gridspec.py @@ -4,6 +4,13 @@ from ultraplot.gridspec import SubplotGrid +def _singleton_axis(obj): + if isinstance(obj, SubplotGrid): + assert len(obj) == 1 + return obj[0] + return obj + + def test_grid_has_dynamic_methods(): """ Check that we can apply the methods to a SubplotGrid object. @@ -135,8 +142,9 @@ def test_gridspec_spanning_slice_deduplicates_axes(): # The first two slots in the top row refer to the same spanning subplot. ax = axs[0, :2] - assert isinstance(ax, uplt.axes.Axes) - assert ax is axs[0, 0] + assert isinstance(ax, uplt.SubplotGrid) + assert len(ax) == 1 + assert _singleton_axis(ax) is _singleton_axis(axs[0, 0]) data = np.array([[0.1, 0.2], [0.4, 0.5], [0.7, 0.8]]) ax.scatter(data[:, 0], data[:, 1], c="grey", label="data", legend=True) @@ -145,3 +153,15 @@ def test_gridspec_spanning_slice_deduplicates_axes(): legend = ax.get_legend() assert legend is not None assert [t.get_text() for t in legend.texts] == ["data"] + + +def test_return_type_after_indexing(): + """ + Inexing should always return a SubplotGrid even if we have 1 element + """ + fig, axs = uplt.subplots(ncols=2, nrows=2) + assert isinstance(axs[1, 0:], uplt.SubplotGrid) + assert len(axs[1, 0:]) == 2 + + assert isinstance(axs[1, 1:], uplt.SubplotGrid) + assert len(axs[1, 1:]) == 1 diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index a8ebb4455..43fae2a81 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -1355,29 +1355,33 @@ def test_legend_span_inference_with_multi_panels( def test_legend_best_axis_selection_right_left(): fig, axs = uplt.subplots(nrows=1, ncols=3) axs.plot([0, 1], [0, 1], label="line") - ref = [axs[0, 0], axs[0, 2]] + left = _anchor_axis(axs[0, 0]) + right = _anchor_axis(axs[0, 2]) + ref = [left, right] fig.legend(ref=ref, loc="r", rows=1) - assert len(axs[0, 2]._panel_dict["right"]) == 1 - assert len(axs[0, 0]._panel_dict["right"]) == 0 + assert len(right._panel_dict["right"]) == 1 + assert len(left._panel_dict["right"]) == 0 fig.legend(ref=ref, loc="l", rows=1) - assert len(axs[0, 0]._panel_dict["left"]) == 1 - assert len(axs[0, 2]._panel_dict["left"]) == 0 + assert len(left._panel_dict["left"]) == 1 + assert len(right._panel_dict["left"]) == 0 def test_legend_best_axis_selection_top_bottom(): fig, axs = uplt.subplots(nrows=2, ncols=1) axs.plot([0, 1], [0, 1], label="line") - ref = [axs[0, 0], axs[1, 0]] + top = _anchor_axis(axs[0, 0]) + bottom = _anchor_axis(axs[1, 0]) + ref = [top, bottom] fig.legend(ref=ref, loc="t", cols=1) - assert len(axs[0, 0]._panel_dict["top"]) == 1 - assert len(axs[1, 0]._panel_dict["top"]) == 0 + assert len(top._panel_dict["top"]) == 1 + assert len(bottom._panel_dict["top"]) == 0 fig.legend(ref=ref, loc="b", cols=1) - assert len(axs[1, 0]._panel_dict["bottom"]) == 1 - assert len(axs[0, 0]._panel_dict["bottom"]) == 0 + assert len(bottom._panel_dict["bottom"]) == 1 + assert len(top._panel_dict["bottom"]) == 0 def test_legend_span_decode_fallback(monkeypatch): diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index c4d7c6d96..ee61616c4 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -9,6 +9,13 @@ import ultraplot as uplt +def _singleton_axis(obj): + if isinstance(obj, uplt.SubplotGrid): + assert len(obj) == 1 + return obj[0] + return obj + + @pytest.mark.mpl_image_compare def test_align_labels(): """ @@ -354,7 +361,7 @@ def test_subset_share_xlabels_implicit_column(): for axi, lab in fig._supxlabel_dict.items() if lab.get_text() == "Right-column X" ] - assert label_axes and label_axes[0] is ax[1, 1] + assert label_axes and label_axes[0] is _singleton_axis(ax[1, 1]) uplt.close(fig) @@ -406,7 +413,7 @@ def test_subset_share_ylabels_implicit_row(): label_axes = [ axi for axi, lab in fig._supylabel_dict.items() if lab.get_text() == "Top-row Y" ] - assert label_axes and label_axes[0] is ax[0, 0] + assert label_axes and label_axes[0] is _singleton_axis(ax[0, 0]) uplt.close(fig) @@ -493,7 +500,7 @@ def test_subset_share_xlabels_implicit_column_top(): for axi, lab in fig._supxlabel_dict.items() if lab.get_text() == "Right-column X (top)" ] - assert label_axes and label_axes[0] is ax[0, 1] + assert label_axes and label_axes[0] is _singleton_axis(ax[0, 1]) uplt.close(fig) @@ -512,7 +519,7 @@ def test_subset_share_ylabels_implicit_row_right(): for axi, lab in fig._supylabel_dict.items() if lab.get_text() == "Top-row Y (right)" ] - assert label_axes and label_axes[0] is ax[0, 1] + assert label_axes and label_axes[0] is _singleton_axis(ax[0, 1]) uplt.close(fig)