Source code for arpes.plotting.mask_tool
"""Utilities for selecting and defining masks on data interactively."""
import numpy as np
from arpes.analysis.mask import apply_mask
from arpes.exceptions import AnalysisError
from arpes.plotting.interactive_utils import CursorTool, SaveableTool
from arpes.typing import DataType
from arpes.utilities import normalize_to_spectrum
__all__ = ["mask"]
[docs]class MaskTool(SaveableTool, CursorTool):
"""Tool to allow masking data by drawing regions."""
auto_zero_nans = False
auto_rebin = False
[docs] def __init__(self, **kwargs):
super().__init__(kwargs.pop("name", None))
self.load_settings(**kwargs)
self.app_main_size = self.settings.get("app_main_size", 600)
self.app_marginal_size = self.settings.get("app_main_size", 300)
self.pointer_mode = "region"
def tool_handler(self, doc):
from bokeh import events
from bokeh.layouts import row, column
from bokeh.models.mappers import LinearColorMapper
from bokeh.models import widgets
from bokeh.plotting import figure
if len(self.arr.shape) != 2:
raise AnalysisError("Cannot use mask tool on non image-like spectra")
arr = self.arr
x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]
default_palette = self.default_palette
self.app_context.update(
{
"region_options": [],
"regions": {},
"data": arr,
"data_range": {
"x": (np.min(x_coords.values), np.max(x_coords.values)),
"y": (np.min(y_coords.values), np.max(y_coords.values)),
},
}
)
self.cursor = [np.mean(self.data_range["x"]), np.mean(self.data_range["y"])]
self.color_maps["main"] = LinearColorMapper(
default_palette, low=np.min(arr.values), high=np.max(arr.values), nan_color="black"
)
main_tools = ["wheel_zoom", "tap", "reset"]
main_title = "Mask Tool: WARNING Unidentified"
try:
main_title = "Mask Tool: {}".format(arr.S.label[:60])
except:
pass
self.figures["main"] = figure(
tools=main_tools,
plot_width=self.app_main_size,
plot_height=self.app_main_size,
min_border=10,
min_border_left=20,
toolbar_location="left",
x_axis_location="below",
y_axis_location="right",
title=main_title,
x_range=self.data_range["x"],
y_range=self.data_range["y"],
)
self.figures["main"].xaxis.axis_label = arr.dims[0]
self.figures["main"].yaxis.axis_label = arr.dims[1]
self.plots["main"] = self.figures["main"].image(
[np.asarray(arr.values.T)],
x=self.data_range["x"][0],
y=self.data_range["y"][0],
dw=self.data_range["x"][1] - self.data_range["x"][0],
dh=self.data_range["y"][1] - self.data_range["y"][0],
color_mapper=self.color_maps["main"],
)
self.add_cursor_lines(self.figures["main"])
region_patches = self.figures["main"].patches(
xs=[], ys=[], color="white", alpha=0.35, line_width=1
)
def add_point_to_region():
if self.active_region in self.regions:
self.regions[self.active_region]["points"].append(list(self.cursor))
update_region_display()
self.save_app()
def click_main_image(event):
self.cursor = [event.x, event.y]
if self.pointer_mode == "region":
add_point_to_region()
POINTER_MODES = [
(
"Cursor",
"cursor",
),
(
"Region",
"region",
),
]
def perform_mask(data=None, **kwargs):
if data is None:
data = arr
data = normalize_to_spectrum(data)
return apply_mask(data, self.app_context["mask"], **kwargs)
self.app_context["perform_mask"] = perform_mask
self.app_context["mask"] = None
pointer_dropdown = widgets.Dropdown(
label="Pointer Mode", button_type="primary", menu=POINTER_MODES
)
self.region_dropdown = widgets.Dropdown(
label="Active Region", button_type="primary", menu=self.region_options
)
edge_mask_button = widgets.Button(label="Edge Mask")
region_name_input = widgets.TextInput(placeholder="Region name...")
add_region_button = widgets.Button(label="Add Region")
clear_region_button = widgets.Button(label="Clear Region")
remove_region_button = widgets.Button(label="Remove Region")
main_color_range_slider = widgets.RangeSlider(
start=0,
end=100,
value=(
0,
100,
),
title="Color Range",
)
def on_click_edge_mask():
if self.active_region in self.regions:
old_points = self.regions[self.active_region]["points"]
dims = [d for d in arr.dims if "eV" != d]
energy_index = arr.dims.index("eV")
max_energy = np.max([p[energy_index] for p in old_points])
other_dim = dims[0]
other_coord = arr.coords[other_dim].values
min_other, max_other = np.min(other_coord), np.max(other_coord)
min_e = np.min(arr.coords["eV"].values)
if arr.dims.index("eV") == 0:
before = [[min_e - 3, min_other - 1], [0, min_other - 1]]
after = [[0, max_other + 1], [min_e - 3, max_other + 1]]
else:
before = [[min_other - 1, min_e - 3], [min_other - 1, 0]]
after = [[max_other + 1, 0], [max_other + 1, min_e - 3]]
self.regions[self.active_region]["points"] = before + old_points + after
self.app_context["mask"] = self.app_context["mask"] or {}
self.app_context["mask"]["fermi"] = max_energy
update_region_display()
self.save_app()
def add_region(region_name):
if region_name not in self.regions:
self.region_options.append(
(
region_name,
region_name,
)
)
self.region_dropdown.menu = self.region_options
self.regions[region_name] = {
"points": [],
"name": region_name,
}
if self.active_region is None:
self.active_region = region_name
self.save_app()
def on_change_active_region(event):
region_id = event.item
self.app_context["active_region"] = region_id
self.active_region = region_id
self.save_app()
def on_change_pointer_mode(event):
pointer_mode = event.item
self.app_context["pointer_mode"] = pointer_mode
self.pointer_mode = pointer_mode
self.save_app()
def update_region_display():
region_names = self.regions.keys()
if self.app_context["mask"] is None:
self.app_context["mask"] = {}
self.app_context["mask"].update(
{"dims": arr.dims, "polys": [r["points"] for r in self.regions.values()]}
)
region_patches.data_source.data = {
"xs": [[p[0] for p in self.regions[r]["points"]] for r in region_names],
"ys": [[p[1] for p in self.regions[r]["points"]] for r in region_names],
}
self.save_app()
self.update_region_display = update_region_display
def on_clear_region():
if self.active_region in self.regions:
self.regions[self.active_region]["points"] = []
update_region_display()
def on_remove_region():
if self.active_region in self.regions:
del self.regions[self.active_region]
new_region_options = [b for b in self.region_options if b[0] != self.active_region]
self.region_dropdown.menu = new_region_options
self.region_options = new_region_options
self.active_region = None
update_region_display()
# Attach callbacks
main_color_range_slider.on_change("value", self.update_colormap_for("main"))
self.figures["main"].on_event(events.Tap, click_main_image)
self.region_dropdown.on_click(on_change_active_region)
pointer_dropdown.on_click(on_change_pointer_mode)
add_region_button.on_click(lambda: add_region(region_name_input.value))
edge_mask_button.on_click(on_click_edge_mask)
clear_region_button.on_click(on_clear_region)
remove_region_button.on_click(on_remove_region)
layout = row(
column(self.figures["main"]),
column(
*[
f
for f in [
column(
pointer_dropdown,
self.region_dropdown,
),
row(
region_name_input,
add_region_button,
),
edge_mask_button if "eV" in arr.dims else None,
row(
clear_region_button,
remove_region_button,
),
column(
self._cursor_info,
main_color_range_slider,
),
]
if f is not None
]
),
)
doc.add_root(layout)
doc.title = "Mask Tool"
self.load_app()
self.save_app()
def serialize(self):
return {
"active_region": self.active_region,
"region_options": self.region_options,
"regions": self.regions,
"cursor": self.cursor,
}
def deserialize(self, json_data):
self.cursor = json_data.get("cursor", [0, 0])
self.app_context["regions"] = json_data.get("regions", {}) or {}
self.app_context["region_options"] = json_data.get("region_options", [])
self.region_dropdown.menu = self.app_context["region_options"]
self.update_region_display()
def mask(data: DataType, **kwargs):
"""Start an interactive mask selection tool."""
data = normalize_to_spectrum(data)
tool = MaskTool(**kwargs)
return tool.make_tool(data)