Source code for arpes.plotting.stack_plot

"""Plotting routines for making the classic stacked line plots.

Think the album art for "Unknown Pleasures".
"""
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import xarray as xr
from arpes.analysis.general import rebin
from arpes.plotting.tof import scatter_with_std
from arpes.plotting.utils import (
    colorbarmaps_for_axis,
    generic_colorbarmap_for_data,
    fancy_labels,
    path_for_plot,
    label_for_dim,
)
from arpes.provenance import save_plot_provenance
from arpes.typing import DataType
from arpes.utilities import normalize_to_spectrum

__all__ = (
    "stack_dispersion_plot",
    "flat_stack_plot",
    "offset_scatter_plot",
)


[docs]@save_plot_provenance def offset_scatter_plot( data: DataType, name_to_plot=None, stack_axis=None, fermi_level=True, cbarmap=None, ax=None, out=None, scale_coordinate=0.5, ylim=None, aux_errorbars=True, **kwargs ): """Makes a stack plot but which uses scatters rather than a lineplot for each curve in the set.""" assert isinstance(data, xr.Dataset) if name_to_plot is None: var_names = [k for k in data.data_vars.keys() if "_std" not in k] assert len(var_names) == 1 name_to_plot = var_names[0] assert (name_to_plot + "_std") in data.data_vars.keys() if len(data.data_vars[name_to_plot].dims) != 2: raise ValueError( "In order to produce a stack plot, data must be image-like." "Passed data included dimensions: {}".format(data.data_vars[name_to_plot].dims) ) fig = None inset_ax = None if ax is None: fig, ax = plt.subplots( figsize=kwargs.get( "figsize", ( 11, 5, ), ) ) if inset_ax is None: inset_ax = inset_axes(ax, width="40%", height="5%", loc="upper left") if stack_axis is None: stack_axis = data.data_vars[name_to_plot].dims[0] skip_colorbar = True if cbarmap is None: skip_colorbar = False try: cbarmap = colorbarmaps_for_axis[stack_axis] except: cbarmap = generic_colorbarmap_for_data( data.coords[stack_axis], ax=inset_ax, ticks=kwargs.get("ticks") ) cbar, cmap = cbarmap if not isinstance(cmap, matplotlib.colors.Colormap): # do our best try: cmap = cmap() except: # might still be fine pass # should be exactly two other_dim = [d for d in data.dims if d != stack_axis][0] if "eV" in data.dims and "eV" != stack_axis and fermi_level: ax.axhline(0, linestyle="--", color="red") ax.fill_betweenx([-1e6, 1e6], 0, 0.2, color="black", alpha=0.07) ax.set_ylim(ylim) # real plotting here for i, (coord, value) in enumerate(data.G.iterate_axis(stack_axis)): delta = data.G.stride(generic_dim_names=False)[other_dim] data_for = value.copy(deep=True) data_for.coords[other_dim] = data_for.coords[other_dim].copy(deep=True) data_for.coords[other_dim].values = data_for.coords[other_dim].values.copy() data_for.coords[other_dim].values -= i * delta * scale_coordinate / 10 scatter_with_std(data_for, name_to_plot, ax=ax, color=cmap(coord[stack_axis])) if aux_errorbars: assert ylim is not None data_for = data_for.copy(deep=True) flattened = data_for.data_vars[name_to_plot].copy(deep=True) flattened.values = ylim[0] * np.ones(flattened.values.shape) data_for = data_for.assign(**{name_to_plot: flattened}) scatter_with_std( data_for, name_to_plot, ax=ax, color=cmap(coord[stack_axis]), fmt="none" ) ax.set_xlabel(other_dim) ax.set_ylabel(name_to_plot) fancy_labels(ax) try: if inset_ax and not skip_colorbar: inset_ax.set_xlabel(stack_axis, fontsize=16) fancy_labels(inset_ax) cbar(ax=inset_ax, **kwargs) except TypeError: # colorbar already rendered pass if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
[docs]@save_plot_provenance def flat_stack_plot( data: DataType, stack_axis=None, fermi_level=True, cbarmap=None, ax=None, mode="line", title=None, out=None, transpose=False, **kwargs ): """Generates a stack plot with all the lines distinguished by color rather than offset.""" data = normalize_to_spectrum(data) if len(data.dims) != 2: raise ValueError( "In order to produce a stack plot, data must be image-like." "Passed data included dimensions: {}".format(data.dims) ) fig = None inset_ax = None if ax is None: fig, ax = plt.subplots( figsize=kwargs.get( "figsize", ( 7, 5, ), ) ) inset_ax = inset_axes(ax, width="40%", height="5%", loc=1) if stack_axis is None: stack_axis = data.dims[0] skip_colorbar = True if cbarmap is None: skip_colorbar = False try: cbarmap = colorbarmaps_for_axis[stack_axis] except KeyError: cbarmap = generic_colorbarmap_for_data( data.coords[stack_axis], ax=inset_ax, ticks=kwargs.get("ticks") ) cbar, cmap = cbarmap # should be exactly two other_dim = [d for d in data.dims if d != stack_axis][0] other_coord = data.coords[other_dim] if not isinstance(cmap, matplotlib.colors.Colormap): # do our best try: cmap = cmap() except: # might still be fine pass if "eV" in data.dims and "eV" != stack_axis and fermi_level: if transpose: ax.axhline(0, color="red", alpha=0.8, linestyle="--", linewidth=1) else: ax.axvline(0, color="red", alpha=0.8, linestyle="--", linewidth=1) # meat of the plotting for coord_dict, marginal in list(data.G.iterate_axis(stack_axis)): if transpose: if mode == "line": ax.plot( marginal.values, marginal.coords[marginal.dims[0]].values, color=cmap(coord_dict[stack_axis]), **kwargs ) else: assert mode == "scatter" raise NotImplementedError else: if mode == "line": marginal.plot(ax=ax, color=cmap(coord_dict[stack_axis]), **kwargs) else: assert mode == "scatter" ax.scatter(*marginal.G.to_arrays(), color=cmap(coord_dict[stack_axis]), **kwargs) ax.set_xlabel(marginal.dims[0]) ax.set_xlabel(label_for_dim(data, ax.get_xlabel())) ax.set_ylabel("Spectrum Intensity (arb).") ax.set_title(title, fontsize=14) ax.set_xlim([other_coord.min().item(), other_coord.max().item()]) try: if inset_ax is not None and not skip_colorbar: inset_ax.set_xlabel(stack_axis, fontsize=16) fancy_labels(inset_ax) cbar(ax=inset_ax, **kwargs) except TypeError: # already rendered pass if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
[docs]@save_plot_provenance def stack_dispersion_plot( data: DataType, stack_axis=None, ax=None, title=None, out=None, max_stacks=100, transpose=False, use_constant_correction=False, correction_side=None, color=None, c=None, label=None, shift=0, no_scatter=False, negate=False, s=1, scale_factor=None, linewidth=1, palette=None, zero_offset=False, uniform=False, **kwargs ): """Generates a stack plot with all the lines distinguished by offset rather than color.""" data = normalize_to_spectrum(data) if stack_axis is None: stack_axis = data.dims[0] other_axes = list(data.dims) other_axes.remove(stack_axis) other_axis = other_axes[0] stack_coord = data.coords[stack_axis] if len(stack_coord.values) > max_stacks: data = rebin( data, reduction=dict([[stack_axis, int(np.ceil(len(stack_coord.values) / max_stacks))]]) ) fig = None if ax is None: fig, ax = plt.subplots(figsize=(7, 7)) if title is None: title = "{} Stack".format(data.S.label.replace("_", " ")) max_over_stacks = np.max(data.values) cvalues = data.coords[other_axis].values if scale_factor is None: maximum_deviation = -np.inf for _, marginal in data.G.iterate_axis(stack_axis): marginal_values = -marginal.values if negate else marginal.values marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1] if use_constant_correction: true_ys = marginal_values - marginal_offset elif zero_offset: true_ys = marginal_values else: true_ys = marginal_values - np.linspace( marginal_offset, right_marginal_offset, len(marginal_values) ) maximum_deviation = np.max([maximum_deviation] + list(np.abs(true_ys))) scale_factor = 0.02 * (np.max(cvalues) - np.min(cvalues)) / maximum_deviation iteration_order = -1 # might need to fiddle with this in certain cases lim = [-np.inf, np.inf] labeled = False for i, (coord_dict, marginal) in enumerate( list(data.G.iterate_axis(stack_axis))[::iteration_order] ): coord_value = coord_dict[stack_axis] xs = cvalues marginal_values = -marginal.values if negate else marginal.values marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1] if use_constant_correction: offset = right_marginal_offset if correction_side == "right" else marginal_offset true_ys = (marginal_values - offset) / max_over_stacks ys = scale_factor * true_ys + coord_value elif zero_offset: true_ys = marginal_values / max_over_stacks ys = scale_factor * true_ys + coord_value elif uniform: true_ys = marginal_values / max_over_stacks ys = scale_factor * true_ys + i else: true_ys = ( marginal_values - np.linspace(marginal_offset, right_marginal_offset, len(marginal_values)) ) / max_over_stacks ys = scale_factor * true_ys + coord_value raw_colors = color or c or "black" if palette: if isinstance(palette, str): palette = cm.get_cmap(palette) raw_colors = palette(np.abs(true_ys / max_over_stacks)) if transpose: xs, ys = ys, xs xs = xs - i * shift lim = [max(lim[0], np.min(xs)), min(lim[1], np.max(xs))] label_for = "_nolegend_" if not labeled: labeled = True label_for = label color_for_plot = raw_colors if callable(color_for_plot): color_for_plot = color_for_plot(coord_value) if isinstance(raw_colors, (str, tuple)) or no_scatter: ax.plot(xs, ys, linewidth=linewidth, color=color_for_plot, label=label_for, **kwargs) else: ax.scatter(xs, ys, color=color_for_plot, s=s, label=label_for, **kwargs) x_label = other_axis y_label = stack_axis if transpose: x_label, y_label = y_label, x_label ax.set_xlabel(label_for_dim(data, x_label)) ax.set_ylabel(label_for_dim(data, y_label)) if transpose: ax.set_ylim(lim) else: ax.set_xlim(lim) ax.set_title(title) if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
@save_plot_provenance def overlapped_stack_dispersion_plot( data: DataType, stack_axis=None, ax=None, title=None, out=None, max_stacks=100, use_constant_correction=False, transpose=False, negate=False, s=1, scale_factor=None, linewidth=1, palette=None, **kwargs ): data = normalize_to_spectrum(data) if stack_axis is None: stack_axis = data.dims[0] other_axes = list(data.dims) other_axes.remove(stack_axis) other_axis = other_axes[0] stack_coord = data.coords[stack_axis] if len(stack_coord.values) > max_stacks: data = rebin( data, reduction=dict([[stack_axis, int(np.ceil(len(stack_coord.values) / max_stacks))]]) ) fig = None if ax is None: fig, ax = plt.subplots(figsize=(7, 7)) if title is None: title = "{} Stack".format(data.S.label.replace("_", " ")) max_over_stacks = np.max(data.values) cvalues = data.coords[other_axis].values if scale_factor is None: maximum_deviation = -np.inf for _, marginal in data.G.iterate_axis(stack_axis): marginal_values = -marginal.values if negate else marginal.values marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1] if use_constant_correction: true_ys = marginal_values - marginal_offset else: true_ys = marginal_values - np.linspace( marginal_offset, right_marginal_offset, len(marginal_values) ) maximum_deviation = np.max([maximum_deviation] + list(np.abs(true_ys))) scale_factor = 0.02 * (np.max(cvalues) - np.min(cvalues)) / maximum_deviation iteration_order = -1 # might need to fiddle with this in certain cases for coord_dict, marginal in list(data.G.iterate_axis(stack_axis))[::iteration_order]: coord_value = coord_dict[stack_axis] xs = cvalues marginal_values = -marginal.values if negate else marginal.values marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1] if use_constant_correction: true_ys = (marginal_values - marginal_offset) / max_over_stacks ys = scale_factor * true_ys + coord_value else: true_ys = ( marginal_values - np.linspace(marginal_offset, right_marginal_offset, len(marginal_values)) ) / max_over_stacks ys = scale_factor * true_ys + coord_value raw_colors = "black" if palette: if isinstance(palette, str): palette = cm.get_cmap(palette) raw_colors = palette(np.abs(true_ys / max_over_stacks)) if transpose: xs, ys = ys, xs if isinstance(raw_colors, str): plt.plot(xs, ys, linewidth=linewidth, color=raw_colors, **kwargs) else: plt.scatter(xs, ys, color=raw_colors, s=s, **kwargs) x_label = other_axis y_label = stack_axis if transpose: x_label, y_label = y_label, x_label ax.set_xlabel(label_for_dim(data, x_label)) ax.set_ylabel(label_for_dim(data, y_label)) ax.set_title(title) if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) plt.show() return fig, ax