Source code for arpes.widgets

"""Provides interactive tools based on matplotlib Qt interactive elements.

This are generally primitive one offs that are useful for accomplishing
something quick. As examples:

1. `pca_explorer` lets you interactively examine a PCA decomposition or
   other decomposition supported by `arpes.analysis.decomposition`
2. `pick_points`, `pick_rectangles` allows selecting many individual points
    or regions from a piece of data, useful to isolate locations to do
    further analysis.
3. `kspace_tool` allows interactively setting coordinate offset for
    angle-to-momentum conversion.
4. `fit_initializer` allows for seeding an XPS curve fit.

All of these return a "context" object which can be used to get information from the current
session (i.e. the selected points or regions, or modified data).
If you forget to save this context, you can recover it as the most recent context
is saved at `arpes.config.CONFIG` under the key "CURRENT_CONTEXT".

There are also primitives for building interactive tools in matplotlib. Such as
DataArrayView, which provides an interactive and updatable plot view from an
xarray.DataArray instance.

In the future, it would be nice to get higher quality interactive tools, as
we start to run into the limits of these ones. But between this and `qt_tool`
we are doing fine for now.

import pathlib
import itertools
import warnings
from functools import wraps
from typing import Callable, List, Optional, Union

import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np

from matplotlib.path import Path
from matplotlib.widgets import (

import arpes.config
from arpes.fits import LorentzianModel, broadcast_model
from arpes.plotting.utils import fancy_labels, imshow_arr, invisible_axes
from arpes.utilities import normalize_to_spectrum
from arpes.utilities.conversion import convert_to_kspace
from arpes.utilities.image import imread_to_xarray

__all__ = (

class SelectFromCollection:
    """Select indices from a matplotlib collection using `LassoSelector`.

    Modified from

    Selected indices are saved in the `ind` attribute. This tool fades out the
    points that are not part of the selection (i.e., reduces their alpha
    values). If your collection has alpha < 1, this tool will permanently
    alter the alpha values.

    Note that this tool selects collection objects based on their *origins*
    (i.e., `offsets`).

    def __init__(self, ax, collection, alpha_other=0.3, on_select=None):
        self.canvas = ax.figure.canvas
        self.collection = collection
        self.alpha_other = alpha_other

        self.xys = collection.get_offsets()
        self.n_pts = len(self.xys)
        self._on_select = on_select

        # Ensure that we have separate colors for each object
        self.facecolors = collection.get_facecolors()
        if not len(self.facecolors):
            raise ValueError("Collection must have a facecolor")

        if len(self.facecolors) == 1:
            self.facecolors = np.tile(self.facecolors, (self.n_pts, 1))

        self.lasso = LassoSelector(ax, onselect=self.onselect)
        self.ind = []

    def onselect(self, verts):
            path = Path(verts)
            self.ind = np.nonzero(path.contains_points(self.xys))[0]
            self.facecolors[:, -1] = self.alpha_other
            self.facecolors[self.ind, -1] = 1

            if self._on_select is not None:
        except Exception:

    def disconnect(self):
        self.facecolors[:, -1] = 1

