# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pytest
import torch

import kornia

from testing.base import BaseTester


class TestResize(BaseTester):
    def test_smoke(self, device, dtype):
        inp = torch.rand(1, 3, 3, 4, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (3, 4), align_corners=False)
        self.assert_close(inp, out, atol=1e-4, rtol=1e-4)

        # 2D
        inp = torch.rand(3, 4, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (3, 4), align_corners=False)
        self.assert_close(inp, out, atol=1e-4, rtol=1e-4)

        # 3D
        inp = torch.rand(3, 3, 4, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (3, 4), align_corners=False)
        self.assert_close(inp, out, atol=1e-4, rtol=1e-4)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 3, 4, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (3, 4), align_corners=False)
        self.assert_close(inp, out, atol=1e-4, rtol=1e-4)

    def test_upsize(self, device, dtype):
        inp = torch.rand(1, 3, 3, 4, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (6, 8), align_corners=False)
        assert out.shape == (1, 3, 6, 8)

        # 2D
        inp = torch.rand(3, 4, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (6, 8), align_corners=False)
        assert out.shape == (6, 8)

        # 3D
        inp = torch.rand(3, 3, 4, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (6, 8), align_corners=False)
        assert out.shape == (3, 6, 8)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 3, 4, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (6, 8), align_corners=False)
        assert out.shape == (1, 2, 3, 2, 1, 3, 6, 8)

    def test_downsize(self, device, dtype):
        inp = torch.rand(1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (3, 1), align_corners=False)
        assert out.shape == (1, 3, 3, 1)

        # 2D
        inp = torch.rand(5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (3, 1), align_corners=False)
        assert out.shape == (3, 1)

        # 3D
        inp = torch.rand(3, 5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (3, 1), align_corners=False)
        assert out.shape == (3, 3, 1)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (3, 1), align_corners=False)
        assert out.shape == (1, 2, 3, 2, 1, 3, 3, 1)

    def test_downsizeAA(self, device, dtype):
        inp = torch.rand(1, 3, 10, 8, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (5, 3), align_corners=False, antialias=True)
        assert out.shape == (1, 3, 5, 3)

        inp = torch.rand(1, 1, 20, 10, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (15, 8), align_corners=False, antialias=True)
        assert out.shape == (1, 1, 15, 8)

        # 2D
        inp = torch.rand(10, 8, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (5, 3), align_corners=False, antialias=True)
        assert out.shape == (5, 3)

        # 3D
        inp = torch.rand(3, 10, 8, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (5, 3), align_corners=False, antialias=True)
        assert out.shape == (3, 5, 3)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 10, 8, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, (5, 3), align_corners=False, antialias=True)
        assert out.shape == (1, 2, 3, 2, 1, 3, 5, 3)

    def test_one_param(self, device, dtype):
        inp = torch.rand(1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False)
        assert out.shape == (1, 3, 25, 10)

        # 2D
        inp = torch.rand(5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False)
        assert out.shape == (25, 10)

        # 3D
        inp = torch.rand(3, 5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False)
        assert out.shape == (3, 25, 10)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False)
        assert out.shape == (1, 2, 3, 2, 1, 3, 25, 10)

    def test_one_param_long(self, device, dtype):
        inp = torch.rand(1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False, side="long")
        assert out.shape == (1, 3, 10, 4)

        # 2D
        inp = torch.rand(5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False, side="long")
        assert out.shape == (10, 4)

        # 3D
        inp = torch.rand(3, 5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False, side="long")
        assert out.shape == (3, 10, 4)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False, side="long")
        assert out.shape == (1, 2, 3, 2, 1, 3, 10, 4)

    def test_one_param_vert(self, device, dtype):
        inp = torch.rand(1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False, side="vert")
        assert out.shape == (1, 3, 10, 4)

        # 2D
        inp = torch.rand(5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False, side="vert")
        assert out.shape == (10, 4)

        # 3D
        inp = torch.rand(3, 5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False, side="vert")
        assert out.shape == (3, 10, 4)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False, side="vert")
        assert out.shape == (1, 2, 3, 2, 1, 3, 10, 4)

    def test_one_param_horz(self, device, dtype):
        inp = torch.rand(1, 3, 2, 5, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False, side="horz")
        assert out.shape == (1, 3, 4, 10)

        # 2D
        inp = torch.rand(1, 3, 2, 5, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False, side="horz")
        assert out.shape == (1, 3, 4, 10)

        # 3D
        inp = torch.rand(1, 3, 2, 5, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False, side="horz")
        assert out.shape == (1, 3, 4, 10)

        # arbitrary dim
        inp = torch.rand(1, 3, 2, 5, device=device, dtype=dtype)
        out = kornia.geometry.transform.resize(inp, 10, align_corners=False, side="horz")
        assert out.shape == (1, 3, 4, 10)

    def test_gradcheck(self, device):
        # test parameters
        new_size = 4
        inp = torch.rand(1, 2, 3, 4, device=device, dtype=torch.float64)
        self.gradcheck(kornia.geometry.transform.Resize(new_size, align_corners=False), (inp,))

    @pytest.mark.parametrize("anti_alias", [True, False])
    def test_dynamo(self, device, dtype, anti_alias, torch_optimizer):
        new_size = (5, 6)
        inp = torch.rand(1, 2, 3, 4, device=device, dtype=dtype)
        op = torch_optimizer(kornia.geometry.transform.resize)
        out = op(inp, new_size, align_corners=False, antialias=anti_alias)
        assert out.shape == (1, 2, 5, 6)
        expected = op(inp, new_size, align_corners=False, antialias=anti_alias)
        self.assert_close(out, expected)


