Definition of custom network architectures
This example shows exemplarily how to define a custom complex RBM architecture, i.e.,
\[\log\psi(\mathbf s) = \sum_i \log\big[\cosh\big(b_i + \sum_j W_{ij} s_j\big)\big]\]
with \(b_i, W_{ij}\in\mathbb C\).
1
2 @flax.linen.compact
3 def __call__(self, s):
4
5 s = 2 * s - 1 # Go from 0/1 representation to 1/-1
6
7 h = flax.linen.Dense(features=self.numHidden,
8 dtype=jVMC.global_defs.tCpx)(s)
9
10 h = jax.numpy.log(jax.numpy.cosh(h))
11
12 return jax.numpy.sum(h)
13
14
15L = 4 # system size
16
17# Initialize custom net
18net = MyNet(numHidden=7)
19
20# Create the variational quantum state
21psi = jVMC.vqs.NQS(net, seed=1234)
22
23# Create a set of 13 random input configurations
24configs = jax.random.bernoulli(jax.random.PRNGKey(4321), shape=(1, 13, L))
25