def popout(plotting_function: Callable) -> Callable:
    """A decorator which applies the "%matplotlib qt" magic so that interactive plots are enabled.

    Sets and subsequently unsets the matplotlib backend for one function call, to allow use of
    'widgets' in Jupyter inline use.

        plotting_function: The plotting function which should be decorated.

        The decorated function.

    def wrapped(*args, **kwargs):
        from IPython import get_ipython

        ipython = get_ipython()
        ipython.magic("matplotlib qt")

        return plotting_function(*args, **kwargs)

        # ideally, cleanup, but this closes the plot, necessary but redundant looking import
        # look into an on close event for matplotlib
        # ipython.magic('matplotlib inline')
        # from matplotlib import pyplot as plt

    return wrapped

class DataArrayView:
    """A model (in the sense of view models) for a DataArray in matplotlib plots.

    Offers support for 1D and 2D DataArrays with masks, selection tools, and a simpler interface
    than the matplotlib primitives.

    Look some more into holoviews for different features.

    def __init__(
    ): = ax
        self._initialized = False
        self._data = None
        self._mask = None
        self.n_dims = None
        self.ax_kwargs = ax_kwargs or dict()
        self._axis_image = None
        self._mask_image = None
        self._mask_cmap = None
        self._transpose_mask = transpose_mask
        self._selector = None
        self._inner_on_select = None
        self.auto_autoscale = auto_autoscale
        self.mask_kwargs = mask_kwargs

        if data is not None:
   = data

    def handle_select(self, event_click=None, event_release=None):
        dims =

        if self.n_dims == 2:
            x1, y1 = event_click.xdata, event_click.ydata
            x2, y2 = event_release.xdata, event_release.ydata

            x1, x2 = min(x1, x2), max(x1, x2)
            y1, y2 = min(y1, y2), max(y1, y2)

            region = dict([[dims[1], slice(x1, x2)], [dims[0], slice(y1, y2)]])
            x1, x2 = event_click, event_release
            x1, x2 = min(x1, x2), max(x1, x2)

            region = dict([[[0], slice(x1, x2)]])


    def attach_selector(self, on_select):
        # data should already have been set
        assert self.n_dims is not None

        self._inner_on_select = on_select

        if self.n_dims == 1:
            self._selector = SpanSelector(
                rectprops=dict(alpha=0.35, facecolor="red"),
            self._selector = RectangleSelector(
                rectprops=dict(fill=False, edgecolor="black", linewidth=2),
                lineprops=dict(linewidth=2, color="black"),

    def data(self):
        return self._data

    def data(self, new_data):
        if self._initialized:
            self._data = new_data
            self._data = new_data
            self._initialized = True
            self.n_dims = len(new_data.dims)
            if self.n_dims == 2:
                self._axis_image = imshow_arr(self._data,, **self.ax_kwargs)[1]
                self.ax_kwargs.pop("cmap", None)
                x, y =[[0]].values,
                self._axis_image =, y, **self.ax_kwargs)
                cs =[[0]].values
      [np.min(cs), np.max(cs)])

        if self.n_dims == 2:
            x, y = (
            extent = [y[0], y[-1], x[0], x[-1]]
            color =[0].get_color()
            x, y =[[0]].values,
            l, h = np.min(y), np.max(y)
            self._axis_image =, y, c=color, **self.ax_kwargs)
  [l - 0.1 * (h - l), h + 0.1 * (h - l)])

        if self.auto_autoscale:

    def mask_cmap(self):
        if self._mask_cmap is None:
            self._mask_cmap ="cmap", "Reds"))
            self._mask_cmap.set_bad("k", alpha=0)

        return self._mask_cmap

    def mask(self):
        return self._mask

    def mask(self, new_mask):
        if np.array(new_mask).shape !=
            # should be indices then
            mask = np.zeros(, dtype=bool)
            np.ravel(mask)[new_mask] = True
            new_mask = mask

        self._mask = new_mask

        for_mask =, * 0 + 1)
        if self.n_dims == 2 and self._transpose_mask:
            for_mask = for_mask.T

        if self.n_dims == 2:
            if self._mask_image is None:
                self._mask_image =
            if self._mask_image is not None:

            x =[[0]].values
            low, high =
            self._mask_image =
                x, low, for_mask * high, color=self.mask_cmap(1.0), **self.mask_kwargs

    def autoscale(self):
        if self.n_dims == 2:

