TDVP for dissipative time evolution in 2D
This example shows how to use the POVM-approach in a two-dimensional setting.
1import matplotlib.pyplot as plt
2import jax.numpy as jnp
3import jax
4import jVMC
5jax.config.update("jax_enable_x64", True)
6from functools import partial
7
8
9def copy_dict(a):
10 b = {}
11 for key, value in a.items():
12 if type(value) == type(a):
13 b[key] = copy_dict(value)
14 else:
15 b[key] = value
16 return b
17
18
19def norm_fun(v, df=lambda x: x):
20 return jnp.real(jnp.conj(jnp.transpose(v)).dot(df(v)))
21
22
23def xy_to_id(x, y, L):
24 return int(x + L * y)
25
26
27L = 2
28inputDim = 4
29logProbFactor = 1
30dim = "2D"
31
32# Initialize net
33sample_shape = (L, L)
34psi = jVMC.util.util.init_net({"batch_size": 5000, "net1": {"type": "RNN2D",
35 "translation": True,
36 "parameters": {"inputDim": 4, "logProbFactor": 1, "hiddenSize": 5, "L": L, "depth": 2, "cell": "RNN",
37 "realValuedOutput": True,
38 "realValuedParams": True}}},
39 sample_shape, 1234)
40print(f"The variational ansatz has {psi.numParameters} parameters.")
41
42# Set up hamiltonian
43system_data = {"dim": dim, "L": L}
44povm = jVMC.operator.POVM(system_data)
45Lindbladian = jVMC.operator.POVMOperator(povm)
46for x in range(L):
47 for y in range(L):
48 Lindbladian.add({"name": "ZZ", "strength": 1.0, "sites": (xy_to_id(x, y, L), xy_to_id((x + 1) % L, y, L))})
49 Lindbladian.add({"name": "ZZ", "strength": 1.0, "sites": (xy_to_id(x, y, L), xy_to_id(x, (y + 1) % L, L))})
50 Lindbladian.add({"name": "X", "strength": 3, "sites": (xy_to_id(x, y, L),)})
51 Lindbladian.add({"name": "dephasing", "strength": .5, "sites": (xy_to_id(x, y, L),)})
52
53# Set up initial state as product state
54prob_dist = jVMC.operator.povm.get_1_particle_distributions("z_up", Lindbladian.povm)
55prob_dist /= prob_dist[0]
56biases = jnp.log(prob_dist[1:])
57params = copy_dict(psi._param_unflatten(psi.get_parameters()))
58
59params["outputDense"]["bias"] = biases
60params["outputDense"]["kernel"] = 1e-15 * params["outputDense"]["kernel"]
61params = jnp.concatenate([p.ravel()
62 for p in jax.tree_util.tree_flatten(params)[0]])
63psi.set_parameters(params)
64
65# Set up sampler
66sampler = jVMC.sampler.ExactSampler(psi, sample_shape, lDim=inputDim, logProbFactor=logProbFactor)
67# sampler = jVMC.sampler.MCSampler(psi, sample_shape, random.PRNGKey(123), updateProposer=jVMC.sampler.propose_POVM_outcome, numSamples=1000)
68
69
70# Set up TDVP
71tdvpEquation = jVMC.util.tdvp.TDVP(sampler, rhsPrefactor=-1.,
72 svdTol=1e-6, diagonalShift=0, makeReal='real')
73
74stepper = jVMC.util.stepper.AdaptiveHeun(timeStep=1e-3, tol=1e-2) # ODE integrator
75
76res = {"X": [], "Y": [], "Z": [], "X_corr_L1": [],
77 "Y_corr_L1": [], "Z_corr_L1": []}
78
79times = []
80t = 0
81
82while t < 5 * 1e-0:
83 times.append(t)
84 result = jVMC.operator.povm.measure_povm(Lindbladian.povm, sampler)
85 for dim in ["X", "Y", "Z"]:
86 res[dim].append(result[dim]["mean"])
87 res[dim + "_corr_L1"].append(result[dim + "_corr_L1"]["mean"])
88
89 dp, dt = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=Lindbladian, psi=psi, normFunction=partial(norm_fun, df=tdvpEquation.S_dot))
90 t += dt
91 psi.set_parameters(dp)
92 print(f"t = {t:.3f}, \t dt = {dt:.2e}")
93
94
95plt.plot(times, res["X"], label=r"$\langle X \rangle$")
96plt.plot(times, res["Y"], label=r"$\langle Y \rangle$")
97plt.plot(times, res["Z"], label=r"$\langle Z \rangle$")
98plt.plot(times, res["Z_corr_L1"], label=r"$\langle Z_iZ_{i+1} \rangle$", linestyle="--")
99plt.xlabel(r"$Jt$")
100plt.legend()
101plt.grid()
102plt.savefig('Lindblad_evolution.pdf')
103plt.show()