﻿// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text.Json;
using Microsoft.Data.SqlClient;
using Microsoft.Data.SqlTypes;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.VectorData.ProviderServices;

namespace Microsoft.SemanticKernel.Connectors.SqlServer;

internal sealed class SqlServerMapper<TRecord>(CollectionModel model)
{
    public TRecord MapFromStorageToDataModel(SqlDataReader reader, bool includeVectors)
    {
        var record = model.CreateRecord<TRecord>()!;

        PopulateValue(reader, model.KeyProperty, record);

        foreach (var property in model.DataProperties)
        {
            PopulateValue(reader, property, record);
        }

        if (includeVectors)
        {
            foreach (var property in model.VectorProperties)
            {
                try
                {
                    var ordinal = reader.GetOrdinal(property.StorageName);

                    if (!reader.IsDBNull(ordinal))
                    {
                        var vector = reader.GetFieldValue<SqlVector<float>>(ordinal);

                        property.SetValueAsObject(record, property.Type switch
                        {
                            var t when t == typeof(SqlVector<float>) => vector,
                            var t when t == typeof(ReadOnlyMemory<float>) => vector.Memory,
                            var t when t == typeof(Embedding<float>) => new Embedding<float>(vector.Memory),
                            var t when t == typeof(float[])
                                => MemoryMarshal.TryGetArray(vector.Memory, out ArraySegment<float> segment)
                                    && segment.Count == segment.Array!.Length
                                    ? segment.Array
                                    : vector.Memory.ToArray(),

                            _ => throw new UnreachableException()
                        });
                    }
                }
                catch (Exception e)
                {
                    throw new InvalidOperationException($"Failed to deserialize vector property '{property.ModelName}'.", e);
                }
            }
        }

        return record;

        static void PopulateValue(SqlDataReader reader, PropertyModel property, object record)
        {
            try
            {
                var ordinal = reader.GetOrdinal(property.StorageName);

                if (reader.IsDBNull(ordinal))
                {
                    property.SetValueAsObject(record, null);
                    return;
                }

                switch (Nullable.GetUnderlyingType(property.Type) ?? property.Type)
                {
                    case var t when t == typeof(byte):
                        property.SetValue(record, reader.GetByte(ordinal)); // TINYINT
                        break;
                    case var t when t == typeof(short):
                        property.SetValue(record, reader.GetInt16(ordinal)); // SMALLINT
                        break;
                    case var t when t == typeof(int):
                        property.SetValue(record, reader.GetInt32(ordinal)); // INT
                        break;
                    case var t when t == typeof(long):
                        property.SetValue(record, reader.GetInt64(ordinal)); // BIGINT
                        break;

                    case var t when t == typeof(float):
                        property.SetValue(record, reader.GetFloat(ordinal)); // REAL
                        break;
                    case var t when t == typeof(double):
                        property.SetValue(record, reader.GetDouble(ordinal)); // FLOAT
                        break;
                    case var t when t == typeof(decimal):
                        property.SetValue(record, reader.GetDecimal(ordinal)); // DECIMAL
                        break;

                    case var t when t == typeof(string):
                        property.SetValue(record, reader.GetString(ordinal)); // NVARCHAR
                        break;
                    case var t when t == typeof(Guid):
                        property.SetValue(record, reader.GetGuid(ordinal)); // UNIQUEIDENTIFIER
                        break;
                    case var t when t == typeof(byte[]):
                        property.SetValueAsObject(record, reader.GetValue(ordinal)); // VARBINARY
                        break;
                    case var t when t == typeof(bool):
                        property.SetValue(record, reader.GetBoolean(ordinal)); // BIT
                        break;

                    case var t when t == typeof(DateTime):
                        property.SetValue(record, reader.GetDateTime(ordinal)); // DATETIME2
                        break;
                    case var t when t == typeof(DateTimeOffset):
                        property.SetValue(record, reader.GetDateTimeOffset(ordinal)); // DATETIMEOFFSET
                        break;
#if NET
                    case var t when t == typeof(DateOnly):
                        property.SetValue(record, reader.GetFieldValue<DateOnly>(ordinal)); // DATE
                        break;
                    case var t when t == typeof(TimeOnly):
                        property.SetValue(record, reader.GetFieldValue<TimeOnly>(ordinal)); // TIME
                        break;
#endif

                    // We map string[] and List<string> properties to SQL Server JSON columns, so deserialize from JSON here.
                    case var t when t == typeof(string[]):
                        property.SetValue(record, JsonSerializer.Deserialize<string[]>(
                            reader.GetString(ordinal),
                            SqlServerJsonSerializerContext.Default.StringArray));
                        break;
                    case var t when t == typeof(List<string>):
                        property.SetValue(record, JsonSerializer.Deserialize<List<string>>(
                            reader.GetString(ordinal),
                            SqlServerJsonSerializerContext.Default.ListString));
                        break;

                    default:
                        throw new NotSupportedException($"Unsupported type '{property.Type.Name}' for property '{property.ModelName}'.");
                }
            }
            catch (Exception ex)
            {
                throw new InvalidOperationException($"Failed to read property '{property.ModelName}' of type '{property.Type.Name}'.", ex);
            }
        }
    }
}
