# Source code for jax._src.nn.functions

# Copyright 2019 The JAX Authors.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

"""Shared neural network activations and other functions."""

from functools import partial
import operator
import warnings
import numpy as np
from typing import Any, Optional, Union

import jax
import jax.numpy as jnp
from jax import custom_jvp
from jax import lax
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.core import AxisName
from jax._src.ops.special import logsumexp as _logsumexp

Array = Any

# activations

[docs]@custom_jvp
@jax.jit
def relu(x: Array) -> Array:
r"""Rectified linear unit activation function.

Computes the element-wise function:

.. math::
\mathrm{relu}(x) = \max(x, 0)

except under differentiation, we take:

.. math::
\nabla \mathrm{relu}(0) = 0

Numerical influence of ReLU’(0) on backpropagation
<https://openreview.net/forum?id=urrcVI-_jRm>_.

Args:
x : input array

Returns:
An array.

Example:
>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)

:func:relu6

"""
return jnp.maximum(x, 0)
# For behavior at 0, see https://openreview.net/forum?id=urrcVI-_jRm
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))

[docs]@jax.jit
def softplus(x: Array) -> Array:
r"""Softplus activation function.

Computes the element-wise function

.. math::
\mathrm{softplus}(x) = \log(1 + e^x)

Args:
x : input array
"""

[docs]@jax.jit
def soft_sign(x: Array) -> Array:
r"""Soft-sign activation function.

Computes the element-wise function

.. math::
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}

Args:
x : input array
"""
return x / (jnp.abs(x) + 1)

[docs]@jax.jit
def sigmoid(x: Array) -> Array:
r"""Sigmoid activation function.

Computes the element-wise function:

.. math::
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}

Args:
x : input array

Returns:
An array.

:func:log_sigmoid

"""
return lax.logistic(x)

[docs]@jax.jit
def silu(x: Array) -> Array:
r"""SiLU (a.k.a. swish) activation function.

Computes the element-wise function:

.. math::
\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}

:func:swish and :func:silu are both aliases for the same function.

Args:
x : input array

Returns:
An array.

:func:sigmoid
"""
return x * sigmoid(x)

swish = silu

[docs]@jax.jit
def log_sigmoid(x: Array) -> Array:
r"""Log-sigmoid activation function.

Computes the element-wise function:

.. math::
\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})

Args:
x : input array

Returns:
An array.

:func:sigmoid
"""
return -softplus(-x)

[docs]@jax.jit
def elu(x: Array, alpha: Array = 1.0) -> Array:
r"""Exponential linear unit activation function.

Computes the element-wise function:

.. math::
\mathrm{elu}(x) = \begin{cases}
x, & x > 0\\
\alpha \left(\exp(x) - 1\right), & x \le 0
\end{cases}

Args:
x : input array
alpha : scalar or array of alpha values (default: 1.0)

Returns:
An array.

:func:selu
"""
safe_x = jnp.where(x > 0, 0., x)
return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))

[docs]@jax.jit
def leaky_relu(x: Array, negative_slope: Array = 1e-2) -> Array:
r"""Leaky rectified linear unit activation function.

Computes the element-wise function:

.. math::
\mathrm{leaky\_relu}(x) = \begin{cases}
x, & x \ge 0\\
\alpha x, & x < 0
\end{cases}

where :math:\alpha = :code:negative_slope.

Args:
x : input array
negative_slope : array or scalar specifying the negative slope (default: 0.01)

Returns:
An array.

:func:relu
"""
return jnp.where(x >= 0, x, negative_slope * x)

[docs]@jax.jit
def hard_tanh(x: Array) -> Array:
r"""Hard :math:\mathrm{tanh} activation function.

Computes the element-wise function:

.. math::
\mathrm{hard\_tanh}(x) = \begin{cases}
-1, & x < -1\\
x, & -1 \le x \le 1\\
1, & 1 < x
\end{cases}

Args:
x : input array

Returns:
An array.
"""
return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x))

[docs]@jax.jit
def celu(x: Array, alpha: Array = 1.0) -> Array:
r"""Continuously-differentiable exponential linear unit activation.

Computes the element-wise function:

.. math::
\mathrm{celu}(x) = \begin{cases}
x, & x > 0\\
\alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
\end{cases}

Continuously Differentiable Exponential Linear Units
<https://arxiv.org/pdf/1704.07483.pdf>_.

Args:
x : input array
alpha : array or scalar (default: 1.0)

Returns:
An array.
"""
return jnp.maximum(x, 0.0) + alpha * jnp.expm1(jnp.minimum(x, 0.0) / alpha)

