Source code for arpes.plotting.utils

"""Contains many common utility functions for managing matplotlib."""
import collections
import pickle
import contextlib
from typing import List, Tuple, Union

import datetime
import re
import errno
import itertools
import json
import os.path
import warnings
import pathlib
from collections import Counter

import matplotlib
import matplotlib.cm as cm
import matplotlib.offsetbox
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colorbar, colors, gridspec
from matplotlib.lines import Line2D

import xarray as xr
from arpes import VERSION
from arpes.config import CONFIG, SETTINGS, attempt_determine_workspace, is_using_tex
from arpes.typing import DataType
from arpes.utilities import normalize_to_spectrum
from arpes.utilities.jupyter import get_recent_history, get_notebook_name

__all__ = (
    # General + IO
    "path_for_plot",
    "path_for_holoviews",
    "name_for_dim",
    "unit_for_dim",
    "load_data_for_figure",
    "savefig",
    "AnchoredHScaleBar",
    "calculate_aspect_ratio",
    # context managers
    "dark_background",
    # color related
    "temperature_colormap",
    "polarization_colorbar",
    "temperature_colormap_around",
    "temperature_colorbar",
    "temperature_colorbar_around",
    "generic_colorbarmap",
    "generic_colorbarmap_for_data",
    "colorbarmaps_for_axis",
    # Axis generation
    "dos_axes",
    "simple_ax_grid",
    # matplotlib 'macros'
    "invisible_axes",
    "no_ticks",
    "get_colorbars",
    "remove_colorbars",
    "frame_with",
    "unchanged_limits",
    "imshow_arr",
    "imshow_mask",
    "lineplot_arr",  # 1D version of imshow_arr
    "plot_arr",  # generic dimension version of imshow_arr, plot_arr
    # insets related
    "inset_cut_locator",
    "swap_xaxis_side",
    "swap_yaxis_side",
    "swap_axis_sides",
    # units related
    "data_to_axis_units",
    "axis_to_data_units",
    "daxis_ddata_units",
    "ddata_daxis_units",
    # TeX related
    "quick_tex",
    "latex_escape",
    # Decorating + labeling
    "label_for_colorbar",
    "label_for_dim",
    "label_for_symmetry_point",
    "sum_annotation",
    "mean_annotation",
    "fancy_labels",
    "mod_plot_to_ax",
    # Data summaries
    "summarize",
    "transform_labels",
    "v_gradient_fill",
    "h_gradient_fill",
)


