//                           _       _
// __      _____  __ ___   ___  __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
//  Copyright © 2016 - 2026 Weaviate B.V. All rights reserved.
//
//  CONTACT: hello@weaviate.io
//

package grpc

import (
	"context"
	"encoding/base64"
	"errors"
	"fmt"
	"net"
	"strings"
	"time"

	"google.golang.org/grpc/peer"

	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
	grpc_sentry "github.com/johnbellone/grpc-middleware-sentry"
	"github.com/sirupsen/logrus"
	"github.com/weaviate/weaviate/adapters/handlers/rest/state"
	pbv0 "github.com/weaviate/weaviate/grpc/generated/protocol/v0"
	pbv1 "github.com/weaviate/weaviate/grpc/generated/protocol/v1"
	"github.com/weaviate/weaviate/usecases/auth/authentication/composer"
	authErrs "github.com/weaviate/weaviate/usecases/auth/authorization/errors"
	"github.com/weaviate/weaviate/usecases/config"
	"github.com/weaviate/weaviate/usecases/monitoring"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/credentials"
	_ "google.golang.org/grpc/encoding/gzip" // Install the gzip compressor
	"google.golang.org/grpc/health/grpc_health_v1"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/proto"

	v0 "github.com/weaviate/weaviate/adapters/handlers/grpc/v0"
	v1 "github.com/weaviate/weaviate/adapters/handlers/grpc/v1"
	"github.com/weaviate/weaviate/adapters/handlers/grpc/v1/auth"
	"github.com/weaviate/weaviate/adapters/handlers/grpc/v1/batch"
)

// CreateGRPCServer creates *grpc.Server with optional grpc.Serveroption passed.
func CreateGRPCServer(state *state.State, options ...grpc.ServerOption) (*grpc.Server, batch.Drain) {
	o := []grpc.ServerOption{
		grpc.MaxRecvMsgSize(state.ServerConfig.Config.GRPC.MaxMsgSize),
		grpc.MaxSendMsgSize(state.ServerConfig.Config.GRPC.MaxMsgSize),
	}

	o = append(o, options...)

	// Add TLS creds for the GRPC connection, if defined.
	if len(state.ServerConfig.Config.GRPC.CertFile) > 0 || len(state.ServerConfig.Config.GRPC.KeyFile) > 0 {
		c, err := credentials.NewServerTLSFromFile(state.ServerConfig.Config.GRPC.CertFile,
			state.ServerConfig.Config.GRPC.KeyFile)
		if err != nil {
			state.Logger.WithField("action", "grpc_startup").
				Fatalf("grpc server TLS credential error: %s", err)
		}
		o = append(o, grpc.Creds(c))
	}

	var interceptors []grpc.UnaryServerInterceptor

	interceptors = append(interceptors, makeAuthInterceptor())

	basicAuth := state.ServerConfig.Config.Cluster.AuthConfig.BasicAuth
	if basicAuth.Enabled() {
		interceptors = append(interceptors,
			basicAuthUnaryInterceptor("/weaviate.v1.FileReplicationService", basicAuth.Username, basicAuth.Password))

		o = append(o, grpc.StreamInterceptor(
			basicAuthStreamInterceptor("/weaviate.v1.FileReplicationService", basicAuth.Username, basicAuth.Password),
		))
	}

	// If sentry is enabled add automatic spans on gRPC requests
	if state.ServerConfig.Config.Sentry.Enabled {
		interceptors = append(interceptors, grpc_middleware.ChainUnaryServer(
			grpc_sentry.UnaryServerInterceptor(),
		))
	}

	if state.Metrics != nil {
		interceptors = append(interceptors, makeMetricsInterceptor(state.Logger, state.Metrics))
	}

	interceptors = append(interceptors, makeIPInterceptor())
	interceptors = append(interceptors, makeOperationalModeInterceptor(state))
	interceptors = append(interceptors, makeMaintenanceModeUnaryInterceptor(state.Cluster.MaintenanceModeEnabledForLocalhost))

	// Add OpenTelemetry tracing interceptors
	interceptors = append(interceptors, monitoring.GRPCTracingInterceptor())

	if len(interceptors) > 0 {
		o = append(o, grpc.ChainUnaryInterceptor(interceptors...))
	}

	allowAnonymous := state.ServerConfig.Config.Authentication.AnonymousAccess.Enabled
	authComposer := composer.New(
		state.ServerConfig.Config.Authentication,
		state.APIKey,
		state.OIDC,
	)

	o = append(o, grpc.ChainStreamInterceptor(makeAuthStreamInterceptor(auth.NewHandler(allowAnonymous, authComposer))))
	o = append(o, grpc.ChainStreamInterceptor(makeMaintenanceModeStreamInterceptor(state.Cluster.MaintenanceModeEnabledForLocalhost)))

	s := grpc.NewServer(o...)
	weaviateV0 := v0.NewService()
	weaviateV1, drainBatch := v1.NewService(allowAnonymous, authComposer, state)
	pbv0.RegisterWeaviateServer(s, weaviateV0)
	pbv1.RegisterWeaviateServer(s, weaviateV1)

	grpc_health_v1.RegisterHealthServer(s, weaviateV1)

	return s, drainBatch
}

