Preface#
The below is an alpha API preview and things might break. The surface syntax of the features of the API are not fixed in stone, and we welcome feedback on any points.
Useful links#
⟶ Slides for the core ideas of the new Functional Core and Linen
⟶ “Design tests” guided our design process. Many are available for functional core and some for the proposed Module abstraction
⟶ Ported examples: ImageNet and WMT (to the proposed Module abstraction). TODO: Port to functional core.
⟶ Our new discussion forums
Install and Import#
# Install the newest JAXlib version.
!pip install --upgrade -q pip jax jaxlib
# Install Flax at head:
!pip install --upgrade -q git+https://github.com/google/flax.git
import functools
from typing import Any, Callable, Sequence, Optional
import jax
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn
Invoking Modules#
Let’s instantiate a Dense
layer.
Modules are actually objects in this API, so we provide contructor arguments when initializing the Module. In this case, we only have to provide the output
features
dimension.
model = nn.Dense(features=3)
We need to initialize the Module variables, these include the parameters of the Module as well as any other state variables.
We call the init
method on the instantiated Module. If the Module __call__
method has args (self, *args, **kwargs)
then we call init
with (rngs, *args, **kwargs)
so in this case, just (rng, input)
:
# Make RNG Keys and a fake input.
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))
# provide key and fake input to get initialized variables
init_variables = model.init(key2, x)
init_variables
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
FrozenDict({
params: {
kernel: Array([[ 0.61506 , -0.22728713, 0.6054702 ],
[-0.29617992, 1.1232013 , -0.879759 ],
[-0.35162622, 0.3806491 , 0.6893246 ],
[-0.1151355 , 0.04567898, -1.091212 ]], dtype=float32),
bias: Array([0., 0., 0.], dtype=float32),
},
})
We call the apply
method on the instantiated Module. If the Module __call__
method has args (self, *args, **kwargs)
then we call apply
with (variables, *args, rngs=<RNGS>, mutable=<MUTABLEKINDS>, **kwargs)
where
<RNGS>
are the optional call time RNGs for things like dropout. For simple Modules this is just a single key, but if your module has multiple kinds of data, it’s a dictionary of rng-keys per-kind, e.g.{'params': key0, 'dropout': key1}
for a Module with dropout layers.<MUTABLEKINDS>
is an optional list of names of kinds that are expected to be mutated during the call. e.g.['batch_stats']
for a layer updating batchnorm statistics.
So in this case, just (variables, input)
:
y = model.apply(init_variables, x)
y
Array([[-0.02996204, 1.102088 , -0.6660265 ],
[-0.31092793, 0.6323942 , -0.53678817],
[ 0.01424007, 0.9424717 , -0.6356147 ],
[ 0.36818963, 0.3586519 , -0.00459214]], dtype=float32)
Additional points:
If you want to
init
orapply
a Module using a method other than call, you need to provide themethod=
kwarg toinit
andapply
to use it instead of the default__call__
, e.g.method='encode'
,method='decode'
to apply the encode/decode methods of an autoencoder.
Defining Basic Modules#
Composing submodules#
We support declaring modules in setup()
that can still benefit from shape inference by using Lazy Initialization that sets up variables the first time the Module is called.
class ExplicitMLP(nn.Module):
features: Sequence[int]
def setup(self):
# we automatically know what to do with lists, dicts of submodules
self.layers = [nn.Dense(feat) for feat in self.features]
# for single submodules, we would just write:
# self.layer1 = nn.Dense(feat1)
def __call__(self, inputs):
x = inputs
for i, lyr in enumerate(self.layers):
x = lyr(x)
if i != len(self.layers) - 1:
x = nn.relu(x)
return x
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))
model = ExplicitMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)
initialized parameter shapes:
{'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
[[ 0. 0. 0. 0. 0. ]
[ 0.0072379 -0.00810347 -0.02550939 0.02151716 -0.01261241]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]]
Here we show the equivalent compact form of the MLP that declares the submodules inline using the @compact
decorator.
class SimpleMLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features):
x = nn.Dense(feat, name=f'layers_{i}')(x)
if i != len(self.features) - 1:
x = nn.relu(x)
# providing a name is optional though!
# the default autonames would be "Dense_0", "Dense_1", ...
# x = nn.Dense(feat)(x)
return x
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))
model = SimpleMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)
initialized parameter shapes:
{'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
[[ 0. 0. 0. 0. 0. ]
[ 0.0072379 -0.00810347 -0.02550939 0.02151716 -0.01261241]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]]
Declaring and using variables#
Flax uses lazy initialization, which allows declared variables to be initialized only at the first site of their use, using whatever shape information is available a the local call site for shape inference. Once a variable has been initialized, a reference to the data is kept for use in subsequent calls.
For declaring parameters that aren’t mutated inside the model, but rather by gradient descent, we use the syntax:
self.param(parameter_name, parameter_init_fn, *init_args)
with arguments:
parameter_name
just the name, a stringparameter_init_fn
a function taking an RNG key and a variable number of other arguments, i.e.fn(rng, *args)
. typically those innn.initializers
take anrng
and ashape
argument.the remaining arguments to feed to the init function when initializing.
Again, we’ll demonstrate declaring things inline as we typically do using the @compact
decorator.
class SimpleDense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
bias_init: Callable = nn.initializers.zeros_init()
@nn.compact
def __call__(self, inputs):
kernel = self.param('kernel',
self.kernel_init, # RNG passed implicitly.
(inputs.shape[-1], self.features)) # shape info.
y = lax.dot_general(inputs, kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),)
bias = self.param('bias', self.bias_init, (self.features,))
y = y + bias
return y
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))
model = SimpleDense(features=3)
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameters:\n', init_variables)
print('output:\n', y)
initialized parameters:
FrozenDict({
params: {
kernel: Array([[ 0.61506 , -0.22728713, 0.6054702 ],
[-0.29617992, 1.1232013 , -0.879759 ],
[-0.35162622, 0.3806491 , 0.6893246 ],
[-0.1151355 , 0.04567898, -1.091212 ]], dtype=float32),
bias: Array([0., 0., 0.], dtype=float32),
},
})
output:
[[-0.02996204 1.102088 -0.6660265 ]
[-0.31092793 0.6323942 -0.53678817]
[ 0.01424007 0.9424717 -0.6356147 ]
[ 0.36818963 0.3586519 -0.00459214]]
We can also declare variables in setup, though in doing so you can’t take advantage of shape inference and have to provide explicit shape information at initialization. The syntax is a little repetitive in this case right now, but we do force agreement of the assigned names.
class ExplicitDense(nn.Module):
features_in: int # <-- explicit input shape
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
bias_init: Callable = nn.initializers.zeros_init()
def setup(self):
self.kernel = self.param('kernel',
self.kernel_init,
(self.features_in, self.features))
self.bias = self.param('bias', self.bias_init, (self.features,))
def __call__(self, inputs):
y = lax.dot_general(inputs, self.kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),)
y = y + self.bias
return y
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))
model = ExplicitDense(features_in=4, features=3)
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameters:\n', init_variables)
print('output:\n', y)
initialized parameters:
FrozenDict({
params: {
kernel: Array([[ 0.61506 , -0.22728713, 0.6054702 ],
[-0.29617992, 1.1232013 , -0.879759 ],
[-0.35162622, 0.3806491 , 0.6893246 ],
[-0.1151355 , 0.04567898, -1.091212 ]], dtype=float32),
bias: Array([0., 0., 0.], dtype=float32),
},
})
output:
[[-0.02996204 1.102088 -0.6660265 ]
[-0.31092793 0.6323942 -0.53678817]
[ 0.01424007 0.9424717 -0.6356147 ]
[ 0.36818963 0.3586519 -0.00459214]]
General Variables#
For declaring generally mutable variables that may be mutated inside the model we use the call:
self.variable(variable_kind, variable_name, variable_init_fn, *init_args)
with arguments:
variable_kind
the “kind” of state this variable is, i.e. the name of the nested-dict collection that this will be stored in inside the top Modules variables. e.g.batch_stats
for the moving statistics for a batch norm layer orcache
for autoregressive cache data. Note that parameters also have a kind, but they’re set to the defaultparam
kind.variable_name
just the name, a stringvariable_init_fn
a function taking a variable number of other arguments, i.e.fn(*args)
. Note that we don’t assume the need for an RNG, if you do want an RNG, provide it via aself.make_rng(variable_kind)
call in the provided arguments.the remaining arguments to feed to the init function when initializing.
⚠️ Unlike parameters, we expect these to be mutated, so self.variable
returns not a constant, but a reference to the variable. To get the raw value, you’d write myvariable.value
and to set it myvariable.value = new_value
.
class Counter(nn.Module):
@nn.compact
def __call__(self):
# easy pattern to detect if we're initializing
is_initialized = self.has_variable('counter', 'count')
counter = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))
if is_initialized:
counter.value += 1
return counter.value
key1 = random.PRNGKey(0)
model = Counter()
init_variables = model.init(key1)
print('initialized variables:\n', init_variables)
y, mutated_variables = model.apply(init_variables, mutable=['counter'])
print('mutated variables:\n', mutated_variables)
print('output:\n', y)
initialized variables:
FrozenDict({
counter: {
count: Array(0, dtype=int32),
},
})
mutated variables:
FrozenDict({
counter: {
count: Array(1, dtype=int32),
},
})
output:
1
Another Mutability and RNGs Example#
Let’s make an artificial, goofy example that mixes differentiable parameters, stochastic layers, and mutable variables:
class Block(nn.Module):
features: int
training: bool
@nn.compact
def __call__(self, inputs):
x = nn.Dense(self.features)(inputs)
x = nn.Dropout(rate=0.5)(x, deterministic=not self.training)
x = nn.BatchNorm(use_running_average=not self.training)(x)
return x
key1, key2, key3, key4 = random.split(random.PRNGKey(0), 4)
x = random.uniform(key1, (3,4,4))
model = Block(features=3, training=True)
init_variables = model.init({'params': key2, 'dropout': key3}, x)
_, init_params = init_variables.pop('params')
# When calling `apply` with mutable kinds, returns a pair of output,
# mutated_variables.
y, mutated_variables = model.apply(
init_variables, x, rngs={'dropout': key4}, mutable=['batch_stats'])
# Now we reassemble the full variables from the updates (in a real training
# loop, with the updated params from an optimizer).
updated_variables = freeze(dict(params=init_params,
**mutated_variables))
print('updated variables:\n', updated_variables)
print('initialized variable shapes:\n',
jax.tree_util.tree_map(jnp.shape, init_variables))
print('output:\n', y)
# Let's run these model variables during "evaluation":
eval_model = Block(features=3, training=False)
y = eval_model.apply(updated_variables, x) # Nothing mutable; single return value.
print('eval output:\n', y)
updated variables:
FrozenDict({
params: {
Dense_0: {
kernel: Array([[-0.28613907, -0.70721805, -0.6914534 ],
[-0.6641256 , 0.68615615, 0.44648856],
[ 0.49008575, 0.35440823, -0.21784972],
[-0.20484754, -0.09779654, 0.5040372 ]], dtype=float32),
bias: Array([0., 0., 0.], dtype=float32),
},
BatchNorm_0: {
scale: Array([1., 1., 1.], dtype=float32),
bias: Array([0., 0., 0.], dtype=float32),
},
},
batch_stats: {
BatchNorm_0: {
mean: Array([-0.00313114, 0.00078917, -0.00079298], dtype=float32),
var: Array([0.992465 , 0.99116045, 0.9916556 ], dtype=float32),
},
},
})
initialized variable shapes:
FrozenDict({
batch_stats: {
BatchNorm_0: {
mean: (3,),
var: (3,),
},
},
params: {
BatchNorm_0: {
bias: (3,),
scale: (3,),
},
Dense_0: {
bias: (3,),
kernel: (4, 3),
},
},
})
output:
[[[ 0.6306426 -0.23165356 0.19488382]
[ 0.6306426 0.79969025 0.10451133]
[-0.35419682 -0.23165356 0.19488382]
[-1.6535418 -0.23165356 -0.45335415]]
[[ 0.7730165 -1.5991257 -1.4089029 ]
[ 0.6306426 -0.23165356 0.19488382]
[ 0.6306426 -0.23165356 -1.390066 ]
[-2.1826966 -0.23165356 0.19488382]]
[[ 0.6306426 -0.23165356 0.19488382]
[-0.9970797 -0.23165356 -1.2379806 ]
[ 0.6306426 2.8843176 2.1083667 ]
[ 0.6306426 -0.23165356 1.3030066 ]]]
eval output:
[[[-0.69251037 0.1369832 0.3395594 ]
[-0.55603373 0.17566003 -0.01766703]
[-0.24226873 0.3135934 -0.1597346 ]
[-0.56605196 0.3020906 -0.13164042]]
[[ 0.03862109 -0.23475364 -0.32686153]
[-0.12492569 0.10663299 0.08993971]
[-0.12356766 -0.38863623 -0.32301313]
[-0.6979118 -0.00074816 0.1433272 ]]
[[-0.21574032 0.4600567 0.45498124]
[-0.4024684 -0.11036564 -0.29194167]
[-0.3182572 0.5323191 0.39172593]
[-0.5571195 0.43794093 0.22718874]]]
JAX transformations inside modules#
JIT#
It’s not immediately clear what use this has, but you can compile specific submodules if there’s a reason to.
Known Gotcha: at the moment, the decorator changes the RNG stream slightly, so comparing jitted an unjitted initializations will look different.
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features):
# JIT the Module (it's __call__ fn by default.)
x = nn.jit(nn.Dense)(feat, name=f'layers_{i}')(x)
if i != len(self.features) - 1:
x = nn.relu(x)
return x
key1, key2 = random.split(random.PRNGKey(3), 2)
x = random.uniform(key1, (4,4))
model = MLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)
initialized parameter shapes:
{'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
[[-0.7031405 0.23289034 -0.23026259 -0.5699682 -0.52668184]
[-0.05363148 0.01776353 -0.0175631 -0.04347387 -0.04017224]
[-1.3031061 0.43160754 -0.42673773 -1.0563024 -0.97608125]
[-0.8637256 0.28607848 -0.28285062 -0.70013916 -0.64696693]]
Remat#
For memory-expensive computations, we can remat
our method to recompute a Module’s output during a backwards pass.
Known Gotcha: at the moment, the decorator changes the RNG stream slightly, so comparing remat’d and undecorated initializations will look different.
class RematMLP(nn.Module):
features: Sequence[int]
# For all transforms, we can annotate a method, or wrap an existing
# Module class. Here we annotate the method.
@nn.remat
@nn.compact
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features):
x = nn.Dense(feat, name=f'layers_{i}')(x)
if i != len(self.features) - 1:
x = nn.relu(x)
return x
key1, key2 = random.split(random.PRNGKey(3), 2)
x = random.uniform(key1, (4,4))
model = RematMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)
initialized parameter shapes:
{'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
[[-0.7031404 0.23289026 -0.23026255 -0.5699681 -0.5266818 ]
[-0.05363144 0.01776351 -0.01756308 -0.04347384 -0.04017221]
[-1.3031058 0.43160748 -0.4267377 -1.0563023 -0.97608113]
[-0.8637254 0.2860784 -0.28285056 -0.700139 -0.6469668 ]]
Vmap#
You can now vmap
Modules inside. The transform has a lot of arguments, they have the usual jax vmap args:
in_axes
- an integer orNone
for each input argumentout_axes
- an integer orNone
for each output argumentaxis_size
- the axis size if you need to give it explicitly
In addition, we provide for each kind of variable it’s axis rules:
variable_in_axes
- a dict from kinds to a single integer orNone
specifying the input axes to mapvariable_out_axes
- a dict from kinds to a single integer orNone
specifying the output axes to mapsplit_rngs
- a dict from RNG-kinds to a bool, specifying whether to split the rng along the axis.
Below we show an example defining a batched, multiheaded attention module from a single-headed unbatched attention implementation.
class RawDotProductAttention(nn.Module):
attn_dropout_rate: float = 0.1
train: bool = False
@nn.compact
def __call__(self, query, key, value, bias=None, dtype=jnp.float32):
assert key.ndim == query.ndim
assert key.ndim == value.ndim
n = query.ndim
attn_weights = lax.dot_general(
query, key,
(((n-1,), (n - 1,)), ((), ())))
if bias is not None:
attn_weights += bias
norm_dims = tuple(range(attn_weights.ndim // 2, attn_weights.ndim))
attn_weights = jax.nn.softmax(attn_weights, axis=norm_dims)
attn_weights = nn.Dropout(self.attn_dropout_rate)(attn_weights,
deterministic=not self.train)
attn_weights = attn_weights.astype(dtype)
contract_dims = (
tuple(range(n - 1, attn_weights.ndim)),
tuple(range(0, n - 1)))
y = lax.dot_general(
attn_weights, value,
(contract_dims, ((), ())))
return y
class DotProductAttention(nn.Module):
qkv_features: Optional[int] = None
out_features: Optional[int] = None
train: bool = False
@nn.compact
def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):
qkv_features = self.qkv_features or inputs_q.shape[-1]
out_features = self.out_features or inputs_q.shape[-1]
QKVDense = functools.partial(
nn.Dense, features=qkv_features, use_bias=False, dtype=dtype)
query = QKVDense(name='query')(inputs_q)
key = QKVDense(name='key')(inputs_kv)
value = QKVDense(name='value')(inputs_kv)
y = RawDotProductAttention(train=self.train)(
query, key, value, bias=bias, dtype=dtype)
y = nn.Dense(features=out_features, dtype=dtype, name='out')(y)
return y
class MultiHeadDotProductAttention(nn.Module):
qkv_features: Optional[int] = None
out_features: Optional[int] = None
batch_axes: Sequence[int] = (0,)
num_heads: int = 1
broadcast_dropout: bool = False
train: bool = False
@nn.compact
def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):
qkv_features = self.qkv_features or inputs_q.shape[-1]
out_features = self.out_features or inputs_q.shape[-1]
# Make multiheaded attention from single-headed dimension.
Attn = nn.vmap(DotProductAttention,
in_axes=(None, None, None),
out_axes=2,
axis_size=self.num_heads,
variable_axes={'params': 0},
split_rngs={'params': True,
'dropout': not self.broadcast_dropout})
# Vmap across batch dimensions.
for axis in reversed(sorted(self.batch_axes)):
Attn = nn.vmap(Attn,
in_axes=(axis, axis, axis),
out_axes=axis,
variable_axes={'params': None},
split_rngs={'params': False, 'dropout': False})
# Run the vmap'd class on inputs.
y = Attn(qkv_features=qkv_features // self.num_heads,
out_features=out_features,
train=self.train,
name='attention')(inputs_q, inputs_kv, bias)
return y.mean(axis=-2)
key1, key2, key3, key4 = random.split(random.PRNGKey(0), 4)
x = random.uniform(key1, (3, 13, 64))
model = functools.partial(
MultiHeadDotProductAttention,
broadcast_dropout=False,
num_heads=2,
batch_axes=(0,))
init_variables = model(train=False).init({'params': key2}, x, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})
print('output:\n', y.shape)
initialized parameter shapes:
{'params': {'attention': {'key': {'kernel': (2, 64, 32)}, 'out': {'bias': (2, 64), 'kernel': (2, 32, 64)}, 'query': {'kernel': (2, 64, 32)}, 'value': {'kernel': (2, 64, 32)}}}}
output:
(3, 13, 2)
Scan#
Scan allows us to apply lax.scan
to Modules, including their parameters and mutable variables. To use it we have to specify how we want each “kind” of variable to be transformed. For scanned variables we specify similar to vmap via in variable_in_axes
, variable_out_axes
:
nn.broadcast
broadcast the variable kind across the scan steps as a constant<axis:int>
scan alongaxis
for e.g. unique parameters at each step
OR we specify that the variable kind is to be treated like a “carry” by passing to the variable_carry
argument.
Further, for scan
’d variable kinds, we further specify whether or not to split the rng at each step.
class SimpleScan(nn.Module):
@nn.compact
def __call__(self, xs):
dummy_rng = random.PRNGKey(0)
init_carry = nn.LSTMCell.initialize_carry(dummy_rng,
xs.shape[:1],
xs.shape[-1])
LSTM = nn.scan(nn.LSTMCell,
in_axes=1, out_axes=1,
variable_broadcast='params',
split_rngs={'params': False})
return LSTM(name="lstm_cell")(init_carry, xs)
key1, key2 = random.split(random.PRNGKey(0), 2)
xs = random.uniform(key1, (1, 5, 2))
model = SimpleScan()
init_variables = model.init(key2, xs)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
y = model.apply(init_variables, xs)
print('output:\n', y)
initialized parameter shapes:
{'params': {'lstm_cell': {'hf': {'bias': (2,), 'kernel': (2, 2)}, 'hg': {'bias': (2,), 'kernel': (2, 2)}, 'hi': {'bias': (2,), 'kernel': (2, 2)}, 'ho': {'bias': (2,), 'kernel': (2, 2)}, 'if': {'kernel': (2, 2)}, 'ig': {'kernel': (2, 2)}, 'ii': {'kernel': (2, 2)}, 'io': {'kernel': (2, 2)}}}}
output:
((Array([[-0.21091282, 0.16342616]], dtype=float32), Array([[-0.16439663, 0.08159931]], dtype=float32)), Array([[[ 0.01701735, 0.04418677],
[-0.11112795, 0.04402397],
[-0.13384211, 0.10686103],
[-0.16457099, 0.0799703 ],
[-0.16439663, 0.08159931]]], dtype=float32))