package org.apache.seatunnel.engine.server.task;

import com.hazelcast.cluster.Address;
import com.hazelcast.spi.impl.operationservice.Operation;
import java.io.IOException;
import java.io.Serializable;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.apache.seatunnel.api.serialization.Serializer;
import org.apache.seatunnel.api.source.SourceEvent;
import org.apache.seatunnel.api.source.SourceSplit;
import org.apache.seatunnel.api.source.SourceSplitEnumerator;
import org.apache.seatunnel.engine.common.utils.ExceptionUtil;
import org.apache.seatunnel.engine.core.dag.actions.SourceAction;
import org.apache.seatunnel.engine.core.job.ConnectorJarIdentifier;
import org.apache.seatunnel.engine.server.checkpoint.ActionStateKey;
import org.apache.seatunnel.engine.server.checkpoint.ActionSubtaskState;
import org.apache.seatunnel.engine.server.checkpoint.CheckpointBarrier;
import org.apache.seatunnel.engine.server.checkpoint.operation.TaskAcknowledgeOperation;
import org.apache.seatunnel.engine.server.execution.ProgressState;
import org.apache.seatunnel.engine.server.execution.TaskLocation;
import org.apache.seatunnel.engine.server.task.context.SeaTunnelSplitEnumeratorContext;
import org.apache.seatunnel.engine.server.task.operation.checkpoint.BarrierFlowOperation;
import org.apache.seatunnel.engine.server.task.operation.source.LastCheckpointNotifyOperation;
import org.apache.seatunnel.engine.server.task.record.Barrier;
import org.apache.seatunnel.engine.server.task.statemachine.SeaTunnelTaskState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/seatunnel/engine/server/task/SourceSplitEnumeratorTask.class */
public class SourceSplitEnumeratorTask<SplitT extends SourceSplit> extends CoordinatorTask {
    private static final Logger log = LoggerFactory.getLogger(SourceSplitEnumeratorTask.class);
    private static final long serialVersionUID = -3713701594297977775L;
    private final SourceAction<?, SplitT, Serializable> source;
    private SourceSplitEnumerator<SplitT, Serializable> enumerator;
    private SeaTunnelSplitEnumeratorContext<SplitT> enumeratorContext;
    private Serializer<Serializable> enumeratorStateSerializer;
    private Serializer<SplitT> splitSerializer;
    private int maxReaderSize;
    private Set<Long> unfinishedReaders;
    private Map<TaskLocation, Address> taskMemberMapping;
    private Map<Long, TaskLocation> taskIDToTaskLocationMapping;
    private Map<Integer, TaskLocation> taskIndexToTaskLocationMapping;
    private volatile SeaTunnelTaskState currState;
    private volatile boolean readerRegisterComplete;
    private volatile boolean prepareCloseTriggered;

    @Override // org.apache.seatunnel.engine.server.task.CoordinatorTask, org.apache.seatunnel.engine.server.task.AbstractTask, org.apache.seatunnel.engine.server.execution.Task
    public void init() throws Exception {
        this.currState = SeaTunnelTaskState.INIT;
        super.init();
        this.readerRegisterComplete = false;
        log.info("starting seatunnel source split enumerator task, source name: " + this.source.getName());
        this.enumeratorContext = new SeaTunnelSplitEnumeratorContext<>(this.source.getParallelism(), this, mo44getMetricsContext());
        this.enumeratorStateSerializer = this.source.getSource().getEnumeratorStateSerializer();
        this.splitSerializer = this.source.getSource().getSplitSerializer();
        this.taskMemberMapping = new ConcurrentHashMap();
        this.taskIDToTaskLocationMapping = new ConcurrentHashMap();
        this.taskIndexToTaskLocationMapping = new ConcurrentHashMap();
        this.maxReaderSize = this.source.getParallelism();
        this.unfinishedReaders = new CopyOnWriteArraySet();
    }

    @Override // org.apache.seatunnel.engine.server.task.AbstractTask, org.apache.seatunnel.engine.server.execution.Task
    public void close() throws IOException {
        super.close();
        if (this.enumerator != null) {
            this.enumerator.close();
        }
        this.progress.done();
    }

    /* JADX WARN: Multi-variable type inference failed */
    public SourceSplitEnumeratorTask(long j, TaskLocation taskLocation, SourceAction<?, SplitT, ?> sourceAction) {
        super(j, taskLocation);
        this.source = sourceAction;
        this.currState = SeaTunnelTaskState.CREATED;
    }

    @Override // org.apache.seatunnel.engine.server.task.AbstractTask, org.apache.seatunnel.engine.server.execution.Task
    @NonNull
    public ProgressState call() throws Exception {
        stateProcess();
        return this.progress.toState();
    }