[docs]@contextlib.contextmanager def unchanged_limits(ax): """Context manager that retains axis limits.""" xlim, ylim = ax.get_xlim(), ax.get_ylim() yield ax.set_xlim(xlim) ax.set_ylim(ylim)
def mod_plot_to_ax(data, ax, mod, **kwargs): """Plots a model onto an axis using the data range from the passed data.""" with unchanged_limits(ax): xs = data.coords[data.dims[0]].values ys = mod.eval(x=xs) ax.plot(xs, ys, **kwargs) def h_gradient_fill(x1, x2, x_solid, fill_color=None, ax=None, zorder=None, alpha=None, **kwargs): """Fills a gradient between x1 and x2. If x_solid is not None, the gradient will be extended at the maximum opacity from the closer limit towards x_solid. Args: x1 x2 x_solid fill_color ax zorder alpha **kwargs Returns: The result of the inner imshow. """ if ax is None: ax = plt.gca() xlim, ylim = ax.get_xlim(), ax.get_ylim() assert fill_color alpha = 1.0 if alpha is None else alpha z = np.empty((1, 100, 4), dtype=float) rgb = matplotlib.colors.colorConverter.to_rgb(fill_color) z[:, :, :3] = rgb z[:, :, -1] = np.linspace(0, alpha, 100)[None, :] xmin, xmax, (ymin, ymax) = x1, x2, ylim im = ax.imshow(z, aspect="auto", extent=[xmin, xmax, ymin, ymax], origin="lower", zorder=zorder) if x_solid is not None: xlow, xhigh = (x2, x_solid) if x_solid > x2 else (x_solid, x1) ax.fill_betweenx(ylim, xlow, xhigh, color=fill_color, alpha=alpha) ax.set_xlim(xlim) ax.set_ylim(ylim) return im def v_gradient_fill(y1, y2, y_solid, fill_color=None, ax=None, zorder=None, alpha=None, **kwargs): """Fills a gradient vertically between y1 and y2. If y_solid is not None, the gradient will be extended at the maximum opacity from the closer limit towards y_solid. Args: y1 y2 y_solid fill_color ax zorder alpha **kwargs Returns: The result of the inner imshow call. """ if ax is None: ax = plt.gca() xlim, ylim = ax.get_xlim(), ax.get_ylim() assert fill_color alpha = 1.0 if alpha is None else alpha z = np.empty((100, 1, 4), dtype=float) rgb = matplotlib.colors.colorConverter.to_rgb(fill_color) z[:, :, :3] = rgb z[:, :, -1] = np.linspace(0, alpha, 100)[:, None] (xmin, xmax), ymin, ymax = xlim, y1, y2 im = ax.imshow(z, aspect="auto", extent=[xmin, xmax, ymin, ymax], origin="lower", zorder=zorder) if y_solid is not None: ylow, yhigh = (y2, y_solid) if y_solid > y2 else (y_solid, y1) ax.fill_between(xlim, ylow, yhigh, color=fill_color, alpha=alpha) ax.set_xlim(xlim) ax.set_ylim(ylim) return im
[docs]def simple_ax_grid( n_axes, figsize=None, **kwargs ) -> Tuple[plt.Figure, List[plt.Axes], List[plt.Axes]]: """Generates a square-ish set of axes and hides the extra ones. It would be nice to accept an "aspect ratio" item that will attempt to fix the grid dimensions to get an aspect ratio close to the desired one. Args: n_axes figsize **kwargs Returns: The figure, the first n axis which are shown, and the remaining hidden axes. """ width = int(np.ceil(np.sqrt(n_axes))) height = width - 1 if width * height < n_axes: height += 1 if figsize is None: figsize = ( 3 * max(width, 5), 3 * max(height, 5), ) fig, ax = plt.subplots(height, width, figsize=figsize, **kwargs) if n_axes == 1: ax = np.array([ax]) ax, ax_rest = ax.ravel()[:n_axes], ax.ravel()[n_axes:] for axi in ax_rest: invisible_axes(axi) return fig, ax, ax_rest
@contextlib.contextmanager def dark_background(overrides): """Context manager for plotting "dark mode".""" defaults = { "axes.edgecolor": "white", "xtick.color": "white", "ytick.color": "white", "axes.labelcolor": "white", "text.color": "white", } defaults.update(overrides) with plt.rc_context(defaults): yield def data_to_axis_units(points, ax=None): """Converts between data and axis units.""" if ax is None: ax = plt.gca() return ax.transAxes.inverted().transform(ax.transData.transform(points)) def axis_to_data_units(points, ax=None): """Converts between axis and data units.""" if ax is None: ax = plt.gca() return ax.transData.inverted().transform(ax.transAxes.transform(points)) def ddata_daxis_units(ax=None): """Gives the derivative of data units with respect to axis units.""" if ax is None: ax = plt.gca() dp1 = axis_to_data_units((1.0, 1.0), ax) dp0 = axis_to_data_units((0.0, 0.0), ax) return dp1 - dp0 def daxis_ddata_units(ax=None): """Gives the derivative of axis units with respect to data units.""" if ax is None: ax = plt.gca() dp1 = data_to_axis_units((1.0, 1.0), ax) dp0 = data_to_axis_units((0.0, 0.0), ax) return dp1 - dp0
[docs]def swap_xaxis_side(ax): """Swaps the x axis to the top of the figure.""" ax.xaxis.tick_top() ax.xaxis.set_label_position("top")
[docs]def swap_yaxis_side(ax): """Swaps the y axis to the right of the figure.""" ax.yaxis.tick_right() ax.yaxis.set_label_position("right")
[docs]def swap_axis_sides(ax): """Swaps the y axis to the right of the figure and the x axis to the top.""" swap_xaxis_side(ax) swap_yaxis_side(ax)
def transform_labels(transform_fn, fig=None, include_titles=True): """Apply a function to all axis labeled in a figure.""" if fig is None: fig = plt.gcf() axes = list(fig.get_axes()) for ax in axes: try: ax.set_xlabel(transform_fn(ax.get_xlabel(), is_title=False)) ax.set_ylabel(transform_fn(ax.get_xlabel(), is_title=False)) if include_titles: ax.set_title(transform_fn(ax.get_title(), is_title=True)) except TypeError: ax.set_xlabel(transform_fn(ax.get_xlabel())) ax.set_ylabel(transform_fn(ax.get_xlabel())) if include_titles: ax.set_title(transform_fn(ax.get_title()))
[docs]def summarize(data: DataType, axes=None): """Makes a summary plot with different marginal plots represented.""" data = normalize_to_spectrum(data) axes_shapes_for_dims = { 1: (1, 1), 2: (1, 1), 3: (2, 2), # one extra here 4: (3, 2), # corresponds to 4 choose 2 axes } if axes is None: fig, axes = plt.subplots(axes_shapes_for_dims.get(len(data.dims)), figsize=(8, 8)) flat_axes = axes.ravel() combinations = list(itertools.combinations(data.dims, 2)) for axi, combination in zip(flat_axes, combinations): data.sum(combination).plot(ax=axi) fancy_labels(axi) for i in range(len(combinations), len(flat_axes)): flat_axes[i].set_axis_off() return axes
[docs]def sum_annotation(eV=None, phi=None): """Annotates that a given axis was summed over by listing the integration range.""" eV_annotation, phi_annotation = "", "" def to_str(bound): if bound is None: return "" return "{:.2f}".format(bound) if eV is not None: if SETTINGS["use_tex"]: eV_annotation = "$\\text{E}_{" + to_str(eV.start) + "}^{" + to_str(eV.stop) + "}$" else: eV_annotation = to_str(eV.start) + " < E < " + to_str(eV.stop) if phi is not None: if SETTINGS["use_tex"]: phi_annotation = "$\\phi_{" + to_str(phi.start) + "}^{" + to_str(phi.stop) + "}$" else: phi_annotation = to_str(phi.start) + " < φ < " + to_str(phi.stop) return eV_annotation + phi_annotation
def mean_annotation(eV=None, phi=None): """Annotates that a given axis was meaned (summed) over by listing the integration range.""" eV_annotation, phi_annotation = "", "" def to_str(bound): if bound is None: return "" return "{:.2f}".format(bound) if eV is not None: if SETTINGS["use_tex"]: eV_annotation = ( "$\\bar{\\text{E}}_{" + to_str(eV.start) + "}^{" + to_str(eV.stop) + "}$" ) else: eV_annotation = "Mean<" + to_str(eV.start) + " < E < " + to_str(eV.stop) + ">" if phi is not None: if SETTINGS["use_tex"]: phi_annotation = "$\\bar{\\phi}_{" + to_str(phi.start) + "}^{" + to_str(phi.stop) + "}$" else: phi_annotation = "Mean<" + to_str(phi.start) + " < φ < " + to_str(phi.stop) + ">" return eV_annotation + phi_annotation
[docs]def frame_with(ax, color="red", linewidth=2): """Makes thick, visually striking borders on a matplotlib plot. Very useful for color coding results in a slideshow. """ for spine in ["left", "right", "top", "bottom"]: ax.spines[spine].set_color(color) ax.spines[spine].set_linewidth(linewidth)
LATEX_ESCAPE_MAP = { "_": r"\_", "<": r"\textless{}", ">": r"\textgreater{}", "{": r"\{", "}": r"\}", "&": r"\&", "%": r"\%", "$": r"\$", "#": r"\#", "~": r"\textasciitilde{}", "^": r"\^{}", "\\": r"\textbackslash{}", } LATEX_ESCAPE_REGEX = re.compile( "|".join( re.escape(str(k)) for k in sorted(LATEX_ESCAPE_MAP.keys(), key=lambda item: -len(item)) ) )
[docs]def latex_escape(text: str, force: bool = False) -> str: """Conditionally escapes a string based on the matplotlib settings. If you need the escaped string even if you are not using matplotlib with LaTeX support, you can pass `force=True`. Adjusted from suggestions at: https://stackoverflow.com/questions/16259923/how-can-i-escape-latex-special-characters-inside-django-templates Args: text: The contents which should be escaped force: Whether we should perform escaping even if matplotlib is not being used with LaTeX support. Returns: The escaped string which should appear in LaTeX with the same contents as the original. """ if not is_using_tex() and not force: return text # otherwise, we need to escape return LATEX_ESCAPE_REGEX.sub(lambda match: LATEX_ESCAPE_MAP[match.group()], text)
def quick_tex(latex_fragment: str, ax=None, fontsize=30) -> plt.Axes: """Sometimes you just need to render some LaTeX. Getting a LaTex session running is far too much effort. Also just go to the KaTeX website and can work well. Args: latex_fragment: The fragment to render Returns: The axes generated. """ if ax is None: fig, ax = plt.subplots() invisible_axes(ax) ax.text(0.2, 0.2, latex_fragment, fontsize=fontsize) return ax def lineplot_arr(arr, ax=None, method="plot", mask=None, mask_kwargs=None, **kwargs): """Convenience method to plot an array with a mask over some other data.""" if mask_kwargs is None: mask_kwargs = dict() if ax is None: _, ax = plt.subplots() xs = None if arr is not None: fn = plt.plot if method == "scatter": fn = plt.scatter xs = arr.coords[arr.dims[0]].values fn(xs, arr.values, **kwargs) if mask is not None: y_lim = ax.get_ylim() if isinstance(mask, list) and isinstance(mask[0], slice): for slice_mask in mask: ax.fill_betweenx(y_lim, slice_mask.start, slice_mask.stop, **mask_kwargs) else: raise NotImplementedError ax.set_ylim(y_lim) return ax
[docs]def plot_arr(arr=None, ax=None, over=None, mask=None, **kwargs): """Convenience method to plot an array with a mask over some other data.""" to_plot = arr if mask is None else mask try: n_dims = len(to_plot.dims) except AttributeError: n_dims = 1 if n_dims == 2: quad = None if arr is not None: ax, quad = imshow_arr(arr, ax=ax, over=over, **kwargs) if mask is not None: over = quad if over is None else over imshow_mask(mask, ax=ax, over=over, **kwargs) if n_dims == 1: ax = lineplot_arr(arr, ax=ax, mask=mask, **kwargs) return ax
[docs]def imshow_mask(mask, ax=None, over=None, cmap=None, **kwargs): """Plots a mask by using a fixed color and transparency.""" assert over is not None if ax is None: ax = plt.gca() if cmap is None: cmap = "Reds" if isinstance(cmap, str): cmap = cm.get_cmap(name=cmap) cmap.set_bad("k", alpha=0) ax.imshow( mask.values, cmap=cmap, interpolation="none", vmax=1, vmin=0, origin="lower", extent=over.get_extent(), aspect=ax.get_aspect(), **kwargs, )
def imshow_arr( arr, ax=None, over=None, origin="lower", aspect="auto", alpha=None, vmin=None, vmax=None, cmap=None, **kwargs, ): """Similar to plt.imshow but users different default origin, and sets appropriate extents. Args: arr ax Returns: The axes and quadmesh instance. """ if ax is None: fig, ax = plt.subplots() x, y = arr.coords[arr.dims[0]].values, arr.coords[arr.dims[1]].values extent = [y[0], y[-1], x[0], x[-1]] if over is None: if alpha is not None: if vmin is None: vmin = arr.min().item() if vmax is None: vmax = arr.max().item() if cmap is None: cmap = "viridis" if isinstance(cmap, str): cmap = cm.get_cmap(cmap) norm = colors.Normalize(vmin=vmin, vmax=vmax) mappable = cm.ScalarMappable(cmap=cmap, norm=norm) mapped_colors = mappable.to_rgba(arr.values) mapped_colors[:, :, 3] = alpha quad = ax.imshow(mapped_colors, origin=origin, extent=extent, aspect=aspect, **kwargs) else: quad = ax.imshow( arr.values, origin=origin, extent=extent, aspect=aspect, cmap=cmap, **kwargs ) ax.grid(False) ax.set_xlabel(arr.dims[1]) ax.set_ylabel(arr.dims[0]) else: quad = ax.imshow( arr.values, extent=over.get_extent(), aspect=ax.get_aspect(), origin=origin, **kwargs ) return ax, quad def dos_axes(orientation="horiz", figsize=None, with_cbar=True) -> Tuple[plt.Figure, plt.Axes]: """Makes axes corresponding to density of states data. This has one image like region and one small marginal for an EDC. Orientation option should be 'horiz' or 'vert'. Args: orientation figsize with_cbar Returns: The generated figure and axes as a tuple. """ if figsize is None: figsize = (12, 9) if orientation == "vert" else (9, 9) fig = plt.figure(figsize=figsize) outer_grid = gridspec.GridSpec(4, 4, wspace=0.0, hspace=0.0) if orientation == "horiz": fig.subplots_adjust(hspace=0.00) gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1]) ax0 = plt.subplot(gs[0]) axes = (ax0, plt.subplot(gs[1], sharex=ax0)) plt.setp(axes[0].get_xticklabels(), visible=False) else: fig.subplots_adjust(wspace=0.00) gs = gridspec.GridSpec(1, 2, width_ratios=[1, 4]) ax0 = plt.subplot(gs[1]) axes = (ax0, plt.subplot(gs[0], sharey=ax0)) plt.setp(axes[0].get_yticklabels(), visible=False) return fig, axes def inset_cut_locator(data, reference_data=None, ax=None, location=None, color=None, **kwargs): """Plots a reference cut location over a figure. Another approach is to separately plot the locator and add it in Illustrator or another tool. Args: data: The data you are plotting reference_data: The reference data containing the location of the cut ax: The axes to plot on location: The location in the cut color: The color to use for the indicator line kwargs: Passed to ax.plot when making the indicator lines """ quad = data.plot(ax=ax) ax.set_xlabel("") ax.set_ylabel("") try: quad.colorbar.remove() except Exception: pass # add more as necessary missing_dim_resolvers = { "theta": lambda: reference_data.S.theta, "beta": lambda: reference_data.S.beta, "phi": lambda: reference_data.S.phi, } missing_dims = [d for d in data.dims if d not in location] missing_values = {d: missing_dim_resolvers[d]() for d in missing_dims} ordered_selector = [location.get(d, missing_values.get(d)) for d in data.dims] n = 200 def resolve(name, value): if isinstance(value, slice): low = value.start high = value.stop if low is None: low = data.coords[name].min().item() if high is None: high = data.coords[name].max().item() return np.linspace(low, high, n) return np.ones((n,)) * value n_cut_dims = len([d for d in ordered_selector if isinstance(d, (collections.Iterable, slice))]) ordered_selector = [resolve(d, v) for d, v in zip(data.dims, ordered_selector)] if missing_dims: assert reference_data is not None print(missing_dims) if n_cut_dims == 2: # a region cut, illustrate with a rect or by suppressing background return if color is None: color = "red" if n_cut_dims == 1: # a line cut, illustrate with a line ax.plot(*ordered_selector[::-1], color=color, **kwargs) elif n_cut_dims == 0: # a single point cut, illustrate with a marker pass def generic_colormap(low, high): """Generates a colormap from the cm.Blues palette, suitable for most purposes.""" delta = high - low low = low - delta / 6 high = high + delta / 6 def get_color(value): return cm.Blues(float((value - low) / (high - low))) return get_color def phase_angle_colormap(low=0, high=np.pi * 2): """Generates a colormap suitable for angular data or data on a unit circle like a phase.""" def get_color(value): return cm.twilight_shifted(float((value - low) / (high - low))) return get_color def delay_colormap(low=-1, high=1): """Generates a colormap suitable for pump-probe delay data.""" def get_color(value): return cm.coolwarm(float((value - low) / (high - low))) return get_color def temperature_colormap(high=300, low=0, cmap=None): """Generates a colormap suitable for temperature data with fixed extent.""" if cmap is None: cmap = cm.Blues_r def get_color(value): return cmap(float((value - low) / (high - low))) return get_color def temperature_colormap_around(central, range=50): """Generates a colormap suitable for temperature data around a central value.""" def get_color(value): return cm.RdBu_r(float((value - central) / range)) return get_color def generic_colorbar(low, high, label="", cmap=None, ax=None, ticks=None, **kwargs): extra_kwargs = { "orientation": "horizontal", "label": label, "ticks": ticks if ticks is not None else [low, high], } delta = high - low low = low - delta / 6 high = high + delta / 6 extra_kwargs.update(kwargs) cb = colorbar.ColorbarBase( ax, cmap=cm.get_cmap(cmap or "Blues"), norm=colors.Normalize(vmin=low, vmax=high), **extra_kwargs, ) return cb def phase_angle_colorbar(high=np.pi * 2, low=0, ax=None, **kwargs): """Generates a colorbar suitable for plotting an angle or value on a unit circle.""" extra_kwargs = { "orientation": "horizontal", "label": "Angle", "ticks": ["0", r"$\pi$", r"$2\pi$"], } if not SETTINGS["use_tex"]: extra_kwargs["ticks"] = ["0", "π", "2π"] extra_kwargs.update(kwargs) cb = colorbar.ColorbarBase( ax, cmap=cm.get_cmap("twilight_shifted"), norm=colors.Normalize(vmin=low, vmax=high), **extra_kwargs, ) return cb def temperature_colorbar(high=300, low=0, ax=None, cmap=None, **kwargs): """Generates a colorbar suitable for temperature data with fixed extent.""" if cmap is None: cmap = "Blues_r" extra_kwargs = { "orientation": "horizontal", "label": "Temperature (K)", "ticks": [low, high], } extra_kwargs.update(kwargs) cb = colorbar.ColorbarBase( ax, cmap=cmap, norm=colors.Normalize(vmin=low, vmax=high), **extra_kwargs ) return cb def delay_colorbar(low=-1, high=1, ax=None, **kwargs): """Generates a colorbar suitable for delay data. TODO make this nonsequential for use in case where you want to have a long time period after the delay or before. """ extra_kwargs = { "orientation": "horizontal", "label": "Probe Pulse Delay (ps)", "ticks": [low, 0, high], } extra_kwargs.update(kwargs) cb = colorbar.ColorbarBase( ax, cmap="coolwarm", norm=colors.Normalize(vmin=low, vmax=high), **extra_kwargs ) return cb def temperature_colorbar_around(central, range=50, ax=None, **kwargs): """Generates a colorbar suitable for temperature axes around a central value.""" extra_kwargs = { "orientation": "horizontal", "label": "Temperature (K)", "ticks": [central - range, central + range], } extra_kwargs.update(kwargs) cb = colorbar.ColorbarBase( ax, cmap="RdBu_r", norm=colors.Normalize(vmin=central - range, vmax=central + range), **extra_kwargs, ) return cb colorbarmaps_for_axis = { "temp": ( temperature_colorbar, temperature_colormap, ), "delay": ( delay_colorbar, delay_colormap, ), "theta": ( phase_angle_colorbar, phase_angle_colormap, ), "volts": ( generic_colorbar, generic_colormap, ), } def get_colorbars(fig=None) -> List[plt.Axes]: """Collects likely colorbars in a figure.""" if fig is None: fig = plt.gcf() colorbars = [] for ax in fig.axes: if ax.get_aspect() == 20: colorbars.append(ax) return colorbars
[docs]def remove_colorbars(fig=None): """Removes colorbars from given (or, if no given figure, current) matplotlib figure. Args: fig: The figure to modify, by default uses the current figure (`plt.gcf()`) """ # TODO after colorbar removal, plots should be relaxed/rescaled to occupy space previously allocated to colorbars # for now, can follow this with plt.tight_layout() try: if fig is not None: for ax in fig.axes: if ax.get_aspect() == 20: # a bit of a hack ax.remove() else: remove_colorbars(plt.gcf()) except Exception: pass
generic_colorbarmap = ( generic_colorbar, generic_colormap, ) def generic_colorbarmap_for_data(data: xr.DataArray, keep_ticks=True, ax=None, **kwargs): """Generates a colorbar and colormap which is useful in general context.""" low, high = data.min().item(), data.max().item() ticks = None if keep_ticks: ticks = data.values return ( generic_colorbar(low=low, high=high, ax=ax, ticks=kwargs.get("ticks", ticks)), generic_colormap(low=low, high=high), ) def polarization_colorbar(ax=None): """Makes a colorbar which is appropriate for "polarization" (e.g. spin) data.""" cb = colorbar.ColorbarBase( ax, cmap="RdBu", norm=colors.Normalize(vmin=-1, vmax=1), orientation="horizontal", label="Polarization", ticks=[-1, 0, 1], ) return cb def calculate_aspect_ratio(data: DataType): """Calculate the aspect ratio which should be used for plotting some data based on extent.""" data = normalize_to_spectrum(data) assert len(data.dims) == 2 x_extent = np.ptp(data.coords[data.dims[0]].values) y_extent = np.ptp(data.coords[data.dims[1]].values) return y_extent / x_extent class AnchoredHScaleBar(matplotlib.offsetbox.AnchoredOffsetbox): """Provides an anchored scale bar on the X axis. Modified from `this StackOverflow question <https://stackoverflow.com/questions/43258638/>`_ as alternate to the one provided through matplotlib. """ def __init__( self, size=1, extent=0.03, label="", loc=2, ax=None, pad=0.4, borderpad=0.5, ppad=0, sep=2, prop=None, label_color=None, frameon=True, **kwargs, ): """Setup the scale bar and coordinate transforms to the parent axis.""" if not ax: ax = plt.gca() trans = ax.get_xaxis_transform() size_bar = matplotlib.offsetbox.AuxTransformBox(trans) line = Line2D([0, size], [0, 0], **kwargs) vline1 = Line2D([0, 0], [-extent / 2.0, extent / 2.0], **kwargs) vline2 = Line2D([size, size], [-extent / 2.0, extent / 2.0], **kwargs) size_bar.add_artist(line) size_bar.add_artist(vline1) size_bar.add_artist(vline2) txt = matplotlib.offsetbox.TextArea( label, minimumdescent=False, textprops={ "color": label_color, }, ) self.vpac = matplotlib.offsetbox.VPacker( children=[size_bar, txt], align="center", pad=ppad, sep=sep ) matplotlib.offsetbox.AnchoredOffsetbox.__init__( self, loc, pad=pad, borderpad=borderpad, child=self.vpac, prop=prop, frameon=frameon ) def load_data_for_figure(p: Union[str, pathlib.Path]): """Tries to load the data associated with a given figure by unpickling the saved data.""" path = str(p) stem = os.path.splitext(path)[0] if stem.endswith("-PAPER"): stem = stem[:-6] pickle_file = stem + ".pickle" if not os.path.exists(pickle_file): raise ValueError("No saved data matching figure.") with open(pickle_file, "rb") as f: data = pickle.load(f) return data
[docs]def savefig(desired_path, dpi=400, data=None, save_data=None, paper=False, **kwargs): """The PyARPES preferred figure saving routine. Provides a number of conveniences over matplotlib's `savefig`: #. Output is scoped per project and per day, which aids organization #. The dpi is set to a reasonable value for the year 2021. #. By omitting a file extension you will get high and low res formats in .png and .pdf which is useful for figure drafting in external software (Adobe Illustrator) #. Data and plot provenenace is tracked, which makes it easier to find your analysis after the fact if you have many many plots. """ if not os.path.splitext(desired_path)[1]: paper = True if save_data is None: if paper: raise ValueError( "You must supply save_data when outputting in paper mode. This " "is for your own good so you can more easily regenerate " "the figure later!" ) else: output_location = path_for_plot(os.path.splitext(desired_path)[0]) with open(output_location + ".pickle", "wb") as f: pickle.dump(save_data, f) if paper: # automatically generate useful file formats high_dpi = max(dpi, 400) formats_for_paper = ["pdf", "png"] # not including SVG anymore because files too large for format in formats_for_paper: savefig( f"{desired_path}-PAPER.{format}", dpi=high_dpi, data=data, paper=False, **kwargs ) savefig(f"{desired_path}-low-PAPER.pdf", dpi=200, data=data, paper=False, **kwargs) return full_path = path_for_plot(desired_path) provenance_path = full_path + ".provenance.json" provenance_context = { "VERSION": VERSION, "time": datetime.datetime.now().isoformat(), "jupyter_notebook_name": get_notebook_name(), "name": "savefig", } def extract(for_data): try: return for_data.attrs.get("provenance", {}) except Exception: return {} if data is not None: assert isinstance( data, ( list, tuple, set, ), ) provenance_context.update( { "jupyter_context": get_recent_history(1), "data": [extract(d) for d in data], } ) else: # get more recent history because we don't have the data provenance_context.update( { "jupyter_context": get_recent_history(5), } ) with open(provenance_path, "w") as f: json.dump(provenance_context, f, indent=2) plt.savefig(full_path, dpi=dpi, **kwargs)
[docs]def path_for_plot(desired_path): """Provides workspace and date scoped path generation for plots. This is used to ensure that analysis products are grouped together and organized in a reasonable way (by each day, together). This will be used automatically if you use `arpes.plotting.utils.savefig` instead of the one from matplotlib. """ if not CONFIG["WORKSPACE"]: attempt_determine_workspace() workspace = CONFIG["WORKSPACE"] if not workspace: warnings.warn("Saving locally, no workspace found.") return os.path.join(os.getcwd(), desired_path) try: import arpes.config figure_path = arpes.config.FIGURE_PATH if figure_path is None: figure_path = os.path.join(workspace["path"], "figures") filename = os.path.join( figure_path, workspace["name"], datetime.date.today().isoformat(), desired_path ) filename = str(pathlib.Path(filename).absolute()) parent_directory = os.path.dirname(filename) if not os.path.exists(parent_directory): try: os.makedirs(parent_directory) except OSError as exc: if exc.errno != errno.EEXIST: raise exc return filename except Exception as e: warnings.warn("Misconfigured FIGURE_PATH saving locally: {}".format(e)) return os.path.join(os.getcwd(), desired_path)
def path_for_holoviews(desired_path): """Determines an appropriate output path for a holoviews save.""" skip_paths = [".svg", ".png", ".jpeg", ".jpg", ".gif"] prefix, ext = os.path.splitext(desired_path) if ext in skip_paths: return prefix return prefix + ext def name_for_dim(dim_name, escaped=True): """Alternate variant of `label_for_dim`.""" if SETTINGS["use_tex"]: name = { "temperature": "Temperature", "beta": r"$\beta$", "theta": r"$\theta$", "chi": r"$\chi$", "alpha": r"$\alpha$", "psi": r"$\psi$", "phi": r"$\phi", "eV": r"$\textnormal{E}$", "kx": r"$\textnormal{k}_\textnormal{x}$", "ky": r"$\textnormal{k}_\textnormal{y}$", "kz": r"$\textnormal{k}_\textnormal{z}$", "kp": r"$\textnormal{k}_\textnormal{\parallel}$", "hv": r"$h\nu$", }.get(dim_name) else: name = { "temperature": "Temperature", "beta": "β", "theta": "θ", "chi": "χ", "alpha": "α", "psi": "ψ", "phi": "φ", "eV": "E", "kx": "Kx", "ky": "Ky", "kz": "Kz", "kp": "Kp", "hv": "Photon Energy", }.get(dim_name) if not escaped: name = name.replace("$", "") return name def unit_for_dim(dim_name, escaped=True): """Calculate LaTeX or fancy display label for the unit associated to a dimension.""" if SETTINGS["use_tex"]: unit = { "temperature": "K", "theta": r"rad", "beta": r"rad", "psi": r"rad", "chi": r"rad", "alpha": r"rad", "phi": r"rad", "eV": r"eV", "kx": r"$\AA^{-1}$", "ky": r"$\AA^{-1}$", "kz": r"$\AA^{-1}$", "kp": r"$\AA^{-1}$", "hv": r"eV", }.get(dim_name) else: unit = { "temperature": "K", "theta": r"rad", "beta": r"rad", "psi": r"rad", "chi": r"rad", "alpha": r"rad", "phi": r"rad", "eV": r"eV", "kx": "1/Å", "ky": "1/Å", "kz": "1/Å", "kp": "1/Å", "hv": "eV", }.get(dim_name) if not escaped: unit = unit.replace("$", "") return unit def label_for_colorbar(data): """Returns an appropriate label for an ARPES intensity colorbar.""" if not data.S.is_differentiated: return r"Spectrum Intensity (arb.)" # determine which axis was differentiated hist = data.S.history records = [h["record"] for h in hist if isinstance(h, dict)] if "curvature" in [r["by"] for r in records]: curvature_record = [r for r in records if r["by"] == "curvature"][0] directions = curvature_record["directions"] return r"Curvature along {} and {}".format( name_for_dim(directions[0]), name_for_dim(directions[1]) ) derivative_records = [r for r in records if r["by"] == "dn_along_axis"] c = Counter(itertools.chain(*[[d["axis"]] * d["order"] for d in derivative_records])) partial_frag = r"" if sum(c.values()) > 1: partial_frag = r"^" + str(sum(c.values())) return ( r"$\frac{\partial" + partial_frag + r" \textnormal{Int.}}{" + r"".join( [ r"\partial {}^{}".format(name_for_dim(item, escaped=False), n) for item, n in c.items() ] ) + "}$ (arb.)" ) def label_for_dim(data=None, dim_name=None, escaped=True): """Generates a fancy label (LaTeX, if available) for a dimension according to standard conventions.""" if SETTINGS.get("use_tex", False): raw_dim_names = { "temperature": "Temperature", "theta": r"$\theta$", "beta": r"$\beta$", "chi": r"$\chi$", "alpha": r"$\alpha$", "psi": r"$\psi$", "phi": r"$\varphi$", "eV": r"Binding Energy (eV)", "angle": r"Interp. Angle", "kinetic": r"Kinetic Energy (eV)", "temp": r"Temperature", "kp": r"$k_\parallel$", "kx": r"$k_\text{x}$", "ky": r"$k_\text{y}$", "kz": r"$k_\perp$", "hv": "Photon Energy", "x": "X (mm)", "y": "Y (mm)", "z": "Z (mm)", "spectrum": "Intensity (arb.)", } else: raw_dim_names = { "temperature": "Temperature", "beta": "β", "theta": "θ", "chi": "χ", "alpha": "α", "psi": "ψ", "phi": "φ", "eV": "Binding Energy (eV)", "angle": "Interp. Angle", "kinetic": "Kinetic Energy (eV)", "temp": "Temperature (K)", "kp": "Kp", "kx": "Kx", "ky": "Ky", "kz": "Kz", "hv": "Photon Energy (eV)", "x": "X (mm)", "y": "Y (mm)", "z": "Z (mm)", "spectrum": "Intensity (arb.)", } if dim_name in raw_dim_names: return raw_dim_names.get(dim_name) try: from titlecase import titlecase except ImportError: warnings.warn("Using alternative titlecase, for better results `pip install titlecase`.") def titlecase(s: str) -> str: """Poor man's titlecase. Args: s: The input string Returns: The titlecased string. """ return s.title() result = titlecase(dim_name.replace("_", " ")) return result
[docs]def fancy_labels(ax_or_ax_set, data=None): """Attaches better display axis labels for all axes. Axes are determined by those that can be traversed in the passed figure or axes. Args: ax_or_ax_set: The axis to search for subaxes data: The source data, used to calculate names, typically you can leave this empty """ if isinstance(ax_or_ax_set, (list, tuple, set, np.ndarray)): for ax in ax_or_ax_set: fancy_labels(ax) return ax = ax_or_ax_set try: ax.set_xlabel(label_for_dim(data=data, dim_name=ax.get_xlabel())) except Exception as e: raise e pass try: ax.set_ylabel(label_for_dim(data=data, dim_name=ax.get_ylabel())) except Exception: pass
def label_for_symmetry_point(point_name: str) -> str: """Determines the LaTeX label for a symmetry point shortcode.""" if SETTINGS["use_tex"]: proper_names = { "G": r"$\Gamma$", "X": r"X", "Y": r"Y", } else: proper_names = { "G": r"Γ", "X": r"X", "Y": r"Y", } return proper_names.get(point_name, point_name) class CoincidentLinesPlot: """Helper to allow drawing lines at the same location. Will draw n lines offset so that their center appears at the data center, and the lines will end up nonoverlapping. Only works for straight lines. Technique adapted from `StackOverflow <https://stackoverflow.com/questions/19394505/matplotlib-expand-the-line-with-specified-width-in-data-unit>`_. """ linewidth = 3 def __init__(self, **kwargs): self.ax = kwargs.pop("ax", plt.gca()) self.fig = kwargs.pop("fig", plt.gcf()) self.extra_kwargs = kwargs self.ppd = 72.0 / self.fig.dpi self.has_drawn = False self.events = { "resize_event": self.ax.figure.canvas.mpl_connect("resize_event", self._resize), "motion_notify_event": self.ax.figure.canvas.mpl_connect( "motion_notify_event", self._resize ), "button_release_event": self.ax.figure.canvas.mpl_connect( "button_release_event", self._resize ), } self.handles = [] self.lines = [] # saved args and kwargs for plotting, does not verify coincidence def add_line(self, *args, **kwargs): """Adds an additional line into the collection to be drawn.""" assert not self.has_drawn self.lines.append( ( args, kwargs, ) ) def draw(self): """Draw all of the lines after offsetting them slightly.""" self.has_drawn = True offset_in_data_units = self.data_units_per_pixel * self.linewidth self.offsets = [ offset_in_data_units * (o - (len(self.lines) - 1) / 2) for o in range(len(self.lines)) ] for offset, (line_args, line_kwargs) in zip(self.offsets, self.lines): line_args = self.normalize_line_args(line_args) line_args[1] = np.array(line_args[1]) + offset handle = self.ax.plot(*line_args, **line_kwargs) self.handles.append(handle) @property def data_units_per_pixel(self): """Gets the data/pixel conversion ratio.""" trans = self.ax.transData.transform inverse = (trans((1, 1)) - trans((0, 0))) * self.ppd return (1 / inverse[0], 1 / inverse[1]) def normalize_line_args(self, args): def is_data_type(value): return isinstance(value, (np.array, np.ndarray, list, tuple)) assert is_data_type(args[0]) if len(args) > 1 and is_data_type(args[1]) and len(args[0]) == len(args[1]): # looks like we have x and y data return args # otherwise we should pad the args with the x data return [range(len(args[0]))] + args def _resize(self, event=None): # Keep the trace in here until we can test appropriately. import pdb pdb.set_trace() """ self.line.set_linewidth(lw) self.ax.figure.canvas.draw_idle() self.lw = lw """
[docs]def invisible_axes(ax): """Make a Axes instance completely invisible.""" ax.grid(False) ax.set_axis_off() ax.patch.set_alpha(0)
[docs]def no_ticks(ax): """Remove all axis ticks.""" ax.get_xaxis().set_ticks([]) ax.get_yaxis().set_ticks([])