Activation functions#

flax.experimental.nnx.celu(x, alpha=1.0)[source]#

Continuously-differentiable exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]

For more information, see Continuously Differentiable Exponential Linear Units.

Parameters
  • x – input array

  • alpha – array or scalar (default: 1.0)

Returns

An array.

flax.experimental.nnx.elu(x, alpha=1.0)[source]#

Exponential linear unit activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]
Parameters
  • x – input array

  • alpha – scalar or array of alpha values (default: 1.0)

Returns

An array.

See also

selu()

flax.experimental.nnx.gelu(x, approximate=True)[source]#

Gaussian error linear unit activation function.

If approximate=False, computes the element-wise function:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]

If approximate=True, uses the approximate formulation of GELU:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]

For more information, see Gaussian Error Linear Units (GELUs), section 2.

Parameters
  • x – input array

  • approximate – whether to use the approximate or exact formulation.

flax.experimental.nnx.glu(x, axis=-1)[source]#

Gated linear unit activation function.

Computes the function:

\[\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right] \right)\]

where the array is split into two along axis. The size of the axis dimension must be divisible by two.

Parameters
  • x – input array

  • axis – the axis along which the split should be computed (default: -1)

Returns

An array.

See also

sigmoid()

flax.experimental.nnx.hard_sigmoid(x)[source]#

Hard Sigmoid activation function.

Computes the element-wise function

\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]
Parameters

x – input array

Returns

An array.

See also

relu6()

flax.experimental.nnx.hard_silu(x)[source]#

Hard SiLU (swish) activation function

Computes the element-wise function

\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]

Both hard_silu() and hard_swish() are aliases for the same function.

Parameters

x – input array

Returns

An array.

See also

hard_sigmoid()

flax.experimental.nnx.hard_swish(x)#

Hard SiLU (swish) activation function

Computes the element-wise function

\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]

Both hard_silu() and hard_swish() are aliases for the same function.

Parameters

x – input array

Returns

An array.

See also

hard_sigmoid()

flax.experimental.nnx.hard_tanh(x)[source]#

Hard \(\mathrm{tanh}\) activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]
Parameters

x – input array

Returns

An array.

flax.experimental.nnx.leaky_relu(x, negative_slope=0.01)[source]#

Leaky rectified linear unit activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]

where \(\alpha\) = negative_slope.

Parameters
  • x – input array

  • negative_slope – array or scalar specifying the negative slope (default: 0.01)

Returns

An array.

See also

relu()

flax.experimental.nnx.log_sigmoid(x)[source]#

Log-sigmoid activation function.

Computes the element-wise function:

\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]
Parameters

x – input array

Returns

An array.

See also

sigmoid()

flax.experimental.nnx.log_softmax(x, axis=-1, where=None, initial=None)[source]#

Log-Softmax function.

Computes the logarithm of the softmax function, which rescales elements to the range \([-\infty, 0)\).

\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]
Parameters
  • x – input array

  • axis – the axis or axes along which the log_softmax should be computed. Either an integer or a tuple of integers.

  • where – Elements to include in the log_softmax.

  • initial – The minimum value used to shift the input array. Must be present when where is not None.

Returns

An array.

See also

softmax()

flax.experimental.nnx.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False)[source]#

Compute the log of the sum of exponentials of input elements.

LAX-backend implementation of scipy.special.logsumexp().

Original docstring below.

Parameters
  • a (array_like) – Input array.

  • axis (None or int or tuple of ints, optional) – Axis or axes over which the sum is taken. By default axis is None, and all elements are summed.

  • b (array-like, optional) – Scaling factor for exp(a) must be of the same shape as a or broadcastable to a. These values may be negative in order to implement subtraction.

  • keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.

  • return_sign (bool, optional) – If this is set to True, the result will be a pair containing sign information; if False, results that are negative will be returned as NaN. Default is False (no sign information).