    @Override // org.apache.seatunnel.engine.server.execution.Task
    public void triggerBarrier(Barrier barrier) throws Exception {
        long currentTimeMillis = System.currentTimeMillis();
        log.debug("split enumer trigger barrier [{}]", barrier);
        if (barrier.prepareClose()) {
            this.prepareCloseTriggered = true;
            this.prepareCloseBarrierId.set(barrier.getId());
        }
        long id = barrier.getId();
        Serializable serializable = null;
        byte[] bArr = null;
        synchronized (this.enumeratorContext) {
            if (barrier.snapshot()) {
                serializable = (Serializable) this.enumerator.snapshotState(id);
                bArr = this.enumeratorStateSerializer.serialize(serializable);
            }
            log.debug("source split enumerator send state [{}] to master", serializable);
            sendToAllReader(taskLocation -> {
                return new BarrierFlowOperation(barrier, taskLocation);
            });
        }
        if (barrier.snapshot()) {
            getExecutionContext().sendToMaster(new TaskAcknowledgeOperation(this.taskLocation, (CheckpointBarrier) barrier, Collections.singletonList(new ActionSubtaskState(ActionStateKey.of(this.source), -1, Collections.singletonList(bArr))))).join();
        }
        log.debug("trigger barrier [{}] finished, cost {}ms. taskLocation [{}]", new Object[]{Long.valueOf(barrier.getId()), Long.valueOf(System.currentTimeMillis() - currentTimeMillis), this.taskLocation});
    }

    @Override // org.apache.seatunnel.engine.server.execution.Task, org.apache.seatunnel.engine.server.checkpoint.Stateful
    public void restoreState(List<ActionSubtaskState> list) throws Exception {
        log.debug("restoreState for split enumerator [{}]", list);
        Optional findFirst = list.stream().map((v0) -> {
            return v0.getState();
        }).flatMap((v0) -> {
            return v0.stream();
        }).filter((v0) -> {
            return Objects.nonNull(v0);
        }).map(bArr -> {
            return (Serializable) ExceptionUtil.sneaky(() -> {
                return (Serializable) this.enumeratorStateSerializer.deserialize(bArr);
            });
        }).findFirst();
        if (findFirst.isPresent()) {
            this.enumerator = this.source.getSource().restoreEnumerator(this.enumeratorContext, (Serializable) findFirst.get());
        } else {
            this.enumerator = this.source.getSource().createEnumerator(this.enumeratorContext);
        }
        this.restoreComplete.complete(null);
        log.debug("restoreState split enumerator [{}] finished", list);
    }

    public Serializer<SplitT> getSplitSerializer() throws ExecutionException, InterruptedException {
        getEnumerator();
        return this.splitSerializer;
    }

    public void addSplitsBack(List<SplitT> list, int i) throws ExecutionException, InterruptedException {
        getEnumerator().addSplitsBack(list, i);
    }

    public void receivedReader(TaskLocation taskLocation, Address address) throws InterruptedException, ExecutionException {
        log.info("received reader register, readerID: " + taskLocation);
        SourceSplitEnumerator<SplitT, Serializable> enumerator = getEnumerator();
        addTaskMemberMapping(taskLocation, address);
        enumerator.registerReader(taskLocation.getTaskIndex());
        int size = this.taskMemberMapping.size();
        if (this.maxReaderSize != size) {
            log.debug(String.format("current task size %d, need size %d to complete register", Integer.valueOf(size), Integer.valueOf(this.maxReaderSize)));
        } else {
            this.readerRegisterComplete = true;
            log.debug(String.format("reader register complete, current task size %d", Integer.valueOf(size)));
        }
    }

    public void requestSplit(long j) throws ExecutionException, InterruptedException {
        getEnumerator().handleSplitRequest((int) j);
    }

    public void handleSourceEvent(int i, SourceEvent sourceEvent) throws ExecutionException, InterruptedException {
        getEnumerator().handleSourceEvent(i, sourceEvent);
    }

    public void addTaskMemberMapping(TaskLocation taskLocation, Address address) {
        this.taskMemberMapping.put(taskLocation, address);
        this.taskIDToTaskLocationMapping.put(Long.valueOf(taskLocation.getTaskID()), taskLocation);
        this.taskIndexToTaskLocationMapping.put(Integer.valueOf(taskLocation.getTaskIndex()), taskLocation);
        this.unfinishedReaders.add(Long.valueOf(taskLocation.getTaskID()));
    }

    public Address getTaskMemberAddress(long j) {
        return this.taskMemberMapping.get(this.taskIDToTaskLocationMapping.get(Long.valueOf(j)));
    }

