# Copyright (c) DataLab Platform Developers, BSD 3-Clause License
# See LICENSE file for details
"""
Matplotlib Backend
==================
Matplotlib-based visualization backend for the DataLab kernel.
This module provides static PNG rendering for signals and images using
matplotlib with the Agg backend. It is the default backend and works
in all environments (JupyterLite, standard Jupyter, scripts).
Features include:
- Single and multi-signal plotting with error bars, curve styles, and legend
- Single and multi-image plotting in grid layout with colorbars
- ROI (Region of Interest) overlays
- Geometry result overlays (points, markers, rectangles, circles, ellipses,
segments, polygons)
- Table/geometry result annotation text boxes
- Mask visualization with semi-transparent overlay
- Axis labels with units, log scale, and axis bounds
- Colormap support with LUT range
- Non-uniform image coordinates (pcolormesh)
Usage::
from datalab_kernel.matplotlib_backend import MatplotlibPlotter
plotter = MatplotlibPlotter(workspace)
plotter.plot("s001") # Single signal
plotter.plot([sig1, sig2]) # Multiple signals
plotter.plot([img1, img2]) # Multiple images
plotter.display_table(table_result) # HTML table
plotter.display_geometry(geom_result) # HTML table
"""
from __future__ import annotations
import base64
import io
from typing import TYPE_CHECKING
import numpy as np
from datalab_kernel.plotter import (
DEFAULT_PLOT_WIDTH,
MASK_OPACITY,
GeometryResultDisplay,
TableResultDisplay,
_apply_axis_bounds,
_apply_log_scale,
_build_results_html,
_extract_geometry_results_from_metadata,
_extract_table_results_from_metadata,
_get_curve_style,
_get_geometry_coord_labels,
_get_image_colormap,
_get_image_extent_and_aspect,
_get_image_lut_range,
_get_signal_style_from_metadata,
_is_non_uniform_image,
)
if TYPE_CHECKING:
from matplotlib.axes import Axes
from datalab_kernel.workspace import DataObject, Workspace
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
#: Color palette for multi-signal/multi-image cycling.
COLORS = ["blue", "red", "green", "orange", "purple", "brown", "pink", "gray", "olive"]
#: Linestyle palette for multi-signal cycling.
LINESTYLES = ["-", "--", "-.", ":"]
#: Color palette used to cycle through Signal ROI fill colors when several
#: ROIs are defined on the same signal. Matches the matplotlib ``tab10``
#: palette used in DataLab so that ROI colors stay consistent across both
#: applications.
ROI_FILL_COLORS = (
"#1f77b4", # blue
"#ff7f0e", # orange
"#2ca02c", # green
"#d62728", # red
"#9467bd", # purple
"#8c564b", # brown
"#e377c2", # pink
"#7f7f7f", # grey
"#bcbd22", # yellow-green
"#17becf", # cyan
)
#: Translucency (matplotlib alpha) used for the ROI fill.
ROI_FILL_ALPHA = 0.35
def roi_color_for_index(index: int) -> str:
"""Return the ROI fill color (hex string) for the given ROI index.
The index is taken modulo the palette size so colors cycle when more
ROIs than palette entries are defined.
"""
return ROI_FILL_COLORS[index % len(ROI_FILL_COLORS)]
# ---------------------------------------------------------------------------
# Internal helpers (matplotlib-specific)
# ---------------------------------------------------------------------------
def _get_next_style(index: int) -> tuple[str, str]:
"""Get color and linestyle for the next plot item.
Args:
index: Sequential index of the item to style
Returns:
A tuple (color, linestyle) for styling the plot item
"""
color = COLORS[index % len(COLORS)]
linestyle = LINESTYLES[(index // len(COLORS)) % len(LINESTYLES)]
return color, linestyle
def _add_table_results_to_axes(
ax: Axes, table_results: list, geometry_results: list | None = None
) -> None:
"""Add table and geometry results as text annotation to matplotlib axes.
Formats TableResult and GeometryResult objects as a text box displayed in
the upper-left corner of the axes, similar to DataLab's result label display.
Args:
ax: Matplotlib axes object
table_results: List of TableResult objects to display
geometry_results: Optional list of GeometryResult objects to display
"""
if not table_results and not geometry_results:
return
# Build text content from all results
text_lines = []
# Add table results first (statistics)
for table in table_results:
# Add table title as header
text_lines.append(f"{table.title}:")
# Get headers and data
headers = list(table.headers)
data = table.data
# Format each row (typically just one row for statistics)
for row in data:
for header, value in zip(headers, row):
# Format numeric values
if isinstance(value, float):
if abs(value) < 0.001 or abs(value) >= 10000:
formatted = f"{value:.3g}"
else:
formatted = f"{value:.3f}"
else:
formatted = str(value)
text_lines.append(f" {header}: {formatted}")
text_lines.append("") # Empty line between results
# Add geometry results after table results
if geometry_results:
for geometry in geometry_results:
text_lines.append(f"{geometry.title}:")
text_lines.append(" Value")
# Get coordinate labels based on geometry kind
coord_labels = _get_geometry_coord_labels(geometry)
# Display first row of coords (most geometry results have one row)
if len(geometry.coords) > 0:
coords = (
geometry.coords[0] if geometry.coords.ndim > 1 else geometry.coords
)
for label, value in zip(coord_labels, coords):
if isinstance(value, float):
if abs(value) < 0.001 or abs(value) >= 10000:
formatted = f"{value:.3g}"
else:
formatted = f"{value:.3f}"
else:
formatted = str(value)
text_lines.append(f" {label}: {formatted}")
text_lines.append("") # Empty line between results
# Remove trailing empty line
if text_lines and text_lines[-1] == "":
text_lines.pop()
text = "\n".join(text_lines)
# Add text box annotation in upper-left corner
ax.text(
0.02,
0.98,
text,
transform=ax.transAxes,
fontsize=9,
verticalalignment="top",
horizontalalignment="left",
fontfamily="monospace",
bbox={
"boxstyle": "round,pad=0.5",
"facecolor": "white",
"edgecolor": "gray",
"alpha": 0.85,
},
)
def _add_single_roi_to_axes(ax: Axes, roi, obj=None, roi_index: int = 0) -> None:
"""Add single ROI overlay to matplotlib axes.
Args:
ax: Matplotlib axes object
roi: Single ROI object (SegmentROI, RectangularROI, CircularROI, or
PolygonalROI)
obj: Parent object (used for SegmentROI to get physical coordinates
and the underlying curve for curve-clipped fill)
roi_index: 0-based ROI index used to pick the cycling fill color for
signal ROIs (so several ROIs on the same signal are visually
distinguishable).
"""
# Delayed import
# pylint: disable=import-outside-toplevel
from matplotlib import patches
roi_class = type(roi).__name__
roi_title = getattr(roi, "title", "")
roi_label = roi_title if roi_title else "ROI"
if roi_class == "RectangularROI":
# coords = [x0, y0, dx, dy]
x0, y0, dx, dy = roi.coords
rect = patches.Rectangle(
(x0, y0),
dx,
dy,
linewidth=2,
edgecolor="red",
facecolor="none",
label=roi_label,
)
ax.add_patch(rect)
ax.annotate(
roi_label,
xy=(x0 + dx / 2, y0),
fontsize=8,
color="red",
ha="center",
va="bottom",
)
elif roi_class == "CircularROI":
# coords = [xc, yc, r]
xc, yc, r = roi.coords
circle = patches.Circle(
(xc, yc),
r,
linewidth=2,
edgecolor="red",
facecolor="none",
label=roi_label,
)
ax.add_patch(circle)
ax.annotate(
roi_label,
xy=(xc, yc - r),
fontsize=8,
color="red",
ha="center",
va="bottom",
)
elif roi_class == "PolygonalROI":
# coords = [x0, y0, x1, y1, x2, y2, ...]
points = roi.coords.reshape(-1, 2)
polygon = patches.Polygon(
points,
closed=True,
linewidth=2,
edgecolor="red",
facecolor="none",
label=roi_label,
)
ax.add_patch(polygon)
ax.annotate(
roi_label,
xy=(points[:, 0].mean(), points[:, 1].min()),
fontsize=8,
color="red",
ha="center",
va="bottom",
)
elif roi_class == "SegmentROI" and obj is not None:
# Signal ROI: X interval, filled along the signal curve with a
# baseline at y=0 (mirrors DataLab's curve-clipped ROI rendering).
x0, x1 = roi.get_physical_coords(obj)
color = roi_color_for_index(roi_index)
x_arr = np.asarray(getattr(obj, "x", None), dtype=float)
y_arr = np.asarray(getattr(obj, "y", None), dtype=float)
if x_arr.size >= 2 and y_arr.size == x_arr.size:
finite = np.isfinite(x_arr) & np.isfinite(y_arr)
x_arr = x_arr[finite]
y_arr = y_arr[finite]
if x_arr.size >= 2 and y_arr.size == x_arr.size:
order = np.argsort(x_arr)
x_arr = x_arr[order]
y_arr = y_arr[order]
x_lo = max(float(x_arr[0]), min(x0, x1))
x_hi = min(float(x_arr[-1]), max(x0, x1))
if x_hi > x_lo:
mask = (x_arr >= x_lo) & (x_arr <= x_hi)
xs_in = x_arr[mask]
ys_in = y_arr[mask]
y_left = float(np.interp(x_lo, x_arr, y_arr))
y_right = float(np.interp(x_hi, x_arr, y_arr))
xs = np.concatenate(([x_lo], xs_in, [x_hi]))
ys = np.concatenate(([y_left], ys_in, [y_right]))
ax.fill_between(
xs,
ys,
0.0,
color=color,
alpha=ROI_FILL_ALPHA,
linewidth=0,
label=roi_label,
)
return
# Fallback: full-height vertical strip when curve data is unusable
ax.axvspan(x0, x1, color=color, alpha=ROI_FILL_ALPHA, label=roi_label)
def _add_geometry_to_axes(ax: Axes, result) -> None:
"""Add geometry result overlay to matplotlib axes.
Iterates over all rows in result.coords to draw each geometric shape.
Supports POINT, MARKER, RECTANGLE, CIRCLE, SEGMENT, ELLIPSE, and POLYGON.
A compact text label is placed near each shape showing the result title
and key value (e.g., ``"FWHM: 0.52"`` near a segment midpoint).
Args:
ax: Matplotlib axes object
result: GeometryResult object with shape information (coords is 2D array)
"""
# Delayed import
# pylint: disable=import-outside-toplevel
from matplotlib import patches
from sigima.objects import KindShape
label_title = getattr(result, "title", "")
def _fmt(v: float) -> str:
"""Format a float value compactly."""
if abs(v) < 0.001 or abs(v) >= 10000:
return f"{v:.3g}"
return f"{v:.3f}"
def _add_label(x: float, y: float, text: str) -> None:
"""Add a small text label near a geometry shape."""
ax.annotate(
text,
xy=(x, y),
fontsize=8,
fontfamily="sans-serif",
color="#333",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": "#ffffc8",
"edgecolor": "#c8c800",
"alpha": 0.8,
},
)
# Iterate over all rows in coords (each row is one shape)
for coords in result.coords:
if result.kind == KindShape.POINT:
x0, y0 = coords
ax.plot(
x0,
y0,
marker="o",
markersize=6,
color="yellow",
markeredgecolor="black",
markeredgewidth=1,
)
_add_label(x0, y0, f"{label_title}: ({_fmt(x0)}, {_fmt(y0)})")
elif result.kind == KindShape.MARKER:
x0, y0 = coords
# Marker with crosshair style
ax.axhline(y0, color="yellow", linestyle="--", linewidth=1, alpha=0.7)
ax.axvline(x0, color="yellow", linestyle="--", linewidth=1, alpha=0.7)
ax.plot(
x0,
y0,
marker="+",
markersize=10,
color="yellow",
markeredgewidth=2,
)
_add_label(x0, y0, f"{label_title}: ({_fmt(x0)}, {_fmt(y0)})")
elif result.kind == KindShape.RECTANGLE:
x0, y0, dx, dy = coords
rect = patches.Rectangle(
(x0, y0),
dx,
dy,
linewidth=2,
edgecolor="yellow",
facecolor="none",
linestyle="--",
)
ax.add_patch(rect)
elif result.kind == KindShape.CIRCLE:
xc, yc, r = coords
circle = patches.Circle(
(xc, yc),
r,
linewidth=2,
edgecolor="yellow",
facecolor="none",
linestyle="--",
)
ax.add_patch(circle)
_add_label(xc + r, yc, f"{label_title}: r={_fmt(r)}")
elif result.kind == KindShape.SEGMENT:
x0, y0, x1, y1 = coords
ax.plot([x0, x1], [y0, y1], "y--", linewidth=2)
length = ((x1 - x0) ** 2 + (y1 - y0) ** 2) ** 0.5
mx, my = (x0 + x1) / 2, (y0 + y1) / 2
_add_label(mx, my, f"{label_title}: {_fmt(length)}")
elif result.kind == KindShape.ELLIPSE:
# For ellipse, coords are (xc, yc, a, b, theta)
xc, yc, a, b, theta = coords
ellipse = patches.Ellipse(
(xc, yc),
2 * a,
2 * b,
angle=np.degrees(theta),
linewidth=2,
edgecolor="yellow",
facecolor="none",
linestyle="--",
)
ax.add_patch(ellipse)
elif result.kind == KindShape.POLYGON:
x = coords[::2]
y = coords[1::2]
ax.plot(x, y, "y--", linewidth=2, marker="o", markersize=4)
# ============================================================================
# Result classes
# ============================================================================
class MplPlotResult:
"""Result of a single-object plot rendered via matplotlib.
Supports Jupyter's rich display protocol: ``_repr_png_()`` returns
static PNG bytes and ``_repr_html_()`` returns an embedded PNG in HTML.
"""
def __init__(
self,
obj: DataObject,
title: str | None = None,
show_roi: bool = True,
show_results: bool = True,
results: list | None = None,
**kwargs,
) -> None:
"""Initialize plot result.
Args:
obj: Object to display
title: Plot title
show_roi: Whether to show ROIs
show_results: Whether to show geometry/table results from metadata
results: Optional list of GeometryResult objects to overlay (for images)
**kwargs: Additional options (e.g., ``colormap``,
``height`` to override the default figure height in pixels)
"""
self._obj = obj
self._title = title
self._show_roi = show_roi
self._show_results = show_results
self._results = results
self._kwargs = kwargs
self._results_html = ""
def _ipython_display_(self, **kwargs) -> None:
"""Display figure and results as separate outputs in Jupyter."""
from IPython.display import ( # pylint: disable=import-outside-toplevel
HTML,
display,
)
# Force a fresh render so _results_html is populated
try:
html = self._repr_html_()
display(HTML(html))
except Exception as exc: # pylint: disable=broad-exception-caught
display(HTML(f"<div>Error rendering plot: {exc}</div>"))
return
if self._results_html:
display(HTML(self._results_html))
def _repr_html_(self) -> str:
"""Return HTML representation for Jupyter display."""
obj_type = type(self._obj).__name__
title = self._title or getattr(self._obj, "title", "Untitled")
if obj_type in ("SignalObj", "ImageObj"):
try:
png_data = self._render_to_png()
b64_data = base64.b64encode(png_data).decode("utf-8")
return (
'<div style="text-align: center;">'
f"<h4>{title}</h4>"
f'<img src="data:image/png;base64,{b64_data}" />'
"</div>"
)
except Exception as e: # pylint: disable=broad-exception-caught
return f"<div>Error rendering {obj_type}: {e}</div>"
return f"<div><strong>{title}</strong>: {obj_type}</div>"
def _repr_png_(self) -> bytes:
"""Return PNG representation for Jupyter display."""
return self._render_to_png()
def _render_to_png(self) -> bytes:
"""Render object to PNG bytes using matplotlib."""
# Delayed import: matplotlib is optional and heavy
# pylint: disable=import-outside-toplevel
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
obj_type = type(self._obj).__name__
title = self._title or getattr(self._obj, "title", "")
# Figure size: honour user height override (pixels → inches at 100 dpi)
_dpi = 100
_h_in = self._kwargs.get("height", 500) / _dpi
_w_in = DEFAULT_PLOT_WIDTH / _dpi
fig, ax = plt.subplots(figsize=(_w_in, _h_in))
if obj_type == "SignalObj":
self._render_signal(ax)
elif obj_type == "ImageObj":
self._render_image(ax, fig)
ax.set_title(title)
ax.grid(True, alpha=0.3)
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=100, pil_kwargs={"compress_level": 1})
plt.close(fig)
buf.seek(0)
return buf.read()
def _render_signal(self, ax: Axes) -> None:
"""Render signal data to axes.
Supports error bars (dx/dy), curve styles (Lines/Sticks/Steps),
per-object line color/style from metadata, log scale, and axis bounds.
Args:
ax: Matplotlib axes object
"""
obj = self._obj
x = obj.x
y = obj.y
# Determine base style: metadata overrides > defaults
base_style = {"linewidth": 1, "color": "blue"}
meta_style = _get_signal_style_from_metadata(obj)
base_style.update(meta_style)
# Determine curve style
curvestyle = _get_curve_style(obj)
has_errorbars = obj.dy is not None or obj.dx is not None
if curvestyle == "Sticks":
stem_color = base_style.get("color", "blue")
ax.stem(
x,
y,
linefmt=f"{stem_color}",
markerfmt=" ",
basefmt="k-",
)
elif curvestyle == "Steps":
ax.step(x, y, where="mid", **base_style)
elif has_errorbars:
# Use errorbar instead of plot when error data is available
ax.errorbar(
x,
y,
yerr=obj.dy,
xerr=obj.dx,
fmt="-",
linewidth=base_style.get("linewidth", 1),
color=base_style.get("color", "blue"),
capsize=3,
elinewidth=0.8,
)
else:
ax.plot(x, y, **base_style)
# Axis labels with units
xlabel = getattr(obj, "xlabel", None) or "X"
ylabel = getattr(obj, "ylabel", None) or "Y"
xunit = getattr(obj, "xunit", None)
yunit = getattr(obj, "yunit", None)
if xunit:
xlabel = f"{xlabel} ({xunit})"
if yunit:
ylabel = f"{ylabel} ({yunit})"
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
# Log scale
_apply_log_scale(ax, obj)
# Axis bounds (when autoscale is disabled)
_apply_axis_bounds(ax, obj)
# Show ROIs
if self._show_roi and hasattr(obj, "roi") and obj.roi:
for roi_idx, roi in enumerate(obj.roi):
_add_single_roi_to_axes(ax, roi, obj, roi_index=roi_idx)
# Auto-extract and display geometry/table results from object metadata
if self._show_results:
metadata_results = _extract_geometry_results_from_metadata(obj)
for result in metadata_results:
_add_geometry_to_axes(ax, result)
table_results = _extract_table_results_from_metadata(obj)
self._results_html = _build_results_html(table_results, metadata_results)
def _render_image(self, ax: Axes, fig) -> None:
"""Render image data to axes.
Supports per-object colormap from metadata, LUT range (vmin/vmax),
non-uniform coordinates (pcolormesh), log scale, and axis bounds.
Args:
ax: Matplotlib axes object
fig: Matplotlib figure object
"""
# pylint: disable=import-outside-toplevel
import matplotlib.pyplot as plt
obj = self._obj
data = obj.data
if np.iscomplexobj(data):
data = np.abs(data)
# Colormap: explicit kwarg > object metadata > default
colormap = _get_image_colormap(obj, self._kwargs)
# LUT range from object attributes
vmin, vmax = _get_image_lut_range(obj)
# Check for non-uniform coordinates
if _is_non_uniform_image(obj):
# Use pcolormesh for non-uniform grids
im = ax.pcolormesh(
obj.xcoords,
obj.ycoords,
data,
cmap=colormap,
vmin=vmin,
vmax=vmax,
shading="auto",
)
ax.set_aspect("auto")
ax.invert_yaxis()
else:
# Compute extent and aspect ratio from physical coordinates
extent, aspect_ratio = _get_image_extent_and_aspect(obj)
im = ax.imshow(
data,
aspect=aspect_ratio,
origin="upper",
cmap=colormap,
extent=extent,
vmin=vmin,
vmax=vmax,
interpolation="nearest",
)
# Overlay mask if present
if hasattr(obj, "maskdata") and obj.maskdata is not None:
mask = obj.maskdata
mask_rgba = np.zeros((*mask.shape, 4), dtype=np.float32)
mask_rgba[mask, :] = [1, 0, 0, MASK_OPACITY]
ax.imshow(
mask_rgba,
origin="upper",
extent=extent,
interpolation="nearest",
)
# Colorbar with label
zlabel = getattr(obj, "zlabel", None) or ""
zunit = getattr(obj, "zunit", None)
cbar = plt.colorbar(im, ax=ax)
if zlabel:
cbar.set_label(f"{zlabel} ({zunit})" if zunit else zlabel)
# Axis labels
xlabel = getattr(obj, "xlabel", None) or "X"
ylabel = getattr(obj, "ylabel", None) or "Y"
xunit = getattr(obj, "xunit", None)
yunit = getattr(obj, "yunit", None)
if xunit:
xlabel = f"{xlabel} ({xunit})"
if yunit:
ylabel = f"{ylabel} ({yunit})"
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
# Log scale
_apply_log_scale(ax, obj)
# Axis bounds (when autoscale is disabled)
_apply_axis_bounds(ax, obj)
# Show ROIs
if self._show_roi and hasattr(obj, "roi") and obj.roi:
for roi_idx, roi in enumerate(obj.roi):
_add_single_roi_to_axes(ax, roi, obj, roi_index=roi_idx)
# Overlay geometry/table results (explicit or from metadata)
if self._show_results:
results_to_display = []
if self._results is not None:
result_list = (
self._results
if isinstance(self._results, (list, tuple))
else [self._results]
)
results_to_display.extend(result_list)
# Auto-extract geometry results from object metadata
metadata_results = _extract_geometry_results_from_metadata(obj)
results_to_display.extend(metadata_results)
for result in results_to_display:
_add_geometry_to_axes(ax, result)
# Auto-extract and display table results (statistics) from metadata
table_results = _extract_table_results_from_metadata(obj)
self._results_html = _build_results_html(table_results, results_to_display)
def __repr__(self) -> str:
"""Return string representation."""
obj_type = type(self._obj).__name__
title = self._title or getattr(self._obj, "title", "Untitled")
return f"MplPlotResult({obj_type}: {title})"
class MplMultiSignalResult:
"""Result of a multi-signal plot rendered via matplotlib.
Supports plotting multiple SignalObj, numpy arrays, or (x, y) tuples
on a single plot with automatic styling.
"""
def __init__(
self,
objs: list,
title: str | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
xunit: str | None = None,
yunit: str | None = None,
show_roi: bool = True,
show_results: bool = True,
**kwargs,
) -> None:
"""Initialize multi-signal plot result.
Args:
objs: List of objects to display (SignalObj, ndarray, or (x, y) tuples)
title: Plot title
xlabel: Label for the x-axis
ylabel: Label for the y-axis
xunit: Unit for the x-axis
yunit: Unit for the y-axis
show_roi: Whether to show ROIs
show_results: Whether to show geometry/table results from metadata
**kwargs: Additional options
"""
self._objs = objs
self._title = title
self._xlabel = xlabel
self._ylabel = ylabel
self._xunit = xunit
self._yunit = yunit
self._show_roi = show_roi
self._show_results = show_results
self._kwargs = kwargs
self._results_html = ""
def _ipython_display_(self, **kwargs) -> None:
"""Display figure and results as separate outputs in Jupyter."""
from IPython.display import ( # pylint: disable=import-outside-toplevel
HTML,
display,
)
try:
html = self._repr_html_()
display(HTML(html))
except Exception as exc: # pylint: disable=broad-exception-caught
display(HTML(f"<div>Error rendering signals: {exc}</div>"))
return
if self._results_html:
display(HTML(self._results_html))
def _repr_html_(self) -> str:
"""Return HTML representation for Jupyter display."""
try:
png_data = self._render_to_png()
b64_data = base64.b64encode(png_data).decode("utf-8")
title = self._title or "Signals"
return (
'<div style="text-align: center;">'
f"<h4>{title}</h4>"
f'<img src="data:image/png;base64,{b64_data}" />'
"</div>"
)
except Exception as e: # pylint: disable=broad-exception-caught
return f"<div>Error rendering signals: {e}</div>"
def _repr_png_(self) -> bytes:
"""Return PNG representation for Jupyter display."""
return self._render_to_png()
def _render_to_png(self) -> bytes:
"""Render signals to PNG bytes using matplotlib."""
# pylint: disable=import-outside-toplevel
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(DEFAULT_PLOT_WIDTH / 100, 6))
if self._title:
fig.suptitle(self._title)
# Track labels/units from first SignalObj
x_label = self._xlabel
y_label = self._ylabel
x_unit = self._xunit
y_unit = self._yunit
all_tbl_results: list = []
all_geo_results: list = []
for idx, data_or_obj in enumerate(self._objs):
color, linestyle = _get_next_style(idx)
obj_type = type(data_or_obj).__name__
if obj_type == "SignalObj":
obj = data_or_obj
xdata = obj.x
ydata = obj.y
label = obj.title or f"Signal {idx + 1}"
# Update labels/units from first SignalObj
if idx == 0:
x_label = x_label or getattr(obj, "xlabel", None) or ""
y_label = y_label or getattr(obj, "ylabel", None) or ""
x_unit = x_unit or getattr(obj, "xunit", None) or ""
y_unit = y_unit or getattr(obj, "yunit", None) or ""
# Determine style: metadata overrides > cycling defaults
plot_style = {"color": color, "linestyle": linestyle}
meta_style = _get_signal_style_from_metadata(obj)
plot_style.update(meta_style)
# Determine curve style
curvestyle = _get_curve_style(obj)
has_errorbars = obj.dy is not None or obj.dx is not None
if curvestyle == "Sticks":
stem_color = plot_style.get("color", color)
ax.stem(
xdata,
ydata,
linefmt=f"{stem_color}",
markerfmt=" ",
basefmt="k-",
label=label,
)
elif curvestyle == "Steps":
ax.step(xdata, ydata, where="mid", label=label, **plot_style)
elif has_errorbars:
ax.errorbar(
xdata,
ydata,
yerr=obj.dy,
xerr=obj.dx,
fmt="-",
linewidth=plot_style.get("linewidth", 1),
color=plot_style.get("color", color),
capsize=3,
elinewidth=0.8,
label=label,
)
else:
ax.plot(xdata, ydata, label=label, **plot_style)
# Log scale (apply from first SignalObj)
if idx == 0:
_apply_log_scale(ax, obj)
_apply_axis_bounds(ax, obj)
# Plot ROIs if requested
if self._show_roi and hasattr(obj, "roi") and obj.roi:
for roi_idx, single_roi in enumerate(obj.roi):
x0, x1 = single_roi.get_physical_coords(obj)
roi_title = getattr(single_roi, "title", "")
roi_label = (
roi_title if roi_title else f"{label} ROI {roi_idx + 1}"
)
ax.axvspan(
x0,
x1,
alpha=0.2,
color=plot_style.get("color", color),
label=roi_label if roi_idx == 0 else None,
)
# Auto-extract and display geometry/table results from metadata
if self._show_results:
metadata_results = _extract_geometry_results_from_metadata(obj)
for result in metadata_results:
_add_geometry_to_axes(ax, result)
table_results = _extract_table_results_from_metadata(obj)
all_tbl_results.extend(table_results)
all_geo_results.extend(metadata_results)
elif isinstance(data_or_obj, tuple) and len(data_or_obj) == 2:
# Tuple of (x, y) arrays
xdata, ydata = data_or_obj
ax.plot(
xdata,
ydata,
color=color,
linestyle=linestyle,
label=f"Signal {idx + 1}",
)
elif isinstance(data_or_obj, np.ndarray):
# Just y data, use indices for x
ydata = data_or_obj
xdata = np.arange(len(ydata))
ax.plot(
xdata,
ydata,
color=color,
linestyle=linestyle,
label=f"Signal {idx + 1}",
)
else:
raise TypeError(f"Unsupported data type: {type(data_or_obj)}")
# Build HTML for results below the plot
self._results_html = _build_results_html(all_tbl_results, all_geo_results)
# Set axis labels with units
if x_label:
ax.set_xlabel(f"{x_label} ({x_unit})" if x_unit else x_label)
if y_label:
ax.set_ylabel(f"{y_label} ({y_unit})" if y_unit else y_label)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=100, pil_kwargs={"compress_level": 1})
plt.close(fig)
buf.seek(0)
return buf.read()
def __repr__(self) -> str:
"""Return string representation."""
return f"MplMultiSignalResult({len(self._objs)} signals)"
class MplMultiImageResult:
"""Result of a multi-image plot rendered via matplotlib.
Supports plotting multiple ImageObj or numpy arrays in a grid layout
with automatic styling, ROI overlays, and geometry result overlays.
"""
def __init__(
self,
objs: list,
title: str | None = None,
titles: list[str] | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
zlabel: str | None = None,
xunit: str | None = None,
yunit: str | None = None,
zunit: str | None = None,
show_roi: bool = True,
show_results: bool = True,
results: list | None = None,
rows: int | None = None,
share_axes: bool = True,
**kwargs,
) -> None:
"""Initialize multi-image plot result.
Args:
objs: List of objects to display (ImageObj or ndarray)
title: Overall figure title
titles: Optional list of titles for each image
xlabel: Label for the x-axis
ylabel: Label for the y-axis
zlabel: Label for the colorbar
xunit: Unit for the x-axis
yunit: Unit for the y-axis
zunit: Unit for the colorbar
show_roi: Whether to show ROIs
show_results: Whether to show geometry/table results from metadata
results: Optional list of GeometryResult objects to overlay
rows: Fixed number of rows in the grid, or None to compute automatically
share_axes: Whether to share axes across plots
**kwargs: Additional options (e.g., ``colormap``,
``height`` to override default per-subplot height in pixels)
"""
self._objs = objs
self._title = title
self._titles = titles
self._xlabel = xlabel
self._ylabel = ylabel
self._zlabel = zlabel
self._xunit = xunit
self._yunit = yunit
self._zunit = zunit
self._show_roi = show_roi
self._show_results = show_results
self._results = results
self._rows = rows
self._share_axes = share_axes
self._kwargs = kwargs
self._results_html = ""
def _ipython_display_(self, **kwargs) -> None:
"""Display figure and results as separate outputs in Jupyter."""
from IPython.display import ( # pylint: disable=import-outside-toplevel
HTML,
display,
)
try:
html = self._repr_html_()
display(HTML(html))
except Exception as exc: # pylint: disable=broad-exception-caught
display(HTML(f"<div>Error rendering images: {exc}</div>"))
return
if self._results_html:
display(HTML(self._results_html))
def _repr_html_(self) -> str:
"""Return HTML representation for Jupyter display."""
try:
png_data = self._render_to_png()
b64_data = base64.b64encode(png_data).decode("utf-8")
title = self._title or "Images"
return (
'<div style="text-align: center;">'
f"<h4>{title}</h4>"
f'<img src="data:image/png;base64,{b64_data}" />'
"</div>"
)
except Exception as e: # pylint: disable=broad-exception-caught
return f"<div>Error rendering images: {e}</div>"
def _repr_png_(self) -> bytes:
"""Return PNG representation for Jupyter display."""
return self._render_to_png()
def _render_to_png(self) -> bytes:
"""Render images to PNG bytes using matplotlib."""
# pylint: disable=import-outside-toplevel
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
n_images = len(self._objs)
# Compute grid layout
if self._rows is not None:
nrows = self._rows
ncols = (n_images + nrows - 1) // nrows
else:
ncols = min(4, n_images)
nrows = (n_images + ncols - 1) // ncols
# Create figure — honour user height override (pixels → inches at 100 dpi)
_dpi = 100
_h_in = self._kwargs.get("height", 600) / _dpi
fig, axes = plt.subplots(
nrows,
ncols,
figsize=(6 * ncols, _h_in * nrows),
sharex=self._share_axes,
sharey=self._share_axes,
squeeze=False,
)
if self._title:
fig.suptitle(self._title)
# Flatten axes for easier iteration
axes_flat = axes.flatten()
# Prepare titles list
titles = self._titles or [None] * n_images
# Prepare results list
if self._results is None:
results_list = [None] * n_images
elif isinstance(self._results, (list, tuple)):
if len(self._results) != n_images:
# If single result, apply to all images
results_list = self._results * n_images
else:
results_list = self._results
else:
results_list = [self._results] * n_images
# Track labels/units from first ImageObj
x_label = self._xlabel
y_label = self._ylabel
z_label = self._zlabel
x_unit = self._xunit
y_unit = self._yunit
z_unit = self._zunit
default_colormap = self._kwargs.get("colormap", None)
all_tbl_results: list = []
all_geo_results: list = []
for idx, (ax, img, img_title, result) in enumerate(
zip(axes_flat, self._objs, titles, results_list)
):
obj_type = type(img).__name__
# Extract data
if obj_type == "ImageObj":
data = img.data
img_title = (
img_title or getattr(img, "title", None) or f"Image {idx + 1}"
)
is_image_obj = True
# Update labels/units from first ImageObj
if idx == 0:
x_label = x_label or getattr(img, "xlabel", None) or ""
y_label = y_label or getattr(img, "ylabel", None) or ""
z_label = z_label or getattr(img, "zlabel", None) or ""
x_unit = x_unit or getattr(img, "xunit", None) or ""
y_unit = y_unit or getattr(img, "yunit", None) or ""
z_unit = z_unit or getattr(img, "zunit", None) or ""
elif isinstance(img, np.ndarray):
data = img
img_title = img_title or f"Image {idx + 1}"
is_image_obj = False
else:
raise TypeError(f"Unsupported image type: {type(img)}")
# Handle complex data
if np.iscomplexobj(data):
data = np.abs(data)
img_title = f"|{img_title}|"
# Determine colormap: explicit kwarg > per-object metadata > viridis
if is_image_obj:
kwargs_for_cmap = self._kwargs if default_colormap else {}
colormap = _get_image_colormap(img, kwargs_for_cmap)
else:
colormap = default_colormap or "viridis"
# LUT range for ImageObj
vmin, vmax = _get_image_lut_range(img) if is_image_obj else (None, None)
# Non-uniform image coordinates
if is_image_obj and _is_non_uniform_image(img):
im = ax.pcolormesh(
img.xcoords,
img.ycoords,
data,
cmap=colormap,
vmin=vmin,
vmax=vmax,
shading="auto",
)
ax.set_aspect("auto")
ax.invert_yaxis()
extent = None # No extent for pcolormesh
else:
# Compute extent and aspect ratio
if is_image_obj:
extent, aspect_ratio = _get_image_extent_and_aspect(img)
else:
nrows_img, ncols_img = data.shape[:2]
extent = [-0.5, ncols_img - 0.5, nrows_img - 0.5, -0.5]
aspect_ratio = 1.0
# Display image
im = ax.imshow(
data,
cmap=colormap,
origin="upper",
aspect=aspect_ratio,
extent=extent,
vmin=vmin,
vmax=vmax,
interpolation="nearest",
)
ax.set_title(img_title)
# Overlay mask if ImageObj has maskdata
if (
is_image_obj
and extent is not None
and hasattr(img, "maskdata")
and img.maskdata is not None
):
mask = img.maskdata
mask_rgba = np.zeros((*mask.shape, 4), dtype=np.float32)
mask_rgba[mask, :] = [1, 0, 0, MASK_OPACITY]
ax.imshow(
mask_rgba,
origin="upper",
extent=extent,
interpolation="nearest",
)
# Add colorbar
cbar = plt.colorbar(im, ax=ax)
if z_label:
cbar.set_label(f"{z_label} ({z_unit})" if z_unit else z_label)
# Set axis labels
if x_label:
ax.set_xlabel(f"{x_label} ({x_unit})" if x_unit else x_label)
if y_label:
ax.set_ylabel(f"{y_label} ({y_unit})" if y_unit else y_label)
# Log scale and axis bounds (from first ImageObj)
if is_image_obj and idx == 0:
_apply_log_scale(ax, img)
_apply_axis_bounds(ax, img)
# Overlay ROIs
if self._show_roi and is_image_obj and hasattr(img, "roi") and img.roi:
for roi_idx, roi in enumerate(img.roi):
_add_single_roi_to_axes(ax, roi, img, roi_index=roi_idx)
# Collect and display geometry/table results if enabled
if self._show_results:
results_to_display = []
if result is not None:
result_list_item = (
result if isinstance(result, (list, tuple)) else [result]
)
results_to_display.extend(result_list_item)
# Auto-extract geometry results from object metadata
if is_image_obj:
metadata_results = _extract_geometry_results_from_metadata(img)
results_to_display.extend(metadata_results)
for res in results_to_display:
_add_geometry_to_axes(ax, res)
# Auto-extract and display table results (statistics) from metadata
if is_image_obj:
table_results = _extract_table_results_from_metadata(img)
all_tbl_results.extend(table_results)
all_geo_results.extend(results_to_display)
# Build HTML for results below the plot
self._results_html = _build_results_html(all_tbl_results, all_geo_results)
# Hide unused subplots
for idx in range(n_images, len(axes_flat)):
axes_flat[idx].axis("off")
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=100, pil_kwargs={"compress_level": 1})
plt.close(fig)
buf.seek(0)
return buf.read()
def __repr__(self) -> str:
"""Return string representation."""
return f"MplMultiImageResult({len(self._objs)} images)"
# ============================================================================
# Main Plotter class (drop-in API)
# ============================================================================
[docs]
class MatplotlibPlotter:
"""Matplotlib-based visualization frontend for the DataLab kernel.
This class provides the same public API as
:class:`datalab_kernel.plotly_backend.PlotlyPlotter` but produces static
matplotlib PNGs instead of interactive Plotly figures.
Example::
from datalab_kernel.matplotlib_backend import MatplotlibPlotter
plotter = MatplotlibPlotter(workspace)
plotter.plot("s001") # Single signal
plotter.plot([sig1, sig2]) # Multiple signals
plotter.plot([img1, img2]) # Multiple images
plotter.display_table(table_result) # HTML table
plotter.display_geometry(geom_result) # HTML table
"""
def __init__(self, workspace: Workspace) -> None:
"""Initialize plotter with workspace reference.
Args:
workspace: The workspace containing objects to plot
"""
self._workspace = workspace
[docs]
def plot(
self,
obj_or_name: DataObject | str | list,
title: str | None = None,
show_roi: bool = True,
show_results: bool = True,
*,
xlabel: str | None = None,
ylabel: str | None = None,
xunit: str | None = None,
yunit: str | None = None,
zlabel: str | None = None,
zunit: str | None = None,
titles: list[str] | None = None,
results: list | None = None,
**kwargs,
) -> MplPlotResult | MplMultiSignalResult | MplMultiImageResult:
"""Plot one or more objects.
Accepts a single object (or workspace name) **or** a list.
* **Single object** — renders one signal or image.
* **List of signals** — overlays all curves on shared axes.
* **List of images** — displays in a subplot grid.
* **Mixed list** — raises :class:`TypeError`.
A single-item list is unwrapped and treated as a single object.
Args:
obj_or_name: Object to plot, workspace name, or a *list* of
objects / names.
title: Plot title (overall figure title for multi-plots).
show_roi: Whether to show ROIs defined in the objects.
show_results: Whether to show geometry/table results from metadata.
xlabel: X-axis label override (multi-plots).
ylabel: Y-axis label override (multi-plots).
xunit: X-axis unit override (multi-plots).
yunit: Y-axis unit override (multi-plots).
zlabel: Colorbar label override (images only).
zunit: Colorbar unit override (images only).
titles: Per-image title list (images only).
results: List of ``GeometryResult`` objects to overlay (images only).
**kwargs: Additional plotting options (``height``, ``colormap``).
Returns:
A result object with Jupyter display capabilities.
Raises:
TypeError: If a list mixes signals and images.
KeyError: If a workspace name is not found.
"""
from datalab_kernel.plotter import ( # pylint: disable=import-outside-toplevel
_IMAGE,
_resolve_and_classify,
)
# --- list input: multi-object dispatch ---
if isinstance(obj_or_name, list):
# Single-item list → unwrap to single-object path
if len(obj_or_name) == 1:
item = obj_or_name[0]
if isinstance(item, str):
item = self._workspace.get(item)
return MplPlotResult(
item,
title=title,
show_roi=show_roi,
show_results=show_results,
results=results,
**kwargs,
)
objs, category = _resolve_and_classify(obj_or_name, self._workspace)
if category == _IMAGE:
return MplMultiImageResult(
objs,
title=title,
titles=titles,
xlabel=xlabel,
ylabel=ylabel,
zlabel=zlabel,
xunit=xunit,
yunit=yunit,
zunit=zunit,
show_roi=show_roi,
show_results=show_results,
results=results,
**kwargs,
)
return MplMultiSignalResult(
objs,
title=title,
xlabel=xlabel,
ylabel=ylabel,
xunit=xunit,
yunit=yunit,
show_roi=show_roi,
show_results=show_results,
**kwargs,
)
# --- scalar input: single-object path ---
if isinstance(obj_or_name, str):
obj = self._workspace.get(obj_or_name)
if title is None:
title = obj_or_name
else:
obj = obj_or_name
if title is None and hasattr(obj, "title"):
title = obj.title
return MplPlotResult(
obj,
title=title,
show_roi=show_roi,
show_results=show_results,
results=results,
**kwargs,
)
[docs]
def display_table(
self,
result,
title: str | None = None,
visible_only: bool = True,
transpose_single_row: bool = True,
) -> TableResultDisplay:
"""Display a TableResult with rich HTML rendering.
Args:
result: TableResult object to display
title: Optional title override (uses result.title if None)
visible_only: If True, show only visible columns based on display prefs
transpose_single_row: If True, transpose single-row tables for readability
Returns:
TableResultDisplay with Jupyter display capabilities
"""
return TableResultDisplay(
result,
title=title,
visible_only=visible_only,
transpose_single_row=transpose_single_row,
)
[docs]
def display_geometry(
self,
result,
title: str | None = None,
) -> GeometryResultDisplay:
"""Display a GeometryResult with rich HTML rendering.
Args:
result: GeometryResult object to display
title: Optional title override (uses result.title if None)
Returns:
GeometryResultDisplay with Jupyter display capabilities
"""
return GeometryResultDisplay(result, title=title)