package com.aliyun.openservices.iot.api.http2.connection.impl;

import com.aliyun.openservices.iot.api.exception.IotClientException;
import com.aliyun.openservices.iot.api.http2.callback.Http2StreamListener;
import com.aliyun.openservices.iot.api.http2.connection.Connection;
import com.aliyun.openservices.iot.api.http2.connection.ConnectionListener;
import com.aliyun.openservices.iot.api.http2.connection.ConnectionStatus;
import com.aliyun.openservices.iot.api.http2.connection.StreamWriteOperation;
import com.aliyun.openservices.iot.api.http2.netty.NettyHttp2Handler;
import com.google.common.collect.Maps;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.*;
import io.netty.handler.codec.http2.Http2Connection.PropertyKey;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

/**
 * wrapper for NettyHttp2Handler
 *
 * @author brhao
 * @date 23/03/2018
 */
@Slf4j
public class ConnectionImpl implements Connection {
    private final PropertyKey STREAM_LISTENER_KEY;

    private Http2Connection http2Connection;
    @Getter
    private ChannelHandlerContext ctx;
    private Http2ConnectionDecoder decoder;
    private Http2ConnectionEncoder encoder;
    private ConnectionListener connectionListener;
    private ConnectionStatus status;
    private Map<String, PropertyKey> propertyKeyMap;

    public ConnectionImpl(NettyHttp2Handler nettyHttp2Handler, ChannelHandlerContext ctx) {
        this.http2Connection = nettyHttp2Handler.connection();
        this.decoder = nettyHttp2Handler.decoder();
        this.encoder = nettyHttp2Handler.encoder();
        this.ctx = ctx;
        STREAM_LISTENER_KEY = http2Connection.newKey();
        propertyKeyMap = Maps.newConcurrentMap();
    }

    private void setStreamListener(Http2Stream stream, Http2StreamListener http2StreamListener) {
        log.debug("set stream listener for streamId:{}", stream.id());
        stream.setProperty(STREAM_LISTENER_KEY, http2StreamListener);
    }

    private Http2Stream stream(int id) {
        return http2Connection.stream(id);
    }