    public TaskLocation getTaskMemberLocation(long j) {
        return this.taskIDToTaskLocationMapping.get(Long.valueOf(j));
    }

    public Address getTaskMemberAddressByIndex(int i) {
        return this.taskMemberMapping.get(this.taskIndexToTaskLocationMapping.get(Integer.valueOf(i)));
    }

    public TaskLocation getTaskMemberLocationByIndex(int i) {
        return this.taskIndexToTaskLocationMapping.get(Integer.valueOf(i));
    }

    private SourceSplitEnumerator<SplitT, Serializable> getEnumerator() throws InterruptedException, ExecutionException {
        while (null == this.restoreComplete) {
            log.warn("Task init is not complete, try to get it again after 200 ms");
            Thread.sleep(200L);
        }
        this.restoreComplete.get();
        return this.enumerator;
    }

    public void readerFinished(long j) {
        this.unfinishedReaders.remove(Long.valueOf(j));
        if (this.unfinishedReaders.isEmpty()) {
            this.prepareCloseStatus = true;
        }
    }

    private void stateProcess() throws Exception {
        switch (this.currState) {
            case INIT:
                this.currState = SeaTunnelTaskState.WAITING_RESTORE;
                reportTaskStatus(SeaTunnelTaskState.WAITING_RESTORE);
                return;
            case WAITING_RESTORE:
                if (!this.restoreComplete.isDone()) {
                    Thread.sleep(100L);
                    return;
                } else {
                    this.currState = SeaTunnelTaskState.READY_START;
                    reportTaskStatus(SeaTunnelTaskState.READY_START);
                    return;
                }
            case READY_START:
                if (!this.startCalled || !this.readerRegisterComplete) {
                    Thread.sleep(100L);
                    return;
                } else {
                    this.currState = SeaTunnelTaskState.STARTING;
                    this.enumerator.open();
                    return;
                }
            case STARTING:
                this.currState = SeaTunnelTaskState.RUNNING;
                log.info("received enough reader, starting enumerator...");
                this.enumerator.run();
                return;
            case RUNNING:
                if (this.prepareCloseStatus) {
                    getExecutionContext().sendToMaster(new LastCheckpointNotifyOperation(this.jobID, this.taskLocation));
                    this.currState = SeaTunnelTaskState.PREPARE_CLOSE;
                    return;
                } else if (this.prepareCloseTriggered) {
                    this.currState = SeaTunnelTaskState.PREPARE_CLOSE;
                    return;
                } else {
                    Thread.sleep(100L);
                    return;
                }
            case PREPARE_CLOSE:
                if (this.closeCalled) {
                    this.currState = SeaTunnelTaskState.CLOSED;
                    return;
                } else {
                    Thread.sleep(100L);
                    return;
                }
            case CLOSED:
                close();
                return;
            case CANCELLING:
                close();
                this.currState = SeaTunnelTaskState.CANCELED;
                return;
            default:
                throw new IllegalArgumentException("Unknown Enumerator State: " + this.currState);
        }
    }

    public Set<Integer> getRegisteredReaders() {
        return (Set) this.taskMemberMapping.keySet().stream().map((v0) -> {
            return v0.getTaskIndex();
        }).collect(Collectors.toSet());
    }

    private void sendToAllReader(Function<TaskLocation, Operation> function) {
        ArrayList arrayList = new ArrayList();
        this.taskMemberMapping.forEach((taskLocation, address) -> {
            log.debug("split enumerator send to read--size: {}, location: {}, address: {}", new Object[]{Integer.valueOf(this.taskMemberMapping.size()), taskLocation, address.toString()});
            arrayList.add(getExecutionContext().sendToMember((Operation) function.apply(taskLocation), address));
        });
        arrayList.forEach((v0) -> {
            v0.join();
        });
    }

    @Override // org.apache.seatunnel.engine.server.task.AbstractTask
    public Set<URL> getJarsUrl() {
        return new HashSet(this.source.getJarUrls());
    }

    @Override // org.apache.seatunnel.engine.server.task.AbstractTask
    public Set<ConnectorJarIdentifier> getConnectorPluginJars() {
        return new HashSet(this.source.getConnectorJarIdentifiers());
    }

    public void notifyCheckpointComplete(long j) throws Exception {
        getEnumerator().notifyCheckpointComplete(j);
        if (this.prepareCloseBarrierId.get() == j) {
            closeCall();
        }
    }

    public void notifyCheckpointAborted(long j) throws Exception {
        getEnumerator().notifyCheckpointAborted(j);
        if (this.prepareCloseBarrierId.get() == j) {
            closeCall();
        }
    }
}
