Nonlinear Fitting#
What about the case of fitting to a function where the fit parameters enter in a nonlinear fashion? For example:
One trick that is often used for something like this is to transform the data. So instead of fitting the data \((x_i, y_i)\), you instead fit \((x_i, \log y_i)\), and then our fitting function is:
which is linear.
However, when there are errors associated with the \(y_i\), the errors do not necessarily transform the correct way when you take the logarithm.
So let’s look at how we would fit directly to a nonlinear function.
We’ll minimize the same fitting function:
with fitting parameters \({\bf a} = (a_1, \ldots, a_M)^\intercal\).
Now we take the derivatives with respect to each parameter, \(a_k\):
Let’s define \(g_k \equiv {\partial \chi^2}/{\partial a_k}\), then we have
This is a nonlinear system of \(M\) equations and \(M\) unknowns. We can solve this using the same multivariate Newton’s method we looked at before:
Take an initial guess at the fit parameters, \({\bf a}^{(k)}\)
Solve the system \({\bf J}\delta {\bf a} = -{\bf g}\), where \(J_{ij} = \partial g_i/\partial a_j\) is the Jacobian
Correct the initial guess, \({\bf a}^{(k+1)} = {\bf a}^{(k)} + \delta {\bf a}\)
As we’ve seen with Newton’s method, convergence will be very sensitive to the initial guess.
Fitting an exponential#
Let’s try this out on data that is constructed to follow an exponential trend.
First let’s construct the data, and perturb it with some errors. We’ll take the form:
import numpy as np
import matplotlib.pyplot as plt
# make up some experimental data
a0 = 2.5
a1 = 2./3.
sigma = 4.0
N = 25
x = np.linspace(0.0, 4.0, N)
r = sigma * np.random.randn(N)
y = a0 * np.exp(a1 * x) + r
yerr = np.abs(r)
fig, ax = plt.subplots()
ax.errorbar(x, y, yerr=yerr, fmt="o")
<ErrorbarContainer object of 3 artists>
Now, let’s compute our vector \({\bf g}\) that we will zero:
We can divide out the \(-2\) in each expression. We’ll keep the overall \(a_0\) in the expression, to deal with the case where it might be \(0\).
Let’s write a function to compute this:
def g(x, y, yerr, a):
"""compute the nonlinear functions we minimize. Here a is the vector
of fit parameters"""
a0, a1 = a
g0 = np.sum(np.exp(a1 * x) * (y - a0 * np.exp(a1 * x)) / yerr**2)
g1 = a0 * np.sum(x * np.exp(a1 * x) * (y - a0 * np.exp(a1 * x)) / yerr**2)
return np.array([g0, g1])
We also need the Jacobian. We could either compute this numerically, via differencing, or analytically. We’ll do the latter.
Notice that the Jacobian is symmetric:
This is called the Hessian matrix.
Let’s write this function:
def jac(x, y, yerr, a):
""" compute the Jacobian of the function g"""
a0, a1 = a
dg0da0 = -np.sum(np.exp(2.0 * a1 * x) / yerr**2)
dg0da1 = np.sum(x * np.exp(a1 * x) * (y - 2.0 * a0 * np.exp(a1 * x)) / yerr**2)
dg1da0 = dg0da1
dg1da1 = np.sum(a0 * x**2 * np.exp(a1 * x) * (y - 2.0 * a0 * np.exp(a1 * x)) / yerr**2)
return np.array([[dg0da0, dg0da1],
[dg1da0, dg1da1]])
def fit(aguess, x, y, yerr, tol=1.e-5):
""" aguess is the initial guess to our fit parameters. x and y
are the vector of points that we are fitting to, and yerr are
the errors in y"""
avec = aguess.copy()
err = 1.e100
while err > tol:
# get the jacobian
J = jac(x, y, yerr, avec)
print("condition number of J: ", np.linalg.cond(J))
# get the current function values
gv = g(x, y, yerr, avec)
# solve for the correction: J da = -g
da = np.linalg.solve(J, -gv)
avec += da
err = np.max(np.abs(da))
return avec
# initial guesses
aguess = np.array([2.0, 1.0])
# fit
afit = fit(aguess, x, y, yerr)
condition number of J: 153.76481549883144
condition number of J: 188.15847894481834
condition number of J: 284.4508368892106
condition number of J: 721.45460913468
condition number of J: 43981.4975463958
condition number of J: 90.49125173924254
condition number of J: 86.87485525849216
condition number of J: 81.76181129616644
condition number of J: 75.20983978761609
condition number of J: 67.98298745401907
condition number of J: 61.64759304528186
condition number of J: 57.93890003765046
condition number of J: 57.969212826396024
condition number of J: 62.772072442053194
condition number of J: 75.10571475327122
condition number of J: 97.08888424337464
condition number of J: 99.67912380606471
condition number of J: 143.8928462463668
condition number of J: 81.3142589034596
condition number of J: 33.39182658576079
condition number of J: 12.49418241560889
condition number of J: 5.2502017618210886
condition number of J: 2.2518856596746613
condition number of J: 1.1094590704715854
condition number of J: 2.936021717338727
condition number of J: 7.832524307631303
condition number of J: 20.600142118660433
condition number of J: 52.90691250880374
condition number of J: 130.84202534394086
condition number of J: 302.6805578180602
condition number of J: 614.8379478856145
condition number of J: 977.9745524352268
condition number of J: 1147.5527129356112
condition number of J: 1164.380742982674
condition number of J: 1164.50561627372
afit
array([ 4.06614613, -15.29281456])
ax.plot(x, afit[0] * np.exp(afit[1] *x))
fig
Is it a minimum?#
We just found an extrema. Let’s plot the surface around our fit parameters to see if it looks like a minimum
npts = 100
a0v = np.linspace(0.5 * afit[0], 2.0 * afit[0], npts)
a1v = np.linspace(0.5 * afit[1], 2.0 * afit[1], npts)
def chisq(a0, a1, x, y, yerr):
return np.sum((y - a0 * np.exp(a1 * x))**2 / yerr**2)
c2 = np.zeros((npts, npts), dtype=np.float64)
for i, a0 in enumerate(a0v):
for j, a1 in enumerate(a1v):
c2[i, j] = chisq(a0, a1, x, y, yerr)
c2.max()
np.float64(37738.404401884116)
Now we’ll plot the (log of) the \(\chi^2\)
fig, ax = plt.subplots()
# we need to transpose to put a0 on the horizontal
# we use origin = lower to have the origin at the lower left
im = ax.imshow(np.log10(c2).T,
origin="lower",
extent=[a0v[0], a0v[-1], a1v[0], a1v[-1]])
fig.colorbar(im, ax=ax, orientation="horizontal")
ax.scatter([afit[0]], [afit[1]], color="r", marker="x")
ax.set_xlabel("$a_0$")
ax.set_ylabel("$a_1$")
Text(0, 0.5, '$a_1$')
It looks like there is a very broad minimum there.
Troubles#
Consider if we tried to add another parameter, fitting to:
here \(a_2\) enters the same way as \(a_0\), which would give a singular matrix, and make our solution unstable.