TDVP for dissipative time evolution in 1D

This example shows how to use the POVM-approach in a one-dimensional setting.

  1import matplotlib.pyplot as plt
  2import jax.numpy as jnp
  3import jax
  4import jVMC
  5from functools import partial
  6jax.config.update("jax_enable_x64", True)
  7
  8
  9def copy_dict(a):
 10    b = {}
 11    for key, value in a.items():
 12        if type(value) == type(a):
 13            b[key] = copy_dict(value)
 14        else:
 15            b[key] = value
 16    return b
 17
 18
 19def norm_fun(v, df=lambda x: x):
 20    return jnp.real(jnp.conj(jnp.transpose(v)).dot(df(v)))
 21
 22
 23L = 4
 24dim = "1D"
 25logProbFactor = 1
 26inputDim = 4
 27
 28# Initialize net
 29sample_shape = (L,)
 30psi = jVMC.util.util.init_net({"batch_size": 5000, "net1":
 31                               {"type": "RNN",
 32                                "translation": True,
 33                                "parameters": {"inputDim": inputDim,
 34                                               "realValuedOutput": True,
 35                                               "realValuedParams": True,
 36                                               "logProbFactor": logProbFactor, "hiddenSize": 6, "L": L, "depth": 2}}},
 37                              sample_shape, 1234)
 38print(f"The variational ansatz has {psi.numParameters} parameters.")
 39
 40# Set up hamiltonian
 41system_data = {"dim": dim, "L": L}
 42povm = jVMC.operator.POVM(system_data)
 43Lindbladian = jVMC.operator.POVMOperator(povm)
 44for l in range(L):
 45    Lindbladian.add({"name": "ZZ", "strength": 1.0, "sites": (l, (l + 1) % L)})
 46    Lindbladian.add({"name": "X", "strength": 3.0, "sites": (l,)})
 47    Lindbladian.add({"name": "dephasing", "strength": 1.0, "sites": (l,)})
 48
 49# Set up initial state as product state
 50prob_dist = jVMC.operator.povm.get_1_particle_distributions("y_up", Lindbladian.povm)
 51prob_dist /= prob_dist[0]
 52biases = jnp.log(prob_dist[1:])
 53params = copy_dict(psi._param_unflatten(psi.get_parameters()))
 54
 55params["outputDense"]["bias"] = biases
 56params["outputDense"]["kernel"] = 1e-15 * params["outputDense"]["kernel"]
 57params = jnp.concatenate([p.ravel()
 58                          for p in jax.tree_util.tree_flatten(params)[0]])
 59psi.set_parameters(params)
 60
 61# Set up sampler
 62sampler = jVMC.sampler.ExactSampler(psi, (L,), lDim=4, logProbFactor=logProbFactor)
 63# sampler = jVMC.sampler.MCSampler(psi, (L,), random.PRNGKey(123), updateProposer=jVMC.sampler.propose_POVM_outcome, numSamples=1000)
 64
 65# Set up TDVP
 66tdvpEquation = jVMC.util.tdvp.TDVP(sampler, rhsPrefactor=-1.,
 67                                   svdTol=1e-6, diagonalShift=0, makeReal='real', crossValidation=False)
 68
 69stepper = jVMC.util.stepper.AdaptiveHeun(timeStep=1e-3, tol=1e-4)  # ODE integrator
 70
 71res = {"X": [], "Y": [], "Z": [], "X_corr_L1": [],
 72       "Y_corr_L1": [], "Z_corr_L1": []}
 73
 74times = []
 75t = 0
 76while t < 5 * 1e-0:
 77    times.append(t)
 78    result = jVMC.operator.povm.measure_povm(Lindbladian.povm, sampler)
 79    for dim in ["X", "Y", "Z"]:
 80        res[dim].append(result[dim]["mean"])
 81        res[dim + "_corr_L1"].append(result[dim + "_corr_L1"]["mean"])
 82
 83    dp, dt = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=Lindbladian, psi=psi, normFunction=partial(norm_fun, df=tdvpEquation.S_dot))
 84    t += dt
 85    psi.set_parameters(dp)
 86    print(f"t = {t:.3f}, \t dt = {dt:.2e}")
 87    if tdvpEquation.crossValidation:
 88        print(f"Cross-Validation-Factor_residual = {tdvpEquation.crossValidationFactor_residual:.3f}")
 89        print(f"Cross-Validation-Factor_tdvpErr = {tdvpEquation.crossValidationFactor_tdvpErr:.3f}")
 90
 91
 92plt.plot(times, res["X"], label=r"$\langle X \rangle$")
 93plt.plot(times, res["Y"], label=r"$\langle Y \rangle$")
 94plt.plot(times, res["Z"], label=r"$\langle Z \rangle$")
 95plt.plot(times, res["Z_corr_L1"], label=r"$\langle Z_iZ_{i+1} \rangle$", linestyle="--")
 96plt.xlabel(r"$Jt$")
 97plt.legend()
 98plt.grid()
 99plt.savefig('Lindblad_evolution.pdf')
100plt.show()