Source code for arpes.plotting.interactive
"""Implements a 2D and 3D data browser via Bokeh."""
import copy
import warnings
import numpy as np
import colorcet as cc
from arpes.fits import ExponentialDecayCModel, GStepBModel
from .interactive_utils import CursorTool, SaveableTool
__all__ = ("ImageTool",)
# TODO Implement alignment tool
[docs]class ImageTool(SaveableTool, CursorTool):
"""Implements a 2D and 3D data browser via Bokeh."""
[docs] def __init__(self, curs=None, **kwargs):
"""Load application and fetch marginal sizes from settings."""
super().__init__(name=kwargs.pop("name", None))
self.load_settings(**kwargs)
self.app_main_size = self.settings.get("main_width", 600)
self.app_marginal_size = self.settings.get("marginal_width", 300)
if curs is not None:
self.cursor_default = curs
# TODO select path in image
def prep_image(self, image_arr):
"""Optionally, postprocess data before showing it."""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if self.app_context["color_mode"] == "linear":
return image_arr.values
# avoid dependency conflict with numpy v0.16 for now
from skimage import exposure # pylint: disable=import-error
return exposure.equalize_adapthist(image_arr.values, clip_limit=0.03)
def tool_handler(self, doc):
"""Delegates widget creation to 2D and 3D code based on array dimensions."""
if len(self.arr.shape) == 3:
return self.tool_handler_3d(doc)
return self.tool_handler_2d(doc)
def tool_handler_2d(self, doc):
"""Application definition and widgets for the 2D data browser."""
from bokeh import events
from bokeh.layouts import row, column, Spacer
from bokeh.models import ColumnDataSource, widgets
from bokeh.models.mappers import LinearColorMapper
from bokeh.models.widgets.markups import Div
from bokeh.plotting import figure
arr = self.arr
# Set up the data
x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]
# Styling
default_palette = self.default_palette
if arr.S.is_subtracted:
default_palette = cc.coolwarm
error_alpha = 0.3
error_fill = "#3288bd"
# Application Organization
self.app_context.update(
{
"data": arr,
"data_range": {
"x": (np.min(x_coords.values), np.max(x_coords.values)),
"y": (np.min(y_coords.values), np.max(y_coords.values)),
},
"show_stat_variation": False,
"color_mode": "linear",
}
)
def stats_patch_from_data(data, subsampling_rate=None):
if subsampling_rate is None:
subsampling_rate = int(min(data.values.shape[0] / 50, 5))
if subsampling_rate == 0:
subsampling_rate = 1
x_values = data.coords[data.dims[0]].values[::subsampling_rate]
values = data.values[::subsampling_rate]
sq = np.sqrt(values)
lower, upper = values - sq, values + sq
return {
"x": np.append(x_values, x_values[::-1]),
"y": np.append(lower, upper[::-1]),
}
def update_stat_variation(plot_name, data):
patch_data = stats_patch_from_data(data)
if plot_name != "right": # the right plot is on transposed axes
plots[plot_name + "_marginal_err"].data_source.data = patch_data
else:
plots[plot_name + "_marginal_err"].data_source.data = {
"x": patch_data["y"],
"y": patch_data["x"],
}
figures, plots, app_widgets = (
self.app_context["figures"],
self.app_context["plots"],
self.app_context["widgets"],
)
if self.cursor_default is not None and len(self.cursor_default) == 2:
self.cursor = self.cursor_default
else:
self.cursor = [
np.mean(self.app_context["data_range"]["x"]),
np.mean(self.app_context["data_range"]["y"]),
] # try a sensible default
# create the main inset plot
main_image = arr
prepped_main_image = self.prep_image(main_image)
self.app_context["color_maps"]["main"] = LinearColorMapper(
default_palette,
low=np.min(prepped_main_image),
high=np.max(prepped_main_image),
nan_color="black",
)
main_tools = ["wheel_zoom", "tap", "reset", "save"]
main_title = "Bokeh Tool: WARNING Unidentified"
try:
main_title = "Bokeh Tool: %s" % arr.S.label[:60]
except:
pass
figures["main"] = figure(
tools=main_tools,
plot_width=self.app_main_size,
plot_height=self.app_main_size,
min_border=10,
min_border_left=50,
toolbar_location="left",
x_axis_location="below",
y_axis_location="right",
title=main_title,
x_range=self.app_context["data_range"]["x"],
y_range=self.app_context["data_range"]["y"],
)
figures["main"].xaxis.axis_label = arr.dims[0]
figures["main"].yaxis.axis_label = arr.dims[1]
figures["main"].toolbar.logo = None
figures["main"].background_fill_color = "#fafafa"
plots["main"] = figures["main"].image(
[prepped_main_image.T],
x=self.app_context["data_range"]["x"][0],
y=self.app_context["data_range"]["y"][0],
dw=self.app_context["data_range"]["x"][1] - self.app_context["data_range"]["x"][0],
dh=self.app_context["data_range"]["y"][1] - self.app_context["data_range"]["y"][0],
color_mapper=self.app_context["color_maps"]["main"],
)
app_widgets["info_div"] = Div(text="", width=self.app_marginal_size, height=100)
# Create the bottom marginal plot
bottom_marginal = arr.sel(**dict([[arr.dims[1], self.cursor[1]]]), method="nearest")
figures["bottom_marginal"] = figure(
plot_width=self.app_main_size,
plot_height=200,
title=None,
x_range=figures["main"].x_range,
y_range=(np.min(bottom_marginal.values), np.max(bottom_marginal.values)),
x_axis_location="above",
toolbar_location=None,
tools=[],
)
plots["bottom_marginal"] = figures["bottom_marginal"].line(
x=bottom_marginal.coords[arr.dims[0]].values, y=bottom_marginal.values
)
plots["bottom_marginal_err"] = figures["bottom_marginal"].patch(
x=[], y=[], color=error_fill, fill_alpha=error_alpha, line_color=None
)
# Create the right marginal plot
right_marginal = arr.sel(**dict([[arr.dims[0], self.cursor[0]]]), method="nearest")
figures["right_marginal"] = figure(
plot_width=200,
plot_height=self.app_main_size,
title=None,
y_range=figures["main"].y_range,
x_range=(np.min(right_marginal.values), np.max(right_marginal.values)),
y_axis_location="left",
toolbar_location=None,
tools=[],
)
plots["right_marginal"] = figures["right_marginal"].line(
y=right_marginal.coords[arr.dims[1]].values, x=right_marginal.values
)
plots["right_marginal_err"] = figures["right_marginal"].patch(
x=[], y=[], color=error_fill, fill_alpha=error_alpha, line_color=None
)
cursor_lines = self.add_cursor_lines(figures["main"])
# Attach tools and callbacks
toggle = widgets.Toggle(label="Show Stat. Variation", button_type="success", active=False)
def set_show_stat_variation(should_show):
self.app_context["show_stat_variation"] = should_show
if should_show:
main_image_data = arr
update_stat_variation(
"bottom",
main_image_data.sel(**dict([[arr.dims[1], self.cursor[1]]]), method="nearest"),
)
update_stat_variation(
"right",
main_image_data.sel(**dict([[arr.dims[0], self.cursor[0]]]), method="nearest"),
)
plots["bottom_marginal_err"].visible = True
plots["right_marginal_err"].visible = True
else:
plots["bottom_marginal_err"].visible = False
plots["right_marginal_err"].visible = False
toggle.on_click(set_show_stat_variation)
scan_keys = ["x", "y", "z", "pass_energy", "hv", "location", "id", "probe_pol", "pump_pol"]
scan_info_source = ColumnDataSource(
{
"keys": [k for k in scan_keys if k in arr.attrs],
"values": [
str(v) if isinstance(v, float) and np.isnan(v) else v
for v in [arr.attrs[k] for k in scan_keys if k in arr.attrs]
],
}
)
scan_info_columns = [
widgets.TableColumn(field="keys", title="Attr."),
widgets.TableColumn(field="values", title="Value"),
]
POINTER_MODES = [
("Cursor", "cursor"),
("Path", "path"),
]
COLOR_MODES = [
("Adaptive Hist. Eq. (Slow)", "adaptive_equalization"),
# ('Histogram Eq.', 'equalization',), # not implemented
("Linear", "linear"),
# ('Log', 'log',), # not implemented
]
def on_change_color_mode(event):
new_color_mode = event.item
self.app_context["color_mode"] = new_color_mode
if old is None or old != new_color_mode:
right_image_data = arr.sel(
**dict([[arr.dims[0], self.cursor[0]]]), method="nearest"
)
bottom_image_data = arr.sel(
**dict([[arr.dims[1], self.cursor[1]]]), method="nearest"
)
main_image_data = arr
prepped_right_image = self.prep_image(right_image_data)
prepped_bottom_image = self.prep_image(bottom_image_data)
prepped_main_image = self.prep_image(main_image_data)
plots["right"].data_source.data = {"image": [prepped_right_image]}
plots["bottom"].data_source.data = {"image": [prepped_bottom_image.T]}
plots["main"].data_source.data = {"image": [prepped_main_image.T]}
update_main_colormap(None, None, main_color_range_slider.value)
color_mode_dropdown = widgets.Dropdown(
label="Color Mode", button_type="primary", menu=COLOR_MODES
)
color_mode_dropdown.on_click(on_change_color_mode)
symmetry_point_name_input = widgets.TextInput(title="Symmetry Point Name", value="G")
snap_checkbox = widgets.CheckboxButtonGroup(labels=["Snap Axes"], active=[])
place_symmetry_point_at_cursor_button = widgets.Button(
label="Place Point", button_type="primary"
)
def update_symmetry_points_for_display():
pass
def place_symmetry_point():
cursor_dict = dict(zip(arr.dims, self.cursor))
skip_dimensions = {"eV", "delay", "cycle"}
if "symmetry_points" not in arr.attrs:
arr.attrs["symmetry_points"] = {}
snap_distance = {
"phi": 2,
"beta": 2,
"kx": 0.01,
"ky": 0.01,
"kz": 0.01,
"kp": 0.01,
"hv": 4,
}
cursor_dict = {k: v for k, v in cursor_dict.items() if k not in skip_dimensions}
snapped = copy.copy(cursor_dict)
if "Snap Axes" in [snap_checkbox.labels[i] for i in snap_checkbox.active]:
for axis, value in cursor_dict.items():
options = [
point[axis]
for point in arr.attrs["symmetry_points"].values()
if axis in point
]
options = sorted(options, key=lambda x: np.abs(x - value))
if options and np.abs(options[0] - value) < snap_distance[axis]:
snapped[axis] = options[0]
arr.attrs["symmetry_points"][symmetry_point_name_input.value] = snapped
place_symmetry_point_at_cursor_button.on_click(place_symmetry_point)
main_color_range_slider = widgets.RangeSlider(
start=0,
end=100,
value=(
0,
100,
),
title="Color Range (Main)",
)
layout = row(
column(figures["main"], figures["bottom_marginal"]),
column(figures["right_marginal"], Spacer(width=200, height=200)),
column(
column(
widgets.Dropdown(
label="Pointer Mode", button_type="primary", menu=POINTER_MODES
)
),
widgets.Tabs(
tabs=[
widgets.Panel(
child=column(
Div(text="<h2>Colorscale:</h2>"),
color_mode_dropdown,
main_color_range_slider,
Div(text='<h2 style="padding-top: 30px;">General Settings:</h2>'),
toggle,
self._cursor_info,
sizing_mode="scale_width",
),
title="Settings",
),
widgets.Panel(
child=column(
app_widgets["info_div"],
Div(
text='<h2 style="padding-top: 30px; padding-bottom: 10px;">Scan Info</h2>'
),
widgets.DataTable(
source=scan_info_source,
columns=scan_info_columns,
width=400,
height=400,
),
sizing_mode="scale_width",
width=400,
),
title="Info",
),
widgets.Panel(
child=column(
Div(text="<h2>Preparation</h2>"),
symmetry_point_name_input,
snap_checkbox,
place_symmetry_point_at_cursor_button,
sizing_mode="scale_width",
),
title="Preparation",
),
],
width=400,
),
),
)
update_main_colormap = self.update_colormap_for("main")
def click_main_image(event):
self.cursor = [event.x, event.y]
right_marginal_data = arr.sel(**dict([[arr.dims[0], self.cursor[0]]]), method="nearest")
bottom_marginal_data = arr.sel(
**dict([[arr.dims[1], self.cursor[1]]]), method="nearest"
)
plots["bottom_marginal"].data_source.data = {
"x": bottom_marginal_data.coords[arr.dims[0]].values,
"y": bottom_marginal_data.values,
}
plots["right_marginal"].data_source.data = {
"y": right_marginal_data.coords[arr.dims[1]].values,
"x": right_marginal_data.values,
}
if self.app_context["show_stat_variation"]:
update_stat_variation("right", right_marginal_data)
update_stat_variation("bottom", bottom_marginal_data)
figures["bottom_marginal"].y_range.start = np.min(bottom_marginal_data.values)
figures["bottom_marginal"].y_range.end = np.max(bottom_marginal_data.values)
figures["right_marginal"].x_range.start = np.min(right_marginal_data.values)
figures["right_marginal"].x_range.end = np.max(right_marginal_data.values)
self.save_app()
figures["main"].on_event(events.Tap, click_main_image)
main_color_range_slider.on_change("value", update_main_colormap)
doc.add_root(layout)
doc.title = "Bokeh Tool"
self.load_app()
self.save_app()
def serialize(self):
"""Saves the current cursor position so it can be loaded later."""
return {
"cursor_dict": self.cursor_dict,
"cursor": self.cursor,
}
def deserialize(self, json_data):
"""Loads the cursor position from a saved copy of the tool."""
if "cursor" in json_data:
self.cursor = json_data["cursor"]
def tool_handler_3d(self, doc):
"""Application and widget definitions for the 3D data browser."""
from bokeh import events
from bokeh.layouts import row, column, Spacer
from bokeh.models import ColumnDataSource, HoverTool, widgets
from bokeh.models.mappers import LinearColorMapper
from bokeh.models.widgets.markups import Div
from bokeh.plotting import figure
arr = self.arr
# Set up the data
x_coords, y_coords, z_coords = (
arr.coords[arr.dims[0]],
arr.coords[arr.dims[1]],
arr.coords[arr.dims[2]],
)
info_formatters = {
"eV": """<div>
<h2>Fermi Edge Info:</h2>
<p>Gap: <b>{:.1f} meV</b></p>
<p>Edge Width: <b>{:.1f} meV</b></p>
</div>
""",
"delay": """<div>
<h2>Delay Scan Info:</h2>
<p>t0: <b>{:.1f} fs</b></p>
<p>Decay time: <b>{:.1f} fs</b></p>
</div>
""",
}
# Styling
default_palette = self.default_palette
if arr.S.is_subtracted:
default_palette = cc.coolwarm
error_alpha = 0.3
error_fill = "#3288bd"
# Application Organization
self.app_context.update(
{
"data": arr,
"data_range": {
"x": (np.min(x_coords.values), np.max(x_coords.values)),
"y": (np.min(y_coords.values), np.max(y_coords.values)),
"z": (np.min(z_coords.values), np.max(z_coords.values)),
},
"show_stat_variation": False,
"color_mode": "linear",
}
)
def stats_patch_from_data(data, subsampling_rate=None):
if subsampling_rate is None:
subsampling_rate = int(min(data.values.shape[0] / 50, 5))
if subsampling_rate == 0:
subsampling_rate = 1
x_values = data.coords[data.dims[0]].values[::subsampling_rate]
values = data.values[::subsampling_rate]
sq = np.sqrt(values)
lower, upper = values - sq, values + sq
return {
"x": np.append(x_values, x_values[::-1]),
"y": np.append(lower, upper[::-1]),
}
def update_stat_variation(plot_name, data):
patch_data = stats_patch_from_data(data)
if plot_name != "right": # the right plot is on transposed axes
plots[plot_name + "_marginal_err"].data_source.data = patch_data
else:
plots[plot_name + "_marginal_err"].data_source.data = {
"x": patch_data["y"],
"y": patch_data["x"],
}
figures, plots, app_widgets = (
self.app_context["figures"],
self.app_context["plots"],
self.app_context["widgets"],
)
if self.cursor_default is not None and len(self.cursor_default) == 3:
self.cursor = self.cursor_default
else:
self.cursor = [
np.mean(self.data_range["x"]),
np.mean(self.data_range["y"]),
np.mean(self.data_range["z"]),
] # try a sensible default
# create the main inset plot
main_image = arr.sel(**dict([[arr.dims[2], self.cursor[2]]]), method="nearest")
prepped_main_image = self.prep_image(main_image)
self.app_context["color_maps"]["main"] = LinearColorMapper(
default_palette,
low=np.min(prepped_main_image),
high=np.max(prepped_main_image),
nan_color="black",
)
main_title = "Bokeh Tool: WARNING Unidentified"
try:
main_title = "Bokeh Tool: %s" % arr.S.label[:60]
except:
pass
main_tools = ["wheel_zoom", "tap", "reset", "save"]
figures["main"] = figure(
tools=main_tools,
plot_width=self.app_main_size,
plot_height=self.app_main_size,
min_border=10,
min_border_left=50,
toolbar_location="left",
x_axis_location="below",
y_axis_location="right",
title=main_title,
x_range=self.app_context["data_range"]["x"],
y_range=self.app_context["data_range"]["y"],
)
figures["main"].xaxis.axis_label = arr.dims[0]
figures["main"].yaxis.axis_label = arr.dims[1]
figures["main"].toolbar.logo = None
figures["main"].background_fill_color = "#fafafa"
plots["main"] = figures["main"].image(
[prepped_main_image.T],
x=self.app_context["data_range"]["x"][0],
y=self.app_context["data_range"]["y"][0],
dw=self.app_context["data_range"]["x"][1] - self.app_context["data_range"]["x"][0],
dh=self.app_context["data_range"]["y"][1] - self.app_context["data_range"]["y"][0],
color_mapper=self.app_context["color_maps"]["main"],
)
# Create the z-selector
z_marginal_data = arr.sel(
**dict([[arr.dims[0], self.cursor[0]], [arr.dims[1], self.cursor[1]]]), method="nearest"
)
z_hover_tool = HoverTool(
tooltips=[
("x", "@x{1.111}"),
("Int. (arb.)", "@y{1.1}"),
],
mode="vline",
)
z_tools = [z_hover_tool, "wheel_zoom", "tap", "reset"]
figures["z_marginal"] = figure(
plot_width=self.app_marginal_size,
plot_height=self.app_marginal_size,
x_range=self.app_context["data_range"]["z"],
y_range=(np.min(z_marginal_data.values), np.max(z_marginal_data.values)),
x_axis_location="above",
toolbar_location="below",
y_axis_location="right",
tools=z_tools,
)
figures["z_marginal"].xaxis.major_label_text_font_size = "0pt"
figures["z_marginal"].xaxis.axis_label = arr.dims[2]
figures["z_marginal"].toolbar.logo = None
plots["z_marginal"] = figures["z_marginal"].line(
x=z_coords.values, y=z_marginal_data.values
)
plots["z_marginal_err"] = figures["z_marginal"].patch(
x=[], y=[], color=error_fill, fill_alpha=error_alpha, line_color=None
)
info_formatter = info_formatters.get(arr.dims[2], "")
app_widgets["info_div"] = Div(text="", width=self.app_marginal_size, height=100)
if arr.dims[2] == "eV":
# Try to fit a Fermi edge and display it in the plot
try:
fit_data = z_marginal_data.sel(eV=slice(-0.25, 0.2))
z_fit = GStepBModel().guess_fit(fit_data)
plots["z_fit"] = figures["z_marginal"].line(
x=fit_data.coords["eV"].values,
y=z_fit.best_fit,
line_dash="dashed",
line_color="red",
)
self.app_context["z_model"] = GStepBModel
app_widgets["info_div"].text = info_formatter.format(
z_fit.params["center"].value * 1000,
z_fit.params["width"].value * 1000,
)
except:
pass
elif arr.dims[2] == "delay" and "t0" in arr.attrs:
# Try to fit a decay constant to the data after t0
plots["t0_marker"] = figures["z_marginal"].line(
x=[float(arr.attrs.get("t0")), float(arr.attrs.get("t0"))],
y=[0, 1000000],
line_color="black",
line_dash="dashed",
)
t0 = float(arr.attrs["t0"])
self.app_context["z_model"] = ExponentialDecayCModel
try:
after_t0 = z_marginal_data.sel(delay=slice(t0 - 0.2, None))
exp_model = ExponentialDecayCModel()
z_fit = exp_model.guess_fit(after_t0, params={"t0": {"value": t0}})
plots["z_fit"] = figures["z_marginal"].line(
x=after_t0.coords["delay"].values,
y=z_fit.best_fit,
line_dash="dashed",
line_color="red",
)
except Exception as e:
plots["z_fit"] = figures["z_marginal"].line(
x=[], y=[], line_dash="dashed", line_color="red"
)
# Create the bottom marginal plot
bottom_image = arr.sel(**dict([[arr.dims[1], self.cursor[1]]]), method="nearest")
prepped_bottom_image = self.prep_image(bottom_image)
self.app_context["color_maps"]["bottom"] = LinearColorMapper(
default_palette,
low=np.min(prepped_bottom_image),
high=np.max(prepped_bottom_image),
nan_color="black",
)
figures["bottom"] = figure(
plot_width=self.app_main_size,
plot_height=self.app_marginal_size,
title=None,
x_range=figures["main"].x_range,
y_range=figures["z_marginal"].x_range,
x_axis_location="above",
toolbar_location=None,
tools=[],
)
figures["bottom"].xaxis.major_label_text_font_size = "0pt"
plots["bottom"] = figures["bottom"].image(
[prepped_bottom_image.T],
x=self.app_context["data_range"]["x"][0],
y=self.app_context["data_range"]["z"][0],
dw=self.app_context["data_range"]["x"][1] - self.app_context["data_range"]["x"][0],
dh=self.app_context["data_range"]["z"][1] - self.app_context["data_range"]["z"][0],
color_mapper=self.app_context["color_maps"]["bottom"],
)
bottom_marginal = bottom_image.sel(
**dict([[arr.dims[2], self.cursor[2]]]), method="nearest"
)
figures["bottom_marginal"] = figure(
plot_width=self.app_main_size,
plot_height=200,
title=None,
x_range=figures["main"].x_range,
y_range=(np.min(bottom_marginal.values), np.max(bottom_marginal.values)),
x_axis_location="above",
toolbar_location=None,
tools=[],
)
plots["bottom_marginal"] = figures["bottom_marginal"].line(
x=bottom_marginal.coords[arr.dims[0]].values, y=bottom_marginal.values
)
plots["bottom_marginal_err"] = figures["bottom_marginal"].patch(
x=[], y=[], color=error_fill, fill_alpha=error_alpha, line_color=None
)
# Create the right marginal plot
right_image = arr.sel(**dict([[arr.dims[0], self.cursor[0]]]), method="nearest")
prepped_right_image = self.prep_image(right_image)
self.app_context["color_maps"]["right"] = LinearColorMapper(
default_palette,
low=np.min(prepped_right_image),
high=np.max(prepped_right_image),
nan_color="black",
)
figures["right"] = figure(
plot_width=self.app_marginal_size,
plot_height=self.app_main_size,
title=None,
x_range=figures["z_marginal"].x_range,
y_range=figures["main"].y_range,
toolbar_location=None,
tools=[],
)
figures["right"].yaxis.major_label_text_font_size = "0pt"
plots["right"] = figures["right"].image(
[prepped_right_image],
x=self.app_context["data_range"]["z"][0],
y=self.app_context["data_range"]["y"][0],
dw=self.app_context["data_range"]["z"][1] - self.app_context["data_range"]["z"][0],
dh=self.app_context["data_range"]["y"][1] - self.app_context["data_range"]["y"][0],
color_mapper=self.app_context["color_maps"]["right"],
)
right_marginal = right_image.sel(**dict([[arr.dims[2], self.cursor[2]]]), method="nearest")
figures["right_marginal"] = figure(
plot_width=200,
plot_height=self.app_main_size,
title=None,
y_range=figures["main"].y_range,
x_range=(np.min(right_marginal.values), np.max(right_marginal.values)),
y_axis_location="left",
toolbar_location=None,
tools=[],
)
plots["right_marginal"] = figures["right_marginal"].line(
y=right_marginal.coords[arr.dims[1]].values, x=right_marginal.values
)
plots["right_marginal_err"] = figures["right_marginal"].patch(
x=[], y=[], color=error_fill, fill_alpha=error_alpha, line_color=None
)
cursor_lines = self.add_cursor_lines(figures["main"])
# Attach tools and callbacks
toggle = widgets.Toggle(label="Show Stat. Variation", button_type="success", active=False)
def set_show_stat_variation(should_show):
self.app_context["show_stat_variation"] = should_show
if should_show:
main_image_data = arr.sel(**dict([[arr.dims[2], self.cursor[2]]]), method="nearest")
update_stat_variation(
"z",
arr.sel(
**dict([[arr.dims[0], self.cursor[0]], [arr.dims[1], self.cursor[1]]]),
method="nearest"
),
)
update_stat_variation(
"bottom",
main_image_data.sel(**dict([[arr.dims[1], self.cursor[1]]]), method="nearest"),
)
update_stat_variation(
"right",
main_image_data.sel(**dict([[arr.dims[0], self.cursor[0]]]), method="nearest"),
)
plots["z_marginal_err"].visible = True
plots["bottom_marginal_err"].visible = True
plots["right_marginal_err"].visible = True
else:
plots["z_marginal_err"].visible = False
plots["bottom_marginal_err"].visible = False
plots["right_marginal_err"].visible = False
toggle.on_click(set_show_stat_variation)
scan_keys = ["x", "y", "z", "pass_energy", "hv", "location", "id", "probe_pol", "pump_pol"]
scan_info_source = ColumnDataSource(
{
"keys": [k for k in scan_keys if k in arr.attrs],
"values": [
str(v) if isinstance(v, float) and np.isnan(v) else v
for v in [arr.attrs[k] for k in scan_keys if k in arr.attrs]
],
}
)
scan_info_columns = [
widgets.TableColumn(field="keys", title="Attr."),
widgets.TableColumn(field="values", title="Value"),
]
POINTER_MODES = [
("Cursor", "cursor"),
("Path", "path"),
]
COLOR_MODES = [
("Adaptive Hist. Eq. (Slow)", "adaptive_equalization"),
# ('Histogram Eq.', 'equalization',), # not implemented
("Linear", "linear"),
# ('Log', 'log',), # not implemented
]
def on_change_color_mode(event):
new_color_mode = event.item
self.app_context["color_mode"] = new_color_mode
if old is None or old != new_color_mode:
cursor = self.cursor
right_image_data = arr.sel(**dict([[arr.dims[0], cursor[0]]]), method="nearest")
bottom_image_data = arr.sel(**dict([[arr.dims[1], cursor[1]]]), method="nearest")
main_image_data = arr.sel(**dict([[arr.dims[2], cursor[2]]]), method="nearest")
prepped_right_image = self.prep_image(right_image_data)
prepped_bottom_image = self.prep_image(bottom_image_data)
prepped_main_image = self.prep_image(main_image_data)
plots["right"].data_source.data = {"image": [prepped_right_image]}
plots["bottom"].data_source.data = {"image": [prepped_bottom_image.T]}
plots["main"].data_source.data = {"image": [prepped_main_image.T]}
update_right_colormap(None, None, right_color_range_slider.value)
update_bottom_colormap(None, None, bottom_color_range_slider.value)
update_main_colormap(None, None, main_color_range_slider.value)
color_mode_dropdown = widgets.Dropdown(
label="Color Mode", button_type="primary", menu=COLOR_MODES
)
color_mode_dropdown.on_click(on_change_color_mode)
symmetry_point_name_input = widgets.TextInput(title="Symmetry Point Name", value="G")
snap_checkbox = widgets.CheckboxButtonGroup(labels=["Snap Axes"], active=[])
place_symmetry_point_at_cursor_button = widgets.Button(
label="Place Point", button_type="primary"
)
def update_symmetry_points_for_display():
pass
def place_symmetry_point():
cursor_dict = dict(zip(arr.dims, self.cursor))
skip_dimensions = {"eV", "delay", "cycle"}
if "symmetry_points" not in arr.attrs:
arr.attrs["symmetry_points"] = {}
snap_distance = {
"phi": 2,
"beta": 2,
"kx": 0.01,
"ky": 0.01,
"kz": 0.01,
"kp": 0.01,
"hv": 4,
}
cursor_dict = {k: v for k, v in cursor_dict.items() if k not in skip_dimensions}
snapped = copy.copy(cursor_dict)
if "Snap Axes" in [snap_checkbox.labels[i] for i in snap_checkbox.active]:
for axis, value in cursor_dict.items():
options = [
point[axis]
for point in arr.attrs["symmetry_points"].values()
if axis in point
]
options = sorted(options, key=lambda x: np.abs(x - value))
if options and np.abs(options[0] - value) < snap_distance[axis]:
snapped[axis] = options[0]
arr.attrs["symmetry_points"][symmetry_point_name_input.value] = snapped
place_symmetry_point_at_cursor_button.on_click(place_symmetry_point)
main_color_range_slider = widgets.RangeSlider(
start=0,
end=100,
value=(
0,
100,
),
title="Color Range (Main)",
)
right_color_range_slider = widgets.RangeSlider(
start=0,
end=100,
value=(
0,
100,
),
title="Color Range (%s Marginal)" % arr.dims[1],
)
bottom_color_range_slider = widgets.RangeSlider(
start=0,
end=100,
value=(
0,
100,
),
title="Color Range (%s Marginal)" % arr.dims[0],
)
layout = row(
column(figures["main"], figures["bottom"], figures["bottom_marginal"]),
column(
figures["right"],
figures["z_marginal"],
Spacer(width=self.app_marginal_size, height=200),
),
column(
figures["right_marginal"],
Spacer(width=200, height=self.app_marginal_size),
Spacer(width=200, height=200),
),
column(
column(
widgets.Dropdown(
label="Pointer Mode", button_type="primary", menu=POINTER_MODES
)
),
widgets.Tabs(
tabs=[
widgets.Panel(
child=column(
Div(text="<h2>Colorscale:</h2>"),
color_mode_dropdown,
main_color_range_slider,
right_color_range_slider,
bottom_color_range_slider,
Div(text='<h2 style="padding-top: 30px;">General Settings:</h2>'),
toggle,
self._cursor_info,
sizing_mode="scale_width",
),
title="Settings",
),
widgets.Panel(
child=column(
app_widgets["info_div"],
Div(
text='<h2 style="padding-top: 30px; padding-bottom: 10px;">Scan Info</h2>'
),
widgets.DataTable(
source=scan_info_source,
columns=scan_info_columns,
width=400,
height=400,
),
sizing_mode="scale_width",
width=400,
),
title="Info",
),
widgets.Panel(
child=column(
Div(text="<h2>Preparation</h2>"),
symmetry_point_name_input,
snap_checkbox,
place_symmetry_point_at_cursor_button,
sizing_mode="scale_width",
),
title="Preparation",
),
],
width=400,
),
),
)
update_main_colormap = self.update_colormap_for("main")
update_bottom_colormap = self.update_colormap_for("bottom")
update_right_colormap = self.update_colormap_for("right")
def click_z_marginal(event):
self.cursor = [self.cursor[0], self.cursor[1], event.x]
cursor = self.cursor
main_image = arr.sel(**dict([[arr.dims[2], cursor[2]]]), method="nearest")
plots["main"].data_source.data = {"image": [self.prep_image(main_image).T]}
update_main_colormap(None, None, main_color_range_slider.value)
right_marginal_data = main_image.sel(
**dict([[arr.dims[0], cursor[0]]]), method="nearest"
)
bottom_marginal_data = main_image.sel(
**dict([[arr.dims[1], cursor[1]]]), method="nearest"
)
plots["bottom_marginal"].data_source.data = {
"x": bottom_marginal_data.coords[arr.dims[0]].values,
"y": bottom_marginal_data.values,
}
plots["right_marginal"].data_source.data = {
"y": right_marginal_data.coords[arr.dims[1]].values,
"x": right_marginal_data.values,
}
if self.app_context["show_stat_variation"]:
update_stat_variation("right", right_marginal_data)
update_stat_variation("bottom", bottom_marginal_data)
figures["bottom_marginal"].y_range.start = np.min(bottom_marginal_data.values)
figures["bottom_marginal"].y_range.end = np.max(bottom_marginal_data.values)
figures["right_marginal"].x_range.start = np.min(right_marginal_data.values)
figures["right_marginal"].x_range.end = np.max(right_marginal_data.values)
def click_main_image(event):
self.cursor = [event.x, event.y, self.cursor[2]]
cursor = self.cursor
right_image_data = arr.sel(**dict([[arr.dims[0], cursor[0]]]), method="nearest")
bottom_image_data = arr.sel(**dict([[arr.dims[1], cursor[1]]]), method="nearest")
prepped_right_image = self.prep_image(right_image_data)
prepped_bottom_image = self.prep_image(bottom_image_data)
plots["right"].data_source.data = {"image": [prepped_right_image]}
plots["bottom"].data_source.data = {"image": [prepped_bottom_image.T]}
update_right_colormap(None, None, right_color_range_slider.value)
update_bottom_colormap(None, None, bottom_color_range_slider.value)
right_marginal_data = right_image_data.sel(
**dict([[arr.dims[2], cursor[2]]]), method="nearest"
)
bottom_marginal_data = bottom_image_data.sel(
**dict([[arr.dims[2], cursor[2]]]), method="nearest"
)
z_data = arr.sel(
**dict([[arr.dims[0], cursor[0]], [arr.dims[1], cursor[1]]]), method="nearest"
)
plots["z_marginal"].data_source.data = {
"x": z_coords.values,
"y": z_data.values,
}
plots["bottom_marginal"].data_source.data = {
"x": bottom_marginal_data.coords[arr.dims[0]].values,
"y": bottom_marginal_data.values,
}
plots["right_marginal"].data_source.data = {
"y": right_marginal_data.coords[arr.dims[1]].values,
"x": right_marginal_data.values,
}
if self.app_context["show_stat_variation"]:
update_stat_variation("z", z_data)
update_stat_variation("right", right_marginal_data)
update_stat_variation("bottom", bottom_marginal_data)
figures["z_marginal"].y_range.start = np.min(z_data.values)
figures["z_marginal"].y_range.end = np.max(z_data.values)
figures["bottom_marginal"].y_range.start = np.min(bottom_marginal_data.values)
figures["bottom_marginal"].y_range.end = np.max(bottom_marginal_data.values)
figures["right_marginal"].x_range.start = np.min(right_marginal_data.values)
figures["right_marginal"].x_range.end = np.max(right_marginal_data.values)
if "z_fit" in plots and arr.dims[2] == "eV":
z_fit_data = z_data.sel(eV=slice(-0.25, 0.2))
new_z_fit = GStepBModel().guess_fit(z_fit_data)
plots["z_fit"].data_source.data = {
"x": z_fit_data.coords["eV"].values,
"y": new_z_fit.best_fit,
}
app_widgets["info_div"].text = info_formatter.format(
new_z_fit.params["center"].value * 1000,
new_z_fit.params["width"].value * 1000,
)
if "z_fit" in plots and arr.dims[2] == "delay":
try:
t0 = float(arr.attrs.get("t0", 0))
after_t0 = z_data.sel(delay=slice(t0 - 0.2, None))
exp_model = ExponentialDecayCModel()
z_fit = exp_model.guess_fit(after_t0, params={"t0": {"value": t0}})
plots["z_fit"].data_source.data = {
"x": after_t0.coords["delay"],
"y": z_fit.best_fit,
}
except Exception as e:
plots["z_fit"].data_source.data = {
"x": [],
"y": [],
}
self.save_app()
figures["z_marginal"].on_event(events.Tap, click_z_marginal)
figures["main"].on_event(events.Tap, click_main_image)
main_color_range_slider.on_change("value", update_main_colormap)
bottom_color_range_slider.on_change("value", update_bottom_colormap)
right_color_range_slider.on_change("value", update_right_colormap)
doc.add_root(layout)
doc.title = "Bokeh Tool"
self.load_app()
self.save_app()