Skip to content

Commit 3ff8455

Browse files
authored
Honor cmap for numeric scatter colors (#616)
* Honor cmap for numeric scatter colors Treat 1D numeric scatter c arrays matching the point count as scalar data for colormapping instead of literal RGBA colors. This preserves Nx3/Nx4 explicit color support, keeps the Matplotlib-compatible cmap behavior for numeric values, and adds a regression test for issue #615. * Add type hints to scatter color parsing Annotate the scatter-specific color parsing helpers touched by the cmap compatibility fix so the intent of the new parameters and return values is explicit without broadening the typing changes beyond the affected code path. * Clarify scatter color semantics in docs Document the scatter color ambiguity resolved by the PR: one-dimensional numeric arrays matching the point count are treated as scalar colormap data, while explicit RGB(A) colors should be passed as N x 3 / N x 4 arrays or via color=. * Add return * Tighten scatter helper input types Replace the loose Any annotations on the scatter color parsing helpers with explicit data and color input aliases based on ArrayLike and color tuples. This keeps the typing aligned with the actual ambiguity being resolved by the cmap fix while staying practical for plotting inputs.
1 parent 69e0001 commit 3ff8455

2 files changed

Lines changed: 74 additions & 4 deletions

File tree

ultraplot/axes/plot.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import sys
1212
from collections.abc import Callable, Iterable
1313
from numbers import Integral, Number
14-
from typing import Any, Iterable, Mapping, Optional, Sequence, Union
14+
from typing import Any, Iterable, Mapping, Optional, Sequence, TypeAlias, Union
1515

1616
import matplotlib as mpl
1717
import matplotlib.artist as martist
@@ -29,6 +29,7 @@
2929
import matplotlib.ticker as mticker
3030
import numpy as np
3131
import numpy.ma as ma
32+
from numpy.typing import ArrayLike
3233
from packaging import version
3334

3435
from .. import colors as pcolors
@@ -64,6 +65,12 @@
6465
# This is half of rc['patch.linewidth'] of 0.6. Half seems like a nice default.
6566
EDGEWIDTH = 0.3
6667

68+
DataInput: TypeAlias = ArrayLike
69+
ColorTupleRGB: TypeAlias = tuple[float, float, float]
70+
ColorTupleRGBA: TypeAlias = tuple[float, float, float, float]
71+
ColorInput: TypeAlias = DataInput | str | ColorTupleRGB | ColorTupleRGBA | None
72+
ParsedColor: TypeAlias = DataInput | list[str] | str | None
73+
6774
# Data argument docstrings
6875
_args_1d_docstring = """
6976
*args : {y} or {x}, {y}
@@ -993,7 +1000,10 @@
9931000
: array-like or color-spec, optional
9941001
The marker color(s). If this is an array matching the shape of `x` and `y`,
9951002
the colors are generated using `cmap`, `norm`, `vmin`, and `vmax`. Otherwise,
996-
this should be a valid matplotlib color.
1003+
this should be a valid matplotlib color. To pass explicit RGB(A) colors,
1004+
use an ``N x 3`` or ``N x 4`` array, or pass a single color with `color=`.
1005+
One-dimensional numeric arrays matching the point count are interpreted as
1006+
scalar values for colormapping.
9971007
smin, smax : float, optional
9981008
The minimum and maximum marker size area in units ``points ** 2``. Ignored
9991009
if `absolute_size` is ``True``. Default value for `smin` is ``1`` and for
@@ -3963,7 +3973,17 @@ def _parse_2d_format(
39633973
zs = tuple(map(inputs._to_numpy_array, zs))
39643974
return (x, y, *zs, kwargs)
39653975

3966-
def _parse_color(self, x, y, c, *, apply_cycle=True, infer_rgb=False, **kwargs):
3976+
def _parse_color(
3977+
self,
3978+
x: DataInput,
3979+
y: DataInput,
3980+
c: ColorInput,
3981+
*,
3982+
apply_cycle: bool = True,
3983+
infer_rgb: bool = False,
3984+
force_cmap: bool = False,
3985+
**kwargs: Any,
3986+
) -> tuple[ParsedColor, dict[str, Any]]:
39673987
"""
39683988
Parse either a colormap or color cycler. Colormap will be discrete and fade
39693989
to subwhite luminance by default. Returns a HEX string if needed so we don't
@@ -3972,7 +3992,7 @@ def _parse_color(self, x, y, c, *, apply_cycle=True, infer_rgb=False, **kwargs):
39723992
# NOTE: This function is positioned above the _parse_cmap and _parse_cycle
39733993
# functions and helper functions.
39743994
parsers = (self._parse_cmap, *self._level_parsers)
3975-
if c is None or mcolors.is_color_like(c):
3995+
if c is None or (mcolors.is_color_like(c) and not force_cmap):
39763996
if infer_rgb and c is not None and (isinstance(c, str) and c != "none"):
39773997
c = pcolors.to_hex(c) # avoid scatter() ambiguous color warning
39783998
if apply_cycle: # False for scatter() so we can wait to get correct 'N'
@@ -4000,6 +4020,32 @@ def _parse_color(self, x, y, c, *, apply_cycle=True, infer_rgb=False, **kwargs):
40004020
warnings._warn_ultraplot(f"Ignoring unused keyword arg(s): {pop}")
40014021
return (c, kwargs)
40024022

4023+
def _scatter_c_is_scalar_data(
4024+
self, x: DataInput, y: DataInput, c: ColorInput
4025+
) -> bool:
4026+
"""
4027+
Return whether scatter ``c=`` should be treated as scalar data.
4028+
4029+
Matplotlib treats 1D numeric arrays matching the point count as values to
4030+
be colormapped, even though short float sequences can also look like an
4031+
RGBA tuple to ``is_color_like``. Preserve explicit RGB/RGBA arrays via the
4032+
existing ``N x 3``/``N x 4`` path and reserve this override for the 1D
4033+
numeric case only.
4034+
"""
4035+
if c is None or isinstance(c, str):
4036+
return False
4037+
values = np.asarray(c)
4038+
if values.ndim != 1 or values.size <= 1:
4039+
return False
4040+
if not np.issubdtype(values.dtype, np.number):
4041+
return False
4042+
x = np.atleast_1d(inputs._to_numpy_array(x))
4043+
y = np.atleast_1d(inputs._to_numpy_array(y))
4044+
point_count = x.shape[0]
4045+
if y.shape[0] != point_count:
4046+
return False
4047+
return values.shape[0] == point_count
4048+
40034049
@warnings._rename_kwargs("0.6.0", centers="values")
40044050
def _parse_cmap(
40054051
self,
@@ -5527,6 +5573,7 @@ def _apply_scatter(self, xs, ys, ss, cc, *, vert=True, **kwargs):
55275573
# Only parse color if explicitly provided
55285574
infer_rgb = True
55295575
if cc is not None:
5576+
force_cmap = self._scatter_c_is_scalar_data(xs, ys, cc)
55305577
if not isinstance(cc, str):
55315578
test = np.atleast_1d(cc)
55325579
if (
@@ -5542,6 +5589,7 @@ def _apply_scatter(self, xs, ys, ss, cc, *, vert=True, **kwargs):
55425589
inbounds=inbounds,
55435590
apply_cycle=False,
55445591
infer_rgb=infer_rgb,
5592+
force_cmap=force_cmap,
55455593
**kw,
55465594
)
55475595
# Create the cycler object by manually cycling and sanitzing the inputs

ultraplot/tests/test_1dplots.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
Test 1D plotting overrides.
44
"""
55

6+
import warnings
7+
68
import numpy as np
79
import numpy.ma as ma
810
import pandas as pd
@@ -378,6 +380,26 @@ def test_scatter_edgecolor_single_row():
378380
return fig
379381

380382

383+
def test_scatter_numeric_c_honors_cmap():
384+
"""
385+
Numeric 1D ``c`` arrays should be treated as scalar data for colormapping.
386+
"""
387+
fig, ax = uplt.subplots()
388+
values = np.array([0.1, 0.2, 0.3, 0.4])
389+
with warnings.catch_warnings(record=True) as caught:
390+
warnings.simplefilter("always")
391+
obj = ax.scatter(
392+
[1.0, 2.0, 3.0, 4.0],
393+
[1.0, 2.0, 3.0, 4.0],
394+
c=values,
395+
cmap="turbo",
396+
)
397+
messages = [str(item.message) for item in caught]
398+
assert not any("Ignoring unused keyword arg(s)" in message for message in messages)
399+
assert "turbo" in obj.get_cmap().name
400+
np.testing.assert_allclose(obj.get_array(), values)
401+
402+
381403
@pytest.mark.mpl_image_compare
382404
def test_scatter_inbounds():
383405
"""

0 commit comments

Comments
 (0)