"""
This module contains functions that implement the JOPT algorithm to
calculate optimal parameters for analytical control pulse sequences.
"""
import qutip as qt
from qutip import Qobj, QobjEvo
from diffrax import Dopri5, PIDController
import jax
from jax import custom_jvp
import jax.numpy as jnp
import qutip_jax # noqa: F401
@custom_jvp
def _abs(x):
return jnp.abs(x)
[docs]def _abs_jvp(primals, tangents):
"""
Custom jvp for absolute value of complex functions
"""
(x,) = primals
(t,) = tangents
abs_x = _abs(x)
res = jnp.where(
abs_x == 0,
0.0, # prevent division by zero
jnp.real(jnp.multiply(jnp.conj(x), t)) / abs_x,
)
return abs_x, res
# register custom jvp for absolut value of complex functions
_abs.defjvp(_abs_jvp)
[docs]class _JOPT:
"""
Class for storing a control problem and calculating
the fidelity error function and its gradient wrt the control parameters.
"""
def __init__(
self,
objective,
time_interval,
time_options,
control_parameters,
alg_kwargs,
guess_params,
**integrator_kwargs,
):
self._Hd = objective.H[0]
self._Hc_lst = objective.H[1:]
self._control_parameters = control_parameters
self._guess_params = guess_params
self._H = self._prepare_generator()
self._initial = objective.initial.to("jax")
self._target = objective.target.to("jax")
self._evo_time = time_interval.evo_time
self._var_t = "guess" in time_options
# inferred attributes
self._norm_fac = 1 / self._target.norm()
# integrator options
self._integrator_kwargs = integrator_kwargs
self._integrator_kwargs["method"] = "diffrax"
self._rtol = self._integrator_kwargs.get("rtol", 1e-5)
self._atol = self._integrator_kwargs.get("atol", 1e-5)
self._integrator_kwargs.setdefault(
"stepsize_controller", PIDController(rtol=self._rtol, atol=self._atol)
)
self._integrator_kwargs.setdefault("solver", Dopri5())
# choose solver and fidelity type according to problem
if self._Hd.issuper:
self._fid_type = alg_kwargs.get("fid_type", "TRACEDIFF")
self._solver = qt.MESolver(H=self._H, options=self._integrator_kwargs)
else:
self._fid_type = alg_kwargs.get("fid_type", "PSU")
self._solver = qt.SESolver(H=self._H, options=self._integrator_kwargs)
self.infidelity = jax.jit(self._infid)
self.gradient = jax.jit(jax.grad(self._infid))
[docs] def _prepare_generator(self):
"""
prepare Hamiltonian call signature
to only take one parameter vector 'p' for mesolve like:
qt.mesolve(H, psi0, tlist, args={'p': p})
"""
def helper(control, lower, upper):
# to fix parameter index in loop
return jax.jit(lambda t, p: control(t, p[lower:upper]))
H = QobjEvo(self._Hd)
idx = 0
for Hc, p_opt in zip(self._Hc_lst, self._control_parameters.values()):
hc, ctrl = Hc[0], Hc[1]
guess = p_opt.get("guess")
M = len(guess)
evo = QobjEvo(
[hc, helper(ctrl, idx, idx + M)], args={"p": self._guess_params}
)
H += evo
idx += M
return H.to("jax")
[docs] def _infid(self, params):
"""
calculate infidelity to be minimized
"""
# adjust integration time-interval, if time is parameter
evo_time = self._evo_time if self._var_t is False else params[-1]
X = self._solver.run(
self._initial, [0.0, evo_time], args={"p": params}
).final_state
if self._fid_type == "TRACEDIFF":
diff = X - self._target
# to prevent if/else in qobj.dag() and qobj.tr()
diff_dag = Qobj(diff.data.adjoint(), dims=diff.dims)
g = 1 / 2 * (diff_dag * diff).data.trace()
infid = jnp.real(self._norm_fac * g)
else:
g = self._norm_fac * self._target.overlap(X)
if self._fid_type == "PSU": # f_PSU (drop global phase)
infid = 1 - _abs(g) # custom_jvp for abs
elif self._fid_type == "SU": # f_SU (incl global phase)
infid = 1 - jnp.real(g)
return infid