/*
 * SPDX-License-Identifier: Apache-2.0
 *
 * The OpenSearch Contributors require contributions made to
 * this file be licensed under the Apache-2.0 license or a
 * compatible open source license.
 */

package org.opensearch.wlm;

import org.opensearch.common.util.concurrent.ThreadContextStatePropagator;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * This class is used to propagate WorkloadGroup related headers to request and nodes
 */
public class WorkloadGroupThreadContextStatePropagator implements ThreadContextStatePropagator {

    public static List<String> PROPAGATED_HEADERS = List.of(WorkloadGroupTask.WORKLOAD_GROUP_ID_HEADER);

    /**
     * @param source current context transient headers
     * @return the map of header and their values to be propagated across request threadContexts
     */
    @Override
    @SuppressWarnings("removal")
    public Map<String, Object> transients(Map<String, Object> source) {
        final Map<String, Object> transientHeaders = new HashMap<>();

        for (String headerName : PROPAGATED_HEADERS) {
            transientHeaders.compute(headerName, (k, v) -> source.get(headerName));
        }
        return transientHeaders;
    }

    /**
     * @param source current context headers
     * @return map of header and their values to be propagated across nodes
     */
    @Override
    @SuppressWarnings("removal")
    public Map<String, String> headers(Map<String, Object> source) {
        final Map<String, String> propagatedHeaders = new HashMap<>();

        for (String headerName : PROPAGATED_HEADERS) {
            propagatedHeaders.compute(headerName, (k, v) -> (String) source.get(headerName));
        }
        return propagatedHeaders;
    }
}