class TestRescale(BaseTester):
    def test_smoke(self, device, dtype):
        input = torch.rand(1, 3, 3, 4, device=device, dtype=dtype)
        output = kornia.geometry.transform.rescale(input, (1.0, 1.0), align_corners=False)
        self.assert_close(input, output, atol=1e-4, rtol=1e-4)

    def test_upsize(self, device, dtype):
        input = torch.rand(1, 3, 3, 4, device=device, dtype=dtype)
        output = kornia.geometry.transform.rescale(input, (3.0, 2.0), align_corners=False)
        assert output.shape == (1, 3, 9, 8)

    def test_downsize(self, device, dtype):
        input = torch.rand(1, 3, 9, 8, device=device, dtype=dtype)
        output = kornia.geometry.transform.rescale(input, (1.0 / 3.0, 1.0 / 2.0), align_corners=False)
        assert output.shape == (1, 3, 3, 4)

    def test_downscale_values(self, device, dtype):
        inp_x = torch.arange(20, device=device, dtype=dtype) / 20.0
        inp = inp_x[None].T @ inp_x[None]
        inp = inp[None, None]
        out = kornia.geometry.transform.rescale(inp, (0.25, 0.25), antialias=False, align_corners=False)
        expected = torch.tensor(
            [
                [
                    [
                        [0.0056, 0.0206, 0.0356, 0.0506, 0.0656],
                        [0.0206, 0.0756, 0.1306, 0.1856, 0.2406],
                        [0.0356, 0.1306, 0.2256, 0.3206, 0.4156],
                        [0.0506, 0.1856, 0.3206, 0.4556, 0.5906],
                        [0.0656, 0.2406, 0.4156, 0.5906, 0.7656],
                    ]
                ]
            ],
            device=device,
            dtype=dtype,
        )
        self.assert_close(out, expected, atol=1e-3, rtol=1e-3)

    def test_downscale_values_AA(self, device, dtype):
        inp_x = torch.arange(20, device=device, dtype=dtype) / 20.0
        inp = inp_x[None].T @ inp_x[None]
        inp = inp[None, None]
        out = kornia.geometry.transform.rescale(inp, (0.25, 0.25), antialias=True, align_corners=False)
        expected = torch.tensor(
            [
                [
                    [
                        [0.0074, 0.0237, 0.0409, 0.0581, 0.0743],
                        [0.0237, 0.0756, 0.1306, 0.1856, 0.2376],
                        [0.0409, 0.1306, 0.2256, 0.3206, 0.4104],
                        [0.0581, 0.1856, 0.3206, 0.4556, 0.5832],
                        [0.0743, 0.2376, 0.4104, 0.5832, 0.7464],
                    ]
                ]
            ],
            device=device,
            dtype=dtype,
        )
        self.assert_close(out, expected, atol=1e-3, rtol=1e-3)

    def test_one_param(self, device, dtype):
        input = torch.rand(1, 3, 3, 4, device=device, dtype=dtype)
        output = kornia.geometry.transform.rescale(input, 2.0, align_corners=False)
        assert output.shape == (1, 3, 6, 8)

    def test_gradcheck(self, device):
        input = torch.rand(1, 2, 3, 4, device=device, dtype=torch.float64)
        self.gradcheck(kornia.geometry.transform.Rescale(2.0, align_corners=False), (input,), nondet_tol=1e-8)