func makeMetricsInterceptor(logger logrus.FieldLogger, metrics *monitoring.PrometheusMetrics) grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		if info.FullMethod != "/weaviate.v1.Weaviate/BatchObjects" {
			return handler(ctx, req)
		}

		// For now only Batch has specific metrics (in line with http API)
		startTime := time.Now()
		reqSizeBytes := float64(proto.Size(req.(proto.Message)))
		reqSizeMB := float64(reqSizeBytes) / (1024 * 1024)
		// Invoke the handler to process the request
		resp, err := handler(ctx, req)

		// Measure duration
		duration := time.Since(startTime)

		logger.WithFields(logrus.Fields{
			"action":             "grpc_batch_objects",
			"method":             info.FullMethod,
			"request_size_bytes": reqSizeBytes,
			"duration":           duration,
		}).Debugf("grpc BatchObjects request (%fMB) took %s", reqSizeMB, duration)

		// Metric uses non-standard base unit ms, use ms for backwards compatibility
		metrics.BatchTime.WithLabelValues("total_api_level_grpc", "n/a", "n/a").
			Observe(float64(duration.Milliseconds()))
		metrics.BatchSizeBytes.WithLabelValues("grpc").Observe(reqSizeBytes)

		return resp, err
	}
}

func makeAuthInterceptor() grpc.UnaryServerInterceptor {
	return func(
		ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler,
	) (any, error) {
		resp, err := handler(ctx, req)

		if errors.As(err, &authErrs.Unauthenticated{}) {
			return nil, status.Error(codes.Unauthenticated, err.Error())
		}

		if errors.As(err, &authErrs.Forbidden{}) {
			return nil, status.Error(codes.PermissionDenied, err.Error())
		}

		return resp, err
	}
}

func makeAuthStreamInterceptor(auth *auth.Handler) grpc.StreamServerInterceptor {
	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
		_, err := auth.PrincipalFromContext(ss.Context())
		if err != nil {
			return status.Error(codes.Unauthenticated, err.Error())
		}
		return handler(srv, ss)
	}
}

func makeIPInterceptor() grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		clientIP := getRealClientIP(ctx)

		// Add IP to context
		ctx = context.WithValue(ctx, "sourceIp", clientIP)
		return handler(ctx, req)
	}
}

func makeOperationalModeInterceptor(state *state.State) grpc.UnaryServerInterceptor {
	return func(
		ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler,
	) (any, error) {
		var err error
		switch state.ServerConfig.Config.OperationalMode.Get() {
		case config.READ_ONLY:
			if config.IsGRPCWrite(info.FullMethod) {
				err = config.ErrReadOnlyModeEnabled
			}
		case config.SCALE_OUT:
			if config.IsGRPCWrite(info.FullMethod) {
				err = config.ErrScaleOutModeEnabled
			}
		case config.WRITE_ONLY:
			if config.IsGRPCRead(info.FullMethod) {
				err = config.ErrWriteOnlyModeEnabled
			}
		default:
			// all good
		}
		if err != nil {
			st := status.New(codes.Unavailable, err.Error())
			return nil, st.Err()
		}
		return handler(ctx, req)
	}
}

