Definition of custom network architectures

This example shows exemplarily how to define a custom complex RBM architecture, i.e.,

\[\log\psi(\mathbf s) = \sum_i \log\big[\cosh\big(b_i + \sum_j W_{ij} s_j\big)\big]\]

with \(b_i, W_{ij}\in\mathbb C\).

 1
 2    @flax.linen.compact
 3    def __call__(self, s):
 4
 5        s = 2 * s - 1  # Go from 0/1 representation to 1/-1
 6
 7        h = flax.linen.Dense(features=self.numHidden,
 8                             dtype=jVMC.global_defs.tCpx)(s)
 9
10        h = jax.numpy.log(jax.numpy.cosh(h))
11
12        return jax.numpy.sum(h)
13
14
15L = 4  # system size
16
17# Initialize custom net
18net = MyNet(numHidden=7)
19
20# Create the variational quantum state
21psi = jVMC.vqs.NQS(net, seed=1234)
22
23# Create a set of 13 random input configurations
24configs = jax.random.bernoulli(jax.random.PRNGKey(4321), shape=(1, 13, L))
25