Benchmark example

This example shows how the benchmark data was obtained and can be used to obtain benchmarking data on your own cluster.

 1import os
 2
 3import jax
 4jax.config.update("jax_enable_x64", True)
 5
 6import jax.random as random
 7import jax.numpy as jnp
 8import numpy as np
 9
10import jVMC
11from jVMC.util.symmetries import LatticeSymmetry
12
13L = 50
14g = -0.7
15
16# Initialize net
17# net = jVMC.nets.CpxCNN(F=[15,], channels=[100], bias=False)
18orbit = LatticeSymmetry(jnp.array([jnp.roll(jnp.identity(L, dtype=np.int32), l, axis=1) for l in range(L)]))
19net = jVMC.nets.RNNsym(orbit=orbit, hiddenSize=15, L=L, depth=5)
20
21psi = jVMC.vqs.NQS(net, batchSize=500, seed=1234)  # Variational wave function
22print(f"The variational ansatz has {psi.numParameters} parameters.")
23
24# Set up hamiltonian
25hamiltonian = jVMC.operator.BranchFreeOperator()
26for l in range(L):
27    hamiltonian.add(jVMC.operator.scal_opstr(-1., (jVMC.operator.Sz(l), jVMC.operator.Sz((l + 1) % L))))
28    hamiltonian.add(jVMC.operator.scal_opstr(g, (jVMC.operator.Sx(l), )))
29
30# Set up sampler
31sampler = jVMC.sampler.MCSampler(psi, (L,), random.PRNGKey(4321), updateProposer=jVMC.sampler.propose_spin_flip_Z2,
32                                 numChains=50, sweepSteps=L,
33                                 numSamples=300000, thermalizationSweeps=0)
34
35# Set up TDVP
36tdvpEquation = jVMC.util.tdvp.TDVP(sampler, rhsPrefactor=1.,
37                                   svdTol=1e-8, diagonalShift=10, makeReal='real')
38
39stepper = jVMC.util.stepper.Euler(timeStep=1e-2)  # ODE integrator
40
41# Set up OutputManager
42wdir = "./benchmarks/"
43if jVMC.mpi_wrapper.rank == 0:
44    try:
45        os.makedirs(wdir)
46    except OSError:
47        print("Creation of the directory %s failed" % wdir)
48    else:
49        print("Successfully created the directory %s " % wdir)
50outp = jVMC.util.OutputManager("./benchmarks/data.hdf5", append=False)
51
52res = []
53for n in range(3):
54
55    dp, _ = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi, numSamples=None, outp=outp)
56    psi.set_parameters(dp)
57
58    print("Benchmarking data")
59    total = 0
60    for key, value in outp.timings.items():
61        print("\taverage and latest timings of ", key)
62        print("\t", value["total"] / value["count"])
63        print("\t", value["newest"])
64        total += value["newest"]
65    print("\t=== Total: ", total)