/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.plans

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LocalRelation, LogicalPlan, UnionLoop, UnionLoopRef}
import org.apache.spark.sql.types.BooleanType

class NormalizePlanSuite extends SparkFunSuite with SQLConfHelper {

  test("Normalize Project") {
    val baselineCol1 = $"col1".int
    val testCol1 = baselineCol1.newInstance()
    val baselinePlan = LocalRelation(baselineCol1).select(baselineCol1)
    val testPlan = LocalRelation(testCol1).select(testCol1)

    assert(baselinePlan != testPlan)
    assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
  }

  test("Normalize ordering in a project list of an inner Project under Project") {
    val baselinePlan =
      LocalRelation($"col1".int, $"col2".string).select($"col1", $"col2").select($"col1")
    val testPlan =
      LocalRelation($"col1".int, $"col2".string).select($"col2", $"col1").select($"col1")

    assert(baselinePlan != testPlan)
    assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
  }

  test("Normalize ordering in a project list of an inner Project under Aggregate") {
    val baselinePlan =
      LocalRelation($"col1".int, $"col2".string).select($"col1", $"col2").groupBy($"col1")($"col1")
    val testPlan =
      LocalRelation($"col1".int, $"col2".string).select($"col2", $"col1").groupBy($"col1")($"col1")

    assert(baselinePlan != testPlan)
    assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
  }

  test("Normalize ordering in an aggregate list of an inner Aggregate under Project") {
    val baselinePlan = LocalRelation($"col1".int, $"col2".string)
      .groupBy($"col1", $"col2")($"col1", $"col2")
      .select($"col1")
    val testPlan = LocalRelation($"col1".int, $"col2".string)
      .groupBy($"col1", $"col2")($"col2", $"col1")
      .select($"col1")

    assert(baselinePlan != testPlan)
    assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
  }

  test("Normalize ordering in an aggregate list of an inner Aggregate under Project and Filter") {
    val baselinePlan = LocalRelation($"col1".int, $"col2".string)
      .groupBy($"col1", $"col2")($"col1", $"col2")
      .where($"col1" === 1)
      .select($"col1")
    val testPlan = LocalRelation($"col1".int, $"col2".string)
      .groupBy($"col1", $"col2")($"col2", $"col1")
      .where($"col1" === 1)
      .select($"col1")

    assert(baselinePlan != testPlan)
    assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
  }

  test("Normalize ordering in an aggregate list of an inner Aggregate under Project and Sort") {
    val baselinePlan = LocalRelation($"col1".int, $"col2".string)
      .groupBy($"col1", $"col2")($"col1", $"col2")
      .orderBy(SortOrder($"col1", Ascending))
      .select($"col1")
    val testPlan = LocalRelation($"col1".int, $"col2".string)
      .groupBy($"col1", $"col2")($"col2", $"col1")
      .orderBy(SortOrder($"col1", Ascending))
      .select($"col1")

    assert(baselinePlan != testPlan)
    assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
  }

  test(
    "Normalize ordering in an aggregate list of an inner Aggregate under Project Sort and Filter"
  ) {
    val baselinePlan = LocalRelation($"col1".int, $"col2".string)
      .groupBy($"col1", $"col2")($"col1", $"col2")
      .where($"col1" === 1)
      .orderBy(SortOrder($"col1", Ascending))
      .select($"col1")
    val testPlan = LocalRelation($"col1".int, $"col2".string)
      .groupBy($"col1", $"col2")($"col2", $"col1")
      .where($"col1" === 1)
      .orderBy(SortOrder($"col1", Ascending))
      .select($"col1")

    assert(baselinePlan != testPlan)
    assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
  }

  test("Normalize InheritAnalysisRules expressions") {
    val castWithoutTimezone =
      Cast(child = Literal(1), dataType = BooleanType, ansiEnabled = conf.ansiEnabled)
    val castWithTimezone = castWithoutTimezone.withTimeZone(conf.sessionLocalTimeZone)

    val baselineExpression = AssertTrue(castWithTimezone)
    val baselinePlan = LocalRelation().select(baselineExpression)

    val testExpression = AssertTrue(castWithoutTimezone)
    val testPlan = LocalRelation().select(testExpression)

    // Before calling [[setTimezoneForAllExpression]], [[AssertTrue]] node will look like:
    //
    // AssertTrue(Cast(Literal(1)), message, If(Cast(Literal(1)), Literal(null), error))
    //
    // Calling [[setTimezoneForAllExpression]] will only apply timezone to the second Cast node
    // because [[InheritAnalysisRules]] only sees replacement expression as its child. This will
    // cause the difference when comparing [[resolvedBaselinePlan]] and [[resolvedTestPlan]],
    // therefore we need normalization.

    // Before applying timezone, no timezone is set.
    testPlan.expressions.foreach {
      case _ @ AssertTrue(firstCast: Cast, _, _ @ If(secondCast: Cast, _, _)) =>
        assert(firstCast.timeZoneId.isEmpty)
        assert(secondCast.timeZoneId.isEmpty)
      case _ =>
    }

    val resolvedBaselinePlan = setTimezoneForAllExpression(baselinePlan)
    val resolvedTestPlan = setTimezoneForAllExpression(testPlan)

    // After applying timezone, only the second cast gets timezone.
    resolvedTestPlan.expressions.foreach {
      case _ @ AssertTrue(firstCast: Cast, _, _ @ If(secondCast: Cast, _, _)) =>
        assert(firstCast.timeZoneId.isEmpty)
        assert(secondCast.timeZoneId.isDefined)
      case _ =>
    }

    // However, plans are still different.
    assert(resolvedBaselinePlan != resolvedTestPlan)
    assert(NormalizePlan(resolvedBaselinePlan) == NormalizePlan(resolvedTestPlan))
  }

