jVMC
  • jVMC: Versatile and performant variational Monte Carlo

Design choices

  • Parallelism

API documentation

  • Variational quantum state module
  • Operator module
  • Sampler module
  • Basic neural network architectures
  • MPI wrapper module
  • Sample statistics module
  • Utilities

Examples

  • Example applications
    • Stochastic reconfiguration for ground state search
    • TDVP for unitary time evolution
    • Definition of custom network architectures
    • Benchmark example
    • TDVP for dissipative time evolution in 1D
    • TDVP for dissipative time evolution in 2D

Installation

  • Installation
jVMC
  • Example applications
  • Stochastic reconfiguration for ground state search
  • View page source

Stochastic reconfiguration for ground state search

This example shows a basic implementation of ground state search with stochastic reconfiguration. You can try this example in Google Colab.

 1#!/usr/bin/env python
 2# coding: utf-8
 3
 4import jax
 5jax.config.update("jax_enable_x64", True)
 6
 7import jax.random as random
 8import jax.numpy as jnp
 9import flax.linen as nn
10
11import numpy as np
12import matplotlib.pyplot as plt
13
14import jVMC
15
16L = 10
17g = -0.7
18
19# Check whether GPU is available
20GPU_avail = ( jax.lib.xla_bridge.get_backend().platform == "gpu" )
21# Initialize net
22if GPU_avail:
23    # reproduces results in Fig. 3 of the paper
24    # estimated run_time in colab (GPU enabled): ~26 minutes
25    def myActFun(x):
26        return 1 + nn.elu(x)
27    net = jVMC.nets.CNN(F=(L,), channels=(16,), strides=(1,), periodicBoundary=True, actFun=(myActFun,))
28    n_steps = 1000
29    n_Samples = 40000
30else:
31    # may be used to obtain results on Laptop CPUs
32    # estimated run_time: ~100 seconds
33    net = jVMC.nets.CpxRBM(numHidden=8, bias=False)
34    n_steps = 300
35    n_Samples = 5000
36
37
38psi = jVMC.vqs.NQS(net, seed=1234)  # Variational wave function
39
40
41def energy_single_p_mode(h_t, P):
42    return np.sqrt(1 + h_t**2 - 2 * h_t * np.cos(P))
43
44
45def ground_state_energy_per_site(h_t, N):
46    Ps = 0.5 * np.arange(- (N - 1), N - 1 + 2, 2)
47    Ps = Ps * 2 * np.pi / N
48    energies_p_modes = np.array([energy_single_p_mode(h_t, P) for P in Ps])
49    return - 1 / N * np.sum(energies_p_modes)
50
51
52exact_energy = ground_state_energy_per_site(g, L)
53print(exact_energy)
54
55# Set up hamiltonian
56hamiltonian = jVMC.operator.BranchFreeOperator()
57for l in range(L):
58    hamiltonian.add(jVMC.operator.scal_opstr(-1., (jVMC.operator.Sz(l), jVMC.operator.Sz((l + 1) % L))))
59    hamiltonian.add(jVMC.operator.scal_opstr(g, (jVMC.operator.Sx(l), )))
60
61# Set up sampler
62sampler = jVMC.sampler.MCSampler(psi, (L,), random.PRNGKey(4321), updateProposer=jVMC.sampler.propose_spin_flip_Z2,
63                                 numChains=100, sweepSteps=L,
64                                 numSamples=n_Samples, thermalizationSweeps=25)
65
66# Set up TDVP
67tdvpEquation = jVMC.util.tdvp.TDVP(sampler, rhsPrefactor=1.,
68                                   svdTol=1e-8, diagonalShift=10, makeReal='real')
69
70stepper = jVMC.util.stepper.Euler(timeStep=1e-2)  # ODE integrator
71
72res = []
73for n in range(n_steps):
74
75    dp, _ = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi, numSamples=None)
76    psi.set_parameters(dp)
77
78    print(n, jax.numpy.real(tdvpEquation.ElocMean0) / L, tdvpEquation.ElocVar0 / L)
79
80    res.append([n, jax.numpy.real(tdvpEquation.ElocMean0) / L, tdvpEquation.ElocVar0 / L])
81
82res = np.array(res)
83
84fig, ax = plt.subplots(2, 1, sharex=True, figsize=[4.8, 4.8])
85ax[0].semilogy(res[:, 0], res[:, 1] - exact_energy, '-', label=r"$L=" + str(L) + "$")
86ax[0].set_ylabel(r'$(E-E_0)/L$')
87
88ax[1].semilogy(res[:, 0], res[:, 2], '-')
89ax[1].set_ylabel(r'Var$(E)/L$')
90ax[0].legend()
91plt.xlabel('iteration')
92plt.tight_layout()
93plt.savefig('gs_search.pdf')
Previous Next

© Copyright 2020, Markus Schmitt.

Built with Sphinx using a theme provided by Read the Docs.