No items found.

Battery design parameter estimation - Part 2

Battery design parameter estimation - Part 2Battery design parameter estimation - Part 2

In this tutorial, we demonstrate how you can use Collimator models for optimization. For a pulse discharge test (discharge current applied as a pulse as in the previous tutorial), we will first generate synthetic $v_t$ data from our battery model. Here we will use a set of known parameters, i.e. known curves representing the dependencies of $v_0$, $R_s$, $R_1$, and $C_1$ on the soc $s$. Then we will add some noise to the terminal voltage output, and imagine that this is the experimental $v_t$ curve that one obtains in a pulse discharge test. The goal then, would be to estimate the the $v_0(s)$, $R_s(s)$, $R_1(s)$, and $C_1(s)$ through optimization, i.e. minimizing the discrepancy between what the model outputs as parameters are varied through the optimizer, and the synthetic experimental curve.

Generating synthetic data

We have created a Battery class based on the model constructed in the previous tutorial into a [.code]battery.py[.code] Python file, so that we can import it here. Next, we use the same simulation technique as shown previously to generate a $v_t$ output, but with some a priori chosen $v_0(s)$, $R_s(s)$, $R_1(s)$, and $C_1(s)$ curves.

# Generate synthetic v0, Rs, R1, Rs curves representing dependency on soc.
# we label them as `true` as they will be the ground truth that the
# estimation/optimisation procedure will seek to find.

soc_points_true = jnp.linspace(0.0, 1.0, 11)
v0_points_true = 3.6 + (5.0 - 2.6) * soc_points_true
v0_points_true = v0_points_true.at[0].set(v0_points_true[0] - 1.0)
v0_points_true = v0_points_true.at[-1].set(v0_points_true[-1] + 1.0)

Rs_points_true = 15e-03 + (10e-03 - 15e-03) * soc_points_true
lamb = 2.0
R1_points_true = 10e-03 + (25e-03 - 10e-03) * (
    jnp.exp(-lamb * soc_points_true) - jnp.exp(-lamb)
) / (1 - jnp.exp(-lamb))
C1_points_true = 1.5e03 + (3.5e03 - 1.5e03) * (jnp.exp(lamb * soc_points_true) - 1) / (
    jnp.exp(lamb) - 1
)

fig, axs = plt.subplots(2, 2, figsize=(10, 6))
axs[0, 0].plot(soc_points_true, v0_points_true, "-ro", label=f"$v_0$: true")
axs[0, 1].plot(soc_points_true, Rs_points_true, "-go", label=f"$R_s$: true")
axs[1, 0].plot(soc_points_true, R1_points_true, "-bo", label=f"$R_1$: true")
axs[1, 1].plot(soc_points_true, C1_points_true, "-ko", label=f"$C_1$: true")
for ax in axs.flatten():
    ax.set_xlabel(r"$s$")
    ax.legend(loc="best")
plt.tight_layout()
plt.show()

The above code produces the following outputs:

Now that we have our data, let's build our model:

builder = collimator.DiagramBuilder()

# Create a battery with the above generated parameters
battery = Battery(
    soc_points=soc_points_true,
    v0_points=v0_points_true,
    Rs_points=Rs_points_true,
    R1_points=R1_points_true,
    C1_points=C1_points_true,
    name="battery"
)

builder.add(battery)

# Pulse block to represent pulse discharge test
discharge_current = builder.add(
    Pulse(100.0, 6 / 16.0, 16 * 60.0, name="discharge_current")
)

# We have two blocks/LeafSystems: discharge_current and battery. We need to connect the output
# of the discharge_current to the input of the battery.
builder.connect(discharge_current.output_ports[0], battery.input_ports[0])

diagram = builder.build()
context = diagram.create_context()  # Create default context

# Specify which signals to record in the simulation output
recorded_signals = {
    "discharge_current": discharge_current.output_ports[0],
    "soc": battery.output_ports[0],
    "vt": battery.output_ports[1],
}

