Source code for qutip_qoc.result

"""
This module contains the Result class for storing and
reporting the results of a full pulse control optimization run.
"""
import pickle
import textwrap
import numpy as np
from inspect import signature
import warnings

import qutip as qt

try:
    import jax
    import jaxlib
    _jitfun_type = type(jax.jit(lambda x: x))
except ImportError:
    _jitfun_type = None

__all__ = ["Result"]


class _Stats:
    """
    Only for backward compatibility with qtrl.
    """

    def __init__(self, result):
        self._result = result

    def report(self):
        print(self._result)


[docs]class Result: """ Class for storing the results of a pulse control optimization run. Attributes ---------- objectives : list of :class:`qutip_qoc.Objective` List of objectives to be optimized. time_interval : :class:`qutip_qoc._TimeInterval` Time interval for the optimization. start_local_time : struct_time Time when the optimization started. end_local_time : struct_time Time when the optimization ended. total_seconds : float Total time in seconds the optimization took. Equal to the sum of iter_seconds. Equal to difference between end_local_time and start_local_time. iters : int Number of iterations until convergence. Equal to the length of iter_seconds. iter_seconds : list of float Seconds between each iteration. message : str Reason for termination. optimized_params : list of ndarray List of optimized parameters. guess_controls : list of ndarray List of guess control pulses used to initialize the optimization. optimized_controls : list of ndarray List of optimized control pulses. optimized_H : list of :class:`qutip.QobjEvo` A specification of the time-depedent quantum object one for each objective (see :class:`qutip_qoc.Objective` H attribute). with optimized control amplitudes. final_states : list of :class:`qutip.Qobj` List of final states after the optimization. One for each objective. infidelity : float Final infidelity error after the optimization. var_time : bool Whether the optimization was performed with variable time. If True, the last parameter in optimized_params is the evolution time. """ def __init__( self, objectives=None, time_interval=None, start_local_time=None, end_local_time=None, total_seconds=None, n_iters=None, iter_seconds=None, message=None, guess_controls=None, optimized_controls=None, optimized_H=None, final_states=None, guess_params=None, new_params=None, optimized_params=None, infidelity=np.inf, var_time=False, qtrl_optimizers=None, ): self.time_interval = time_interval self.objectives = objectives self.start_local_time = start_local_time self.end_local_time = end_local_time self._total_seconds = total_seconds self.n_iters = n_iters self.iter_seconds = iter_seconds self.message = message self._guess_controls = guess_controls self._optimized_controls = optimized_controls self._optimized_H = optimized_H self.guess_params = guess_params self.new_params = new_params self._optimized_params = optimized_params self._final_states = final_states self.infidelity = infidelity self.var_time = var_time self.qtrl_optimizers = qtrl_optimizers # qtrl result backward compatibility self.stats = _Stats(self) def __str__(self): time_optim_summary = ( "- Optimized time parameter: " + str(self.optimized_params[-1]) if self.var_time else "" ) return textwrap.dedent( r""" Control Optimization Result -------------------------- - Started at {start_local_time} - Number of objectives: {n_objectives} - Final fidelity error: {final_infid} - Final parameters: {final_params} - Number of iterations: {n_iters} - Reason for termination: {message} {time_optim_summary} - Ended at {end_local_time} ({time_delta}s) """.format( start_local_time=self.start_local_time, n_objectives=len(self.objectives), final_infid=self.infidelity, final_params=self.optimized_params, n_iters=self.n_iters, end_local_time=self.end_local_time, time_delta=self.total_seconds, time_optim_summary=time_optim_summary, message=self.message, ) ).strip() def __repr__(self): return self.__str__() @property def total_seconds(self): """ Total time in seconds the optimization took. """ if self._total_seconds is None: self._total_seconds = sum(self.iter_seconds) return self._total_seconds @property def optimized_params(self): """ Parameter values after optimization. """ if self._optimized_params is None: # reshape (optimized) new_parameters array to match # shape and type of the guess_parameters list if self.qtrl_optimizers and len(self.guess_params[0]) == len( self.time_interval.tslots ): # GRAPE amps = self.qtrl_optimizers[0]._get_ctrl_amps(self.new_params) opt_params = amps.T else: # GOAT, JOPT, CRAB opt_params, idx = [], 0 for guess in self.guess_params: opt = self.new_params[idx : idx + len(guess)] if isinstance(guess, list): opt = opt.tolist() opt_params.append(opt) idx += len(guess) self._optimized_params = opt_params return self._optimized_params @optimized_params.setter def optimized_params(self, params): self._optimized_params = params @property def optimized_controls(self): """ Control pulses after optimization. """ if self._optimized_controls is None: opt_ctrl = [] for j, H in enumerate(zip(self.objectives[0].H[1:], self.optimized_params)): Hc, xf = H control, cf = Hc[1], [] if not self.qtrl_optimizers: # continuous control as in JOPT/GOAT try: tslots = self.time_interval.tslots except Exception: print( "time_interval.tslots not specified " "(probably missing n_tslots), defaulting to 100 " "collocation points for result.optimized_controls" ) tslots = np.linspace(0.0, self.time_interval.evo_time, 100) for t in tslots: cf.append(control(t, xf)) else: # discrete control as in GRAPE/CRAB if len(xf) == len(self.time_interval.tslots): cf = np.array(xf) else: # parameterized CRAB pgen = self.qtrl_optimizers[0].pulse_generator[j] pgen.set_optim_var_vals(np.array(self.optimized_params[j])) cf = np.array(pgen.gen_pulse()) opt_ctrl.append(cf) self._optimized_controls = opt_ctrl return self._optimized_controls @property def guess_controls(self): """ Control pulses before the optimization. """ if self._guess_controls is None: if self.qtrl_optimizers: qtrl_res = self.qtrl_optimizers[0]._create_result() gss_ctrl = qtrl_res.initial_amps.T else: gss_ctrl = [] for j, H in enumerate(zip(self.objectives[0].H[1:], self.guess_params)): Hc, xi = H control, c0 = Hc[1], [] if callable(control): # continuous control as in JOPT/GOAT try: tslots = self.time_interval.tslots except Exception: print( "time_interval.tslots not specified " "(probably missing n_tslots), defaulting to 100 " "collocation points for result.optimized_controls" ) tslots = np.linspace(0.0, self.time_interval.evo_time, 100) for t in tslots: c0.append(control(t, xi)) else: # discrete control as in GRAPE/CRAB if len(xi) == len(self.time_interval.tslots): c0 = xi else: # parameterized CRAB pgen = self.qtrl_optimizers[0].pulse_generator[j] pgen.set_optim_var_vals(np.array(self.guess_params[j])) c0 = pgen.gen_pulse() gss_ctrl.append(c0) self._guess_controls = gss_ctrl return self._guess_controls @property def optimized_H(self): """ Optimized Hamiltonians with optimized controls. """ if self._optimized_H is None: opt_H = [] for obj in self.objectives: # Create the optimized Hamiltonian with optimized controls if not self.qtrl_optimizers: # GOAT, JOPT H = obj.H else: H = [obj.H[0]] # drift for Hc, cf in zip(obj.H[1:], self.optimized_controls): if isinstance(Hc, qt.Qobj): # parameterized CRAB H.append([Hc, cf]) else: # discrete control as in GRAPE, CRAB H.append([Hc[0], cf]) # Create the corresponding QobjEvo object para_keys = [] args_dict = {} if not self.qtrl_optimizers: # GOAT, JOPT # extract parameter names from control functions f(t, para_key) c_sigs = [signature(Hc[1]) for Hc in self.objectives[0].H[1:]] c_keys = [sig.parameters.keys() for sig in c_sigs] para_keys = [list(keys)[1] for keys in c_keys] for key, val in zip(para_keys, self.optimized_params): args_dict[key] = val H_evo = ( qt.QobjEvo(H, args=args_dict) if args_dict # GOAT, JOPT else qt.QobjEvo(H, tlist=self.time_interval.tslots) ) opt_H.append(H_evo) self._optimized_H = opt_H return self._optimized_H @property def final_states(self): """ Evolved system states after optimization. """ if self._final_states is None: states = [] if self.var_time: # last parameter is optimized time evo_time = self.optimized_params[-1][0] else: evo_time = self.time_interval.evo_time # choose solver method based on type of control function # if jax is installed, _jitfun_type is set to # jaxlib.xla_extension.PjitFunction, otherwise it is None if _jitfun_type is not None and isinstance( self.objectives[0].H[1][1], _jitfun_type ): method = "diffrax" # for JAX defined contols else: method = "adams" for obj, opt_H in zip(self.objectives, self.optimized_H): if opt_H.issuper: # choose solver solver = qt.MESolver( opt_H, options={ "normalize_output": False, "method": method, }, ) else: solver = qt.SESolver( opt_H, options={ "normalize_output": False, "method": method, }, ) states.append( # compute evolution solver.run(obj.initial, tlist=[0.0, evo_time]).final_state ) self._final_states = states return self._final_states
[docs] def _update(self, infidelity, parameters): """ Used to update the result during optimization. """ self.infidelity = infidelity self.new_params = parameters
[docs] def dump(self, filename): """ Save the result to a file. """ with open(filename, "wb") as dump_fh: pickler = pickle.Pickler(dump_fh) pickler.dump(self)
[docs] @classmethod def load(cls, filename, objectives=None): """ Load a objective from a file. """ with open(filename, "rb") as dump_fh: result = pickle.load(dump_fh) result.objectives = objectives return result
@property def evo_full_final(self): """ Deprecated, use final_states[0] instead. """ warnings.warn( "evo_full_final is deprecated, use final_states[0] instead", DeprecationWarning, ) return self.final_states[0] @property def fid_err(self): """ Deprecated, use infidelity instead. """ warnings.warn( "fid_err is deprecated, use infidelity instead", DeprecationWarning ) return self.infidelity @property def grad_norm_final(self): """ Deprecated, not supported. """ warnings.warn( "grad_norm_final is deprecated, it is not supported", DeprecationWarning ) return None # not supported @property def termination_reason(self): """ Deprecated, use message instead. """ warnings.warn( "termination_reason is deprecated, use message instead", DeprecationWarning ) return self.message @property def num_iter(self): """ Deprecated, use n_iters instead. """ warnings.warn("num_iter is deprecated, use n_iters instead", DeprecationWarning) return self.n_iters @property def wall_time(self): """ Deprecated, use total_seconds instead. """ warnings.warn( "wall_time is deprecated, use total_seconds instead", DeprecationWarning ) return self.total_seconds