  test("Normalize CommonExpressionId") {
    val baselineCommonExpressionRef =
      CommonExpressionRef(id = new CommonExpressionId, dataType = BooleanType, nullable = false)
    val baselineCommonExpressionDef = CommonExpressionDef(child = Literal(0))
    val testCommonExpressionRef =
      CommonExpressionRef(id = new CommonExpressionId, dataType = BooleanType, nullable = false)
    val testCommonExpressionDef = CommonExpressionDef(child = Literal(0))

    val baselinePlanRef = LocalRelation().select(baselineCommonExpressionRef)
    val testPlanRef = LocalRelation().select(testCommonExpressionRef)

    assert(baselinePlanRef != testPlanRef)
    assert(NormalizePlan(baselinePlanRef) == NormalizePlan(testPlanRef))

    val baselinePlanDef = LocalRelation().select(baselineCommonExpressionDef)
    val testPlanDef = LocalRelation().select(testCommonExpressionDef)

    assert(baselinePlanDef != testPlanDef)
    assert(NormalizePlan(baselinePlanDef) == NormalizePlan(testPlanDef))
  }

  test("Normalize non-deterministic expressions") {
    val random = new Random()
    val baselineExpression = rand(random.nextLong())
    val testExpression = rand(random.nextLong())

    val baselinePlan = LocalRelation().select(baselineExpression)
    val testPlan = LocalRelation().select(testExpression)

    assert(baselinePlan != testPlan)
    assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
  }

  test("Normalize UnionLoopRef IDs") {
    val col1 = $"col1".int
    val col2 = col1.newInstance()

    // Create two UnionLoopRefs with different loopIds
    val baselineLoopRef = UnionLoopRef(
      loopId = 100L,
      output = Seq(col2),
      accumulated = false
    )

    val testLoopRef = UnionLoopRef(
      loopId = 200L,
      output = Seq(col2),
      accumulated = false
    )

    // Before normalization, plans are different
    assert(baselineLoopRef != testLoopRef)

    // After normalization, they should be equal (loopIds normalized)
    assert(NormalizePlan(baselineLoopRef) == NormalizePlan(testLoopRef))
  }

  test("Normalize UnionLoop IDs and outputAttrIds and UnionLoopRefIds") {
    val col1 = $"col1".int
    val col2 = col1.newInstance()
    val anchor = LocalRelation(col1)

    // Create two UnionLoops with different IDs and different outputAttrIds
    val unionLoop1 = UnionLoop(
      id = 100L,
      anchor = anchor,
      recursion = UnionLoopRef(loopId = 100L, output = Seq(col2), accumulated = false),
      outputAttrIds = Seq(ExprId(1), ExprId(2))
    )

    val unionLoop2 = UnionLoop(
      id = 200L,
      anchor = anchor,
      recursion = UnionLoopRef(loopId = 200L, output = Seq(col2), accumulated = false),
      outputAttrIds = Seq(ExprId(1), ExprId(2))
    )

    // Before normalization, plans are different
    assert(unionLoop1 != unionLoop2)

    // After normalization, they should be equal (IDs normalized, outputAttrIds zeroed)
    assert(NormalizePlan(unionLoop1) == NormalizePlan(unionLoop2))
  }

  test("Normalize rCTEs") {
    val col1 = $"col1".int
    val col2 = $"col2".int
    val anchor = LocalRelation(col1)

    // Create two full recursive CTEs - CTERelationDef with a UnionLoop and UnionLoopRef with the
    // same id
    val recursiveCTE1 = CTERelationDef(
      child = UnionLoop(
        id = 100L,
        anchor = anchor,
        recursion = UnionLoopRef(loopId = 100L, output = Seq(col2), accumulated = false),
        outputAttrIds = Seq(ExprId(1), ExprId(2))
      ),
      id = 100L
    )

    val recursiveCTE2 = CTERelationDef(
      child = UnionLoop(
        id = 200L,
        anchor = anchor,
        recursion = UnionLoopRef(loopId = 200L, output = Seq(col2), accumulated = false),
        outputAttrIds = Seq(ExprId(1), ExprId(2))
      ),
      id = 200L
    )

    val normalizedRecursiveCTE = CTERelationDef(
      child = UnionLoop(
        id = 0L,
        anchor = LocalRelation(col1.withExprId(ExprId(0))),
        recursion = UnionLoopRef(
          loopId = 0L,
          output = Seq(col2.withExprId(ExprId(0))),
          accumulated = false
        ),
        outputAttrIds = Seq(ExprId(0), ExprId(0))
      ),
      id = 0L
    )

    // Before normalization, plans are different
    assert(recursiveCTE1 != recursiveCTE2)

    // After normalization, they should be equal to the normalized plan
    assert(NormalizePlan(recursiveCTE1) == normalizedRecursiveCTE)
    assert(NormalizePlan(recursiveCTE2) == normalizedRecursiveCTE)
  }

  private def setTimezoneForAllExpression(plan: LogicalPlan): LogicalPlan = {
    plan.transformAllExpressions {
      case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
        e.withTimeZone(conf.sessionLocalTimeZone)
    }
  }
}
