/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.federated;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.io.File;
import java.io.Serializable;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLException;
import org.apache.log4j.Logger;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.federated.FederatedLookupTable;
import org.apache.sysds.runtime.controlprogram.federated.FederatedReadCache;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandler;
import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkloadAnalyzer;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.NetworkTrafficCounter;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.lineage.LineageCache;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageItem;

public class FederatedWorker {
    protected static Logger log = Logger.getLogger(FederatedWorker.class);
    private final int _port;
    private final FederatedLookupTable _flt;
    private final FederatedReadCache _frc;
    private final FederatedWorkloadAnalyzer _fan;
    private final boolean _debug;
    private Timing networkTimer = new Timing();

    public FederatedWorker(int port, boolean debug) {
        this._flt = new FederatedLookupTable();
        this._frc = new FederatedReadCache();
        this._fan = ConfigurationManager.getCompressConfig().isWorkload() ? new FederatedWorkloadAnalyzer() : null;
        this._port = port == -1 ? 4040 : port;
        this._debug = debug;
        LineageCacheConfig.setConfig(DMLScript.LINEAGE_REUSE);
        LineageCacheConfig.setCachePolicy(DMLScript.LINEAGE_POLICY);
        LineageCacheConfig.setEstimator(DMLScript.LINEAGE_ESTIMATE);
        this.run();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void run() {
        log.info((Object)("Setting up Federated Worker on port " + this._port));
        int par_conn = ConfigurationManager.getDMLConfig().getIntValue("sysds.federated.par_conn");
        int EVENT_LOOP_THREADS = par_conn > 0 ? par_conn : InfrastructureAnalyzer.getLocalParallelism();
        NioEventLoopGroup bossGroup = new NioEventLoopGroup(1);
        ThreadPoolExecutor workerTPE = new ThreadPoolExecutor(1, Integer.MAX_VALUE, 10L, TimeUnit.SECONDS, new SynchronousQueue<Runnable>(true));
        NioEventLoopGroup workerGroup = new NioEventLoopGroup(EVENT_LOOP_THREADS, (Executor)workerTPE);
        boolean ssl = ConfigurationManager.isFederatedSSL();
        try {
            ServerBootstrap b = new ServerBootstrap();
            b.group((EventLoopGroup)bossGroup, (EventLoopGroup)workerGroup);
            b.channel(NioServerSocketChannel.class);
            b.childHandler(this.createChannel(ssl));
            b.option(ChannelOption.SO_BACKLOG, (Object)128);
            b.childOption(ChannelOption.SO_KEEPALIVE, (Object)true);
            log.info((Object)("Starting Federated Worker server at port: " + this._port));
            ChannelFuture f = b.bind(this._port).sync();
            log.info((Object)("Started Federated Worker at port: " + this._port));
            f.channel().closeFuture().sync();
        }
        catch (Exception e) {
            log.info((Object)"Federated worker interrupted");
            if (this._debug) {
                log.error((Object)e.getMessage());
                e.printStackTrace();
            }
        }
        finally {
            log.info((Object)"Federated Worker Shutting down.");
            workerGroup.shutdownGracefully();
            bossGroup.shutdownGracefully();
        }
    }

    private ChannelInitializer<SocketChannel> createChannel(final boolean ssl) {
        try {
            SelfSignedCertificate cert = new SelfSignedCertificate();
            final SslContext cont2 = SslContextBuilder.forServer((File)cert.certificate(), (File)cert.privateKey()).build();
            return new ChannelInitializer<SocketChannel>(){

                public void initChannel(SocketChannel ch) {
                    ChannelPipeline cp = ch.pipeline();
                    if (ConfigurationManager.getDMLConfig().getBooleanValue("sysds.federated.ssl")) {
                        cp.addLast(new ChannelHandler[]{cont2.newHandler(ch.alloc())});
                    }
                    if (ssl) {
                        cp.addLast(new ChannelHandler[]{cont2.newHandler(ch.alloc())});
                    }
                    cp.addLast("NetworkTrafficCounter", (ChannelHandler)new NetworkTrafficCounter(FederatedStatistics::logWorkerTraffic));
                    cp.addLast("ObjectDecoder", (ChannelHandler)new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.weakCachingResolver((ClassLoader)ClassLoader.getSystemClassLoader())));
                    cp.addLast("ObjectEncoder", (ChannelHandler)new ObjectEncoder());
                    cp.addLast(new ChannelHandler[]{FederationUtils.decoder(), new FederatedResponseEncoder()});
                    cp.addLast(new ChannelHandler[]{new FederatedWorkerHandler(FederatedWorker.this._flt, FederatedWorker.this._frc, FederatedWorker.this._fan, FederatedWorker.this.networkTimer)});
                }
            };
        }
        catch (CertificateException | SSLException e) {
            throw new DMLRuntimeException("Failed creating channel SSL", e);
        }
    }

    public static class FederatedResponseEncoder
    extends ObjectEncoder {
        protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, Serializable msg, boolean preferDirect) throws Exception {
            int initCapacity = 256;
            if (msg instanceof FederatedResponse) {
                FederatedResponse response = (FederatedResponse)msg;
                try {
                    initCapacity = Math.toIntExact(response.estimateSerializationBufferSize());
                }
                catch (ArithmeticException ae) {
                    initCapacity = Integer.MAX_VALUE;
                }
            }
            if (preferDirect) {
                return ctx.alloc().ioBuffer(initCapacity);
            }
            return ctx.alloc().heapBuffer(initCapacity);
        }

        protected void encode(ChannelHandlerContext ctx, Serializable msg, ByteBuf out) throws Exception {
            long t1;
            byte[] cachedBytes;
            FederatedResponse response;
            boolean linReusePossible;
            LineageItem objLI = null;
            boolean bl = linReusePossible = !LineageCacheConfig.ReuseCacheType.isNone() && msg instanceof FederatedResponse;
            if (linReusePossible && (response = (FederatedResponse)msg).getData() != null && response.getData().length != 0 && response.getData()[0] instanceof CacheBlock && (cachedBytes = LineageCache.reuseSerialization(objLI = response.getLineageItem())) != null) {
                out.writeBytes(cachedBytes);
                return;
            }
            int startIdx = (linReusePossible &= objLI != null) ? out.writerIndex() : 0;
            long t0 = linReusePossible ? System.nanoTime() : 0L;
            super.encode(ctx, msg, out);
            long l = t1 = linReusePossible ? System.nanoTime() : 0L;
            if (linReusePossible) {
                out.readerIndex(startIdx);
                byte[] dst = new byte[out.readableBytes()];
                out.readBytes(dst);
                LineageCache.putSerializedObject(dst, objLI, t1 - t0);
                out.resetReaderIndex();
            }
        }
    }
}

