MPI wrapper module

The jVMC.mpi_wrapper module wraps typically required MPI communications, for which the mpi4py package is used. This means especially the statistical evaluation of Monte Carlo samples. For this purpose, it can interact with the sampler classes from the jVMC.sampler module as follows:

Example:

Assuming that sampler is an instance of a sampler class, psi is a variational quantum state, and op is an instance of a class derived from the Operator class, associated with a quantum operator \(\hat O\). Then we can get samples and the corresponding \(O_{loc}(\mathbf s)\) via

>>> s, logPsi, _ = sampler.sample()
>>> sPrime, _ = op.get_s_primes(sampleConfigs)
>>> logPsiOffd = psi(sPrime)
>>> Oloc = get_O_loc(logPsi, logPsiOffd)

Now, on each MPI process Oloc is a two-dimensional array of size (number of devices) \(\times\) (number of samples per device). To get, for example, the Monte Carlo estimate of the expectation value of \(\hat O\),

\(\langle\hat O\rangle\approx\frac{1}{N_S}\sum_{j=1}^{N_S}O_{loc}(\mathbf s_j)\)

we can use the jVMC.mpi_wrapper.get_global_mean function

>>> Omean = jVMC.mpi_wrapper.get_global_mean(Oloc)

Thereby, we obtain the mean computed across all MPI processes and local devices.

jVMC.mpi_wrapper.bcast_unknown_size(data, root=0)

Broadcast a one-dimensional array.

This function broadcasts the input data array to all MPI processes.

Arguments:
  • data: One dimensional array of datatype np.float64.

  • root: Rank of root process.

Returns:

On each MPI process the data received from the root process.

jVMC.mpi_wrapper.distribute_sampling(numSamples, localDevices=None, numChainsPerDevice=1) int

Distribute sampling tasks across processes and devices.

For a desired total number of samples this function determines how many samples should be generated by each Monte Carlo chain.

It is assumed that a given number of MC chains is running in parallel on each device, and that each MPI process can potentially utilize multiple devices. Since the numbers of samples per chain have to be identical accross the devices of one MPI process, the resulting total number of samples can slightly exceed the requested number of samples.

Arguments:
  • numSamples: Total number of samples.

  • localDevices: Number of devices per MPI process.

  • numChainsPerDevice: Number of chains run in parallel on each device.

Returns:

Number of samples to be generated per device to reach the desired total number of samples.

jVMC.mpi_wrapper.gather(data)

Gathers and distributes data all-to-all. The returned data is therefore of shape (commSize * data.shape[0],) + data.shape[1:].

Arguments:
  • data: Array of input data.

Returns:

Concatenated data from all devices.

jVMC.mpi_wrapper.global_covariance(data, p)

Computes the covariance matrix of input data across MPI processes and device/batch dimensions.

On each MPI process the input data is assumed to be a jax.numpy.array with a leading device dimension followed by a batch dimension and one data dimension. The data is reduced by computing the covariance matrix along device and batch dimensions as well as accross MPI processes. Hence, the result is an array of shape data.shape[2] \(\times\) data.shape[2].

The mean is computed using the given probabilities, i.e.,

\(\text{Cov}(X)=\sum_{j=1}^{N_S} p_jX_j\cdot X_j^\dagger - \bigg(\sum_{j=1}^{N_S} p_jX_j\bigg)\cdot\bigg(\sum_{j=1}^{N_S}p_jX_j^\dagger\bigg)\)

Arguments:
  • data: Array of input data.

  • p: Probabilities associated with the given data.

Returns:

Covariance matrix of data across MPI processes and device/batch dimensions.

jVMC.mpi_wrapper.global_mean(data, p)

Computes the mean of input data across MPI processes and device/batch dimensions.

On each MPI process the input data is assumed to be a jax.numpy.array with a leading device dimension followed by a batch dimension. The data is reduced by computing the mean along device and batch dimensions as well as accross MPI processes. Hence, the result is an array of shape data.shape[2:].

The mean is computed using the given probabilities, i.e.,

\(\langle X\rangle=\sum_{j=1}^{N_S} p_jX_j\)

Arguments:
  • data: Array of input data.

  • p: Probabilities associated with the given data.

Returns:

Mean of data across MPI processes and device/batch dimensions.

jVMC.mpi_wrapper.global_sum(data)

Computes the sum of input data across MPI processes and device/batch dimensions.

On each MPI process the input data is assumed to be a jax.numpy.array with a leading device dimension followed by a batch dimension. The data is reduced by summing up along device and batch dimensions as well as accross MPI processes. Hence, the result is an array of shape data.shape[2:].

Arguments:
  • data: Array of input data.

Returns:

Sum of data across MPI processes and device/batch dimensions.

jVMC.mpi_wrapper.global_variance(data, p)

Computes the variance of input data across MPI processes and device/batch dimensions.

On each MPI process the input data is assumed to be a jax.numpy.array with a leading device dimension followed by a batch dimension. The data is reduced by computing the variance along device and batch dimensions as well as accross MPI processes. Hence, the result is an array of shape data.shape[2:].

The mean is computed using the given probabilities, i.e.,

\(\text{Var}(X)=\sum_{j=1}^{N_S} p_j |X_j-\langle X\rangle|^2\)

Arguments:
  • data: Array of input data.

  • p: Probabilities associated with the given data.

Returns:

Variance of data across MPI processes and device/batch dimensions.