Source code for arpes.plotting.spin

"""Some general plotting routines for presentation of spin-ARPES data."""
import matplotlib.cm as cm
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import LineCollection
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import xarray as xr
from arpes.analysis.sarpes import to_intensity_polarization
from arpes.analysis.statistics import mean_and_deviation
from arpes.bootstrap import bootstrap
from arpes.plotting.tof import scatter_with_std
from arpes.plotting.utils import label_for_dim, savefig, path_for_plot, polarization_colorbar
from arpes.provenance import save_plot_provenance
from arpes.utilities.math import polarization, propagate_statistical_error


__all__ = (
    "spin_polarized_spectrum",
    "spin_colored_spectrum",
    "spin_difference_spectrum",
)


test_polarization = propagate_statistical_error(polarization)


[docs]@save_plot_provenance def spin_colored_spectrum(spin_dr, title=None, ax=None, out=None, scatter=False, **kwargs): """Plots a spin spectrum using total intensity and assigning color with the spin polarization.""" if ax is None: _, ax = plt.subplots(figsize=(6, 4)) as_intensity = to_intensity_polarization(spin_dr) intensity = as_intensity.intensity pol = as_intensity.polarization.copy(deep=True) if len(intensity.dims) == 1: inset_ax = inset_axes(ax, width="30%", height="5%", loc=1) coord = intensity.coords[intensity.dims[0]] points = np.array([coord.values, intensity.values]).G.reshape(-1, 1, 2) pol.values[np.isnan(pol.values)] = 0 pol.values[pol.values > 1] = 1 pol.values[pol.values < -1] = -1 pol_colors = cm.get_cmap("RdBu")(pol.values[:-1]) if scatter: pol_colors = cm.get_cmap("RdBu")(pol.values) ax.scatter(coord.values, intensity.values, c=pol_colors, s=1.5) else: segments = np.concatenate([points[:-1], points[1:]], axis=1) lc = LineCollection(segments, colors=pol_colors) ax.add_collection(lc) ax.set_xlim(coord.min().item(), coord.max().item()) ax.set_ylim(0, intensity.max().item() * 1.15) ax.set_ylabel("ARPES Spectrum Intensity (arb.)") ax.set_xlabel(label_for_dim(spin_dr, dim_name=intensity.dims[0])) ax.set_title(title if title is not None else "Spin Polarization") polarization_colorbar(inset_ax) if out is not None: savefig(out, dpi=400) plt.clf() return path_for_plot(out) else: plt.show()
[docs]@save_plot_provenance def spin_difference_spectrum(spin_dr, title=None, ax=None, out=None, scatter=False, **kwargs): """Plots a spin difference spectrum.""" if ax is None: _, ax = plt.subplots(figsize=(6, 4)) try: as_intensity = to_intensity_polarization(spin_dr) except AssertionError: as_intensity = spin_dr intensity = as_intensity.intensity pol = as_intensity.polarization.copy(deep=True) if len(intensity.dims) == 1: inset_ax = inset_axes(ax, width="30%", height="5%", loc=1) coord = intensity.coords[intensity.dims[0]] points = np.array([coord.values, intensity.values]).G.reshape(-1, 1, 2) pol.values[np.isnan(pol.values)] = 0 pol.values[pol.values > 1] = 1 pol.values[pol.values < -1] = -1 pol_colors = cm.get_cmap("RdBu")(pol.values[:-1]) if scatter: pol_colors = cm.get_cmap("RdBu")(pol.values) ax.scatter(coord.values, intensity.values, c=pol_colors, s=1.5) else: segments = np.concatenate([points[:-1], points[1:]], axis=1) lc = LineCollection(segments, colors=pol_colors) ax.add_collection(lc) ax.set_xlim(coord.min().item(), coord.max().item()) ax.set_ylim(0, intensity.max().item() * 1.15) ax.set_ylabel("ARPES Spectrum Intensity (arb.)") ax.set_xlabel(label_for_dim(spin_dr, dim_name=intensity.dims[0])) ax.set_title(title if title is not None else "Spin Polarization") polarization_colorbar(inset_ax) if out is not None: savefig(out, dpi=400) plt.clf() return path_for_plot(out) else: plt.show()
[docs]@save_plot_provenance def spin_polarized_spectrum( spin_dr, title=None, ax=None, out=None, component="y", scatter=False, stats=False, norm=None ): """Plots a simple spin polarized spectrum using curves for the up and down components.""" if ax is None: _, ax = plt.subplots(2, 1, sharex=True) if stats: spin_dr = bootstrap(lambda x: x)(spin_dr, N=100) pol = mean_and_deviation(to_intensity_polarization(spin_dr)) counts = mean_and_deviation(spin_dr) else: counts = spin_dr pol = to_intensity_polarization(counts) ax_left = ax[0] ax_right = ax[1] up = counts.down.data down = counts.up.data energies = spin_dr.coords["eV"].values min_e, max_e = np.min(energies), np.max(energies) # Plot the spectra if stats: if scatter: scatter_with_std(counts, "up", color="red", ax=ax_left) scatter_with_std(counts, "down", color="blue", ax=ax_left) else: v, s = counts.up.values, counts.up_std.values ax_left.plot(energies, v, "r") ax_left.fill_between(energies, v - s, v + s, color="r", alpha=0.25) v, s = counts.down.values, counts.down_std.values ax_left.plot(energies, v, "b") ax_left.fill_between(energies, v - s, v + s, color="b", alpha=0.25) else: ax_left.plot(energies, up, "r") ax_left.plot(energies, down, "b") ax_left.set_title(title if title is not None else "Spin spectrum {}".format("")) ax_left.set_ylabel(r"\textbf{Spectrum Intensity}") ax_left.set_xlabel(r"\textbf{Kinetic energy} (eV)") ax_left.set_xlim(min_e, max_e) max_up = np.max(up) max_down = np.max(down) ax_left.set_ylim(0, max(max_down, max_up) * 1.2) # Plot the polarization and associated statistical error bars if stats: if scatter: scatter_with_std(pol, "polarization", ax=ax_right, color="black") else: v = pol.polarization.data s = pol.polarization_std.data ax_right.plot(energies, v, color="black") ax_right.fill_between(energies, v - s, v + s, color="black", alpha=0.25) else: ax_right.plot(energies, pol.polarization.data, color="black") ax_right.fill_between(energies, 0, 1, facecolor="blue", alpha=0.1) ax_right.fill_between(energies, -1, 0, facecolor="red", alpha=0.1) ax_right.set_title("Spin polarization, $\\text{S}_\\textbf{" + component + "}$") ax_right.set_ylabel(r"\textbf{Polarization}") ax_right.set_xlabel(r"\textbf{Kinetic Energy} (eV)") ax_right.set_xlim(min_e, max_e) ax_right.axhline(0, color="white", linestyle=":") ax_right.set_ylim(-1, 1) ax_right.grid(True, axis="y") plt.tight_layout() if out is not None: savefig(out, dpi=400) plt.clf() return path_for_plot(out) else: pass return ax
def polarization_intensity_to_color(data: xr.Dataset, vmax=None, pmax=1): """Converts a dataset with intensity and polarization into a RGB colorarray. This consists of a few steps: 1. first we take the polarization to get a RdBu RGB value 2. We convert the RGB value to HSV 3. We use the relative intensity to compute a new value for the V ('value') channel 4. We convert back to RGB Args: data: The input intensity/data to convert to a color representation. Returns: The rgb color data. """ if vmax is None: # use the 98th percentile data if not provided vmax = np.percentile(data.intensity.values, 98) rgbas = cm.RdBu((data.polarization.values / pmax + 1) / 2) slices = [slice(None) for _ in data.polarization.dims] + [slice(0, 3)] rgbs = rgbas[slices] hsvs = matplotlib.colors.rgb_to_hsv(rgbs) intensity_values = data.intensity.values.copy() / vmax intensity_values[intensity_values > 1] = 1 hsvs[:, :, 2] = intensity_values return matplotlib.colors.hsv_to_rgb(hsvs) @save_plot_provenance def hue_brightness_plot(data: xr.Dataset, ax=None, out=None, **kwargs): assert "intensity" in data and "polarization" in data fig = None if ax is None: fig, ax = plt.subplots( figsize=kwargs.get( "figsize", ( 7, 5, ), ) ) x, y = data.coords[data.intensity.dims[0]].values, data.coords[data.intensity.dims[1]].values extent = [y[0], y[-1], x[0], x[-1]] ax.imshow( polarization_intensity_to_color(data, **kwargs), extent=extent, aspect="auto", origin="lower", ) ax.set_xlabel(data.intensity.dims[1]) ax.set_ylabel(data.intensity.dims[0]) ax.grid(False) if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax