jVMC: Versatile and performant variational Monte Carlo
This package, available on GitHub, provides a versatile and efficient implementation of variational quantum Monte Carlo in Python. It utilizes Google’s JAX library to exploit the blessings of automatic differentiation and just-in-time compilation for available computing resources. The package is devised to serve as transparent implementation of the core computational tasks that guarantees efficiency, while at the same time providing large flexibility.
In particular, jVMC provides a framework that allows to work with arbitrary variational wave functions and quantum operators. The code was written mainly with neural quantum states (NQS) as variational wave functions in mind, but it is not restricted to that; the ansatz wave functions can be arbitrary parametrized programs. Nonetheless, throughout the documentation we will refer to the variational ansätze as “networks”.
Design choices
Variational wave functions
A core part of this codebase is the NQS class, an abstract wrapper class
for variational wave functions, which proves an interface that other parts
of the code rely on. At initialization the specific variational wave function
is passed to NQS in the form of a Flax
module. Flax is a
library that supplements JAX with a class structure to enable simple
implementation of neural networks (and more) based on modules as it is
known also from Pytorch.
Parallelism
The performance of the code relies on a few design choices, which enable
efficient computation for typical use cases on a desktop device as well as
on distributed multi GPU clusters. An important manifestation of these
choices are required array dimensions when interfacing jVMC: All data
that is related to network evaluations will have two leading dimensions,
namely the device dimension and the batch dimension. Distributed
computing is enabled using the MPI through the mpi4py package.
See Parallelism for details.
Example
The core task in Variational Monte Carlo is sampling from the Born distribution \(|\psi_\theta(s)|^2\) of a variational wave function \(\psi_\theta(s)\) and computing the mean of local estimators
\(O_{loc}(s)=\sum_{s'}O_{s,s'}\frac{\psi_\theta(s')}{\psi_\theta(s)}\) ,
where \(O_{s,s'}=\langle s|\hat O|s'\rangle\) are the matrix elements of some quantum operator \(\hat O\) and \(|s\rangle\) denotes a computational basis.
Assume that op is an Operator object (see Operator module)
corresponding to \(\hat O\) and psi is an NQS object implementing
the variational wave function (see Variational quantum state module). Moreover, assume that sampler is a suited
Sampler object (see Sampler module). Then, estimating an operator expectation value
using the jVMC framework boils down to:
s, logPsi, _ = sampler.sample() # Get samples (parallelized)
sPrime, _ = op.get_s_primes(sampleConfigs) # Get s', where O_{s,s'}!=0
logPsiOffd = psi(sPrime) # Evaluate wave function on s'
Oloc = get_O_loc(logPsi, logPsiOffd) # Compute local estimator
Omean = jVMC.mpi_wrapper.get_global_mean(Oloc) # Compute mean of all processes
Also computing the variational derivatives \(\partial_{\theta_k}\log\psi_\theta(s)\)
is straightforward when using the NQS class:
grad_psi = psi.gradients(s)
See the Example applications section for a number of more elaborate example applications.