// NOTE: This is mostly a copy paste from the standalone s3 impl, which will be cleaned up later
//! Object storage backend implementation using object_store.
//!
//! ## ETag Implementation Note
//! The `UpdateVersion` struct is serialized to JSON and stored as an ETag string.
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use chroma_config::{registry::Registry, Configurable};
use chroma_error::ChromaError;
use chroma_tracing::util::Stopwatch;
use chroma_types::Cmek;
use futures::stream::{self, StreamExt, TryStreamExt};
use object_store::{
    client::{
        HttpClient, HttpConnector, HttpError, HttpErrorKind, HttpRequest, HttpResponse,
        HttpService, ReqwestConnector,
    },
    gcp::GoogleCloudStorageBuilder,
    ClientOptions, GetRange, HeaderValue, ObjectMeta, ObjectStore, UpdateVersion,
};
use serde::{Deserialize, Serialize};

use crate::{
    config::{ObjectStorageConfig, ObjectStorageProvider, StorageConfig},
    metrics::StorageMetrics,
    s3::DeletedObjects,
    ETag, GetOptions, PutMode, PutOptions, StorageConfigError, StorageError,
};

const GCP_CMEK_HEADER: &str = "x-goog-encryption-kms-key-name";

#[derive(Debug)]
struct HttpClientWrapper {
    reqwest_client: HttpClient,
}

#[async_trait::async_trait]
impl HttpService for HttpClientWrapper {
    async fn call(&self, mut req: HttpRequest) -> Result<HttpResponse, HttpError> {
        // Attach customer managed encryption key if configured
        if let Some(cmek) = req.extensions_mut().remove::<Cmek>() {
            let header = req.headers_mut();
            match cmek {
                Cmek::Gcp(resource) => {
                    header.insert(
                        GCP_CMEK_HEADER,
                        HeaderValue::from_str(&resource)
                            .map_err(|err| HttpError::new(HttpErrorKind::Request, err))?,
                    );
                }
            }
        }
        self.reqwest_client.execute(req).await
    }
}

#[derive(Debug)]
struct ChromaHttpConnector;

impl HttpConnector for ChromaHttpConnector {
    fn connect(&self, options: &ClientOptions) -> object_store::Result<HttpClient> {
        let reqwest_client = ReqwestConnector::default().connect(options)?;
        Ok(HttpClient::new(HttpClientWrapper { reqwest_client }))
    }
}

impl From<object_store::Error> for StorageError {
    fn from(e: object_store::Error) -> Self {
        match e {
            object_store::Error::NotFound { path, source } => StorageError::NotFound {
                path,
                source: source.into(),
            },
            object_store::Error::AlreadyExists { path, source } => StorageError::AlreadyExists {
                path,
                source: source.into(),
            },
            object_store::Error::Precondition { path, source } => StorageError::Precondition {
                path,
                source: source.into(),
            },
            object_store::Error::NotModified { path, source } => StorageError::NotModified {
                path,
                source: source.into(),
            },
            object_store::Error::PermissionDenied { path, source } => {
                StorageError::PermissionDenied {
                    path,
                    source: source.into(),
                }
            }
            object_store::Error::Unauthenticated { path, source } => {
                StorageError::Unauthenticated {
                    path,
                    source: source.into(),
                }
            }
            object_store::Error::NotSupported { source } => StorageError::NotSupported {
                source: source.into(),
            },
            object_store::Error::InvalidPath { source } => StorageError::Generic {
                source: Arc::new(source),
            },
            err @ object_store::Error::Generic { .. } => StorageError::Generic {
                source: Arc::new(err),
            },
            object_store::Error::JoinError { source } => StorageError::Generic {
                source: Arc::new(source),
            },
            object_store::Error::UnknownConfigurationKey { store, key } => {
                StorageError::UnknownConfigurationKey { store, key }
            }
            err => StorageError::Generic {
                source: Arc::new(err),
            },
        }
    }
}

/// Serializable wrapper for UpdateVersion
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ObjectVersionTag {
    e_tag: Option<String>,
    version: Option<String>,
}

impl From<UpdateVersion> for ObjectVersionTag {
    fn from(uv: UpdateVersion) -> Self {
        Self {
            e_tag: uv.e_tag,
            version: uv.version,
        }
    }
}

impl From<ObjectVersionTag> for UpdateVersion {
    fn from(ovt: ObjectVersionTag) -> Self {
        Self {
            e_tag: ovt.e_tag,
            version: ovt.version,
        }
    }
}

/// Convert UpdateVersion to ETag via serialization
impl TryFrom<&UpdateVersion> for ETag {
    type Error = StorageError;

    fn try_from(uv: &UpdateVersion) -> Result<Self, Self::Error> {
        serde_json::to_string(&ObjectVersionTag::from(uv.clone()))
            .map(ETag)
            .map_err(|e| StorageError::Generic {
                source: Arc::new(e),
            })
    }
}

