# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""nn.Module containing operators to work on RGB-Depth images."""

from __future__ import annotations

from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn

from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
from kornia.filters.sobel import spatial_gradient
from kornia.geometry.grid import create_meshgrid

from .camera import PinholeCamera, cam2pixel, pixel2cam, project_points, unproject_points
from .conversions import normalize_pixel_coordinates, normalize_points_with_intrinsics
from .linalg import convert_points_to_homogeneous, transform_points

"""nn.Module containing operators to work on RGB-Depth images."""

__all__ = [
    "DepthWarper",
    "depth_from_disparity",
    "depth_from_plane_equation",
    "depth_to_3d",
    "depth_to_3d_v2",
    "depth_to_normals",
    "depth_warp",
    "unproject_meshgrid",
    "warp_frame_depth",
]


def unproject_meshgrid(
    height: int,
    width: int,
    camera_matrix: torch.Tensor,
    normalize_points: bool = False,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
    """Compute a 3d point per pixel given its depth value and the camera intrinsics.

    .. tip::

        This function should be used in conjunction with :py:func:`kornia.geometry.depth.depth_to_3d_v2` to cache
        the meshgrid computation when warping multiple frames with the same camera intrinsics.

    Args:
        height: height of image.
        width: width of image.
        camera_matrix: tensor containing the camera intrinsics with shape :math:`(3, 3)`.
        normalize_points: whether to normalize the pointcloud. This must be set to `True` when the depth is
          represented as the Euclidean ray length from the camera position.
        device: device to place the result on.
        dtype: dtype of the result.

    Return:
        tensor with a 3d point per pixel of the same resolution as the input :math:`(*, H, W, 3)`.

    """
    KORNIA_CHECK_SHAPE(camera_matrix, ["*", "3", "3"])

    # create base coordinates grid
    points_uv: torch.Tensor = create_meshgrid(
        height, width, normalized_coordinates=False, device=device, dtype=dtype
    ).squeeze()  # HxWx2

    # project pixels to camera frame
    camera_matrix_tmp: torch.Tensor = camera_matrix[:, None, None]  # Bx1x1x3x3

    points_xy = normalize_points_with_intrinsics(points_uv, camera_matrix_tmp)  # HxWx2

    # unproject pixels to camera frame
    points_xyz = convert_points_to_homogeneous(points_xy)  # HxWx3

    if normalize_points:
        points_xyz = F.normalize(points_xyz, dim=-1, p=2)

    return points_xyz


def depth_to_3d_v2(
    depth: torch.Tensor,
    camera_matrix: torch.Tensor,
    normalize_points: bool = False,
    xyz_grid: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    # NOTE: when this replaces the `depth_to_3d` behaviour, a deprecated function should be added here, instead
    # of just replace the other function.
    """Compute a 3d point per pixel given its depth value and the camera intrinsics.

    .. note::

        This is an alternative implementation of :py:func:`kornia.geometry.depth.depth_to_3d`
        that does not require the creation of a meshgrid.

    Args:
        depth: image tensor containing a depth value per pixel with shape :math:`(*, H, W)`.
        camera_matrix: tensor containing the camera intrinsics with shape :math:`(*, 3, 3)`.
        normalize_points: whether to normalise the pointcloud. This must be set to `True` when the depth is
          represented as the Euclidean ray length from the camera position.
        xyz_grid: explicit xyz point values.

    Return:
        tensor with a 3d point per pixel of the same resolution as the input :math:`(*, H, W, 3)`.

    Example:
        >>> depth = torch.rand(4, 4)
        >>> K = torch.eye(3).repeat(2,1,1)
        >>> depth_to_3d_v2(depth, K).shape
        torch.Size([2, 4, 4, 3])

    """
    KORNIA_CHECK_SHAPE(depth, ["*", "H", "W"])
    KORNIA_CHECK_SHAPE(camera_matrix, ["*", "3", "3"])

    # create base grid if not provided
    height, width = depth.shape[-2:]
    points_xyz: torch.Tensor = xyz_grid or unproject_meshgrid(
        height, width, camera_matrix, normalize_points, depth.device, depth.dtype
    )

    KORNIA_CHECK_SHAPE(points_xyz, ["*", "H", "W", "3"])

    return points_xyz * depth[..., None]  # HxWx3


def depth_to_3d(depth: torch.Tensor, camera_matrix: torch.Tensor, normalize_points: bool = False) -> torch.Tensor:
    """Compute a 3d point per pixel given its depth value and the camera intrinsics.

    .. note::

        This is an alternative implementation of `depth_to_3d` that does not require the creation of a meshgrid.
        In future, we will support only this implementation.

    Args:
        depth: image tensor containing a depth value per pixel with shape :math:`(B, 1, H, W)`.
        camera_matrix: tensor containing the camera intrinsics with shape :math:`(B, 3, 3)`.
        normalize_points: whether to normalise the pointcloud. This must be set to `True` when the depth is
          represented as the Euclidean ray length from the camera position.

    Return:
        tensor with a 3d point per pixel of the same resolution as the input :math:`(B, 3, H, W)`.

    Example:
        >>> depth = torch.rand(1, 1, 4, 4)
        >>> K = torch.eye(3)[None]
        >>> depth_to_3d(depth, K).shape
        torch.Size([1, 3, 4, 4])

    """
    KORNIA_CHECK_IS_TENSOR(depth)
    KORNIA_CHECK_IS_TENSOR(camera_matrix)
    KORNIA_CHECK_SHAPE(depth, ["B", "1", "H", "W"])
    KORNIA_CHECK_SHAPE(camera_matrix, ["B", "3", "3"])

    # create base coordinates grid
    _, _, height, width = depth.shape
    points_2d: torch.Tensor = create_meshgrid(
        height, width, normalized_coordinates=False, device=depth.device, dtype=depth.dtype
    )  # 1xHxWx2

    # depth should come in Bx1xHxW
    points_depth: torch.Tensor = depth.permute(0, 2, 3, 1)  # 1xHxWx1

    # project pixels to camera frame
    camera_matrix_tmp: torch.Tensor = camera_matrix[:, None, None]  # Bx1x1x3x3
    points_3d: torch.Tensor = unproject_points(
        points_2d, points_depth, camera_matrix_tmp, normalize=normalize_points
    )  # BxHxWx3

    return points_3d.permute(0, 3, 1, 2)  # Bx3xHxW


def depth_to_normals(depth: torch.Tensor, camera_matrix: torch.Tensor, normalize_points: bool = False) -> torch.Tensor:
    """Compute the normal surface per pixel.

    Args:
        depth: image tensor containing a depth value per pixel with shape :math:`(B, 1, H, W)`.
        camera_matrix: tensor containing the camera intrinsics with shape :math:`(B, 3, 3)`.
        normalize_points: whether to normalize the pointcloud. This must be set to `True` when the depth is
        represented as the Euclidean ray length from the camera position.

    Return:
        tensor with a normal surface vector per pixel of the same resolution as the input :math:`(B, 3, H, W)`.

    Example:
        >>> depth = torch.rand(1, 1, 4, 4)
        >>> K = torch.eye(3)[None]
        >>> depth_to_normals(depth, K).shape
        torch.Size([1, 3, 4, 4])

    """
    KORNIA_CHECK_IS_TENSOR(depth)
    KORNIA_CHECK_IS_TENSOR(camera_matrix)
    KORNIA_CHECK_SHAPE(depth, ["B", "1", "H", "W"])
    KORNIA_CHECK_SHAPE(camera_matrix, ["B", "3", "3"])

    # compute the 3d points from depth
    xyz: torch.Tensor = depth_to_3d(depth, camera_matrix, normalize_points)  # Bx3xHxW

    # compute the pointcloud spatial gradients
    gradients: torch.Tensor = spatial_gradient(xyz)  # Bx3x2xHxW

    # compute normals
    a, b = gradients[:, :, 0], gradients[:, :, 1]  # Bx3xHxW

    normals: torch.Tensor = torch.linalg.cross(a, b, dim=1)
    return F.normalize(normals, dim=1, p=2)


def depth_from_plane_equation(
    plane_normals: torch.Tensor,
    plane_offsets: torch.Tensor,
    points_uv: torch.Tensor,
    camera_matrix: torch.Tensor,
    eps: float = 1e-8,
) -> torch.Tensor:
    """Compute depth values from plane equations and pixel coordinates.

    Args:
        plane_normals (torch.Tensor): Plane normal vectors of shape (B, 3).
        plane_offsets (torch.Tensor): Plane offsets of shape (B, 1).
        points_uv (torch.Tensor): Pixel coordinates of shape (B, N, 2).
        camera_matrix (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3).
        eps: epsilon for numerical stability.

    Returns:
        torch.Tensor: Computed depth values at the given pixels, shape (B, N).

    """
    KORNIA_CHECK_SHAPE(plane_normals, ["B", "3"])
    KORNIA_CHECK_SHAPE(plane_offsets, ["B", "1"])
    KORNIA_CHECK_SHAPE(points_uv, ["B", "N", "2"])
    KORNIA_CHECK_SHAPE(camera_matrix, ["B", "3", "3"])

    # Normalize pixel coordinates
    points_xy = normalize_points_with_intrinsics(points_uv, camera_matrix)  # (B, N, 2)
    rays = convert_points_to_homogeneous(points_xy)  # (B, N, 3)

    # Reshape plane normals to match rays
    plane_normals_exp = plane_normals.unsqueeze(1)  # (B, 1, 3)
    # No need to unsqueeze plane_offsets; it is already (B, 1)

    # Compute the denominator of the depth equation
    denom = torch.sum(rays * plane_normals_exp, dim=-1)  # (B, N)
    denom_abs = torch.abs(denom)
    zero_mask = denom_abs < eps
    denom = torch.where(zero_mask, eps * torch.sign(denom), denom)

    # Compute depth from plane equation
    depth = plane_offsets / denom  # plane_offsets: (B, 1), denom: (B, N) -> depth: (B, N)
    return depth


def warp_frame_depth(
    image_src: torch.Tensor,
    depth_dst: torch.Tensor,
    src_trans_dst: torch.Tensor,
    camera_matrix: torch.Tensor,
    normalize_points: bool = False,
) -> torch.Tensor:
    """Warp a tensor from a source to destination frame by the depth in the destination.

    Compute 3d points from the depth, transform them using given transformation, then project the point cloud to an
    image plane.

    Args:
        image_src: image tensor in the source frame with shape :math:`(B,D,H,W)`.
        depth_dst: depth tensor in the destination frame with shape :math:`(B,1,H,W)`.
        src_trans_dst: transformation matrix from destination to source with shape :math:`(B,4,4)`.
        camera_matrix: tensor containing the camera intrinsics with shape :math:`(B,3,3)`.
        normalize_points: whether to normalize the pointcloud. This must be set to ``True`` when the depth
           is represented as the Euclidean ray length from the camera position.

    Return:
        the warped tensor in the source frame with shape :math:`(B,3,H,W)`.

    """
    KORNIA_CHECK_SHAPE(image_src, ["B", "D", "H", "W"])
    KORNIA_CHECK_SHAPE(depth_dst, ["B", "1", "H", "W"])
    KORNIA_CHECK_SHAPE(src_trans_dst, ["B", "4", "4"])
    KORNIA_CHECK_SHAPE(camera_matrix, ["B", "3", "3"])

    # unproject source points to camera frame
    points_3d_dst: torch.Tensor = depth_to_3d(depth_dst, camera_matrix, normalize_points)  # Bx3xHxW

    # transform points from source to destination
    points_3d_dst = points_3d_dst.permute(0, 2, 3, 1)  # BxHxWx3

    # apply transformation to the 3d points
    points_3d_src = transform_points(src_trans_dst[:, None], points_3d_dst)  # BxHxWx3

    # project back to pixels
    camera_matrix_tmp: torch.Tensor = camera_matrix[:, None, None]  # Bx1x1xHxW
    points_2d_src: torch.Tensor = project_points(points_3d_src, camera_matrix_tmp)  # BxHxWx2

    # normalize points between [-1 / 1]
    height, width = depth_dst.shape[-2:]
    points_2d_src_norm: torch.Tensor = normalize_pixel_coordinates(points_2d_src, height, width)  # BxHxWx2

    return F.grid_sample(image_src, points_2d_src_norm, align_corners=True)


class DepthWarper(nn.Module):
    r"""Warp a patch by depth.

    .. math::
        P_{src}^{\{dst\}} = K_{dst} * T_{src}^{\{dst\}}

        I_{src} = \\omega(I_{dst}, P_{src}^{\{dst\}}, D_{src})

    Args:
        pinholes_dst: the pinhole models for the destination frame.
        height: the height of the image to warp.
        width: the width of the image to warp.
        mode: interpolation mode to calculate output values ``'bilinear'`` | ``'nearest'``.
        padding_mode: padding mode for outside grid values ``'zeros'`` | ``'border'`` | ``'reflection'``.
        align_corners: interpolation flag.

    """

    # All per-instance, not global (thread safe, multiple warps)
    def __init__(
        self,
        pinhole_dst: PinholeCamera,
        height: int,
        width: int,
        mode: str = "bilinear",
        padding_mode: str = "zeros",
        align_corners: bool = True,
    ) -> None:
        super().__init__()
        self.width: int = width
        self.height: int = height
        self.mode: str = mode
        self.padding_mode: str = padding_mode
        self.eps = 1e-6
        self.align_corners: bool = align_corners

        # state members
        # _pinhole_dst is Type[PinholeCamera], enforce in constructor
        if not isinstance(pinhole_dst, PinholeCamera):
            raise TypeError(f"Expected pinhole_dst as PinholeCamera, got {type(pinhole_dst)}")
        self._pinhole_dst: PinholeCamera = pinhole_dst
        self._pinhole_src: None | PinholeCamera = None
        self._dst_proj_src: None | torch.Tensor = None

        # Meshgrid only depends on (height, width), can be staticmethod cached
        self.grid: torch.Tensor = self._create_meshgrid(height, width)

    @staticmethod
    def _create_meshgrid(height: int, width: int) -> torch.Tensor:
        grid: torch.Tensor = create_meshgrid(height, width, normalized_coordinates=False)  # 1xHxWx2
        return convert_points_to_homogeneous(grid)  # append ones to last dim

    def compute_projection_matrix(self, pinhole_src: PinholeCamera) -> DepthWarper:
        """Compute the projection matrix from the source to destination frame."""
        # Inline type checks for faster fail-fast
        if type(self._pinhole_dst) is not PinholeCamera:
            raise TypeError(
                f"Member self._pinhole_dst expected to be of class PinholeCamera. Got {type(self._pinhole_dst)}"
            )
        if type(pinhole_src) is not PinholeCamera:
            raise TypeError(f"Argument pinhole_src expected to be of class PinholeCamera. Got {type(pinhole_src)}")
        # Compute transformation matrix: dst_extrinsics @ inv(src_extrinsics)
        batch_shape = pinhole_src.extrinsics.shape[:-2]
        device = pinhole_src.extrinsics.device
        dtype = pinhole_src.extrinsics.dtype

        # Create 4x4 identity matrices efficiently
        inv_extr = torch.eye(4, device=device, dtype=dtype).expand(*batch_shape, 4, 4).contiguous()
        dst_trans_src = torch.eye(4, device=device, dtype=dtype).expand(*batch_shape, 4, 4).contiguous()

        # Inline inverse transformation
        src_rmat = pinhole_src.extrinsics[..., :3, :3]
        src_tvec = pinhole_src.extrinsics[..., :3, 3:]
        inv_rmat = torch.transpose(src_rmat, -1, -2)
        inv_tvec = torch.matmul(-inv_rmat, src_tvec)

        # Set rotation and translation parts
        inv_extr[..., :3, :3] = inv_rmat
        inv_extr[..., :3, 3:] = inv_tvec

        # Compose with dst extrinsics
        dst_rmat = self._pinhole_dst.extrinsics[..., :3, :3]
        dst_tvec = self._pinhole_dst.extrinsics[..., :3, 3:]
        composed_rmat = torch.matmul(dst_rmat, inv_rmat)
        composed_tvec = torch.matmul(dst_rmat, inv_tvec) + dst_tvec

        dst_trans_src[..., :3, :3] = composed_rmat
        dst_trans_src[..., :3, 3:] = composed_tvec

        # intrinsics (Nx3x3) @ extrinsics (Nx4x4)
        dst_proj_src = torch.matmul(self._pinhole_dst.intrinsics, dst_trans_src)

        self._pinhole_src = pinhole_src
        self._dst_proj_src = dst_proj_src
        return self

    def _compute_projection(self, x: float, y: float, invd: float) -> torch.Tensor:
        if self._dst_proj_src is None or self._pinhole_src is None:
            raise ValueError("Please, call compute_projection_matrix.")

        point = torch.tensor(
            [[[x], [y], [invd], [1.0]]], device=self._dst_proj_src.device, dtype=self._dst_proj_src.dtype
        )
        flow = torch.matmul(self._dst_proj_src, point)
        z = 1.0 / flow[:, 2]
        _x = flow[:, 0] * z
        _y = flow[:, 1] * z
        return torch.cat([_x, _y], 1)

    def compute_subpixel_step(self) -> torch.Tensor:
        """Compute the inverse depth step for sub pixel accurate sampling of the depth cost volume, per camera.

        Szeliski, Richard, and Daniel Scharstein. "Symmetric sub-pixel stereo matching." European Conference on Computer
        Vision. Springer Berlin Heidelberg, 2002.
        """
        if self._dst_proj_src is None:
            raise RuntimeError("Expected torch.Tensor, but got None Type from the projection matrix")

        delta_d = 0.01
        center_x = self.width / 2
        center_y = self.height / 2

        # Batch both invds in one call (for potential fused kernels in future) for efficiency
        invds = (1.0 - delta_d, 1.0 + delta_d)
        # Instead of two calls, process both at once with minimal tensor construction
        points = (
            torch.tensor(
                [[center_x, center_y, invds[0], 1.0], [center_x, center_y, invds[1], 1.0]],
                dtype=self._dst_proj_src.dtype,
                device=self._dst_proj_src.device,
            )
            .transpose(0, 1)
            .unsqueeze(0)
        )  # (1, 4, 2)
        # Repeat projection matrix for batch
        proj = self._dst_proj_src
        flow = torch.matmul(proj, points)  # (N, 3/4, 2)
        zs = 1.0 / flow[:, 2]  # (N, 2)
        xs = flow[:, 0] * zs
        ys = flow[:, 1] * zs
        xys = torch.stack((xs, ys), dim=-1)  # (N, 2, 2)
        dxy = torch.norm(xys[:, 1] - xys[:, 0], p=2, dim=1) / 2.0
        dxdd = dxy / delta_d
        # half pixel sampling, min for all cameras
        return torch.min(0.5 / dxdd)

    def warp_grid(self, depth_src: torch.Tensor) -> torch.Tensor:
        """Compute a grid for warping a given the depth from the reference pinhole camera.

        The function `compute_projection_matrix` has to be called beforehand in order to have precomputed the relative
        projection matrices encoding the relative pose and the intrinsics between the reference and a non reference
        camera.
        """
        # TODO: add type and value checkings
        if self._dst_proj_src is None or self._pinhole_src is None:
            raise ValueError("Please, call compute_projection_matrix.")

        if len(depth_src.shape) != 4:
            raise ValueError(f"Input depth_src has to be in the shape of Bx1xHxW. Got {depth_src.shape}")

        # unpack depth attributes
        batch_size, _, _, _ = depth_src.shape
        device: torch.device = depth_src.device
        dtype: torch.dtype = depth_src.dtype

        # expand the base coordinate grid according to the input batch size
        pixel_coords: torch.Tensor = self.grid.to(device=device, dtype=dtype).expand(batch_size, -1, -1, -1)  # BxHxWx3

        # reproject the pixel coordinates to the camera frame
        cam_coords_src: torch.Tensor = pixel2cam(
            depth_src, self._pinhole_src.intrinsics_inverse().to(device=device, dtype=dtype), pixel_coords
        )  # BxHxWx3

        # reproject the camera coordinates to the pixel
        pixel_coords_src: torch.Tensor = cam2pixel(
            cam_coords_src, self._dst_proj_src.to(device=device, dtype=dtype)
        )  # (B*N)xHxWx2

        # normalize between -1 and 1 the coordinates
        pixel_coords_src_norm: torch.Tensor = normalize_pixel_coordinates(pixel_coords_src, self.height, self.width)
        return pixel_coords_src_norm

    def forward(self, depth_src: torch.Tensor, patch_dst: torch.Tensor) -> torch.Tensor:
        """Warp a tensor from destination frame to reference given the depth in the reference frame.

        Args:
            depth_src: the depth in the reference frame. The tensor must have a shape :math:`(B, 1, H, W)`.
            patch_dst: the patch in the destination frame. The tensor must have a shape :math:`(B, C, H, W)`.

        Return:
            the warped patch from destination frame to reference.

        Shape:
            - Output: :math:`(N, C, H, W)` where C = number of channels.

        Example:
            >>> # pinholes camera models
            >>> pinhole_dst = PinholeCamera(torch.randn(1, 4, 4), torch.randn(1, 4, 4),
            ... torch.tensor([32]), torch.tensor([32]))
            >>> pinhole_src = PinholeCamera(torch.randn(1, 4, 4), torch.randn(1, 4, 4),
            ... torch.tensor([32]), torch.tensor([32]))
            >>> # create the depth warper, compute the projection matrix
            >>> warper = DepthWarper(pinhole_dst, 32, 32)
            >>> _ = warper.compute_projection_matrix(pinhole_src)
            >>> # warp the destination frame to reference by depth
            >>> depth_src = torch.ones(1, 1, 32, 32)  # Nx1xHxW
            >>> image_dst = torch.rand(1, 3, 32, 32)  # NxCxHxW
            >>> image_src = warper(depth_src, image_dst)  # NxCxHxW

        """
        return F.grid_sample(
            patch_dst,
            self.warp_grid(depth_src),
            mode=self.mode,
            padding_mode=self.padding_mode,
            align_corners=self.align_corners,
        )


def depth_warp(
    pinhole_dst: PinholeCamera,
    pinhole_src: PinholeCamera,
    depth_src: torch.Tensor,
    patch_dst: torch.Tensor,
    height: int,
    width: int,
    align_corners: bool = True,
) -> torch.Tensor:
    """Warp a tensor from destination frame to reference given the depth in the reference frame.

    See :class:`~kornia.geometry.warp.DepthWarper` for details.

    Example:
        >>> # pinholes camera models
        >>> pinhole_dst = PinholeCamera(torch.randn(1, 4, 4), torch.randn(1, 4, 4),
        ... torch.tensor([32]), torch.tensor([32]))
        >>> pinhole_src = PinholeCamera(torch.randn(1, 4, 4), torch.randn(1, 4, 4),
        ... torch.tensor([32]), torch.tensor([32]))
        >>> # warp the destination frame to reference by depth
        >>> depth_src = torch.ones(1, 1, 32, 32)  # Nx1xHxW
        >>> image_dst = torch.rand(1, 3, 32, 32)  # NxCxHxW
        >>> image_src = depth_warp(pinhole_dst, pinhole_src, depth_src, image_dst, 32, 32)  # NxCxHxW

    """
    # Cache and reuse warper and projection matrix (single use/call)
    # Inlined for performance, use local variables and freed objects
    # instead of class members where possible.
    warper = DepthWarper(pinhole_dst, height, width, align_corners=align_corners)
    # projection matrix is required for each call, avoid double checking in class
    warper.compute_projection_matrix(pinhole_src)
    # __call__ implemented by nn.Module (likely calls forward, not shown).
    return warper(depth_src, patch_dst)


def depth_from_disparity(
    disparity: torch.Tensor, baseline: float | torch.Tensor, focal: float | torch.Tensor
) -> torch.Tensor:
    """Compute depth from disparity.

    Args:
        disparity: Disparity tensor of shape :math:`(*, H, W)`.
        baseline: float/tensor containing the distance between the two lenses.
        focal: float/tensor containing the focal length.

    Return:
        Depth map of the shape :math:`(*, H, W)`.

    Example:
        >>> disparity = torch.rand(4, 1, 4, 4)
        >>> baseline = torch.rand(1)
        >>> focal = torch.rand(1)
        >>> depth_from_disparity(disparity, baseline, focal).shape
        torch.Size([4, 1, 4, 4])

    """
    KORNIA_CHECK_IS_TENSOR(disparity, f"Input disparity type is not a torch.Tensor. Got {type(disparity)}.")
    KORNIA_CHECK_SHAPE(disparity, ["*", "H", "W"])
    KORNIA_CHECK(
        isinstance(baseline, (float, torch.Tensor)),
        f"Input baseline should be either a float or torch.Tensor. Got {type(baseline)}",
    )
    KORNIA_CHECK(
        isinstance(focal, (float, torch.Tensor)),
        f"Input focal should be either a float or torch.Tensor. Got {type(focal)}",
    )

    if isinstance(baseline, torch.Tensor):
        KORNIA_CHECK_SHAPE(baseline, ["1"])

    if isinstance(focal, torch.Tensor):
        KORNIA_CHECK_SHAPE(focal, ["1"])

    return baseline * focal / (disparity + 1e-8)
