//                           _       _
// __      _____  __ ___   ___  __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
//  Copyright © 2016 - 2026 Weaviate B.V. All rights reserved.
//
//  CONTACT: hello@weaviate.io
//

//go:build integrationTest

package db

import (
	"context"
	"fmt"
	"testing"

	"github.com/sirupsen/logrus"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"

	replicationTypes "github.com/weaviate/weaviate/cluster/replication/types"
	"github.com/weaviate/weaviate/entities/aggregation"
	"github.com/weaviate/weaviate/entities/dto"
	"github.com/weaviate/weaviate/entities/filters"
	"github.com/weaviate/weaviate/entities/models"
	"github.com/weaviate/weaviate/entities/schema"
	"github.com/weaviate/weaviate/entities/search"
	"github.com/weaviate/weaviate/usecases/cluster"
	"github.com/weaviate/weaviate/usecases/memwatch"
	schemaUC "github.com/weaviate/weaviate/usecases/schema"
	"github.com/weaviate/weaviate/usecases/sharding"
)

// This test aims to prevent a regression on
// https://github.com/weaviate/weaviate/issues/1352
//
// It reuses the company-schema from the regular filters test, but runs them in
// isolation as to not interfere with the existing tests
func Test_LimitsOnChainedFilters(t *testing.T) {
	dirName := t.TempDir()

	logger := logrus.New()
	shardState := singleShardState()
	schemaGetter := &fakeSchemaGetter{
		schema:     schema.Schema{Objects: &models.Schema{Classes: nil}},
		shardState: shardState,
	}
	mockSchemaReader := schemaUC.NewMockSchemaReader(t)
	mockSchemaReader.EXPECT().Shards(mock.Anything).Return(shardState.AllPhysicalShards(), nil).Maybe()
	mockSchemaReader.EXPECT().Read(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(className string, retryIfClassNotFound bool, readFunc func(*models.Class, *sharding.State) error) error {
		class := &models.Class{Class: className}
		return readFunc(class, shardState)
	}).Maybe()
	mockSchemaReader.EXPECT().ReadOnlySchema().Return(models.Schema{Classes: nil}).Maybe()
	mockSchemaReader.EXPECT().ShardReplicas(mock.Anything, mock.Anything).Return([]string{"node1"}, nil).Maybe()
	mockReplicationFSMReader := replicationTypes.NewMockReplicationFSMReader(t)
	mockReplicationFSMReader.EXPECT().FilterOneShardReplicasRead(mock.Anything, mock.Anything, mock.Anything).Return([]string{"node1"}).Maybe()
	mockReplicationFSMReader.EXPECT().FilterOneShardReplicasWrite(mock.Anything, mock.Anything, mock.Anything).Return([]string{"node1"}, nil).Maybe()
	mockNodeSelector := cluster.NewMockNodeSelector(t)
	mockNodeSelector.EXPECT().LocalName().Return("node1").Maybe()
	mockNodeSelector.EXPECT().NodeHostname(mock.Anything).Return("node1", true).Maybe()
	repo, err := New(logger, "node1", Config{
		MemtablesFlushDirtyAfter:  60,
		RootPath:                  dirName,
		QueryMaximumResults:       10000,
		MaxImportGoroutinesFactor: 1,
	}, &FakeRemoteClient{}, mockNodeSelector, &FakeRemoteNodeClient{}, &FakeReplicationClient{}, nil, memwatch.NewDummyMonitor(),
		mockNodeSelector, mockSchemaReader, mockReplicationFSMReader)
	require.Nil(t, err)
	repo.SetSchemaGetter(schemaGetter)
	require.Nil(t, repo.WaitForStartup(testCtx()))
	defer repo.Shutdown(context.Background())

	migrator := NewMigrator(repo, logger, "node1")

	t.Run("creating the class", func(t *testing.T) {
		schema := schema.Schema{
			Objects: &models.Schema{
				Classes: []*models.Class{
					productClass,
					companyClass,
				},
			},
		}

		require.Nil(t,
			migrator.AddClass(context.Background(), productClass))
		require.Nil(t,
			migrator.AddClass(context.Background(), companyClass))

		schemaGetter.schema = schema
	})

	data := chainedFilterCompanies(100)

	t.Run("import companies", func(t *testing.T) {
		for i, company := range data {
			t.Run(fmt.Sprintf("importing product %d", i), func(t *testing.T) {
				require.Nil(t,
					repo.PutObject(context.Background(), company, []float32{0.1, 0.2, 0.01, 0.2}, nil, nil, nil, 0))
			})
		}
	})

	t.Run("combine two filters with a strict limit", func(t *testing.T) {
		limit := 20

		filter := filterAnd(
			buildFilter("price", 20, gte, dtInt),
			buildFilter("price", 100, lt, dtInt),
		)

		res, err := repo.Search(context.Background(), dto.GetParams{
			ClassName: companyClass.Class,
			Filters:   filter,
			Pagination: &filters.Pagination{
				Limit: limit,
			},
			Properties: search.SelectProperties{{Name: "price"}},
		})

		require.Nil(t, err)
		assert.Len(t, res, limit)

		for _, obj := range res {
			assert.Less(t, obj.Schema.(map[string]interface{})["price"].(float64),
				float64(100))
			assert.GreaterOrEqual(t,
				obj.Schema.(map[string]interface{})["price"].(float64), float64(20))
		}
	})
}

