Fit Accessor¶
The FitAccessor is a pandas DataFrame accessor that provides a simple interface for fitting data.
- class ezfit.fit.FitAccessor(df: DataFrame)[source]¶
Bases:
objectAccessor for fitting data in a pandas DataFrame to a given model.
- __call__(model: Callable[..., Any], x: str, y: str, yerr: str | None = None, plot: bool = True, method: FitMethod = 'curve_fit', fit_kwargs: FitKwargs | None = None, residuals: Literal['none', 'res', 'percent', 'rmse'] = 'res', color_error: str = 'C0', color_model: str = 'C3', color_residuals: str = 'C0', fmt_error: str = '.', ls_model: str = '-', ls_residuals: str = '', marker_residuals: str = '.', err_kws: dict[str, Any] | None = None, mod_kws: dict[str, Any] | None = None, res_kws: dict[str, Any] | None = None, **parameters: dict[str, Any]) tuple[Model, Axes | None, Axes | None][source]¶
Fit the data to the model and optionally plot the results.
Calls the [FitAccessor.fit](#fitaccessorfit) and [FitAccessor.plot](#fitaccessorplot) methods in sequence.
- Parameters:
model (Callable[..., Any]) – The model function to fit the data to.
x (str) – The name of the column in the DataFrame for the independent variable.
y (str) – The name of the column in the DataFrame for the dependent variable.
yerr (str | None, optional) – The name of the column for the error on the dependent variable, by default None.
plot (bool, optional) – Whether to plot the results, by default True.
method (FitMethod, optional) – The fitting method to use, by default “curve_fit”. Available methods: ‘curve_fit’, ‘minimize’, ‘differential_evolution’, ‘shgo’, ‘dual_annealing’, ‘emcee’, ‘bayesian_ridge’, ‘ridge’, ‘lasso’, ‘elasticnet’, ‘polynomial’. ‘bayesian_ridge’ requires scikit-learn and is only valid for linear models. ‘emcee’ requires emcee. Methods other than ‘curve_fit’ and ‘bayesian_ridge’ may require sigma (yerr).
fit_kwargs (FitKwargs | None, optional) – Keyword arguments passed to the fitting function (e.g., scipy.optimize.curve_fit, scipy.optimize.minimize, etc.), by default None.
residuals (Literal["none", "res", "percent", "rmse"], optional) – The type of residuals to plot. Set to “none” to disable residuals plot, by default “res”.
color_error (str, optional) – Color for data points/error bars, by default “C0”.
color_model (str, optional) – Color for the fitted model line, by default “C3”.
color_residuals (str, optional) – Color for the residuals plot, by default “C0”.
fmt_error (str, optional) – Marker style for data points, by default “.”.
ls_model (str, optional) – Line style for the model line, by default “-“.
ls_residuals (str, optional) – Line style for residuals, by default “”.
marker_residuals (str, optional) – Marker style for residuals, by default “.”.
err_kws (dict[str, Any] | None, optional) – Additional keyword arguments for data/error bar plotting (plt.errorbar), by default None.
mod_kws (dict[str, Any] | None, optional) – Additional keyword arguments for model line plotting (plt.plot), by default None.
res_kws (dict[str, Any] | None, optional) – Additional keyword arguments for residuals plotting (plt.plot), by default None.
**parameters (dict[str, Any]) – Specification of model parameters (initial values, bounds, fixed). Passed as keyword arguments, e.g., param_name={“value”: 1, “min”: 0}.
- Returns:
A tuple containing the fitted Model object, the main plot Axes (or None), and the residuals plot Axes (or None).
- Return type:
- Raises:
ColumnNotFoundError – If a specified column (x, y, yerr) is not found.
ImportError – If a required library (e.g., scikit-learn, emcee) is not installed for the chosen method.
ValueError – If an invalid method is chosen, if required arguments (like sigma) are missing for a method, or if the fit fails.
TypeError – If the model is not a callable or if the parameters are not a dictionary or if the parameters are not valid for the model.
- fit(model: Callable[..., Any], x: str, y: str, yerr: str | None = None, method: FitMethod = 'curve_fit', fit_kwargs: FitKwargs | None = None, **parameters: dict[str, Any]) Model[source]¶
Fit the data to the model.
- Parameters:
model (Callable[..., Any]) – The model function to fit the data to.
x (str) – The name of the column for the independent variable.
y (str) – The name of the column for the dependent variable.
yerr (str | None, optional) – The name of the column for the error on the dependent variable, by default None.
method (FitMethod, optional) – The fitting method to use, by default “curve_fit”. Available methods: ‘curve_fit’, ‘minimize’, ‘differential_evolution’, ‘shgo’, ‘dual_annealing’, ‘emcee’, ‘bayesian_ridge’, ‘ridge’, ‘lasso’, ‘elasticnet’, ‘polynomial’. ‘bayesian_ridge’ requires scikit-learn and is only valid for linear models. ‘emcee’ requires emcee. Methods other than ‘curve_fit’ and ‘bayesian_ridge’ require sigma (yerr).
fit_kwargs (FitKwargs | None, optional) – Keyword arguments passed to the underlying fitting function (e.g., scipy.optimize.curve_fit, scipy.optimize.minimize, sklearn.linear_model.BayesianRidge, emcee.EnsembleSampler), by default None.
**parameters (dict[str, Any]) – Specification of model parameters (initial values, bounds, fixed).
- Returns:
The fitted Model object.
- Return type:
- Raises:
ColumnNotFoundError – If a specified column (x, y, yerr) is not found.
ImportError – If a required library (e.g., scikit-learn, emcee) is not installed for the chosen method.
ValueError – If an invalid method is chosen, if required arguments (like sigma) are missing for a method, or if the fit fails.
- plot(x: str, y: str, model: Model, yerr: str | None = None, ax: Axes | None = None, plot_options: PlotOptions | None = None, residuals: Literal['none', 'res', 'percent', 'rmse'] | None = None, color_error: str | None = None, color_model: str | None = None, color_residuals: str | None = None, fmt_error: str | None = None, ls_model: str | None = None, ls_residuals: str | None = None, marker_residuals: str | None = None, err_kws: dict[str, Any] | None = None, mod_kws: dict[str, Any] | None = None, res_kws: dict[str, Any] | None = None) Axes | tuple[Axes, Axes][source]¶
Plot the data, model, and residuals.
- Parameters:
x (str) – The name of the column for the independent variable.
y (str) – The name of the column for the dependent variable.
model (Model) – The fitted Model object containing the function and parameters.
yerr (str | None, optional) – The name of the column for the error on the dependent variable, by default None.
ax (Axes | None, optional) – An existing Matplotlib Axes object to plot on. If None, a new figure/axes is created, by default None.
plot_options (PlotOptions | None, optional) – PlotOptions object containing all plotting parameters. If provided, overrides individual plotting parameters, by default None.
residuals (Literal["none", "res", "percent", "rmse"] | None, optional) – The type of residuals to plot. Set to “none” to disable residuals plot. Overrides plot_options if provided, by default None (defaults to “res”).
color_error (str | None, optional) – Color for data points/error bars. Overrides plot_options if provided, by default None.
color_model (str | None, optional) – Color for the fitted model line. Overrides plot_options if provided, by default None.
color_residuals (str | None, optional) – Color for the residuals plot. Overrides plot_options if provided, by default None.
fmt_error (str | None, optional) – Marker style for data points. Overrides plot_options if provided, by default None.
ls_model (str | None, optional) – Line style for the model line. Overrides plot_options if provided, by default None.
ls_residuals (str | None, optional) – Line style for residuals. Overrides plot_options if provided, by default None.
marker_residuals (str | None, optional) – Marker style for residuals. Overrides plot_options if provided, by default None.
err_kws (dict[str, Any] | None, optional) – Additional keyword arguments for plt.errorbar. Overrides plot_options if provided, by default None.
mod_kws (dict[str, Any] | None, optional) – Additional keyword arguments for model line plt.plot. Overrides plot_options if provided, by default None.
res_kws (dict[str, Any] | None, optional) – Additional keyword arguments for residuals plt.plot. Overrides plot_options if provided, by default None.
- Returns:
The main plot Axes object, or a tuple of (main Axes, residuals Axes) if residuals are plotted.
- Return type:
Axes | tuple[Axes, Axes]
- Raises:
ColumnNotFoundError – If a specified column (x, y, yerr) is not found.
ValueError – If an invalid residuals metric is specified or model has no parameters.
Examples¶
Basic Usage¶
import pandas as pd
import ezfit
df = pd.read_csv("data.csv")
def line(x, m, b):
return m * x + b
model, ax, ax_res = df.fit(line, "x", "y", "yerr")
With Parameter Bounds¶
model, ax, _ = df.fit(
line, "x", "y", "yerr",
m={"value": 1.0, "min": 0, "max": 10},
b={"value": 0.0, "min": -5, "max": 5}
)
Using Different Methods¶
# Default: curve_fit
model, ax, _ = df.fit(line, "x", "y", "yerr")
# Use MCMC
model, ax, _ = df.fit(
line, "x", "y", "yerr",
method="emcee",
fit_kwargs={"nwalkers": 50, "nsteps": 1000}
)
# Use scikit-learn
model, ax, _ = df.fit(line, "x", "y", method="ridge")