Source code for arpes.plotting.path_tool
"""A utility for selecting paths on a marginal of your data."""
import numpy as np
import xarray as xr
from arpes.analysis.path import select_along_path
from arpes.exceptions import AnalysisError
from arpes.plotting.interactive_utils import CursorTool, SaveableTool
from arpes.typing import DataType
from arpes.utilities import normalize_to_spectrum
__all__ = ["path_tool"]
[docs]class PathTool(SaveableTool, CursorTool):
"""Tool to allow drawing paths on data, creating selections, and masking regions around paths.
Integrates with the tools in arpes.analysis.path
"""
auto_zero_nans = False
auto_rebin = False
[docs] def __init__(self, **kwargs):
super().__init__(kwargs.pop("name", None))
self.load_settings(**kwargs)
self.app_main_size = self.settings.get("app_main_size", 600)
self.app_marginal_size = self.settings.get("app_main_size", 300)
self.pointer_mode = "path"
def tool_handler(self, doc):
from bokeh import events
from bokeh.layouts import row, column
from bokeh.models.mappers import LinearColorMapper
from bokeh.models import widgets, warnings
from bokeh.plotting import figure
if len(self.arr.shape) != 2:
raise AnalysisError("Cannot use path tool on non image-like spectra")
arr = self.arr
x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]
default_palette = self.default_palette
self.app_context.update(
{
"path_options": [],
"active_path": None,
"paths": {},
"data": arr,
"data_range": {
"x": (np.min(x_coords.values), np.max(x_coords.values)),
"y": (np.min(y_coords.values), np.max(y_coords.values)),
},
}
)
self.cursor = [np.mean(self.data_range["x"]), np.mean(self.data_range["y"])]
self.color_maps["main"] = LinearColorMapper(
default_palette, low=np.min(arr.values), high=np.max(arr.values), nan_color="black"
)
main_tools = ["wheel_zoom", "tap", "reset"]
main_title = "Path Tool: WARNING Unidentified"
try:
main_title = "Path Tool: {}".format(arr.S.label[:60])
except:
pass
self.figures["main"] = figure(
tools=main_tools,
plot_width=self.app_main_size,
plot_height=self.app_main_size,
min_border=10,
min_border_left=20,
toolbar_location="left",
x_axis_location="below",
y_axis_location="right",
title=main_title,
x_range=self.data_range["x"],
y_range=self.data_range["y"],
)
self.figures["main"].xaxis.axis_label = arr.dims[0]
self.figures["main"].yaxis.axis_label = arr.dims[1]
self.plots["main"] = self.figures["main"].image(
[np.asarray(arr.values.T)],
x=self.data_range["x"][0],
y=self.data_range["y"][0],
dw=self.data_range["x"][1] - self.data_range["x"][0],
dh=self.data_range["y"][1] - self.data_range["y"][0],
color_mapper=self.color_maps["main"],
)
self.plots["paths"] = self.figures["main"].multi_line(
xs=[], ys=[], line_color="white", line_width=2
)
self.add_cursor_lines(self.figures["main"])
def add_point_to_path():
if self.active_path in self.paths:
self.paths[self.active_path]["points"].append(list(self.cursor))
update_path_display()
self.save_app()
def click_main_image(event):
self.cursor = [event.x, event.y]
if self.pointer_mode == "path":
add_point_to_path()
POINTER_MODES = [
("Cursor", "cursor"),
("Path", "path"),
]
def convert_to_xarray():
"""Creates a Dataset consisting of one array for each path.
For each of the paths, we will create a dataset which has an index dimension,
and datavariables for each of the coordinate dimensions
"""
def convert_single_path_to_xarray(points):
vars = {d: np.array([p[i] for p in points]) for i, d in enumerate(self.arr.dims)}
coords = {
"index": np.array(range(len(points))),
}
vars = {k: xr.DataArray(v, coords=coords, dims=["index"]) for k, v in vars.items()}
return xr.Dataset(data_vars=vars, coords=coords)
return {
path["name"]: convert_single_path_to_xarray(path["points"])
for path in self.paths.values()
}
def select(data=None, **kwargs):
if data is None:
data = self.arr
if len(self.paths) > 1:
warnings.warn("Only using first path.")
return select_along_path(
path=list(convert_to_xarray().items())[0][1], data=data, **kwargs
)
self.app_context["to_xarray"] = convert_to_xarray
self.app_context["select"] = select
pointer_dropdown = widgets.Dropdown(
label="Pointer Mode", button_type="primary", menu=POINTER_MODES
)
self.path_dropdown = widgets.Dropdown(
label="Active Path", button_type="primary", menu=self.path_options
)
path_name_input = widgets.TextInput(placeholder="Path name...")
add_path_button = widgets.Button(label="Add Path")
clear_path_button = widgets.Button(label="Clear Path")
remove_path_button = widgets.Button(label="Remove Path")
main_color_range_slider = widgets.RangeSlider(
start=0,
end=100,
value=(
0,
100,
),
title="Color Range",
)
def add_path(path_name):
if path_name not in self.paths:
self.path_options.append(
(
path_name,
path_name,
)
)
self.path_dropdown.menu = self.path_options
self.paths[path_name] = {
"points": [],
"name": path_name,
}
if self.active_path is None:
self.active_path = path_name
self.save_app()
def on_change_active_path(event):
path_id = event.item
self.debug_text = path_id
self.app_context["active_path"] = path_id
self.active_path = path_id
self.save_app()
def on_change_pointer_mode(event):
pointer_mode = event.item
self.app_context["pointer_mode"] = pointer_mode
self.pointer_mode = pointer_mode
self.save_app()
def update_path_display():
self.plots["paths"].data_source.data = {
"xs": [[point[0] for point in p["points"]] for p in self.paths.values()],
"ys": [[point[1] for point in p["points"]] for p in self.paths.values()],
}
self.save_app()
self.update_path_display = update_path_display
def on_clear_path():
if self.active_path in self.paths:
self.paths[self.active_path]["points"] = []
update_path_display()
def on_remove_path():
if self.active_path in self.paths:
del self.paths[self.active_path]
new_path_options = [b for b in self.path_options if b[0] != self.active_path]
self.path_dropdown.menu = new_path_options
self.path_options = new_path_options
self.active_path = None
update_path_display()
# Attach callbacks
main_color_range_slider.on_change("value", self.update_colormap_for("main"))
self.figures["main"].on_event(events.Tap, click_main_image)
self.path_dropdown.on_click(on_change_active_path)
pointer_dropdown.on_click(on_change_pointer_mode)
add_path_button.on_click(lambda: add_path(path_name_input.value))
clear_path_button.on_click(on_clear_path)
remove_path_button.on_click(on_remove_path)
layout = row(
column(self.figures["main"]),
column(
column(
pointer_dropdown,
self.path_dropdown,
),
row(
path_name_input,
add_path_button,
),
row(
clear_path_button,
remove_path_button,
),
column(
self._cursor_info,
main_color_range_slider,
),
self.debug_div,
),
)
doc.add_root(layout)
doc.title = "Path Tool"
self.load_app()
self.save_app()
def serialize(self):
return {
"active_path": self.active_path,
"path_options": self.path_options,
"paths": self.paths,
"cursor": self.cursor,
}
def deserialize(self, json_data):
self.cursor = json_data.get("cursor", [0, 0])
self.app_context["paths"] = json_data.get("paths", {}) or {}
self.app_context["path_options"] = json_data.get("path_options", [])
self.path_dropdown.menu = self.app_context["path_options"]
self.update_path_display()
def path_tool(data: DataType, **kwargs):
"""Opens the path tool for the given data."""
data = normalize_to_spectrum(data)
tool = PathTool(**kwargs)
return tool.make_tool(data)