Source code for arpes.plotting.dyn_tool
"""Allows for making any function of a spectrum into a dynamic tool with Bokeh."""
import inspect
import numpy as np
import typing
from arpes.exceptions import AnalysisError
from arpes.plotting.interactive_utils import BokehInteractiveTool, CursorTool
from arpes.typing import DataType
from arpes.utilities import Debounce, normalize_to_spectrum
__all__ = (
"DynamicTool",
"dyn",
)
[docs]class DynamicTool(BokehInteractiveTool, CursorTool):
"""Presents a utility to rerun a function with different arguments and see the result of the function."""
[docs] def __init__(self, analysis_fn, widget_specification, **kwargs):
"""Initialize the tool and load settings from the user specified ones."""
super().__init__()
self.load_settings(**kwargs)
self.analysis_fn = analysis_fn
self.widget_specification = widget_specification
self.app_main_size = self.settings.get("main_width", 600)
self.app_marginal_size = self.settings.get("marginal_width", 300)
def tool_handler(self, doc):
"""Configures widgets for the dynamic tool.
In order to accomplish this, we need to inspect the type signature
for the function and generate inputs for it dynamically.
"""
from bokeh import events
from bokeh.layouts import row, column
from bokeh.models.mappers import LinearColorMapper
from bokeh.models import widgets
from bokeh.plotting import figure
if len(self.arr.shape) != 2:
raise AnalysisError("Cannot use the band tool on non image-like spectra")
self.data_for_display = self.arr
x_coords, y_coords = (
self.data_for_display.coords[self.data_for_display.dims[0]],
self.data_for_display.coords[self.data_for_display.dims[1]],
)
default_palette = self.default_palette
self.app_context.update(
{
"data": self.arr,
"data_range": {
"x": (np.min(x_coords.values), np.max(x_coords.values)),
"y": (np.min(y_coords.values), np.max(y_coords.values)),
},
}
)
figures, plots = self.app_context["figures"], self.app_context["plots"]
self.cursor = [np.mean(self.data_range["x"]), np.mean(self.data_range["y"])]
self.color_maps["main"] = LinearColorMapper(
default_palette,
low=np.min(self.data_for_display.values),
high=np.max(self.data_for_display.values),
nan_color="black",
)
main_tools = ["wheel_zoom", "tap", "reset"]
main_title = "{} Tool: WARNING Unidentified".format(self.analysis_fn.__name__)
try:
main_title = "{} Tool: {}".format(
self.analysis_fn.__name__, self.data_for_display.S.label[:60]
)
except:
pass
figures["main"] = figure(
tools=main_tools,
plot_width=self.app_main_size,
plot_height=self.app_main_size,
min_border=10,
min_border_left=50,
toolbar_location="left",
x_axis_location="below",
y_axis_location="right",
title=main_title,
x_range=self.data_range["x"],
y_range=self.data_range["y"],
)
figures["main"].xaxis.axis_label = self.data_for_display.dims[0]
figures["main"].yaxis.axis_label = self.data_for_display.dims[1]
figures["main"].toolbar.logo = None
figures["main"].background_fill_color = "#fafafa"
plots["main"] = figures["main"].image(
[self.data_for_display.values.T],
x=self.app_context["data_range"]["x"][0],
y=self.app_context["data_range"]["y"][0],
dw=self.app_context["data_range"]["x"][1] - self.app_context["data_range"]["x"][0],
dh=self.app_context["data_range"]["y"][1] - self.app_context["data_range"]["y"][0],
color_mapper=self.app_context["color_maps"]["main"],
)
# Create the bottom marginal plot
bottom_marginal = self.data_for_display.sel(
**dict([[self.data_for_display.dims[1], self.cursor[1]]]), method="nearest"
)
bottom_marginal_original = self.arr.sel(
**dict([[self.data_for_display.dims[1], self.cursor[1]]]), method="nearest"
)
figures["bottom_marginal"] = figure(
plot_width=self.app_main_size,
plot_height=200,
title=None,
x_range=figures["main"].x_range,
y_range=(np.min(bottom_marginal.values), np.max(bottom_marginal.values)),
x_axis_location="above",
toolbar_location=None,
tools=[],
)
plots["bottom_marginal"] = figures["bottom_marginal"].line(
x=bottom_marginal.coords[self.data_for_display.dims[0]].values, y=bottom_marginal.values
)
plots["bottom_marginal_original"] = figures["bottom_marginal"].line(
x=bottom_marginal_original.coords[self.arr.dims[0]].values,
y=bottom_marginal_original.values,
line_color="red",
)
# Create the right marginal plot
right_marginal = self.data_for_display.sel(
**dict([[self.data_for_display.dims[0], self.cursor[0]]]), method="nearest"
)
right_marginal_original = self.arr.sel(
**dict([[self.data_for_display.dims[0], self.cursor[0]]]), method="nearest"
)
figures["right_marginal"] = figure(
plot_width=200,
plot_height=self.app_main_size,
title=None,
y_range=figures["main"].y_range,
x_range=(np.min(right_marginal.values), np.max(right_marginal.values)),
y_axis_location="left",
toolbar_location=None,
tools=[],
)
plots["right_marginal"] = figures["right_marginal"].line(
y=right_marginal.coords[self.data_for_display.dims[1]].values, x=right_marginal.values
)
plots["right_marginal_original"] = figures["right_marginal"].line(
y=right_marginal_original.coords[self.data_for_display.dims[1]].values,
x=right_marginal_original.values,
line_color="red",
)
# add lines
self.add_cursor_lines(figures["main"])
_ = figures["main"].multi_line(xs=[], ys=[], line_color="white", line_width=1) # band lines
# prep the widgets for the analysis function
signature = inspect.signature(self.analysis_fn)
# drop the first which has to be the input data, we can revisit later if this is too limiting
parameter_names = list(signature.parameters)[1:]
named_widgets = dict(zip(parameter_names, self.widget_specification))
built_widgets = {}
def update_marginals():
right_marginal_data = self.data_for_display.sel(
**dict([[self.data_for_display.dims[0], self.cursor[0]]]), method="nearest"
)
bottom_marginal_data = self.data_for_display.sel(
**dict([[self.data_for_display.dims[1], self.cursor[1]]]), method="nearest"
)
plots["bottom_marginal"].data_source.data = {
"x": bottom_marginal_data.coords[self.data_for_display.dims[0]].values,
"y": bottom_marginal_data.values,
}
plots["right_marginal"].data_source.data = {
"y": right_marginal_data.coords[self.data_for_display.dims[1]].values,
"x": right_marginal_data.values,
}
right_marginal_data = self.arr.sel(
**dict([[self.data_for_display.dims[0], self.cursor[0]]]), method="nearest"
)
bottom_marginal_data = self.arr.sel(
**dict([[self.data_for_display.dims[1], self.cursor[1]]]), method="nearest"
)
plots["bottom_marginal_original"].data_source.data = {
"x": bottom_marginal_data.coords[self.data_for_display.dims[0]].values,
"y": bottom_marginal_data.values,
}
plots["right_marginal_original"].data_source.data = {
"y": right_marginal_data.coords[self.data_for_display.dims[1]].values,
"x": right_marginal_data.values,
}
figures["bottom_marginal"].y_range.start = np.min(bottom_marginal_data.values)
figures["bottom_marginal"].y_range.end = np.max(bottom_marginal_data.values)
figures["right_marginal"].x_range.start = np.min(right_marginal_data.values)
figures["right_marginal"].x_range.end = np.max(right_marginal_data.values)
def click_main_image(event):
self.cursor = [event.x, event.y]
update_marginals()
error_msg = widgets.Div(text="")
@Debounce(0.25)
def update_data_for_display():
try:
self.data_for_display = self.analysis_fn(
self.arr,
*[built_widgets[p].value for p in parameter_names if p in built_widgets]
)
error_msg.text = ""
except Exception as e:
error_msg.text = "{}".format(e)
# flush + update
update_marginals()
plots["main"].data_source.data = {"image": [self.data_for_display.values.T]}
def update_data_change_wrapper(attr, old, new):
if old != new:
update_data_for_display()
for parameter_name in named_widgets.keys():
specification = named_widgets[parameter_name]
widget = None
if specification["type"] == int:
widget = widgets.Slider(
start=specification["start"],
end=specification["end"],
value=specification["value"],
title=parameter_name,
)
if specification["type"] == float:
widget = widgets.Slider(
start=specification["start"],
end=specification["end"],
value=specification["value"],
step=specification["step"],
title=parameter_name,
)
if widget is not None:
built_widgets[parameter_name] = widget
widget.on_change("value", update_data_change_wrapper)
update_main_colormap = self.update_colormap_for("main")
self.app_context["run"] = lambda x: x
main_color_range_slider = widgets.RangeSlider(
start=0,
end=100,
value=(
0,
100,
),
title="Color Range",
)
# Attach callbacks
main_color_range_slider.on_change("value", update_main_colormap)
figures["main"].on_event(events.Tap, click_main_image)
layout = row(
column(figures["main"], figures["bottom_marginal"]),
column(figures["right_marginal"]),
column(
column(*[built_widgets[p] for p in parameter_names if p in built_widgets]),
column(
self._cursor_info,
main_color_range_slider,
error_msg,
),
),
)
doc.add_root(layout)
doc.title = "Band Tool"
[docs]def dyn(dynamic_function: typing.Callable, data: DataType, widget_specifications=None):
"""Starts the dynamic tool using `dynamic_function` and widgets for each arg."""
data = normalize_to_spectrum(data)
tool = DynamicTool(dynamic_function, widget_specifications)
return tool.make_tool(data)