Parallelism

The performance of the code relies on a few design choices, which enable efficient computation for typical use cases on a desktop device as well as on distributed multi GPU clusters. An important manifestation of these choices are required array dimensions when interfacing jVMC: All data that is related to network evaluations will have two leading dimensions, namely the device dimension and the batch dimension.

Example:

Assume we are using a computer that has a single GPU accelerator. Moreover, we are working on a one-dimensional system of size \(L=7\). We have defined a suited MCMCSampler object called sampler and also a NQS object called psi. Then we can inspect the shape of the configuration array generated by the sampler, when asking for 10 configurations, as follows:

>>> configs, _, _ = sampler.sample(psi, 10)
>>> configs.shape
(1,10,7)

The first dimension is the device dimension, which takes the value 1, because we have a single GPU. The second dimension is the batch dimension, which equals the number of samples we asked for. The following dimensions correspond to the physical system, in this case, \(L=7\).

Now, let’s evaluate the wave function on this batch of configurations:

>>> logPsi = psi(configs)
>>> logPsi.shape
(1,10)

We obtain one wave function coefficient per configuration (size of the batch dimension) and keep the leading device dimension.

Finally, we might need gradients:

>>> g = psi.gradients(configs)
>>> g.shape
(1,10,123)

Again, we see the known leading dimensions. Moreover, our variational wave function psi seems to have 123 parameters, which explains the size of the last dimension.

The following provides a more detailed explanation of this choice.

Intrinsic parallelism

In variational Monte Carlo multiple levels of parallelism can be exploited.

First, Monte Carlo sampling is an “embarrassingly parallel” task that is straightforward to distribute among multiple processes. jVMC employs MPI to parallelize Monte Carlo sampling, allowing to utilize multiple cores locally or many nodes of a cluster.

At the lower level, the algorithm requires independent wave function evaluations on large numbers of computational basis states. This task is well suited for Single-Input Multiple-Data (SIMD) parallelization schemes; even more so, when the evaluation involves suited operations like matrix-vector products in the case of NQS. Therefore, jVMC batches operations wherever possible, relying on JAX just-in-time-compilation to generate efficient code that exploits the local computing resources.

Multiple accelerators per node

There are different ways to deal with computers that provide multiple accelerators per compute node, namely:

  1. Launch one MPI process per accelerator.

  2. Distribute computation across the available devices, while working with a single process.

jVMC supports both options. The jVMC.global_defs.set_pmap_devices() function enables the user at the beginning of a program to choose for each MPI process which subset of the available devices to work with. For a homogeneous treatment of both choices, any data arrays passed to or obtained from the jVMC API have an additional leading device dimension to account for potential parallelization across devices. The size of this dimension corresponds to the number of devices used by the process and any computation will be distributed among the devices.

It is important to realize and keep in mind that when working with multiple devices the device dimension is also physically distributed across the different devices. Hence, any computation on data with device dimension larger than one should be performed on the respective devices to avoid memory transfer overheads.

Batching

To guarantee high arithmetic intensity suited batching of computational tasks is crucial. This applies in particular to wave function evaluations. Therefore, any operation implemented in jVMC is performed on a batch of input data. This means that following the leading device dimension, any data arrays passed to or obtained from the jVMC API have an additional batch dimension.

MPI

If you are running your simulation with multiple MPI processes, jVMC automatically parallelized the Monte Carlo sampling tasks across these. If your machine has multiple GPUs attached to each node and all of them are visible, you may want to assign one GPU per MPI process. This is achieved by the following line:

jVMC.global_defs.set_pmap_devices(jax.devices()[jVMC.mpi_wrapper.rank % jax.device_count()])