Basic neural network architectures

This module provides implementations of standard neural network architectures.

class jVMC.nets.SymNet(orbit: ~jVMC.util.symmetries.LatticeSymmetry, net: callable, avgFun: callable = <function avgFun_Coefficients_Exp>, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Wrapper module for symmetrization. This is a wrapper module for the incorporation of lattice symmetries. The given plain ansatz \(\psi_\theta\) is symmetrized as

\(\Psi_\theta(s)=\frac{1}{|\mathcal S|}\sum_{\tau\in\mathcal S}\psi_\theta(\tau(s))\)

where \(\mathcal S\) denotes the set of symmetry operations (orbit in our nomenclature).

Initialization arguments:
  • orbit: orbits which define the symmetry operations (instance of util.symmetries.LatticeSymmetry)

  • net: Flax module defining the plain ansatz.

  • avgFun: Different choices for the details of averaging.

class jVMC.nets.CpxRBM(numHidden: int = 2, bias: bool = False, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Restricted Boltzmann machine with complex parameters.

Initialization arguments:
  • s: Computational basis configuration.

  • numHidden: Number of hidden units.

  • bias: Boolean indicating whether to use bias.

class jVMC.nets.RBM(numHidden: int = 2, bias: bool = False, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Restricted Boltzmann machine with real parameters.

Initialization arguments:
  • s: Computational basis configuration.

  • numHidden: Number of hidden units.

  • bias: Boolean indicating whether to use bias.

class jVMC.nets.CpxCNN(F: ~typing.Sequence[int] = (8, ), channels: ~typing.Sequence[int] = (10, ), strides: ~typing.Sequence[int] = (1, ), actFun: ~typing.Sequence[callable] = (<function poly6>, ), bias: bool = True, firstLayerBias: bool = False, periodicBoundary: bool = True, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Convolutional neural network with complex parameters.

Initialization arguments:
  • F: Filter diameter

  • channels: Number of channels

  • strides: Number of pixels the filter shifts over

  • actFun: Non-linear activation function

  • bias: Whether to use biases

  • firstLayerBias: Whether to use biases in the first layer

  • periodicBoundary: Whether to use periodic boundary conditions

class jVMC.nets.CNN(F: ~typing.Sequence[int] = (8, ), channels: ~typing.Sequence[int] = (10, ), strides: ~typing.Sequence[int] = (1, ), actFun: ~typing.Sequence[callable] = (<PjitFunction of <function elu>>, ), bias: bool = True, firstLayerBias: bool = False, periodicBoundary: bool = True, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Convolutional neural network with real parameters.

Initialization arguments:
  • F: Filter diameter

  • channels: Number of channels

  • strides: Number of pixels the filter shifts over

  • actFun: Non-linear activation function

  • bias: Whether to use biases

  • firstLayerBias: Whether to use biases in the first layer

  • periodicBoundary: Whether to use periodic boundary conditions

class jVMC.nets.FFN(layers: ~typing.Sequence[int] = (10, ), bias: bool = False, actFun: ~typing.Sequence[callable] = (<PjitFunction of <function elu>>, ), parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Feed forward network with real parameters.

Initialization arguments:
  • layers: Computational basis configuration.

  • bias: Boolean indicating whether to use bias.

  • actFun: Non-linear activation function.

class jVMC.nets.RNN1DGeneral(L: int = 10, hiddenSize: int = 10, depth: int = 1, inputDim: int = 2, actFun: callable = <PjitFunction of <function elu>>, initScale: float = 1.0, logProbFactor: float = 0.5, realValuedOutput: bool = False, realValuedParams: bool = True, cell: str | list = 'RNN', parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Implementation of a multi-layer RNN for one-dimensional data with arbitrary cell.

The cell parameter can be a string (“RNN”, “LSTM”, or “GRU”) indicating a pre-implemented cell. Alternatively, a custom cell can be passed in the form of a tuple containing a flax module that implements the hidden state update rule and the initial value of the hidden state (i.e., the initial carry). The signature of the __call__ function of the cell flax module has to be (carry, state) -> (new_carry, output).

This model can produce real positive or complex valued output. In either case the output is normalized such that

\(\sum_s |RNN(s)|^{1/\kappa}=1\).

Here, \(\kappa\) corresponds to the initialization parameter logProbFactor. Thereby, the RNN can represent both probability distributions and wave functions. Real or complex valued output is chosen through the parameter realValuedOutput.

The RNN allows for autoregressive sampling through the sample member function.

Initialization arguments:
  • L: length of the spin chain

  • hiddenSize: size of the hidden state vector

  • depth: number of RNN-cells in the RNNCellStack

  • inputDim: dimension of the input

  • actFun: non-linear activation function

  • initScale: factor by which the initial parameters are scaled

  • logProbFactor: factor defining how output and associated sample probability are related. 0.5 for pure states and 1 for POVMs.

  • realValuedOutput: Boolean indicating whether the network output is a real or complex number.

  • realValuedParams: Boolean indicating whether the network parameters are real or complex parameters.

  • cell: String (“RNN”, “LSTM”, or “GRU”) or custom definition indicating which type of cell to use for hidden state transformations.

class jVMC.nets.RNN2DGeneral(L: int = 10, hiddenSize: int = 10, depth: int = 1, inputDim: int = 2, actFun: callable = <PjitFunction of <function elu>>, initScale: float = 1.0, logProbFactor: float = 0.5, realValuedOutput: bool = False, realValuedParams: bool = True, cell: str | list = 'RNN', parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Implementation of a multi-layer RNN for one-dimensional data with arbitrary cell. This implementation follows approximately the original proposal for RNN wave functions in Hibat-Allah et al., Phys. Rev. Research 2, 023358 (2020).

The cell parameter can be a string (“RNN”, “LSTM”, or “GRU”) indicating a pre-implemented cell. Alternatively, a custom cell can be passed in the form of a tuple containing a flax module that implements the hidden state update rule and the initial value of the hidden state (i.e., the initial carry). The signature of the __call__ function of the cell flax module has to be (carry, state) -> (new_carry, output).

This model can produce real positive or complex valued output. In either case the output is normalized such that

\(\sum_s |RNN(s)|^{1/\kappa}=1\).

Here, \(\kappa\) corresponds to the initialization parameter logProbFactor. Thereby, the RNN can represent both probability distributions and wave functions. Real or complex valued output is chosen through the parameter realValuedOutput.

The RNN allows for autoregressive sampling through the sample member function.

Initialization arguments:
  • L: length of the spin chain

  • hiddenSize: size of the hidden state vector

  • depth: number of RNN-cells in the RNNCellStack

  • inputDim: dimension of the input

  • actFun: non-linear activation function

  • initScale: factor by which the initial parameters are scaled

  • logProbFactor: factor defining how output and associated sample probability are related. 0.5 for pure states and 1 for POVMs.

  • realValuedOutput: Boolean indicating whether the network output is a real or complex number.

  • realValuedParams: Boolean indicating whether the network parameters are real or complex parameters.

  • cell: String (“RNN”, “LSTM”, or “GRU”) or custom definition indicating which type of cell to use for hidden state transformations.