.. _Parallelism: Parallelism ----------- The performance of the code relies on a few design choices, which enable efficient computation for typical use cases on a desktop device as well as on distributed multi GPU clusters. An important manifestation of these choices are required array dimensions when interfacing `jVMC`: All data that is related to network evaluations will have two leading dimensions, namely the `device dimension` and the `batch dimension`. **Example:** Assume we are using a computer that has a single GPU accelerator. Moreover, we are working on a one-dimensional system of size :math:`L=7`. We have defined a suited ``MCMCSampler`` object called ``sampler`` and also a ``NQS`` object called ``psi``. Then we can inspect the shape of the configuration array generated by the sampler, when asking for ``10`` configurations, as follows: >>> configs, _, _ = sampler.sample(psi, 10) >>> configs.shape (1,10,7) The first dimension is the `device dimension`, which takes the value ``1``, because we have a single GPU. The second dimension is the `batch dimension`, which equals the number of samples we asked for. The following dimensions correspond to the physical system, in this case, :math:`L=7`. Now, let's evaluate the wave function on this batch of configurations: >>> logPsi = psi(configs) >>> logPsi.shape (1,10) We obtain one wave function coefficient per configuration (size of the `batch dimension`) and keep the leading `device dimension`. Finally, we might need gradients: >>> g = psi.gradients(configs) >>> g.shape (1,10,123) Again, we see the known leading dimensions. Moreover, our variational wave function ``psi`` seems to have 123 parameters, which explains the size of the last dimension. The following provides a more detailed explanation of this choice. Intrinsic parallelism ^^^^^^^^^^^^^^^^^^^^^ In variational Monte Carlo multiple levels of parallelism can be exploited. First, Monte Carlo sampling is an "embarrassingly parallel" task that is straightforward to distribute among multiple processes. `jVMC` employs MPI to parallelize Monte Carlo sampling, allowing to utilize multiple cores locally or many nodes of a cluster. At the lower level, the algorithm requires independent wave function evaluations on large numbers of computational basis states. This task is well suited for Single-Input Multiple-Data (SIMD) parallelization schemes; even more so, when the evaluation involves suited operations like matrix-vector products in the case of NQS. Therefore, `jVMC` batches operations wherever possible, relying on JAX just-in-time-compilation to generate efficient code that exploits the local computing resources. Multiple accelerators per node ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ There are different ways to deal with computers that provide multiple accelerators per compute node, namely: (i) Launch one MPI process per accelerator. (ii) Distribute computation across the available devices, while working with a single process. `jVMC` supports both options. The ``jVMC.global_defs.set_pmap_devices()`` function enables the user at the beginning of a program to choose for each MPI process which subset of the available devices to work with. For a homogeneous treatment of both choices, any data arrays passed to or obtained from the `jVMC` API have an additional leading `device dimension` to account for potential parallelization across devices. The size of this dimension corresponds to the number of devices used by the process and any computation will be distributed among the devices. It is important to realize and keep in mind that when working with multiple devices the `device dimension` is also `physically` distributed across the different devices. Hence, any computation on data with `device dimension` larger than one should be performed on the respective devices to avoid memory transfer overheads. Batching ^^^^^^^^ To guarantee high arithmetic intensity suited batching of computational tasks is crucial. This applies in particular to wave function evaluations. Therefore, any operation implemented in `jVMC` is performed on a batch of input data. This means that following the leading `device dimension`, any data arrays passed to or obtained from the `jVMC` API have an additional `batch dimension`. MPI ^^^ If you are running your simulation with multiple MPI processes, `jVMC` automatically parallelized the Monte Carlo sampling tasks across these. If your machine has multiple GPUs attached to each node and all of them are visible, you may want to assign one GPU per MPI process. This is achieved by the following line: .. code-block:: jVMC.global_defs.set_pmap_devices(jax.devices()[jVMC.mpi_wrapper.rank % jax.device_count()])