/*
 * Decompiled with CFR 0.152.
 */
package com.alipay.sofa.jraft.rpc.impl;

import com.alipay.sofa.jraft.ReplicatorGroup;
import com.alipay.sofa.jraft.entity.PeerId;
import com.alipay.sofa.jraft.error.InvokeTimeoutException;
import com.alipay.sofa.jraft.error.RemotingException;
import com.alipay.sofa.jraft.option.RpcOptions;
import com.alipay.sofa.jraft.rpc.InvokeCallback;
import com.alipay.sofa.jraft.rpc.InvokeContext;
import com.alipay.sofa.jraft.rpc.RpcClient;
import com.alipay.sofa.jraft.rpc.RpcUtils;
import com.alipay.sofa.jraft.rpc.impl.GrpcRaftRpcFactory;
import com.alipay.sofa.jraft.rpc.impl.ManagedChannelHelper;
import com.alipay.sofa.jraft.rpc.impl.MarshallerRegistry;
import com.alipay.sofa.jraft.util.DirectExecutor;
import com.alipay.sofa.jraft.util.Endpoint;
import com.alipay.sofa.jraft.util.Requires;
import com.alipay.sofa.jraft.util.SystemPropertyUtil;
import com.google.protobuf.Message;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.MethodDescriptor;
import io.grpc.protobuf.ProtoUtils;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.StreamObserver;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GrpcClient
implements RpcClient {
    private static final Logger LOG = LoggerFactory.getLogger(GrpcClient.class);
    private static final int RESET_CONN_THRESHOLD = SystemPropertyUtil.getInt((String)"jraft.grpc.max.conn.failures.to_reset", (int)2);
    private final Map<Endpoint, ManagedChannel> managedChannelPool = new ConcurrentHashMap<Endpoint, ManagedChannel>();
    private final Map<Endpoint, AtomicInteger> transientFailures = new ConcurrentHashMap<Endpoint, AtomicInteger>();
    private final Map<String, Message> parserClasses;
    private final MarshallerRegistry marshallerRegistry;
    private volatile ReplicatorGroup replicatorGroup;

    public GrpcClient(Map<String, Message> parserClasses, MarshallerRegistry marshallerRegistry) {
        this.parserClasses = parserClasses;
        this.marshallerRegistry = marshallerRegistry;
    }

    public boolean init(RpcOptions opts) {
        return true;
    }

    public void shutdown() {
        this.closeAllChannels();
        this.transientFailures.clear();
    }

    public boolean checkConnection(Endpoint endpoint) {
        return this.checkConnection(endpoint, false);
    }

    public boolean checkConnection(Endpoint endpoint, boolean createIfAbsent) {
        Requires.requireNonNull((Object)endpoint, (String)"endpoint");
        return this.checkChannel(endpoint, createIfAbsent);
    }

    public void closeConnection(Endpoint endpoint) {
        Requires.requireNonNull((Object)endpoint, (String)"endpoint");
        this.closeChannel(endpoint);
    }

    public void registerConnectEventListener(ReplicatorGroup replicatorGroup) {
        this.replicatorGroup = replicatorGroup;
    }

    public Object invokeSync(Endpoint endpoint, Object request, InvokeContext ctx, long timeoutMs) throws RemotingException {
        CompletableFuture future = new CompletableFuture();
        this.invokeAsync(endpoint, request, ctx, (result, err) -> {
            if (err == null) {
                future.complete(result);
            } else {
                future.completeExceptionally(err);
            }
        }, timeoutMs);
        try {
            return future.get(timeoutMs, TimeUnit.MILLISECONDS);
        }
        catch (TimeoutException e) {
            future.cancel(true);
            throw new InvokeTimeoutException((Throwable)e);
        }
        catch (Throwable t) {
            future.cancel(true);
            throw new RemotingException(t);
        }
    }

    public void invokeAsync(Endpoint endpoint, Object request, InvokeContext ctx, final InvokeCallback callback, long timeoutMs) {
        Requires.requireNonNull((Object)endpoint, (String)"endpoint");
        Requires.requireNonNull((Object)request, (String)"request");
        final Executor executor = callback.executor() != null ? callback.executor() : DirectExecutor.INSTANCE;
        ManagedChannel ch = this.getCheckedChannel(endpoint);
        if (ch == null) {
            executor.execute(() -> callback.complete(null, (Throwable)new RemotingException("Fail to connect: " + endpoint)));
            return;
        }
        MethodDescriptor<Message, Message> method = this.getCallMethod(request);
        CallOptions callOpts = CallOptions.DEFAULT.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
        ClientCalls.asyncUnaryCall((ClientCall)ch.newCall(method, callOpts), (Object)((Message)request), (StreamObserver)new StreamObserver<Message>(){

            public void onNext(Message value) {
                executor.execute(() -> callback.complete((Object)value, null));
            }

            public void onError(Throwable throwable) {
                executor.execute(() -> callback.complete(null, throwable));
            }

            public void onCompleted() {
            }
        });
    }

    private MethodDescriptor<Message, Message> getCallMethod(Object request) {
        String interest = request.getClass().getName();
        Message reqIns = (Message)Requires.requireNonNull((Object)this.parserClasses.get(interest), (String)("null default instance: " + interest));
        return MethodDescriptor.newBuilder().setType(MethodDescriptor.MethodType.UNARY).setFullMethodName(MethodDescriptor.generateFullMethodName((String)interest, (String)"_call")).setRequestMarshaller(ProtoUtils.marshaller((Message)reqIns)).setResponseMarshaller(ProtoUtils.marshaller((Message)this.marshallerRegistry.findResponseInstanceByRequest(interest))).build();
    }

    private ManagedChannel getCheckedChannel(Endpoint endpoint) {
        ManagedChannel ch = this.getChannel(endpoint, true);
        if (this.checkConnectivity(endpoint, ch)) {
            return ch;
        }
        return null;
    }

    private ManagedChannel getChannel(Endpoint endpoint, boolean createIfAbsent) {
        if (createIfAbsent) {
            return this.managedChannelPool.computeIfAbsent(endpoint, this::newChannel);
        }
        return this.managedChannelPool.get(endpoint);
    }

    private ManagedChannel newChannel(Endpoint endpoint) {
        ManagedChannel ch = ManagedChannelBuilder.forAddress((String)endpoint.getIp(), (int)endpoint.getPort()).usePlaintext().directExecutor().maxInboundMessageSize(GrpcRaftRpcFactory.RPC_MAX_INBOUND_MESSAGE_SIZE).build();
        LOG.info("Creating new channel to: {}.", (Object)endpoint);
        this.notifyWhenStateChanged(ConnectivityState.IDLE, endpoint, ch);
        return ch;
    }

    private ManagedChannel removeChannel(Endpoint endpoint) {
        return this.managedChannelPool.remove(endpoint);
    }

    private void notifyWhenStateChanged(ConnectivityState state, Endpoint endpoint, ManagedChannel ch) {
        ch.notifyWhenStateChanged(state, () -> this.onStateChanged(endpoint, ch));
    }

    private void onStateChanged(Endpoint endpoint, ManagedChannel ch) {
        ConnectivityState state = ch.getState(false);
        LOG.info("The channel {} is in state: {}.", (Object)endpoint, (Object)state);
        switch (state) {
            case READY: {
                this.notifyReady(endpoint);
                this.notifyWhenStateChanged(ConnectivityState.READY, endpoint, ch);
                break;
            }
            case TRANSIENT_FAILURE: {
                this.notifyFailure(endpoint);
                this.notifyWhenStateChanged(ConnectivityState.TRANSIENT_FAILURE, endpoint, ch);
                break;
            }
            case SHUTDOWN: {
                this.notifyShutdown(endpoint);
                break;
            }
            case CONNECTING: {
                this.notifyWhenStateChanged(ConnectivityState.CONNECTING, endpoint, ch);
                break;
            }
            case IDLE: {
                this.notifyWhenStateChanged(ConnectivityState.IDLE, endpoint, ch);
            }
        }
    }

    private void notifyReady(Endpoint endpoint) {
        LOG.info("The channel {} has successfully established.", (Object)endpoint);
        this.clearConnFailuresCount(endpoint);
        ReplicatorGroup rpGroup = this.replicatorGroup;
        if (rpGroup != null) {
            try {
                RpcUtils.runInThread(() -> {
                    PeerId peer = new PeerId();
                    if (peer.parse(endpoint.toString())) {
                        LOG.info("Peer {} is connected.", (Object)peer);
                        rpGroup.checkReplicator(peer, true);
                    } else {
                        LOG.error("Fail to parse peer: {}.", (Object)endpoint);
                    }
                });
            }
            catch (Throwable t) {
                LOG.error("Fail to check replicator {}.", (Object)endpoint, (Object)t);
            }
        }
    }

    private void notifyFailure(Endpoint endpoint) {
        LOG.warn("There has been some transient failure on this channel {}.", (Object)endpoint);
    }

    private void notifyShutdown(Endpoint endpoint) {
        LOG.warn("This channel {} has started shutting down. Any new RPCs should fail immediately.", (Object)endpoint);
    }

    private void closeAllChannels() {
        for (Map.Entry<Endpoint, ManagedChannel> entry : this.managedChannelPool.entrySet()) {
            ManagedChannel ch = entry.getValue();
            LOG.info("Shutdown managed channel: {}, {}.", (Object)entry.getKey(), (Object)ch);
            ManagedChannelHelper.shutdownAndAwaitTermination(ch);
        }
        this.managedChannelPool.clear();
    }

    private void closeChannel(Endpoint endpoint) {
        ManagedChannel ch = this.removeChannel(endpoint);
        LOG.info("Close connection: {}, {}.", (Object)endpoint, (Object)ch);
        if (ch != null) {
            ManagedChannelHelper.shutdownAndAwaitTermination(ch);
        }
    }

    private boolean checkChannel(Endpoint endpoint, boolean createIfAbsent) {
        ManagedChannel ch = this.getChannel(endpoint, createIfAbsent);
        if (ch == null) {
            return false;
        }
        return this.checkConnectivity(endpoint, ch);
    }

    private int incConnFailuresCount(Endpoint endpoint) {
        return this.transientFailures.computeIfAbsent(endpoint, ep -> new AtomicInteger()).incrementAndGet();
    }

    private void clearConnFailuresCount(Endpoint endpoint) {
        this.transientFailures.remove(endpoint);
    }

    private boolean checkConnectivity(Endpoint endpoint, ManagedChannel ch) {
        ConnectivityState st = ch.getState(false);
        if (st != ConnectivityState.TRANSIENT_FAILURE && st != ConnectivityState.SHUTDOWN) {
            return true;
        }
        int c = this.incConnFailuresCount(endpoint);
        if (c < RESET_CONN_THRESHOLD) {
            if (c == RESET_CONN_THRESHOLD - 1) {
                ch.resetConnectBackoff();
            }
            return true;
        }
        this.clearConnFailuresCount(endpoint);
        ManagedChannel removedCh = this.removeChannel(endpoint);
        if (removedCh == null) {
            return false;
        }
        LOG.warn("Channel[{}] in [INACTIVE] state {} times, it has been removed from the pool.", (Object)endpoint, (Object)c);
        if (removedCh != ch) {
            ManagedChannelHelper.shutdownAndAwaitTermination(removedCh, 100L);
        }
        ManagedChannelHelper.shutdownAndAwaitTermination(ch, 100L);
        return false;
    }
}