    @Override
    public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) {
        Http2Stream stream = stream(streamId);
        int readableSize = data.readableBytes();
        byte[] bytes;
        if (padding == 0 && data.hasArray()) {
            bytes = data.array();
        } else {
            int dataSize = readableSize - padding;
            byte[] temp = new byte[readableSize];
            data.readBytes(temp, 0, readableSize);
            if (padding == 0) {
                bytes = temp;
            } else {
                bytes = Arrays.copyOf(temp, dataSize);
            }
        }

        streamCallbackApply(stream, streamService -> {
            streamService.onDataRead(this, stream, bytes, endOfStream);
        });

        return readableSize;
    }

    @Override
    public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency,
                              short weight, boolean exclusive, int padding, boolean endOfStream) {
        Http2Stream stream = stream(streamId);

        if (!streamListener(stream).isPresent()) {
            defaultStreamListener().ifPresent(l -> setStreamListener(stream, l));
        }

        boolean r = streamCallbackApply(stream, streamService ->
                streamService.onHeadersRead(this, stream, headers, endOfStream));

        if (!r) {
            writeGoAway(streamId, 2, ("no handler for stream " + streamId).getBytes());
        }
    }

    private Optional<Http2StreamListener> defaultStreamListener() {
        return streamListener(http2Connection.connectionStream());
    }

    @Override
    public void setConnectionListener(ConnectionListener listener) {
        this.connectionListener = listener;
        if (listener != null) {
            listener.onStatusChange(getStatus(), this);
        }
    }

    @Override
    public void removeConnectListener() {
        setConnectionListener(null);
    }

    @Override
    public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) {
        if (connectionListener != null) {
            connectionListener.onSettingReceive(this, settings);
        }
    }

    @Override
    public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) {

    }

    @Override
    public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) {
        streamCallbackApply(stream(streamId), l -> l.onStreamError(this, stream(streamId),
                new IOException("rst frame received, code : " + errorCode)));
    }

    @Override
    public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags,
                               ByteBuf payload) {
        streamCallbackApply(stream(streamId), l -> l.onStreamError(this, stream(streamId),
                new IOException("unknown frame received, hex dump: " + ByteBufUtil.hexDump(payload))));
    }

    @Override
    public void onConnectionClosed() {
        setStatus(ConnectionStatus.CLOSED);
    }

    @Override
    public void setStatus(ConnectionStatus status) {
        if (connectionListener != null) {
            connectionListener.onStatusChange(status, this);
        }
        this.status = status;
    }

    @Override
    public ConnectionStatus getStatus() {
        return this.status;
    }

    @Override
    public boolean isAuthorized() {
        return this.status.equals(ConnectionStatus.AUTHORIZED);
    }

    @Override
    public String toString() {
        return ctx.channel().id().asShortText();
    }

    private Optional<Http2StreamListener> streamListener(Http2Stream stream) {
        return Optional.ofNullable(stream.getProperty(STREAM_LISTENER_KEY));
    }

    @Override
    public void onError(ChannelHandlerContext ctx, boolean outbound, Throwable cause) {
        try {
            http2Connection.forEachActiveStream(stream -> {
                streamCallbackApply(stream, l -> l.onStreamError(
                        ConnectionImpl.this, stream, new IOException(cause)));
                return true;
            });
        } catch (Http2Exception e) {
            log.error("error occurs when notify listener. exception: ", e);
        }
    }

    @Override
    public PropertyKey getPropertyKey(String keyName) {
        return propertyKeyMap.computeIfAbsent(keyName, name -> http2Connection.newKey());
    }

    @Override
    public void setProperty(PropertyKey key, Object object) {
        http2Connection.connectionStream().setProperty(key, object);
    }

    @Override
    public Object getProperty(PropertyKey key) {
        return http2Connection.connectionStream().getProperty(key);
    }

    @Override
    public void setDefaultStreamListener(Http2StreamListener http2StreamListener) {
        setStreamListener(http2Connection.connectionStream(), http2StreamListener);
    }

    @Override
    public void close() {
        ctx.close().syncUninterruptibly();
    }

    private boolean streamCallbackApply(Http2Stream stream, Consumer<Http2StreamListener> f) {
        Optional<Http2StreamListener> optional = streamListener(stream);

        if (!optional.isPresent()) {
            return false;
        }

        Http2StreamListener streamService = optional.get();
        f.accept(streamService);
        return true;
    }

    @Override
    public CompletableFuture<StreamWriteOperation> writeHeaders(Http2Headers headers, boolean endStream,
                                                                Http2StreamListener http2StreamListener) {
        return doInEventLoop((cf, channelPromise) -> {
            log.debug("write headers {}", headers);
            int streamId = http2Connection.local().incrementAndGetNextStreamId();
            Http2Stream stream = http2Connection.stream(streamId);
            if (stream == null) {
                try {
                    stream = http2Connection.local().createStream(streamId, false);
                } catch (Http2Exception e) {
                    throw new IotClientException(e);
                }
            }
            if (http2StreamListener != null) {
                setStreamListener(stream, http2StreamListener);
            }
            cf.setResult(new StreamWriteOperation(stream, this));
            encoder.writeHeaders(ctx, streamId, headers, 0, endStream, channelPromise);
            ctx.pipeline().flush();
        });
    }

    @Override
    public CompletableFuture<StreamWriteOperation> writeData(int streamId, byte[] data, boolean endStream) {
        return doInEventLoop((cf, channelPromise) -> {
            log.info("write data on connection {}, stream id: {}, size : {}",
                    ctx.channel().id(), streamId, data.length);
            cf.setResult(new StreamWriteOperation(http2Connection.stream(streamId), this));
            encoder.writeData(ctx, streamId, Unpooled.wrappedBuffer(data), 0, endStream, channelPromise);
            ctx.pipeline().flush();
        });
    }

    @Override
    public CompletableFuture<Connection> writeRst(int streamId, int errorCode) {
        return doInEventLoop((cf, channelPromise) -> {
            log.info("write data on connection {}, stream id: {}, error code: {}",
                    ctx.channel().id(), streamId, errorCode);
            cf.setResult(this);
            encoder.writeRstStream(ctx, streamId, errorCode, channelPromise);
            ctx.pipeline().flush();
        });
    }

    @Override
    public CompletableFuture<Connection> writeGoAway(int lastStreamId, int errorCode, byte[] debugData) {
        return doInEventLoop((cf, channelPromise) -> {
            log.info("write goaway on connection {}, stream id: {}, size : {}",
                    ctx.channel().id(), lastStreamId, debugData.length);
            cf.setResult(this);
            encoder.writeGoAway(ctx, lastStreamId, errorCode, Unpooled.wrappedBuffer(debugData), channelPromise);
            ctx.pipeline().flush();
        });
    }

    private <R> CompletableFuture<R> doInEventLoop(BiConsumer<CompletableFutureBridge<R>, ChannelPromise> consumer) {
        CompletableFutureBridge<R> cf = new CompletableFutureBridge<>();

        ChannelPromise promise = ctx.newPromise();

        promise.addListener(future -> {
            if (future.isSuccess()) {
                cf.complete();
            } else {
                cf.completeExceptionally(future.cause());
            }
        });

        if (ctx.channel().eventLoop().inEventLoop()) {
            consumer.accept(cf, promise);
            return cf;
        }

        CompletableFuture.runAsync(() -> consumer.accept(cf, promise), ctx.channel().eventLoop())
                .whenComplete((f, t) -> {
                    if (t != null) {
                        cf.completeExceptionally(t);
                    }
                });

        return cf;
    }

    class CompletableFutureBridge<T> extends CompletableFuture<T> {
        @Setter
        private T result;

        void complete() {
            this.complete(result);
        }
    }
}
