TDVP for dissipative time evolution in 2D

This example shows how to use the POVM-approach in a two-dimensional setting.

  1import matplotlib.pyplot as plt
  2import jax.numpy as jnp
  3import jax
  4import jVMC
  5jax.config.update("jax_enable_x64", True)
  6from functools import partial
  7
  8
  9def copy_dict(a):
 10    b = {}
 11    for key, value in a.items():
 12        if type(value) == type(a):
 13            b[key] = copy_dict(value)
 14        else:
 15            b[key] = value
 16    return b
 17
 18
 19def norm_fun(v, df=lambda x: x):
 20    return jnp.real(jnp.conj(jnp.transpose(v)).dot(df(v)))
 21
 22
 23def xy_to_id(x, y, L):
 24    return int(x + L * y)
 25
 26
 27L = 2
 28inputDim = 4
 29logProbFactor = 1
 30dim = "2D"
 31
 32# Initialize net
 33sample_shape = (L, L)
 34psi = jVMC.util.util.init_net({"batch_size": 5000, "net1": {"type": "RNN2D",
 35                                        "translation": True,
 36                                        "parameters": {"inputDim": 4, "logProbFactor": 1, "hiddenSize": 5, "L": L, "depth": 2, "cell": "RNN",
 37                                                       "realValuedOutput": True,
 38                                                       "realValuedParams": True}}},
 39                              sample_shape, 1234)
 40print(f"The variational ansatz has {psi.numParameters} parameters.")
 41
 42# Set up hamiltonian
 43system_data = {"dim": dim, "L": L}
 44povm = jVMC.operator.POVM(system_data)
 45Lindbladian = jVMC.operator.POVMOperator(povm)
 46for x in range(L):
 47    for y in range(L):
 48        Lindbladian.add({"name": "ZZ", "strength": 1.0, "sites": (xy_to_id(x, y, L), xy_to_id((x + 1) % L, y, L))})
 49        Lindbladian.add({"name": "ZZ", "strength": 1.0, "sites": (xy_to_id(x, y, L), xy_to_id(x, (y + 1) % L, L))})
 50        Lindbladian.add({"name": "X", "strength": 3, "sites": (xy_to_id(x, y, L),)})
 51        Lindbladian.add({"name": "dephasing", "strength": .5, "sites": (xy_to_id(x, y, L),)})
 52
 53# Set up initial state as product state
 54prob_dist = jVMC.operator.povm.get_1_particle_distributions("z_up", Lindbladian.povm)
 55prob_dist /= prob_dist[0]
 56biases = jnp.log(prob_dist[1:])
 57params = copy_dict(psi._param_unflatten(psi.get_parameters()))
 58
 59params["outputDense"]["bias"] = biases
 60params["outputDense"]["kernel"] = 1e-15 * params["outputDense"]["kernel"]
 61params = jnp.concatenate([p.ravel()
 62                          for p in jax.tree_util.tree_flatten(params)[0]])
 63psi.set_parameters(params)
 64
 65# Set up sampler
 66sampler = jVMC.sampler.ExactSampler(psi, sample_shape, lDim=inputDim, logProbFactor=logProbFactor)
 67# sampler = jVMC.sampler.MCSampler(psi, sample_shape, random.PRNGKey(123), updateProposer=jVMC.sampler.propose_POVM_outcome, numSamples=1000)
 68
 69
 70# Set up TDVP
 71tdvpEquation = jVMC.util.tdvp.TDVP(sampler, rhsPrefactor=-1.,
 72                                   svdTol=1e-6, diagonalShift=0, makeReal='real')
 73
 74stepper = jVMC.util.stepper.AdaptiveHeun(timeStep=1e-3, tol=1e-2)  # ODE integrator
 75
 76res = {"X": [], "Y": [], "Z": [], "X_corr_L1": [],
 77       "Y_corr_L1": [], "Z_corr_L1": []}
 78
 79times = []
 80t = 0
 81
 82while t < 5 * 1e-0:
 83    times.append(t)
 84    result = jVMC.operator.povm.measure_povm(Lindbladian.povm, sampler)
 85    for dim in ["X", "Y", "Z"]:
 86        res[dim].append(result[dim]["mean"])
 87        res[dim + "_corr_L1"].append(result[dim + "_corr_L1"]["mean"])
 88
 89    dp, dt = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=Lindbladian, psi=psi, normFunction=partial(norm_fun, df=tdvpEquation.S_dot))
 90    t += dt
 91    psi.set_parameters(dp)
 92    print(f"t = {t:.3f}, \t dt = {dt:.2e}")
 93
 94
 95plt.plot(times, res["X"], label=r"$\langle X \rangle$")
 96plt.plot(times, res["Y"], label=r"$\langle Y \rangle$")
 97plt.plot(times, res["Z"], label=r"$\langle Z \rangle$")
 98plt.plot(times, res["Z_corr_L1"], label=r"$\langle Z_iZ_{i+1} \rangle$", linestyle="--")
 99plt.xlabel(r"$Jt$")
100plt.legend()
101plt.grid()
102plt.savefig('Lindblad_evolution.pdf')
103plt.show()