/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.connect.service

import java.util.UUID

import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkSQLException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
import org.apache.spark.sql.pipelines.logging.PipelineEvent
import org.apache.spark.sql.test.SharedSparkSession

class SparkConnectSessionManagerSuite extends SharedSparkSession {

  override def beforeEach(): Unit = {
    super.beforeEach()
    SparkConnectService.sessionManager.invalidateAllSessions()
    SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
  }

  test("sessionId needs to be an UUID") {
    val key = SessionKey("user", "not an uuid")
    val exGetOrCreate = intercept[SparkSQLException] {
      SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
    }
    assert(exGetOrCreate.getCondition == "INVALID_HANDLE.FORMAT")
  }

  test(
    "getOrCreateIsolatedSession/getIsolatedSession/getIsolatedSessionIfPresent " +
      "gets the existing session") {
    val key = SessionKey("user", UUID.randomUUID().toString)
    val sessionHolder = SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)

    val sessionGetOrCreate =
      SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
    assert(sessionGetOrCreate === sessionHolder)

    val sessionGet = SparkConnectService.sessionManager.getIsolatedSession(key, None)
    assert(sessionGet === sessionHolder)

    val sessionGetIfPresent = SparkConnectService.sessionManager.getIsolatedSessionIfPresent(key)
    assert(sessionGetIfPresent.get === sessionHolder)
  }

  test("client-observed session id validation works") {
    val key = SessionKey("user", UUID.randomUUID().toString)
    val sessionHolder = SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
    // Works if the client doesn't set the observed session id.
    SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
    // Works with the correct existing session id.
    SparkConnectService.sessionManager.getOrCreateIsolatedSession(
      key,
      Some(sessionHolder.session.sessionUUID))
    // Fails with the different session id.
    val exGet = intercept[SparkSQLException] {
      SparkConnectService.sessionManager.getOrCreateIsolatedSession(
        key,
        Some(sessionHolder.session.sessionUUID + "invalid"))
    }
    assert(exGet.getCondition == "INVALID_HANDLE.SESSION_CHANGED")
  }

  test(
    "getOrCreateIsolatedSession/getIsolatedSession/getIsolatedSessionIfPresent " +
      "doesn't recreate closed session") {
    val key = SessionKey("user", UUID.randomUUID().toString)
    val sessionHolder = SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
    SparkConnectService.sessionManager.closeSession(key)

    val exGetOrCreate = intercept[SparkSQLException] {
      SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
    }
    assert(exGetOrCreate.getCondition == "INVALID_HANDLE.SESSION_CLOSED")

    val exGet = intercept[SparkSQLException] {
      SparkConnectService.sessionManager.getIsolatedSession(key, None)
    }
    assert(exGet.getCondition == "INVALID_HANDLE.SESSION_CLOSED")

    val sessionGetIfPresent = SparkConnectService.sessionManager.getIsolatedSessionIfPresent(key)
    assert(sessionGetIfPresent.isEmpty)
  }

  test("getIsolatedSession/getIsolatedSessionIfPresent when session doesn't exist") {
    val key = SessionKey("user", UUID.randomUUID().toString)

    val exGet = intercept[SparkSQLException] {
      SparkConnectService.sessionManager.getIsolatedSession(key, None)
    }
    assert(exGet.getCondition == "INVALID_HANDLE.SESSION_NOT_FOUND")

    val sessionGetIfPresent = SparkConnectService.sessionManager.getIsolatedSessionIfPresent(key)
    assert(sessionGetIfPresent.isEmpty)
  }

  test("SessionHolder with custom expiration time is not cleaned up due to inactivity") {
    val key = SessionKey("user", UUID.randomUUID().toString)
    val sessionHolder = SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)

    assert(
      SparkConnectService.sessionManager.listActiveSessions.exists(
        _.sessionId == sessionHolder.sessionId))
    sessionHolder.setCustomInactiveTimeoutMs(Some(5.days.toMillis))

    // clean up with inactivity timeout of 0.
    SparkConnectService.sessionManager.periodicMaintenance(defaultInactiveTimeoutMs = 0L)
    // session should still be there.
    assert(
      SparkConnectService.sessionManager.listActiveSessions.exists(
        _.sessionId == sessionHolder.sessionId))

    sessionHolder.setCustomInactiveTimeoutMs(None)
    // it will be cleaned up now.
    SparkConnectService.sessionManager.periodicMaintenance(defaultInactiveTimeoutMs = 0L)
    assert(SparkConnectService.sessionManager.listActiveSessions.isEmpty)
    assert(
      SparkConnectService.sessionManager.listClosedSessions.exists(
        _.sessionId == sessionHolder.sessionId))
  }

  test("SessionHolder is recorded with status closed after close") {
    val key = SessionKey("user", UUID.randomUUID().toString)
    val sessionHolder = SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
    val activeSessionInfo = SparkConnectService.sessionManager.listActiveSessions.find(
      _.sessionId == sessionHolder.sessionId)
    assert(activeSessionInfo.isDefined)
    assert(activeSessionInfo.get.status == SessionStatus.Started)
    assert(activeSessionInfo.get.closedTimeMs.isEmpty)

    SparkConnectService.sessionManager.closeSession(sessionHolder.key)

    assert(SparkConnectService.sessionManager.listActiveSessions.isEmpty)
    val closedSessionInfo = SparkConnectService.sessionManager.listClosedSessions.find(
      _.sessionId == sessionHolder.sessionId)
    assert(closedSessionInfo.isDefined)
    assert(closedSessionInfo.get.status == SessionStatus.Closed)
    assert(closedSessionInfo.get.closedTimeMs.isDefined)
  }

  test("Pipeline execution cache is cleared when the session holder is closed") {
    val key = SessionKey("user", UUID.randomUUID().toString)
    val sessionHolder = SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
    val graphId = "test_graph"
    val pipelineUpdateContext = new PipelineUpdateContextImpl(
      new DataflowGraph(Seq(), Seq(), Seq(), Seq()),
      (_: PipelineEvent) => None,
      storageRoot = "file:///test_storage_root")
    sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
    assert(
      sessionHolder.getPipelineExecution(graphId).nonEmpty,
      "pipeline execution was not cached")
    SparkConnectService.sessionManager.closeSession(sessionHolder.key)
    assert(
      sessionHolder.getPipelineExecution(graphId).isEmpty,
      "pipeline execution was not removed")
  }

  test("baseSession allows creating sessions after default session is cleared") {
    // Create a new session manager to test initialization
    val sessionManager = new SparkConnectSessionManager()

    // Initialize the base session with the test SparkContext
    sessionManager.initializeBaseSession(spark.sparkContext)

    // Clear the default and active sessions to simulate the scenario where
    // SparkSession.active or SparkSession.getDefaultSession would fail
    SparkSession.clearDefaultSession()
    SparkSession.clearActiveSession()

    // Create an isolated session - this should still work because we have baseSession
    val key = SessionKey("user", UUID.randomUUID().toString)
    val sessionHolder = sessionManager.getOrCreateIsolatedSession(key, None)

    // Verify the session was created successfully
    assert(sessionHolder != null)
    assert(sessionHolder.session != null)

    // Clean up
    sessionManager.closeSession(key)
  }

  test("initializeBaseSession is idempotent") {
    // Create a new session manager to test initialization
    val sessionManager = new SparkConnectSessionManager()

    // Initialize the base session multiple times
    sessionManager.initializeBaseSession(spark.sparkContext)
    val key1 = SessionKey("user1", UUID.randomUUID().toString)
    val sessionHolder1 = sessionManager.getOrCreateIsolatedSession(key1, None)
    val baseSessionUUID1 = sessionHolder1.session.sessionUUID

    // Initialize again - should not change the base session
    sessionManager.initializeBaseSession(spark.sparkContext)
    val key2 = SessionKey("user2", UUID.randomUUID().toString)
    val sessionHolder2 = sessionManager.getOrCreateIsolatedSession(key2, None)

    // Both sessions should be isolated from each other
    assert(sessionHolder1.session.sessionUUID != sessionHolder2.session.sessionUUID)

    // Clean up
    sessionManager.closeSession(key1)
    sessionManager.closeSession(key2)
  }
}
