package com.hazelcast.jet.impl.pipeline.transform;

import com.hazelcast.function.BiFunctionEx;
import com.hazelcast.function.FunctionEx;
import com.hazelcast.function.Functions;
import com.hazelcast.jet.aggregate.AggregateOperation;
import com.hazelcast.jet.core.Edge;
import com.hazelcast.jet.core.Partitioner;
import com.hazelcast.jet.core.Vertex;
import com.hazelcast.jet.core.processor.Processors;
import com.hazelcast.jet.impl.pipeline.PipelineImpl;
import com.hazelcast.jet.impl.pipeline.Planner;
import java.util.List;
import javax.annotation.Nonnull;

/* loaded from: input_file:com/hazelcast/jet/impl/pipeline/transform/GroupTransform.class */
public class GroupTransform<K, A, R, OUT> extends AbstractTransform {
    private static final long serialVersionUID = 1;

    @Nonnull
    private final List<FunctionEx<?, ? extends K>> groupKeyFns;

    @Nonnull
    private final AggregateOperation<A, R> aggrOp;

    @Nonnull
    private final BiFunctionEx<? super K, ? super R, OUT> mapToOutputFn;

    public GroupTransform(@Nonnull List<Transform> list, @Nonnull List<FunctionEx<?, ? extends K>> list2, @Nonnull AggregateOperation<A, R> aggregateOperation, @Nonnull BiFunctionEx<? super K, ? super R, OUT> biFunctionEx) {
        super(createName(list), list);
        this.groupKeyFns = list2;
        this.aggrOp = aggregateOperation;
        this.mapToOutputFn = biFunctionEx;
    }

    private static String createName(@Nonnull List<Transform> list) {
        return list.size() == 1 ? "group-and-aggregate" : list.size() + "-way cogroup-and-aggregate";
    }

    @Override // com.hazelcast.jet.impl.pipeline.transform.Transform
    public void addToDag(Planner planner, PipelineImpl.Context context) {
        determineLocalParallelism(-1, context, false);
        if (shouldRebalanceAnyInput() || this.aggrOp.combineFn() == null) {
            addToDagSingleStage(planner);
        } else {
            addToDagTwoStage(planner);
        }
    }

    private void addToDagSingleStage(Planner planner) {
        planner.addEdges(this, planner.addVertex(this, name(), determinedLocalParallelism(), Processors.aggregateByKeyP(this.groupKeyFns, this.aggrOp, this.mapToOutputFn)).v, (edge, i) -> {
            edge.distributed().partitioned(this.groupKeyFns.get(i));
        });
    }

    private void addToDagTwoStage(Planner planner) {
        List<FunctionEx<?, ? extends K>> list = this.groupKeyFns;
        Vertex localParallelism = planner.dag.newVertex(name() + AggregateTransform.FIRST_STAGE_VERTEX_NAME_SUFFIX, Processors.accumulateByKeyP(list, this.aggrOp)).localParallelism(determinedLocalParallelism());
        Planner.PlannerVertex addVertex = planner.addVertex(this, name(), determinedLocalParallelism(), Processors.combineByKeyP(this.aggrOp, this.mapToOutputFn));
        planner.addEdges(this, localParallelism, (edge, i) -> {
            edge.partitioned((FunctionEx) list.get(i), Partitioner.HASH_CODE);
        });
        planner.dag.edge(Edge.between(localParallelism, addVertex.v).distributed().partitioned(Functions.entryKey()));
    }
}