[docs]@jax.jit
def selu(x: Array) -> Array:
r"""Scaled exponential linear unit activation.

Computes the element-wise function:

.. math::
\mathrm{selu}(x) = \lambda \begin{cases}
x, & x > 0\\
\alpha e^x - \alpha, & x \le 0
\end{cases}

where :math:\lambda = 1.0507009873554804934193349852946 and
:math:\alpha = 1.6732632423543772848170429916717.

Self-Normalizing Neural Networks
<https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>_.

Args:
x : input array

Returns:
An array.

:func:elu
"""
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
return scale * elu(x, alpha)

# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
# @partial(jax.jit, static_argnames=("approximate",))
[docs]def gelu(x: Array, approximate: bool = True) -> Array:
r"""Gaussian error linear unit activation function.

If approximate=False, computes the element-wise function:

.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
\frac{x}{\sqrt{2}} \right) \right)

If approximate=True, uses the approximate formulation of GELU:

.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)

<https://arxiv.org/abs/1606.08415>_, section 2.

Args:
x : input array
approximate: whether to use the approximate or exact formulation.
"""

# Promote to nearest float-like dtype.
x = x.astype(dtypes.to_inexact_dtype(x.dtype))

if approximate:
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3))))
return x * cdf
else:
sqrt_2 = np.sqrt(2).astype(x.dtype)
return jnp.array(x * (lax.erf(x / sqrt_2) + 1) / 2, dtype=x.dtype)

[docs]@partial(jax.jit, static_argnames=("axis",))
def glu(x: Array, axis: int = -1) -> Array:
r"""Gated linear unit activation function.

Computes the function:

.. math::
\mathrm{glu}(x) =  x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
\mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
\right)

where the array is split into two along axis. The size of the axis
dimension must be divisible by two.

Args:
x : input array
axis: the axis along which the split should be computed (default: -1)

Returns:
An array.

:func:sigmoid
"""
size = x.shape[axis]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = jnp.split(x, 2, axis)
return x1 * sigmoid(x2)

# other functions

logsumexp = _logsumexp

