Source code for arpes.plotting.band_tool
"""An interactive band selection tool used to initialize curve fits."""
import numpy as np
from bokeh import events
import xarray as xr
from arpes.analysis.band_analysis import fit_patterned_bands
from arpes.exceptions import AnalysisError
from arpes.models import band
from arpes.plotting.interactive_utils import CursorTool, SaveableTool
from arpes.utilities import normalize_to_spectrum
__all__ = ("BandTool",)
[docs]class BandTool(SaveableTool, CursorTool):
"""Two dimensional fitting band tool."""
auto_zero_nans = False
auto_rebin = False
[docs] def __init__(self, **kwargs):
"""Load plot sizes and standard settings from user overrides."""
super().__init__(kwargs.pop("name", None))
self.load_settings(**kwargs)
self.app_main_size = self.settings.get("main_width", 600)
self.app_marginal_size = self.settings.get("marginal_width", 300)
self.active_band = None
self.pointer_mode = "band"
def tool_handler(self, doc):
"""Sets up widgets for the Bokeh application."""
from bokeh.layouts import row, column
from bokeh.models.mappers import LinearColorMapper
from bokeh.models import widgets
from bokeh.plotting import figure
if len(self.arr.shape) != 2:
raise AnalysisError("Cannot use the band tool on non image-like spectra")
arr = self.arr
x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]
default_palette = self.default_palette
self.app_context.update(
{
"bands": {},
"center_float": None,
"data": arr,
"data_range": {
"x": (np.min(x_coords.values), np.max(x_coords.values)),
"y": (np.min(y_coords.values), np.max(y_coords.values)),
},
"direction_normal": True,
"fit_mode": "mdc",
}
)
figures, plots, app_widgets = (
self.app_context["figures"],
self.app_context["plots"],
self.app_context["widgets"],
)
self.cursor = [np.mean(self.data_range["x"]), np.mean(self.data_range["y"])]
self.color_maps["main"] = LinearColorMapper(
default_palette, low=np.min(arr.values), high=np.max(arr.values), nan_color="black"
)
main_tools = ["wheel_zoom", "tap", "reset"]
main_title = "Band Tool: WARNING Unidentified"
try:
main_title = "Band Tool: {}".format(arr.S.label[:60])
except:
pass
figures["main"] = figure(
tools=main_tools,
plot_width=self.app_main_size,
plot_height=self.app_main_size,
min_border=10,
min_border_left=50,
toolbar_location="left",
x_axis_location="below",
y_axis_location="right",
title=main_title,
x_range=self.data_range["x"],
y_range=self.data_range["y"],
)
figures["main"].xaxis.axis_label = arr.dims[0]
figures["main"].yaxis.axis_label = arr.dims[1]
figures["main"].toolbar.logo = None
figures["main"].background_fill_color = "#fafafa"
plots["main"] = figures["main"].image(
[arr.values.T],
x=self.app_context["data_range"]["x"][0],
y=self.app_context["data_range"]["y"][0],
dw=self.app_context["data_range"]["x"][1] - self.app_context["data_range"]["x"][0],
dh=self.app_context["data_range"]["y"][1] - self.app_context["data_range"]["y"][0],
color_mapper=self.app_context["color_maps"]["main"],
)
# add lines
self.add_cursor_lines(figures["main"])
band_lines = figures["main"].multi_line(xs=[], ys=[], line_color="white", line_width=1)
def append_point_to_band():
cursor = self.cursor
if self.active_band in self.app_context["bands"]:
self.app_context["bands"][self.active_band]["points"].append(list(cursor))
update_band_display()
def click_main_image(event):
self.cursor = [event.x, event.y]
if self.pointer_mode == "band":
append_point_to_band()
update_main_colormap = self.update_colormap_for("main")
POINTER_MODES = [
("Cursor", "cursor"),
("Band", "band"),
]
FIT_MODES = [
("EDC", "edc"),
("MDC", "mdc"),
]
DIRECTIONS = [
("From Bottom/Left", "forward"),
("From Top/Right", "reverse"),
]
BAND_TYPES = [
("Lorentzian", "Lorentzian"),
("Voigt", "Voigt"),
("Gaussian", "Gaussian"),
]
band_classes = {
"Lorentzian": band.Band,
"Gaussian": band.BackgroundBand,
"Voigt": band.VoigtBand,
}
self.app_context["band_options"] = []
def pack_bands():
packed_bands = {}
for band_name, band_description in self.app_context["bands"].items():
if not band_description["points"]:
raise AnalysisError("Band {} is empty.".format(band_name))
stray = None
try:
stray = float(band_description["center_float"])
except (KeyError, ValueError, TypeError):
try:
stray = float(self.app_context["center_float"])
except Exception:
pass
packed_bands[band_name] = {
"name": band_name,
"band": band_classes.get(band_description["type"], band.Band),
"dims": self.arr.dims,
"params": {
"amplitude": {"min": 0},
},
"points": band_description["points"],
}
if stray is not None:
packed_bands[band_name]["params"]["stray"] = stray
return packed_bands
def fit(override_data=None):
packed_bands = pack_bands()
dims = list(self.arr.dims)
if "eV" in dims:
dims.remove("eV")
angular_direction = dims[0]
if isinstance(override_data, xr.Dataset):
override_data = normalize_to_spectrum(override_data)
return fit_patterned_bands(
override_data if override_data is not None else self.arr,
packed_bands,
fit_direction="eV" if self.app_context["fit_mode"] == "edc" else angular_direction,
direction_normal=self.app_context["direction_normal"],
)
self.app_context["pack_bands"] = pack_bands
self.app_context["fit"] = fit
self.pointer_dropdown = widgets.Dropdown(
label="Pointer Mode", button_type="primary", menu=POINTER_MODES
)
self.direction_dropdown = widgets.Dropdown(
label="Fit Direction", button_type="primary", menu=DIRECTIONS
)
self.band_dropdown = widgets.Dropdown(
label="Active Band", button_type="primary", menu=self.app_context["band_options"]
)
self.fit_mode_dropdown = widgets.Dropdown(
label="Mode", button_type="primary", menu=FIT_MODES
)
self.band_type_dropdown = widgets.Dropdown(
label="Band Type", button_type="primary", menu=BAND_TYPES
)
self.band_name_input = widgets.TextInput(placeholder="Band name...")
self.center_float_widget = widgets.TextInput(placeholder="Center Constraint")
self.center_float_copy = widgets.Button(label="Copy to all...")
self.add_band_button = widgets.Button(label="Add Band")
self.clear_band_button = widgets.Button(label="Clear Band")
self.remove_band_button = widgets.Button(label="Remove Band")
self.main_color_range_slider = widgets.RangeSlider(
start=0,
end=100,
value=(
0,
100,
),
title="Color Range",
)
def add_band(band_name):
if band_name not in self.app_context["bands"]:
self.app_context["band_options"].append(
(
band_name,
band_name,
)
)
self.band_dropdown.menu = self.app_context["band_options"]
self.app_context["bands"][band_name] = {
"type": "Lorentzian",
"points": [],
"name": band_name,
"center_float": None,
}
if self.active_band is None:
self.active_band = band_name
self.save_app()
def on_copy_center_float():
for band_name in self.app_context["bands"].keys():
self.app_context["bands"][band_name]["center_float"] = self.app_context[
"center_float"
]
self.save_app()
def on_change_active_band(event):
self.app_context["active_band"] = event.item
self.active_band = event.item
def on_change_pointer_mode(event):
self.app_context["pointer_mode"] = event.item
self.pointer_mode = event.item
def set_center_float_value(event):
new_value = event.item
self.app_context["center_float"] = new_value
if self.active_band in self.app_context["bands"]:
self.app_context["bands"][self.active_band]["center_float"] = new_value
self.save_app()
def set_fit_direction(event):
new_direction = event.item
self.app_context["direction_normal"] = new_direction == "forward"
self.save_app()
def set_fit_mode(event):
new_mode = event.item
self.app_context["fit_mode"] = new_mode
self.save_app()
def set_band_type(event):
if self.active_band in self.app_context["bands"]:
self.app_context["bands"][self.active_band]["type"] = event.item
self.save_app()
def update_band_display():
band_names = self.app_context["bands"].keys()
band_lines.data_source.data = {
"xs": [[p[0] for p in self.app_context["bands"][b]["points"]] for b in band_names],
"ys": [[p[1] for p in self.app_context["bands"][b]["points"]] for b in band_names],
}
self.save_app()
self.update_band_display = update_band_display
def on_clear_band():
if self.active_band in self.app_context["bands"]:
self.app_context["bands"][self.active_band]["points"] = []
update_band_display()
def on_remove_band():
if self.active_band in self.app_context["bands"]:
del self.app_context["bands"][self.active_band]
new_band_options = [
b for b in self.app_context["band_options"] if b[0] != self.active_band
]
self.band_dropdown.menu = new_band_options
self.app_context["band_options"] = new_band_options
self.active_band = None
update_band_display()
# Attach callbacks
self.main_color_range_slider.on_change("value", update_main_colormap)
figures["main"].on_event(events.Tap, click_main_image)
self.band_dropdown.on_click(on_change_active_band)
self.pointer_dropdown.on_click(on_change_pointer_mode)
self.add_band_button.on_click(lambda: add_band(self.band_name_input.value))
self.clear_band_button.on_click(on_clear_band)
self.remove_band_button.on_click(on_remove_band)
self.center_float_copy.on_click(on_copy_center_float)
self.center_float_widget.on_change(set_center_float_value)
self.direction_dropdown.on_click(set_fit_direction)
self.fit_mode_dropdown.on_click(set_fit_mode)
self.band_type_dropdown.on_click(set_band_type)
layout = row(
column(figures["main"]),
column(
column(
self.pointer_dropdown,
self.band_dropdown,
self.fit_mode_dropdown,
self.band_type_dropdown,
self.direction_dropdown,
),
row(
self.band_name_input,
self.add_band_button,
),
row(
self.clear_band_button,
self.remove_band_button,
),
row(self.center_float_widget, self.center_float_copy),
column(
self._cursor_info,
self.main_color_range_slider,
),
),
)
doc.add_root(layout)
doc.title = "Band Tool"
self.load_app()
self.save_app()
def serialize(self):
"""Saves application state so it can be recovered later.
See `.deserialize` for information on what is retained.
"""
return {
"active_band": self.active_band,
"band_options": self.band_options,
"bands": self.bands,
"center_float": self.app_context["center_float"],
"cursor": self.cursor,
"direction_normal": self.direction_normal,
"fit_mode": self.fit_mode,
}
def deserialize(self, json_data):
"""Loads a variety of application state so it can be recovered later.
Mostly we load:
* Band definitions
* Modes for fitting
* The cursor location
"""
self.cursor = json_data.get("cursor", [0, 0])
# self.active_band = json_data.get('active_band', None)
self.app_context["center_float"] = json_data.get("center_float", None)
self.app_context["bands"] = json_data.get("bands", {}) or {}
self.app_context["direction_normal"] = json_data.get("direction_normal", True)
self.app_context["fit_mode"] = json_data.get("fit_mode", "mdc")
self.app_context["band_options"] = json_data.get("band_options", [])
self.center_float_widget.value = str(self.app_context["center_float"] or 0.1)
self.band_dropdown.menu = self.app_context["band_options"]
self.update_band_display()