from sympy.assumptions.ask import ask, Q
from sympy.core.relational import Eq
from sympy.core.singleton import S
from sympy.core.sympify import _sympify
from sympy.functions.special.tensor_functions import KroneckerDelta
from sympy.matrices.exceptions import NonInvertibleMatrixError
from .matexpr import MatrixExpr


class ZeroMatrix(MatrixExpr):
    """The Matrix Zero 0 - additive identity

    Examples
    ========

    >>> from sympy import MatrixSymbol, ZeroMatrix
    >>> A = MatrixSymbol('A', 3, 5)
    >>> Z = ZeroMatrix(3, 5)
    >>> A + Z
    A
    >>> Z*A.T
    0
    >>> Z.as_explicit()
    Matrix([
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0]])
    """
    is_ZeroMatrix = True

    def __new__(cls, m, n):
        m, n = _sympify(m), _sympify(n)
        cls._check_dim(m)
        cls._check_dim(n)

        return super().__new__(cls, m, n)

    @property
    def shape(self):
        return (self.args[0], self.args[1])

    def _eval_power(self, exp):
        # exp = -1, 0, 1 are already handled at this stage
        if (exp < 0) == True:
            raise NonInvertibleMatrixError("Matrix det == 0; not invertible")
        return self

    def _eval_transpose(self):
        return ZeroMatrix(self.cols, self.rows)

    def _eval_adjoint(self):
        return ZeroMatrix(self.cols, self.rows)

    def _eval_trace(self):
        return S.Zero

    def _eval_determinant(self):
        return S.Zero

    def _eval_inverse(self):
        raise NonInvertibleMatrixError("Matrix det == 0; not invertible.")

    def _eval_as_real_imag(self):
        return (self, self)

    def _eval_conjugate(self):
        return self

    def _entry(self, i, j, **kwargs):
        return S.Zero


class GenericZeroMatrix(ZeroMatrix):
    """
    A zero matrix without a specified shape

    This exists primarily so MatAdd() with no arguments can return something
    meaningful.
    """
    def __new__(cls):
        # super(ZeroMatrix, cls) instead of super(GenericZeroMatrix, cls)
        # because ZeroMatrix.__new__ doesn't have the same signature
        return super(ZeroMatrix, cls).__new__(cls)

    @property
    def rows(self):
        raise TypeError("GenericZeroMatrix does not have a specified shape")

    @property
    def cols(self):
        raise TypeError("GenericZeroMatrix does not have a specified shape")

    @property
    def shape(self):
        raise TypeError("GenericZeroMatrix does not have a specified shape")

    # Avoid Matrix.__eq__ which might call .shape
    def __eq__(self, other):
        return isinstance(other, GenericZeroMatrix)

    def __ne__(self, other):
        return not (self == other)

    def __hash__(self):
        return super().__hash__()



class Identity(MatrixExpr):
    """The Matrix Identity I - multiplicative identity

    Examples
    ========

    >>> from sympy import Identity, MatrixSymbol
    >>> A = MatrixSymbol('A', 3, 5)
    >>> I = Identity(3)
    >>> I*A
    A
    >>> I.as_explicit()
    Matrix([
    [1, 0, 0],
    [0, 1, 0],
    [0, 0, 1]])
    """

    is_Identity = True

    def __new__(cls, n):
        n = _sympify(n)
        cls._check_dim(n)

        return super().__new__(cls, n)

    @property
    def rows(self):
        return self.args[0]

    @property
    def cols(self):
        return self.args[0]

    @property
    def shape(self):
        return (self.args[0], self.args[0])

    @property
    def is_square(self):
        return True

    def _eval_transpose(self):
        return self

    def _eval_trace(self):
        return self.rows

    def _eval_inverse(self):
        return self

    def _eval_as_real_imag(self):
        return (self, ZeroMatrix(*self.shape))

    def _eval_conjugate(self):
        return self

    def _eval_adjoint(self):
        return self

    def _entry(self, i, j, **kwargs):
        eq = Eq(i, j)
        if eq is S.true:
            return S.One
        elif eq is S.false:
            return S.Zero
        return KroneckerDelta(i, j, (0, self.cols-1))

    def _eval_determinant(self):
        return S.One

    def _eval_power(self, exp):
        return self


class GenericIdentity(Identity):
    """
    An identity matrix without a specified shape

    This exists primarily so MatMul() with no arguments can return something
    meaningful.
    """
    def __new__(cls):
        # super(Identity, cls) instead of super(GenericIdentity, cls) because
        # Identity.__new__ doesn't have the same signature
        return super(Identity, cls).__new__(cls)

    @property
    def rows(self):
        raise TypeError("GenericIdentity does not have a specified shape")

    @property
    def cols(self):
        raise TypeError("GenericIdentity does not have a specified shape")

    @property
    def shape(self):
        raise TypeError("GenericIdentity does not have a specified shape")

    @property
    def is_square(self):
        return True

    # Avoid Matrix.__eq__ which might call .shape
    def __eq__(self, other):
        return isinstance(other, GenericIdentity)

    def __ne__(self, other):
        return not (self == other)

    def __hash__(self):
        return super().__hash__()


