"""Helper functions for coordinate transformations and user/analysis API.
All the functions here assume standard polar angles, as given in the
`data model documentation <https://arpes.readthedocs.io/spectra>`_.
Functions here must accept constants or numpy arrays as valid inputs,
so all standard math functions have been replaced by their equivalents out
of numpy. Array broadcasting should handle any issues or weirdnesses that
would encourage the use of direct iteration, but in case you need to write
a conversion directly, be aware that any functions here must work on arrays
as well for consistency with client code.
Everywhere:
Kinetic energy -> 'kinetic_energy'
Binding energy -> 'eV', for convenience (negative below 0)
Photon energy -> 'hv'
Better facilities should be added for ToFs to do simultaneous (timing, angle) to (binding energy, k-space).
"""
from arpes.utilities.conversion.grids import (
determine_axis_type,
determine_momentum_axes_from_measurement_axes,
is_dimension_unconvertible,
)
from .fast_interp import Interpolator
from arpes.trace import traceable
import collections
import warnings
import numpy as np
import scipy.interpolate
import xarray as xr
from arpes.provenance import provenance, update_provenance
from arpes.utilities import normalize_to_spectrum
from typing import Callable, Optional, Union
from .kx_ky_conversion import ConvertKxKy, ConvertKp
from .kz_conversion import ConvertKpKz
__all__ = ["convert_to_kspace", "slice_along_path"]
@traceable
def grid_interpolator_from_dataarray(
arr: xr.DataArray,
fill_value=0.0,
method="linear",
bounds_error=False,
trace: Callable = None,
):
"""Translates an xarray.DataArray contents into a scipy.interpolate.RegularGridInterpolator.
This is principally used for coordinate translations.
"""
flip_axes = set()
for d in arr.dims:
c = arr.coords[d]
if len(c) > 1 and c[1] - c[0] < 0:
flip_axes.add(d)
values = arr.values
trace("Flipping axes")
for dim in flip_axes:
values = np.flip(values, arr.dims.index(dim))
interp_points = [
arr.coords[d].values[::-1] if d in flip_axes else arr.coords[d].values for d in arr.dims
]
trace_size = [len(pts) for pts in interp_points]
if method == "linear":
trace(f"Using fast_interp.Interpolator: size {trace_size}")
return Interpolator.from_arrays(interp_points, values)
trace(f"Calling scipy.interpolate.RegularGridInterpolator: size {trace_size}")
return scipy.interpolate.RegularGridInterpolator(
points=interp_points,
values=values,
bounds_error=bounds_error,
fill_value=fill_value,
method=method,
)
def slice_along_path(
arr: xr.DataArray,
interpolation_points=None,
axis_name=None,
resolution=None,
shift_gamma=True,
n_points: Optional[int] = None,
extend_to_edge=False,
**kwargs,
):
"""Gets a cut along a path specified by waypoints in an array.
TODO: There might be a little bug here where the last coordinate has a value of 0, causing the interpolation to loop
back to the start point. For now I will just deal with this in client code where I see it until I understand if it is
universal.
Interpolates along a path through a volume. If the volume is higher dimensional than the desired path, the
interpolation is broadcasted along the free dimensions. This allows one to specify a k-space path and receive
the band structure along this path in k-space.
Points can either by specified by coordinates, or by reference to symmetry points, should they exist in the source
array. These symmetry points are translated to regular coordinates immediately, but are provided as a convenience.
If not all points specify the same set of coordinates, an attempt will be made to unify the coordinates. As an example,
if the specified path is (kx=0, ky=0, T=20) -> (kx=1, ky=1), the path will be made between (kx=0, ky=0, T=20) ->
(kx=1, ky=1, T=20). On the other hand, the path (kx=0, ky=0, T=20) -> (kx=1, ky=1, T=40) -> (kx=0, ky=1) will result
in an error because there is no way to break the ambiguity on the temperature for the last coordinate.
A reasonable value will be chosen for the resolution, near the maximum resolution of any of the interpolated
axes by default.
This function transparently handles the entire path. An alternate approach would be to convert each segment
separately and concatenate the interpolated axis with xarray.
If the sentinel value 'G' for the Gamma point is included in the interpolation points, the coordinate axis of the
interpolated coordinate will be shifted so that its value at the Gamma point is 0. You can opt out of this with the
parameter 'shift_gamma'
Args:
arr: Source data
interpolation_points: Path vertices
axis_name: Label for the interpolated axis. Under special
circumstances a reasonable name will be chosen,
resolution: Requested resolution along the interpolated axis.
shift_gamma: Controls whether the interpolated axis is shifted
to a value of 0 at Gamma.
n_points: The number of desired points along the output path. This will be inferred
approximately based on resolution if not provided.
extend_to_edge: Controls whether or not to scale the vector S -
G for symmetry point S so that you interpolate
**kwargs
such as when the interpolation dimensions are kx and ky: in this case the interpolated dimension will be labeled kp.
In mixed or ambiguous situations the axis will be labeled by the default value 'inter'.
to the edge of the available data
Returns:
xr.DataArray containing the interpolated data.
"""
if interpolation_points is None:
raise ValueError("You must provide points specifying an interpolation path")
def extract_symmetry_point(name):
raw_point = arr.attrs["symmetry_points"][name]
G = arr.attrs["symmetry_points"]["G"]
if not extend_to_edge or name == "G":
return raw_point
# scale the point so that it reaches the edge of the dataset
S = np.array([raw_point[d] for d in arr.dims if d in raw_point])
G = np.array([G[d] for d in arr.dims if d in raw_point])
scale_factor = np.inf
for i, d in enumerate([d for d in arr.dims if d in raw_point]):
dS = (S - G)[i]
coord = arr.coords[d]
if np.abs(dS) < 0.001:
continue
if dS < 0:
required_scale = (np.min(coord) - G[i]) / dS
if required_scale < scale_factor:
scale_factor = float(required_scale)
else:
required_scale = (np.max(coord) - G[i]) / dS
if required_scale < scale_factor:
scale_factor = float(required_scale)
S = (S - G) * scale_factor + G
return dict(zip([d for d in arr.dims if d in raw_point], S))
parsed_interpolation_points = [
x
if isinstance(x, collections.Iterable) and not isinstance(x, str)
else extract_symmetry_point(x)
for x in interpolation_points
]
free_coordinates = list(arr.dims)
seen_coordinates = collections.defaultdict(set)
for point in parsed_interpolation_points:
for coord, value in point.items():
seen_coordinates[coord].add(value)
if coord in free_coordinates:
free_coordinates.remove(coord)
for point in parsed_interpolation_points:
for coord, values in seen_coordinates.items():
if coord not in point:
if len(values) != 1:
raise ValueError(
"Ambiguous interpolation waypoint broadcast at dimension {}".format(coord)
)
else:
point[coord] = list(values)[0]
if axis_name is None:
try:
axis_name = determine_axis_type(seen_coordinates.keys())
except KeyError:
axis_name = "inter"
if axis_name == "angle" or axis_name == "inter":
warnings.warn(
"Interpolating along axes with different dimensions "
"will not include Jacobian correction factor."
)
converted_coordinates = None
converted_dims = free_coordinates + [axis_name]
path_segments = list(zip(parsed_interpolation_points, parsed_interpolation_points[1:]))
def element_distance(waypoint_a, waypoint_b):
delta = np.array([waypoint_a[k] - waypoint_b[k] for k in waypoint_a.keys()])
return np.linalg.norm(delta)
def required_sampling_density(waypoint_a, waypoint_b):
ks = waypoint_a.keys()
dist = element_distance(waypoint_a, waypoint_b)
delta = np.array([waypoint_a[k] - waypoint_b[k] for k in ks])
delta_idx = [abs(d / (arr.coords[k][1] - arr.coords[k][0])) for d, k in zip(delta, ks)]
return dist / np.max(delta_idx)
# Approximate how many points we should use
segment_lengths = [element_distance(*segment) for segment in path_segments]
path_length = sum(segment_lengths)
gamma_offset = 0 # offset the gamma point to a k coordinate of 0 if possible
if "G" in interpolation_points and shift_gamma:
gamma_offset = sum(segment_lengths[0 : interpolation_points.index("G")])
if resolution is None:
if n_points is None:
resolution = np.min([required_sampling_density(*segment) for segment in path_segments])
else:
path_length / n_points
def converter_for_coordinate_name(name):
def raw_interpolator(*coordinates):
return coordinates[free_coordinates.index(name)]
if name in free_coordinates:
return raw_interpolator
# Conversion involves the interpolated coordinates
def interpolated_coordinate_to_raw(*coordinates):
# Coordinate order is [*free_coordinates, interpolated]
interpolated = coordinates[len(free_coordinates)] + gamma_offset
# Start with empty array that we will mask writes onto
# We need to go with a masking approach rather than a concatenation based one because the coordinates
# come from np.meshgrid
dest_coordinate = np.zeros(shape=interpolated.shape)
start = 0
for i, l in enumerate(segment_lengths):
end = start + l
normalized = (interpolated - start) / l
seg_start, seg_end = path_segments[i]
dim_start, dim_end = seg_start[name], seg_end[name]
mask = np.logical_and(normalized >= 0, normalized < 1)
dest_coordinate[mask] = (
dim_start * (1 - normalized[mask]) + dim_end * normalized[mask]
)
start = end
return dest_coordinate
return interpolated_coordinate_to_raw
converted_coordinates = {d: arr.coords[d].values for d in free_coordinates}
if n_points is None:
n_points = int(path_length / resolution)
# Adjust this coordinate under special circumstances
converted_coordinates[axis_name] = (
np.linspace(0, path_length, int(path_length / resolution)) - gamma_offset
)
converted_ds = convert_coordinates(
arr,
converted_coordinates,
{
"dims": converted_dims,
"transforms": dict(zip(arr.dims, [converter_for_coordinate_name(d) for d in arr.dims])),
},
as_dataset=True,
)
if axis_name in arr.dims and len(parsed_interpolation_points) == 2:
if parsed_interpolation_points[1][axis_name] < parsed_interpolation_points[0][axis_name]:
# swap the sign on this axis as a convenience to the caller
converted_ds.coords[axis_name].data = -converted_ds.coords[axis_name].data
if "id" in converted_ds.attrs:
del converted_ds.attrs["id"]
provenance(
converted_ds,
arr,
{
"what": "Slice along path",
"by": "slice_along_path",
"parsed_interpolation_points": parsed_interpolation_points,
"interpolation_points": interpolation_points,
},
)
return converted_ds
[docs]@update_provenance("Automatically k-space converted")
@traceable
def convert_to_kspace(
arr: xr.DataArray,
bounds=None,
resolution=None,
calibration=None,
coords=None,
allow_chunks: bool = False,
trace: Callable = None,
**kwargs,
):
"""Converts volumetric the data to momentum space ("backwards"). Typically what you want.
Works in general by regridding the data into the new coordinate space and then
interpolating back into the original data.
For forward conversion, see sibling methods. Forward conversion works by
converting the coordinates, rather than by interpolating
the data. As a result, the data will be totally unchanged by the conversion
(if we do not apply a Jacobian correction), but the coordinates will no
longer have equal spacing.
This is only really useful for zero and one dimensional data because for two dimensional data,
the coordinates must become two dimensional in order to fully specify every data point
(this is true in generality, in 3D the coordinates must become 3D as well).
The only exception to this is if the extra axes do not need to be k-space converted. As is the
case where one of the dimensions is `cycle` or `delay`, for instance.
You can request a particular resolution for the new data with the `resolution=` parameter,
or a specific set of bounds with the `bounds=`
Examples:
Convert a 2D cut with automatically inferred range and resolution.
>>> convert_to_kspace(arpes.io.load_example_data()) # doctest: +SKIP
xr.DataArray(...)
Convert a 3D map with a specified momentum window
>>> convert_to_kspace( # doctest: +SKIP
fermi_surface_map,
kx=np.linspace(-1, 1, 200),
ky=np.linspace(-1, 1, 350),
)
xr.DataArray(...)
Args:
arr (xr.DataArray): [description]
#bounds ([type], optional): [description]. Defaults to None.
resolution ([type], optional): [description]. Defaults to None.
calibration ([type], optional): [description]. Defaults to None.
coords ([type], optional): [description]. Defaults to None.
allow_chunks (bool, optional): [description]. Defaults to False.
trace (Callable, optional): Controls whether to use execution tracing. Defaults to None.
Pass `True` to enable.
Raises:
NotImplementedError: [description]
AnalysisError: [description]
ValueError: [description]
Returns:
[type]: [description]
"""
if coords is None:
coords = {}
coords.update(kwargs)
trace("Normalizing to spectrum")
if isinstance(arr, xr.Dataset):
warnings.warn(
"Remember to use a DataArray not a Dataset, attempting to extract spectrum and copy attributes."
)
attrs = arr.attrs.copy()
arr = normalize_to_spectrum(arr)
arr.attrs.update(attrs)
has_eV = "eV" in arr.dims
# Chunking logic
if allow_chunks and has_eV and len(arr.eV) > 50:
DESIRED_CHUNK_SIZE = 1000 * 1000 * 20
n_chunks = np.prod(arr.shape) // DESIRED_CHUNK_SIZE
if n_chunks > 100:
warnings.warn("Input array is very large. Please consider resampling.")
chunk_thickness = max(len(arr.eV) // n_chunks, 1)
trace(f"Chunking along energy: {n_chunks}, thickness {chunk_thickness}")
finished = []
low_idx = 0
high_idx = chunk_thickness
while low_idx < len(arr.eV):
chunk = arr.isel(eV=slice(low_idx, high_idx))
if len(chunk.eV) == 1:
chunk = chunk.squeeze("eV")
kchunk = convert_to_kspace(
chunk,
bounds=bounds,
resolution=resolution,
calibration=calibration,
coords=coords,
allow_chunks=False,
trace=trace,
**kwargs,
)
if "eV" not in kchunk.dims:
kchunk = kchunk.expand_dims("eV")
finished.append(kchunk)
low_idx = high_idx
high_idx = min(len(arr.eV), high_idx + chunk_thickness)
return xr.concat(finished, dim="eV")
# Chunking is finished here
# TODO be smarter about the resolution inference
trace("Determining dimensions and resolution")
removed = [d for d in arr.dims if is_dimension_unconvertible(d)]
old_dims = [d for d in arr.dims if not is_dimension_unconvertible(d)]
# Energy gets put at the front as a standardization
if "eV" in removed:
removed.remove("eV")
old_dims.sort()
trace("Replacing dummy coordinates with index-like ones.")
# temporarily reassign coordinates for dimensions we will not
# convert to "index-like" dimensions
restore_index_like_coordinates = {r: arr.coords[r].values for r in removed}
new_index_like_coordinates = {r: np.arange(len(arr.coords[r].values)) for r in removed}
arr = arr.assign_coords(**new_index_like_coordinates)
if not old_dims:
return arr # no need to convert, might be XPS or similar
converted_dims = (
(["eV"] if has_eV else [])
+ determine_momentum_axes_from_measurement_axes(old_dims)
+ removed
)
convert_cls = {
("phi",): ConvertKp,
("beta", "phi"): ConvertKxKy,
("phi", "theta"): ConvertKxKy,
("phi", "psi"): ConvertKxKy,
# ('chi', 'phi',): ConvertKxKy,
("hv", "phi"): ConvertKpKz,
}.get(tuple(old_dims))
converter = convert_cls(arr, converted_dims, calibration=calibration)
trace("Converting coordinates")
converted_coordinates = converter.get_coordinates(resolution=resolution, bounds=bounds)
if not set(coords.keys()).issubset(converted_coordinates.keys()):
extra = set(coords.keys()).difference(converted_coordinates.keys())
raise ValueError("Unexpected passed coordinates: {}".format(extra))
converted_coordinates.update(coords)
trace("Calling convert_coordinates")
result = convert_coordinates(
arr,
converted_coordinates,
{
"dims": converted_dims,
"transforms": dict(zip(arr.dims, [converter.conversion_for(d) for d in arr.dims])),
},
trace=trace,
)
trace("Reassigning index-like coordinates.")
result = result.assign_coords(**restore_index_like_coordinates)
trace("Finished.")
return result
@traceable
def convert_coordinates(
arr: xr.DataArray,
target_coordinates,
coordinate_transform,
as_dataset=False,
trace: Callable = None,
):
ordered_source_dimensions = arr.dims
trace("Instantiating grid interpolator.")
grid_interpolator = grid_interpolator_from_dataarray(
arr.transpose(*ordered_source_dimensions),
fill_value=float("nan"),
trace=trace,
)
trace("Finished instantiating grid interpolator.")
# Skip the Jacobian correction for now
# Convert the raw coordinate axes to a set of gridded points
trace(f"Calling meshgrid: {[len(target_coordinates[d]) for d in coordinate_transform['dims']]}")
meshed_coordinates = np.meshgrid(
*[target_coordinates[dim] for dim in coordinate_transform["dims"]], indexing="ij"
)
trace("Raveling coordinates")
meshed_coordinates = [meshed_coord.ravel() for meshed_coord in meshed_coordinates]
if "eV" not in arr.dims:
try:
meshed_coordinates = [arr.S.lookup_offset_coord("eV")] + meshed_coordinates
except ValueError:
pass
old_coord_names = [dim for dim in arr.dims if dim not in target_coordinates]
old_coordinate_transforms = [
coordinate_transform["transforms"][dim] for dim in arr.dims if dim not in target_coordinates
]
trace(f"Calling coordinate transforms")
output_shape = [len(target_coordinates[d]) for d in coordinate_transform["dims"]]
def compute_coordinate(transform):
return np.reshape(
transform(*meshed_coordinates),
output_shape,
order="C",
)
old_dimensions = []
for tr in old_coordinate_transforms:
trace(f"Running transform {tr}")
old_dimensions.append(compute_coordinate(tr))
trace(f"Done running transforms.")
ordered_transformations = [coordinate_transform["transforms"][dim] for dim in arr.dims]
trace("Calling grid interpolator")
trace("Pulling back coordinates")
transformed_coordinates = []
for tr in ordered_transformations:
trace(f"Running transform {tr}")
transformed_coordinates.append(tr(*meshed_coordinates))
if not isinstance(grid_interpolator, Interpolator):
transformed_coordinates = np.array(transformed_coordinates).T
trace("Calling grid interpolator")
converted_volume = grid_interpolator(transformed_coordinates)
# Wrap it all up
def acceptable_coordinate(c: Union[np.ndarray, xr.DataArray]) -> bool:
# Currently we do this to filter out coordinates that are functions of the old angular dimensions,
# we could forward convert these, but right now we do not
try:
if set(c.dims).issubset(coordinate_transform["dims"]):
return True
else:
return False
except:
return True
trace("Bundling into DataArray")
target_coordinates = {k: v for k, v in target_coordinates.items() if acceptable_coordinate(v)}
data = xr.DataArray(
np.reshape(
converted_volume,
[len(target_coordinates[d]) for d in coordinate_transform["dims"]],
order="C",
),
target_coordinates,
coordinate_transform["dims"],
attrs=arr.attrs,
)
old_mapped_coords = [
xr.DataArray(values, target_coordinates, coordinate_transform["dims"], attrs=arr.attrs)
for values in old_dimensions
]
if as_dataset:
vars = {"data": data}
vars.update(dict(zip(old_coord_names, old_mapped_coords)))
return xr.Dataset(vars, attrs=arr.attrs)
trace("Finished")
return data