Source code for flax.core.scope

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Flax functional core: Scopes."""

import collections
import contextlib
import dataclasses
import functools
import hashlib
import typing
from typing import (
  Any,
  Callable,
  Dict,
  Generic,
  Iterable,
  Literal,
  Mapping,
  Optional,
  Sequence,
  Set,
  Tuple,
  TypeVar,
  Union,
  cast,
  overload,
)

import jax
import numpy as np
from jax import numpy as jnp
from jax import random, tree_util

from flax import config as config
from flax import configurations as legacy_config  # only for flax_lazy_rng
from flax import errors, struct, traceback_util
from flax.ids import uuid
from flax.typing import (
  PRNGKey,
  Array,
  RNGSequences,
  Collection,
  MutableCollection,
  VariableDict,
  FrozenVariableDict as FrozenVariableDict,
  MutableVariableDict,
  PRNGFoldable,
)

from . import meta, partial_eval, tracers
from .frozen_dict import FrozenDict, freeze, unfreeze

traceback_util.register_exclusion(__file__)

T = TypeVar('T')


Filter = Union[bool, str, typing.Collection[str], 'DenyList']

# When conditioning on filters we require explicit boolean comparisons.
# pylint: disable=g-bool-id-comparison


@dataclasses.dataclass(frozen=True, eq=True)
class DenyList:
  """DenyList represents an opt-out based mutability filter.
  DenyList can be used to make every collection mutable except the ones
  defined in the given filter.
  To for example make everything but the params collection mutable::
    nn.apply(fn, mutable=nn.DenyList(["params"]))
  Attributes:
    deny: The filter representing the collections that are not mutable.
  """

  deny: Filter


CollectionFilter = Filter
PRNGSequenceFilter = Filter


class LazyRng(struct.PyTreeNode):
  """Wrapper around JAX PRNGKey that lazily maintains a tuple of static data to be folded into the rng."""

  rng: PRNGKey
  suffix: Tuple[PRNGFoldable, ...] = struct.field(pytree_node=False)

  def as_jax_rng(self) -> PRNGKey:
    return _fold_in_static(self.rng, self.suffix)

  @staticmethod
  def create(
    rng: Union['LazyRng', PRNGKey], *suffix: PRNGFoldable
  ) -> 'LazyRng':
    if not legacy_config.flax_lazy_rng:
      if isinstance(rng, LazyRng):
        assert not rng.suffix
        rng = rng.rng
      return LazyRng(_legacy_rng_fold_in(rng, suffix), ())
    if isinstance(rng, LazyRng):
      return LazyRng(rng.rng, rng.suffix + suffix)
    else:
      return LazyRng(rng, suffix)


def _legacy_rng_fold_in(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey:
  """Legacy RNG folding."""
  for x in data:
    if isinstance(x, str):
      m = hashlib.sha1()
      m.update(x.encode('utf-8'))
      d = m.digest()
      hash_int = int.from_bytes(d[:4], byteorder='big')
      rng = random.fold_in(rng, jnp.uint32(hash_int))  # type: ignore
    elif isinstance(x, int):
      rng = random.fold_in(rng, x)
    else:
      raise ValueError(f'Expected int or string, got: {x}')
  return rng


def _fold_in_static(
  rng: PRNGKey, data: typing.Collection[PRNGFoldable]
) -> PRNGKey:
  """Folds static data (strings & ints) into a jax.random.PRNGKey using its SHA-1 hash.

  This is faster than splitting an PRNGKey because it allows generating new PRNG
  keys in parallel that are independent of each other.

  Args:
   rng: the rng to fold the string into.
   data: the string to be folded in.

  Returns:
   The newly generated PRNG key.
  """
  if not data:
    return rng
  m = hashlib.sha1()
  for x in data:
    if config.flax_fix_rng_separator:
      # encode seperate to avoid collisions like for example: ("ab", "c") and ("a", "bc")
      m.update(b'\00')
    if isinstance(x, str):
      m.update(x.encode('utf-8'))
    elif isinstance(x, int):
      m.update(x.to_bytes((x.bit_length() + 7) // 8, byteorder='big'))
    else:
      raise ValueError(f'Expected int or string, got: {x}')
  d = m.digest()
  hash_int = int.from_bytes(d[:4], byteorder='big')
  return random.fold_in(rng, jnp.uint32(hash_int))  # type: ignore


def is_filter_empty(filter_like: Filter) -> bool:
  """Returns True if `filter_like` is an empty filter.

  Args:
    filter_like: The filter to test.

  Returns:
    A filter is empty when it is an empty collection, it is a bool with value
    False, ir it is a DenyList that matches everything. A string filter is never
    empty.
  """
  if isinstance(filter_like, str):
    return False
  if isinstance(filter_like, typing.Collection):
    return not filter_like
  if isinstance(filter_like, bool):
    return not filter_like
  if isinstance(filter_like, DenyList):
    # if any arbitrary collection is in the denylist it matches everything so
    # the filter is empty. This is checked with a stub.
    return in_filter(filter_like.deny, '__flax_internal_stub__')
  raise errors.InvalidFilterError(filter_like)


def in_filter(filter_like: Filter, col: str) -> bool:
  """Checks whether a filter can be applied to a collection.

  Used for both collections and rng sequence filters.

  Args:
    filter_like: a filter (either a boolean, a string, or a list of strings) for
      a collection.
    col: a collection, which is a string identifying a dictionary of data, for
      instance "params" or "batch_stats".

  Returns:
    True if either `filter_like` is True, equal to `col`, or a sequence
    containing `col`.
  """
  if isinstance(filter_like, str):
    return col == filter_like
  if isinstance(filter_like, typing.Collection):
    return col in filter_like
  if isinstance(filter_like, bool):
    return filter_like
  if isinstance(filter_like, DenyList):
    return not in_filter(filter_like.deny, col)
  raise errors.InvalidFilterError(filter_like)


def filter_to_set(x: Filter) -> Set[str]:
  """Converts a Filter into a set of collections, fails on the infinite set.

  Args:
    x: a filter (boolean, string, or list of strings).

  Returns:
    The input filter represented as a set of strings.
  """
  assert x is not True and not isinstance(x, DenyList), 'Infinite set'
  if x is False:
    return set()
  if isinstance(x, str):
    return set([x])
  if isinstance(x, typing.Collection):
    return set(x)
  raise errors.InvalidFilterError(x)


def union_filters(a: Filter, b: Filter) -> Filter:
  """Takes the union of two filters (similar to a logical or).

  Args:
    a: a filter.
    b: a filter.

  Returns:
    The union of the two input filters. For instance,
    `union_filters('f1', ['f2']) = {'f1', 'f2'}`.
  """
  if a is True or b is True:
    return True
  if isinstance(a, DenyList) and isinstance(b, DenyList):
    return DenyList(intersect_filters(a.deny, b.deny))
  if isinstance(b, DenyList):
    a, b = b, a
  if isinstance(a, DenyList):
    return DenyList(subtract_filters(a.deny, b))

  a = filter_to_set(a)
  b = filter_to_set(b)
  return a.union(b)


def subtract_filters(a: Filter, b: Filter) -> Filter:
  """Returns the subtraction of b from a.

  Args:
    a: a filter.
    b: a filter.

  Returns:
    A filter matching with values in a that are not in b.
  """
  if b is True:
    return False
  if a is True:
    return DenyList(b)
  if isinstance(a, DenyList) and isinstance(b, DenyList):
    return subtract_filters(b.deny, a.deny)
  if isinstance(a, DenyList):
    return DenyList(union_filters(a.deny, b))
  if isinstance(b, DenyList):
    return intersect_filters(a, b.deny)
  a = filter_to_set(a)
  b = filter_to_set(b)
  return a - b


def intersect_filters(a: Filter, b: Filter) -> Filter:
  """Take the intersection of two filters (similar to a logical and).

  Args:
    a: a filter.
    b: a filter.

  Returns:
    The intersection of the two input filters. For instance,
    `intersect_filters('f1', ['f1', 'f2']) = {'f1'}`.
  """
  if a is True:
    return b
  if b is True:
    return a
  if isinstance(a, DenyList) and isinstance(b, DenyList):
    return DenyList(union_filters(b.deny, a.deny))
  if isinstance(b, DenyList):
    b, a = a, b
  if isinstance(a, DenyList):
    return subtract_filters(b, a.deny)
  a = filter_to_set(a)
  b = filter_to_set(b)
  return a.intersection(b)


def group_collections(
  xs: VariableDict, col_filters: Sequence[CollectionFilter]
) -> Sequence[MutableVariableDict]:
  """Groups variables by collection filters.

  Iteratively applies the filters in `col_filters` to `xs`, and adds the result
  of applying each filter to the output sequence. Each key in `xs` is only added
  to the output once.

  Args:
    xs: a dictionary of variables, keyed by collections (strings).
    col_filters: a list of collection filters.

  Returns:
    A sequence S with `len(S) == len(col_filters)`. Each `S[i]` is the result of
    applying filter `col_filters[i]` to the remaining keys in `xs`.
  """
  cols: Iterable[str]
  cols = xs.keys()
  groups = []
  for col_filter in col_filters:
    remaining_cols = []
    group = {}
    for col in cols:
      if in_filter(col_filter, col):
        group[col] = jax.tree_util.tree_map(lambda x: x, xs[col])
      else:
        remaining_cols.append(col)
    cols = remaining_cols
    groups.append(group)
  return tuple(groups)


[docs]class Variable(Generic[T]): """A Variable object allows mutable access to a variable in a VariableDict. Variables are identified by a collection (e.g., "batch_stats") and a name (e.g., "moving_mean"). The value property gives access to the variable's content and can be assigned to for mutation. """ def __init__(self, scope: 'Scope', collection: str, name: str, unbox: bool): """Initializes a variable. Args: scope: The scope in which the variable is stored. collection: The collection of the variable (e.g., "params"). name: The name of the variable (e.g., "dense"). unbox: Whether to unbox boxed values with metadata. """ self._id = uuid() self.scope = scope self.collection = collection self.name = name self.unbox = unbox @property def value(self) -> T: """Returns the value of this Variable.""" v = self.scope.get_variable(self.collection, self.name) return meta.unbox(v) if self.unbox else v @value.setter def value(self, value: T): """Updates the value of this Variable.""" if self.unbox: cur = self.scope.get_variable(self.collection, self.name) cur_struct = tree_util.tree_structure(cur, is_leaf=meta.is_axis_metadata) value_struct = tree_util.tree_structure( value, is_leaf=meta.is_axis_metadata ) has_meta = any(map(meta.is_axis_metadata, cur_struct.flatten_up_to(cur))) if cur_struct == value_struct and has_meta: value = meta.replace_boxed(cur, value) self.scope.put_variable(self.collection, self.name, value) def is_mutable(self) -> bool: """Checks if this Variable is mutable.""" return self.scope.is_mutable_collection(self.collection)
class _ChildRNGSentinel: pass # used to identify that an rng counter is meant for a child scope child_rng_token = _ChildRNGSentinel() class _DefaultSentinel: pass # used to denote no default flag value on scope no_flag = _DefaultSentinel() class Scope: """A Scope allows easy access to variables and manages RNGS of a neural network layer. Scopes are purely functional and encapsulated in :class:`flax.linen.module.Module`, so users writing neural network code usually generally do not interact with ``Scopes`` directly. See `core design tests <https://github.com/google/flax/tree/main/tests/core/design>`_ for a number of examples using ``Scopes``. """ reservations: Dict[str, Set[Optional[str]]] def __init__( self, variables: MutableVariableDict, rngs: Optional[Union[RNGSequences, Dict[str, LazyRng]]] = None, name: Optional[str] = None, mutable: CollectionFilter = False, parent: Optional['Scope'] = None, path: Iterable[str] = (), debug_path: Iterable[str] = (), flags: Optional[Mapping] = None, ): """Initializes a Scope. Args: variables: VariableDict to initialize the Scope with. rngs: RNGs used in this scope or one of the child scopes. name: name of this scope. mutable: A CollectionFilter determining which variables are mutable. parent: The parent scope. path: The path in the variable tree from the root scope to this scope. It exactly matches the module path. debug_path: Similar to path but could contain transformation decorators. flags: internal flags. """ rngs = {k: LazyRng.create(v) for k, v in rngs.items()} if rngs else {} self._variables = variables self.parent = parent self.name = name self.path = tuple(path) self.debug_path = tuple(debug_path) or self.path self.rngs = rngs self.mutable = mutable self.flags = freeze({} if flags is None else flags) self._root = parent.root if parent else None self.trace_level = tracers.trace_level(tracers.current_trace()) self.rng_counters = {key: 0 for key in self.rngs} self.reservations = collections.defaultdict(set) self._invalid = False def __eq__(self, other: Any) -> bool: # If the root variable dict and path are the same, then two scopes behave # identically. Effectively, a scope is nothing more than a cursor into a # variable dict and an rng counter dict. if not isinstance(other, Scope): return False if self is other: return True return ( self.root._variables is other.root._variables and self.path == other.path and self.rng_counters is other.rng_counters ) def __hash__(self) -> int: # see __eq__ return hash((id(self.root._variables), self.path, id(self.rng_counters))) @property def root(self) -> 'Scope': return self._root or self @property def path_text(self) -> str: """Returns the debug path as a human readable string.""" return '/' + '/'.join(self.debug_path) @property def invalid(self) -> bool: """Returns true if this scope is invalidated as a result of `Scope.temporary`.""" return self._invalid def _check_valid(self): if self._invalid: raise errors.InvalidScopeError(self.name) @contextlib.contextmanager def temporary(self): """Returns a context manager that will invalidate this Scope when leaving the context.""" try: yield self finally: self.invalidate() def invalidate(self): """Invalidates the Scope.""" self._invalid = True def mutable_variables(self) -> Union[VariableDict, Dict[str, Any]]: """Returns an immutable copy of the mutable variables belonging to this Scope.""" self._populate_collections() xs = { k: v for k, v in self._variables.items() if in_filter(self.mutable, k) } if config.flax_return_frozendict: return freeze(xs) return xs def variables(self) -> Union[VariableDict, Dict[str, Any]]: """Returns an immutable copy of the variables belonging to this Scope.""" self._populate_collections() if config.flax_return_frozendict: return freeze(self._variables) return self._variables def _validate_trace_level(self): tracers.check_trace_level(self.trace_level) def rewound(self, rewind_rngs: bool = False) -> 'Scope': """Returns a rewound version of this Scope. Args: rewind_rngs: if true, reset the RNG counter of this scope. Returns: A rewound version of this scope, which means reservations are emptied, and the rng counter is optionally rewound. """ self._check_valid() scope = Scope( self._variables, self.rngs, self.name, self.mutable, self.parent, path=self.path, debug_path=self.debug_path, flags=self.flags, ) if not rewind_rngs: scope.rng_counters = self.rng_counters return scope def name_reserved(self, name: str, col: Optional[str] = None) -> bool: """Checks whether a name for a child Scope or Variable is taken. Args: name: the name to check for collision. col: if a variable, the collection used. """ if name in self.reservations: # allow the same name for two variables in # different collections, otherwise raise error. if ( None in self.reservations[name] or col is None or col in self.reservations[name] ): return True return False def reserve(self, name: str, col: Optional[str] = None): """Reserves a name for a child Scope or Variable. Throws an error if the name exists already. Args: name: the name to reserve. col: if a variable, the collection used. """ if not isinstance(name, str): raise TypeError( 'The type of scope "{name}" should be string but ' f'it is {type(name)}' ) if self.name_reserved(name, col): raise ValueError(f'Duplicate use of scope name: "{name}"') self.reservations[name].add(col) def default_name(self, prefix: str) -> str: """Generates an unreserved name with the given prefix. Args: prefix: prefix to use for generating an unreserved name. Returns: The generated name. """ i = 0 while True: name = f'{prefix}{i}' if name not in self.reservations: return name i += 1 def push( self, name: Optional[str] = None, prefix: str = '', reuse=False ) -> 'Scope': """Creates a child Scope. Args: name: optional name of the child. prefix: prefix used for generating the name if `name` is `None`. reuse: if True will return a pre-existing child scope with the given name instead of throwing an error. Returns: The child scope. """ self._check_valid() self._validate_trace_level() if name is None: name = self.default_name(prefix) if not reuse or name not in self.reservations: self.reserve(name) rngs = {key: LazyRng.create(rng, name) for key, rng in self.rngs.items()} rng_key = (child_rng_token, name) if rng_key in self.rng_counters: rng_counters = self.rng_counters.get(rng_key) # type: ignore else: rng_counters = {key: 0 for key in rngs} self.rng_counters[rng_key] = rng_counters # type: ignore scope = Scope( {}, name=name, rngs=rngs, parent=self, mutable=self.mutable, path=self.path + (name,), debug_path=self.debug_path + (name,), flags=self.flags, ) scope.rng_counters = rng_counters return scope def child( self, fn: Callable[..., Any], name: Optional[str] = None, prefix: Optional[str] = None, named_call: bool = True, **partial_kwargs, ) -> Callable[..., Any]: """Partially applies a child scope to fn. When calling the returned function multiple times variables will be reused. Args: fn: the function to partially apply the child Scope to. name: optional name of the child. prefix: prefix used for generating name if it is `None`. named_call: if true, `fn` will be run under `jax.named_scope`. The XLA profiler will use this to name tag the computation. **partial_kwargs: additional kwargs partially applied to `fn`. Returns: The function with a partially applied scope. """ if name is None: if prefix is None: prefix = fn.__name__ + '_' if hasattr(fn, '__name__') else '' name = self.default_name(prefix) scope = self.push(name) @functools.wraps(fn) def wrapper(*args, **kwargs): kwargs = dict(partial_kwargs, **kwargs) if named_call: with jax.named_scope(name): res = fn(scope.rewound(), *args, **kwargs) else: res = fn(scope.rewound(), *args, **kwargs) return res return wrapper def is_mutable_collection(self, col: str) -> bool: """Returns true if the collection `col` is mutable.""" return in_filter(self.mutable, col) def is_collection_empty(self, col: str) -> bool: """Returns true if the collection is empty.""" if col in self.root._variables: # pylint: disable=protected-access return not self.root._variables[col] # pylint: disable=protected-access return True def _mutable_collection(self, col: str) -> MutableCollection: """Returns the collection `col` as a mutable object.""" assert self.is_mutable_collection(col), f'Collection {col} is not mutable' # The actual variable dict is stored in the root scope only, and subscopes # hold references to subtrees relevant to them. This function ensures that # the collections are created in the top-level Scope and we return the # correct reference. if col not in self._variables: if not self.parent: # If this is the top-level Scope, just add an empty collection. self._variables[col] = {} else: assert self.name is not None # Only top-level Scope have name None. # Populate the parent collections recursively and obtain a reference to # the direct parent (which, by transitivity, is be a reference to a # dict in the root Scope). parent_col = self.parent._mutable_collection(col) # pylint: disable=protected-access if self.name not in parent_col: # If this Scope's name does not occur in the parent collection, add it # to the parent scope (updating the parent's variable dict). parent_col[self.name] = {} # Store a reference to the parent's scope collection for in this scope's # variable dict. self._variables[col] = parent_col[self.name] return self._variables[col] def _collection(self, col: str) -> Collection: """Returns a collection of variables of collection `col`.""" if col not in self._variables: if self.parent: assert self.name is not None parent_col = self.parent._collection(col) # pylint: disable=protected-access if self.name not in parent_col: return FrozenDict() self._variables[col] = parent_col[self.name] else: return FrozenDict() return self._variables[col] def has_rng(self, name: str) -> bool: """Returns true if a PRNGSequence with name `name` exists.""" return name in self.rngs def make_rng(self, name: str = 'params') -> PRNGKey: """Generates A PRNGKey from a PRNGSequence with name `name`.""" if not self.has_rng(name): if self.has_rng('params'): name = 'params' else: raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"') self._check_valid() self._validate_trace_level() self.rng_counters[name] += 1 return LazyRng.create(self.rngs[name], self.rng_counters[name]).as_jax_rng() def get_variable(self, col: str, name: str, default: Any = None) -> Any: """Retrieves the value of a Variable. Args: col: the variable collection. name: the name of the variable. default: the default value to return if the variable does not exist in this scope. Returns: The value of the input variable, of the default value if the variable doesn't exist in this scope. """ variables = self._collection(col) if name in variables: return variables[name] else: return default def has_variable(self, col: str, name: str) -> bool: """Returns true if the given variable exists in this scope. Args: col: the collection of the variable. name: the name of the variable. """ variables = self._collection(col) return name in variables def put_variable(self, col: str, name: str, value: Any): """Updates the value of the given variable if it is mutable, or an error otherwise. Args: col: the collection of the variable. name: the name of the variable. value: the new value of the given variable. """ self._check_valid() self._validate_trace_level() if not self.is_mutable_collection(col): raise errors.ModifyScopeVariableError(col, name, self.path_text) variables = self._mutable_collection(col) # Make sure reference sharing of child variable dictionaries isn't broken. # See https://github.com/google/flax/issues/2022 for more details. def put(target, key, val): if ( key in target and isinstance(target[key], dict) and isinstance(val, Mapping) ): for k, v in val.items(): put(target[key], k, v) else: target[key] = val put(variables, name, value) @overload def variable( self, col: str, name: str, init_fn: Optional[Callable[..., T]] = None, *init_args, ) -> Variable[T]: ... @overload def variable( self, col: str, name: str, init_fn: Optional[Callable[..., T]] = None, *init_args, unbox: Literal[True], **init_kwargs, ) -> Variable[T]: ... @overload def variable( self, col: str, name: str, init_fn: Optional[Callable[..., T]] = None, *init_args, unbox: Literal[False], **init_kwargs, ) -> Variable[meta.AxisMetadata[T]]: ... @overload def variable( self, col: str, name: str, init_fn: Optional[Callable[..., T]] = None, *init_args, unbox: bool = True, **init_kwargs, ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: ... def variable( self, col: str, name: str, # pylint: disable=keyword-arg-before-vararg init_fn: Optional[Callable[..., T]] = None, *init_args, unbox: bool = True, **init_kwargs, ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: """Creates a variable if it doesn't exist yet in this scope and returns it. Args: col: the collection of the variable. name: the name of the variable. init_fn: a function taking a PRNGKey plus any other number of positional arguments. If None, the variable must already be initialized otherwise an error is raised. *init_args: the positional arguments to evaluate init_fn on lazily. unbox: If True, ``AxisMetadata`` instances are replaced by their unboxed value, see ``flax.nn.meta.unbox`` (default: True). **init_kwargs: the key-word arguments to evaluate init_fn on lazily. Returns: The variable. Throws an error if the variable exists already. """ self.reserve(name, col) if not self.has_variable(col, name): if not self.is_mutable_collection(col) or init_fn is None: if self.is_collection_empty(col): raise errors.ScopeCollectionNotFound(col, name, self.path_text) raise errors.ScopeVariableNotFoundError(name, col, self.path_text) init_value = init_fn(*init_args, **init_kwargs) self.put_variable(col, name, init_value) # cast to make static analyzers happy return cast( Union[Variable[T], Variable[meta.AxisMetadata[T]]], Variable(self, col, name, unbox=unbox), ) @overload def param( self, name: str, init_fn: Callable[..., T], *init_args, ) -> T: ... @overload def param( self, name: str, init_fn: Callable[..., T], *init_args, unbox: Literal[True], **init_kwargs, ) -> T: ... @overload def param( self, name: str, init_fn: Callable[..., T], *init_args, unbox: Literal[False], **init_kwargs, ) -> meta.AxisMetadata[T]: ... @overload def param( self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool, **init_kwargs, ) -> Union[T, meta.AxisMetadata[T]]: ... def param( self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool = True, **init_kwargs, ) -> Union[T, meta.AxisMetadata[T]]: """Creates a parameter if it doesn't exist yet in this scope and returns it. If the parameter exists already, the existing value is simply returned. Args: name: the name of the parameter. init_fn: a function taking a PRNGKey plus any other number of positional arguments. *init_args: the positional arguments to evaluate init_fn on lazily. unbox: If True, ``AxisMetadata`` instances are replaced by their unboxed value, see ``flax.nn.meta.unbox`` (default: True). **init_kwargs: the key-word arguments to evaluate init_fn on lazily. Returns: The parameters. Throws an error if the params exist already. """ self.reserve(name, 'params') if self.has_variable('params', name): value = self.get_variable('params', name) # Validate that the shape of the init_fn output is the same as the shape # of the existing parameter. This is to make sure that the hparams set up # in a Flax Module match the shapes coming in during apply, and if not, # catch it with an error message. # NOTE: We could consider moving this to `self.` abs_value = jax.eval_shape( lambda: init_fn(random.key(0), *init_args, **init_kwargs) ) abs_value_flat = jax.tree_util.tree_leaves(abs_value) value_flat = jax.tree_util.tree_leaves(value) for val, abs_val in zip(value_flat, abs_value_flat): # NOTE: We could check dtype consistency here as well but it's # usefuleness is less obvious. We might intentionally change the dtype # for inference to a half float type for example. if jnp.shape(val) != jnp.shape(abs_val): raise errors.ScopeParamShapeError( name, self.path_text, jnp.shape(abs_val), jnp.shape(val) ) else: if not self.is_mutable_collection('params'): if self.is_collection_empty('params'): raise errors.ScopeCollectionNotFound('params', name, self.path_text) raise errors.ScopeParamNotFoundError(name, self.path_text) value = init_fn(self.make_rng('params'), *init_args, **init_kwargs) self.put_variable('params', name, value) if unbox: value = meta.unbox(value) return value def _populate_collections(self): collections = self.root._variables.keys() # pylint: disable=protected-access for col in collections: self._collection(col) def has_flag(self, key) -> bool: return key in self.flags def get_flag(self, key, default=no_flag) -> Any: if key not in self.flags and default is no_flag: return ValueError(f'Flag {key} not present on scope.') return self.flags.get(key, default) def _unfreeze_variables(variables, mutable): new_variables = {} for key, value in variables.items(): if in_filter(mutable, key): new_variables[key] = unfreeze(value) else: new_variables[key] = value return new_variables def bind( variables: VariableDict, rngs: Optional[RNGSequences] = None, mutable: CollectionFilter = False, flags: Optional[Mapping] = None, ): """Binds variables and rngs to a new ``Scope``. bind provides a ``Scope`` instance without transforming a function with ``apply``. This is particularly useful for debugging and interactive use cases like notebooks where a function would limit the ability split up code into different cells. a ``Scope`` instance is a stateful object. Note that idiomatic JAX is functional and therefore a ``Scope` does not mix well well with vanilla JAX APIs. Therefore, we recommend using ``apply`` when code should be reusable and compatible across the JAX software ecosystem. Args: variables: Variable dictionary to bind. rngs: RNGs to bind. mutable: Which variable collections to treat as mutable. flags: internal flags. Returns: A new scope with the variables and rngs bound to it. """ if not _is_valid_variables(variables): raise errors.ApplyScopeInvalidVariablesTypeError() if rngs is not None and not _is_valid_rngs(rngs): raise errors.InvalidRngError( 'rngs should be a dictionary mapping strings to `jax.PRNGKey`.' ) new_variables = _unfreeze_variables(variables, mutable) return Scope(new_variables, rngs=rngs, mutable=mutable, flags=flags) def apply( fn: Callable[..., Any], mutable: CollectionFilter = False, flags: Optional[Mapping] = None, ) -> Callable[..., Any]: """Functionalize a `Scope` function. Args: fn: a function taking a `Scope` as its first argument. mutable: the filter determining which variable collections are mutable. flags: internal flags. Returns: `fn` with the scope partially applied. """ @functools.wraps(fn) def wrapper( variables: VariableDict, *args, rngs: Optional[Union[PRNGKey, RNGSequences]] = None, **kwargs, ) -> Union[Any, Tuple[Any, Union[VariableDict, Dict[str, Any]]]]: if rngs is not None: if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): raise ValueError( 'The ``rngs`` argument passed to an apply function should be a ' '``jax.PRNGKey`` or a dictionary mapping strings to ' '``jax.PRNGKey``.' ) if not isinstance(rngs, (dict, FrozenDict)): rngs = {'params': rngs} # Try to detect if user accidentally passed {'params': {'params': ...}. if ( 'params' in variables and isinstance(variables['params'], (dict, FrozenDict)) and 'params' in variables['params'] ): raise errors.ApplyScopeInvalidVariablesStructureError(variables) with bind( variables, rngs=rngs, mutable=mutable, flags=flags ).temporary() as root: y = fn(root, *args, **kwargs) if mutable is not False: return y, root.mutable_variables() else: return y return wrapper def init( fn: Callable[..., Any], mutable: CollectionFilter = True, flags: Optional[Mapping] = None, ) -> Callable[..., Any]: """Functionalize a `Scope` function for initialization. Args: fn: a function taking a `Scope` as its first argument. mutable: the filter determining which variable collections are mutable. flags: internal flags. Returns: `fn` with the scope partially applied. """ @functools.wraps(fn) def wrapper(rngs, *args, **kwargs) -> Tuple[Any, VariableDict]: if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): raise ValueError( 'First argument passed to an init function should be a ' '``jax.PRNGKey`` or a dictionary mapping strings to ' '``jax.PRNGKey``.' ) if not isinstance(rngs, (dict, FrozenDict)): rngs = {'params': rngs} init_flags = {**(flags if flags is not None else {}), 'initializing': True} return apply(fn, mutable=mutable, flags=init_flags)( {}, *args, rngs=rngs, **kwargs ) return wrapper def lazy_init( fn: Callable[..., Any], mutable: CollectionFilter = True, flags: Optional[Mapping] = None, ) -> Callable[..., Any]: """Functionalizes a `Scope` function for lazy initialization. Similair to ``init`` except that the init function now accepts ``jax.ShapeDtypeStruct`` instances for arguments that do not affect the variable initialization (typically this is all the input data). Example:: def f(scope, x): # the kernel init only uses the shape of x so we don't actually # need a value for x and can pass it as a ShapeDtypeStruct in lazy_init. k = scope.param("kernel", nn.initializers.lecun_normal(), (x.shape[-1], x.shape[-1])) return x @ k init_fn = lazy_init(f) variables = init_fn(random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32)) Args: fn: a function taking a `Scope` as its first argument. mutable: the filter determining which variable collections are mutable. flags: internal flags. Returns: `fn` with the scope partially applied. Unlike ``init`` which returns a tuple of function output and variables, the lazy init function only returns the variables. """ return partial_eval.lazy_init( lambda *args, **kwargs: init(fn, mutable, flags)(*args, **kwargs)[1] ) def _is_valid_collection(col: VariableDict): if not isinstance(col, (FrozenDict, dict)): return False for name in col.keys(): # Any value can be stored in a collection so only keys can be verified. if not isinstance(name, str): return False return True def _is_valid_variables(variables: VariableDict) -> bool: """Checks whether the given variable dict is valid. Args: variables: A variable dict. Returns: True if `variables` is a valid variable dict. """ for name, col in variables.items(): if not isinstance(name, str): return False if not _is_valid_collection(col): return False return True def _is_valid_rng(rng: Array): """Checks whether rng is a valid JAX PRNGKey, also handling custom prngs.""" # This check is valid for either new-style or old-style PRNG keys if not isinstance(rng, (np.ndarray, jnp.ndarray)): return False # Handle new-style typed PRNG keys if hasattr(jax.dtypes, 'prng_key'): # JAX 0.4.14 or newer if jax.dtypes.issubdtype(rng.dtype, jax.dtypes.prng_key): return rng.shape == () elif hasattr(jax.random, 'PRNGKeyArray'): # Previous JAX versions if isinstance(rng, jax.random.PRNGKeyArray): return rng.shape == () # Handle old-style raw PRNG keys expected_rng = jax.eval_shape( lambda s: jax.random.key_data(jax.random.key(s)), 0 ) if (rng.shape, rng.dtype) != (expected_rng.shape, expected_rng.dtype): return False return True def _is_valid_rngs(rngs: Union[PRNGKey, RNGSequences]): if not isinstance(rngs, (FrozenDict, dict)): return False for key, val in rngs.items(): if not isinstance(key, str): return False if not _is_valid_rng(val): return False return True