func basicAuthUnaryInterceptor(servicePrefix, expectedUsername, expectedPassword string) grpc.UnaryServerInterceptor {
	return func(
		ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler,
	) (any, error) {
		if !strings.HasPrefix(info.FullMethod, servicePrefix) {
			return handler(ctx, req)
		}

		md, ok := metadata.FromIncomingContext(ctx)
		if !ok {
			return nil, status.Error(codes.Unauthenticated, "missing metadata")
		}

		authHeader := md["authorization"]
		if len(authHeader) == 0 || !strings.HasPrefix(authHeader[0], "Basic ") {
			return nil, status.Error(codes.Unauthenticated, "missing or invalid auth header")
		}

		// Decode and validate Basic Auth credentials
		payload, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(authHeader[0], "Basic "))
		if err != nil {
			return nil, status.Error(codes.Unauthenticated, "invalid base64 encoding")
		}

		parts := strings.SplitN(string(payload), ":", 2)
		if len(parts) != 2 || parts[0] != expectedUsername || parts[1] != expectedPassword {
			return nil, status.Error(codes.Unauthenticated, "invalid username or password")
		}

		return handler(ctx, req)
	}
}

func makeMaintenanceModeUnaryInterceptor(maintenanceModeEnabledForLocalhost func() bool) grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
		if maintenanceModeEnabledForLocalhost() {
			return nil, status.Error(codes.Unavailable, "server is in maintenance mode")
		}
		return handler(ctx, req)
	}
}

func makeMaintenanceModeStreamInterceptor(maintenanceModeEnabledForLocalhost func() bool) grpc.StreamServerInterceptor {
	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
		if maintenanceModeEnabledForLocalhost() {
			return status.Error(codes.Unavailable, "server is in maintenance mode")
		}
		return handler(srv, ss)
	}
}

func getRealClientIP(ctx context.Context) string {
	// First, check for forwarded headers in metadata
	md, ok := metadata.FromIncomingContext(ctx)
	if ok {
		if xRealIP := md.Get("x-real-ip"); len(xRealIP) > 0 {
			return xRealIP[0]
		}

		if xForwardedFor := md.Get("x-forwarded-for"); len(xForwardedFor) > 0 {
			// X-Forwarded-For can contain multiple IPs, take the first one
			ips := strings.Split(xForwardedFor[0], ",")
			if len(ips) > 0 {
				return strings.TrimSpace(ips[0])
			}
		}
	}

	// Fall back to peer address
	if p, ok := peer.FromContext(ctx); ok {
		host, _, err := net.SplitHostPort(p.Addr.String())
		if err != nil {
			return convertIP6ToIP4Loopback(p.Addr.String())
		}
		return convertIP6ToIP4Loopback(host)
	}

	return "unknown"
}

func convertIP6ToIP4Loopback(ip string) string {
	if ip == "::1" {
		return "127.0.0.1" // Convert IPv6 loopback to IPv4
	}
	return ip
}

func basicAuthStreamInterceptor(servicePrefix, expectedUsername, expectedPassword string) grpc.StreamServerInterceptor {
	return func(
		srv interface{},
		ss grpc.ServerStream,
		info *grpc.StreamServerInfo,
		handler grpc.StreamHandler,
	) error {
		if !strings.HasPrefix(info.FullMethod, servicePrefix) {
			return handler(srv, ss) // no auth needed
		}

		md, ok := metadata.FromIncomingContext(ss.Context())
		if !ok {
			return status.Error(codes.Unauthenticated, "missing metadata")
		}

		authHeader := md["authorization"]
		if len(authHeader) == 0 || !strings.HasPrefix(authHeader[0], "Basic ") {
			return status.Error(codes.Unauthenticated, "missing or invalid auth header")
		}

		decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(authHeader[0], "Basic "))
		if err != nil {
			return status.Error(codes.Unauthenticated, "invalid base64 encoding")
		}

		parts := strings.SplitN(string(decoded), ":", 2)
		if len(parts) != 2 || parts[0] != expectedUsername || parts[1] != expectedPassword {
			return status.Error(codes.Unauthenticated, "invalid username or password")
		}

		return handler(srv, ss)
	}
}

func StartAndListen(s *grpc.Server, state *state.State) error {
	lis, err := net.Listen("tcp", fmt.Sprintf(":%d",
		state.ServerConfig.Config.GRPC.Port))
	if err != nil {
		return err
	}
	state.Logger.WithField("action", "grpc_startup").
		Infof("grpc server listening at %v", lis.Addr())
	if err := s.Serve(lis); err != nil {
		return fmt.Errorf("failed to serve: %w", err)
	}

	return nil
}
