//                           _       _
// __      _____  __ ___   ___  __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
//  Copyright © 2016 - 2026 Weaviate B.V. All rights reserved.
//
//  CONTACT: hello@weaviate.io
//

package tokenizer

import (
	"os"
	"runtime"
	"strconv"
	"strings"
	"sync"
	"time"
	"unicode"

	entcfg "github.com/weaviate/weaviate/entities/config"

	"github.com/go-ego/gse"
	koDict "github.com/ikawaha/kagome-dict-ko"
	"github.com/ikawaha/kagome-dict/dict"
	"github.com/ikawaha/kagome-dict/ipa"
	kagomeTokenizer "github.com/ikawaha/kagome/v2/tokenizer"
	"github.com/weaviate/weaviate/entities/models"
	"github.com/weaviate/weaviate/usecases/monitoring"
)

var (
	gseTokenizer   *gse.Segmenter  // Japanese
	gseTokenizerCh *gse.Segmenter  // Chinese
	gseLock        = &sync.Mutex{} // Lock for gse
	UseGse         = false         // Load Japanese dictionary and prepare tokenizer
	UseGseCh       = false         // Load Chinese dictionary and prepare tokenizer
	// The Tokenizer Libraries can consume a lot of memory, so we limit the number of parallel tokenizers
	ApacTokenizerThrottle = chan struct{}(nil) // Throttle for tokenizers
	tokenizers            KagomeTokenizers     // Tokenizers for Korean and Japanese
	kagomeInitLock        sync.Mutex           // Lock for kagome initialization

	customTokenizers sync.Map
)

type KagomeTokenizers struct {
	Korean   *kagomeTokenizer.Tokenizer
	Japanese *kagomeTokenizer.Tokenizer
}

// Optional tokenizers can be enabled with an environment variable like:
// 'ENABLE_TOKENIZER_XXX', e.g. 'ENABLE_TOKENIZER_GSE', 'ENABLE_TOKENIZER_KAGOME_KR', 'ENABLE_TOKENIZER_KAGOME_JA'
var Tokenizations []string = []string{
	models.PropertyTokenizationWord,
	models.PropertyTokenizationLowercase,
	models.PropertyTokenizationWhitespace,
	models.PropertyTokenizationField,
	models.PropertyTokenizationTrigram,
}

func init() {
	numParallel := runtime.GOMAXPROCS(0)
	numParallelStr := os.Getenv("TOKENIZER_CONCURRENCY_COUNT")
	if numParallelStr != "" {
		x, err := strconv.Atoi(numParallelStr)
		if err == nil {
			numParallel = x
		}
	}
	ApacTokenizerThrottle = make(chan struct{}, numParallel)
	InitOptionalTokenizers()
	customTokenizers = sync.Map{}
}

func InitOptionalTokenizers() {
	if entcfg.Enabled(os.Getenv("USE_GSE")) || entcfg.Enabled(os.Getenv("ENABLE_TOKENIZER_GSE")) {
		UseGse = true
		Tokenizations = append(Tokenizations, models.PropertyTokenizationGse)
		init_gse()
	}
	if entcfg.Enabled(os.Getenv("ENABLE_TOKENIZER_GSE_CH")) {
		Tokenizations = append(Tokenizations, models.PropertyTokenizationGseCh)
		UseGseCh = true
		init_gse_ch()
	}
	if entcfg.Enabled(os.Getenv("ENABLE_TOKENIZER_KAGOME_KR")) && tokenizers.Korean == nil {
		func() {
			kagomeInitLock.Lock()
			defer kagomeInitLock.Unlock()
			Tokenizations = append(Tokenizations, models.PropertyTokenizationKagomeKr)
			tokenizers.Korean, _ = initializeKagomeTokenizerKr(nil)
		}()
	}
	if entcfg.Enabled(os.Getenv("ENABLE_TOKENIZER_KAGOME_JA")) && tokenizers.Japanese == nil {
		func() {
			kagomeInitLock.Lock()
			defer kagomeInitLock.Unlock()
			Tokenizations = append(Tokenizations, models.PropertyTokenizationKagomeJa)
			tokenizers.Japanese, _ = initializeKagomeTokenizerJa(nil)
		}()
	}
}

func init_gse() {
	gseLock.Lock()
	defer gseLock.Unlock()
	if gseTokenizer == nil {
		startTime := time.Now()
		seg, err := gse.New("ja")
		if err != nil {
			return
		}
		gseTokenizer = &seg
		monitoring.GetMetrics().TokenizerInitializeDuration.WithLabelValues("gse").Observe(time.Since(startTime).Seconds())
	}
}