func chainedFilterCompanies(size int) []*models.Object {
	out := make([]*models.Object, size)

	for i := range out {
		out[i] = &models.Object{
			ID:    mustNewUUID(),
			Class: companyClass.Class,
			Properties: map[string]interface{}{
				"price": int64(i),
			},
		}
	}

	return out
}

// This test aims to prevent a regression on
// https://github.com/weaviate/weaviate/issues/1355
//
// It reuses the company-schema from the regular filters test, but runs them in
// isolation as to not interfere with the existing tests
func Test_FilterLimitsAfterUpdates(t *testing.T) {
	dirName := t.TempDir()

	logger := logrus.New()
	shardState := singleShardState()
	schemaGetter := &fakeSchemaGetter{
		schema:     schema.Schema{Objects: &models.Schema{Classes: nil}},
		shardState: shardState,
	}
	mockSchemaReader := schemaUC.NewMockSchemaReader(t)
	mockSchemaReader.EXPECT().Shards(mock.Anything).Return(shardState.AllPhysicalShards(), nil).Maybe()
	mockSchemaReader.EXPECT().Read(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(className string, retryIfClassNotFound bool, readFunc func(*models.Class, *sharding.State) error) error {
		class := &models.Class{Class: className}
		return readFunc(class, shardState)
	}).Maybe()
	mockSchemaReader.EXPECT().ReadOnlySchema().Return(models.Schema{Classes: nil}).Maybe()
	mockSchemaReader.EXPECT().ShardReplicas(mock.Anything, mock.Anything).Return([]string{"node1"}, nil).Maybe()
	mockReplicationFSMReader := replicationTypes.NewMockReplicationFSMReader(t)
	mockReplicationFSMReader.EXPECT().FilterOneShardReplicasRead(mock.Anything, mock.Anything, mock.Anything).Return([]string{"node1"}).Maybe()
	mockReplicationFSMReader.EXPECT().FilterOneShardReplicasWrite(mock.Anything, mock.Anything, mock.Anything).Return([]string{"node1"}, nil).Maybe()
	mockNodeSelector := cluster.NewMockNodeSelector(t)
	mockNodeSelector.EXPECT().LocalName().Return("node1").Maybe()
	mockNodeSelector.EXPECT().NodeHostname(mock.Anything).Return("node1", true).Maybe()
	repo, err := New(logger, "node1", Config{
		MemtablesFlushDirtyAfter:  60,
		RootPath:                  dirName,
		QueryMaximumResults:       10000,
		MaxImportGoroutinesFactor: 1,
	}, &FakeRemoteClient{}, mockNodeSelector, &FakeRemoteNodeClient{}, &FakeReplicationClient{}, nil, memwatch.NewDummyMonitor(),
		mockNodeSelector, mockSchemaReader, mockReplicationFSMReader)
	require.Nil(t, err)
	repo.SetSchemaGetter(schemaGetter)
	require.Nil(t, repo.WaitForStartup(testCtx()))
	defer repo.Shutdown(context.Background())

	migrator := NewMigrator(repo, logger, "node1")

	t.Run("creating the class", func(t *testing.T) {
		schema := schema.Schema{
			Objects: &models.Schema{
				Classes: []*models.Class{
					productClass,
					companyClass,
				},
			},
		}

		require.Nil(t,
			migrator.AddClass(context.Background(), productClass))
		require.Nil(t,
			migrator.AddClass(context.Background(), companyClass))

		schemaGetter.schema = schema
	})

	data := chainedFilterCompanies(100)

	t.Run("import companies", func(t *testing.T) {
		for i, company := range data {
			t.Run(fmt.Sprintf("importing product %d", i), func(t *testing.T) {
				require.Nil(t,
					repo.PutObject(context.Background(), company, []float32{0.1, 0.2, 0.01, 0.2}, nil, nil, nil, 0))
			})
		}
	})

	t.Run("verify all with ref count 0 are found", func(t *testing.T) {
		limit := 100
		filter := buildFilter("makesProduct", 0, eq, dtInt)
		res, err := repo.Search(context.Background(), dto.GetParams{
			ClassName: companyClass.Class,
			Filters:   filter,
			Pagination: &filters.Pagination{
				Limit: limit,
			},
		})

		require.Nil(t, err)
		assert.Len(t, res, limit)
	})

	t.Run("verify a non refcount prop", func(t *testing.T) {
		limit := 100
		filter := buildFilter("price", float64(0), gte, dtNumber)
		res, err := repo.Search(context.Background(), dto.GetParams{
			ClassName: companyClass.Class,
			Filters:   filter,
			Pagination: &filters.Pagination{
				Limit: limit,
			},
		})

		require.Nil(t, err)
		assert.Len(t, res, limit)
	})

	t.Run("perform updates on each company", func(t *testing.T) {
		// in this case we're altering the vector position, but it doesn't really
		// matter - what we want to provoke is to fill up our index with deleted
		// doc ids
		for i, company := range data {
			t.Run(fmt.Sprintf("importing product %d", i), func(t *testing.T) {
				require.Nil(t,
					repo.PutObject(context.Background(), company, []float32{0.1, 0.21, 0.01, 0.2}, nil, nil, nil, 0))
			})
		}
	})

	t.Run("verify all with ref count 0 are found", func(t *testing.T) {
		limit := 100
		filter := buildFilter("makesProduct", 0, eq, dtInt)
		res, err := repo.Search(context.Background(), dto.GetParams{
			ClassName: companyClass.Class,
			Filters:   filter,
			Pagination: &filters.Pagination{
				Limit: limit,
			},
		})

		require.Nil(t, err)
		assert.Len(t, res, limit)
	})

	t.Run("verify a non refcount prop", func(t *testing.T) {
		limit := 100
		filter := buildFilter("price", float64(0), gte, dtNumber)
		res, err := repo.Search(context.Background(), dto.GetParams{
			ClassName: companyClass.Class,
			Filters:   filter,
			Pagination: &filters.Pagination{
				Limit: limit,
			},
		})

		require.Nil(t, err)
		assert.Len(t, res, limit)
	})
}