t_span = (0.0, 9600.0)  # span over which to simulate the system

# simulate the combination of diagram and context over t_span
sol = collimator.simulate(diagram, context, t_span, recorded_signals=recorded_signals)

fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 6))
ax1.plot(sol.time, sol.outputs["discharge_current"], "-r", label=r"$i$")
ax2.plot(sol.time, sol.outputs["soc"], "-b", label=r"$s$")
ax3.plot(sol.time, sol.outputs["vt"], "-g", label=r"$v_t$")
ax1.set_ylabel("i [A]")
ax2.set_ylabel("soc [-]")
ax3.set_ylabel("$v_t$ [V]")
for ax in (ax1, ax2, ax3):
    ax.set_xlabel("time [s]")
    ax.legend(loc="best")
fig.tight_layout()
plt.show()

The result of running that simulation is the following output:

Note that the $v_t$ output goes negative, which is not what happens in the physical world. This is because we have arbitrarily created the curves for the dependency of $v_0$, $R_s$, $R_1$, and $C_1$ on $s$. We will ignore this oddity for now. In the next tutorial, we will work with real experimental data.

Now let's add some noise to the $v_t$ signal to mimic an experiment's $v_t$ measurement.

# Add noise to terminal voltage to create synthetic measurement
seed = 42
key = jax.random.PRNGKey(seed)

t_exp_data = sol.time
vt_exp_data = sol.outputs["vt"] + 0.03 * jax.random.normal(key, sol.outputs["vt"].shape)

fig, ax = plt.subplots(figsize=(10, 3))
ax.plot(t_exp_data, vt_exp_data, label=r"$v_t$: exp")
plt.tight_layout()
plt.show()

Computing the $L_2$ error

Given the synthetic experimental data ($v_t$ measurements), our goal is to estimate the parameters of the battery so that the output of the model matches that of the experiment. We can use an discrepancy metric such as the $L_2$ error for this. Thus our optimization problem becomes:

$$ \hat{\theta} = \text{arg}\,\min\limits_{\theta}\ \int_0^{T} \left(v_t^{exp} - v_t(\theta) \right)^2 dt, $$

where $\theta$ represent the model parameters. In our case, we have 44 model parameters: 11 values per curve for the $v_0$, $R_s$, $R_1$, and $C_1$ curves.

To compute the loss function $\int_0^{T} \left(v_t^{exp} - v_t(\theta) \right)^2 dt$, we can use some additional primitive blocks to create a diagram that will compute the integral. We note that the integrand is a squared error. [.code]Integrator[.code] block is already available in the Wildcat library, so all we need to do is to compute the squared error, which can be achieved through an [.code]Adder[.code] and [.code]Power[.code] blocks. Such a system, which takes two inputs $x$ and $y$ and computes $\int (x-y)^2$, can be created as follows:

from collimator.library import Adder, Power, Integrator, LookupTable1d, Clock

def make_l2_loss(name="l2_loss"):
    builder = collimator.DiagramBuilder()

    err = builder.add(
        Adder(2, operators="+-", name="err")
    )  # compute difference/error between the two input ports
    sq_err = builder.add(Power(2.0, name="sq_err"))  # square the above error
    sq_err_int = builder.add(Integrator(0.0, name="sq_err_int"))  # integrate

    builder.connect(err.output_ports[0], sq_err.input_ports[0])
    builder.connect(sq_err.output_ports[0], sq_err_int.input_ports[0])

    # Diagram level export inputs and outputs
    builder.export_input(err.input_ports[0])  # x
    builder.export_input(err.input_ports[1])  # y
    builder.export_output(sq_err_int.output_ports[0])  # \int (x-y)^2 dt

    return builder.build(name=name)

We also need to generate an interpolant block for the experimental data, which can be achieved through the [.code]LookUpTable1d[.code] block. With the experimental [.code]t_exp_data[.code] and [.code]vt_exp_data[.code] generated above as the parameters of such a block, the block output will be a linear interpolation of this data at any simulation time [.code]t[.code].

