Skip to content

Commit ce73924

Browse files
committed
Honor cmap for numeric scatter colors
Treat 1D numeric scatter c arrays matching the point count as scalar data for colormapping instead of literal RGBA colors. This preserves Nx3/Nx4 explicit color support, keeps the Matplotlib-compatible cmap behavior for numeric values, and adds a regression test for issue #615.
1 parent 69e0001 commit ce73924

2 files changed

Lines changed: 59 additions & 3 deletions

File tree

ultraplot/axes/plot.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3963,7 +3963,17 @@ def _parse_2d_format(
39633963
zs = tuple(map(inputs._to_numpy_array, zs))
39643964
return (x, y, *zs, kwargs)
39653965

3966-
def _parse_color(self, x, y, c, *, apply_cycle=True, infer_rgb=False, **kwargs):
3966+
def _parse_color(
3967+
self,
3968+
x,
3969+
y,
3970+
c,
3971+
*,
3972+
apply_cycle=True,
3973+
infer_rgb=False,
3974+
force_cmap=False,
3975+
**kwargs,
3976+
):
39673977
"""
39683978
Parse either a colormap or color cycler. Colormap will be discrete and fade
39693979
to subwhite luminance by default. Returns a HEX string if needed so we don't
@@ -3972,7 +3982,7 @@ def _parse_color(self, x, y, c, *, apply_cycle=True, infer_rgb=False, **kwargs):
39723982
# NOTE: This function is positioned above the _parse_cmap and _parse_cycle
39733983
# functions and helper functions.
39743984
parsers = (self._parse_cmap, *self._level_parsers)
3975-
if c is None or mcolors.is_color_like(c):
3985+
if c is None or (mcolors.is_color_like(c) and not force_cmap):
39763986
if infer_rgb and c is not None and (isinstance(c, str) and c != "none"):
39773987
c = pcolors.to_hex(c) # avoid scatter() ambiguous color warning
39783988
if apply_cycle: # False for scatter() so we can wait to get correct 'N'
@@ -4000,6 +4010,30 @@ def _parse_color(self, x, y, c, *, apply_cycle=True, infer_rgb=False, **kwargs):
40004010
warnings._warn_ultraplot(f"Ignoring unused keyword arg(s): {pop}")
40014011
return (c, kwargs)
40024012

4013+
def _scatter_c_is_scalar_data(self, x, y, c) -> bool:
4014+
"""
4015+
Return whether scatter ``c=`` should be treated as scalar data.
4016+
4017+
Matplotlib treats 1D numeric arrays matching the point count as values to
4018+
be colormapped, even though short float sequences can also look like an
4019+
RGBA tuple to ``is_color_like``. Preserve explicit RGB/RGBA arrays via the
4020+
existing ``N x 3``/``N x 4`` path and reserve this override for the 1D
4021+
numeric case only.
4022+
"""
4023+
if c is None or isinstance(c, str):
4024+
return False
4025+
values = np.asarray(c)
4026+
if values.ndim != 1 or values.size <= 1:
4027+
return False
4028+
if not np.issubdtype(values.dtype, np.number):
4029+
return False
4030+
x = np.atleast_1d(inputs._to_numpy_array(x))
4031+
y = np.atleast_1d(inputs._to_numpy_array(y))
4032+
point_count = x.shape[0]
4033+
if y.shape[0] != point_count:
4034+
return False
4035+
return values.shape[0] == point_count
4036+
40034037
@warnings._rename_kwargs("0.6.0", centers="values")
40044038
def _parse_cmap(
40054039
self,
@@ -5527,6 +5561,7 @@ def _apply_scatter(self, xs, ys, ss, cc, *, vert=True, **kwargs):
55275561
# Only parse color if explicitly provided
55285562
infer_rgb = True
55295563
if cc is not None:
5564+
force_cmap = self._scatter_c_is_scalar_data(xs, ys, cc)
55305565
if not isinstance(cc, str):
55315566
test = np.atleast_1d(cc)
55325567
if (
@@ -5542,6 +5577,7 @@ def _apply_scatter(self, xs, ys, ss, cc, *, vert=True, **kwargs):
55425577
inbounds=inbounds,
55435578
apply_cycle=False,
55445579
infer_rgb=infer_rgb,
5580+
force_cmap=force_cmap,
55455581
**kw,
55465582
)
55475583
# Create the cycler object by manually cycling and sanitzing the inputs

ultraplot/tests/test_1dplots.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
Test 1D plotting overrides.
44
"""
55

6+
import warnings
7+
68
import numpy as np
79
import numpy.ma as ma
810
import pandas as pd
@@ -375,7 +377,25 @@ def test_scatter_edgecolor_single_row():
375377
len(result3.get_edgecolors()) == 0
376378
), "Single row without alpha should have no edges"
377379

378-
return fig
380+
381+
def test_scatter_numeric_c_honors_cmap():
382+
"""
383+
Numeric 1D ``c`` arrays should be treated as scalar data for colormapping.
384+
"""
385+
fig, ax = uplt.subplots()
386+
values = np.array([0.1, 0.2, 0.3, 0.4])
387+
with warnings.catch_warnings(record=True) as caught:
388+
warnings.simplefilter("always")
389+
obj = ax.scatter(
390+
[1.0, 2.0, 3.0, 4.0],
391+
[1.0, 2.0, 3.0, 4.0],
392+
c=values,
393+
cmap="turbo",
394+
)
395+
messages = [str(item.message) for item in caught]
396+
assert not any("Ignoring unused keyword arg(s)" in message for message in messages)
397+
assert "turbo" in obj.get_cmap().name
398+
np.testing.assert_allclose(obj.get_array(), values)
379399

380400

381401
@pytest.mark.mpl_image_compare

0 commit comments

Comments
 (0)