TDVP for unitary time evolution
This example shows a basic implementation of unitary time evolution by solving a time-dependent variational principle.
1import os
2
3import jax
4jax.config.update("jax_enable_x64", True)
5
6import numpy as np
7
8import time
9
10import jVMC
11from jVMC.util import measure
12import jVMC.operator as op
13
14import matplotlib.pyplot as plt
15
16
17L = 6
18g = -0.7
19h = 0.1
20
21dt = 1e-3 # Initial time step
22integratorTol = 1e-4 # Adaptive integrator tolerance
23tmax = 2 # Final time
24
25# Set up variational wave function
26net = jVMC.nets.CpxRBM(numHidden=10, bias=True)
27
28psi = jVMC.vqs.NQS(net, seed=1234) # Variational wave function
29
30# Set up hamiltonian
31hamiltonian = jVMC.operator.BranchFreeOperator()
32for l in range(L):
33 hamiltonian.add(op.scal_opstr(-1., (op.Sz(l), op.Sz((l + 1) % L))))
34 hamiltonian.add(op.scal_opstr(g, (op.Sx(l), )))
35 hamiltonian.add(op.scal_opstr(h, (op.Sz(l),)))
36
37# Set up observables
38observables = {
39 "energy": hamiltonian,
40 "X": jVMC.operator.BranchFreeOperator(),
41}
42for l in range(L):
43 observables["X"].add(op.scal_opstr(1. / L, (op.Sx(l), )))
44
45sampler = None
46# Set up exact sampler
47sampler = jVMC.sampler.ExactSampler(psi, L)
48
49# Set up TDVP
50tdvpEquation = jVMC.util.tdvp.TDVP(sampler, pinvTol=1e-8,
51 rhsPrefactor=1.j,
52 makeReal='imag')
53
54t = 0.0 # Initial time
55
56# Set up stepper
57stepper = jVMC.util.stepper.AdaptiveHeun(timeStep=dt, tol=integratorTol)
58
59# Measure initial observables
60obs = measure(observables, psi, sampler)
61data = []
62data.append([t, obs["energy"]["mean"][0], obs["X"]["mean"][0]])
63
64plt.ion()
65plt.xlim(0, tmax)
66plt.ylim(0, 1)
67plt.legend()
68plt.ylabel(r"Transverse magnetization $\langle X\rangle$")
69plt.xlabel(r"Time $\langle Jt\rangle$")
70
71while t < tmax:
72 tic = time.perf_counter()
73 print("> t = %f\n" % (t))
74
75 # TDVP step
76 dp, dt = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi)
77 psi.set_parameters(dp)
78 t += dt
79
80 # Measure observables
81 obs = measure(observables, psi, sampler)
82 data.append([t, obs["energy"]["mean"][0], obs["X"]["mean"][0]])
83
84 # Write some meta info to screen
85 print(" Time step size: dt = %f" % (dt))
86 tdvpErr, tdvpRes = tdvpEquation.get_residuals()
87 print(" Residuals: tdvp_err = %.2e, solver_res = %.2e" % (tdvpErr, tdvpRes))
88 print(" Energy = %f +/- %f" % (obs["energy"]["mean"], obs["energy"]["MC_error"]))
89 toc = time.perf_counter()
90 print(" == Total time for this step: %fs\n" % (toc - tic))
91
92 # Plot data
93 npdata = np.array(data)
94 plt.plot(npdata[:, 0], npdata[:, 2], c="red")
95 plt.pause(0.05)