Basic neural network architectures
This module provides implementations of standard neural network architectures.
- class jVMC.nets.SymNet(orbit: ~jVMC.util.symmetries.LatticeSymmetry, net: callable, avgFun: callable = <function avgFun_Coefficients_Exp>, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)
Wrapper module for symmetrization. This is a wrapper module for the incorporation of lattice symmetries. The given plain ansatz \(\psi_\theta\) is symmetrized as
\(\Psi_\theta(s)=\frac{1}{|\mathcal S|}\sum_{\tau\in\mathcal S}\psi_\theta(\tau(s))\)
where \(\mathcal S\) denotes the set of symmetry operations (
orbitin our nomenclature).- Initialization arguments:
orbit: orbits which define the symmetry operations (instance ofutil.symmetries.LatticeSymmetry)net: Flax module defining the plain ansatz.avgFun: Different choices for the details of averaging.
- class jVMC.nets.CpxRBM(numHidden: int = 2, bias: bool = False, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)
Restricted Boltzmann machine with complex parameters.
- Initialization arguments:
s: Computational basis configuration.numHidden: Number of hidden units.bias:Booleanindicating whether to use bias.
- class jVMC.nets.RBM(numHidden: int = 2, bias: bool = False, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)
Restricted Boltzmann machine with real parameters.
- Initialization arguments:
s: Computational basis configuration.numHidden: Number of hidden units.bias:Booleanindicating whether to use bias.
- class jVMC.nets.CpxCNN(F: ~typing.Sequence[int] = (8, ), channels: ~typing.Sequence[int] = (10, ), strides: ~typing.Sequence[int] = (1, ), actFun: ~typing.Sequence[callable] = (<function poly6>, ), bias: bool = True, firstLayerBias: bool = False, periodicBoundary: bool = True, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)
Convolutional neural network with complex parameters.
- Initialization arguments:
F: Filter diameterchannels: Number of channelsstrides: Number of pixels the filter shifts overactFun: Non-linear activation functionbias: Whether to use biasesfirstLayerBias: Whether to use biases in the first layerperiodicBoundary: Whether to use periodic boundary conditions
- class jVMC.nets.CNN(F: ~typing.Sequence[int] = (8, ), channels: ~typing.Sequence[int] = (10, ), strides: ~typing.Sequence[int] = (1, ), actFun: ~typing.Sequence[callable] = (<PjitFunction of <function elu>>, ), bias: bool = True, firstLayerBias: bool = False, periodicBoundary: bool = True, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)
Convolutional neural network with real parameters.
- Initialization arguments:
F: Filter diameterchannels: Number of channelsstrides: Number of pixels the filter shifts overactFun: Non-linear activation functionbias: Whether to use biasesfirstLayerBias: Whether to use biases in the first layerperiodicBoundary: Whether to use periodic boundary conditions
- class jVMC.nets.FFN(layers: ~typing.Sequence[int] = (10, ), bias: bool = False, actFun: ~typing.Sequence[callable] = (<PjitFunction of <function elu>>, ), parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)
Feed forward network with real parameters.
- Initialization arguments:
layers: Computational basis configuration.bias:Booleanindicating whether to use bias.actFun: Non-linear activation function.
- class jVMC.nets.RNN1DGeneral(L: int = 10, hiddenSize: int = 10, depth: int = 1, inputDim: int = 2, actFun: callable = <PjitFunction of <function elu>>, initScale: float = 1.0, logProbFactor: float = 0.5, realValuedOutput: bool = False, realValuedParams: bool = True, cell: str | list = 'RNN', parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)
Implementation of a multi-layer RNN for one-dimensional data with arbitrary cell.
The
cellparameter can be a string (“RNN”, “LSTM”, or “GRU”) indicating a pre-implemented cell. Alternatively, a custom cell can be passed in the form of a tuple containing a flax module that implements the hidden state update rule and the initial value of the hidden state (i.e., the initialcarry). The signature of the__call__function of the cell flax module has to be(carry, state) -> (new_carry, output).This model can produce real positive or complex valued output. In either case the output is normalized such that
\(\sum_s |RNN(s)|^{1/\kappa}=1\).
Here, \(\kappa\) corresponds to the initialization parameter
logProbFactor. Thereby, the RNN can represent both probability distributions and wave functions. Real or complex valued output is chosen through the parameterrealValuedOutput.The RNN allows for autoregressive sampling through the
samplemember function.- Initialization arguments:
L: length of the spin chainhiddenSize: size of the hidden state vectordepth: number of RNN-cells in the RNNCellStackinputDim: dimension of the inputactFun: non-linear activation functioninitScale: factor by which the initial parameters are scaledlogProbFactor: factor defining how output and associated sample probability are related. 0.5 for pure states and 1 for POVMs.realValuedOutput: Boolean indicating whether the network output is a real or complex number.realValuedParams: Boolean indicating whether the network parameters are real or complex parameters.cell: String (“RNN”, “LSTM”, or “GRU”) or custom definition indicating which type of cell to use for hidden state transformations.
- class jVMC.nets.RNN2DGeneral(L: int = 10, hiddenSize: int = 10, depth: int = 1, inputDim: int = 2, actFun: callable = <PjitFunction of <function elu>>, initScale: float = 1.0, logProbFactor: float = 0.5, realValuedOutput: bool = False, realValuedParams: bool = True, cell: str | list = 'RNN', parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)
Implementation of a multi-layer RNN for one-dimensional data with arbitrary cell. This implementation follows approximately the original proposal for RNN wave functions in Hibat-Allah et al., Phys. Rev. Research 2, 023358 (2020).
The
cellparameter can be a string (“RNN”, “LSTM”, or “GRU”) indicating a pre-implemented cell. Alternatively, a custom cell can be passed in the form of a tuple containing a flax module that implements the hidden state update rule and the initial value of the hidden state (i.e., the initialcarry). The signature of the__call__function of the cell flax module has to be(carry, state) -> (new_carry, output).This model can produce real positive or complex valued output. In either case the output is normalized such that
\(\sum_s |RNN(s)|^{1/\kappa}=1\).
Here, \(\kappa\) corresponds to the initialization parameter
logProbFactor. Thereby, the RNN can represent both probability distributions and wave functions. Real or complex valued output is chosen through the parameterrealValuedOutput.The RNN allows for autoregressive sampling through the
samplemember function.- Initialization arguments:
L: length of the spin chainhiddenSize: size of the hidden state vectordepth: number of RNN-cells in the RNNCellStackinputDim: dimension of the inputactFun: non-linear activation functioninitScale: factor by which the initial parameters are scaledlogProbFactor: factor defining how output and associated sample probability are related. 0.5 for pure states and 1 for POVMs.realValuedOutput: Boolean indicating whether the network output is a real or complex number.realValuedParams: Boolean indicating whether the network parameters are real or complex parameters.cell: String (“RNN”, “LSTM”, or “GRU”) or custom definition indicating which type of cell to use for hidden state transformations.