helpers# class flax.experimental.nnx.Dict(*args, **kwargs)[source]# class flax.experimental.nnx.List(*args, **kwargs)[source]# class flax.experimental.nnx.Sequential(*args, **kwargs)[source]# class flax.experimental.nnx.TrainState(graphdef: 'GraphDef[M]', params: 'State', tx: 'optax.GradientTransformation', opt_state: 'optax.OptState', step: 'jax.Array')[source]# replace(**updates)# “Returns a new object replacing the specified fields with new values.