Source code for flax.nnx.helpers

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import inspect
import typing as tp

import jax
import jax.numpy as jnp
import optax

from flax.nnx.module import GraphDef, Module
from flax.nnx.proxy_caller import ApplyCaller
from flax.nnx.rnglib import Rngs
from flax.nnx.statelib import State
from flax.training.train_state import struct

A = tp.TypeVar('A')
M = tp.TypeVar('M', bound=Module)
TS = tp.TypeVar('TS', bound='TrainState')


[docs]class Dict(Module, tp.Mapping[str, A]): @tp.overload def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): ... @tp.overload def __init__( self, mapping: tp.Optional[tp.Mapping[str, A]] = None, /, **kwargs: A ): ... def __init__(self, *args, **kwargs): for name, value in dict(*args, **kwargs).items(): setattr(self, name, value) def __getitem__(self, key) -> A: return getattr(self, key) def __setitem__(self, key, value): setattr(self, key, value) def __getattr__(self, key) -> A: return super().__getattribute__(key) def __setattr__(self, key, value): super().__setattr__(key, value) def __iter__(self) -> tp.Iterator[str]: return (k for k in vars(self) if k != '_object__state') def __len__(self) -> int: return len(vars(self))
[docs]class Sequential(Module): def __init__(self, *fns: tp.Callable[..., tp.Any]): self.layers = list(fns) def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any: output: tp.Any = None for i, f in enumerate(self.layers): if not callable(f): raise TypeError(f'Sequence[{i}] is not callable: {f}') if i > 0: if isinstance(output, tuple): args = output kwargs = {} elif isinstance(output, dict): args = () kwargs = output else: args = (output,) kwargs = {} if rngs is not None and has_keyword_arg(f, 'rngs'): kwargs['rngs'] = rngs output = f(*args, **kwargs) return output
class ModuleDefApply(tp.Protocol, tp.Generic[M]): def __call__( self, state: State, *states: State ) -> ApplyCaller[tuple[State, GraphDef[M]]]: ...
[docs]class TrainState(tp.Generic[M], struct.PyTreeNode): graphdef: GraphDef[M] params: State opt_state: optax.OptState step: jax.Array tx: optax.GradientTransformation = struct.field(pytree_node=False) @classmethod def create( cls, graphdef: GraphDef[M], *, params: State, tx: optax.GradientTransformation, step: int = 0, **kwargs, ): return cls( graphdef=graphdef, params=params, opt_state=tx.init(params), step=jnp.asarray(step), tx=tx, **kwargs, ) if tp.TYPE_CHECKING: def __getattr__(self, key: str) -> tp.Any: ... def apply( self, state: tp.Union[State, str], *states: tp.Union[State, str] ) -> ApplyCaller[tuple[GraphDef[M], State]]: states = (state, *states) _states: list[State] = [] for _state in states: if isinstance(_state, str): _state_key = _state _state = getattr(self, _state_key) if not isinstance(_state, State): raise TypeError( f'Expected {self.__class__.__name__}.{_state_key} to be a State, got {type(_state)}' ) _states.append(_state) return self.graphdef.apply(*_states) def apply_gradients(self: TS, grads: State, **kwargs) -> TS: updates, opt_state = self.tx.update(grads, self.opt_state, self.params) params = optax.apply_updates(self.params, updates) # type: ignore step = self.step + 1 return self.replace( params=params, opt_state=opt_state, step=step, **kwargs, )
def has_keyword_arg(func: tp.Callable[..., tp.Any], name: str) -> bool: """Return True if func has keyword-only arguments with the given name.""" return any( param.name == name and param.kind in (param.KEYWORD_ONLY, param.POSITIONAL_OR_KEYWORD) for param in inspect.signature(func).parameters.values() )