Shortcuts

GaussianNLLLoss

class torch.nn.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[source]

Gaussian negative log likelihood loss.

The targets are treated as samples from Gaussian distributions with expectations and variances predicted by the neural network. For a D-dimensional target tensor modelled as having heteroscedastic Gaussian distributions with a D-dimensional tensor of expectations input and a D-dimensional tensor of positive variances var the loss is:

loss=12i=1D(log(max(var[i], eps))+(input[i]target[i])2max(var[i], eps))+const.\text{loss} = \frac{1}{2}\sum_{i=1}^D \left(\log\left(\text{max}\left(\text{var}[i], \ \text{eps}\right)\right) + \frac{\left(\text{input}[i] - \text{target}[i]\right)^2} {\text{max}\left(\text{var}[i], \ \text{eps}\right)}\right) + \text{const.}

where eps is used for stability. By default, the constant term of the loss function is omitted unless full is True. If var is a scalar (implying target tensor has homoscedastic Gaussian distributions) it is broadcasted to be the same size as the input.

Parameters
  • full (bool, optional) – include the constant term in the loss calculation. Default: False.

  • eps (float, optional) – value used to clamp var (see note below), for stability. Default: 1e-6.

  • reduction (string, optional) – specifies the reduction to apply to the output:'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the output is the average of all batch member losses, 'sum': the output is the sum of all batch member losses. Default: 'mean'.

Shape:
  • Input: (N,)(N, *) where * means any number of additional dimensions

  • Target: (N,)(N, *), same shape as the input

  • Var: (N,1)(N, 1) or (N,)(N, *), same shape as the input

  • Output: scalar if reduction is 'mean' (default) or 'sum'. If reduction is 'none', then (N)(N)

Examples:

>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 2, requires_grad=True) #heteroscedastic
>>> output = loss(input, target, var)
>>> output.backward()


>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 1, requires_grad=True) #homoscedastic
>>> output = loss(input, target, var)
>>> output.backward()

Note

The clamping of var is ignored with respect to autograd, and so the gradients are unaffected by it.

Reference:

Nix, D. A. and Weigend, A. S., “Estimating the mean and variance of the target probability distribution”, Proceedings of 1994 IEEE International Conference on Neural Networks (ICNN’94), Orlando, FL, USA, 1994, pp. 55-60 vol.1, doi: 10.1109/ICNN.1994.374138.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources