Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/py4vasp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from py4vasp import demo
from py4vasp import demo, graph
from py4vasp._analysis.mlff import MLFFErrorAnalysis
from py4vasp._batch import Batch
from py4vasp._calculation import Calculation, calculation
Expand Down
269 changes: 222 additions & 47 deletions src/py4vasp/_third_party/graph/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

import dataclasses
import itertools
from typing import Generator, Tuple

import numpy as np

from py4vasp import _config
from py4vasp._third_party.graph import trace
from py4vasp._util import import_
from py4vasp._util.slicing import Plane

ff = import_.optional("plotly.figure_factory")
go = import_.optional("plotly.graph_objects")
Expand All @@ -23,9 +23,60 @@ class Contour(trace.Trace):
"""Represents data on a 2d slice through the unit cell.

This class creates a visualization of the data within the unit cell based on its
configuration. Currently it supports the creation of heatmaps and quiver plots.
For heatmaps each data point corresponds to one point on the grid. For quiver plots
each data point should be a 2d vector within the plane.
configuration. It supports the creation of heatmaps, contour plots, and quiver plots
for 2D data representations.

For scalar data (heatmaps/contours), each data point corresponds to one point on the grid.
For vector data (quiver plots), each data point should be a 2D vector within the plane.

Examples
--------
Create a simple heatmap:

>>> from py4vasp.graph import Lattice, Graph
>>> lattice = Lattice(vectors=np.array([[3.0, 0.0], [0.0, 3.0]]))
>>> data = np.random.rand(50, 50)
>>> contour = Contour(
... data=data,
... lattice=lattice,
... label="Charge Density",
... colorbar_label="e/ų",
... )
>>> Graph(contour).show()

Create a contour plot with isolevels:

>>> contour = Contour(
... data=data,
... lattice=lattice,
... label="Potential",
... isolevels=True,
... show_contour_values=True,
... color_scheme="diverging",
... )
>>> Graph(contour).show()

Create a quiver plot for vector data:

>>> vector_data = np.random.rand(2, 20, 20)
>>> quiver = Contour(
... data=vector_data,
... lattice=lattice,
... label="Current Density",
... max_number_arrows=100,
... )
>>> Graph(quiver).show()

Use a custom color scheme with limits:

>>> contour = Contour(
... data=data,
... lattice=lattice,
... label="Energy",
... color_scheme="positive",
... color_limits=(0, 1.0),
... )
>>> Graph(contour).show()
"""

_interpolation_factor = 2
Expand All @@ -37,58 +88,183 @@ class Contour(trace.Trace):
"""Can be linear or cubic to determine interpolation behavior."""

data: np.array
"""2d or 3d grid data in the plane spanned by the lattice vectors. If the data is
the dimensions should be the ones of the grid, if the data is 3d the first dimension
should be a 2 for a vector in the plane of the grid and the other two dimensions
should be the grid."""
lattice: Plane
"""Lattice plane in which the data is represented spanned by 2 vectors.
Each vector should have two components, so remove any element normal to
the plane. Can be generated with the 'plane' function in py4vasp._util.slicing."""
"""Grid data representing values in the plane spanned by the lattice vectors.

- For scalar data (2D array): Shape should match the grid dimensions (ny, nx).
Used for heatmaps and contour plots.
- For vector data (3D array): Shape should be (2, ny, nx) where the first dimension
contains the x and y components of vectors in the plane. Used for quiver plots.
"""

lattice: "Lattice"
"""Lattice plane defining the visualization coordinate system.

Should contain exactly 2 lattice vectors, each with 2 components (x, y).
Any components normal to the visualization plane should be removed beforehand.
The data grid points are distributed along these lattice vectors.
"""

label: str
"Assign a label to the visualization that may be used to identify one among multiple plots."
"""Descriptive label for this visualization.

Used to identify this plot among multiple visualizations and appears in legends.
Example: "Charge Density", "Potential", "Current Density"
"""

colorbar_label: str = None
"""Label to show at the colorbar."""
"""Label displayed on the colorbar axis.

Typically includes the physical quantity and units.
Example: "e/ų", "eV", "electrons"
If None, no colorbar label is shown.
"""