The entire system can now be created as follows:

builder = collimator.DiagramBuilder()

discharge_current = builder.add(
    Pulse(100.0, 6 / 16.0, 16 * 60.0, name="discharge_current")
)

battery = builder.add(Battery(name="battery"))
l2_loss = builder.add(make_l2_loss(name="l2_loss"))

clock = builder.add(Clock(name="clock"))
vt_exp = builder.add(LookupTable1d(t_exp_data, vt_exp_data, "linear", name="vt_exp"))

builder.connect(discharge_current.output_ports[0], battery.input_ports[0])
builder.connect(clock.output_ports[0], vt_exp.input_ports[0])
builder.connect(battery.output_ports[1], l2_loss.input_ports[0])
builder.connect(vt_exp.output_ports[0], l2_loss.input_ports[1])

builder.export_output(l2_loss.output_ports[0])

diagram = builder.build()
context = diagram.create_context()

...and then we can create a function to plot the results for any given set of parameters (11x4 values of the $v_0$, $R_s$, $R_1$, and $C_1$, curves):

def forward_plot(v0_points, Rs_points, R1_points, C1_points, context):
    new_params = {
        "v0_points": v0_points,
        "Rs_points": Rs_points,
        "R1_points": R1_points,
        "C1_points": C1_points,
    }

    subcontext = context[diagram["battery"].system_id].with_parameters(new_params)
    context = context.with_subcontext(diagram["battery"].system_id, subcontext)

    recorded_signals = {
        "vt": battery.output_ports[1],
        "vt_exp": vt_exp.output_ports[0],
        "soc": battery.output_ports[0],
        "discharge_current": discharge_current.output_ports[0],
        "l2_loss": l2_loss.output_ports[0],
    }

    sol = collimator.simulate(
        diagram, context, (0.0, 9600.0), recorded_signals=recorded_signals
    )
    l2_loss_final = sol.outputs["l2_loss"][-1]
    print(f"{l2_loss_final=}")

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 6))
    ax1.plot(sol.time, sol.outputs["vt"], "-b", label="Terminal voltage: sim")
    ax1.plot(sol.time, sol.outputs["vt_exp"], "-r", label="Terminal voltage: exp")
    ax2.plot(sol.time, sol.outputs["l2_loss"], label=r"$\int e^2 dt$")
    ax1.legend()
    ax2.legend()
    fig.tight_layout()
    return fig

Running a simulation with the above is simple: 

# run the simulation with arbitrarily chosen flat curves of v_0, Rs, R1, and C1 and plot the results
fig_forward = forward_plot(
    v0_points=3.0 * jnp.ones(11),
    Rs_points=15e-03 * jnp.ones(11),
    R1_points=15e-03 * jnp.ones(11),
    C1_points=2e03 * jnp.ones(11),
    context=context,
)
plt.show()

Optimization setup

So now we have a way to obtain the L2 error. For parameter estimation, we need to minimize this error. Eliminating the overhead of recording the solution and only generating the L2 error of interest we can write a more efficient function than the one above. Before we do that, we need to ensure that during the optimization process, physically positive parameters of $v_0$, $R_s$, $R_1$, and $C_1$ do not become negative. One approach would be to use constrained or at least box constrained optimizers. Alternatively, we can use unconstrained optimization, but transform the parameters such that they never become negative. For the latter, we can use a logarithmic transformation of the following form for $v_0$, $R_s$, $R_1$, and $C_1$. Representing them generically as $\Psi$, we can apply a transformation of the following form:

$$ \Psi = \Psi_{\text{ref}}\, e^{\psi} $$

Thus, for example, $v_0 = v_{0_{\text{ref}}} \, e^{\psi_{v_0}}$, and the optimization variable is $\psi_{v_0}$ instead of $v_0$. Note that $\psi_{v_0}$ can be unconstrained while keeping the real parameter $v_0$ always positive. In the following code $\psi$ variables are labelled as [.code]log_params[.code]:

# Define reference values
v0_ref = 3.0
Rs_ref = 10e-03
R1_ref = 15e-03
C1_ref = 2e03

options = SimulatorOptions(max_minor_step_size=1.0, max_major_steps=100, rtol=1e-06, atol=1e-08)

@jax.jit
def forward(log_params, context):
    # transform the log_params to real parametric space and update context
    params = jnp.exp(log_params)
    params_arr = params.reshape((4, 11))

    new_params = {
        "v0_points": v0_ref * params_arr[0, :],
        "Rs_points": Rs_ref * params_arr[1, :],
        "R1_points": R1_ref * params_arr[2, :],
        "C1_points": C1_ref * params_arr[3, :],
    }

    subcontext = context[diagram["battery"].system_id].with_parameters(new_params)
    context = context.with_subcontext(diagram["battery"].system_id, subcontext)

    

    sol = collimator.simulate(diagram, context, (0.0, 9600.0), options=options)

    l2_loss = sol.context[diagram["l2_loss"]["sq_err_int"].system_id].continuous_state

    # normalise the l2_loss and add regularisation (optional)
    cost = (1.0 / 9600) * l2_loss  # + (1.0/44)*1e-03*jnp.sum(log_params**2)
    return cost

# For some set of parameters (also initial guesses represented by `_0` suffix) run the simulation and print cost
v0_points_0 = 3.0 * jnp.ones(11)
Rs_points_0 = 15e-03 * jnp.ones(11)
R1_points_0 = 15e-03 * jnp.ones(11)
C1_points_0 = 2e03 * jnp.ones(11)

log_params_0 = jnp.hstack(
    [
        jnp.log(v0_points_0 / v0_ref),
        jnp.log(Rs_points_0 / Rs_ref),
        jnp.log(R1_points_0 / R1_ref),
        jnp.log(C1_points_0 / C1_ref),
    ]
)
cost = forward(log_params_0, context)
print(f"{cost=}")

cost=Array(4.37318956, dtype=float64, weak_type=True)

Note that Collimator models can be [.code]jit[.code] compiled as Collimator leverages the JAX framework. This is important for multiple fast runs and for automatically computing gradients, which can be used for many purposes including optimization (as shown below).

Optimization

Now that we have the [.code]forward[.code] function, we can leverage JAX's autodiff functionalities to compute the gradient of our [.code]forward[.code] function automatically.

grad_forward = jax.jit(jax.grad(forward))
grads_0 = grad_forward(log_params_0, context)
print(f"{grads_0.shape}")

(44,)

With the [.code]jit[.code] compiled [.code]forward[.code] and [.code]grad_forward[.code] functions created, we can leverage any optimization framework (scipy, optax, jaxopt, etc.). We demonstrate optimization with [.code]scipy.optimize[.code] first.

Scipy

from scipy.optimize import minimize
from functools import partial

res = minimize(
    partial(forward, context=context),
    log_params_0,
    jac=partial(grad_forward, context=context),
    method="BFGS",
)
print(f"Optimization suceeded: {res.success}")
print(f"Optimized function value: {res.fun}")

Optimization suceeded: False
Optimized function value: 0.0006845431040740013

The optimizer's [.code]success[.code] status is [.code]False[.code]. It may have gotten stuck in a local minima, a common issue with gradient based optimizers. We can try with different guesses, rescaling the loss function, and/or re-parameterizing the optimization variables. However, for now, let's proceed with the best solution found by the BFGS optimizer, and inspect the gradients at the optimal values:

opt_log_params = jnp.array(res.x)
grad_forward(opt_log_params, context)