[docs]@popout def fit_initializer(data, peak_type=LorentzianModel, **kwargs): """A tool for initializing lineshape fitting.""" ctx = {} gs = gridspec.GridSpec(2, 2) ax_initial = plt.subplot(gs[0, 0]) ax_fitted = plt.subplot(gs[0, 1]) ax_other = plt.subplot(gs[1, 0]) ax_test = plt.subplot(gs[1, 1]) invisible_axes(ax_other) prefixes = "abcdefghijklmnopqrstuvwxyz" model_settings = [] model_defs = [] fitted_individual_models = [] for_fit = data.expand_dims("fit_dim") for_fit.coords["fit_dim"] = np.array([0]) data_view = DataArrayView(ax_initial) residual_view = DataArrayView(ax_fitted, ax_kwargs=dict(linestyle=":", color="orange")) fitted_view = DataArrayView(ax_fitted, ax_kwargs=dict(color="red")) initial_fit_view = DataArrayView(ax_fitted, ax_kwargs=dict(linestyle="--", color="blue")) def compute_parameters(): renamed = [ {"{}_{}".format(prefix, k): v for k, v in m_setting.items()} for m_setting, prefix in zip(model_settings, prefixes) ] return dict(itertools.chain(*[list(d.items()) for d in renamed])) def on_add_new_peak(selection): amplitude = data.sel(**selection).mean().item() selection = selection[data.dims[0]] center = (selection.start + selection.stop) / 2 sigma = selection.stop - selection.start model_settings.append( { "center": {"value": center, "min": center - sigma, "max": center + sigma}, "sigma": {"value": sigma}, "amplitude": {"min": 0, "value": amplitude}, } ) model_defs.append(LorentzianModel) if model_defs: results = broadcast_model(model_defs, for_fit, "fit_dim", params=compute_parameters()) result = results.results[0].item() if result is not None: # residual for_residual = data.copy(deep=True) for_residual.values = result.residual = for_residual # fit_result for_best_fit = data.copy(deep=True) for_best_fit.values = result.best_fit = for_best_fit # initial_fit_result for_initial_fit = data.copy(deep=True) for_initial_fit.values = result.init_fit = for_initial_fit ax_fitted.set_ylim(ax_initial.get_ylim()) = data data_view.attach_selector(on_select=on_add_new_peak) ctx["data"] = data def on_copy_settings(event): try: import pyperclip import pprint pyperclip.copy(pprint.pformat(compute_parameters())) except ImportError: pass finally: import pprint print(pprint.pformat(compute_parameters())) copy_settings_button = Button(ax_test, "Copy Settings") copy_settings_button.on_clicked(on_copy_settings) ctx["button"] = copy_settings_button return ctx
[docs]@popout def pca_explorer( pca, data, component_dim="components", initial_values=None, transpose_mask=False, **kwargs ): """A tool providing PCA decomposition exploration of a dataset. Args: pca: The decomposition of the data, the output of an sklearn PCA decomp. data: The original data. component_dim: The variable name or identifier associated to the PCA component projection in the input data. Defaults to "components" which is what is produced by `pca_along`. initial_values: Which of the PCA components to use for the 2D embedding. Defaults to None. transpose_mask: Controls whether the PCA masks should be transposed before application. Defaults to False. """ if initial_values is None: initial_values = [0, 1] pca_dims = list(pca.dims) pca_dims.remove(component_dim) other_dims = [d for d in data.dims if d not in pca_dims] context = { "selected_components": initial_values, "selected_indices": [], "sum_data": None, "map_data": None, "selector": None, "integration_region": {}, } arpes.config.CONFIG["CURRENT_CONTEXT"] = context def compute_for_scatter(): for_scatter = pca.copy(deep=True).isel( **dict([[component_dim, context["selected_components"]]]) ) for_scatter = for_scatter.S.transpose_to_back(component_dim) size = data.mean(other_dims).stack(pca_dims=pca_dims).values norm = np.expand_dims(np.linalg.norm(pca.values, axis=(0,)), axis=-1) return (for_scatter / norm).stack(pca_dims=pca_dims), 5 * size / np.mean(size) # ===== Set up axes ====== gs = gridspec.GridSpec(2, 2) ax_components = plt.subplot(gs[0, 0]) ax_sum_selected = plt.subplot(gs[0, 1]) ax_map = plt.subplot(gs[1, 0]) gs_widget = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[1, 1]) ax_widget_1 = plt.subplot(gs_widget[0, 0]) ax_widget_2 = plt.subplot(gs_widget[1, 0]) ax_widget_3 = plt.subplot(gs_widget[2, 0]) selected_view = DataArrayView(ax_sum_selected, ax_kwargs=dict(cmap="viridis")) map_view = DataArrayView( ax_map, ax_kwargs=dict(cmap="Greys"), mask_kwargs=dict(cmap="Reds", alpha=0.35), transpose_mask=transpose_mask, ) def update_from_selection(ind): # Calculate the new data if ind is None or not len(ind): context["selected_indices"] = [] context["sum_data"] = data.stack(pca_dims=pca_dims).sum("pca_dims") else: context["selected_indices"] = ind context["sum_data"] = data.stack(pca_dims=pca_dims).isel(pca_dims=ind).sum("pca_dims") if context["integration_region"] is not None: data_sel = data.sel(**context["integration_region"]).sum(other_dims) else: data_sel = data.sum(other_dims) # Update all views = data_sel map_view.mask = ind = context["sum_data"] def set_axes(component_x, component_y): ax_components.clear() context["selected_components"] = [component_x, component_y] for_scatter, size = compute_for_scatter() pts = ax_components.scatter(for_scatter.values[0], for_scatter.values[1], s=size) if context["selector"] is not None: context["selector"].disconnect() context["selector"] = SelectFromCollection( ax_components, pts, on_select=update_from_selection ) ax_components.set_xlabel("$e_" + str(component_x) + "$") ax_components.set_ylabel("$e_" + str(component_y) + "$") update_from_selection([]) def on_change_axes(event): try: val_x = int(context["axis_X_input"].text) val_y = int(context["axis_Y_input"].text) def clamp(x, low, high): if low <= x < high: return x if x < low: return low return high maximum = len(pca.coords[component_dim].values) - 1 val_x, val_y = clamp(val_x, 0, maximum), clamp(val_y, 0, maximum) assert val_x != val_y set_axes(val_x, val_y) except Exception: pass context["axis_button"] = Button(ax_widget_1, "Change Decomp Axes") context["axis_button"].on_clicked(on_change_axes) context["axis_X_input"] = TextBox(ax_widget_2, "Axis X:", initial=str(initial_values[0])) context["axis_Y_input"] = TextBox(ax_widget_3, "Axis Y:", initial=str(initial_values[1])) def on_select_summed(region): context["integration_region"] = region update_from_selection(context["selected_indices"]) set_axes(*initial_values) selected_view.attach_selector(on_select_summed) plt.tight_layout() return context
[docs]@popout def kspace_tool( data, overplot_bz: Optional[Union[Callable, List[Callable]]] = None, bounds=None, resolution=None, coords=None, **kwargs ): """A utility for assigning coordinate offsets using a live momentum conversion.""" original_data = data data = normalize_to_spectrum(data) if len(data.dims) > 2: data = data.sel(eV=slice(-0.05, 0.05)).sum("eV", keep_attrs=True) data.coords["eV"] = 0 if "eV" in data.dims: data = data.S.transpose_to_front("eV") data = data.copy(deep=True) ctx = {"original_data": original_data, "data": data, "widgets": []} arpes.config.CONFIG["CURRENT_CONTEXT"] = ctx gs = gridspec.GridSpec(4, 3) ax_initial = plt.subplot(gs[0:2, 0:2]) ax_converted = plt.subplot(gs[2:, 0:2]) if overplot_bz is not None: try: len(overplot_bz) except TypeError: overplot_bz = [overplot_bz] for fn in overplot_bz: fn(ax_converted) n_widget_axes = 8 gs_widget = gridspec.GridSpecFromSubplotSpec(n_widget_axes, 1, subplot_spec=gs[:, 2]) widget_axes = [plt.subplot(gs_widget[i, 0]) for i in range(n_widget_axes)] [invisible_axes(a) for a in widget_axes[:-2]] skip_dims = {"x", "X", "y", "y", "z", "Z", "T"} for dim in skip_dims: if dim in data.dims: raise ValueError("Please provide data without the {} dimension".format(dim)) convert_dims = ["theta", "beta", "phi", "psi"] if "eV" not in data.dims: convert_dims += ["chi"] if "hv" in data.dims: convert_dims += ["hv"] ang_range = (-45 * np.pi / 180, 45 * np.pi / 180, 0.01) default_ranges = { "eV": [-0.05, 0.05, 0.001], "hv": [-20, 20, 0.5], } sliders = {} def update_kspace_plot(_): for name, slider in sliders.items(): data.attrs["{}_offset".format(name)] = slider.val with warnings.catch_warnings(): warnings.simplefilter("ignore") = convert_to_kspace( data, bounds=bounds, resolution=resolution, coords=coords, **kwargs ) axes = iter(widget_axes) for convert_dim in convert_dims: widget_ax = next(axes) low, high, delta = default_ranges.get(convert_dim, ang_range) init = data.S.lookup_offset(convert_dim) sliders[convert_dim] = Slider( widget_ax, convert_dim, init + low, init + high, valinit=init, valstep=delta ) sliders[convert_dim].on_changed(update_kspace_plot) def compute_offsets(): return {k: v.val for k, v in sliders.items()} def on_copy_settings(event): try: import pyperclip import pprint pyperclip.copy(pprint.pformat(compute_offsets())) except ImportError: pass finally: import pprint print(pprint.pformat(compute_offsets())) def apply_offsets(event): for name, offset in compute_offsets().items(): print(name, offset) original_data.attrs["{}_offset".format(name)] = offset try: for s in original_data.S.spectra: s.attrs["{}_offset".format(name)] = offset except AttributeError: pass ctx["widgets"].append(sliders) copy_settings_button = Button(widget_axes[-1], "Copy Offsets") apply_settings_button = Button(widget_axes[-2], "Apply Offsets") copy_settings_button.on_clicked(on_copy_settings) apply_settings_button.on_clicked(apply_offsets) ctx["widgets"].append(copy_settings_button) ctx["widgets"].append(apply_settings_button) data_view = DataArrayView(ax_initial) converted_view = DataArrayView(ax_converted) = data update_kspace_plot(None) plt.tight_layout() return ctx
[docs]@popout def pick_rectangles(data, **kwargs): """A utility allowing for selection of rectangular regions.""" ctx = {"points": [], "rect_next": False} arpes.config.CONFIG["CURRENT_CONTEXT"] = ctx rects = [] fig = plt.figure() data.S.plot(**kwargs) ax = fig.gca() def onclick(event): ctx["points"].append([event.xdata, event.ydata]) if ctx["rect_next"]: p1, p2 = ctx["points"][-2], ctx["points"][-1] p1[0], p2[0] = min(p1[0], p2[0]), max(p1[0], p2[0]) p1[1], p2[1] = min(p1[1], p2[1]), max(p1[1], p2[1]) rects.append([p1, p2]) rect = plt.Rectangle( ( p1[0], p1[1], ), p2[0] - p1[0], p2[1] - p1[1], edgecolor="red", linewidth=2, fill=False, ) ax.add_patch(rect) ctx["rect_next"] = not ctx["rect_next"] plt.draw() _ = plt.connect("button_press_event", onclick) return rects
@popout def pick_gamma(data, **kwargs): fig = plt.figure() data.S.plot(**kwargs) ax = fig.gca() dims = data.dims def onclick(event): data.attrs["symmetry_points"] = {"G": {}} print(event.x, event.xdata, event.y, event.ydata) for dim, value in zip(dims, [event.ydata, event.xdata]): if dim == "eV": continue data.attrs["symmetry_points"]["G"][dim] = value plt.draw() _ = plt.connect("button_press_event", onclick) return data
[docs]@popout def pick_points(data_or_str, **kwargs): """A utility allowing for selection of points in a dataset.""" using_image_data = isinstance(data_or_str, (str, pathlib.Path)) ctx = {"points": []} arpes.config.CONFIG["CURRENT_CONTEXT"] = ctx fig = plt.figure() if using_image_data: data = imread_to_xarray(data_or_str) plt.imshow(data.values) else: data = data_or_str data.S.plot(**kwargs) ax = fig.gca() if using_image_data: ax.grid(False) x0, y0 = ax.transAxes.transform((0, 0)) # lower left in pixels x1, y1 = ax.transAxes.transform((1, 1)) # upper right in pixes dx = x1 - x0 dy = y1 - y0 maxd = max(dx, dy) xlim, ylim = ax.get_xlim(), ax.get_ylim() width = 0.03 * maxd / dx * (xlim[1] - xlim[0]) height = 0.03 * maxd / dy * (ylim[1] - ylim[0]) def onclick(event): ctx["points"].append([event.xdata, event.ydata]) circ = matplotlib.patches.Ellipse( ( event.xdata, event.ydata, ), width, height, color="red", ) ax.add_patch(circ) plt.draw() _ = plt.connect("button_press_event", onclick) return ctx["points"]