isolevels: bool = False
"Defines whether isolevels should be added or a heatmap is used."
"""Display mode for scalar data visualization.

- True: Show contour lines with isolevels (constant value curves)
- False: Show heatmap with continuous color gradients (default)
Only applies to 2D scalar data, ignored for vector data.
"""

show_contour_values: bool = None
"Defines whether contour values should be shown along contour plot lines."
"""Whether to display numerical values along contour lines.

Only relevant when isolevels=True. If None, uses plotly's default behavior.
Set to True to show the isovalues, False to hide them.
"""

color_scheme: str = "auto"
"""The color_scheme argument informs the chosen color map and parameters for the contours plot.
It should be chosen according to the nature of the data to be plotted, as one of the following:
- "auto" (Default): py4vasp will try to infer the color scheme on its own.
- "monochrome" OR "stm": Standard colorscheme for STM.
- "sequential": Use a sequential color scheme.
- "positive": Values are only positive. Use a Reds color scheme. Consider setting color_limits=(0, None).
- "diverging": Use a diverging color scheme.
- "negative": Values are only negative. Use a reverse Blues color scheme. Consider setting color_limits=(None, 0).
- "cyclical": Use a cyclical color scheme.
"""Color mapping strategy for the visualization.

Available options:
- "auto" (default): Automatically select based on data range:
* "diverging" if data spans negative and positive values
* "positive" if all values are non-negative
* "negative" if all values are non-positive
* "default" otherwise
- "monochrome" or "stm": Single-color gradient (suitable for STM images)
- "sequential": Perceptually uniform progression (Viridis)
- "positive": Red gradient for non-negative data
- "negative": Reverse blue gradient for non-positive data
- "diverging": Red-white-blue for data crossing zero (RdBu_r)

Choose based on your data's physical meaning and range.
"""

color_limits: tuple = None
"""Is a tuple that sets the minimum and maximum of the color scale. Can be:
- None | (None, None): No limits are imposed.
- (float, None): Sets the minimum of the color scale.
- (None, float): Sets the maximum of the color scale.
- (float, float): Sets minimum and maximum of the color scale."""
"""Explicit bounds for the color scale mapping.

Controls which data values map to the minimum and maximum colors:
- None or (None, None): Use data's actual min/max (default)
- (vmin, None): Set minimum, auto-detect maximum
- (None, vmax): Auto-detect minimum, set maximum
- (vmin, vmax): Set both bounds explicitly

Useful for:
- Comparing multiple plots on the same scale
- Emphasizing specific data ranges
- Clipping outliers
"""

traces_as_periodic: bool = False
"""If True, traces (contour and quiver) are shifted so that quiver and heatmap 'cell'
centers align with the positions they were computed at. Periodic images will be drawn
so that the supercell still appears completely covered on all sides.
"""Alignment mode for visualization elements relative to the computational grid.

- True: Align visualization elements (contours, arrows) with the actual grid points
where data was computed. Periodic images are drawn to fully cover the supercell,
providing physically accurate representation but potentially less aesthetic appearance.

- False (default): Align heatmap cells with supercell boundaries for clean visual
appearance. Grid points appear at cell corners rather than centers. No periodic
images needed, but may be slightly misleading about where data was computed.

Recommended: True for quantitative analysis, False for presentation graphics.
"""

If False, traces (contour and quiver) are shifted so that the heatmap cells visually
align with the supercell instead. No periodic images are required, but the visual
presentation might be misleading."""
supercell: np.array = (1, 1)
"Multiple of each lattice vector to be drawn."
"""Number of unit cell repetitions along each lattice vector.

Array of 2 integers (na, nb) specifying how many times to tile the unit cell
along the first and second lattice vectors. Useful for visualizing periodic
patterns or showing context around a single cell.
Default (1, 1) shows one unit cell.
"""

show_cell: bool = True
"Show the unit cell in the resulting visualization."
"""Whether to draw the unit cell boundaries.