// This test aims to prevent a regression on
// https://github.com/weaviate/weaviate/issues/1356
//
// It reuses the company-schema from the regular filters test, but runs them in
// isolation as to not interfere with the existing tests
func Test_AggregationsAfterUpdates(t *testing.T) {
	dirName := t.TempDir()

	logger := logrus.New()
	shardState := singleShardState()
	schemaGetter := &fakeSchemaGetter{
		schema:     schema.Schema{Objects: &models.Schema{Classes: nil}},
		shardState: shardState,
	}
	mockSchemaReader := schemaUC.NewMockSchemaReader(t)
	mockSchemaReader.EXPECT().Shards(mock.Anything).Return(shardState.AllPhysicalShards(), nil).Maybe()
	mockSchemaReader.EXPECT().Read(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(className string, retryIfClassNotFound bool, readFunc func(*models.Class, *sharding.State) error) error {
		class := &models.Class{Class: className}
		return readFunc(class, shardState)
	}).Maybe()
	mockSchemaReader.EXPECT().ReadOnlySchema().Return(models.Schema{Classes: nil}).Maybe()
	mockSchemaReader.EXPECT().ShardReplicas(mock.Anything, mock.Anything).Return([]string{"node1"}, nil).Maybe()
	mockReplicationFSMReader := replicationTypes.NewMockReplicationFSMReader(t)
	mockReplicationFSMReader.EXPECT().FilterOneShardReplicasRead(mock.Anything, mock.Anything, mock.Anything).Return([]string{"node1"}).Maybe()
	mockReplicationFSMReader.EXPECT().FilterOneShardReplicasWrite(mock.Anything, mock.Anything, mock.Anything).Return([]string{"node1"}, nil).Maybe()
	mockNodeSelector := cluster.NewMockNodeSelector(t)
	mockNodeSelector.EXPECT().LocalName().Return("node1").Maybe()
	mockNodeSelector.EXPECT().NodeHostname(mock.Anything).Return("node1", true).Maybe()
	repo, err := New(logger, "node1", Config{
		MemtablesFlushDirtyAfter:  60,
		RootPath:                  dirName,
		QueryMaximumResults:       10000,
		MaxImportGoroutinesFactor: 1,
	}, &FakeRemoteClient{}, mockNodeSelector, &FakeRemoteNodeClient{}, &FakeReplicationClient{}, nil, memwatch.NewDummyMonitor(),
		mockNodeSelector, mockSchemaReader, mockReplicationFSMReader)
	require.Nil(t, err)
	repo.SetSchemaGetter(schemaGetter)
	require.Nil(t, repo.WaitForStartup(testCtx()))
	defer repo.Shutdown(context.Background())

	migrator := NewMigrator(repo, logger, "node1")

	t.Run("creating the class", func(t *testing.T) {
		schema := schema.Schema{
			Objects: &models.Schema{
				Classes: []*models.Class{
					productClass,
					companyClass,
				},
			},
		}

		require.Nil(t,
			migrator.AddClass(context.Background(), productClass))
		require.Nil(t,
			migrator.AddClass(context.Background(), companyClass))

		schemaGetter.schema = schema
	})

	data := chainedFilterCompanies(100)

	t.Run("import companies", func(t *testing.T) {
		for i, company := range data {
			t.Run(fmt.Sprintf("importing product %d", i), func(t *testing.T) {
				require.Nil(t,
					repo.PutObject(context.Background(), company, []float32{0.1, 0.2, 0.01, 0.2}, nil, nil, nil, 0))
			})
		}
	})

	t.Run("verify all with ref count 0 are correctly aggregated",
		func(t *testing.T) {
			filter := buildFilter("makesProduct", 0, eq, dtInt)
			res, err := repo.Aggregate(context.Background(),
				aggregation.Params{
					ClassName:        schema.ClassName(companyClass.Class),
					Filters:          filter,
					IncludeMetaCount: true,
				}, nil)

			require.Nil(t, err)
			require.Len(t, res.Groups, 1)
			assert.Equal(t, res.Groups[0].Count, 100)
		})

	t.Run("perform updates on each company", func(t *testing.T) {
		// in this case we're altering the vector position, but it doesn't really
		// matter - what we want to provoke is to fill up our index with deleted
		// doc ids
		for i, company := range data {
			t.Run(fmt.Sprintf("importing product %d", i), func(t *testing.T) {
				require.Nil(t,
					repo.PutObject(context.Background(), company, []float32{0.1, 0.21, 0.01, 0.2}, nil, nil, nil, 0))
			})
		}
	})

	t.Run("verify all with ref count 0 are correctly aggregated",
		func(t *testing.T) {
			filter := buildFilter("makesProduct", 0, eq, dtInt)
			res, err := repo.Aggregate(context.Background(),
				aggregation.Params{
					ClassName:        schema.ClassName(companyClass.Class),
					Filters:          filter,
					IncludeMetaCount: true,
				}, nil)

			require.Nil(t, err)
			require.Len(t, res.Groups, 1)
			assert.Equal(t, res.Groups[0].Count, 100)
		})

	t.Run("verify all with ref count 0 are correctly aggregated",
		func(t *testing.T) {
			filter := buildFilter("makesProduct", 0, eq, dtInt)
			res, err := repo.Aggregate(context.Background(),
				aggregation.Params{
					ClassName:        schema.ClassName(companyClass.Class),
					Filters:          filter,
					IncludeMetaCount: true,
				}, nil)

			require.Nil(t, err)
			require.Len(t, res.Groups, 1)
			assert.Equal(t, 100, res.Groups[0].Count)
		})
}
