Sampler module

The sampler module provides two possibilities to sample variational quantum states: Markov Chain Monte Carlo (MCMC) and direct sampling. For direct sampling the variational ansatz needs to provide a sample() member function. If the variational wave function provides such a function, direct sampling is used. Otherwise, MCMC is employed. Additionally, the ExactSampler class exists, which works with all basis states instead of samples, which can be helpful for quick troubleshooting.

class jVMC.sampler.ExactSampler(net, sampleShape, lDim=2, logProbFactor=0.5)

Class for full enumeration of basis states.

This class generates a full basis of the many-body Hilbert space. Thereby, it allows to exactly perform sums over the full Hilbert space instead of stochastic sampling.

Initialization arguments:
  • net: Network defining the probability distribution.

  • sampleShape: Shape of computational basis states.

  • lDim: Local Hilbert space dimension.

  • logProbFactor: Factor for the log-probabilities, aquivalent to the exponent for the probability distribution. For pure wave functions this should be 0.5, and 1.0 for POVMs.

sample(parameters=None, numSamples=None, multipleOf=None)

Return all computational basis states.

Sampling is automatically distributed accross MPI processes and available devices.

Arguments:
  • parameters: Dummy argument to provide identical interface as the MCSampler class.

  • numSamples: Dummy argument to provide identical interface as the MCSampler class.

  • multipleOf: Dummy argument to provide identical interface as the MCSampler class.

Returns:

configs, logPsi, p: All computational basis configurations, corresponding wave function coefficients, and probabilities \(|\psi(s)|^2\) (normalized).

class jVMC.sampler.MCSampler(net, sampleShape, key, updateProposer=None, numChains=1, updateProposerArg=None, numSamples=100, thermalizationSweeps=10, sweepSteps=10, initState=None, mu=2, logProbFactor=0.5)

A sampler class.

This class provides functionality to sample computational basis states from the distribution

\(p_{\mu}(s)=\frac{|\psi(s)|^{\mu}}{\sum_s|\psi(s)|^{\mu}}\).

For \(\mu=2\) this corresponds to sampling from the Born distribution. \(0\leq\mu<2\) can be used to perform importance sampling (see [arXiv:2108.08631]).

Sampling is automatically distributed accross MPI processes and locally available devices.

Initializer arguments:
  • net: Network defining the probability distribution.

  • sampleShape: Shape of computational basis configurations.

  • key: An instance of jax.random.PRNGKey. Alternatively, an int that will be used as seed to initialize a PRNGKey.

  • updateProposer: A function to propose updates for the MCMC algorithm. It is called as updateProposer(key, config, **kwargs), where key is an instance of jax.random.PRNGKey, config is a computational basis configuration, and **kwargs are optional additional arguments.

  • numChains: Number of Markov chains, which are run in parallel.

  • updateProposerArg: An optional argument that will be passed to the updateProposer as kwargs.

  • numSamples: Default number of samples to be returned by the sample() member function.

  • thermalizationSweeps: Number of sweeps to perform for thermalization of the Markov chain.

  • sweepSteps: Number of proposed updates per sweep.

  • mu: Parameter for the distribution \(p_{\mu}(s)\), see above.

  • logProbFactor: Factor for the log-probabilities, aquivalent to the exponent for the probability distribution. For pure wave functions this should be 0.5, and 1.0 for POVMs. In the POVM case, the mu parameter must be set to 1.0, to sample the unchanged POVM distribution.

acceptance_ratio()

Get acceptance ratio.

Returns:

Acceptance ratio observed in the last call to sample().

get_last_number_of_samples()

Return last number of samples.

This function is required, because the actual number of samples might exceed the requested number of samples when sampling is distributed accross multiple processes or devices.

Returns:

Number of samples generated by last call to sample() member function.

sample(parameters=None, numSamples=None, multipleOf=1)

Generate random samples from wave function.

If supported by net, direct sampling is peformed. Otherwise, MCMC is run to generate the desired number of samples. For direct sampling the real part of net needs to provide a sample() member function that generates samples from \(p_{\mu}(s)\).

Sampling is automatically distributed accross MPI processes and available devices. In that case the number of samples returned might exceed numSamples.

Arguments:
  • parameters: Network parameters to use for sampling.

  • numSamples: Number of samples to generate. When running multiple processes or on multiple devices per process, the number of samples returned is numSamples or more. If None, the default number of samples is returned (see set_number_of_samples() member function).

  • multipleOf: This argument allows to choose the number of samples returned to be the smallest multiple of multipleOf larger than numSamples. This feature is useful to distribute a total number of samples across multiple processors in such a way that the number of samples per processor is identical for each processor.

Returns:

A sample of computational basis configurations drawn from \(p_{\mu}(s)\).

set_number_of_samples(N)

Set default number of samples.

Arguments:
  • N: Number of samples.

set_random_key(key)

Set key for pseudo random number generator.

Args:
  • key: Key (jax.random.PRNGKey)