Source code for flax.nnx.nn.dtypes
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp
from flax.typing import Dtype
from jax import numpy as jnp
T = tp.TypeVar('T', bound=tuple)
[docs]def canonicalize_dtype(
*args, dtype: Dtype | None = None, inexact: bool = True
) -> Dtype:
"""Canonicalize an optional dtype to the definitive dtype.
If the ``dtype`` is None this function will infer the dtype. If it is not
None it will be returned unmodified or an exceptions is raised if the dtype
is invalid.
from the input arguments using ``jnp.result_type``.
Args:
*args: JAX array compatible values. None values
are ignored.
dtype: Optional dtype override. If specified the arguments are cast to
the specified dtype instead and dtype inference is disabled.
inexact: When True, the output dtype must be a subdtype
of `jnp.inexact`. Inexact dtypes are real or complex floating points. This
is useful when you want to apply operations that don't work directly on
integers like taking a mean for example.
Returns:
The dtype that *args should be cast to.
"""
if dtype is None:
args_filtered = [jnp.asarray(x) for x in args if x is not None]
dtype = jnp.result_type(*args_filtered)
if inexact and not jnp.issubdtype(dtype, jnp.inexact):
dtype = jnp.promote_types(jnp.float32, dtype)
if inexact and not jnp.issubdtype(dtype, jnp.inexact):
raise ValueError(f'Dtype must be inexact: {dtype}')
return dtype