# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import torch
from torch import nn

__all__ = ["Hflip", "Rot180", "Vflip", "hflip", "rot180", "vflip"]


class Vflip(nn.Module):
    r"""Vertically flip a torch.Tensor image or a batch of torch.Tensor images.

    Input must be a torch.Tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.

    Args:
        input: input torch.Tensor.

    Returns:
        The vertically flipped image torch.Tensor.

    Examples:
        >>> vflip = Vflip()
        >>> input = torch.tensor([[[
        ...    [0., 0., 0.],
        ...    [0., 0., 0.],
        ...    [0., 1., 1.]
        ... ]]])
        >>> vflip(input)
        tensor([[[[0., 1., 1.],
                  [0., 0., 0.],
                  [0., 0., 0.]]]])

    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return vflip(input)

    def __repr__(self) -> str:
        return self.__class__.__name__


class Hflip(nn.Module):
    r"""Horizontally flip a torch.Tensor image or a batch of torch.Tensor images.

    Input must be a torch.Tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.

    Args:
        input: input torch.Tensor.

    Returns:
        The horizontally flipped image torch.Tensor.

    Examples:
        >>> hflip = Hflip()
        >>> input = torch.tensor([[[
        ...    [0., 0., 0.],
        ...    [0., 0., 0.],
        ...    [0., 1., 1.]
        ... ]]])
        >>> hflip(input)
        tensor([[[[0., 0., 0.],
                  [0., 0., 0.],
                  [1., 1., 0.]]]])

    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return hflip(input)

    def __repr__(self) -> str:
        return self.__class__.__name__


class Rot180(nn.Module):
    r"""Rotate a torch.Tensor image or a batch of torch.Tensor images 180 degrees.

    Input must be a torch.Tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.

    Args:
        input: input torch.Tensor.

    Examples:
        >>> rot180 = Rot180()
        >>> input = torch.tensor([[[
        ...    [0., 0., 0.],
        ...    [0., 0., 0.],
        ...    [0., 1., 1.]
        ... ]]])
        >>> rot180(input)
        tensor([[[[1., 1., 0.],
                  [0., 0., 0.],
                  [0., 0., 0.]]]])

    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return rot180(input)

    def __repr__(self) -> str:
        return self.__class__.__name__


def rot180(input: torch.Tensor) -> torch.Tensor:
    r"""Rotate a torch.Tensor image or a batch of torch.Tensor images 180 degrees.

    .. image:: _static/img/rot180.png

    Input must be a torch.Tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.

    Args:
        input: input torch.Tensor.

    Returns:
        The rotated image torch.Tensor.

    """
    return torch.flip(input, [-2, -1])


def hflip(input: torch.Tensor) -> torch.Tensor:
    r"""Horizontally flip a torch.Tensor image or a batch of torch.Tensor images.

    .. image:: _static/img/hflip.png

    Input must be a torch.Tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.

    Args:
        input: input torch.Tensor.

    Returns:
        The horizontally flipped image torch.Tensor.

    """
    return input.flip(-1).contiguous()


def vflip(input: torch.Tensor) -> torch.Tensor:
    r"""Vertically flip a torch.Tensor image or a batch of torch.Tensor images.

    .. image:: _static/img/vflip.png

    Input must be a torch.Tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.

    Args:
        input: input torch.Tensor.

    Returns:
        The vertically flipped image torch.Tensor.

    """
    return input.flip(-2).contiguous()
