"""Establishes the PyARPES data model by extending the `xarray` types.
This is another core part of PyARPES. It provides a lot of extensions to
what comes out of the box in xarray. Some of these are useful generics,
generally on the .T extension, others collect and manipulate metadata,
interface with plotting routines, provide functional programming utilities,
etc.
If `f` is an ARPES spectrum, then `f.S` should provide a nice representation of your data
in a Jupyter cell. This is a complement to the text based approach that merely printing `f`
offers. Note, as of PyARPES v3.x.y, the xarray version has been bumped and this representation
is no longer necessary as one is provided upstream.
The main accessors are .S, .G, .X. and .F.
The `.S` accessor:
The `.S` accessor contains functionality related to spectroscopy. Utilities
which only make sense in this context should be placed here, while more generic
tools should be placed elsewhere.
The `.G.` accessor:
This a general purpose collection of tools which exists to provide conveniences
over what already exists in the xarray data model. As an example, there are
various tools for simultaneous iteration of data and coordinates here, as well as
for vectorized application of functions to data or coordinates.
The `.X` accessor:
This is an accessor which contains tools related to selecting and subselecting
data. The two most notable tools here are `.X.first_exceeding` which is very useful
for initializing curve fits and `.X.max_in_window` which is useful for refining
these initial parameter choices.
The `.F.` accessor:
This is an accessor which contains tools related to interpreting curve fitting
results. In particular there are utilities for vectorized extraction of parameters,
for plotting several curve fits, or for selecting "worst" or "best" fits according
to some measure.
"""
import pandas as pd
import lmfit
import arpes
import contextlib
import collections
import copy
import itertools
import warnings
from collections import OrderedDict
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from scipy import ndimage as ndi
import arpes.constants
import arpes.plotting as plotting
from arpes.analysis.general import rebin
from arpes.analysis.band_analysis_utils import param_getter, param_stderr_getter
from arpes.models.band import MultifitBand
from arpes.plotting.utils import fancy_labels, remove_colorbars
from arpes.plotting.parameter import plot_parameter
from arpes.typing import DataType, DTypeLike
from arpes.utilities import apply_dataarray
from arpes.utilities.collections import MappableDict
from arpes.utilities.conversion import slice_along_path
import arpes.utilities.math
from arpes.utilities.region import DesignatedRegions, normalize_region
from arpes.utilities.xarray import unwrap_xarray_item, unwrap_xarray_dict
__all__ = ["ARPESDataArrayAccessor", "ARPESDatasetAccessor", "ARPESFitToolsAccessor"]
def _iter_groups(grouped: Dict[str, Any]) -> Iterator[Any]:
"""Iterates through a flattened sequence.
Sequentially yields keys and values from each sequence associated with a key.
If a key "k" is associated to a value "v0" which is not iterable, it will be emitted as a
single pair [k, v0].
Otherwise, one pair is yielded for every item in the associated value.
"""
for k, value_or_list in grouped.items():
try:
for list_item in value_or_list:
yield k, list_item
except TypeError:
yield k, value_or_list
class ARPESAccessorBase:
"""Base class for the xarray extensions in PyARPES."""
def along(self, directions, **kwargs):
return slice_along_path(self._obj, directions, **kwargs)
def find(self, name):
return [n for n in dir(self) if name in n]
@property
def sherman_function(self):
for option in ["sherman", "sherman_function", "SHERMAN"]:
if option in self._obj.attrs:
return self._obj.attrs[option]
raise ValueError("No Sherman function could be found on the data. Is this a spin dataset?")
@property
def experimental_conditions(self):
try:
temp = self.temp
except AttributeError:
temp = None
return {
"hv": self.hv,
"polarization": self.polarization,
"temp": temp,
}
@property
def polarization(self):
if "epu_pol" in self._obj.attrs:
# merlin: TODO normalize these
# check and complete
try:
return {
0: "p",
1: "rc",
2: "s",
}.get(int(self._obj.attrs["epu_pol"]))
except ValueError:
return self._obj.attrs["epu_pol"]
return None
@property
def is_subtracted(self):
if self._obj.attrs.get("subtracted"):
return True
if isinstance(self._obj, xr.DataArray):
# if at least 5% of the values are < 0 we should consider the data
# to be best represented by a coolwarm map
return (((self._obj < 0) * 1).mean() > 0.05).item()
@property
def is_spatial(self) -> bool:
"""Infers whether a given scan has real-space dimensions (SPEM or u/nARPES).
Returns:
True if the data is explicltly a "ucut" or "spem" or contains "X", "Y", or "Z", dimensions.
False otherwise.
"""
if self.spectrum_type in {"ucut", "spem"}:
return True
return any(d in {"X", "Y", "Z"} for d in self._obj.dims)
@property
def is_kspace(self) -> bool:
"""Infers whether the scan is k-space converted or not.
Because of the way this is defined, it will return
true for XPS spectra, which I suppose is true but trivially.
Returns:
True if the data is k-space converted.
False otherwise.
"""
return not any(d in {"phi", "theta", "beta", "angle"} for d in self._obj.dims)
@property
def is_slit_vertical(self) -> bool:
"""Whether the data was taken on an analyzer with vertical slit.
Caveat emptor: this assumes that the alpha coordinate is not some intermediate value.
Returns:
True if the alpha value is consistent with a vertical slit analyzer.
False otherwise.
"""
return np.abs(self.lookup_offset_coord("alpha") - np.pi / 2) < (np.pi / 180)
@property
def endstation(self) -> str:
"""Alias for the location attribute used to load the data.
Returns:
The name of loader/location which was used to load data.
"""
return self._obj.attrs["location"]
def with_values(self, new_values: np.ndarray) -> xr.DataArray:
"""Copy with new array values.
Easy way of creating a DataArray that has the same shape as the calling object but data populated
from the array `new_values`
Args:
new_values: The new values which should be used for the data.
Returns:
A copy of the data with new values but identical dimensions, coordinates, and attrs.
"""
return xr.DataArray(
new_values.reshape(self._obj.values.shape),
coords=self._obj.coords,
dims=self._obj.dims,
attrs=self._obj.attrs,
)
def with_standard_coords(self):
obj = self._obj
collected_renamings = {}
coord_names = {
"pixel",
"eV",
"phi",
}
for coord_name in coord_names:
clarified = [name for name in obj.coords.keys() if (coord_name + "-") in name]
assert len(clarified) < 2
if clarified:
collected_renamings[clarified[0]] = coord_name
return obj.rename(collected_renamings)
@property
def logical_offsets(self):
if "long_x" not in self._obj.coords:
raise ValueError(
"Logical offsets can currently only be "
"accessed for hierarchical motor systems like nanoARPES."
)
return MappableDict(
unwrap_xarray_dict(
{
"x": self._obj.coords["long_x"] - self._obj.coords["physical_long_x"],
"y": self._obj.coords["long_y"] - self._obj.coords["physical_long_y"],
"z": self._obj.coords["long_z"] - self._obj.coords["physical_long_z"],
}
)
)
@property
def hv(self):
if "hv" in self._obj.coords:
value = float(self._obj.coords["hv"])
if not np.isnan(value):
return value
if "hv" in self._obj.attrs:
value = float(self._obj.attrs["hv"])
if not np.isnan(value):
return value
if "location" in self._obj.attrs:
if self._obj.attrs["location"] == "ALG-MC":
return 5.93
return None
def fetch_ref_attrs(self):
if "ref_attrs" in self._obj.attrs:
return self._obj.attrs
raise NotImplementedError
@property
def scan_type(self):
return self._obj.attrs.get("daq_type")
@property
def spectrum_type(self):
if "spectrum_type" in self._obj.attrs and self._obj.attrs["spectrum_type"]:
return self._obj.attrs["spectrum_type"]
dim_types = {
("eV",): "xps_spectrum",
("eV", "phi"): "spectrum",
# this should check whether the other angular axis perpendicular to scan axis?
("eV", "phi", "beta"): "map",
("eV", "phi", "theta"): "map",
("eV", "hv", "phi"): "hv_map",
# kspace
("eV", "kp"): "spectrum",
("eV", "kx", "ky"): "map",
("eV", "kp", "kz"): "hv_map",
}
dims = tuple(sorted(list(self._obj.dims)))
return dim_types.get(dims)
@property
def is_differentiated(self):
history = self.short_history()
return "dn_along_axis" in history or "curvature" in history
def transpose_to_front(self, dim):
dims = list(self._obj.dims)
assert dim in dims
dims.remove(dim)
return self._obj.transpose(*([dim] + dims))
def transpose_to_back(self, dim):
dims = list(self._obj.dims)
assert dim in dims
dims.remove(dim)
return self._obj.transpose(*(dims + [dim]))
def select_around_data(
self,
points: Union[Dict[str, Any], xr.Dataset],
radius: Dict[str, float] = None,
fast: bool = False,
safe: bool = True,
mode: str = "sum",
**kwargs,
):
"""Performs a binned selection around a point or points.
Can be used to perform a selection along one axis as a function of another, integrating a region
in the other dimensions.
Example:
As an example, suppose we have a dataset with dimensions ('eV', 'kp', 'T',)
and we also by fitting determined the Fermi momentum as a function of T, kp_F('T'), stored in the
dataarray kFs. Then we could select momentum integrated EDCs in a small window around the fermi momentum
for each temperature by using
>>> edcs_at_fermi_momentum = full_data.S.select_around_data({'kp': kFs}, radius={'kp': 0.04}, fast=True) # doctest: +SKIP
The resulting data will be EDCs for each T, in a region of radius 0.04 inverse angstroms around the
Fermi momentum.
Args:
points: The set of points where the selection should be performed.
radius: The radius of the selection in each coordinate. If dimensions are omitted, a standard sized
selection will be made as a compromise.
fast: If true, uses a rectangular rather than a circular region for selectioIf true, uses a
rectangular rather than a circular region for selection.
safe: If true, infills radii with default values. Defaults to `True`.
mode: How the reduction should be performed, one of "sum" or "mean". Defaults to "sum"
kwargs: Can be used to pass radii parameters by keyword with `_r` postfix.
Returns:
The binned selection around the desired point or points.
"""
if isinstance(self._obj, xr.Dataset):
raise TypeError("Cannot use select_around on Datasets only DataArrays!")
if mode not in {"sum", "mean"}:
raise ValueError("mode parameter should be either sum or mean.")
if isinstance(
points,
(
tuple,
list,
),
):
warnings.warn("Dangerous iterable points argument to `select_around`")
points = dict(zip(points, self._obj.dims))
if isinstance(points, xr.Dataset):
points = {k: points[k].item() for k in points.data_vars}
default_radii = {
"kp": 0.02,
"kz": 0.05,
"phi": 0.02,
"beta": 0.02,
"theta": 0.02,
"eV": 0.05,
"delay": 0.2,
"T": 2,
"temp": 2,
}
unspecified = 0.1
if isinstance(radius, float):
radius = {d: radius for d in points.keys()}
else:
collected_terms = set("{}_r".format(k) for k in points.keys()).intersection(
set(kwargs.keys())
)
if collected_terms:
radius = {
d: kwargs.get("{}_r".format(d), default_radii.get(d, unspecified))
for d in points.keys()
}
elif radius is None:
radius = {d: default_radii.get(d, unspecified) for d in points.keys()}
assert isinstance(radius, dict)
radius = {d: radius.get(d, default_radii.get(d, unspecified)) for d in points.keys()}
along_dims = list(points.values())[0].dims
selected_dims = list(points.keys())
stride = self._obj.G.stride(generic_dim_names=False)
new_dim_order = [d for d in self._obj.dims if d not in along_dims] + list(along_dims)
data_for = self._obj.transpose(*new_dim_order)
new_data = data_for.sum(selected_dims, keep_attrs=True)
for coord, value in data_for.G.iterate_axis(along_dims):
nearest_sel_params = {}
if safe:
for d, v in radius.items():
if v < stride[d]:
nearest_sel_params[d] = points[d].sel(**coord)
radius = {d: v for d, v in radius.items() if d not in nearest_sel_params}
if fast:
selection_slices = {
d: slice(
points[d].sel(**coord) - radius[d],
points[d].sel(**coord) + radius[d],
)
for d in points.keys()
if d in radius
}
selected = value.sel(**selection_slices)
else:
raise NotImplementedError
if nearest_sel_params:
selected = selected.sel(**nearest_sel_params, method="nearest")
for d in nearest_sel_params:
# need to remove the extra dims from coords
del selected.coords[d]
if mode == "sum":
new_data.loc[coord] = selected.sum(list(radius.keys())).values
elif mode == "mean":
new_data.loc[coord] = selected.mean(list(radius.keys())).values
return new_data
def select_around(
self,
point: Union[Dict[str, Any], xr.Dataset],
radius: Dict[str, float] = None,
fast: bool = False,
safe: bool = True,
mode: str = "sum",
**kwargs,
) -> xr.DataArray:
"""Selects and integrates a region around a one dimensional point.
This method is useful to do a small region integration, especially around
points on a path of a k-point of interest. See also the companion method `select_around_data`.
If the fast flag is set, we will use the Manhattan norm, i.e. sum over square regions
rather than ellipsoids, as this is less costly.
If radii are not set, or provided through kwargs as 'eV_r' or 'phi_r' for instance,
then we will try to use reasonable default values; buyer beware.
Args:
point: The points where the selection should be performed.
radius: The radius of the selection in each coordinate. If dimensions are omitted, a standard sized
selection will be made as a compromise.
fast: If true, uses a rectangular rather than a circular region for selectioIf true, uses a
rectangular rather than a circular region for selection.
safe: If true, infills radii with default values. Defaults to `True`.
mode: How the reduction should be performed, one of "sum" or "mean". Defaults to "sum"
kwargs: Can be used to pass radii parameters by keyword with `_r` postfix.
Returns:
The binned selection around the desired point or points.
"""
if isinstance(self._obj, xr.Dataset):
raise TypeError("Cannot use select_around on Datasets only DataArrays!")
if mode not in {"sum", "mean"}:
raise ValueError("mode parameter should be either sum or mean.")
if isinstance(
point,
(
tuple,
list,
),
):
warnings.warn("Dangerous iterable point argument to `select_around`")
point = dict(zip(point, self._obj.dims))
if isinstance(point, xr.Dataset):
point = {k: point[k].item() for k in point.data_vars}
default_radii = {
"kp": 0.02,
"kz": 0.05,
"phi": 0.02,
"beta": 0.02,
"theta": 0.02,
"eV": 0.05,
"delay": 0.2,
"T": 2,
}
unspecified = 0.1
if isinstance(radius, float):
radius = {d: radius for d in point.keys()}
else:
collected_terms = set("{}_r".format(k) for k in point.keys()).intersection(
set(kwargs.keys())
)
if collected_terms:
radius = {
d: kwargs.get("{}_r".format(d), default_radii.get(d, unspecified))
for d in point.keys()
}
elif radius is None:
radius = {d: default_radii.get(d, unspecified) for d in point.keys()}
assert isinstance(radius, dict)
radius = {d: radius.get(d, default_radii.get(d, unspecified)) for d in point.keys()}
# make sure we are taking at least one pixel along each
nearest_sel_params = {}
if safe:
stride = self._obj.G.stride(generic_dim_names=False)
for d, v in radius.items():
if v < stride[d]:
nearest_sel_params[d] = point[d]
radius = {d: v for d, v in radius.items() if d not in nearest_sel_params}
if fast:
selection_slices = {
d: slice(point[d] - radius[d], point[d] + radius[d])
for d in point.keys()
if d in radius
}
selected = self._obj.sel(**selection_slices)
else:
# selected = self._obj
raise NotImplementedError
if nearest_sel_params:
selected = selected.sel(**nearest_sel_params, method="nearest")
for d in nearest_sel_params:
# need to remove the extra dims from coords
del selected.coords[d]
if mode == "sum":
return selected.sum(list(radius.keys()))
elif mode == "mean":
return selected.mean(list(radius.keys()))
def short_history(self, key="by"):
return [h["record"][key] if isinstance(h, dict) else h for h in self.history]
def _calculate_symmetry_points(
self,
symmetry_points,
projection_distance=0.05,
include_projected=True,
epsilon=0.01,
):
# For each symmetry point, we need to determine if it is projected or not
# if it is projected, we need to calculate its projected coordinates
points = collections.defaultdict(list)
projected_points = collections.defaultdict(list)
fixed_coords = {k: v for k, v in self._obj.coords.items() if k not in self._obj.indexes}
index_coords = self._obj.indexes
for point, locations in symmetry_points.items():
if not isinstance(locations, list):
locations = [locations]
for location in locations:
# determine whether the location needs to be projected
projected = False
skip = False
for axis, value in location.items():
if axis in fixed_coords and np.abs(value - fixed_coords[axis]) > epsilon:
projected = True
if axis not in fixed_coords and axis not in index_coords:
# cannot even hope to do anything here, we don't have enough info
skip = True
if skip:
continue
new_location = location.copy()
if projected:
# Go and do the projection, for now we will assume we just get it by
# replacing the value of the mismatched coordinates.
# This does not work if the coordinate system is not orthogonal
for axis, v in location.items():
if axis in fixed_coords:
new_location = fixed_coords[axis]
projected_points[point].append(location)
else:
points[point].append(location)
return points, projected_points
def symmetry_points(self, raw=False, **kwargs):
try:
symmetry_points = self.fetch_ref_attrs().get("symmetry_points", {})
except:
symmetry_points = {}
our_symmetry_points = self._obj.attrs.get("symmetry_points", {})
symmetry_points = copy.deepcopy(symmetry_points)
symmetry_points.update(our_symmetry_points)
if raw:
return symmetry_points
return self._calculate_symmetry_points(symmetry_points, **kwargs)
@property
def iter_own_symmetry_points(self):
sym_points, _ = self.symmetry_points()
return _iter_groups(sym_points)
@property
def iter_projected_symmetry_points(self):
_, sym_points = self.symmetry_points()
return _iter_groups(sym_points)
@property
def iter_symmetry_points(self):
for sym_point in self.iter_own_symmetry_points:
yield sym_point
for sym_point in self.iter_projected_symmetry_points:
yield sym_point
@property
def history(self):
provenance = self._obj.attrs.get("provenance", None)
def unlayer(prov):
if prov is None:
return [], None
if isinstance(prov, str):
return [prov], None
first_layer = copy.copy(prov)
rest = first_layer.pop("parents_provenance", None)
if rest is None:
rest = first_layer.pop("parents_provanence", None)
if isinstance(rest, list):
warnings.warn(
"Encountered multiple parents in history extraction, "
"throwing away all but the first."
)
if rest:
rest = rest[0]
else:
rest = None
return [first_layer], rest
def _unwrap_provenance(prov):
if prov is None:
return []
first, rest = unlayer(prov)
return first + _unwrap_provenance(rest)
return _unwrap_provenance(provenance)
@property
def spectrometer(self):
ds = self._obj
spectrometers = {
"SToF": arpes.constants.SPECTROMETER_SPIN_TOF,
"ToF": arpes.constants.SPECTROMETER_STRAIGHT_TOF,
"DLD": arpes.constants.SPECTROMETER_DLD,
"BL7": arpes.constants.SPECTROMETER_BL7,
"ANTARES": arpes.constants.SPECTROMETER_ANTARES,
}
if "spectrometer_name" in ds.attrs:
return spectrometers.get(ds.attrs["spectrometer_name"])
if isinstance(ds, xr.Dataset):
if "up" in ds.data_vars or ds.attrs.get("18 MCP3") == 0:
return spectrometers["SToF"]
elif isinstance(ds, xr.DataArray):
if ds.name == "up" or ds.attrs.get("18 MCP3") == 0:
return spectrometers["SToF"]
if "location" in ds.attrs:
return {
"ALG-MC": arpes.constants.SPECTROMETER_MC,
"BL403": arpes.constants.SPECTROMETER_BL4,
"ALG-SToF": arpes.constants.SPECTROMETER_STRAIGHT_TOF,
"Kaindl": arpes.constants.SPECTROMETER_KAINDL,
"BL7": arpes.constants.SPECTROMETER_BL7,
"ANTARES": arpes.constants.SPECTROMETER_ANTARES,
}.get(ds.attrs["location"])
try:
return spectrometers[ds.attrs["spectrometer_name"]]
except KeyError:
return {}
@property
def dshape(self):
arr = self._obj
return dict(zip(arr.dims, arr.shape))
@property
def original_id(self):
history = self.history
if len(history) >= 3:
first_modification = history[-3]
return first_modification["parent_id"]
return self._obj.attrs["id"]
@property
def original_parent_scan_name(self):
try:
history = self.history
if len(history) >= 3:
first_modification = history[-3]
df = self._obj.attrs["df"]
return df[df.id == first_modification["parent_id"]].index[0]
except:
pass
return ""
@property
def scan_row(self):
df = self._obj.attrs["df"]
sdf = df[df.path == self._obj.attrs["file"]]
return list(sdf.iterrows())[0]
@property
def df_index(self):
return self.scan_row[0]
@property
def df_after(self):
return self._obj.attrs["df"][self._obj.attrs["df"].index > self.df_index]
def df_until_type(self, df=None, spectrum_type=None):
if df is None:
df = self.df_after
if spectrum_type is None:
spectrum_type = (self.spectrum_type,)
if isinstance(spectrum_type, str):
spectrum_type = (spectrum_type,)
try:
indices = [df[df["spectrum_type"].eq(s)] for s in spectrum_type]
indices = [d.index[0] for d in indices if not d.empty]
if not indices:
raise IndexError()
min_index = min(indices)
return df[df.index < min_index]
except IndexError:
# nothing
return df
@property
def scan_name(self):
for option in ["scan", "file"]:
if option in self._obj.attrs:
return self._obj.attrs[option]
id = self._obj.attrs.get("id")
if id is None:
return "No ID"
try:
df = self._obj.attrs["df"]
return df[df.id == id].index[0]
except (IndexError, KeyError, AttributeError):
# data is probably not raw data
return self.original_parent_scan_name
@property
def label(self):
return str(self._obj.attrs.get("description", self.scan_name))
@property
def t0(self):
if "t0" in self._obj.attrs:
value = float(self._obj.attrs["t0"])
if not np.isnan(value):
return value
if "T0_ps" in self._obj.attrs:
value = float(self._obj.attrs["T0_ps"])
if not np.isnan(value):
return value
return None
@contextlib.contextmanager
def with_rotation_offset(self, offset: float):
"""Temporarily rotates the chi_offset by `offset`."""
old_chi_offset = self.offsets.get("chi", 0)
self.apply_offsets({"chi": old_chi_offset + offset})
yield old_chi_offset + offset
self.apply_offsets({"chi": old_chi_offset})
def apply_offsets(self, offsets):
for k, v in offsets.items():
self._obj.attrs["{}_offset".format(k)] = v
@property
def offsets(self):
return {
c: self.lookup_offset(c) for c in self._obj.coords if f"{c}_offset" in self._obj.attrs
}
def lookup_offset_coord(self, name):
return self.lookup_coord(name) - self.lookup_offset(name)
def lookup_coord(self, name):
if name in self._obj.coords:
return unwrap_xarray_item(self._obj.coords[name])
if name in self._obj.attrs:
return unwrap_xarray_item(self._obj.attrs[name])
raise ValueError("Could not find coordinate {}.".format(name))
def lookup_offset(self, attr_name):
symmetry_points = self.symmetry_points(raw=True)
if "G" in symmetry_points:
gamma_point = symmetry_points["G"]
if attr_name in gamma_point:
return unwrap_xarray_item(gamma_point[attr_name])
offset_name = attr_name + "_offset"
if offset_name in self._obj.attrs:
return unwrap_xarray_item(self._obj.attrs[offset_name])
return unwrap_xarray_item(self._obj.attrs.get("data_preparation", {}).get(offset_name, 0))
@property
def beta_offset(self):
return self.lookup_offset("beta")
@property
def psi_offset(self):
return self.lookup_offset("psi")
@property
def theta_offset(self):
return self.lookup_offset("theta")
@property
def phi_offset(self):
return self.lookup_offset("phi")
@property
def chi_offset(self):
return self.lookup_offset("chi")
@property
def work_function(self) -> float:
"""Provides the work function, if present in metadata.
Otherwise, uses something approximate.
"""
if "sample_workfunction" in self._obj.attrs:
return self._obj.attrs["sample_workfunction"]
return 4.3
@property
def inner_potential(self) -> float:
"""Provides the inner potential, if present in metadata.
Otherwise, 10eV is assumed.
"""
if "inner_potential" in self._obj.attrs:
return self._obj.attrs["inner_potential"]
return 10
def find_spectrum_energy_edges(self, indices=False):
energy_marginal = self._obj.sum([d for d in self._obj.dims if d not in ["eV"]])
embed_size = 20
embedded = np.ndarray(shape=[embed_size] + list(energy_marginal.values.shape))
embedded[:] = energy_marginal.values
embedded = ndi.gaussian_filter(embedded, embed_size / 3)
# try to avoid dependency conflict with numpy v0.16
from skimage import feature # pylint: disable=import-error
edges = (
feature.canny(embedded, sigma=embed_size / 5, use_quantiles=True, low_threshold=0.3) * 1
)
edges = np.where(edges[int(embed_size / 2)] == 1)[0]
if indices:
return edges
delta = self._obj.G.stride(generic_dim_names=False)
return edges * delta["eV"] + self._obj.coords["eV"].values[0]
def find_spectrum_angular_edges_full(self, indices=False):
# as a first pass, we need to find the bottom of the spectrum, we will use this
# to select the active region and then to rebin into course steps in energy from 0
# down to this region
# we will then find the appropriate edge for each slice, and do a fit to the edge locations
energy_edge = self.find_spectrum_energy_edges()
low_edge = np.min(energy_edge) + 0.05
high_edge = np.max(energy_edge) - 0.05
if high_edge - low_edge < 0.15:
# Doesn't look like the automatic inference of the energy edge was valid
high_edge = 0
low_edge = np.min(self._obj.coords["eV"].values)
angular_dim = "pixel" if "pixel" in self._obj.dims else "phi"
energy_cut = self._obj.sel(eV=slice(low_edge, high_edge)).S.sum_other(["eV", angular_dim])
n_cuts = int(np.ceil(high_edge - low_edge / 0.05))
new_shape = {"eV": n_cuts}
new_shape[angular_dim] = len(energy_cut.coords[angular_dim].values)
rebinned = rebin(energy_cut, shape=new_shape)
embed_size = 20
embedded = np.ndarray(shape=[embed_size] + [len(rebinned.coords[angular_dim].values)])
low_edges = []
high_edges = []
for e_cut in rebinned.coords["eV"].values:
e_slice = rebinned.sel(eV=e_cut)
values = e_slice.values
values[values > np.mean(values)] = np.mean(values)
embedded[:] = values
embedded = ndi.gaussian_filter(embedded, embed_size / 1.5)
# try to avoid dependency conflict with numpy v0.16
from skimage import feature # pylint: disable=import-error
edges = (
feature.canny(
embedded,
sigma=4,
use_quantiles=False,
low_threshold=0.7,
high_threshold=1.5,
)
* 1
)
edges = np.where(edges[int(embed_size / 2)] == 1)[0]
low_edges.append(np.min(edges))
high_edges.append(np.max(edges))
if indices:
return np.array(low_edges), np.array(high_edges), rebinned.coords["eV"]
delta = self._obj.G.stride(generic_dim_names=False)
low_edges = (
np.array(low_edges) * delta[angular_dim] + rebinned.coords[angular_dim].values[0]
)
high_edges = (
np.array(high_edges) * delta[angular_dim] + rebinned.coords[angular_dim].values[0]
)
return low_edges, high_edges, rebinned.coords["eV"]
def zero_spectrometer_edges(self, cut_margin=None, interp_range=None, low=None, high=None):
if low is not None:
assert high is not None
assert len(low) == len(high) == 2
low_edges = low
high_edges = high
(
low_edges,
high_edges,
rebinned_eV_coord,
) = self.find_spectrum_angular_edges_full(indices=True)
angular_dim = "pixel" if "pixel" in self._obj.dims else "phi"
if cut_margin is None:
if "pixel" in self._obj.dims:
cut_margin = 50
else:
cut_margin = int(0.08 / self._obj.G.stride(generic_dim_names=False)[angular_dim])
else:
if isinstance(cut_margin, float):
assert angular_dim == "phi"
cut_margin = int(
cut_margin / self._obj.G.stride(generic_dim_names=False)[angular_dim]
)
if interp_range is not None:
low_edge = xr.DataArray(low_edges, coords={"eV": rebinned_eV_coord}, dims=["eV"])
high_edge = xr.DataArray(high_edges, coords={"eV": rebinned_eV_coord}, dims=["eV"])
low_edge = low_edge.sel(eV=interp_range)
high_edge = high_edge.sel(eV=interp_range)
import pdb
pdb.set_trace()
other_dims = list(self._obj.dims)
other_dims.remove("eV")
other_dims.remove(angular_dim)
copied = self._obj.copy(deep=True).transpose(*(["eV", angular_dim] + other_dims))
low_edges += cut_margin
high_edges -= cut_margin
for i, energy in enumerate(copied.coords["eV"].values):
index = np.searchsorted(rebinned_eV_coord, energy)
other = index + 1
if other >= len(rebinned_eV_coord):
other = len(rebinned_eV_coord) - 1
index = len(rebinned_eV_coord) - 2
low = int(np.interp(energy, rebinned_eV_coord, low_edges))
high = int(np.interp(energy, rebinned_eV_coord, high_edges))
copied.values[i, 0:low] = 0
copied.values[i, high:-1] = 0
return copied
def sum_other(self, dim_or_dims, keep_attrs=False):
if isinstance(dim_or_dims, str):
dim_or_dims = [dim_or_dims]
return self._obj.sum(
[d for d in self._obj.dims if d not in dim_or_dims], keep_attrs=keep_attrs
)
def mean_other(self, dim_or_dims, keep_attrs=False):
if isinstance(dim_or_dims, str):
dim_or_dims = [dim_or_dims]
return self._obj.mean(
[d for d in self._obj.dims if d not in dim_or_dims], keep_attrs=keep_attrs
)
def find_spectrum_angular_edges(self, indices=False):
angular_dim = "pixel" if "pixel" in self._obj.dims else "phi"
energy_edge = self.find_spectrum_energy_edges()
energy_slice = slice(np.max(energy_edge) - 0.1, np.max(energy_edge))
near_ef = self._obj.sel(eV=energy_slice).sum(
[d for d in self._obj.dims if d not in [angular_dim]]
)
embed_size = 20
embedded = np.ndarray(shape=[embed_size] + list(near_ef.values.shape))
embedded[:] = near_ef.values
embedded = ndi.gaussian_filter(embedded, embed_size / 3)
# try to avoid dependency conflict with numpy v0.16
from skimage import feature # pylint: disable=import-error
edges = (
feature.canny(embedded, sigma=embed_size / 5, use_quantiles=True, low_threshold=0.2) * 1
)
edges = np.where(edges[int(embed_size / 2)] == 1)[0]
if indices:
return edges
delta = self._obj.G.stride(generic_dim_names=False)
return edges * delta[angular_dim] + self._obj.coords[angular_dim].values[0]
def trimmed_selector(self):
raise NotImplementedError
def wide_angle_selector(self, include_margin=True):
edges = self.find_spectrum_angular_edges()
low_edge, high_edge = np.min(edges), np.max(edges)
# go and build in a small margin
if include_margin:
if "pixels" in self._obj.dims:
low_edge += 50
high_edge -= 50
else:
low_edge += 0.05
high_edge -= 0.05
return slice(low_edge, high_edge)
def narrow_angle_selector(self):
raise NotImplementedError
def meso_effective_selector(self):
energy_edge = self.find_spectrum_energy_edges()
return slice(np.max(energy_edge) - 0.3, np.max(energy_edge) - 0.1)
def region_sel(self, *regions):
def process_region_selector(selector: Union[slice, DesignatedRegions], dimension_name: str):
if isinstance(selector, slice):
return selector
# need to read out the region
options = {
"eV": (
DesignatedRegions.ABOVE_EF,
DesignatedRegions.BELOW_EF,
DesignatedRegions.EF_NARROW,
DesignatedRegions.MESO_EF,
DesignatedRegions.MESO_EFFECTIVE_EF,
DesignatedRegions.ABOVE_EFFECTIVE_EF,
DesignatedRegions.BELOW_EFFECTIVE_EF,
DesignatedRegions.EFFECTIVE_EF_NARROW,
),
"phi": (
DesignatedRegions.NARROW_ANGLE,
DesignatedRegions.WIDE_ANGLE,
DesignatedRegions.TRIM_EMPTY,
),
}
options_for_dim = options.get(dimension_name, [d for d in DesignatedRegions])
assert selector in options_for_dim
# now we need to resolve out the region
resolution_methods = {
DesignatedRegions.ABOVE_EF: slice(0, None),
DesignatedRegions.BELOW_EF: slice(None, 0),
DesignatedRegions.EF_NARROW: slice(-0.1, 0.1),
DesignatedRegions.MESO_EF: slice(-0.3, -0.1),
DesignatedRegions.MESO_EFFECTIVE_EF: self.meso_effective_selector,
# Implement me
# DesignatedRegions.TRIM_EMPTY: ,
DesignatedRegions.WIDE_ANGLE: self.wide_angle_selector,
# DesignatedRegions.NARROW_ANGLE: self.narrow_angle_selector,
}
resolution_method = resolution_methods[selector]
if isinstance(resolution_method, slice):
return resolution_method
if callable(resolution_method):
return resolution_method()
raise NotImplementedError("Unable to determine resolution method.")
obj = self._obj
def unpack_dim(dim_name):
if dim_name == "angular":
return "pixel" if "pixel" in obj.dims else "phi"
return dim_name
for region in regions:
region = {unpack_dim(k): v for k, v in normalize_region(region).items()}
# remove missing dimensions from selection for permissiveness
# and to transparent composing of regions
region = {k: process_region_selector(v, k) for k, v in region.items() if k in obj.dims}
obj = obj.sel(**region)
return obj
def fat_sel(self, widths: Optional[Dict[str, Any]] = None, **kwargs) -> xr.DataArray:
"""Allows integrating a selection over a small region.
The produced dataset will be normalized by dividing by the number
of slices integrated over.
This can be used to produce temporary datasets that have reduced
uncorrelated noise.
Args:
widths: Override the widths for the slices. Resonable defaults are used otherwise. Defaults to None.
kwargs: slice dict. Has the same function as xarray.DataArray.sel
Returns:
The data after selection.
"""
if widths is None:
widths = {}
default_widths = {
"eV": 0.05,
"phi": 2,
"beta": 2,
"theta": 2,
"kx": 0.02,
"ky": 0.02,
"kp": 0.02,
"kz": 0.1,
}
extra_kwargs = {k: v for k, v in kwargs.items() if k not in self._obj.dims}
slice_kwargs = {k: v for k, v in kwargs.items() if k not in extra_kwargs}
slice_widths = {
k: widths.get(k, extra_kwargs.get(k + "_width", default_widths.get(k)))
for k in slice_kwargs
}
slices = {
k: slice(v - slice_widths[k] / 2, v + slice_widths[k] / 2)
for k, v in slice_kwargs.items()
}
sliced = self._obj.sel(**slices)
thickness = np.product([len(sliced.coords[k]) for k in slice_kwargs.keys()])
normalized = sliced.sum(slices.keys(), keep_attrs=True) / thickness
for k, v in slices.items():
normalized.coords[k] = (v.start + v.stop) / 2
normalized.attrs.update(self._obj.attrs.copy())
return normalized
@property
def reference_settings(self):
settings = self.spectrometer_settings or {}
settings.update(
{
"hv": self.hv,
}
)
return settings
@property
def beamline_settings(self):
find_keys = {
"entrance_slit": {
"entrance_slit",
},
"exit_slit": {
"exit_slit",
},
"hv": {
"hv",
"photon_energy",
},
"grating": {},
}
settings = {}
for key, options in find_keys.items():
for option in options:
if option in self._obj.attrs:
settings[key] = self._obj.attrs[option]
break
if self.endstation == "BL403":
settings["grating"] = "HEG" # for now assume we always use the first order light
return settings
@property
def spectrometer_settings(self):
find_keys = {
"lens_mode": {
"lens_mode",
},
"pass_energy": {
"pass_energy",
},
"scan_mode": {
"scan_mode",
},
"scan_region": {
"scan_region",
},
"slit": {
"slit",
"slit_plate",
},
}
settings = {}
for key, options in find_keys.items():
for option in options:
if option in self._obj.attrs:
settings[key] = self._obj.attrs[option]
break
if isinstance(settings.get("slit"), (float, np.float32, np.float64)):
settings["slit"] = int(round(settings["slit"]))
return settings
@property
def sample_pos(self):
x, y, z = None, None, None
try:
x = self._obj.attrs["x"]
except KeyError:
pass
try:
y = self._obj.attrs["y"]
except KeyError:
pass
try:
z = self._obj.attrs["z"]
except KeyError:
pass
def do_float(w):
return float(w) if w is not None else None
return (do_float(x), do_float(y), do_float(z))
@property
def sample_angles(self):
return (
# manipulator
self.lookup_coord("beta"),
self.lookup_coord("theta"),
self.lookup_coord("chi"),
# analyzer
self.lookup_coord("phi"),
self.lookup_coord("psi"),
self.lookup_coord("alpha"),
)
@property
def full_coords(self):
full_coords = {}
full_coords.update(dict(zip(["x", "y", "z"], self.sample_pos)))
full_coords.update(
dict(zip(["beta", "theta", "chi", "phi", "psi", "alpha"], self.sample_angles))
)
full_coords.update(
{
"hv": self.hv,
}
)
full_coords.update(self._obj.coords)
return full_coords
@property
def sample_info(self):
return unwrap_xarray_dict(
{
"id": self._obj.attrs.get("sample_id"),
"name": self._obj.attrs.get("sample_name"),
"source": self._obj.attrs.get("sample_source"),
"reflectivity": self._obj.attrs.get("sample_reflectivity"),
}
)
@property
def scan_info(self):
return unwrap_xarray_dict(
{
"time": self._obj.attrs.get("time"),
"date": self._obj.attrs.get("date"),
"type": self.scan_type,
"spectrum_type": self.spectrum_type,
"experimenter": self._obj.attrs.get("experimenter"),
"sample": self._obj.attrs.get("sample_name"),
}
)
@property
def experiment_info(self):
return unwrap_xarray_dict(
{
"temperature": self._obj.attrs.get("temperature"),
"temperature_cryotip": self._obj.attrs.get("temperature_cryotip"),
"pressure": self._obj.attrs.get("pressure"),
"polarization": self.probe_polarization,
"photon_flux": self._obj.attrs.get("photon_flux"),
"photocurrent": self._obj.attrs.get("photocurrent"),
"probe": self._obj.attrs.get("probe"),
"probe_detail": self._obj.attrs.get("probe_detail"),
"analyzer": self._obj.attrs.get("analyzer"),
"analyzer_detail": self.analyzer_detail,
}
)
@property
def pump_info(self):
return unwrap_xarray_dict(
{
"pump_wavelength": self._obj.attrs.get("pump_wavelength"),
"pump_energy": self._obj.attrs.get("pump_energy"),
"pump_fluence": self._obj.attrs.get("pump_fluence"),
"pump_pulse_energy": self._obj.attrs.get("pump_pulse_energy"),
"pump_spot_size": (
self._obj.attrs.get("pump_spot_size_x"),
self._obj.attrs.get("pump_spot_size_y"),
),
"pump_profile": self._obj.attrs.get("pump_profile"),
"pump_linewidth": self._obj.attrs.get("pump_linewidth"),
"pump_temporal_width": self._obj.attrs.get("pump_temporal_width"),
"pump_polarization": self.pump_polarization,
}
)
@property
def probe_info(self):
return unwrap_xarray_dict(
{
"probe_wavelength": self._obj.attrs.get("probe_wavelength"),
"probe_energy": self._obj.coords["hv"],
"probe_fluence": self._obj.attrs.get("probe_fluence"),
"probe_pulse_energy": self._obj.attrs.get("probe_pulse_energy"),
"probe_spot_size": (
self._obj.attrs.get("probe_spot_size_x"),
self._obj.attrs.get("probe_spot_size_y"),
),
"probe_profile": self._obj.attrs.get("probe_profile"),
"probe_linewidth": self._obj.attrs.get("probe_linewidth"),
"probe_temporal_width": self._obj.attrs.get("probe_temporal_width"),
"probe_polarization": self.probe_polarization,
}
)
@property
def laser_info(self):
return {
**self.probe_info,
**self.pump_info,
"repetition_rate": self._obj.attrs.get("repetition_rate"),
}
@property
def analyzer_info(self) -> Dict[str, Any]:
"""General information about the photoelectron analyzer used."""
return unwrap_xarray_dict(
{
"lens_mode": self._obj.attrs.get("lens_mode"),
"lens_mode_name": self._obj.attrs.get("lens_mode_name"),
"acquisition_mode": self._obj.attrs.get("acquisition_mode"),
"pass_energy": self._obj.attrs.get("pass_energy"),
"slit_shape": self._obj.attrs.get("slit_shape"),
"slit_width": self._obj.attrs.get("slit_width"),
"slit_number": self._obj.attrs.get("slit_number"),
"lens_table": self._obj.attrs.get("lens_table"),
"analyzer_type": self._obj.attrs.get("analyzer_type"),
"mcp_voltage": self._obj.attrs.get("mcp_voltage"),
}
)
@property
def daq_info(self) -> Dict[str, Any]:
"""General information about the acquisition settings for an ARPES experiment."""
return unwrap_xarray_dict(
{
"daq_type": self._obj.attrs.get("daq_type"),
"region": self._obj.attrs.get("daq_region"),
"region_name": self._obj.attrs.get("daq_region_name"),
"center_energy": self._obj.attrs.get("daq_center_energy"),
"prebinning": self.prebinning,
"trapezoidal_correction_strategy": self._obj.attrs.get(
"trapezoidal_correction_strategy"
),
"dither_settings": self._obj.attrs.get("dither_settings"),
"sweep_settings": self.sweep_settings,
"frames_per_slice": self._obj.attrs.get("frames_per_slice"),
"frame_duration": self._obj.attrs.get("frame_duration"),
}
)
@property
def beamline_info(self) -> Dict[str, Any]:
"""Information about the beamline or light source used for a measurement."""
return unwrap_xarray_dict(
{
"hv": self._obj.coords["hv"],
"linewidth": self._obj.attrs.get("probe_linewidth"),
"photon_polarization": self.probe_polarization,
"undulator_info": self.undulator_info,
"repetition_rate": self._obj.attrs.get("repetition_rate"),
"beam_current": self._obj.attrs.get("beam_current"),
"entrance_slit": self._obj.attrs.get("entrance_slit"),
"exit_slit": self._obj.attrs.get("exit_slit"),
"monochromator_info": self.monochromator_info,
}
)
@property
def sweep_settings(self) -> Dict[str, Any]:
"""For datasets acquired with swept acquisition settings, provides those settings."""
return {
"high_energy": self._obj.attrs.get("sweep_high_energy"),
"low_energy": self._obj.attrs.get("sweep_low_energy"),
"n_sweeps": self._obj.attrs.get("n_sweeps"),
"step": self._obj.attrs.get("sweep_step"),
}
@property
def probe_polarization(self) -> Tuple[float, float]:
"""Provides the probe polarization of the UV/x-ray source."""
return (
self._obj.attrs.get("probe_polarization_theta"),
self._obj.attrs.get("probe_polarization_alpha"),
)
@property
def pump_polarization(self) -> Tuple[float, float]:
"""For Tr-ARPES experiments, provides the pump polarization."""
return (
self._obj.attrs.get("pump_polarization_theta"),
self._obj.attrs.get("pump_polarization_alpha"),
)
@property
def prebinning(self) -> Dict[str, Any]:
"""Information about the prebinning performed during scan acquisition."""
prebinning = {}
for d in self._obj.indexes:
if "{}_prebinning".format(d) in self._obj.attrs:
prebinning[d] = self._obj.attrs["{}_prebinning".format(d)]
return prebinning
@property
def monochromator_info(self) -> Dict[str, Any]:
"""Details about the monochromator used on the UV/x-ray source."""
return {
"grating_lines_per_mm": self._obj.attrs.get("grating_lines_per_mm"),
}
@property
def undulator_info(self) -> Dict[str, Any]:
"""Details about the undulator for data performed at an undulator source."""
return {
"gap": self._obj.attrs.get("undulator_gap"),
"z": self._obj.attrs.get("undulator_z"),
"harmonic": self._obj.attrs.get("undulator_harmonic"),
"polarization": self._obj.attrs.get("undulator_polarization"),
"type": self._obj.attrs.get("undulator_type"),
}
@property
def analyzer_detail(self) -> Dict[str, Any]:
"""Details about the analyzer, its capabilities, and metadata."""
return {
"name": self._obj.attrs.get("analyzer_name"),
"parallel_deflectors": self._obj.attrs.get("parallel_deflectors"),
"perpendicular_deflectors": self._obj.attrs.get("perpendicular_deflectors"),
"type": self._obj.attrs.get("analyzer_type"),
"radius": self._obj.attrs.get("analyzer_radius"),
}
@property
def temp(self) -> Union[float, xr.DataArray]:
"""The temperature at which an experiment was performed."""
prefered_attrs = [
"TA",
"ta",
"t_a",
"T_A",
"T_1",
"t_1",
"t1",
"T1",
"temp",
"temp_sample",
"temp_cryotip",
"temperature_sensor_b",
"temperature_sensor_a",
]
for attr in prefered_attrs:
if attr in self._obj.attrs:
return float(self._obj.attrs[attr])
raise AttributeError("Could not read temperature off any standard attr")
@property
def condensed_attrs(self) -> Dict[str, Any]:
"""An attributes shortlist.
Since we enforce camelcase on attributes, this is a reasonable filter that catches
the ones we don't use very often.
"""
return {k: v for k, v in self._obj.attrs.items() if k[0].islower()}
@property
def referenced_scans(self) -> pd.DataFrame:
if self.spectrum_type == "map":
df = self._obj.attrs["df"]
return df[(df.spectrum_type != "map") & (df.ref_id == self._obj.id)]
else:
assert self.spectrum_type in {"ucut", "spem"}
df = self.df_until_type(
spectrum_type=(
"ucut",
"spem",
)
)
return df
def generic_fermi_surface(self, fermi_energy):
return self.fat_sel(eV=fermi_energy)
@property
def fermi_surface(self):
return self.fat_sel(eV=0)
def __init__(self, xarray_obj):
self._obj = xarray_obj
@staticmethod
def dict_to_html(d):
return """
<table>
<thead>
<tr>
<th>Key</th>
<th>Value</th>
</tr>
</thead>
<tbody>
{rows}
</tbody>
</table>
""".format(
rows="".join(["<tr><td>{}</td><td>{}</td></tr>".format(k, v) for k, v in d.items()])
)
def _repr_html_full_coords(self, coords):
def coordinate_dataarray_to_flat_rep(value):
if not isinstance(value, xr.DataArray):
return value
return "<span>{min:.3g}<strong> to </strong>{max:.3g}<strong> by </strong>{delta:.3g}</span>".format(
min=value.min().item(),
max=value.max().item(),
delta=value.values[1] - value.values[0],
)
return ARPESAccessorBase.dict_to_html(
{k: coordinate_dataarray_to_flat_rep(v) for k, v in coords.items()}
)
def _repr_html_spectrometer_info(self):
skip_keys = {
"dof",
}
ordered_settings = OrderedDict(self.spectrometer_settings)
ordered_settings.update({k: v for k, v in self.spectrometer.items() if k not in skip_keys})
return ARPESAccessorBase.dict_to_html(ordered_settings)
@staticmethod
def _repr_html_experimental_conditions(conditions):
transforms = {
"polarization": lambda p: {
"p": "Linear Horizontal",
"s": "Linear Vertical",
"rc": "Right Circular",
"lc": "Left Circular",
"s-p": "Linear Dichroism",
"p-s": "Linear Dichroism",
"rc-lc": "Circular Dichroism",
"lc-rc": "Circular Dichroism",
}.get(p, p),
"hv": "{} eV".format,
"temp": "{} Kelvin".format,
}
id = lambda x: x
return ARPESAccessorBase.dict_to_html(
{k: transforms.get(k, id)(v) for k, v in conditions.items() if v is not None}
)
def _repr_html_(self):
skip_data_vars = {
"time",
}
if isinstance(self._obj, xr.Dataset):
to_plot = [k for k in self._obj.data_vars.keys() if k not in skip_data_vars]
to_plot = [k for k in to_plot if 1 <= len(self._obj[k].dims) < 3]
to_plot = to_plot[:5]
if to_plot:
_, ax = plt.subplots(
1,
len(to_plot),
figsize=(
len(to_plot) * 3,
3,
),
)
if len(to_plot) == 1:
ax = [ax]
for i, plot_var in enumerate(to_plot):
self._obj[plot_var].plot(ax=ax[i])
fancy_labels(ax[i])
ax[i].set_title(plot_var.replace("_", " "))
remove_colorbars()
else:
if 1 <= len(self._obj.dims) < 3:
fig, ax = plt.subplots(1, 1, figsize=(4, 3))
self._obj.plot(ax=ax)
fancy_labels(ax)
ax.set_title("")
remove_colorbars()
wrapper_style = 'style="display: flex; flex-direction: row;"'
try:
name = self.df_index
except:
if "id" in self._obj.attrs:
name = "ID: " + str(self._obj.attrs["id"])[:9] + "..."
else:
name = "No name"
warning = ""
if len(self._obj.attrs) < 10:
warning = ': <span style="color: red;">Few Attributes, Data Is Summed?</span>'
return """
<header><strong>{name}{warning}</strong></header>
<div {wrapper_style}>
<details open>
<summary>Experimental Conditions</summary>
{conditions}
</details>
<details open>
<summary>Full Coordinates</summary>
{coordinates}
</details>
<details open>
<summary>Spectrometer</summary>
{spectrometer_info}
</details>
</div>
""".format(
name=name,
warning=warning,
wrapper_style=wrapper_style,
conditions=self._repr_html_experimental_conditions(self.experimental_conditions),
coordinates=self._repr_html_full_coords(
{k: v for k, v in self.full_coords.items() if v is not None}
),
spectrometer_info=self._repr_html_spectrometer_info(),
)
@xr.register_dataarray_accessor("S")
class ARPESDataArrayAccessor(ARPESAccessorBase):
"""Spectrum related accessor for `xr.DataArray`."""
def plot(self, *args, rasterized=True, **kwargs):
"""Utility delegate to `xr.DataArray.plot` which rasterizes."""
if len(self._obj.dims) == 2:
kwargs["rasterized"] = rasterized
with plt.rc_context(rc={"text.usetex": False}):
self._obj.plot(*args, **kwargs)
def show(self, detached=False, **kwargs):
"""Opens the Qt based image tool."""
import arpes.plotting.qt_tool
arpes.plotting.qt_tool.qt_tool(self._obj, detached=detached, **kwargs)
def show_d2(self, **kwargs):
"""Opens the Bokeh based second derivative image tool."""
from arpes.plotting.all import CurvatureTool
curve_tool = CurvatureTool(**kwargs)
return curve_tool.make_tool(self._obj)
def show_band_tool(self, **kwargs):
"""Opens the Bokeh based band placement tool."""
from arpes.plotting.all import BandTool
band_tool = BandTool(**kwargs)
return band_tool.make_tool(self._obj)
def fs_plot(self, pattern="{}.png", **kwargs):
"""Provides a reference plot of the approximate Fermi surface."""
out = kwargs.get("out")
if out is not None and isinstance(out, bool):
out = pattern.format("{}_fs".format(self.label))
kwargs["out"] = out
return plotting.labeled_fermi_surface(self._obj, **kwargs)
def fermi_edge_reference_plot(self, pattern="{}.png", **kwargs) -> plt.Axes:
"""Provides a reference plot for a Fermi edge reference."""
out = kwargs.get("out")
if out is not None and isinstance(out, bool):
out = pattern.format("{}_fermi_edge_reference".format(self.label))
kwargs["out"] = out
return plotting.fermi_edge.fermi_edge_reference(self._obj, **kwargs)
def _referenced_scans_for_spatial_plot(self, use_id=True, pattern="{}.png", **kwargs):
out = kwargs.get("out")
label = self._obj.attrs["id"] if use_id else self.label
if out is not None and isinstance(out, bool):
out = pattern.format("{}_reference_scan_fs".format(label))
kwargs["out"] = out
return plotting.reference_scan_spatial(self._obj, **kwargs)
def _referenced_scans_for_map_plot(self, use_id=True, pattern="{}.png", **kwargs):
out = kwargs.get("out")
label = self._obj.attrs["id"] if use_id else self.label
if out is not None and isinstance(out, bool):
out = pattern.format("{}_reference_scan_fs".format(label))
kwargs["out"] = out
return plotting.reference_scan_fermi_surface(self._obj, **kwargs)
def _referenced_scans_for_hv_map_plot(self, use_id=True, pattern="{}.png", **kwargs):
out = kwargs.get("out")
label = self._obj.attrs["id"] if use_id else self.label
if out is not None and isinstance(out, bool):
out = pattern.format("{}_hv_reference_scan".format(label))
out = "{}_hv_reference_scan.png".format(label)
kwargs["out"] = out
return plotting.hv_reference_scan(self._obj, **kwargs)
def _simple_spectrum_reference_plot(self, use_id=True, pattern="{}.png", **kwargs):
out = kwargs.get("out")
label = self._obj.attrs["id"] if use_id else self.label
if out is not None and isinstance(out, bool):
out = pattern.format("{}_spectrum_reference".format(label))
kwargs["out"] = out
return plotting.fancy_dispersion(self._obj, **kwargs)
def cut_nan_coords(self) -> xr.DataArray:
"""Selects data where coordinates are not `nan`.
Returns:
The subset of the data where coordinates are not `nan`.
"""
slices = dict()
for cname, cvalue in self._obj.coords.items():
try:
end_ind = np.where(np.isnan(cvalue.values))[0][0]
end_ind = None if end_ind == -1 else end_ind
slices[cname] = slice(None, end_ind)
except IndexError:
pass
return self._obj.isel(**slices)
def nan_to_num(self, x: Any = 0) -> xr.DataArray:
"""Provides an `xarray` version of `numpy.nan_to_num`.
Args:
x: The fill value
Returns:
A copy of the data with nans filled in.
"""
data = self._obj.copy(deep=True)
assert isinstance(data, xr.DataArray)
data.values[np.isnan(data.values)] = x
return data
def reference_plot(self, **kwargs) -> plt.Axes:
"""Generates a reference plot for this piece of data according to its spectrum type.
Raises:
NotImplementedError: If there is no standard approach for plotting this data.
Returns:
The axes which were used for plotting.
"""
if self.spectrum_type == "map":
return self._referenced_scans_for_map_plot(**kwargs)
elif self.spectrum_type == "hv_map":
return self._referenced_scans_for_hv_map_plot(**kwargs)
elif self.spectrum_type == "spectrum":
return self._simple_spectrum_reference_plot(**kwargs)
elif self.spectrum_type in {"ucut", "spem"}:
return self._referenced_scans_for_spatial_plot(**kwargs)
else:
raise NotImplementedError
NORMALIZED_DIM_NAMES = ["x", "y", "z", "w"]
@xr.register_dataset_accessor("G")
@xr.register_dataarray_accessor("G")
class GenericAccessorTools:
_obj = None
def round_coordinates(self, coords, as_indices: bool = False):
data = self._obj
rounded = {
k: v.item()
for k, v in data.sel(**coords, method="nearest").coords.items()
if k in coords
}
if as_indices:
rounded = {k: data.coords[k].index(v) for k, v in rounded.items()}
return rounded
def argmax_coords(self):
data = self._obj
raveled_idx = data.argmax().item()
flat_indices = np.unravel_index(raveled_idx, data.values.shape)
max_coords = {d: data.coords[d][flat_indices[i]].item() for i, d in enumerate(data.dims)}
return max_coords
def apply_over(self, fn, copy=True, **selections):
data = self._obj
if copy:
data = data.copy(deep=True)
try:
transformed = fn(data.sel(**selections))
except:
transformed = fn(data.sel(**selections).values)
if isinstance(transformed, xr.DataArray):
transformed = transformed.values
data.loc[selections] = transformed
return data
def to_unit_range(self, percentile=None):
if percentile is None:
norm = self._obj - self._obj.min()
return norm / norm.max()
percentile = min(percentile, 100 - percentile)
low, high = np.percentile(
self._obj,
(
percentile,
100 - percentile,
),
)
norm = self._obj - low
return norm / (high - low)
def extent(self, *args, dims=None) -> Tuple[float, float, float, float]:
"""Returns an "extent" array that can be used to draw with plt.imshow."""
if dims is None:
if not args:
dims = self._obj.dims
else:
dims = args
assert len(dims) == 2 and "You must supply exactly two dims to `.G.extent` not {}".format(
dims
)
return [
self._obj.coords[dims[0]][0].item(),
self._obj.coords[dims[0]][-1].item(),
self._obj.coords[dims[1]][0].item(),
self._obj.coords[dims[1]][-1].item(),
]
def drop_nan(self):
assert len(self._obj.dims) == 1
mask = np.logical_not(np.isnan(self._obj.values))
return self._obj.isel(**dict([[self._obj.dims[0], mask]]))
def shift_coords(self, dims, shift):
if not isinstance(shift, np.ndarray):
shift = np.ones((len(dims),)) * shift
def transform(data):
new_shift = shift
for _ in range(len(dims)):
new_shift = np.expand_dims(new_shift, 0)
return data + new_shift
return self.transform_coords(dims, transform)
def scale_coords(self, dims, scale):
if not isinstance(scale, np.ndarray):
n_dims = len(dims)
scale = np.identity(n_dims) * scale
elif len(scale.shape) == 1:
scale = np.diag(scale)
return self.transform_coords(dims, scale)
def transform_coords(
self, dims: List[str], transform: Union[np.ndarray, Callable]
) -> xr.DataArray:
"""Transforms the given coordinate values according to an arbitrary function.
The transformation should either be a function
from a len(dims) x size of raveled coordinate array to len(dims) x size of raveled_coordinate
array or a linear transformation as a matrix which is multiplied into such an array.
Params:
dims: A List of dimensions that should be transformed
transform: The transformation to apply, can either be a function, or a matrix
Returns:
An identical valued array over new coordinates.
"""
as_array = np.stack([self._obj.data_vars[d].values for d in dims], axis=-1)
if isinstance(transform, np.ndarray):
transformed = np.dot(as_array, transform)
else:
transformed = transform(as_array)
copied = self._obj.copy(deep=True)
for d, arr in zip(dims, np.split(transformed, transformed.shape[-1], axis=-1)):
copied.data_vars[d].values = np.squeeze(arr, axis=-1)
return copied
def filter_vars(self, f):
return xr.Dataset(
data_vars={k: v for k, v in self._obj.data_vars.items() if f(v, k)},
attrs=self._obj.attrs,
)
def coordinatize(self, as_coordinate_name: Optional[str] = None) -> xr.DataArray:
"""Copies data into a coordinate's data, with an optional renaming.
If you think of an array as a function c => f(c) from coordinates to values at
those coordinates, this function replaces f by the identity to give c => c
Remarkably, `coordinatize` is a word.
For the most part, this is only useful when converting coordinate values into
k-space "forward".
Args:
as_coordinate_name: A new coordinate name for the only dimension. Defaults to None.
Returns:
An array which consists of the mapping c => c.
"""
assert len(self._obj.dims) == 1
d = self._obj.dims[0]
if as_coordinate_name is None:
as_coordinate_name = d
o = self._obj.rename(dict([[d, as_coordinate_name]])).copy(deep=True)
o.coords[as_coordinate_name] = o.values
return o
def ravel(self) -> Dict[str, xr.DataArray]:
"""Converts to a flat representation where the coordinate values are also present.
Extremely valuable for plotting a dataset with coordinates, X, Y and values Z(X,Y)
on a scatter plot in 3D.
By default the data is listed under the key 'data'.
Returns:
A dictionary mapping between coordinate names and their coordinate arrays.
Additionally, there is a key "data" which maps to the `.values` attribute of the array.
"""
assert isinstance(self._obj, xr.DataArray)
dims = self._obj.dims
coords_as_list = [self._obj.coords[d].values for d in dims]
raveled_coordinates = dict(zip(dims, [cs.ravel() for cs in np.meshgrid(*coords_as_list)]))
assert "data" not in raveled_coordinates
raveled_coordinates["data"] = self._obj.values.ravel()
return raveled_coordinates
def meshgrid(self, as_dataset=False):
assert isinstance(self._obj, xr.DataArray)
dims = self._obj.dims
coords_as_list = [self._obj.coords[d].values for d in dims]
meshed_coordinates = dict(zip(dims, [cs for cs in np.meshgrid(*coords_as_list)]))
assert "data" not in meshed_coordinates
meshed_coordinates["data"] = self._obj.values
if as_dataset:
# this could use a bit of cleaning up
faked = ["x", "y", "z", "w"]
meshed_coordinates = {
k: (faked[: len(v.shape)], v) for k, v in meshed_coordinates.items() if k != "data"
}
return xr.Dataset(meshed_coordinates)
return meshed_coordinates
def to_arrays(self) -> Tuple[np.ndarray, np.ndarray]:
"""Converts a (1D) `xr.DataArray` into two plain ``ndarray``s of their coordinate and data.
Useful for rapidly converting into a format than can be `plt.scatter`ed
or similar.
Example:
We can use this to quickly scatter a 1D dataset where one axis is the coordinate value.
>>> plt.scatter(*data.G.as_arrays(), marker='s') # doctest: +SKIP
Returns:
A tuple of the coordinate array (first index) and the data array (second index)
"""
assert len(self._obj.dims) == 1
return (self._obj.coords[self._obj.dims[0]].values, self._obj.values)
def clean_outliers(self, clip=0.5):
low, high = np.percentile(self._obj.values, [clip, 100 - clip])
copied = self._obj.copy(deep=True)
copied.values[copied.values < low] = low
copied.values[copied.values > high] = high
return copied
def as_movie(self, time_dim=None, pattern="{}.png", **kwargs):
if time_dim is None:
time_dim = self._obj.dims[-1]
out = kwargs.get("out")
if out is not None and isinstance(out, bool):
out = pattern.format("{}_animation".format(self._obj.S.label))
kwargs["out"] = out
return plotting.plot_movie(self._obj, time_dim, **kwargs)
def filter_coord(
self, coordinate_name: str, sieve: Callable[[Any, xr.DataArray], bool]
) -> xr.DataArray:
"""Filters a dataset along a coordinate.
Sieve should be a function which accepts a coordinate value and the slice
of the data along that dimension.
Internally, the predicate function `sieve` is applied to the coordinate and slice to generate
a mask. The mask is used to select from the data after iteration.
An improvement here would support filtering over several coordinates.
Args:
coordinate_name: The coordinate which should be filtered.
sieve: A predicate to be applied to the coordinate and data at that coordinate.
Returns:
A subset of the data composed of the slices which make the `sieve` predicate `True`.
"""
mask = np.array(
[
i
for i, c in enumerate(self._obj.coords[coordinate_name])
if sieve(c, self._obj.isel(**dict([[coordinate_name, i]])))
]
)
return self._obj.isel(**dict([[coordinate_name, mask]]))
def iterate_axis(self, axis_name_or_axes):
if isinstance(axis_name_or_axes, int):
axis_name_or_axes = self._obj.dims[axis_name_or_axes]
if isinstance(axis_name_or_axes, str):
axis_name_or_axes = [axis_name_or_axes]
coord_iterators = [self._obj.coords[d].values for d in axis_name_or_axes]
for indices in itertools.product(*[range(len(c)) for c in coord_iterators]):
cut_coords = [cs[index] for cs, index in zip(coord_iterators, indices)]
coords_dict = dict(zip(axis_name_or_axes, cut_coords))
yield coords_dict, self._obj.sel(method="nearest", **coords_dict)
def map_axes(self, axes, fn, dtype=None, **kwargs):
if isinstance(self._obj, xr.Dataset):
raise TypeError(
"map_axes can only work on xr.DataArrays for now because of "
"how the type inference works"
)
obj = self._obj.copy(deep=True)
if dtype is not None:
obj.values = np.ndarray(shape=obj.values.shape, dtype=dtype)
type_assigned = False
for coord, value in self.iterate_axis(axes):
new_value = fn(value, coord)
if dtype is None:
if not type_assigned:
obj.values = np.ndarray(shape=obj.values.shape, dtype=new_value.data.dtype)
type_assigned = True
obj.loc[coord] = new_value.values
else:
obj.loc[coord] = new_value
return obj
def transform(
self,
axes: Union[str, List[str]],
transform_fn: Callable,
dtype: DTypeLike = None,
*args,
**kwargs,
):
"""Applies a vectorized operation across a subset of array axes.
Transform has similar semantics to matrix multiplication, the dimensions of the
output can grow or shrink depending on whether the transformation is size preserving,
grows the data, shinks the data, or leaves in place.
Examples:
As an example, let us suppose we have a function which takes the mean and
variance of the data:
[dimension], coordinate_value -> [{'mean', 'variance'}]
And a dataset with dimensions [X, Y]. Then calling transform
maps to a dataset with the same dimension X but where Y has been replaced by
the length 2 label {'mean', 'variance'}. The full dimensions in this case are
['X', {'mean', 'variance'}].
>>> data.G.transform('X', f).dims # doctest: +SKIP
["X", "mean", "variance"]
Please note that the transformed axes always remain in the data because they
are iterated over and cannot therefore be modified.
The transform function `transform_fn` must accept the coordinate of the
marginal at the currently iterated point.
Args:
axes: Dimension/axis or set of dimensions to iterate over
transform_fn: Transformation function that takes a DataArray into a new DataArray
dtype: An optional type hint for the transformed data. Defaults to None.
args: args to pass into transform_fn
kwargs: kwargs to pass into transform_fn
Raises:
TypeError: When the underying object is an `xr.Dataset` instead of an `xr.DataArray`.
This is due to a constraint related to type inference with a single passed dtype.
Returns:
The data consisting of applying `transform_fn` across the specified axes.
"""
if isinstance(self._obj, xr.Dataset):
raise TypeError(
"transform can only work on xr.DataArrays for now because of "
"how the type inference works"
)
dest = None
for coord, value in self.iterate_axis(axes):
new_value = transform_fn(value, coord, *args, **kwargs)
if dest is None:
new_value = transform_fn(value, coord, *args, **kwargs)
original_dims = [d for d in self._obj.dims if d not in value.dims]
original_shape = [self._obj.shape[self._obj.dims.index(d)] for d in original_dims]
original_coords = {k: v for k, v in self._obj.coords.items() if k not in value.dims}
full_shape = original_shape + list(new_value.shape)
new_coords = original_coords
new_coords.update(
{k: v for k, v in new_value.coords.items() if k not in original_coords}
)
new_dims = original_dims + list(new_value.dims)
dest = xr.DataArray(
np.zeros(full_shape, dtype=dtype or new_value.data.dtype),
coords=new_coords,
dims=new_dims,
)
dest.loc[coord] = new_value
return dest
def map(self, fn, **kwargs):
return apply_dataarray(self._obj, np.vectorize(fn, **kwargs))
def enumerate_iter_coords(self):
coords_list = [self._obj.coords[d].values for d in self._obj.dims]
for indices in itertools.product(*[range(len(c)) for c in coords_list]):
cut_coords = [cs[index] for cs, index in zip(coords_list, indices)]
yield indices, dict(zip(self._obj.dims, cut_coords))
def iter_coords(self, dim_names=None):
if dim_names is None:
dim_names = self._obj.dims
for ts in itertools.product(*[self._obj.coords[d].values for d in self._obj.dims]):
yield dict(zip(self._obj.dims, ts))
def range(self, generic_dim_names=True):
indexed_coords = [self._obj.coords[d] for d in self._obj.dims]
indexed_ranges = [(np.min(coord.values), np.max(coord.values)) for coord in indexed_coords]
dim_names = self._obj.dims
if generic_dim_names:
dim_names = NORMALIZED_DIM_NAMES[: len(dim_names)]
return dict(zip(dim_names, indexed_ranges))
def stride(self, *args, generic_dim_names=True):
indexed_coords = [self._obj.coords[d] for d in self._obj.dims]
indexed_strides = [coord.values[1] - coord.values[0] for coord in indexed_coords]
dim_names = self._obj.dims
if generic_dim_names:
dim_names = NORMALIZED_DIM_NAMES[: len(dim_names)]
result = dict(zip(dim_names, indexed_strides))
if args:
if len(args) == 1:
if not isinstance(args[0], str):
# if passed list of strs as argument
result = [result[selected_names] for selected_names in args[0]]
else:
# if passed single name as argument
result = result[args[0]]
else:
# if passed several names as arguments
result = [result[selected_names] for selected_names in args]
return result
def shift_by(self, other: xr.DataArray, shift_axis=None, zero_nans=True, shift_coords=False):
# for now we only support shifting by a one dimensional array
data = self._obj.copy(deep=True)
by_axis = other.dims[0]
assert len(other.dims) == 1
assert len(other.coords[by_axis]) == len(data.coords[by_axis])
if shift_coords:
mean_shift = np.mean(other.values)
other -= mean_shift
if shift_axis is None:
option_dims = list(data.dims)
option_dims.remove(by_axis)
assert len(option_dims) == 1
shift_axis = option_dims[0]
shift_amount = -other.values / data.G.stride(generic_dim_names=False)[shift_axis]
shifted_data = arpes.utilities.math.shift_by(
data.values,
shift_amount,
axis=list(data.dims).index(shift_axis),
by_axis=list(data.dims).index(by_axis),
order=1,
)
if zero_nans:
shifted_data[np.isnan(shifted_data)] = 0
built_data = xr.DataArray(
shifted_data,
data.coords,
data.dims,
attrs=data.attrs.copy(),
)
if shift_coords:
built_data = built_data.assign_coords(
**dict([[shift_axis, data.coords[shift_axis] + mean_shift]])
)
return built_data
def __init__(self, xarray_obj: DataType):
self._obj = xarray_obj
@xr.register_dataarray_accessor("X")
class SelectionToolAccessor:
_obj = None
def __init__(self, xarray_obj: DataType):
self._obj = xarray_obj
def last_exceeding(self, dim, value: float, relative=False) -> xr.DataArray:
return self.first_exceeding(dim, value, relative=relative, reverse=False)
@xr.register_dataset_accessor("S")
class ARPESDatasetAccessor(ARPESAccessorBase):
"""Spectrum related accessor for `xr.Dataset`."""
def __getattr__(self, item: str) -> Any:
"""Forward attribute access to the spectrum, if necessary.
Args:
item: Attribute name
Returns:
The attribute after lookup on the default spectrum
"""
return getattr(self._obj.S.spectrum.S, item)
def polarization_plot(self, **kwargs) -> plt.Axes:
"""Creates a spin polarization plot.
Returns:
The axes which were plotted onto for customization.
"""
out = kwargs.get("out")
if out is not None and isinstance(out, bool):
out = "{}_spin_polarization.png".format(self.label)
kwargs["out"] = out
return plotting.spin_polarized_spectrum(self._obj, **kwargs)
@property
def is_spatial(self) -> bool:
"""Predicate indicating whether the dataset is a spatial scanning dataset.
Returns:
True if the dataset has dimensions indicating it is a spatial scan.
False otherwise
"""
try:
return self.spectrum.S.is_spatial
except Exception:
# self.spectrum may be None, in which case it is not a spatial scan
return False
@property
def spectrum(self) -> Optional[xr.DataArray]:
"""Isolates a single spectrum from a dataset.
This is a convenience method which is typically used in startup for
tools and analysis routines which need to operate on a single
piece of data. As an example, the image browser `qt_tool` needs
an `xr.DataArray` to operate but will accept an `xr.Dataset`
which it will attempt to resolve to a single spectrum.
In practice, we filter data variables by whether they contain "spectrum"
in the name before selecting the one with the largest pixel volume.
This is a heuristic which tries to guarantee we select ARPES data
above XPS data, if they were collected together.
Returns:
A spectrum found in the dataset, if one can be isolated.
In the case that several candidates are found, a single spectrum
is selected among the candidates.
Attributes from the parent dataset are assigned onto the selected
array as a convenience.
"""
spectrum = None
if "spectrum" in self._obj.data_vars:
spectrum = self._obj.spectrum
elif "raw" in self._obj.data_vars:
spectrum = self._obj.raw
elif "__xarray_dataarray_variable__" in self._obj.data_vars:
spectrum = self._obj.__xarray_dataarray_variable__
else:
candidates = self.spectra
if candidates:
spectrum = candidates[0]
best_volume = np.prod(spectrum.shape)
for c in candidates[1:]:
volume = np.prod(c.shape)
if volume > best_volume:
spectrum = c
best_volume = volume
if spectrum is not None and "df" not in spectrum.attrs:
spectrum.attrs["df"] = self._obj.attrs.get("df", None)
return spectrum
@property
def spectra(self) -> List[xr.DataArray]:
"""Collects the variables which are likely spectra.
Returns:
The subset of the data_vars which have dimensions indicating
that they are spectra.
"""
spectra = []
for dv in list(self._obj.data_vars):
if "spectrum" in dv:
spectra.append(self._obj[dv])
return spectra
@property
def spectrum_type(self) -> str:
"""Gives a heuristic estimate of what kind of data is contained by the spectrum.
Returns:
The kind of data, coarsely
"""
try:
# this isn't the smartest thing in the world,
# but it should allow some old code to keep working on datasets transparently
return self.spectrum.S.spectrum_type
except Exception:
return "dataset"
@property
def degrees_of_freedom(self) -> Set[str]:
"""The collection of all degrees of freedom.
Equivalently, dimensions on a piece of data.
Returns:
All degrees of freedom as a set.
"""
return set(self.spectrum.dims)
@property
def spectrum_degrees_of_freedom(self) -> Set[str]:
"""Collects the spectrometer degrees of freedom.
Spectrometer degrees of freedom are any which would be collected by an ARToF
and their momentum equivalents.
Returns:
The collection of spectrum degrees of freedom.
"""
return self.degrees_of_freedom.intersection({"eV", "phi", "pixel", "kx", "kp", "ky"})
@property
def scan_degrees_of_freedom(self) -> Set[str]:
"""Collects the scan degrees of freedom.
Scan degrees of freedom are all of the degrees of freedom which are not recorded
by the spectrometer but are "scanned over". This includes spatial axes,
temperature, etc.
Returns:
The collection of scan degrees of freedom represented in the array.
"""
return self.degrees_of_freedom.difference(self.spectrum_degrees_of_freedom)
def reference_plot(self, **kwargs):
"""Creates reference plots for a dataset.
A bit of a misnomer because this actually makes many plots. For full datasets,
the relevant components are:
#. Temperature as function of scan DOF
#. Photocurrent as a function of scan DOF
#. Photocurrent normalized + unnormalized figures, in particular
#. The reference plots for the photocurrent normalized spectrum
#. The normalized total cycle intensity over scan DoF, i.e. cycle vs scan DOF
integrated over E, phi
#. For delay scans:
#. Fermi location as a function of scan DoF, integrated over phi
#. Subtraction scans
#. For spatial scans:
#. energy/angle integrated spatial maps with subsequent measurements indicated
#. energy/angle integrated FS spatial maps with subsequent measurements indicated
Args:
kwargs: Passed to plotting routines to provide user control
"""
scan_dofs_integrated = self._obj.sum(*list(self.scan_degrees_of_freedom))
original_out = kwargs.get("out")
# make figures for temperature, photocurrent, delay
make_figures_for = ["T", "IG_nA", "current", "photocurrent"]
name_normalization = {
"T": "T",
"IG_nA": "photocurrent",
"current": "photocurrent",
}
for figure_item in make_figures_for:
if figure_item not in self._obj.data_vars:
continue
name = name_normalization.get(figure_item, figure_item)
data_var = self._obj[figure_item]
out = "{}_{}_spec_integrated_reference.png".format(self.label, name)
plotting.scan_var_reference_plot(data_var, title="Reference {}".format(name), out=out)
# may also want to make reference figures summing over cycle, or summing over beta
# make photocurrent normalized figures
try:
normalized = self._obj / self._obj.IG_nA
normalized.S.make_spectrum_reference_plots(prefix="norm_PC_", out=True)
except:
pass
self.make_spectrum_reference_plots(out=True)
def make_spectrum_reference_plots(self, prefix: str = "", **kwargs):
"""Creates photocurrent normalized + unnormalized figures.
Creates:
#. The reference plots for the photocurrent normalized spectrum
#. The normalized total cycle intensity over scan DoF, i.e. cycle vs scan DOF integrated over E, phi
#. For delay scans:
#. Fermi location as a function of scan DoF, integrated over phi
#. Subtraction scans
Args:
prefix: A prefix inserted into filenames to make them unique.
kwargs: Passed to plotting routines to provide user control over plotting
behavior
"""
self.spectrum.S.reference_plot(pattern=prefix + "{}.png", **kwargs)
if self.is_spatial:
referenced = self.referenced_scans
if "cycle" in self._obj.coords:
integrated_over_scan = self._obj.sum(*list(self.spectrum_degrees_of_freedom))
integrated_over_scan.S.spectrum.S.reference_plot(
pattern=prefix + "sum_spec_DoF_{}.png", **kwargs
)
if "delay" in self._obj.coords:
dims = self.spectrum_degrees_of_freedom
dims.remove("eV")
angle_integrated = self._obj.sum(*list(dims))
# subtraction scan
self.spectrum.S.subtraction_reference_plots(pattern=prefix + "{}.png", **kwargs)
angle_integrated.S.fermi_edge_reference_plots(pattern=prefix + "{}.png", **kwargs)
def __init__(self, xarray_obj: xr.Dataset):
"""Initialization hook for xarray.
This should never need to be called directly.
Args:
xarray_obj: The parent object which this is an accessor for
"""
super().__init__(xarray_obj)
self._spectrum = None