# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import inspect
import typing as tp
import jax
import jax.numpy as jnp
import optax
from flax.nnx.module import GraphDef, Module
from flax.nnx.proxy_caller import ApplyCaller
from flax.nnx.rnglib import Rngs
from flax.nnx.statelib import State
from flax.training.train_state import struct
A = tp.TypeVar('A')
M = tp.TypeVar('M', bound=Module)
TS = tp.TypeVar('TS', bound='TrainState')
[docs]class Dict(Module, tp.Mapping[str, A]):
@tp.overload
def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): ...
@tp.overload
def __init__(
self, mapping: tp.Optional[tp.Mapping[str, A]] = None, /, **kwargs: A
): ...
def __init__(self, *args, **kwargs):
for name, value in dict(*args, **kwargs).items():
setattr(self, name, value)
def __getitem__(self, key) -> A:
return getattr(self, key)
def __setitem__(self, key, value):
setattr(self, key, value)
def __getattr__(self, key) -> A:
return super().__getattribute__(key)
def __setattr__(self, key, value):
super().__setattr__(key, value)
def __iter__(self) -> tp.Iterator[str]:
return (k for k in vars(self) if k != '_object__state')
def __len__(self) -> int:
return len(vars(self))
[docs]class Sequential(Module):
def __init__(self, *fns: tp.Callable[..., tp.Any]):
self.layers = list(fns)
def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any:
output: tp.Any = None
for i, f in enumerate(self.layers):
if not callable(f):
raise TypeError(f'Sequence[{i}] is not callable: {f}')
if i > 0:
if isinstance(output, tuple):
args = output
kwargs = {}
elif isinstance(output, dict):
args = ()
kwargs = output
else:
args = (output,)
kwargs = {}
if rngs is not None and has_keyword_arg(f, 'rngs'):
kwargs['rngs'] = rngs
output = f(*args, **kwargs)
return output
class ModuleDefApply(tp.Protocol, tp.Generic[M]):
def __call__(
self, state: State, *states: State
) -> ApplyCaller[tuple[State, GraphDef[M]]]: ...
[docs]class TrainState(tp.Generic[M], struct.PyTreeNode):
graphdef: GraphDef[M]
params: State
opt_state: optax.OptState
step: jax.Array
tx: optax.GradientTransformation = struct.field(pytree_node=False)
@classmethod
def create(
cls,
graphdef: GraphDef[M],
*,
params: State,
tx: optax.GradientTransformation,
step: int = 0,
**kwargs,
):
return cls(
graphdef=graphdef,
params=params,
opt_state=tx.init(params),
step=jnp.asarray(step),
tx=tx,
**kwargs,
)
if tp.TYPE_CHECKING:
def __getattr__(self, key: str) -> tp.Any: ...
def apply(
self, state: tp.Union[State, str], *states: tp.Union[State, str]
) -> ApplyCaller[tuple[GraphDef[M], State]]:
states = (state, *states)
_states: list[State] = []
for _state in states:
if isinstance(_state, str):
_state_key = _state
_state = getattr(self, _state_key)
if not isinstance(_state, State):
raise TypeError(
f'Expected {self.__class__.__name__}.{_state_key} to be a State, got {type(_state)}'
)
_states.append(_state)
return self.graphdef.apply(*_states)
def apply_gradients(self: TS, grads: State, **kwargs) -> TS:
updates, opt_state = self.tx.update(grads, self.opt_state, self.params)
params = optax.apply_updates(self.params, updates) # type: ignore
step = self.step + 1
return self.replace(
params=params,
opt_state=opt_state,
step=step,
**kwargs,
)
def has_keyword_arg(func: tp.Callable[..., tp.Any], name: str) -> bool:
"""Return True if func has keyword-only arguments with the given name."""
return any(
param.name == name
and param.kind in (param.KEYWORD_ONLY, param.POSITIONAL_OR_KEYWORD)
for param in inspect.signature(func).parameters.values()
)