impl TryFrom<ObjectMeta> for ETag {
    type Error = StorageError;

    fn try_from(om: ObjectMeta) -> Result<Self, Self::Error> {
        (&UpdateVersion {
            e_tag: om.e_tag,
            version: om.version,
        })
            .try_into()
    }
}

/// Convert ETag to UpdateVersion via deserialization
impl TryFrom<&ETag> for UpdateVersion {
    type Error = StorageError;

    fn try_from(etag: &ETag) -> Result<Self, Self::Error> {
        let serializable: ObjectVersionTag =
            serde_json::from_str(&etag.0).map_err(|e| StorageError::Generic {
                source: Arc::new(std::io::Error::new(
                    std::io::ErrorKind::InvalidInput,
                    format!("Invalid ETag format: {}", e),
                )),
            })?;
        Ok(serializable.into())
    }
}

impl TryFrom<PutMode> for object_store::PutMode {
    type Error = StorageError;

    fn try_from(value: PutMode) -> Result<Self, Self::Error> {
        Ok(match value {
            PutMode::IfMatch(etag) => Self::Update((&etag).try_into()?),
            PutMode::IfNotExist => Self::Create,
            PutMode::Upsert => Self::Overwrite,
        })
    }
}

#[derive(Clone)]
pub struct ObjectStorage {
    pub(crate) bucket: String,
    pub(super) download_part_size_bytes: u64,
    pub(super) store: Arc<dyn ObjectStore>,
    pub(super) upload_part_size_bytes: u64,
    metrics: StorageMetrics,
}

impl ObjectStorage {
    pub async fn new(config: &ObjectStorageConfig) -> Result<Self, Box<dyn ChromaError>> {
        if config.download_part_size_bytes == 0 || config.upload_part_size_bytes == 0 {
            return Err(StorageError::Message {
                message: "Cannot partition with zero chunk size".to_string(),
            }
            .boxed());
        }
        let store = match config.provider {
            ObjectStorageProvider::GCS => GoogleCloudStorageBuilder::from_env()
                .with_bucket_name(&config.bucket)
                .with_retry(object_store::RetryConfig {
                    max_retries: config.request_retry_count,
                    retry_timeout: Duration::from_millis(config.request_timeout_ms),
                    ..Default::default()
                })
                .with_client_options(
                    ClientOptions::new()
                        .with_timeout(Duration::from_millis(config.request_timeout_ms))
                        .with_connect_timeout(Duration::from_millis(config.connect_timeout_ms)),
                )
                .with_http_connector(ChromaHttpConnector)
                .build()
                .map_err(|e| {
                    Box::new(StorageConfigError::FailedToCreateBucket(format!(
                        "Failed to create GCS client: {}",
                        e
                    ))) as Box<dyn ChromaError>
                })?,
        };

        Ok(ObjectStorage {
            bucket: config.bucket.clone(),
            download_part_size_bytes: config.download_part_size_bytes,
            store: Arc::new(store),
            upload_part_size_bytes: config.upload_part_size_bytes,
            metrics: StorageMetrics::default(),
        })
    }

    pub async fn head(&self, key: &str) -> Result<ObjectMeta, StorageError> {
        Ok(self.store.head(&key.into()).await?)
    }

    pub async fn confirm_same(&self, key: &str, e_tag: &ETag) -> Result<bool, StorageError> {
        let metadata = self.head(key).await?;
        let current_etag = ETag::try_from(metadata)?;
        Ok(current_etag.0 == e_tag.0)
    }

    pub fn partition(total_size: u64, chunk_size: u64) -> impl Iterator<Item = (u64, u64)> {
        let chunk_count = total_size.div_ceil(chunk_size);
        let chunk_start = (0..chunk_count).map(move |i| i * chunk_size);
        chunk_start
            .clone()
            .zip(chunk_start.skip(1).chain([total_size]))
    }

