/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0 and the Server Side Public License, v 1; you may not use this file except
 * in compliance with, at your election, the Elastic License 2.0 or the Server
 * Side Public License, v 1.
 */

package org.elasticsearch.discovery.zen;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.coordination.ValidateJoinRequest;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.EmptyTransportResponseHandler;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.Collection;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

public class MembershipAction {

    private static final Logger logger = LogManager.getLogger(MembershipAction.class);

    public static final String DISCOVERY_JOIN_ACTION_NAME = "internal:discovery/zen/join";
    public static final String DISCOVERY_JOIN_VALIDATE_ACTION_NAME = "internal:discovery/zen/join/validate";
    public static final String DISCOVERY_LEAVE_ACTION_NAME = "internal:discovery/zen/leave";

    public interface JoinCallback {
        void onSuccess();

        void onFailure(Exception e);
    }

    public interface MembershipListener {
        void onJoin(DiscoveryNode node, JoinCallback callback);

        void onLeave(DiscoveryNode node);
    }

    private final TransportService transportService;

    private final MembershipListener listener;

    public MembershipAction(
        TransportService transportService,
        MembershipListener listener,
        Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators
    ) {
        this.transportService = transportService;
        this.listener = listener;

        transportService.registerRequestHandler(
            DISCOVERY_JOIN_ACTION_NAME,
            ThreadPool.Names.GENERIC,
            JoinRequest::new,
            new JoinRequestRequestHandler()
        );
        transportService.registerRequestHandler(
            DISCOVERY_JOIN_VALIDATE_ACTION_NAME,
            ThreadPool.Names.GENERIC,
            ValidateJoinRequest::new,
            new ValidateJoinRequestRequestHandler(transportService::getLocalNode, joinValidators)
        );
        transportService.registerRequestHandler(
            DISCOVERY_LEAVE_ACTION_NAME,
            ThreadPool.Names.GENERIC,
            LeaveRequest::new,
            new LeaveRequestRequestHandler()
        );
    }

    public void sendLeaveRequest(DiscoveryNode masterNode, DiscoveryNode node) {
        transportService.sendRequest(
            node,
            DISCOVERY_LEAVE_ACTION_NAME,
            new LeaveRequest(masterNode),
            EmptyTransportResponseHandler.INSTANCE_SAME
        );
    }

    public void sendLeaveRequestBlocking(DiscoveryNode masterNode, DiscoveryNode node, TimeValue timeout) {
        transportService.submitRequest(
            masterNode,
            DISCOVERY_LEAVE_ACTION_NAME,
            new LeaveRequest(node),
            EmptyTransportResponseHandler.INSTANCE_SAME
        ).txGet(timeout.millis(), TimeUnit.MILLISECONDS);
    }

    public void sendJoinRequestBlocking(DiscoveryNode masterNode, DiscoveryNode node, TimeValue timeout) {
        transportService.submitRequest(
            masterNode,
            DISCOVERY_JOIN_ACTION_NAME,
            new JoinRequest(node),
            EmptyTransportResponseHandler.INSTANCE_SAME
        ).txGet(timeout.millis(), TimeUnit.MILLISECONDS);
    }

    /**
     * Validates the join request, throwing a failure if it failed.
     */
    public void sendValidateJoinRequestBlocking(DiscoveryNode node, ClusterState state, TimeValue timeout) {
        transportService.submitRequest(
            node,
            DISCOVERY_JOIN_VALIDATE_ACTION_NAME,
            new ValidateJoinRequest(state),
            EmptyTransportResponseHandler.INSTANCE_SAME
        ).txGet(timeout.millis(), TimeUnit.MILLISECONDS);
    }

    public static class JoinRequest extends TransportRequest {

        private DiscoveryNode node;

        public DiscoveryNode getNode() {
            return node;
        }

        public JoinRequest(StreamInput in) throws IOException {
            super(in);
            node = new DiscoveryNode(in);
        }

        public JoinRequest(DiscoveryNode node) {
            this.node = node;
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            super.writeTo(out);
            node.writeTo(out);
        }
    }

    private class JoinRequestRequestHandler implements TransportRequestHandler<JoinRequest> {

        @Override
        public void messageReceived(final JoinRequest request, final TransportChannel channel, Task task) throws Exception {
            listener.onJoin(request.getNode(), new JoinCallback() {
                @Override
                public void onSuccess() {
                    try {
                        channel.sendResponse(TransportResponse.Empty.INSTANCE);
                    } catch (Exception e) {
                        onFailure(e);
                    }
                }

                @Override
                public void onFailure(Exception e) {
                    try {
                        channel.sendResponse(e);
                    } catch (Exception inner) {
                        inner.addSuppressed(e);
                        logger.warn("failed to send back failure on join request", inner);
                    }
                }
            });
        }
    }

    static class ValidateJoinRequestRequestHandler implements TransportRequestHandler<ValidateJoinRequest> {
        private final Supplier<DiscoveryNode> localNodeSupplier;
        private final Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators;

        ValidateJoinRequestRequestHandler(
            Supplier<DiscoveryNode> localNodeSupplier,
            Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators
        ) {
            this.localNodeSupplier = localNodeSupplier;
            this.joinValidators = joinValidators;
        }

        @Override
        public void messageReceived(ValidateJoinRequest request, TransportChannel channel, Task task) throws Exception {
            DiscoveryNode node = localNodeSupplier.get();
            assert node != null : "local node is null";
            joinValidators.stream().forEach(action -> action.accept(node, request.getState()));
            channel.sendResponse(TransportResponse.Empty.INSTANCE);
        }
    }

    public static class LeaveRequest extends TransportRequest {

        private DiscoveryNode node;

        public LeaveRequest(StreamInput in) throws IOException {
            super(in);
            node = new DiscoveryNode(in);
        }

        private LeaveRequest(DiscoveryNode node) {
            this.node = node;
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            super.writeTo(out);
            node.writeTo(out);
        }
    }

    private class LeaveRequestRequestHandler implements TransportRequestHandler<LeaveRequest> {

        @Override
        public void messageReceived(LeaveRequest request, TransportChannel channel, Task task) throws Exception {
            listener.onLeave(request.node);
            channel.sendResponse(TransportResponse.Empty.INSTANCE);
        }
    }
}
