MPI wrapper module
The jVMC.mpi_wrapper module wraps typically required MPI communications, for which the mpi4py package is used.
This means especially the statistical evaluation of
Monte Carlo samples. For this purpose, it can interact with the sampler classes from the jVMC.sampler module as follows:
Example:
Assuming that
sampleris an instance of a sampler class,psiis a variational quantum state, andopis an instance of a class derived from theOperatorclass, associated with a quantum operator \(\hat O\). Then we can get samples and the corresponding \(O_{loc}(\mathbf s)\) via>>> s, logPsi, _ = sampler.sample() >>> sPrime, _ = op.get_s_primes(sampleConfigs) >>> logPsiOffd = psi(sPrime) >>> Oloc = get_O_loc(logPsi, logPsiOffd)Now, on each MPI process
Olocis a two-dimensional array of size (number of devices) \(\times\) (number of samples per device). To get, for example, the Monte Carlo estimate of the expectation value of \(\hat O\),\(\langle\hat O\rangle\approx\frac{1}{N_S}\sum_{j=1}^{N_S}O_{loc}(\mathbf s_j)\)
we can use the
jVMC.mpi_wrapper.get_global_meanfunction>>> Omean = jVMC.mpi_wrapper.get_global_mean(Oloc)Thereby, we obtain the mean computed across all MPI processes and local devices.
- jVMC.mpi_wrapper.bcast_unknown_size(data, root=0)
Broadcast a one-dimensional array.
This function broadcasts the input data array to all MPI processes.
- Arguments:
data: One dimensional array of datatypenp.float64.root: Rank of root process.
- Returns:
On each MPI process the data received from the root process.
- jVMC.mpi_wrapper.distribute_sampling(numSamples, localDevices=None, numChainsPerDevice=1) int
Distribute sampling tasks across processes and devices.
For a desired total number of samples this function determines how many samples should be generated by each Monte Carlo chain.
It is assumed that a given number of MC chains is running in parallel on each device, and that each MPI process can potentially utilize multiple devices. Since the numbers of samples per chain have to be identical accross the devices of one MPI process, the resulting total number of samples can slightly exceed the requested number of samples.
- Arguments:
numSamples: Total number of samples.localDevices: Number of devices per MPI process.numChainsPerDevice: Number of chains run in parallel on each device.
- Returns:
Number of samples to be generated per device to reach the desired total number of samples.
- jVMC.mpi_wrapper.gather(data)
Gathers and distributes data all-to-all. The returned data is therefore of shape
(commSize * data.shape[0],) + data.shape[1:].- Arguments:
data: Array of input data.
- Returns:
Concatenated data from all devices.
- jVMC.mpi_wrapper.global_covariance(data, p)
Computes the covariance matrix of input data across MPI processes and device/batch dimensions.
On each MPI process the input data is assumed to be a
jax.numpy.arraywith a leading device dimension followed by a batch dimension and one data dimension. The data is reduced by computing the covariance matrix along device and batch dimensions as well as accross MPI processes. Hence, the result is an array of shapedata.shape[2]\(\times\)data.shape[2].The mean is computed using the given probabilities, i.e.,
\(\text{Cov}(X)=\sum_{j=1}^{N_S} p_jX_j\cdot X_j^\dagger - \bigg(\sum_{j=1}^{N_S} p_jX_j\bigg)\cdot\bigg(\sum_{j=1}^{N_S}p_jX_j^\dagger\bigg)\)
- Arguments:
data: Array of input data.p: Probabilities associated with the given data.
- Returns:
Covariance matrix of data across MPI processes and device/batch dimensions.
- jVMC.mpi_wrapper.global_mean(data, p)
Computes the mean of input data across MPI processes and device/batch dimensions.
On each MPI process the input data is assumed to be a
jax.numpy.arraywith a leading device dimension followed by a batch dimension. The data is reduced by computing the mean along device and batch dimensions as well as accross MPI processes. Hence, the result is an array of shapedata.shape[2:].The mean is computed using the given probabilities, i.e.,
\(\langle X\rangle=\sum_{j=1}^{N_S} p_jX_j\)
- Arguments:
data: Array of input data.p: Probabilities associated with the given data.
- Returns:
Mean of data across MPI processes and device/batch dimensions.
- jVMC.mpi_wrapper.global_sum(data)
Computes the sum of input data across MPI processes and device/batch dimensions.
On each MPI process the input data is assumed to be a
jax.numpy.arraywith a leading device dimension followed by a batch dimension. The data is reduced by summing up along device and batch dimensions as well as accross MPI processes. Hence, the result is an array of shapedata.shape[2:].- Arguments:
data: Array of input data.
- Returns:
Sum of data across MPI processes and device/batch dimensions.
- jVMC.mpi_wrapper.global_variance(data, p)
Computes the variance of input data across MPI processes and device/batch dimensions.
On each MPI process the input data is assumed to be a
jax.numpy.arraywith a leading device dimension followed by a batch dimension. The data is reduced by computing the variance along device and batch dimensions as well as accross MPI processes. Hence, the result is an array of shapedata.shape[2:].The mean is computed using the given probabilities, i.e.,
\(\text{Var}(X)=\sum_{j=1}^{N_S} p_j |X_j-\langle X\rangle|^2\)
- Arguments:
data: Array of input data.p: Probabilities associated with the given data.
- Returns:
Variance of data across MPI processes and device/batch dimensions.