Source code for arpes.fits.utilities

"""Provides broadcasted and process parallel curve fitting for PyARPES.

The core of this module is `broadcast_model` which is a serious workhorse in PyARPES for
analyses based on curve fitting. This allows simple multidimensional curve fitting by
iterative fitting across one or many axes. Currently basic strategies are implemented,
but in the future we would like to provide:

1. Passing xr.DataArray values to parameter guesses and bounds, which can be interpolated/selected
   to allow changing conditions throughout the curve fitting session.
2. A strategy allowing retries with initial guess taken from the previous fit. This is similar
   to some adaptive curve fitting routines that have been proposed in the literature.

import os
import dill
from packaging import version

import lmfit
import numpy as np
from tqdm import tqdm_notebook

import arpes.fits.fit_models
from typing import Callable, Union, Tuple, Any, Dict, List

import xarray as xr
from arpes.trace import traceable
from arpes.provenance import update_provenance
from arpes.typing import DataType
from arpes.utilities import normalize_to_spectrum
from . import mp_fits

__all__ = ("broadcast_model", "result_to_hints")

TypeIterable = Union[List[type], Tuple[type]]

XARRAY_REQUIRES_VALUES_WRAPPING = version.parse(xr.__version__) > version.parse("0.10.0")

def wrap_for_xarray_values_unpacking(item):
    """This is a shim for"""
        return np.array(item, dtype=object)

    return item

[docs]def result_to_hints(m: lmfit.model.ModelResult, defaults=None) -> Dict[str, Dict[str, Any]]: """Turns an `lmfit.model.ModelResult` into a dictionary with initial guesses. Args: m: The model result to extract parameters from defaults: Returned if `m` is None, useful for cell re-evaluation in Jupyter Returns: A dict containing parameter specifications in key-value rathan than `lmfit.Parameter` format, as you might pass as `params=` to PyARPES fitting code. """ if m is None: return defaults return {k: {"value": m.params[k].value} for k in m.params}
def parse_model(model): """Takes a model string and turns it into a tokenized version. 1. ModelClass -> ModelClass 2. [ModelClass] -> [ModelClass] 3. str -> [<ModelClass, operator as string>] i.e. A + (B + C) * D -> [A, '(', B, '+', C, ')', '*', D] Args: model: The model specification Returns: A tokenized specification of the model suitable for passing to the curve fitting routine. """ if not isinstance(model, str): return model pad_all = ["+", "-", "*", "/", "(", ")"] for pad in pad_all: model = model.replace(pad, " {} ".format(pad)) special = set(pad_all) def read_token(token): if token in special: return token try: token = float(token) return token except ValueError: try: return arpes.fits.fit_models.__dict__[token] except KeyError: raise ValueError("Could not find model: {}".format(token)) return [read_token(token) for token in model.split()]
[docs]@update_provenance("Broadcast a curve fit along several dimensions") @traceable def broadcast_model( model_cls: Union[type, TypeIterable], data: DataType, broadcast_dims, params=None, progress=True, weights=None, safe=False, prefixes=None, window=None, parallelize=None, trace: Callable = None, ): """Perform a fit across a number of dimensions. Allows composite models as well as models defined and compiled through strings. Args: model_cls: The model specification data: The data to curve fit broadcast_dims: Which dimensions of the input should be iterated across as opposed to fit across params: Parameter hints, consisting of plain values or arrays for interpolation progress: Whether to show a progress bar weights: Weights to apply when curve fitting. Should have the same shape as the input data safe: Whether to mask out nan values window: A specification of cuts/windows to apply to each curve fit parallelize: Whether to parallelize curve fits, defaults to True if unspecified and more than 20 fits were requested trace: Controls whether execution tracing/timestamping is used for performance investigation Returns: An `xr.Dataset` containing the curve fitting results. These are data vars: - "results": Containing an `xr.DataArray` of the `lmfit.model.ModelResult` instances - "residual": The residual array, with the same shape as the input - "data": The original data used for fitting - "norm_residual": The residual array normalized by the data, i.e. the fractional error """ if params is None: params = {} if isinstance(broadcast_dims, str): broadcast_dims = [broadcast_dims] trace("Normalizing to spectrum") data = normalize_to_spectrum(data) cs = {} for dim in broadcast_dims: cs[dim] = data.coords[dim] other_axes = set(data.dims).difference(set(broadcast_dims)) template = data.sum(list(other_axes)) template.values = np.ndarray(template.shape, dtype=np.object) n_fits = if parallelize is None: parallelize = n_fits > 20 trace("Copying residual") residual = data.copy(deep=True) residual.values = np.zeros(residual.shape) trace("Parsing model") model = parse_model(model_cls) wrap_progress = lambda x, *_, **__: x if progress: wrap_progress = tqdm_notebook serialize = parallelize fitter = mp_fits.MPWorker( data=data, uncompiled_model=model, prefixes=prefixes, params=params, safe=safe, serialize=serialize, weights=weights, window=window, ) if parallelize: trace(f"Running fits (nfits={n_fits}) in parallel (n_threads={os.cpu_count()})") print("Running on multiprocessing pool... this may take a while the first time.") from .hot_pool import hot_pool pool = hot_pool.pool exe_results = list( wrap_progress( pool.imap(fitter, template.G.iter_coords()), total=n_fits, desc="Fitting on pool..." ) ) else: trace(f"Running fits (nfits={n_fits}) serially") exe_results = [] for _, cut_coords in wrap_progress( template.G.enumerate_iter_coords(), desc="Fitting", total=n_fits ): exe_results.append(fitter(cut_coords)) if serialize: trace("Deserializing...") print("Deserializing...") def unwrap(result_data): # using the lmfit deserialization and serialization seems slower than double pickling with dill # result = lmfit.model.ModelResult(compiled_model, compiled_model.make_params()) # return result.loads(result_data) return dill.loads(result_data) exe_results = [(unwrap(res), residual, cs) for res, residual, cs in exe_results] print("Finished deserializing") trace(f"Finished running fits Collating") for fit_result, fit_residual, coords in exe_results: template.loc[coords] = wrap_for_xarray_values_unpacking(fit_result) residual.loc[coords] = fit_residual trace("Bundling into dataset") return xr.Dataset( { "results": template, "data": data, "residual": residual, "norm_residual": residual / data, }, residual.coords, )