Array([ 4.56354999e-07,  1.70414089e-06,  1.25888861e-06, -2.67128084e-06,
        1.25510572e-06, -4.35678180e-06,  1.35288534e-04, -2.74137097e-06,
       -9.47457128e-07, -1.40904947e-05, -1.70302183e-06, -5.03830965e-06,
        2.69772030e-06, -6.22192953e-06,  1.29389009e-05,  3.96714786e-06,
        1.73411954e-06, -5.72306343e-06,  5.47048689e-06,  1.40880380e-06,
        5.64562874e-07,  2.61121869e-07,  2.82774118e-06, -3.22560244e-06,
        3.76466226e-06, -1.00973561e-05, -2.86161852e-06, -1.85306970e-06,
        8.61201772e-06, -5.12136796e-06,  2.13797362e-07,  2.33656516e-06,
       -2.12832262e-05,  6.26569048e-06,  3.56580454e-06, -9.38141567e-06,
        2.42977789e-06,  5.66475004e-06, -2.95595382e-06,  3.03574486e-06,
       -3.01631061e-06,  6.54811747e-06,  7.21253362e-07, -6.81386193e-06],      dtype=float64)

We see that they are sufficiently small, implying a good solution has been found. Next, we can plot the output $v_t$ with the optimal parameters and compare it with the experimental curve:

opt_log_params = jnp.array(res.x)

# Convert from \psi-space to real space
opt_params = jnp.exp(opt_log_params)
opt_params_arr = opt_params.reshape((4, 11))
opt_v0_points = v0_ref * opt_params_arr[0, :]
opt_Rs_points = Rs_ref * opt_params_arr[1, :]
opt_R1_points = R1_ref * opt_params_arr[2, :]
opt_C1_points = C1_ref * opt_params_arr[3, :]

# Plot the results
fig_opt = forward_plot(
    v0_points=opt_v0_points,
    Rs_points=opt_Rs_points,
    R1_points=opt_R1_points,
    C1_points=opt_C1_points,
    context=context,
)
plt.show()

We see that the simulation output with optimal parameter curves closely follows the experimental data. Note that with the addition of noise, the optimal parameters are not necessarily unique. In the absence of any further data, the only information we have to estimate the parameters is the experimental $v_t$ curve. We can compare the optimal solution found against the true known parameters, a luxury only available for synthetic data.

fig, axs = plt.subplots(2, 2, figsize=(10, 6))
axs[0, 0].plot(soc_points_true, v0_points_true, "-ro", label=f"$v_0$: true")
axs[0, 1].plot(soc_points_true, Rs_points_true, "-go", label=f"$R_s$: true")
axs[1, 0].plot(soc_points_true, R1_points_true, "-bo", label=f"$R_1$: true")
axs[1, 1].plot(soc_points_true, C1_points_true, "-ko", label=f"$C_1$: true")

axs[0, 0].plot(soc_points_true, opt_v0_points, "--r*", label=f"$v_0$: opt")
axs[0, 1].plot(soc_points_true, opt_Rs_points, "--g*", label=f"$R_s$: opt")
axs[1, 0].plot(soc_points_true, opt_R1_points, "--b*", label=f"$R_1$: opt")
axs[1, 1].plot(soc_points_true, opt_C1_points, "--k*", label=f"$C_1$: opt")

for ax in axs.flatten():
    ax.set_xlabel(r"$s$")
    ax.legend(loc="best")
plt.tight_layout()
plt.show()

The optimizer does a reasonably good job of estimating the curves, particularly in light that parameters are not unique and many combinations of the parameters, including the one found by the optimizer, are able to reproduce the experimental data. With [.code]callback[.code] functions (see Scipy documentation), one can visualize, how the parameters evolved during the optimization process. See the example below for the same problem, but different initial conditions:

One can also try changing initial conditions and/or trying different solvers. Next, we demonstrate how we can use [.code]optax[.code] for optimization of our model.

Optax

With optax, we use the [.code]Adam[.code] optimizer and use our [.code]jit[.code] compiled [.code]grad_forward[.code] function to compute the gradients:

import optax

# Optax optimizer
optimizer = optax.adam(learning_rate=0.01)

# Initialize optimizer state
log_params = log_params_0
opt_state = optimizer.init(log_params)

