TDVP for unitary time evolution

This example shows a basic implementation of unitary time evolution by solving a time-dependent variational principle.

 1import os
 2
 3import jax
 4jax.config.update("jax_enable_x64", True)
 5
 6import numpy as np
 7
 8import time
 9
10import jVMC
11from jVMC.util import measure
12import jVMC.operator as op
13
14import matplotlib.pyplot as plt
15
16
17L = 6
18g = -0.7
19h = 0.1
20
21dt = 1e-3  # Initial time step
22integratorTol = 1e-4  # Adaptive integrator tolerance
23tmax = 2  # Final time
24
25# Set up variational wave function
26net = jVMC.nets.CpxRBM(numHidden=10, bias=True)
27
28psi = jVMC.vqs.NQS(net, seed=1234)  # Variational wave function
29
30# Set up hamiltonian
31hamiltonian = jVMC.operator.BranchFreeOperator()
32for l in range(L):
33    hamiltonian.add(op.scal_opstr(-1., (op.Sz(l), op.Sz((l + 1) % L))))
34    hamiltonian.add(op.scal_opstr(g, (op.Sx(l), )))
35    hamiltonian.add(op.scal_opstr(h, (op.Sz(l),)))
36
37# Set up observables
38observables = {
39    "energy": hamiltonian,
40    "X": jVMC.operator.BranchFreeOperator(),
41}
42for l in range(L):
43    observables["X"].add(op.scal_opstr(1. / L, (op.Sx(l), )))
44
45sampler = None
46# Set up exact sampler
47sampler = jVMC.sampler.ExactSampler(psi, L)
48
49# Set up TDVP
50tdvpEquation = jVMC.util.tdvp.TDVP(sampler, pinvTol=1e-8,
51                                   rhsPrefactor=1.j,
52                                   makeReal='imag')
53
54t = 0.0  # Initial time
55
56# Set up stepper
57stepper = jVMC.util.stepper.AdaptiveHeun(timeStep=dt, tol=integratorTol)
58
59# Measure initial observables
60obs = measure(observables, psi, sampler)
61data = []
62data.append([t, obs["energy"]["mean"][0], obs["X"]["mean"][0]])
63
64plt.ion()
65plt.xlim(0, tmax)
66plt.ylim(0, 1)
67plt.legend()
68plt.ylabel(r"Transverse magnetization $\langle X\rangle$")
69plt.xlabel(r"Time $\langle Jt\rangle$")
70
71while t < tmax:
72    tic = time.perf_counter()
73    print(">  t = %f\n" % (t))
74
75    # TDVP step
76    dp, dt = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi)
77    psi.set_parameters(dp)
78    t += dt
79
80    # Measure observables
81    obs = measure(observables, psi, sampler)
82    data.append([t, obs["energy"]["mean"][0], obs["X"]["mean"][0]])
83
84    # Write some meta info to screen
85    print("   Time step size: dt = %f" % (dt))
86    tdvpErr, tdvpRes = tdvpEquation.get_residuals()
87    print("   Residuals: tdvp_err = %.2e, solver_res = %.2e" % (tdvpErr, tdvpRes))
88    print("    Energy = %f +/- %f" % (obs["energy"]["mean"], obs["energy"]["MC_error"]))
89    toc = time.perf_counter()
90    print("   == Total time for this step: %fs\n" % (toc - tic))
91
92    # Plot data
93    npdata = np.array(data)
94    plt.plot(npdata[:, 0], npdata[:, 2], c="red")
95    plt.pause(0.05)