/*
 * Decompiled with CFR 0.152.
 */
package io.kroxylicious.proxy.internal.filter;

import io.kroxylicious.proxy.filter.DescribeClusterResponseFilter;
import io.kroxylicious.proxy.filter.FetchResponseFilter;
import io.kroxylicious.proxy.filter.FilterContext;
import io.kroxylicious.proxy.filter.FindCoordinatorResponseFilter;
import io.kroxylicious.proxy.filter.MetadataResponseFilter;
import io.kroxylicious.proxy.filter.ProduceResponseFilter;
import io.kroxylicious.proxy.filter.ResponseFilterResult;
import io.kroxylicious.proxy.internal.net.EndpointReconciler;
import io.kroxylicious.proxy.model.VirtualCluster;
import io.kroxylicious.proxy.service.HostPort;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletionStage;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.ObjIntConsumer;
import java.util.function.ToIntFunction;
import org.apache.kafka.common.message.DescribeClusterResponseData;
import org.apache.kafka.common.message.FetchResponseData;
import org.apache.kafka.common.message.FindCoordinatorResponseData;
import org.apache.kafka.common.message.MetadataResponseData;
import org.apache.kafka.common.message.ProduceResponseData;
import org.apache.kafka.common.message.ResponseHeaderData;
import org.apache.kafka.common.protocol.ApiMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BrokerAddressFilter
implements MetadataResponseFilter,
FindCoordinatorResponseFilter,
DescribeClusterResponseFilter,
ProduceResponseFilter,
FetchResponseFilter {
    private static final Logger LOGGER = LoggerFactory.getLogger(BrokerAddressFilter.class);
    private final VirtualCluster virtualCluster;
    private final EndpointReconciler reconciler;

    public BrokerAddressFilter(VirtualCluster virtualCluster, EndpointReconciler reconciler) {
        this.virtualCluster = virtualCluster;
        this.reconciler = reconciler;
    }

    public CompletionStage<ResponseFilterResult> onMetadataResponse(short apiVersion, ResponseHeaderData header, MetadataResponseData data, FilterContext context) {
        HashMap<Integer, HostPort> nodeMap = new HashMap<Integer, HostPort>();
        for (MetadataResponseData.MetadataResponseBroker broker : data.brokers()) {
            nodeMap.put(broker.nodeId(), new HostPort(broker.host(), broker.port()));
            this.apply(context, broker, MetadataResponseData.MetadataResponseBroker::nodeId, MetadataResponseData.MetadataResponseBroker::host, MetadataResponseData.MetadataResponseBroker::port, MetadataResponseData.MetadataResponseBroker::setHost, MetadataResponseData.MetadataResponseBroker::setPort);
        }
        return this.doReconcileThenForwardResponse(header, (ApiMessage)data, context, nodeMap);
    }

    public CompletionStage<ResponseFilterResult> onDescribeClusterResponse(short apiVersion, ResponseHeaderData header, DescribeClusterResponseData data, FilterContext context) {
        HashMap<Integer, HostPort> nodeMap = new HashMap<Integer, HostPort>();
        for (DescribeClusterResponseData.DescribeClusterBroker broker : data.brokers()) {
            nodeMap.put(broker.brokerId(), new HostPort(broker.host(), broker.port()));
            this.apply(context, broker, DescribeClusterResponseData.DescribeClusterBroker::brokerId, DescribeClusterResponseData.DescribeClusterBroker::host, DescribeClusterResponseData.DescribeClusterBroker::port, DescribeClusterResponseData.DescribeClusterBroker::setHost, DescribeClusterResponseData.DescribeClusterBroker::setPort);
        }
        return this.doReconcileThenForwardResponse(header, (ApiMessage)data, context, nodeMap);
    }

    public CompletionStage<ResponseFilterResult> onFindCoordinatorResponse(short apiVersion, ResponseHeaderData header, FindCoordinatorResponseData data, FilterContext context) {
        for (FindCoordinatorResponseData.Coordinator coordinator : data.coordinators()) {
            if (coordinator.nodeId() < 0) continue;
            this.apply(context, coordinator, FindCoordinatorResponseData.Coordinator::nodeId, FindCoordinatorResponseData.Coordinator::host, FindCoordinatorResponseData.Coordinator::port, FindCoordinatorResponseData.Coordinator::setHost, FindCoordinatorResponseData.Coordinator::setPort);
        }
        if (data.nodeId() >= 0 && data.host() != null && !data.host().isEmpty() && data.port() > 0) {
            this.apply(context, data, FindCoordinatorResponseData::nodeId, FindCoordinatorResponseData::host, FindCoordinatorResponseData::port, FindCoordinatorResponseData::setHost, FindCoordinatorResponseData::setPort);
        }
        return context.forwardResponse(header, (ApiMessage)data);
    }

    public boolean shouldHandleProduceResponse(short apiVersion) {
        return apiVersion >= 10;
    }

    public CompletionStage<ResponseFilterResult> onProduceResponse(short apiVersion, ResponseHeaderData header, ProduceResponseData response, FilterContext context) {
        if (response.nodeEndpoints() != null) {
            response.nodeEndpoints().forEach(ne -> this.apply(context, ne, ProduceResponseData.NodeEndpoint::nodeId, ProduceResponseData.NodeEndpoint::host, ProduceResponseData.NodeEndpoint::port, ProduceResponseData.NodeEndpoint::setHost, ProduceResponseData.NodeEndpoint::setPort));
        }
        return context.forwardResponse(header, (ApiMessage)response);
    }

    public boolean shouldHandleFetchResponse(short apiVersion) {
        return apiVersion >= 16;
    }

    public CompletionStage<ResponseFilterResult> onFetchResponse(short apiVersion, ResponseHeaderData header, FetchResponseData response, FilterContext context) {
        if (response.nodeEndpoints() != null) {
            response.nodeEndpoints().forEach(ne -> this.apply(context, ne, FetchResponseData.NodeEndpoint::nodeId, FetchResponseData.NodeEndpoint::host, FetchResponseData.NodeEndpoint::port, FetchResponseData.NodeEndpoint::setHost, FetchResponseData.NodeEndpoint::setPort));
        }
        return context.forwardResponse(header, (ApiMessage)response);
    }

    private <T> void apply(FilterContext context, T broker, Function<T, Integer> nodeIdGetter, Function<T, String> hostGetter, ToIntFunction<T> portGetter, BiConsumer<T, String> hostSetter, ObjIntConsumer<T> portSetter) {
        String incomingHost = hostGetter.apply(broker);
        int incomingPort = portGetter.applyAsInt(broker);
        HostPort downstreamAddress = this.virtualCluster.getBrokerAddress(nodeIdGetter.apply(broker));
        LOGGER.trace("{}: Rewriting broker address in response {}:{} -> {}", new Object[]{context, incomingHost, incomingPort, downstreamAddress});
        hostSetter.accept(broker, downstreamAddress.host());
        portSetter.accept(broker, downstreamAddress.port());
    }

    private CompletionStage<ResponseFilterResult> doReconcileThenForwardResponse(ResponseHeaderData header, ApiMessage data, FilterContext context, Map<Integer, HostPort> nodeMap) {
        return this.reconciler.reconcile(this.virtualCluster, nodeMap).toCompletableFuture().thenCompose(u -> {
            LOGGER.debug("Endpoint reconciliation complete for virtual cluster {}", (Object)this.virtualCluster);
            return context.responseFilterResultBuilder().forward((ApiMessage)header, data).completed();
        });
    }
}

