Benchmark example
This example shows how the benchmark data was obtained and can be used to obtain benchmarking data on your own cluster.
1import os
2
3import jax
4jax.config.update("jax_enable_x64", True)
5
6import jax.random as random
7import jax.numpy as jnp
8import numpy as np
9
10import jVMC
11from jVMC.util.symmetries import LatticeSymmetry
12
13L = 50
14g = -0.7
15
16# Initialize net
17# net = jVMC.nets.CpxCNN(F=[15,], channels=[100], bias=False)
18orbit = LatticeSymmetry(jnp.array([jnp.roll(jnp.identity(L, dtype=np.int32), l, axis=1) for l in range(L)]))
19net = jVMC.nets.RNNsym(orbit=orbit, hiddenSize=15, L=L, depth=5)
20
21psi = jVMC.vqs.NQS(net, batchSize=500, seed=1234) # Variational wave function
22print(f"The variational ansatz has {psi.numParameters} parameters.")
23
24# Set up hamiltonian
25hamiltonian = jVMC.operator.BranchFreeOperator()
26for l in range(L):
27 hamiltonian.add(jVMC.operator.scal_opstr(-1., (jVMC.operator.Sz(l), jVMC.operator.Sz((l + 1) % L))))
28 hamiltonian.add(jVMC.operator.scal_opstr(g, (jVMC.operator.Sx(l), )))
29
30# Set up sampler
31sampler = jVMC.sampler.MCSampler(psi, (L,), random.PRNGKey(4321), updateProposer=jVMC.sampler.propose_spin_flip_Z2,
32 numChains=50, sweepSteps=L,
33 numSamples=300000, thermalizationSweeps=0)
34
35# Set up TDVP
36tdvpEquation = jVMC.util.tdvp.TDVP(sampler, rhsPrefactor=1.,
37 svdTol=1e-8, diagonalShift=10, makeReal='real')
38
39stepper = jVMC.util.stepper.Euler(timeStep=1e-2) # ODE integrator
40
41# Set up OutputManager
42wdir = "./benchmarks/"
43if jVMC.mpi_wrapper.rank == 0:
44 try:
45 os.makedirs(wdir)
46 except OSError:
47 print("Creation of the directory %s failed" % wdir)
48 else:
49 print("Successfully created the directory %s " % wdir)
50outp = jVMC.util.OutputManager("./benchmarks/data.hdf5", append=False)
51
52res = []
53for n in range(3):
54
55 dp, _ = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi, numSamples=None, outp=outp)
56 psi.set_parameters(dp)
57
58 print("Benchmarking data")
59 total = 0
60 for key, value in outp.timings.items():
61 print("\taverage and latest timings of ", key)
62 print("\t", value["total"] / value["count"])
63 print("\t", value["newest"])
64 total += value["newest"]
65 print("\t=== Total: ", total)