class OneMatrix(MatrixExpr):
    """
    Matrix whose all entries are ones.

    Also called "matrix of ones" or "all-ones matrix".

    https://en.wikipedia.org/wiki/Matrix_of_ones

    Examples
    ========

    >>> from sympy.matrices.expressions import OneMatrix
    >>> O = OneMatrix(3, 4)
    >>> O.shape
    (3, 4)
    >>> O.as_explicit()
    Matrix([
    [1, 1, 1, 1],
    [1, 1, 1, 1],
    [1, 1, 1, 1]])
    """
    def __new__(cls, m, n, evaluate=False):
        m, n = _sympify(m), _sympify(n)
        cls._check_dim(m)
        cls._check_dim(n)

        if evaluate:
            condition = Eq(m, 1) & Eq(n, 1)
            if condition == True:
                return Identity(1)

        obj = super().__new__(cls, m, n)
        return obj

    @property
    def shape(self):
        return self._args

    @property
    def is_Identity(self):
        return self._is_1x1() == True

    def as_explicit(self):
        from sympy.matrices.immutable import ImmutableDenseMatrix
        return ImmutableDenseMatrix.ones(*self.shape)

    def doit(self, **hints):
        args = self.args
        if hints.get('deep', True):
            args = [a.doit(**hints) for a in args]
        return self.func(*args, evaluate=True)

    def _eval_power(self, exp):
        # exp = -1, 0, 1 are already handled at this stage
        if self._is_1x1() == True:
            return Identity(1)
        if (exp < 0) == True:
            raise NonInvertibleMatrixError("Matrix det == 0; not invertible")
        if ask(Q.integer(exp)):
            return self.shape[0] ** (exp - 1) * OneMatrix(*self.shape)
        return super()._eval_power(exp)

    def _eval_transpose(self):
        return OneMatrix(self.cols, self.rows)

    def _eval_adjoint(self):
        return OneMatrix(self.cols, self.rows)

    def _eval_trace(self):
        return S.One*self.rows

    def _is_1x1(self):
        """Returns true if the matrix is known to be 1x1"""
        shape = self.shape
        return Eq(shape[0], 1) & Eq(shape[1], 1)

    def _eval_determinant(self):
        condition = self._is_1x1()
        if condition == True:
            return S.One
        elif condition == False:
            return S.Zero
        else:
            from sympy.matrices.expressions.determinant import Determinant
            return Determinant(self)

    def _eval_inverse(self):
        condition = self._is_1x1()
        if condition == True:
            return Identity(1)
        elif condition == False:
            raise NonInvertibleMatrixError("Matrix det == 0; not invertible.")
        else:
            from .inverse import Inverse
            return Inverse(self)

    def _eval_as_real_imag(self):
        return self, ZeroMatrix(*self.shape)

    def _eval_conjugate(self):
        return self

    def _entry(self, i, j, **kwargs):
        return S.One


class MatrixUnit(MatrixExpr):
    """
    Matrix with only one nonzero entry with value 1. Also called single-entry matrix.

    https://en.wikipedia.org/wiki/Matrix_unit

    Examples
    ========

    Create a matrix unit of shape `(3, 4)` with unit entry at the second row
    and third column, i.e. at `(1, 2)`

    >>> from sympy.matrices.expressions.special import MatrixUnit
    >>> E = MatrixUnit(3, 4, 1, 2)
    >>> E.shape
    (3, 4)
    >>> E.as_explicit()
    Matrix([
    [0, 0, 0, 0],
    [0, 0, 1, 0],
    [0, 0, 0, 0]])
    >>> E[1, 2]
    1
    >>> E[1, 1]
    0

    The transposition of a matrix unit is a different matrix unit:

    >>> E.T
    MatrixUnit(4, 3, 2, 1)
    >>> E.T.as_explicit()
    Matrix([
    [0, 0, 0],
    [0, 0, 0],
    [0, 1, 0],
    [0, 0, 0]])

    Both shape and position of unit entry may be symbolic

    >>> from sympy import symbols
    >>> a, b, i, j, x, y = symbols("a b i j x y")
    >>> M = MatrixUnit(a, b, i, j)
    >>> M.shape
    (a, b)
    >>> M[x, y]
    KroneckerDelta(i, x, (0, a - 1))*KroneckerDelta(j, y, (0, b - 1))

    """
    def __new__(cls, rows, cols, i, j):
        obj = MatrixExpr.__new__(cls, rows, cols, i, j)
        obj._i = i
        obj._j = j
        return obj

    @property
    def shape(self):
        return self._args[0], self._args[1]

    @property
    def rows(self):
        return self._args[0]

    @property
    def cols(self):
        return self._args[1]

    @property
    def is_square(self):
        return Eq(self.rows, self.cols)

    def _eval_transpose(self):
        return MatrixUnit(self.cols, self.rows, self._j, self._i)

    def _eval_trace(self):
        if (self.rows == self.cols) == True:
            if self.shape == (1, 1):
                return S.One
            return S.Zero

    def _eval_inverse(self):
        if self.shape == (1, 1):
            return self
        raise NonInvertibleMatrixError("Cannot invert this MatrixUnit")

    def _eval_as_real_imag(self):
        return self, ZeroMatrix(*self.shape)

    def _eval_conjugate(self):
        return self

    def _eval_adjoint(self):
        return self._eval_conjugate()._eval_transpose()

    def _entry(self, i, j, **kwargs):
        return KroneckerDelta(i, self._i, (0, self.rows-1)) * KroneckerDelta(j, self._j, (0, self.cols-1))