class TestRotate(BaseTester):
    def test_angle90(self, device, dtype):
        # prepare input data
        inp = torch.tensor([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]], device=device, dtype=dtype)
        expected = torch.tensor([[[0.0, 0.0], [4.0, 6.0], [3.0, 5.0], [0.0, 0.0]]], device=device, dtype=dtype)
        # prepare transformation
        angle = torch.tensor([90.0], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Rotate(angle, align_corners=True)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_angle90_batch2(self, device, dtype):
        # prepare input data
        inp = torch.tensor([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]], device=device, dtype=dtype).repeat(
            2, 1, 1, 1
        )
        expected = torch.tensor(
            [[[[0.0, 0.0], [4.0, 6.0], [3.0, 5.0], [0.0, 0.0]]], [[[0.0, 0.0], [5.0, 3.0], [6.0, 4.0], [0.0, 0.0]]]],
            device=device,
            dtype=dtype,
        )
        # prepare transformation
        angle = torch.tensor([90.0, -90.0], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Rotate(angle, align_corners=True)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_angle90_batch2_broadcast(self, device, dtype):
        # prepare input data
        inp = torch.tensor([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]], device=device, dtype=dtype).repeat(
            2, 1, 1, 1
        )
        expected = torch.tensor(
            [[[[0.0, 0.0], [4.0, 6.0], [3.0, 5.0], [0.0, 0.0]]], [[[0.0, 0.0], [4.0, 6.0], [3.0, 5.0], [0.0, 0.0]]]],
            device=device,
            dtype=dtype,
        )
        # prepare transformation
        angle = torch.tensor([90.0], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Rotate(angle, align_corners=True)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_gradcheck(self, device):
        # test parameters
        angle = torch.tensor([90.0], device=device, dtype=torch.float64)

        # evaluate function gradient
        input = torch.rand(1, 2, 3, 4, device=device, dtype=torch.float64)
        self.gradcheck(kornia.geometry.transform.rotate, (input, angle))

    @pytest.mark.skip("Need deep look into it since crashes everywhere.")
    @pytest.mark.skip(reason="turn off all jit for a while")
    def test_jit(self, device, dtype):
        angle = torch.tensor([90.0], device=device, dtype=dtype)
        batch_size, channels, height, width = 2, 3, 64, 64
        img = torch.ones(batch_size, channels, height, width, device=device, dtype=dtype)
        rot = kornia.geometry.transform.Rotate(angle)
        rot_traced = torch.jit.trace(kornia.geometry.transform.Rotate(angle), img)
        self.assert_close(rot(img), rot_traced(img))


class TestTranslate(BaseTester):
    def test_dxdy(self, device, dtype):
        # prepare input data
        inp = torch.tensor([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]], device=device, dtype=dtype)
        expected = torch.tensor([[[0.0, 1.0], [0.0, 3.0], [0.0, 5.0], [0.0, 7.0]]], device=device, dtype=dtype)
        # prepare transformation
        translation = torch.tensor([[1.0, 0.0]], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Translate(translation, align_corners=True)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_dxdy_batch(self, device, dtype):
        # prepare input data
        inp = torch.tensor([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]], device=device, dtype=dtype).repeat(
            2, 1, 1, 1
        )
        expected = torch.tensor(
            [[[[0.0, 1.0], [0.0, 3.0], [0.0, 5.0], [0.0, 7.0]]], [[[0.0, 0.0], [0.0, 1.0], [0.0, 3.0], [0.0, 5.0]]]],
            device=device,
            dtype=dtype,
        )
        # prepare transformation
        translation = torch.tensor([[1.0, 0.0], [1.0, 1.0]], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Translate(translation, align_corners=True)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_dxdy_batch_broadcast(self, device, dtype):
        # prepare input data
        inp = torch.tensor([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]], device=device, dtype=dtype).repeat(
            2, 1, 1, 1
        )
        expected = torch.tensor(
            [[[[0.0, 1.0], [0.0, 3.0], [0.0, 5.0], [0.0, 7.0]]], [[[0.0, 1.0], [0.0, 3.0], [0.0, 5.0], [0.0, 7.0]]]],
            device=device,
            dtype=dtype,
        )
        # prepare transformation
        translation = torch.tensor([[1.0, 0.0]], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Translate(translation, align_corners=True)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_gradcheck(self, device):
        # test parameters
        translation = torch.tensor([[1.0, 0.0]], device=device, dtype=torch.float64)

        # evaluate function gradient
        input = torch.rand(1, 2, 3, 4, device=device, dtype=torch.float64)
        self.gradcheck(kornia.geometry.transform.translate, (input, translation), requires_grad=(True, False))

    @pytest.mark.skip("Need deep look into it since crashes everywhere.")
    @pytest.mark.skip(reason="turn off all jit for a while")
    def test_jit(self, device, dtype):
        translation = torch.tensor([[1.0, 0.0]], device=device, dtype=dtype)
        batch_size, channels, height, width = 2, 3, 64, 64
        img = torch.ones(batch_size, channels, height, width, device=device, dtype=dtype)
        trans = kornia.geometry.transform.Translate(translation)
        trans_traced = torch.jit.trace(kornia.geometry.transform.Translate(translation), img)
        self.assert_close(trans(img), trans_traced(img), atol=1e-4, rtol=1e-4)


class TestScale(BaseTester):
    def test_scale_factor_2(self, device, dtype):
        # prepare input data
        inp = torch.tensor(
            [[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]],
            device=device,
            dtype=dtype,
        )
        # prepare transformation
        scale_factor = torch.tensor([[2.0, 2.0]], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Scale(scale_factor)
        self.assert_close(transform(inp).sum().item(), 12.25, atol=1e-4, rtol=1e-4)

    def test_scale_factor_05(self, device, dtype):
        # prepare input data
        inp = torch.tensor(
            [[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
            device=device,
            dtype=dtype,
        )
        expected = torch.tensor(
            [[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]],
            device=device,
            dtype=dtype,
        )
        # prepare transformation
        scale_factor = torch.tensor([[0.5, 0.5]], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Scale(scale_factor)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_scale_factor_05_batch2(self, device, dtype):
        # prepare input data
        inp = torch.tensor(
            [[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
            device=device,
            dtype=dtype,
        ).repeat(2, 1, 1, 1)
        expected = torch.tensor(
            [[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]],
            device=device,
            dtype=dtype,
        ).repeat(2, 1, 1, 1)
        # prepare transformation
        scale_factor = torch.tensor([[0.5, 0.5]], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Scale(scale_factor)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_scale_factor_05_batch2_broadcast(self, device, dtype):
        # prepare input data
        inp = torch.tensor(
            [[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
            device=device,
            dtype=dtype,
        ).repeat(2, 1, 1, 1)
        expected = torch.tensor(
            [[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]],
            device=device,
            dtype=dtype,
        ).repeat(2, 1, 1, 1)
        # prepare transformation
        scale_factor = torch.tensor([[0.5, 0.5]], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Scale(scale_factor)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_gradcheck(self, device):
        # test parameters
        scale_factor = torch.tensor([[0.5, 0.5]], device=device, dtype=torch.float64)

        # evaluate function gradient
        input = torch.rand(1, 2, 3, 4, device=device, dtype=torch.float64)
        self.gradcheck(kornia.geometry.transform.scale, (input, scale_factor), requires_grad=(True, False))

    @pytest.mark.skip("Need deep look into it since crashes everywhere.")
    @pytest.mark.skip(reason="turn off all jit for a while")
    def test_jit(self, device, dtype):
        scale_factor = torch.tensor([[0.5, 0.5]], device=device, dtype=dtype)
        batch_size, channels, height, width = 2, 3, 64, 64
        img = torch.ones(batch_size, channels, height, width, device=device, dtype=dtype)
        trans = kornia.geometry.transform.Scale(scale_factor)
        trans_traced = torch.jit.trace(kornia.Scale(scale_factor), img)
        self.assert_close(trans(img), trans_traced(img), atol=1e-4, rtol=1e-4)


class TestShear(BaseTester):
    def test_shear_x(self, device, dtype):
        # prepare input data
        inp = torch.tensor(
            [[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
            device=device,
            dtype=dtype,
        )
        expected = torch.tensor(
            [[[0.75, 1.0, 1.0, 1.0], [0.25, 1.0, 1.0, 1.0], [0.0, 0.75, 1.0, 1.0], [0.0, 0.25, 1.0, 1.0]]],
            device=device,
            dtype=dtype,
        )

        # prepare transformation
        shear = torch.tensor([[0.5, 0.0]], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Shear(shear, align_corners=False)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_shear_y(self, device, dtype):
        # prepare input data
        inp = torch.tensor(
            [[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
            device=device,
            dtype=dtype,
        )
        expected = torch.tensor(
            [[[0.75, 0.25, 0.0, 0.0], [1.0, 1.0, 0.75, 0.25], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
            device=device,
            dtype=dtype,
        )

        # prepare transformation
        shear = torch.tensor([[0.0, 0.5]], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Shear(shear, align_corners=False)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_shear_batch2(self, device, dtype):
        # prepare input data
        inp = torch.tensor(
            [[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
            device=device,
            dtype=dtype,
        ).repeat(2, 1, 1, 1)

        expected = torch.tensor(
            [
                [[[0.75, 1.0, 1.0, 1.0], [0.25, 1.0, 1.0, 1.0], [0.0, 0.75, 1.0, 1.0], [0.0, 0.25, 1.0, 1.0]]],
                [[[0.75, 0.25, 0.0, 0.0], [1.0, 1.0, 0.75, 0.25], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
            ],
            device=device,
            dtype=dtype,
        )

        # prepare transformation
        shear = torch.tensor([[0.5, 0.0], [0.0, 0.5]], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Shear(shear, align_corners=False)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_shear_batch2_broadcast(self, device, dtype):
        # prepare input data
        inp = torch.tensor(
            [[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
            device=device,
            dtype=dtype,
        ).repeat(2, 1, 1, 1)

        expected = torch.tensor(
            [[[[0.75, 1.0, 1.0, 1.0], [0.25, 1.0, 1.0, 1.0], [0.0, 0.75, 1.0, 1.0], [0.0, 0.25, 1.0, 1.0]]]],
            device=device,
            dtype=dtype,
        ).repeat(2, 1, 1, 1)

        # prepare transformation
        shear = torch.tensor([[0.5, 0.0]], device=device, dtype=dtype)
        transform = kornia.geometry.transform.Shear(shear, align_corners=False)
        self.assert_close(transform(inp), expected, atol=1e-4, rtol=1e-4)

    def test_gradcheck(self, device):
        # test parameters
        shear = torch.tensor([[0.5, 0.0]], device=device, dtype=torch.float64)

        # evaluate function gradient
        input = torch.rand(1, 2, 3, 4, device=device, dtype=torch.float64)
        self.gradcheck(kornia.geometry.transform.shear, (input, shear), requires_grad=(True, False))

    @pytest.mark.skip("Need deep look into it since crashes everywhere.")
    @pytest.mark.skip(reason="turn off all jit for a while")
    def test_jit(self, device, dtype):
        shear = torch.tensor([[0.5, 0.0]], device=device, dtype=dtype)
        batch_size, channels, height, width = 2, 3, 64, 64
        img = torch.ones(batch_size, channels, height, width, device=device, dtype=dtype)
        trans = kornia.geometry.transform.Shear(shear, align_corners=False)
        trans_traced = torch.jit.trace(kornia.geometry.transform.Shear(shear), img)
        self.assert_close(trans(img), trans_traced(img), atol=1e-4, rtol=1e-4)


class TestAffine2d(BaseTester):
    def test_affine_no_args(self):
        with pytest.raises(RuntimeError):
            kornia.geometry.transform.Affine()

    def test_affine_batch_size_mismatch(self, device, dtype):
        with pytest.raises(RuntimeError):
            angle = torch.rand(1, device=device, dtype=dtype)
            translation = torch.rand(2, 2, device=device, dtype=dtype)
            kornia.geometry.transform.Affine(angle, translation)

    def test_affine_rotate(self, device, dtype):
        # TODO: Remove when #666 is implemented
        if device.type == "cuda":
            pytest.skip("Currently breaks in CUDA.See https://github.com/kornia/kornia/issues/666")
        torch.manual_seed(0)
        angle = torch.rand(1, device=device, dtype=dtype) * 90.0
        input = torch.rand(1, 2, 3, 4, device=device, dtype=dtype)

        transform = kornia.geometry.transform.Affine(angle=angle).to(device=device, dtype=dtype)
        actual = transform(input)
        expected = kornia.geometry.transform.rotate(input, angle)
        self.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

    def test_affine_translate(self, device, dtype):
        # TODO: Remove when #666 is implemented
        if device.type == "cuda":
            pytest.skip("Currently breaks in CUDA.See https://github.com/kornia/kornia/issues/666")
        torch.manual_seed(0)
        translation = torch.rand(1, 2, device=device, dtype=dtype) * 2.0
        input = torch.rand(1, 2, 3, 4, device=device, dtype=dtype)

        transform = kornia.geometry.transform.Affine(translation=translation).to(device=device, dtype=dtype)
        actual = transform(input)
        expected = kornia.geometry.transform.translate(input, translation)
        self.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

    def test_affine_scale(self, device, dtype):
        # TODO: Remove when #666 is implemented
        if device.type == "cuda":
            pytest.skip("Currently breaks in CUDA.See https://github.com/kornia/kornia/issues/666")
        torch.manual_seed(0)
        _scale_factor = torch.rand(1, device=device, dtype=dtype) * 2.0
        scale_factor = torch.stack([_scale_factor, _scale_factor], dim=1)
        input = torch.rand(1, 2, 3, 4, device=device, dtype=dtype)

        transform = kornia.geometry.transform.Affine(scale_factor=scale_factor).to(device=device, dtype=dtype)
        actual = transform(input)
        expected = kornia.geometry.transform.scale(input, scale_factor)
        self.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

    @pytest.mark.skip(
        "_compute_shear_matrix and get_affine_matrix2d yield different results. "
        "See https://github.com/kornia/kornia/issues/629 for details."
    )
    def test_affine_shear(self, device, dtype):
        torch.manual_seed(0)
        shear = torch.rand(1, 2, device=device, dtype=dtype)
        input = torch.rand(1, 2, 3, 4, device=device, dtype=dtype)

        transform = kornia.geometry.transform.Affine(shear=shear).to(device, dtype)
        actual = transform(input)
        expected = kornia.geometry.transform.shear(input, shear)
        self.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

    def test_affine_rotate_translate(self, device, dtype):
        # TODO: Remove when #666 is implemented
        if device.type == "cuda":
            pytest.skip("Currently breaks in CUDA.See https://github.com/kornia/kornia/issues/666")
        batch_size = 2

        input = torch.tensor(
            [[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]],
            device=device,
            dtype=dtype,
        ).repeat(batch_size, 1, 1, 1)

        angle = torch.tensor(180.0, device=device, dtype=dtype).repeat(batch_size)
        translation = torch.tensor([1.0, 0.0], device=device, dtype=dtype).repeat(batch_size, 1)

        expected = torch.tensor(
            [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 0.0]]],
            device=device,
            dtype=dtype,
        ).repeat(batch_size, 1, 1, 1)

        transform = kornia.geometry.transform.Affine(angle=angle, translation=translation, align_corners=True).to(
            device=device, dtype=dtype
        )
        actual = transform(input)
        self.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

    def test_compose_affine_matrix_3x3(self, device, dtype):
        """To get parameters:
        import torchvision as tv
        from PIL import Image
        from torch import Tensor as T
        import math
        import random
        img_size = (96,96)
        seed = 42
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
        np.random.seed(seed)  # Numpy module.
        random.seed(seed)  # Python random module.
        torch.manual_seed(seed)
        tfm = tv.transforms.RandomAffine(degrees=(-25.0,25.0),
                                        scale=(0.6, 1.4) ,
                                        translate=(0, 0.1),
                                        shear=(-25., 25., -20., 20.))
        angle, translations, scale, shear = tfm.get_params(tfm.degrees, tfm.translate,
                                                        tfm.scale, tfm.shear, img_size)
        print (angle, translations, scale, shear)
        output_size = img_size
        center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)

        matrix = tv.transforms.functional._get_inverse_affine_matrix(center, angle, translations, scale, shear)
        matrix = np.array(matrix).reshape(2,3)
        print (matrix)
        """
        import math

        from torch import Tensor as T

        batch_size, _, height, width = 1, 1, 96, 96
        angle, translations = 6.971339922894188, (0.0, -4.0)
        scale, shear = [0.7785685905190581, 0.7785685905190581], [11.8235607082617, 7.06797949691645]
        matrix_expected = T([[1.27536969, 4.26828945e-01, -3.2349e01], [2.18297196e-03, 1.29424165e00, -9.1996e00]])
        center = T([float(width), float(height)]).view(1, 2) / 2.0 + 0.5
        center = center.expand(batch_size, -1)
        matrix_kornia = kornia.geometry.transform.get_affine_matrix2d(
            T(translations).view(-1, 2),
            center,
            T([scale]).view(-1, 2),
            T([angle]).view(-1),
            T([math.radians(shear[0])]).view(-1, 1),
            T([math.radians(shear[1])]).view(-1, 1),
        )
        matrix_kornia = matrix_kornia.inverse()[0, :2].detach().cpu()
        self.assert_close(matrix_kornia, matrix_expected, atol=1e-4, rtol=1e-4)

    def test_broadcasting_issue_3176(self, device, dtype):
        # Issue 3176: RandomRotation with multi-channel masks caused crash
        # Scenario: Tensor is (B=1, C, H, W) but Matrix is (B=8, 2, 3)
        # This implies applying N different transformations to a single image
        B_mat, B_ten = 8, 1
        C, H, W = 3, 64, 64

        input_tensor = torch.rand(B_ten, C, H, W, device=device, dtype=dtype)

        matrix = torch.rand(B_mat, 2, 3, device=device, dtype=dtype)

        output = kornia.geometry.transform.affine(input_tensor, matrix, mode="bilinear")

        assert output.shape == (B_mat, C, H, W)

    def test_warp_affine_fill_value(self, device, dtype):
        # Feature: Support padding_mode="fill" with custom fill_value
        # Scenario: 1-channel mask, fill with 1.0
        B, C, H, W = 1, 1, 10, 10
        src = torch.zeros(B, C, H, W, device=device, dtype=dtype)

        M = torch.eye(2, 3, device=device, dtype=dtype).unsqueeze(0)
        M[..., 0, 2] = 5.0

        # Fill with 1.0
        fill_val = torch.tensor([1.0], device=device, dtype=dtype)

        out = kornia.geometry.transform.warp_affine(src, M, (H, W), padding_mode="fill", fill_value=fill_val)

        assert out[0, 0, 0, 0] == 1.0

        assert out[0, 0, 0, 9] == 0.0


class TestGetAffineMatrix(BaseTester):
    def test_smoke(self, device, dtype):
        H, W = 5, 5
        translation = torch.tensor([[0.0, 0.0]], device=device, dtype=dtype)
        # NOTE: ideally the center should be [W * 0.5, H * 0.5]
        center = torch.tensor([[W // 2, H // 2]], device=device, dtype=dtype)
        zoom1 = torch.ones([1, 1], device=device, dtype=dtype) * 0.5
        zoom2 = torch.ones([1, 1], device=device, dtype=dtype) * 1.0
        zoom = torch.cat([zoom1, zoom2], -1)
        angle = torch.zeros([1], device=device, dtype=dtype)
        affine_mat = kornia.geometry.get_affine_matrix2d(translation, center, zoom, angle)

        img = torch.ones(1, 1, H, W, device=device, dtype=dtype)
        expected = torch.zeros_like(img)
        expected[..., 1:4] = 1.0

        out = kornia.geometry.transform.warp_affine(img, affine_mat[:, :2], (H, W))
        self.assert_close(out, expected)


class TestGetShearMatrix(BaseTester):
    def test_get_shear_matrix2d_with_controlled_values(self, device, dtype):
        # Define controlled values for shear angles and center
        sx = torch.tensor([0.5], device=device, dtype=dtype)
        sy = torch.tensor([0.25], device=device, dtype=dtype)
        center = torch.tensor([[0.0, 0.0]], device=device, dtype=dtype)

        # Calculate the shear matrix using your function
        out = kornia.geometry.transform.get_shear_matrix2d(center, sx=sx, sy=sy)

        # Define the expected shear matrix with controlled numbers
        expected = torch.tensor(
            [[[1.0, -0.5463, 0.0], [-0.2553, 1.1395, 0.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype
        )

        self.assert_close(out, expected, atol=1e-4, rtol=1e-4)