    async fn multipart_get(&self, key: &str) -> Result<(Bytes, ETag), StorageError> {
        let metadata = self.head(key).await?;
        let object_size = metadata.size;
        let etag = metadata.try_into()?;
        if object_size == 0 {
            return Ok((Bytes::new(), etag));
        }

        let chunk_ranges = Self::partition(object_size, self.download_part_size_bytes)
            .map(|(start, end)| start..end);

        let mut buffer = BytesMut::zeroed(object_size as usize);
        let stopwatch = Stopwatch::new(
            &self.metrics.s3_get_latency_ms,
            &[],
            chroma_tracing::util::StopWatchUnit::Millis,
        );
        let get_part_futures = buffer
            .chunks_mut(self.download_part_size_bytes as usize)
            .zip(chunk_ranges)
            .map(|(byte_buffer, byte_range)| async move {
                let bytes = self
                    .store
                    .get_opts(
                        &key.into(),
                        object_store::GetOptions {
                            range: Some(GetRange::Bounded(byte_range)),
                            ..Default::default()
                        },
                    )
                    .await?
                    .bytes()
                    .await?;
                if bytes.len() != byte_buffer.len() {
                    return Err(StorageError::Message {
                        message: format!(
                            "Expected {} bytes in part, got {} bytes",
                            byte_buffer.len(),
                            bytes.len()
                        ),
                    });
                }
                byte_buffer.copy_from_slice(&bytes);
                Ok(())
            })
            .collect::<Vec<_>>();

        let chunk_count = get_part_futures.len();
        self.metrics.s3_get_count.add(chunk_count as u64, &[]);
        stream::iter(get_part_futures)
            .buffer_unordered(chunk_count)
            .try_collect::<Vec<_>>()
            .await?;

        drop(stopwatch);
        Ok((buffer.freeze(), etag))
    }

    async fn oneshot_get(&self, key: &str) -> Result<(Bytes, ETag), StorageError> {
        self.metrics.s3_get_count.add(1, &[]);
        let _stopwatch = Stopwatch::new(
            &self.metrics.s3_get_latency_ms,
            &[],
            chroma_tracing::util::StopWatchUnit::Millis,
        );
        let result = self.store.get_opts(&key.into(), Default::default()).await?;
        let update_version = UpdateVersion {
            e_tag: result.meta.e_tag.clone(),
            version: result.meta.version.clone(),
        };
        let etag = (&update_version).try_into()?;
        Ok((result.bytes().await?, etag))
    }

    pub async fn get(&self, key: &str, options: GetOptions) -> Result<(Bytes, ETag), StorageError> {
        if options.request_parallelism {
            self.multipart_get(key).await
        } else {
            self.oneshot_get(key).await
        }
    }

    async fn multipart_put(
        &self,
        key: &str,
        bytes: Bytes,
        options: PutOptions,
    ) -> Result<ETag, StorageError> {
        let mut put_options = object_store::PutMultipartOptions::default();

        // Apply customer managed encryption key
        if let Some(cmek) = options.cmek {
            put_options.extensions.insert(cmek);
        }

        let total_size_bytes = bytes.len() as u64;
        self.metrics.s3_put_count.add(1, &[]);
        self.metrics.s3_put_bytes.record(total_size_bytes, &[]);
        let stopwatch = Stopwatch::new(
            &self.metrics.s3_put_latency_ms,
            &[],
            chroma_tracing::util::StopWatchUnit::Millis,
        );
        let chunk_ranges = Self::partition(bytes.len() as u64, self.upload_part_size_bytes)
            .map(|(start, end)| start as usize..end as usize);
        let mut upload_handle = self
            .store
            .put_multipart_opts(&key.into(), put_options)
            .await?;
        let upload_parts = chunk_ranges
            .map(|range| {
                self.metrics
                    .s3_upload_part_bytes
                    .record(range.end.saturating_sub(range.start) as u64, &[]);
                upload_handle.put_part(bytes.slice(range).into())
            })
            .collect::<Vec<_>>();
        let chunk_count = upload_parts.len();

        stream::iter(upload_parts)
            .buffer_unordered(chunk_count)
            .try_collect::<Vec<_>>()
            .await?;

        self.metrics
            .s3_multipart_upload_parts
            .record(chunk_count as u64, &[]);

        let result = upload_handle.complete().await?;

        let update_version = UpdateVersion {
            e_tag: result.e_tag,
            version: result.version,
        };

        let res = (&update_version).try_into();
        if res.is_err() {
            self.metrics.s3_put_error_count.add(1, &[]);
        }
        let duration = stopwatch.finish();
        if duration > Duration::from_secs(1) {
            self.metrics.s3_put_bytes_slow.record(total_size_bytes, &[]);
        }
        res
    }

    pub async fn oneshot_put(
        &self,
        key: &str,
        bytes: Bytes,
        options: PutOptions,
    ) -> Result<ETag, StorageError> {
        let mut put_options = object_store::PutOptions::default();

        // Apply customer managed encryption key
        if let Some(cmek) = options.cmek {
            put_options.extensions.insert(cmek);
        }

        // Apply conditional operations
        put_options.mode = options.mode.try_into()?;

        let total_size_bytes = bytes.len() as u64;
        self.metrics.s3_put_count.add(1, &[]);
        self.metrics.s3_put_bytes.record(total_size_bytes, &[]);
        let stopwatch = Stopwatch::new(
            &self.metrics.s3_put_latency_ms,
            &[],
            chroma_tracing::util::StopWatchUnit::Millis,
        );

        let result = self
            .store
            .put_opts(&key.into(), bytes.into(), put_options)
            .await?;

        let update_version = UpdateVersion {
            e_tag: result.e_tag,
            version: result.version,
        };

        let res = (&update_version).try_into();
        if res.is_err() {
            self.metrics.s3_put_error_count.add(1, &[]);
        }
        let duration = stopwatch.finish();
        if duration > Duration::from_secs(1) {
            self.metrics.s3_put_bytes_slow.record(total_size_bytes, &[]);
        }
        res
    }