func init_gse_ch() {
	gseLock.Lock()
	defer gseLock.Unlock()
	if gseTokenizerCh == nil {
		startTime := time.Now()
		seg, err := gse.New("zh")
		if err != nil {
			return
		}
		gseTokenizerCh = &seg
		monitoring.GetMetrics().TokenizerInitializeDuration.WithLabelValues("gse").Observe(time.Since(startTime).Seconds())
	}
}

func TokenizeForClass(tokenization string, in string, class string) []string {
	tokenizer, ok := customTokenizers.Load(class)
	if tokenization == models.PropertyTokenizationKagomeKr && ok && tokenizer.(*KagomeTokenizers).Korean != nil {
		ApacTokenizerThrottle <- struct{}{}
		defer func() { <-ApacTokenizerThrottle }()
		return tokenizeKagome(tokenizer.(*KagomeTokenizers).Korean, kagomeTokenizer.Normal, models.PropertyTokenizationKagomeKr, in)
	} else if tokenization == models.PropertyTokenizationKagomeJa && ok && tokenizer.(*KagomeTokenizers).Japanese != nil {
		defer func() { <-ApacTokenizerThrottle }()
		return tokenizeKagome(tokenizer.(*KagomeTokenizers).Japanese, kagomeTokenizer.Search, models.PropertyTokenizationKagomeJa, in)
	} else {
		return Tokenize(tokenization, in)
	}
}

func Tokenize(tokenization string, in string) []string {
	switch tokenization {
	case models.PropertyTokenizationWord:
		return tokenizeWord(in)
	case models.PropertyTokenizationLowercase:
		return tokenizeLowercase(in)
	case models.PropertyTokenizationWhitespace:
		return tokenizeWhitespace(in)
	case models.PropertyTokenizationField:
		return tokenizeField(in)
	case models.PropertyTokenizationTrigram:
		return tokenizetrigram(in)
	case models.PropertyTokenizationGse:
		ApacTokenizerThrottle <- struct{}{}
		defer func() { <-ApacTokenizerThrottle }()
		return tokenizeGSE(in)
	case models.PropertyTokenizationGseCh:
		ApacTokenizerThrottle <- struct{}{}
		defer func() { <-ApacTokenizerThrottle }()
		return tokenizeGseCh(in)
	case models.PropertyTokenizationKagomeKr:
		ApacTokenizerThrottle <- struct{}{}
		defer func() { <-ApacTokenizerThrottle }()
		return tokenizeKagome(tokenizers.Korean, kagomeTokenizer.Normal, models.PropertyTokenizationKagomeKr, in)
	case models.PropertyTokenizationKagomeJa:
		ApacTokenizerThrottle <- struct{}{}
		defer func() { <-ApacTokenizerThrottle }()
		return tokenizeKagome(tokenizers.Japanese, kagomeTokenizer.Search, models.PropertyTokenizationKagomeJa, in)
	default:
		return []string{}
	}
}

func TokenizeWithWildcardsForClass(tokenization string, in string, class string) []string {
	switch tokenization {
	case models.PropertyTokenizationWord:
		return tokenizeWordWithWildcards(in)
	case models.PropertyTokenizationTrigram:
		return tokenizetrigramWithWildcards(in)
	default:
		return TokenizeForClass(tokenization, in, class)
	}
}

func removeEmptyStrings(terms []string) []string {
	for i := 0; i < len(terms); i++ {
		if terms[i] == "" || terms[i] == " " {
			terms = append(terms[:i], terms[i+1:]...)
			i--
		}
	}
	return terms
}

// tokenizeField trims white spaces
// (former DataTypeString/Field)
func tokenizeField(in string) []string {
	startTime := time.Now()
	ret := []string{strings.TrimFunc(in, unicode.IsSpace)}
	monitoring.GetMetrics().TokenizerDuration.WithLabelValues("field").Observe(float64(time.Since(startTime).Seconds()))
	monitoring.GetMetrics().TokenCount.WithLabelValues("field").Add(float64(len(ret)))
	monitoring.GetMetrics().TokenCountPerRequest.WithLabelValues("field").Observe(float64(len(ret)))
	return ret
}

// tokenizeWhitespace splits on white spaces, does not alter casing
// (former DataTypeString/Word)
func tokenizeWhitespace(in string) []string {
	startTime := time.Now()
	ret := strings.FieldsFunc(in, unicode.IsSpace)
	monitoring.GetMetrics().TokenizerDuration.WithLabelValues("whitespace").Observe(float64(time.Since(startTime).Seconds()))
	monitoring.GetMetrics().TokenCount.WithLabelValues("whitespace").Add(float64(len(ret)))
	monitoring.GetMetrics().TokenCountPerRequest.WithLabelValues("whitespace").Observe(float64(len(ret)))
	return ret
}