- True (default): Draw outline of the unit cell as a box/parallelogram
- False: Hide unit cell boundaries

Helpful for understanding the periodicity and lattice geometry.
"""

max_number_arrows: int = None
"Subsample the data until the number of arrows falls below this limit."
"""Maximum arrow count for quiver plots (vector data only).

If the vector field grid has more points than this limit, data is automatically
subsampled to reduce visual clutter. Subsampling is done uniformly in both
directions, attempting to keep arrows evenly distributed.

None (default): Show all arrows without subsampling.
Recommended: ~100-500 for readable visualizations.
"""

scale_arrows: float = None
"""Scale arrows by this factor when converting their length to Å. None means
autoscale them so that the arrows do not overlap."""
"""Arrow length scaling factor for quiver plots (vector data only).

Multiplier applied to vector magnitudes when converting to visual arrow lengths:
- None (default): Automatically scale so the longest arrow equals the smaller
of the two grid spacings (prevents overlap)
- float > 0: Manual scaling factor. Larger values = longer arrows.
Value of 1.0 means arrow length in Ångströms equals vector magnitude.

def to_plotly(self):
Use manual scaling when comparing multiple plots or adjusting readability.
"""

def to_plotly(self) -> Generator[Tuple["Contour" | "Heatmap", dict], None, None]:
"""
Convert the contour data to Plotly figure format.

This method transforms the internal data representation into a format suitable
for visualization with Plotly. It handles three types of plots: contour plots,
heatmaps, and quiver (vector field) plots, depending on the data configuration.

Returns
-------
-
A generator yielding tuples of (plot_data, options) where plot_data is the
Plotly-compatible data structure and options contains the plot configuration.

Notes
-----
The method tiles the data according to the supercell dimensions and transposes
it to match Plotly's expected data layout (swapping a and b axes). The lattice
vectors are scaled by the supercell dimensions before creating the plot.

Examples
--------
Generate a monochrome heatmap

>>> from py4vasp.graph import Lattice
>>> lattice = Lattice(vectors=np.array([[3.0, 0.0], [0.0, 3.0]]))
>>> data = np.random.rand(50, 50)
>>> heatmap = Contour(data, lattice, "heatmap", color_scheme="monochrome")
>>> for plot_data, options in heatmap.to_plotly():
... print(plot_data)
... print(options)
Heatmap(...)
{'shapes': ..., 'annotations': ...}

Create a contour plot

>>> contour = Contour(data, lattice, "contour", isolevels=True)
>>> for plot_data, options in contour.to_plotly():
... print(plot_data)
... print(options)
Contour(...)
{'shapes': ..., 'annotations': ...}
"""
lattice_supercell = np.diag(self.supercell) @ self.lattice.vectors
# swap a and b axes because that is the way plotly expects the data
data = np.tile(self.data, self.supercell).T
Expand Down Expand Up @@ -292,10 +468,9 @@ def _use_data_without_interpolation(self, lattice, data):
)

def _get_periodic_left(self, periodic_expand: int) -> int:
"""When we periodically expand the data, we may wish to do so symmetrically.
This function returns the integer number of points to prepend to the line mesh,
data row or column. Generally, line meshes will need to be shifted by this number.
"""
# When we periodically expand the data, we may wish to do so symmetrically.
# This function returns the integer number of points to prepend to the line mesh,
# data row or column. Generally, line meshes will need to be shifted by this number.
periodic_left = 0
if (self.traces_as_periodic) and (periodic_expand > 1):
periodic_left = int(np.floor((periodic_expand - 1) / 2))
Expand Down Expand Up @@ -331,7 +506,7 @@ def _make_mesh(self, lattice, num_point, index, periodic_expand: int = 1):
return mesh

def _mask_outside_supercell(self, x_out, y_out, z_out, lattice_supercell):
"""Mask points that are outside the supercell area."""
# Mask points that are outside the supercell area.
# Convert Cartesian coordinates to lattice coordinates
# lattice_supercell has vectors as rows, so we need its inverse
lattice_inv = np.linalg.inv(lattice_supercell)
Expand Down
Loading
Loading