Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ultraplot/gridspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,7 +1992,7 @@ def __getitem__(self, key):
# De-duplicate while preserving order so method dispatch does not repeat.
objs = list(dict.fromkeys(objs))
if len(objs) == 1:
return objs[0]
return SubplotGrid(objs[0])
return SubplotGrid(objs)

def __setitem__(self, key, value):
Expand Down
24 changes: 22 additions & 2 deletions ultraplot/tests/test_gridspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from ultraplot.gridspec import SubplotGrid


def _singleton_axis(obj):
if isinstance(obj, SubplotGrid):
assert len(obj) == 1
return obj[0]
return obj


def test_grid_has_dynamic_methods():
"""
Check that we can apply the methods to a SubplotGrid object.
Expand Down Expand Up @@ -135,8 +142,9 @@ def test_gridspec_spanning_slice_deduplicates_axes():

# The first two slots in the top row refer to the same spanning subplot.
ax = axs[0, :2]
assert isinstance(ax, uplt.axes.Axes)
assert ax is axs[0, 0]
assert isinstance(ax, uplt.SubplotGrid)
assert len(ax) == 1
assert _singleton_axis(ax) is _singleton_axis(axs[0, 0])

data = np.array([[0.1, 0.2], [0.4, 0.5], [0.7, 0.8]])
ax.scatter(data[:, 0], data[:, 1], c="grey", label="data", legend=True)
Expand All @@ -145,3 +153,15 @@ def test_gridspec_spanning_slice_deduplicates_axes():
legend = ax.get_legend()
assert legend is not None
assert [t.get_text() for t in legend.texts] == ["data"]


def test_return_type_after_indexing():
"""
Inexing should always return a SubplotGrid even if we have 1 element
"""
fig, axs = uplt.subplots(ncols=2, nrows=2)
assert isinstance(axs[1, 0:], uplt.SubplotGrid)
assert len(axs[1, 0:]) == 2

assert isinstance(axs[1, 1:], uplt.SubplotGrid)
assert len(axs[1, 1:]) == 1
24 changes: 14 additions & 10 deletions ultraplot/tests/test_legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,29 +1355,33 @@ def test_legend_span_inference_with_multi_panels(
def test_legend_best_axis_selection_right_left():
fig, axs = uplt.subplots(nrows=1, ncols=3)
axs.plot([0, 1], [0, 1], label="line")
ref = [axs[0, 0], axs[0, 2]]
left = _anchor_axis(axs[0, 0])
right = _anchor_axis(axs[0, 2])
ref = [left, right]

fig.legend(ref=ref, loc="r", rows=1)
assert len(axs[0, 2]._panel_dict["right"]) == 1
assert len(axs[0, 0]._panel_dict["right"]) == 0
assert len(right._panel_dict["right"]) == 1
assert len(left._panel_dict["right"]) == 0

fig.legend(ref=ref, loc="l", rows=1)
assert len(axs[0, 0]._panel_dict["left"]) == 1
assert len(axs[0, 2]._panel_dict["left"]) == 0
assert len(left._panel_dict["left"]) == 1
assert len(right._panel_dict["left"]) == 0


def test_legend_best_axis_selection_top_bottom():
fig, axs = uplt.subplots(nrows=2, ncols=1)
axs.plot([0, 1], [0, 1], label="line")
ref = [axs[0, 0], axs[1, 0]]
top = _anchor_axis(axs[0, 0])
bottom = _anchor_axis(axs[1, 0])
ref = [top, bottom]

fig.legend(ref=ref, loc="t", cols=1)
assert len(axs[0, 0]._panel_dict["top"]) == 1
assert len(axs[1, 0]._panel_dict["top"]) == 0
assert len(top._panel_dict["top"]) == 1
assert len(bottom._panel_dict["top"]) == 0

fig.legend(ref=ref, loc="b", cols=1)
assert len(axs[1, 0]._panel_dict["bottom"]) == 1
assert len(axs[0, 0]._panel_dict["bottom"]) == 0
assert len(bottom._panel_dict["bottom"]) == 1
assert len(top._panel_dict["bottom"]) == 0


def test_legend_span_decode_fallback(monkeypatch):
Expand Down
15 changes: 11 additions & 4 deletions ultraplot/tests/test_subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
import ultraplot as uplt


def _singleton_axis(obj):
if isinstance(obj, uplt.SubplotGrid):
assert len(obj) == 1
return obj[0]
return obj


@pytest.mark.mpl_image_compare
def test_align_labels():
"""
Expand Down Expand Up @@ -354,7 +361,7 @@ def test_subset_share_xlabels_implicit_column():
for axi, lab in fig._supxlabel_dict.items()
if lab.get_text() == "Right-column X"
]
assert label_axes and label_axes[0] is ax[1, 1]
assert label_axes and label_axes[0] is _singleton_axis(ax[1, 1])

uplt.close(fig)

Expand Down Expand Up @@ -406,7 +413,7 @@ def test_subset_share_ylabels_implicit_row():
label_axes = [
axi for axi, lab in fig._supylabel_dict.items() if lab.get_text() == "Top-row Y"
]
assert label_axes and label_axes[0] is ax[0, 0]
assert label_axes and label_axes[0] is _singleton_axis(ax[0, 0])

uplt.close(fig)

Expand Down Expand Up @@ -493,7 +500,7 @@ def test_subset_share_xlabels_implicit_column_top():
for axi, lab in fig._supxlabel_dict.items()
if lab.get_text() == "Right-column X (top)"
]
assert label_axes and label_axes[0] is ax[0, 1]
assert label_axes and label_axes[0] is _singleton_axis(ax[0, 1])

uplt.close(fig)

Expand All @@ -512,7 +519,7 @@ def test_subset_share_ylabels_implicit_row_right():
for axi, lab in fig._supylabel_dict.items()
if lab.get_text() == "Top-row Y (right)"
]
assert label_axes and label_axes[0] is ax[0, 1]
assert label_axes and label_axes[0] is _singleton_axis(ax[0, 1])

uplt.close(fig)

Expand Down
Loading