// tokenizeLowercase splits on white spaces and lowercases the words
func tokenizeLowercase(in string) []string {
	startTime := time.Now()
	terms := tokenizeWhitespace(in)
	ret := lowercase(terms)
	monitoring.GetMetrics().TokenizerDuration.WithLabelValues("lowercase").Observe(float64(time.Since(startTime).Seconds()))
	return ret
}

// tokenizeWord splits on any non-alphanumerical and lowercases the words
// (former DataTypeText/Word)
func tokenizeWord(in string) []string {
	startTime := time.Now()
	terms := strings.FieldsFunc(in, func(r rune) bool {
		return !unicode.IsLetter(r) && !unicode.IsNumber(r)
	})
	ret := lowercase(terms)
	monitoring.GetMetrics().TokenizerDuration.WithLabelValues("word").Observe(float64(time.Since(startTime).Seconds()))
	monitoring.GetMetrics().TokenCount.WithLabelValues("word").Add(float64(len(ret)))
	monitoring.GetMetrics().TokenCountPerRequest.WithLabelValues("word").Observe(float64(len(ret)))
	return ret
}

// tokenizetrigram splits on any non-alphanumerical and lowercases the words, joins them together, then groups them into trigrams
func tokenizetrigram(in string) []string {
	startTime := time.Now()
	// Strip whitespace and punctuation from the input string
	inputString := strings.ToLower(strings.Join(strings.FieldsFunc(in, func(r rune) bool {
		return !unicode.IsLetter(r) && !unicode.IsNumber(r)
	}), ""))
	runes := []rune(inputString)
	var trirunes [][]rune
	for i := 0; i < len(runes)-2; i++ {
		trirunes = append(trirunes, runes[i:i+3])
	}

	var trigrams []string
	for _, trirune := range trirunes {
		trigrams = append(trigrams, string(trirune))
	}
	monitoring.GetMetrics().TokenizerDuration.WithLabelValues("trigram").Observe(float64(time.Since(startTime).Seconds()))
	monitoring.GetMetrics().TokenCount.WithLabelValues("trigram").Add(float64(len(trigrams)))
	monitoring.GetMetrics().TokenCountPerRequest.WithLabelValues("trigram").Observe(float64(len(trigrams)))
	return trigrams
}

// tokenizeGSE uses the gse tokenizer to tokenise Japanese
func tokenizeGSE(in string) []string {
	if !UseGse {
		return []string{}
	}
	startTime := time.Now()
	gseLock.Lock()
	defer gseLock.Unlock()
	terms := gseTokenizer.CutAll(in)

	ret := removeEmptyStrings(terms)

	monitoring.GetMetrics().TokenizerDuration.WithLabelValues("gse").Observe(float64(time.Since(startTime).Seconds()))
	monitoring.GetMetrics().TokenCount.WithLabelValues("gse").Add(float64(len(ret)))
	monitoring.GetMetrics().TokenCountPerRequest.WithLabelValues("gse").Observe(float64(len(ret)))
	return ret
}

// tokenizeGSE uses the gse tokenizer to tokenise Chinese
func tokenizeGseCh(in string) []string {
	if !UseGseCh {
		return []string{}
	}
	gseLock.Lock()
	defer gseLock.Unlock()
	startTime := time.Now()
	terms := gseTokenizerCh.CutAll(in)
	ret := removeEmptyStrings(terms)

	monitoring.GetMetrics().TokenizerDuration.WithLabelValues("gse").Observe(float64(time.Since(startTime).Seconds()))
	monitoring.GetMetrics().TokenCount.WithLabelValues("gse").Add(float64(len(ret)))
	monitoring.GetMetrics().TokenCountPerRequest.WithLabelValues("gse").Observe(float64(len(ret)))
	return ret
}

func initializeKagomeTokenizerKr(userDict *models.TokenizerUserDictConfig) (*kagomeTokenizer.Tokenizer, error) {
	startTime := time.Now()

	dictInstance := koDict.Dict()
	tokenizer, err := initializeKagomeTokenizer(dictInstance, userDict)
	if err != nil {
		return nil, err
	}
	monitoring.GetMetrics().TokenizerInitializeDuration.WithLabelValues(models.PropertyTokenizationKagomeKr).Observe(float64(time.Since(startTime).Seconds()))
	return tokenizer, nil
}

func initializeKagomeTokenizerJa(userDict *models.TokenizerUserDictConfig) (*kagomeTokenizer.Tokenizer, error) {
	startTime := time.Now()

	dictInstance := ipa.Dict()
	tokenizer, err := initializeKagomeTokenizer(dictInstance, userDict)
	if err != nil {
		return nil, err
	}
	monitoring.GetMetrics().TokenizerInitializeDuration.WithLabelValues(models.PropertyTokenizationKagomeJa).Observe(float64(time.Since(startTime).Seconds()))
	return tokenizer, nil
}