    pub fn is_oneshot_upload(&self, total_size_bytes: u64, options: &PutOptions) -> bool {
        // NOTE(sicheng): GCS has no support for conditional multipart upload
        // https://docs.cloud.google.com/storage/docs/multipart-uploads
        total_size_bytes <= self.upload_part_size_bytes
            || matches!(
                options.mode,
                crate::PutMode::IfMatch(_) | crate::PutMode::IfNotExist
            )
    }

    pub async fn put(
        &self,
        key: &str,
        bytes: Bytes,
        options: PutOptions,
    ) -> Result<ETag, StorageError> {
        if self.is_oneshot_upload(bytes.len() as u64, &options) {
            self.oneshot_put(key, bytes, options).await
        } else {
            self.multipart_put(key, bytes, options).await
        }
    }

    // TODO(sicheng): This was used for hnsw files on disk and should be cleaned up
    // because we directly load hnsw to memory now
    pub async fn put_file(
        &self,
        key: &str,
        path: &str,
        options: PutOptions,
    ) -> Result<ETag, StorageError> {
        let bytes = tokio::fs::read(path)
            .await
            .map_err(|e| StorageError::Generic {
                source: Arc::new(e),
            })?;
        self.put(key, bytes.into(), options).await
    }

    pub async fn delete(&self, key: &str) -> Result<(), StorageError> {
        self.metrics.s3_delete_count.add(1, &[]);
        self.store.delete(&key.into()).await?;
        Ok(())
    }

    pub async fn delete_many<S: AsRef<str> + std::fmt::Debug, I: IntoIterator<Item = S>>(
        &self,
        keys: I,
    ) -> Result<DeletedObjects, StorageError> {
        let keys = keys.into_iter().collect::<Vec<_>>();

        // Execute deletes in parallel
        let results = stream::iter(keys)
            .map(|key| async move {
                let key_str = key.as_ref().to_string();
                (key_str, self.delete(key.as_ref()).await)
            })
            .buffer_unordered(32)
            .collect::<Vec<_>>()
            .await;

        self.metrics
            .s3_delete_many_count
            .add(results.len() as u64, &[]);

        let mut result = DeletedObjects::default();
        for (key, res) in results {
            match res {
                Ok(_) => result.deleted.push(key),
                Err(e) => result.errors.push(e),
            }
        }

        Ok(result)
    }

    pub async fn rename(&self, src_key: &str, dst_key: &str) -> Result<(), StorageError> {
        self.metrics.s3_rename_count.add(1, &[]);
        let _stopwatch = Stopwatch::new(
            &self.metrics.s3_rename_latency_ms,
            &[],
            chroma_tracing::util::StopWatchUnit::Millis,
        );

        self.store.rename(&src_key.into(), &dst_key.into()).await?;
        Ok(())
    }

    pub async fn copy(&self, src_key: &str, dst_key: &str) -> Result<(), StorageError> {
        self.metrics.s3_copy_count.add(1, &[]);
        let _stopwatch = Stopwatch::new(
            &self.metrics.s3_copy_latency_ms,
            &[],
            chroma_tracing::util::StopWatchUnit::Millis,
        );
        self.store.copy(&src_key.into(), &dst_key.into()).await?;
        Ok(())
    }

    pub async fn list_prefix(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
        let prefix_path = if prefix.is_empty() {
            None
        } else {
            Some(prefix.into())
        };

        self.metrics.s3_list_count.add(1, &[]);
        let _stopwatch = Stopwatch::new(
            &self.metrics.s3_list_latency_ms,
            &[],
            chroma_tracing::util::StopWatchUnit::Millis,
        );

        let list_stream = self.store.list(prefix_path.as_ref());

        let keys = list_stream
            .map_ok(|meta| meta.location.to_string())
            .try_collect()
            .await?;

        Ok(keys)
    }
}

#[async_trait]
impl Configurable<StorageConfig> for ObjectStorage {
    async fn try_from_config(
        config: &StorageConfig,
        _registry: &Registry,
    ) -> Result<Self, Box<dyn ChromaError>> {
        match config {
            StorageConfig::Object(gcs_config) => ObjectStorage::new(gcs_config).await,
            _ => Err(Box::new(StorageConfigError::InvalidStorageConfig)),
        }
    }
}
