Dtypes#

flax.nnx.nn.dtypes.canonicalize_dtype(*args, dtype=None, inexact=True)[source]#

Canonicalize an optional dtype to the definitive dtype.

If the dtype is None this function will infer the dtype. If it is not None it will be returned unmodified or an exceptions is raised if the dtype is invalid. from the input arguments using jnp.result_type.

Parameters
  • *args – JAX array compatible values. None values are ignored.

  • dtype – Optional dtype override. If specified the arguments are cast to the specified dtype instead and dtype inference is disabled.

  • inexact – When True, the output dtype must be a subdtype

  • This (of jnp.inexact. Inexact dtypes are real or complex floating points.) –

  • on (is useful when you want to apply operations that don't work directly) –

  • example. (integers like taking a mean for) –

Returns

The dtype that *args should be cast to.

flax.nnx.nn.dtypes.promote_dtype(args, /, *, dtype=None, inexact=True)[source]#

“Promotes input arguments to a specified or inferred dtype.

All args are cast to the same dtype. See canonicalize_dtype for how this dtype is determined.

The behavior of promote_dtype is mostly a convinience wrapper around jax.numpy.promote_types. The differences being that it automatically casts all input to the inferred dtypes, allows inference to be overridden by a forced dtype, and has an optional check to garantuee the resulting dtype is inexact.

Parameters
  • *args – JAX array compatible values. None values are returned as is.

  • dtype – Optional dtype override. If specified the arguments are cast to the specified dtype instead and dtype inference is disabled.

  • inexact – When True, the output dtype must be a subdtype

  • This (of jnp.inexact. Inexact dtypes are real or complex floating points.) –

  • on (is useful when you want to apply operations that don't work directly) –

  • example. (integers like taking a mean for) –

Returns

The arguments cast to arrays of the same dtype.