package cn.iinti.majora3.sdk.client;

import cn.iinti.majora3.sdk.codec.PackageType;
import cn.iinti.majora3.sdk.proto.IProto;
import cn.iinti.majora3.sdk.proto.SessionInit;
import cn.iinti.majora3.sdk.proto.SessionPayload;
import cn.iinti.majora3.sdk.proto.SessionSignal;
import io.netty.buffer.ByteBuf;
import io.netty.channel.*;
import lombok.Getter;

import java.util.HashMap;
import java.util.Map;

/**
 * 管理双工隧道流程，主要包括tcp流量代理、pty控制等
 */
public class SessionManager {

    private static final Map<PackageType, SessionFactory<?>> factories = new HashMap<>();


    public static void registerSessionFactory(PackageType packageType, SessionFactory<?> factory) {
        Class<? extends IProto> peerClass = packageType.getByteBufCreator().getPeerClass();
        if (!SessionInit.class.isAssignableFrom(peerClass)) {
            throw new IllegalArgumentException("Invalid peer class: " + peerClass.getName());
        }
        factories.put(packageType, factory);
    }

    static {
        registerSessionFactory(PackageType.SESSION_INIT_TCP, new TCPSessionFactory());
    }

    public static boolean supportSession(PackageType packageType) {
        return factories.containsKey(packageType);
    }

    public interface SessionFactory<T extends SessionInit> {
        ChannelFuture createSession(MajoraClient client, T sessionInit, SimpleChannelInboundHandler<ByteBuf> channelHandler);
    }

    public SessionManager(MajoraClient majoraClient, MajoraConnection mConnection) {
        this.majoraClient = majoraClient;
        this.mConnection = mConnection;
    }

    private final MajoraClient majoraClient;
    private final MajoraConnection mConnection;

    private final Map<Integer, SessionHolder> sessions = new HashMap<>();


    public class SessionHolder {
        @Getter
        protected final int sessionId;
        protected final Channel channel;

        protected SessionHolder(int sessionId, Channel channel) {
            this.sessionId = sessionId;
            this.channel = channel;
        }

        void onUpstreamSessionData(ByteBuf upstreamPayload) {
            SessionPayload sessionPayload = new SessionPayload(upstreamPayload, sessionId);
            mConnection.writeToMajora(sessionPayload);
        }

        void onUpstreamSessionClose(String reason) {
            SessionSignal.SessionClose sessionClose = new SessionSignal.SessionClose(sessionId, reason);
            mConnection.writeToMajora(sessionClose);
        }
    }

    private void writeSessionCreateFailed(int sessionId, String reason) {
        SessionSignal.SessionClose sessionClose = new SessionSignal.SessionClose(sessionId, reason);
        mConnection.writeToMajora(sessionClose);
    }

    public void onInboundStreamClose(int sessionId) {
        majoraClient.doOnMainThead(() -> {
            SessionHolder sessionHolder = sessions.remove(sessionId);
            if (sessionHolder == null) {
                return;
            }
            sessionHolder.channel.close();
        });
    }

    public void onInboundStreamData(SessionPayload sessionPayload) {
        SessionHolder sessionHolder = sessions.get(sessionPayload.getSessionId());
        if (sessionHolder == null) {
            majoraClient.getLogger().log(() -> "Got none existing SessionPayload from: " + sessionPayload.getSessionId());
            return;
        }
        sessionHolder.channel.writeAndFlush(sessionPayload.content().retain());
    }

    public void destroy() {
        majoraClient.doOnMainThead(() -> {
            for (SessionHolder value : sessions.values()) {
                value.channel.close();
            }
            sessions.clear();
        });
    }

    public <T extends SessionInit> void createSession(MajoraClient client, T sessionInit) {
        int sessionId = sessionInit.getSessionId();
        SessionFactory<T> sessionFactory = findSessionFactory(sessionInit);
        if (sessionFactory == null) {
            writeSessionCreateFailed(sessionId, "can not handle this session because there is no session factory registered:"
                    + sessionInit.getClass()
            );
            return;
        }
        ChannelFuture channelFuture = sessionFactory.createSession(client, sessionInit, new SimpleChannelInboundHandler<ByteBuf>() {
            @Override
            protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) {
                SessionHolder sessionHolder = sessions.get(sessionId);
                if (sessionHolder == null) {
                    client.getLogger().log(() -> "wtf, no session factory for " + sessionInit.getClass().getName());
                    ctx.close();
                    return;
                }
                sessionHolder.onUpstreamSessionData(msg.retain());
            }
        });
        channelFuture.addListener((ChannelFutureListener) future -> {
            if (!future.isSuccess()) {
                majoraClient.getLogger().log(() -> "create session failed:" + sessionInit.desc(), future.cause());
                writeSessionCreateFailed(sessionId, "create session failed:" + future.cause());
                return;
            }
            majoraClient.getLogger().log(() -> "session create:" + sessionInit.desc());
            majoraClient.doOnMainThead(() -> {
                Channel channel = future.channel();
                // register session
                sessions.put(sessionId, new SessionHolder(sessionId, channel));

                // notify session create
                mConnection.writeToMajora(new SessionSignal.SessionReady(sessionId));

                // monitory session close to unregister
                channel.closeFuture().addListener((ChannelFutureListener) future1 -> majoraClient.doOnMainThead(() -> {
                    majoraClient.getLogger().log(() -> "session closed:" + sessionInit.desc());
                    SessionHolder sessionHolder = sessions.remove(sessionId);
                    if (sessionHolder != null) {
                        sessionHolder.onUpstreamSessionClose("upstream remote session closed");
                    }
                }));
            });
        });
    }

    @SuppressWarnings("unchecked")
    private <T extends SessionInit> SessionFactory<T> findSessionFactory(T sessionInit) {
        for (Map.Entry<PackageType, SessionFactory<?>> entry : factories.entrySet()) {
            Class<? extends IProto> peerClass = entry.getKey().getByteBufCreator().getPeerClass();
            if (peerClass.isAssignableFrom(sessionInit.getClass())) {
                return (SessionFactory<T>) entry.getValue();
            }
        }
        return null;
    }
}