[docs]@partial(jax.jit, static_argnames=("axis",))
def log_softmax(x: Array,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
r"""Log-Softmax function.

Computes the logarithm of the :code:softmax function, which rescales
elements to the range :math:[-\infty, 0).

.. math ::
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
\right)

Args:
x : input array
axis: the axis or axes along which the :code:log_softmax should be
computed. Either an integer or a tuple of integers.
where: Elements to include in the :code:log_softmax.
initial: The minimum value used to shift the input array. Must be present
when :code:where is not None.

Returns:
An array.

:func:softmax
"""
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
shifted_logsumexp = jnp.log(
jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
result = shifted - shifted_logsumexp
if where is not None:
return jnp.where(where, result, -jnp.inf)
return result

# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
#@partial(jax.jit, static_argnames=("axis",))
[docs]def softmax(x: Array,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
r"""Softmax function.

Computes the function which rescales elements to the range :math:[0, 1]
such that the elements along :code:axis sum to :math:1.

.. math ::
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}

Args:
x : input array
axis: the axis or axes along which the softmax should be computed. The
softmax output summed across these dimensions should sum to :math:1.
Either an integer or a tuple of integers.
where: Elements to include in the :code:softmax.
initial: The minimum value used to shift the input array. Must be present
when :code:where is not None.

Returns:
An array.

:func:log_softmax
"""
if jax.config.jax_softmax_custom_jvp:
return _softmax(x, axis, where, initial)
else:
return _softmax_deprecated(x, axis, where, initial)

# TODO(mattjj): replace softmax with _softmax when deprecation flag is removed
@partial(jax.custom_jvp, nondiff_argnums=(1,))
def _softmax(
x,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - x_max)
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
if where is not None:
result = jnp.where(where, result, 0)
return result

@_softmax.defjvp
def _softmax_jvp(axis, primals, tangents):
(x, where, initial), (x_dot, _, _) = primals, tangents
y = _softmax(x, axis, where, initial)
return y, y * (x_dot - (y * x_dot).sum(axis, where=where, keepdims=True))

def _softmax_deprecated(x, axis, where, initial):
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
if where is not None:
result = jnp.where(where, result, 0)
return result

[docs]@partial(jax.jit, static_argnames=("axis",))
def standardize(x: Array,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: Array = 1e-5,
where: Optional[Array] = None) -> Array:
r"""Normalizes an array by subtracting mean and dividing by :math:\sqrt{\mathrm{variance}}."""
if mean is None:
mean = jnp.mean(x, axis, keepdims=True, where=where)
if variance is None:
# this definition is traditionally seen as less accurate than jnp.var's
# mean((x - mean(x))**2) but may be faster and even, given typical
# activation distributions and low-precision arithmetic, more accurate
# when used in neural network normalization layers
variance = jnp.mean(
jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean)
return (x - mean) * lax.rsqrt(variance + epsilon)

def normalize(x: Array,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: Array = 1e-5,
where: Optional[Array] = None) -> Array:
r"""Normalizes an array by subtracting mean and dividing by :math:\sqrt{\mathrm{variance}}."""
warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
return standardize(x, axis, mean, variance, epsilon, where)

@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis"))
def _one_hot(x: Array, num_classes: int, *,
dtype: Any, axis: Union[int, AxisName]) -> Array:
num_classes = core.concrete_dim_or_error(
num_classes,
"The error arose in jax.nn.one_hot argument num_classes.")
dtype = dtypes.canonicalize_dtype(dtype)
x = jnp.asarray(x)
try:
output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
except TypeError:
axis_size = lax.psum(1, axis)
if num_classes != axis_size:
raise ValueError(f"Expected num_classes to match the size of axis {axis}, "
f"but {num_classes} != {axis_size}") from None
axis_idx = lax.axis_index(axis)
return jnp.asarray(x == axis_idx, dtype=dtype)
axis = operator.index(axis)  # type: ignore[arg-type]
lhs = lax.expand_dims(x, (axis,))
rhs_shape = [1] * x.ndim
rhs_shape.insert(output_pos_axis, num_classes)
return jnp.asarray(lhs == rhs, dtype=dtype)

[docs]def one_hot(x: Array, num_classes: int, *,
dtype: Any = jnp.float_, axis: Union[int, AxisName] = -1) -> Array:
"""One-hot encodes the given indices.

Each index in the input x is encoded as a vector of zeros of length
num_classes with the element at index set to one::

>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
Array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)

Indices outside the range [0, num_classes) will be encoded as zeros::

>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
Array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)

Args:
x: A tensor of indices.
num_classes: Number of classes in the one-hot dimension.
dtype: optional, a float dtype for the returned values (default :obj:jnp.float_).
axis: the axis or axes along which the function should be
computed.
"""
num_classes = core.concrete_dim_or_error(
num_classes,
"The error arose in jax.nn.one_hot argument num_classes.")
return _one_hot(x, num_classes, dtype=dtype, axis=axis)

[docs]@jax.custom_jvp
@jax.jit
def relu6(x: Array) -> Array:
r"""Rectified Linear Unit 6 activation function.

Computes the element-wise function

.. math::
\mathrm{relu6}(x) = \min(\max(x, 0), 6)

except under differentiation, we take:

.. math::
\nabla \mathrm{relu}(0) = 0

and

.. math::
\nabla \mathrm{relu}(6) = 0

Args:
x : input array

Returns:
An array.

:func:relu
"""
return jnp.minimum(jnp.maximum(x, 0), 6.)
relu6.defjvps(lambda g, ans, x:
lax.select((x > 0) & (x < 6), g, lax.full_like(g, 0)))

[docs]@jax.jit
def hard_sigmoid(x: Array) -> Array:
r"""Hard Sigmoid activation function.

Computes the element-wise function

.. math::
\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}

Args:
x : input array

Returns:
An array.

:func:relu6
"""
return relu6(x + 3.) / 6.

[docs]@jax.jit
def hard_silu(x: Array) -> Array:
r"""Hard SiLU (swish) activation function

Computes the element-wise function

.. math::
\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)

Both :func:hard_silu and :func:hard_swish are aliases for the same
function.

Args:
x : input array

Returns:
An array.

:func:hard_sigmoid