func initializeKagomeTokenizer(dictInstance *dict.Dict, userDict *models.TokenizerUserDictConfig) (*kagomeTokenizer.Tokenizer, error) {
	options := []kagomeTokenizer.Option{
		kagomeTokenizer.OmitBosEos(),
	}

	if userDict != nil {
		dict, err := NewUserDictFromModel(userDict)
		if err != nil {
			return nil, err
		}
		if dict != nil {
			options = append(options, kagomeTokenizer.UserDict(dict))
		}
	}
	tokenizer, err := kagomeTokenizer.New(dictInstance, options...)
	if err != nil {
		return nil, err
	}
	return tokenizer, nil
}

func tokenizeKagome(tokenizer *kagomeTokenizer.Tokenizer, mode kagomeTokenizer.TokenizeMode, label string, in string) []string {
	if label == models.PropertyTokenizationKagomeJa && tokenizer == nil {
		return []string{}
	}
	if label == models.PropertyTokenizationKagomeKr && tokenizer == nil {
		return []string{}
	}
	startTime := time.Now()
	kagomeTokens := tokenizer.Analyze(in, mode)
	var terms []string
	for _, token := range kagomeTokens {
		if extra := token.UserExtra(); extra != nil {
			terms = append(terms, extra.Tokens...)
		} else {
			terms = append(terms, token.Surface)
		}
	}

	ret := removeEmptyStrings(terms)
	monitoring.GetMetrics().TokenizerDuration.WithLabelValues(label).Observe(float64(time.Since(startTime).Seconds()))
	monitoring.GetMetrics().TokenCount.WithLabelValues(label).Add(float64(len(ret)))
	monitoring.GetMetrics().TokenCountPerRequest.WithLabelValues(label).Observe(float64(len(ret)))
	return ret
}

// tokenizeWordWithWildcards splits on any non-alphanumerical except wildcard-symbols and
// lowercases the words
func tokenizeWordWithWildcards(in string) []string {
	startTime := time.Now()
	terms := strings.FieldsFunc(in, func(r rune) bool {
		return !unicode.IsLetter(r) && !unicode.IsNumber(r) && r != '?' && r != '*'
	})
	ret := lowercase(terms)
	monitoring.GetMetrics().TokenizerDuration.WithLabelValues("word_with_wildcards").Observe(float64(time.Since(startTime).Seconds()))
	monitoring.GetMetrics().TokenCount.WithLabelValues("word_with_wildcards").Add(float64(len(ret)))
	monitoring.GetMetrics().TokenCountPerRequest.WithLabelValues("word_with_wildcards").Observe(float64(len(ret)))
	return ret
}

// tokenizetrigramWithWildcards splits on any non-alphanumerical and lowercases the words, applies any wildcards, then joins them together, then groups them into trigrams
// this is unlikely to be useful, but is included for completeness
func tokenizetrigramWithWildcards(in string) []string {
	startTime := time.Now()
	terms := tokenizeWordWithWildcards(in)
	inputString := strings.Join(terms, "")
	var trigrams []string
	for i := 0; i < len(inputString)-2; i++ {
		trigrams = append(trigrams, inputString[i:i+3])
	}
	monitoring.GetMetrics().TokenizerDuration.WithLabelValues("trigram_with_wildcards").Observe(float64(time.Since(startTime).Seconds()))
	monitoring.GetMetrics().TokenCount.WithLabelValues("trigram_with_wildcards").Add(float64(len(trigrams)))
	monitoring.GetMetrics().TokenCountPerRequest.WithLabelValues("trigram_with_wildcards").Observe(float64(len(trigrams)))
	return trigrams
}

func lowercase(terms []string) []string {
	for i := range terms {
		terms[i] = strings.ToLower(terms[i])
	}
	monitoring.GetMetrics().TokenCount.WithLabelValues("lowercase").Add(float64(len(terms)))
	monitoring.GetMetrics().TokenCountPerRequest.WithLabelValues("lowercase").Observe(float64(len(terms)))
	return terms
}

func TokenizeAndCountDuplicatesForClass(tokenization string, in string, class string) ([]string, []int) {
	counts := map[string]int{}
	for _, term := range TokenizeForClass(tokenization, in, class) {
		counts[term]++
	}

	unique := make([]string, len(counts))
	boosts := make([]int, len(counts))

	i := 0
	for term, boost := range counts {
		unique[i] = term
		boosts[i] = boost
		i++
	}

	return unique, boosts
}
