//                           _       _
// __      _____  __ ___   ___  __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
//  Copyright © 2016 - 2026 Weaviate B.V. All rights reserved.
//
//  CONTACT: hello@weaviate.io
//

package geo

import (
	"context"
	"fmt"
	"time"

	"github.com/pkg/errors"
	"github.com/sirupsen/logrus"
	"github.com/weaviate/weaviate/adapters/repos/db/helpers"
	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw"
	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
	"github.com/weaviate/weaviate/entities/cyclemanager"
	"github.com/weaviate/weaviate/entities/filters"
	"github.com/weaviate/weaviate/entities/models"
	hnswent "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
	"github.com/weaviate/weaviate/usecases/memwatch"
)

const DefaultHNSWEF = 800

// Index wraps another index to provide geo searches. This allows us to reuse
// the hnsw vector index, without making geo searches dependent on
// hnsw-specific features.
//
// In the future we could use this level of abstraction to provide a better
// suited geo-index if we deem it necessary
type Index struct {
	config      Config
	vectorIndex vectorIndex
}

// vectorIndex represents the underlying vector index, typically hnsw
type vectorIndex interface {
	Add(ctx context.Context, id uint64, vector []float32) error
	KnnSearchByVectorMaxDist(ctx context.Context, query []float32, dist float32, ef int,
		allowList helpers.AllowList) ([]uint64, error)
	Delete(id ...uint64) error
	Dump(...string)
	Drop(ctx context.Context, keepFiles bool) error
	PostStartup(ctx context.Context)
}

// Config is passed to the GeoIndex when its created
type Config struct {
	ID                 string
	CoordinatesForID   CoordinatesForID
	DisablePersistence bool
	RootPath           string
	Logger             logrus.FieldLogger

	HNSWEF int

	SnapshotDisabled                         bool
	SnapshotOnStartup                        bool
	SnapshotCreateInterval                   time.Duration
	SnapshotMinDeltaCommitlogsNumer          int
	SnapshotMinDeltaCommitlogsSizePercentage int
	AllocChecker                             memwatch.AllocChecker
}

func (c Config) hnswEF() int {
	if c.HNSWEF > 0 {
		return c.HNSWEF
	}
	return DefaultHNSWEF
}

func NewIndex(config Config,
	commitLogMaintenanceCallbacks, tombstoneCleanupCallbacks cyclemanager.CycleCallbackGroup,
) (*Index, error) {
	vi, err := hnsw.New(hnsw.Config{
		VectorForIDThunk:      config.CoordinatesForID.VectorForID,
		ID:                    config.ID,
		RootPath:              config.RootPath,
		MakeCommitLoggerThunk: makeCommitLoggerFromConfig(config, commitLogMaintenanceCallbacks),
		DistanceProvider:      distancer.NewGeoProvider(),
		DisableSnapshots:      config.SnapshotDisabled,
		SnapshotOnStartup:     config.SnapshotOnStartup,
		AllocChecker:          config.AllocChecker,
	}, hnswent.UserConfig{
		MaxConnections:         64,
		EFConstruction:         128,
		CleanupIntervalSeconds: hnswent.DefaultCleanupIntervalSeconds,
	}, tombstoneCleanupCallbacks, nil)
	if err != nil {
		return nil, errors.Wrap(err, "underlying hnsw index")
	}

	i := &Index{
		config:      config,
		vectorIndex: vi,
	}

	return i, nil
}

func (i *Index) Drop(ctx context.Context, keepFiles bool) error {
	if err := i.vectorIndex.Drop(ctx, keepFiles); err != nil {
		return err
	}

	i.vectorIndex = nil
	return nil
}

func (i *Index) PostStartup(ctx context.Context) {
	i.vectorIndex.PostStartup(ctx)
}

func makeCommitLoggerFromConfig(config Config, maintenanceCallbacks cyclemanager.CycleCallbackGroup,
) hnsw.MakeCommitLogger {
	makeCL := hnsw.MakeNoopCommitLogger
	if !config.DisablePersistence {
		makeCL = func() (hnsw.CommitLogger, error) {
			return hnsw.NewCommitLogger(config.RootPath, config.ID, config.Logger, maintenanceCallbacks,
				hnsw.WithSnapshotDisabled(config.SnapshotDisabled),
				hnsw.WithSnapshotCreateInterval(config.SnapshotCreateInterval),
				hnsw.WithSnapshotMinDeltaCommitlogsNumer(config.SnapshotMinDeltaCommitlogsNumer),
				hnsw.WithSnapshotMinDeltaCommitlogsSizePercentage(config.SnapshotMinDeltaCommitlogsSizePercentage),
			)
		}
	}
	return makeCL
}

// Add extends the index with the specified GeoCoordinates. It is thread-safe
// and can be called concurrently.
func (i *Index) Add(ctx context.Context, id uint64, coordinates *models.GeoCoordinates) error {
	v, err := geoCoordiantesToVector(coordinates)
	if err != nil {
		return errors.Wrap(err, "invalid arguments")
	}

	return i.vectorIndex.Add(ctx, id, v)
}

// WithinGeoRange searches the index by the specified range. It is thread-safe
// and can be called concurrently.
func (i *Index) WithinRange(ctx context.Context,
	geoRange filters.GeoRange,
) ([]uint64, error) {
	if geoRange.GeoCoordinates == nil {
		return nil, fmt.Errorf("invalid arguments: GeoCoordinates in range must be set")
	}

	query, err := geoCoordiantesToVector(geoRange.GeoCoordinates)
	if err != nil {
		return nil, errors.Wrap(err, "invalid arguments")
	}

	return i.vectorIndex.KnnSearchByVectorMaxDist(ctx, query, geoRange.Distance, i.config.hnswEF(), nil)
}

func (i *Index) Delete(id uint64) error {
	return i.vectorIndex.Delete(id)
}