Returns

  • res (ndarray) – The result, np.log(np.sum(np.exp(a))) calculated in a numerically more stable way. If b is given then np.log(np.sum(b*np.exp(a))) is returned. If return_sign is True, res contains the log of the absolute value of the argument.

  • sgn (ndarray) – If return_sign is True, this will be an array of floating-point numbers matching res containing +1, 0, -1 (for real-valued inputs) or a complex phase (for complex inputs). This gives the sign of the argument of the logarithm in res. If return_sign is False, only one result is returned.

flax.experimental.nnx.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[source]#

One-hot encodes the given indices.

Each index in the input x is encoded as a vector of zeros of length num_classes with the element at index set to one:

>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

Indices outside the range [0, num_classes) will be encoded as zeros:

>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
Parameters
  • x – A tensor of indices.

  • num_classes – Number of classes in the one-hot dimension.

  • dtype – optional, a float dtype for the returned values (default jnp.float_).

  • axis – the axis or axes along which the function should be computed.

flax.experimental.nnx.relu(x)[source]#

Rectified linear unit activation function.

Computes the element-wise function:

\[\mathrm{relu}(x) = \max(x, 0)\]

except under differentiation, we take:

\[\nabla \mathrm{relu}(0) = 0\]

For more information see Numerical influence of ReLU’(0) on backpropagation.

Parameters

x – input array

Returns

An array.

Example

>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)

See also

relu6()

flax.experimental.nnx.selu(x)[source]#

Scaled exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]

where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).

For more information, see Self-Normalizing Neural Networks.

Parameters

x – input array

Returns

An array.

See also

elu()

flax.experimental.nnx.sigmoid(x)[source]#

Sigmoid activation function.

Computes the element-wise function:

\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]
Parameters

x – input array

Returns

An array.

See also

log_sigmoid()

flax.experimental.nnx.silu(x)[source]#

SiLU (aka swish) activation function.

Computes the element-wise function:

\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]

swish() and silu() are both aliases for the same function.

Parameters

x – input array

Returns

An array.

See also

sigmoid()

flax.experimental.nnx.soft_sign(x)[source]#

Soft-sign activation function.

Computes the element-wise function

\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]
Parameters

x – input array

flax.experimental.nnx.softmax(x, axis=-1, where=None, initial=None)[source]#

Softmax function.

Computes the function which rescales elements to the range \([0, 1]\) such that the elements along axis sum to \(1\).

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
Parameters
  • x – input array

  • axis – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.

  • where – Elements to include in the softmax.

  • initial – The minimum value used to shift the input array. Must be present when where is not None.

Returns

An array.

See also

log_softmax()

flax.experimental.nnx.softplus(x)[source]#

Softplus activation function.

Computes the element-wise function

\[\mathrm{softplus}(x) = \log(1 + e^x)\]
Parameters

x – input array

flax.experimental.nnx.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[source]#

Normalizes an array by subtracting mean and dividing by \(\sqrt{\mathrm{variance}}\).

flax.experimental.nnx.swish(x)#

SiLU (aka swish) activation function.

Computes the element-wise function:

\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]

swish() and silu() are both aliases for the same function.

Parameters

x – input array

Returns

An array.

See also

sigmoid()

flax.experimental.nnx.tanh(x, /)#

Compute hyperbolic tangent element-wise.

LAX-backend implementation of numpy.tanh().

Original docstring below.

Equivalent to np.sinh(x)/np.cosh(x) or -1j * np.tan(1j*x).

Parameters

x (array_like) – Input array.

Returns

y – The corresponding hyperbolic tangent values. This is a scalar if x is a scalar.

Return type

ndarray

References

1

M. Abramowitz and I. A. Stegun, Handbook of Mathematical Functions. New York, NY: Dover, 1972, pg. 83. https://personal.math.ubc.ca/~cbm/aands/page_83.htm

2

Wikipedia, “Hyperbolic function”, https://en.wikipedia.org/wiki/Hyperbolic_function