# Copyright 2022 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flax functional core: Scopes."""
import contextlib
import dataclasses
import functools
import hashlib
import typing
from typing import (Any, Callable, Dict, Generic, Iterable, Mapping, Optional,
Sequence, Set, Tuple, TypeVar, Union)
from . import tracers
from flax import config
from flax import errors
from flax import struct
from flax import traceback_util
from .frozen_dict import freeze
from .frozen_dict import FrozenDict
from .frozen_dict import unfreeze
import jax
from jax import config as jax_config
from jax import numpy as jnp
from jax import random
import numpy as np
traceback_util.register_exclusion(__file__)
T = TypeVar('T')
PRNGKey = Any
Array = Any
RNGSequences = Dict[str, PRNGKey]
Filter = Union[bool, str, typing.Collection[str], 'DenyList']
# When conditioning on filters we require explicit boolean comparisons.
# pylint: disable=g-bool-id-comparison
@dataclasses.dataclass(frozen=True, eq=True)
class DenyList:
"""DenyList represents an opt-out based mutability filter.
DenyList can be used to make every collection mutable except the ones
defined in the given filter.
To for example make everything but the params collection mutable::
nn.apply(fn, mutable=nn.DenyList(["params"]))
Attributes:
deny: The filter representing the collections that are not mutable.
"""
deny: Filter
CollectionFilter = Filter
PRNGSequenceFilter = Filter
Collection = Mapping[str, Any]
MutableCollection = Dict[str, Any]
VariableDict = Mapping[str, Collection]
FrozenVariableDict = FrozenDict[str, Collection]
MutableVariableDict = Dict[str, MutableCollection]
PRNGFoldable = Union[int, str]
class LazyRng(struct.PyTreeNode):
"""Wrapper around JAX PRNGKey that lazily maintains a tuple of static data to be folded into the rng."""
rng: PRNGKey
suffix: Tuple[PRNGFoldable, ...] = struct.field(pytree_node=False)
def as_jax_rng(self) -> PRNGKey:
return _fold_in_static(self.rng, self.suffix)
@staticmethod
def create(rng: Union['LazyRng', PRNGKey],
*suffix: PRNGFoldable) -> 'LazyRng':
if not config.flax_lazy_rng:
if isinstance(rng, LazyRng):
assert not rng.suffix
rng = rng.rng
return LazyRng(_legacy_rng_fold_in(rng, suffix), ())
if isinstance(rng, LazyRng):
return LazyRng(rng.rng, rng.suffix + suffix)
else:
return LazyRng(rng, suffix)
def _legacy_rng_fold_in(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey:
"""Legacy RNG folding."""
for x in data:
if isinstance(x, str):
m = hashlib.sha1()
m.update(x.encode('utf-8'))
d = m.digest()
hash_int = int.from_bytes(d[:4], byteorder='big')
rng = random.fold_in(rng, jnp.uint32(hash_int))
elif isinstance(x, int):
rng = random.fold_in(rng, x)
else:
raise ValueError(f'Expected int or string, got: {x}')
return rng
def _fold_in_static(rng: PRNGKey,
data: typing.Collection[PRNGFoldable]) -> PRNGKey:
"""Folds static data (strings & ints) into a jax.random.PRNGKey using its SHA-1 hash.
This is faster than splitting an PRNGKey because it allows generating new PRNG
keys in parallel that are independent of each other.
Args:
rng: the rng to fold the string into.
data: the string to be folded in.
Returns:
The newly generated PRNG key.
"""
if not data:
return rng
m = hashlib.sha1()
for x in data:
if isinstance(x, str):
m.update(x.encode('utf-8'))
elif isinstance(x, int):
m.update(x.to_bytes((x.bit_length() + 7) // 8, byteorder='big'))
else:
raise ValueError(f'Expected int or string, got: {x}')
d = m.digest()
hash_int = int.from_bytes(d[:4], byteorder='big')
return random.fold_in(rng, jnp.uint32(hash_int))
def is_filter_empty(filter_like: Filter) -> bool:
"""Returns True if `filter_like` is an empty filter.
Args:
filter_like: The filter to test.
Returns:
A filter is empty when it is an empty collection, it is a bool with value
False, ir it is a DenyList that matches everything. A string filter is never
empty.
"""
if isinstance(filter_like, str):
return False
if isinstance(filter_like, typing.Collection):
return not filter_like
if isinstance(filter_like, bool):
return not filter_like
if isinstance(filter_like, DenyList):
# if any arbitrary collection is in the denylist it matches everything so
# the filter is empty. This is checked with a stub.
return in_filter(filter_like.deny, '__flax_internal_stub__')
raise errors.InvalidFilterError(filter_like)
def in_filter(filter_like: Filter, col: str) -> bool:
"""Checks whether a filter can be applied to a collection.
Used for both collections and rng sequence filters.
Args:
filter_like: a filter (either a boolean, a string, or a list of strings) for
a collection.
col: a collection, which is a string identifying a dictionary of data, for
instance "params" or "batch_stats".
Returns:
True if either `filter_like` is True, equal to `col`, or a sequence
containing `col`.
"""
if isinstance(filter_like, str):
return col == filter_like
if isinstance(filter_like, typing.Collection):
return col in filter_like
if isinstance(filter_like, bool):
return filter_like
if isinstance(filter_like, DenyList):
return not in_filter(filter_like.deny, col)
raise errors.InvalidFilterError(filter_like)
def filter_to_set(x: Filter) -> Set[str]:
"""Converts a Filter into a set of collections, fails on the infinite set.
Args:
x: a filter (boolean, string, or list of strings).
Returns:
The input filter represented as a set of strings.
"""
assert x is not True and not isinstance(x, DenyList), 'Infinite set'
if x is False:
return set()
if isinstance(x, str):
return set([x])
if isinstance(x, typing.Collection):
return set(x)
raise errors.InvalidFilterError(x)
def union_filters(a: Filter, b: Filter) -> Filter:
"""Takes the union of two filters (similar to a logical or).
Args:
a: a filter.
b: a filter.
Returns:
The union of the two input filters. For instance,
`union_filters('f1', ['f2']) = {'f1', 'f2'}`.
"""
if a is True or b is True:
return True
if isinstance(a, DenyList) and isinstance(b, DenyList):
return DenyList(intersect_filters(a.deny, b.deny))
if isinstance(b, DenyList):
a, b = b, a
if isinstance(a, DenyList):
return DenyList(subtract_filters(a.deny, b))
a = filter_to_set(a)
b = filter_to_set(b)
return a.union(b)
def subtract_filters(a: Filter, b: Filter) -> Filter:
"""Returns the subtraction of b from a.
Args:
a: a filter.
b: a filter.
Returns:
A filter matching with values in a that are not in b.
"""
if b is True:
return False
if a is True:
return DenyList(b)
if isinstance(a, DenyList) and isinstance(b, DenyList):
return subtract_filters(b.deny, a.deny)
if isinstance(a, DenyList):
return DenyList(union_filters(a.deny, b))
if isinstance(b, DenyList):
return intersect_filters(a, b.deny)
a = filter_to_set(a)
b = filter_to_set(b)
return a - b
def intersect_filters(a: Filter, b: Filter) -> Filter:
"""Take the intersection of two filters (similar to a logical and).
Args:
a: a filter.
b: a filter.
Returns:
The intersection of the two input filters. For instance,
`intersect_filters('f1', ['f1', 'f2']) = {'f1'}`.
"""
if a is True:
return b
if b is True:
return a
if isinstance(a, DenyList) and isinstance(b, DenyList):
return DenyList(union_filters(b.deny, a.deny))
if isinstance(b, DenyList):
b, a = a, b
if isinstance(a, DenyList):
return subtract_filters(b, a.deny)
a = filter_to_set(a)
b = filter_to_set(b)
return a.intersection(b)
def group_collections(
xs: VariableDict,
col_filters: Sequence[CollectionFilter]) -> Sequence[MutableVariableDict]:
"""Groups variables by collection filters.
Iteratively applies the filters in `col_filters` to `xs`, and adds the result
of applying each filter to the output sequence. Each key in `xs` is only added
to the output once.
Args:
xs: a dictionary of variables, keyed by collections (strings).
col_filters: a list of collection filters.
Returns:
A sequence S with `len(S) == len(col_filters)`. Each `S[i]` is the result of
applying filter `col_filters[i]` to the remaining keys in `xs`.
"""
cols: Iterable[str]
cols = xs.keys()
groups = []
for col_filter in col_filters:
remaining_cols = []
group = {}
for col in cols:
if in_filter(col_filter, col):
group[col] = jax.tree_map(lambda x: x, xs[col])
else:
remaining_cols.append(col)
cols = remaining_cols
groups.append(group)
return tuple(groups)
[docs]class Variable(Generic[T]):
"""A Variable object allows mutable access to a variable in a VariableDict.
Variables are identified by a collection (e.g., "batch_stats") and a name
(e.g., "moving_mean"). The value property gives access to the variable's
content and can be assigned to for mutation.
"""
def __init__(self, scope: 'Scope', collection: str, name: str):
"""Initializes a variable.
Args:
scope: The scope in which the variable is stored.
collection: The collection of the variable (e.g., "params").
name: The name of the variable (e.g., "dense").
"""
self.scope = scope
self.collection = collection
self.name = name
@property
def value(self) -> T:
"""Returns the value of this Variable."""
return self.scope.get_variable(self.collection, self.name)
@value.setter
def value(self, value: T):
"""Updates the value of this Variable."""
self.scope.put_variable(self.collection, self.name, value)
def is_mutable(self) -> bool:
"""Checks if this Variable is mutable."""
return self.scope.is_mutable_collection(self.collection)
class _ChildRNGSentinel:
pass
# used to identify that an rng counter is meant for a child scope
child_rng_token = _ChildRNGSentinel()
class Scope:
"""A Scope allows easy access to variables and manages RNGS of a neural network layer.
Scopes are purely functional and encapsulated in
:class:`flax.linen.module.Module`, so users writing neural network code
usually generally do not interact with ``Scopes`` directly.
See `core design tests
<https://github.com/google/flax/tree/main/tests/core/design>`_
for a number of examples using ``Scopes``.
"""
reservations: Set[str]
def __init__(self,
variables: MutableVariableDict,
rngs: Optional[Dict[str, Union[PRNGKey, LazyRng]]] = None,
name: Optional[str] = None,
mutable: CollectionFilter = False,
parent: Optional['Scope'] = None,
path: Iterable[str] = ()):
"""Initializes a Scope.
Args:
variables: VariableDict to initialize the Scope with.
rngs: RNGs used in this scope or one of the child scopes.
name: name of this scope.
mutable: A CollectionFilter determining which variables are mutable.
parent: The parent scope.
path: The path in the variable tree from the root scope to this scope.
"""
rngs = {k: LazyRng.create(v) for k, v in rngs.items()} if rngs else {}
self._variables = variables
self.parent = parent
self.name = name
self.path = tuple(path)
self.rngs = rngs
self.mutable = mutable
self._root = parent.root if parent else None
self.trace_level = tracers.trace_level(tracers.current_trace())
self.rng_counters = {key: 0 for key in self.rngs}
self.reservations = set()
self._invalid = False
def __eq__(self, other: Any) -> bool:
# If the root variable dict and path are the same, then two scopes behave
# identically. Effectively, a scope is nothing more than a cursor into a
# variable dict and an rng counter dict.
if not isinstance(other, Scope):
return False
if self is other:
return True
return self.root._variables is other.root._variables and self.path == other.path and self.rng_counters is other.rng_counters
def __hash__(self) -> int:
# see __eq__
return hash((id(self.root._variables), self.path, id(self.rng_counters)))
@property
def root(self) -> 'Scope':
return self._root or self
@property
def path_text(self) -> str:
"""Returns the path as a human readable string with slashes between parts."""
return '/' + '/'.join(self.path)
@property
def invalid(self) -> bool:
"""Returns true if this scope is invalidated as a result of `Scope.temporary`."""
return self._invalid
def _check_valid(self):
if self._invalid:
raise errors.InvalidScopeError(self.name)
@contextlib.contextmanager
def temporary(self):
"""Returns a context manager that will invalidate this Scope when leaving the context."""
try:
yield self
finally:
self.invalidate()
def invalidate(self):
"""Invalidates the Scope."""
self._invalid = True
def mutable_variables(self) -> VariableDict:
"""Returns an immutable copy of the mutable variables belonging to this Scope."""
self._populate_collections()
xs = {k: v for k, v in self._variables.items()
if in_filter(self.mutable, k)}
return freeze(xs)
def variables(self) -> VariableDict:
"""Returns an immutable copy of the variables belonging to this Scope."""
self._populate_collections()
return freeze(self._variables)
def _validate_trace_level(self):
tracers.check_trace_level(self.trace_level)
def rewound(self, rewind_rngs: bool = False) -> 'Scope':
"""Returns a rewound version of this Scope.
Args:
rewind_rngs: if true, reset the RNG counter of this scope.
Returns:
A rewound version of this scope, which means reservations are
emptied, and the rng counter is optionally rewound.
"""
self._check_valid()
scope = Scope(self._variables, self.rngs, self.name, self.mutable,
self.parent)
if not rewind_rngs:
scope.rng_counters = self.rng_counters
return scope
def reserve(self, name: str):
"""Reserves a name for a child Scope or Variable.
Args:
name: the name to reserve.
"""
if not isinstance(name, str):
raise TypeError('The type of scope "{name}" should be string but '
f'it is {type(name)}')
if name in self.reservations:
raise ValueError(f'Duplicate use of scope name: "{name}"')
self.reservations.add(name)
def default_name(self, prefix: str) -> str:
"""Generates an unreserved name with the given prefix.
Args:
prefix: prefix to use for generating an unreserved name.
Returns:
The generated name.
"""
i = 0
while True:
name = f'{prefix}{i}'
if name not in self.reservations:
return name
i += 1
def push(self,
name: Optional[str] = None,
prefix: str = '',
reuse=False) -> 'Scope':
"""Creates a child Scope.
Args:
name: optional name of the child.
prefix: prefix used for generating the name if `name` is `None`.
reuse: if True will return a pre-existing child scope with the given name
instead of throwing an error.
Returns:
The child scope.
"""
self._check_valid()
self._validate_trace_level()
if name is None:
name = self.default_name(prefix)
if not reuse or name not in self.reservations:
self.reserve(name)
rngs = {key: LazyRng.create(rng, name) for key, rng in self.rngs.items()}
rng_key = (child_rng_token, name)
if rng_key in self.rng_counters:
rng_counters = self.rng_counters.get(rng_key)
else:
rng_counters = {key: 0 for key in rngs}
self.rng_counters[rng_key] = rng_counters
scope = Scope({},
name=name,
rngs=rngs,
parent=self,
mutable=self.mutable,
path=self.path + (name,))
scope.rng_counters = rng_counters
return scope
def child(self,
fn: Callable[..., Any],
name: Optional[str] = None,
prefix: Optional[str] = None,
named_call: bool = True,
**partial_kwargs) -> Callable[..., Any]:
"""Partially applies a child scope to fn.
When calling the returned function multiple times variables will be reused.
Args:
fn: the function to partially apply the child Scope to.
name: optional name of the child.
prefix: prefix used for generating name if it is `None`.
named_call: if true, `fn` will be wrapped with `lift.named_call`. The XLA
profiler will use this to name tag the computation.
**partial_kwargs: additional kwargs partially applied to `fn`.
Returns:
The function with a partially applied scope.
"""
if name is None:
if prefix is None:
prefix = fn.__name__ + '_' if hasattr(fn, '__name__') else ''
name = self.default_name(prefix)
scope = self.push(name)
if named_call:
# We import named_call at runtime to avoid a circular import issue.
from . import lift # pylint: disable=g-import-not-at-top
fn = lift.named_call(fn, name)
@functools.wraps(fn)
def wrapper(*args, **kwargs):
kwargs = dict(partial_kwargs, **kwargs)
return fn(scope.rewound(), *args, **kwargs)
return wrapper
def is_mutable_collection(self, col: str) -> bool:
"""Returns true if the collection `col` is mutable."""
return in_filter(self.mutable, col)
def is_collection_empty(self, col: str) -> bool:
"""Returns true if the collection is empty."""
if col in self.root._variables: # pylint: disable=protected-access
return not self.root._variables[col] # pylint: disable=protected-access
return True
def _mutable_collection(self, col: str) -> MutableCollection:
"""Returns the collection `col` as a mutable object."""
assert self.is_mutable_collection(col), f'Collection {col} is not mutable'
if col not in self._variables:
if self.parent:
parent_col = self.parent._mutable_collection(col) # pylint: disable=protected-access
if self.name not in parent_col:
parent_col[self.name] = {}
self._variables[col] = parent_col[self.name]
else:
self._variables[col] = {}
return self._variables[col]
def _collection(self, col: str) -> Collection:
"""Returns a collection of variables of collection `col`."""
if col not in self._variables:
if self.parent:
parent_col = self.parent._collection(col) # pylint: disable=protected-access
if self.name not in parent_col:
return FrozenDict()
self._variables[col] = parent_col[self.name]
else:
return FrozenDict()
return self._variables[col]
def has_rng(self, name: str) -> bool:
"""Returns true if a PRNGSequence with name `name` exists."""
return name in self.rngs
def make_rng(self, name: str) -> PRNGKey:
"""Generates A PRNGKey from a PRNGSequence with name `name`."""
if not self.has_rng(name):
raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"')
self._check_valid()
self._validate_trace_level()
self.rng_counters[name] += 1
return LazyRng.create(self.rngs[name], self.rng_counters[name]).as_jax_rng()
def get_variable(self, col: str, name: str, default: Any = None) -> Any:
"""Retrieves the value of a Variable.
Args:
col: the variable collection.
name: the name of the variable.
default: the default value to return if the variable does not exist in
this scope.
Returns:
The value of the input variable, of the default value if the variable
doesn't exist in this scope.
"""
variables = self._collection(col)
if name in variables:
return variables[name]
else:
return default
def has_variable(self, col: str, name: str) -> bool:
"""Returns true if the given variable exists in this scope.
Args:
col: the collection of the variable.
name: the name of the variable.
"""
variables = self._collection(col)
return name in variables
def put_variable(self, col: str, name: str, value: Any):
"""Updates the value of the given variable if it is mutable, or an error otherwise.
Args:
col: the collection of the variable.
name: the name of the variable.
value: the new value of the given variable.
"""
self._check_valid()
self._validate_trace_level()
if not self.is_mutable_collection(col):
raise errors.ModifyScopeVariableError(col, name, self.path_text)
variables = self._mutable_collection(col)
variables[name] = value
def variable(self, col: str, name: str, # pylint: disable=keyword-arg-before-vararg
init_fn: Optional[Callable[..., T]] = None,
*init_args) -> Variable[T]:
"""Creates a variable if it doesn't exist yet in this scope and returns it.
Args:
col: the collection of the variable.
name: the name of the variable.
init_fn: a function taking a PRNGKey plus any other number of positional
arguments. If None, the variable must already be initialized otherwise
an error is raised.
*init_args: the arguments to evaluate init_fn on lazily.
Returns:
The variable.
"""
self.reserve(name)
if not self.has_variable(col, name):
if not self.is_mutable_collection(col) or init_fn is None:
if self.is_collection_empty(col):
raise errors.ScopeCollectionNotFound(col, name, self.path_text)
raise errors.ScopeVariableNotFoundError(name, col, self.path_text)
init_value = init_fn(*init_args)
self.put_variable(col, name, init_value)
return Variable(self, col, name)
def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:
"""Creates a parameter if it doesn't exist yet in this scope and returns it.
If the parameter exists already, the existing value is simply returned.
Args:
name: the name of the parameter.
init_fn: a function taking a PRNGKey plus any other number of positional
arguments.
*init_args: the arguments to evaluate init_fn on lazily.
Returns:
The parameters.
"""
self.reserve(name)
if self.has_variable('params', name):
abs_rng = jax.ShapeDtypeStruct(random.default_prng_impl().key_shape,
jnp.uint32)
value = self.get_variable('params', name)
# Validate that the shape of the init_fn output is the same as the shape
# of the existing parameter. This is to make sure that the hparams set up
# in a Flax Module match the shapes coming in during apply, and if not,
# catch it with an error message.
# NOTE: We could consider moving this to `self.`
abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args), abs_rng)
abs_value_flat = jax.tree_leaves(abs_value)
value_flat = jax.tree_leaves(value)
for val, abs_val in zip(value_flat, abs_value_flat):
# NOTE: We could check dtype consistency here as well but it's
# usefuleness is less obvious. We might intentionally change the dtype
# for inference to a half float type for example.
if jnp.shape(val) != jnp.shape(abs_val):
raise errors.ScopeParamShapeError(name, self.path_text,
jnp.shape(val), jnp.shape(abs_val))
else:
if not self.is_mutable_collection('params'):
if self.is_collection_empty('params'):
raise errors.ScopeCollectionNotFound('params', name, self.path_text)
raise errors.ScopeParamNotFoundError(name, self.path_text)
value = init_fn(self.make_rng('params'), *init_args)
self.put_variable('params', name, value)
return value
def _populate_collections(self):
collections = self.root._variables.keys() # pylint: disable=protected-access
for col in collections:
self._collection(col)
def _unfreeze_variables(variables, mutable):
new_variables = {}
for key, value in variables.items():
if in_filter(mutable, key):
new_variables[key] = unfreeze(value)
else:
new_variables[key] = freeze(value)
return new_variables
def bind(variables: VariableDict,
rngs: Optional[RNGSequences] = None,
mutable: CollectionFilter = False):
"""Binds variables and rngs to a new ``Scope``.
bind provides a ``Scope`` instance without transforming a function with
``apply``. This is particalary useful for debugging and interactive use cases
like notebooks where a function would limit the ability split up code into
different cells.
a ``Scope`` instance is a stateful object. Note that idiomatic JAX is
functional and therefore a ``Scope` does not mix well well with vanilla JAX
APIs. Therefore, we recommend using ``apply`` when code should be reusable and
compatible across the JAX software ecosystem.
Args:
variables: Variable dictionary to bind.
rngs: RNGs to bind.
mutable: Which variable colections to treat as mutable.
Returns:
A new scope with the variables and rngs bound to it.
"""
if not _is_valid_variables(variables):
raise errors.ApplyScopeInvalidVariablesTypeError()
if rngs is not None and not _is_valid_rngs(rngs):
raise errors.InvalidRngError(
'rngs should be a dictionary mapping strings to `jax.PRNGKey`.')
new_variables = _unfreeze_variables(variables, mutable)
return Scope(new_variables, rngs=rngs, mutable=mutable)
def apply(fn: Callable[..., Any],
mutable: CollectionFilter = False) -> Callable[..., Any]:
"""Functionalize a `Scope` function.
Args:
fn: a function taking a `Scope` as its first argument.
mutable: the filter determining which variable collections are mutable.
Returns:
`fn` with the scope partially applied.
"""
@functools.wraps(fn)
def wrapper(variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
**kwargs) -> Union[Any, Tuple[Any, VariableDict]]:
# Try to detect if user accidentally passed {'params': {'params': ...}.
if 'params' in variables and isinstance(
variables['params'],
(dict, FrozenDict)) and 'params' in variables['params']:
raise errors.ApplyScopeInvalidVariablesStructureError(variables)
with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
y = fn(root, *args, **kwargs)
if mutable is not False:
return y, root.mutable_variables()
else:
return y
return wrapper
def init(fn: Callable[..., Any],
mutable: CollectionFilter = True) -> Callable[..., Any]:
"""Functionalize a `Scope` function for initialization.
Args:
fn: a function taking a `Scope` as its first argument.
mutable: the filter determining which variable collections are mutable.
Returns:
`fn` with the scope partially applied.
"""
@functools.wraps(fn)
def wrapper(rngs, *args, **kwargs) -> Tuple[Any, VariableDict]:
if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs):
raise ValueError('First argument passed to an init function should be a '
'`jax.PRNGKey` or a dictionary mapping strings to '
'`jax.PRNGKey`.')
if not isinstance(rngs, dict):
rngs = {'params': rngs}
return apply(fn, mutable=mutable)({}, *args, rngs=rngs, **kwargs)
return wrapper
def _is_valid_collection(col: VariableDict):
if not isinstance(col, (FrozenDict, dict)):
return False
for name in col.keys():
# Any value can be stored in a collection so only keys can be verified.
if not isinstance(name, str):
return False
return True
def _is_valid_variables(variables: VariableDict) -> bool:
"""Checks whether the given variable dict is valid.
Args:
variables: A variable dict.
Returns:
True if `variables` is a valid variable dict.
"""
for name, col in variables.items():
if not isinstance(name, str):
return False
if not _is_valid_collection(col):
return False
return True
def _is_valid_rng(rng: Array):
"""Checks whether rng is a valid JAX PRNGKey, also handling custom prngs."""
# New-style JAX KeyArrays have a base type.
if jax_config.jax_enable_custom_prng:
if not isinstance(rng, jax.random.KeyArray):
return False
# Old-style JAX PRNGKeys are plain uint32 arrays.
else:
if not isinstance(rng, (np.ndarray, jnp.ndarray)):
return False
if (rng.shape != random.default_prng_impl().key_shape or
rng.dtype != jnp.uint32):
return False
return True
def _is_valid_rngs(rngs: RNGSequences):
if not isinstance(rngs, (FrozenDict, dict)):
return False
for key, val in rngs.items():
if not isinstance(key, str):
return False
if not _is_valid_rng(val):
return False
return True