Source code for arpes.plotting.fit_inspection_tool

"""Provides Bokeh based utilities for inspecting fits and fit quality."""
import numpy as np
from bokeh import events

import xarray as xr
from arpes.plotting.interactive_utils import BokehInteractiveTool, CursorTool

__all__ = ("FitCheckTool",)


[docs]class FitCheckTool(BokehInteractiveTool, CursorTool): """Interactive verification of fit quality for broadcast fits.""" auto_zero_nans = False auto_rebin = False
[docs] def __init__(self, **kwargs): """Loads marginal sizes and configures initial application settings.""" super().__init__() self.load_settings(**kwargs) self.app_main_size = self.settings.get("main_width", 600) self.app_marginal_size = self.settings.get("marginal_width", 300) self.selected_data = "data" self.use_dataset = True self.remove_outliers = True self.outlier_clip = 1
def tool_handler(self, doc): """Defines the application widgets and UI interactions.""" from bokeh.layouts import row, column, Spacer from bokeh.models.mappers import LinearColorMapper from bokeh.models import widgets from bokeh.models.widgets.markups import Div from bokeh.plotting import figure self.arr = self.arr.copy(deep=True) if not isinstance(self.arr, xr.Dataset): self.use_dataset = False residual = None if self.use_dataset: raw_data = self.arr.data raw_data.values[np.isnan(raw_data.values)] = 0 fit_results = self.arr.results residual = self.arr.residual residual.values[np.isnan(residual.values)] = 0 else: raw_data = self.arr.attrs["original_data"] fit_results = self.arr fit_direction = [d for d in raw_data.dims if d not in fit_results.dims] fit_direction = fit_direction[0] two_dimensional = False if len(raw_data.dims) != 2: two_dimensional = True x_coords, y_coords = ( fit_results.coords[fit_results.dims[0]], fit_results.coords[fit_results.dims[1]], ) z_coords = raw_data.coords[fit_direction] else: x_coords, y_coords = ( raw_data.coords[raw_data.dims[0]], raw_data.coords[raw_data.dims[1]], ) if two_dimensional: self.settings["palette"] = "coolwarm" default_palette = self.default_palette self.app_context.update( { "data": raw_data, "fits": fit_results, "residual": residual, "original": self.arr, "data_range": { "x": (np.min(x_coords.values), np.max(x_coords.values)), "y": (np.min(y_coords.values), np.max(y_coords.values)), }, } ) if two_dimensional: self.app_context["data_range"]["z"] = (np.min(z_coords.values), np.max(z_coords.values)) figures, plots, app_widgets = ( self.app_context["figures"], self.app_context["plots"], self.app_context["widgets"], ) self.cursor_dims = raw_data.dims if two_dimensional: self.cursor = [ np.mean(self.data_range["x"]), np.mean(self.data_range["y"]), np.mean(self.data_range["z"]), ] else: self.cursor = [np.mean(self.data_range["x"]), np.mean(self.data_range["y"])] app_widgets["fit_info_div"] = Div(text="") self.app_context["color_maps"]["main"] = LinearColorMapper( default_palette, low=np.min(raw_data.values), high=np.max(raw_data.values), nan_color="black", ) main_tools = ["wheel_zoom", "tap", "reset", "save"] main_title = "Fit Inspection Tool: WARNING Unidentified" try: main_title = "Fit Inspection Tool: {}".format(raw_data.S.label[:60]) except: pass figures["main"] = figure( tools=main_tools, plot_width=self.app_main_size, plot_height=self.app_main_size, min_border=10, min_border_left=50, toolbar_location="left", x_axis_location="below", y_axis_location="right", title=main_title, x_range=self.data_range["x"], y_range=self.app_context["data_range"]["y"], ) figures["main"].xaxis.axis_label = raw_data.dims[0] figures["main"].yaxis.axis_label = raw_data.dims[1] figures["main"].toolbar.logo = None figures["main"].background_fill_color = "#fafafa" data_for_main = raw_data if two_dimensional: data_for_main = data_for_main.sel( **dict([[fit_direction, self.cursor[2]]]), method="nearest" ) plots["main"] = figures["main"].image( [data_for_main.values.T], x=self.app_context["data_range"]["x"][0], y=self.app_context["data_range"]["y"][0], dw=self.app_context["data_range"]["x"][1] - self.app_context["data_range"]["x"][0], dh=self.app_context["data_range"]["y"][1] - self.app_context["data_range"]["y"][0], color_mapper=self.app_context["color_maps"]["main"], ) band_centers = [b.center for b in fit_results.F.bands.values()] bands_xs = [b.coords[b.dims[0]].values for b in band_centers] bands_ys = [b.values for b in band_centers] if fit_results.dims[0] == raw_data.dims[1]: bands_ys, bands_xs = bands_xs, bands_ys plots["band_locations"] = figures["main"].multi_line( xs=bands_xs, ys=bands_ys, line_color="white", line_width=1, line_dash="dashed" ) # add cursor lines cursor_lines = self.add_cursor_lines(figures["main"]) # marginals if not two_dimensional: figures["bottom"] = figure( plot_width=self.app_main_size, plot_height=self.app_marginal_size, min_border=10, title=None, x_range=figures["main"].x_range, x_axis_location="above", toolbar_location=None, tools=[], ) else: figures["bottom"] = Spacer(width=self.app_main_size, height=self.app_marginal_size) right_y_range = figures["main"].y_range if two_dimensional: right_y_range = self.data_range["z"] figures["right"] = figure( plot_width=self.app_marginal_size, plot_height=self.app_main_size, min_border=10, title=None, y_range=right_y_range, y_axis_location="left", toolbar_location=None, tools=[], ) marginal_line_width = 2 if not two_dimensional: bottom_data = raw_data.sel( **dict([[raw_data.dims[1], self.cursor[1]]]), method="nearest" ) right_data = raw_data.sel( **dict([[raw_data.dims[0], self.cursor[0]]]), method="nearest" ) plots["bottom"] = figures["bottom"].line( x=bottom_data.coords[raw_data.dims[0]].values, y=bottom_data.values, line_width=marginal_line_width, ) plots["bottom_residual"] = figures["bottom"].line( x=[], y=[], line_color="red", line_width=marginal_line_width ) plots["bottom_fit"] = figures["bottom"].line( x=[], y=[], line_color="blue", line_width=marginal_line_width, line_dash="dashed" ) plots["bottom_init_fit"] = figures["bottom"].line( x=[], y=[], line_color="green", line_width=marginal_line_width, line_dash="dotted" ) plots["right"] = figures["right"].line( y=right_data.coords[raw_data.dims[1]].values, x=right_data.values, line_width=marginal_line_width, ) plots["right_residual"] = figures["right"].line( x=[], y=[], line_color="red", line_width=marginal_line_width ) plots["right_fit"] = figures["right"].line( x=[], y=[], line_color="blue", line_width=marginal_line_width, line_dash="dashed" ) plots["right_init_fit"] = figures["right"].line( x=[], y=[], line_color="green", line_width=marginal_line_width, line_dash="dotted" ) else: right_data = raw_data.sel( **{k: v for k, v in self.cursor_dict.items() if k != fit_direction}, method="nearest" ) plots["right"] = figures["right"].line( y=right_data.coords[right_data.dims[0]].values, x=right_data.values, line_width=marginal_line_width, ) plots["right_residual"] = figures["right"].line( x=[], y=[], line_color="red", line_width=marginal_line_width ) plots["right_fit"] = figures["right"].line( x=[], y=[], line_color="blue", line_width=marginal_line_width, line_dash="dashed" ) plots["right_init_fit"] = figures["right"].line( x=[], y=[], line_color="green", line_width=marginal_line_width, line_dash="dotted" ) def on_change_main_view(event): data_source = event.item self.selected_data = data_source data = None if data_source == "data": data = raw_data.sel( **{k: v for k, v in self.cursor_dict.items() if k == fit_direction}, method="nearest" ) elif data_source == "residual": data = residual.sel( **{k: v for k, v in self.cursor_dict.items() if k == fit_direction}, method="nearest" ) elif two_dimensional: data = fit_results.F.s(data_source) data.values[np.isnan(data.values)] = 0 if data is not None: if self.remove_outliers: data = data.G.clean_outliers(clip=self.outlier_clip) plots["main"].data_source.data = { "image": [data.values.T], } update_main_colormap(None, None, main_color_range_slider.value) def update_fit_display(): target = "right" if fit_results.dims[0] == raw_data.dims[1]: target = "bottom" if two_dimensional: target = "right" current_fit = fit_results.sel( **{k: v for k, v in self.cursor_dict.items() if k != fit_direction}, method="nearest" ).item() coord_vals = raw_data.coords[fit_direction].values else: current_fit = fit_results.sel( **dict([[fit_results.dims[0], self.cursor[0 if target == "right" else 1]]]), method="nearest" ).item() coord_vals = raw_data.coords[raw_data.dims[0 if target == "bottom" else 1]].values if current_fit is not None: app_widgets["fit_info_div"].text = current_fit._repr_html_( short=True ) # pylint: disable=protected-access else: app_widgets["fit_info_div"].text = "No fit here." plots["{}_residual".format(target)].data_source.data = { "x": [], "y": [], } plots["{}_fit".format(target)].data_source.data = { "x": [], "y": [], } plots["{}_init_fit".format(target)].data_source.data = { "x": [], "y": [], } return if target == "bottom": residual_x = coord_vals residual_y = current_fit.residual init_fit_x = coord_vals init_fit_y = current_fit.init_fit fit_x = coord_vals fit_y = current_fit.best_fit else: residual_y = coord_vals residual_x = current_fit.residual init_fit_y = coord_vals init_fit_x = current_fit.init_fit fit_y = coord_vals fit_x = current_fit.best_fit plots["{}_residual".format(target)].data_source.data = { "x": residual_x, "y": residual_y, } plots["{}_fit".format(target)].data_source.data = { "x": fit_x, "y": fit_y, } plots["{}_init_fit".format(target)].data_source.data = { "x": init_fit_x, "y": init_fit_y, } def click_right_marginal(event): self.cursor = [self.cursor[0], self.cursor[1], event.y] on_change_main_view(None, None, self.selected_data) def click_main_image(event): if two_dimensional: self.cursor = [event.x, event.y, self.cursor[2]] else: self.cursor = [event.x, event.y] if not two_dimensional: right_marginal_data = raw_data.sel( **dict([[raw_data.dims[0], self.cursor[0]]]), method="nearest" ) bottom_marginal_data = raw_data.sel( **dict([[raw_data.dims[1], self.cursor[1]]]), method="nearest" ) plots["bottom"].data_source.data = { "x": bottom_marginal_data.coords[raw_data.dims[0]].values, "y": bottom_marginal_data.values, } else: right_marginal_data = raw_data.sel( **{k: v for k, v in self.cursor_dict.items() if k != fit_direction}, method="nearest" ) plots["right"].data_source.data = { "y": right_marginal_data.coords[right_marginal_data.dims[0]].values, "x": right_marginal_data.values, } update_fit_display() def on_change_outlier_clip(attr, old, new): self.outlier_clip = new on_change_main_view(None, None, self.selected_data) def set_remove_outliers(should_remove_outliers): if self.remove_outliers != should_remove_outliers: self.remove_outliers = should_remove_outliers on_change_main_view(None, None, self.selected_data) update_main_colormap = self.update_colormap_for("main") MAIN_CONTENT_OPTIONS = [ ("Residual", "residual"), ("Data", "data"), ] if two_dimensional: available_parameters = fit_results.F.parameter_names for param_name in available_parameters: MAIN_CONTENT_OPTIONS.append( ( param_name, param_name, ) ) remove_outliers_toggle = widgets.Toggle( label="Remove Outliers", button_type="primary", active=self.remove_outliers ) remove_outliers_toggle.on_click(set_remove_outliers) outlier_clip_slider = widgets.Slider( title="Clip", start=0, end=10, value=self.outlier_clip, callback_throttle=150, step=0.2 ) outlier_clip_slider.on_change("value", on_change_outlier_clip) main_content_select = widgets.Dropdown( label="Main Content", button_type="primary", menu=MAIN_CONTENT_OPTIONS ) main_content_select.on_click(on_change_main_view) # Widgety things main_color_range_slider = widgets.RangeSlider( start=0, end=100, value=( 0, 100, ), title="Color Range", ) # Attach callbacks main_color_range_slider.on_change("value", update_main_colormap) figures["main"].on_event(events.Tap, click_main_image) if two_dimensional: figures["right"].on_event(events.Tap, click_right_marginal) layout = row( column(figures["main"], figures.get("bottom")), column(figures["right"], app_widgets["fit_info_div"]), column( column( *[ widget for widget in [ self._cursor_info, main_color_range_slider, main_content_select, remove_outliers_toggle if two_dimensional else None, outlier_clip_slider if two_dimensional else None, ] if widget is not None ] ), ), ) update_fit_display() doc.add_root(layout) doc.title = "Band Tool"