# Gradient descent loop
num_epochs = 1000
for epoch in range(num_epochs):
    gradients = grad_forward(log_params, context)
    updates, opt_state = optimizer.update(gradients, opt_state)
    log_params = optax.apply_updates(log_params, updates)

    # Print the function value at the current parameters
    if (epoch + 1) % 50 == 0:
        current_function_value = forward(log_params, context)
        print(
            f"Epoch [{epoch+1}/{num_epochs}]: forward(log_params) = {current_function_value}"
        )

Epoch [50/1000]: forward(log_params) = 0.27251043220572385
Epoch [100/1000]: forward(log_params) = 0.013398332180523925
Epoch [150/1000]: forward(log_params) = 0.0049060524910408045
Epoch [200/1000]: forward(log_params) = 0.002560092166880761
Epoch [250/1000]: forward(log_params) = 0.0015999159535121869
Epoch [300/1000]: forward(log_params) = 0.0011705076890363313
Epoch [350/1000]: forward(log_params) = 0.0009641271450004881
Epoch [400/1000]: forward(log_params) = 0.0008576420903454343
Epoch [450/1000]: forward(log_params) = 0.0007988910413465778
Epoch [500/1000]: forward(log_params) = 0.0007645753239331456
Epoch [550/1000]: forward(log_params) = 0.0007435847426334666
Epoch [600/1000]: forward(log_params) = 0.0007302031624058201
Epoch [650/1000]: forward(log_params) = 0.0007211318539287574
Epoch [700/1000]: forward(log_params) = 0.0007149030603885657
Epoch [750/1000]: forward(log_params) = 0.000710292290209918
Epoch [800/1000]: forward(log_params) = 0.0007067809082700486
Epoch [850/1000]: forward(log_params) = 0.0007041046804926375
Epoch [900/1000]: forward(log_params) = 0.0007018041457500387
Epoch [950/1000]: forward(log_params) = 0.0006998696743720331
Epoch [1000/1000]: forward(log_params) = 0.0006982385532658905

Similarly to the [.code]scipy[.code] case, we can plot the results of the optimal solution found:

opt_log_params = log_params

# Convert from \psi-space to real space
opt_params = jnp.exp(opt_log_params)
opt_params_arr = opt_params.reshape((4, 11))
opt_v0_points = v0_ref * opt_params_arr[0, :]
opt_Rs_points = Rs_ref * opt_params_arr[1, :]
opt_R1_points = R1_ref * opt_params_arr[2, :]
opt_C1_points = C1_ref * opt_params_arr[3, :]

# Plot the results
fig_opt = forward_plot(
    v0_points=opt_v0_points,
    Rs_points=opt_Rs_points,
    R1_points=opt_R1_points,
    C1_points=opt_C1_points,
    context=context,
)

fig, axs = plt.subplots(2, 2, figsize=(10, 6))
axs[0, 0].plot(soc_points_true, v0_points_true, "-ro", label=f"$v_0$: true")
axs[0, 1].plot(soc_points_true, Rs_points_true, "-go", label=f"$R_s$: true")
axs[1, 0].plot(soc_points_true, R1_points_true, "-bo", label=f"$R_1$: true")
axs[1, 1].plot(soc_points_true, C1_points_true, "-ko", label=f"$C_1$: true")

axs[0, 0].plot(soc_points_true, opt_v0_points, "--r*", label=f"$v_0$: opt")
axs[0, 1].plot(soc_points_true, opt_Rs_points, "--g*", label=f"$R_s$: opt")
axs[1, 0].plot(soc_points_true, opt_R1_points, "--b*", label=f"$R_1$: opt")
axs[1, 1].plot(soc_points_true, opt_C1_points, "--k*", label=f"$C_1$: opt")

for ax in axs.flatten():
    ax.set_xlabel(r"$s$")
    ax.legend(loc="best")
plt.tight_layout()
plt.show()

l2_loss_final=6.072426816424518

Check out installment #3 to see how to use real world data to optimize your parameters.

Try it in Collimator