/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.delegation.hive;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollationImpl;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.flink.table.catalog.CatalogManager;
import org.apache.flink.table.catalog.ObjectIdentifier;
import org.apache.flink.table.catalog.UnresolvedIdentifier;
import org.apache.flink.table.catalog.hive.HiveCatalog;
import org.apache.flink.table.operations.CatalogSinkModifyOperation;
import org.apache.flink.table.operations.Operation;
import org.apache.flink.table.operations.QueryOperation;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.delegation.PlannerContext;
import org.apache.flink.table.planner.delegation.hive.HiveParserCalcitePlanner;
import org.apache.flink.table.planner.delegation.hive.HiveParserUtils;
import org.apache.flink.table.planner.delegation.hive.SqlFunctionConverter;
import org.apache.flink.table.planner.delegation.hive.copy.HiveParserQB;
import org.apache.flink.table.planner.delegation.hive.copy.HiveParserSqlFunctionConverter;
import org.apache.flink.table.planner.delegation.hive.copy.HiveParserTypeConverter;
import org.apache.flink.table.planner.operations.PlannerQueryOperation;
import org.apache.flink.table.planner.plan.nodes.hive.LogicalDistribution;
import org.apache.flink.util.Preconditions;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.exec.FunctionInfo;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.metadata.Partition;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.parse.QBMetaData;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.SettableUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;

public class HiveParserDMLHelper {
    private final PlannerContext plannerContext;
    private final SqlFunctionConverter funcConverter;
    private final CatalogManager catalogManager;

    public HiveParserDMLHelper(PlannerContext plannerContext, SqlFunctionConverter funcConverter, CatalogManager catalogManager) {
        this.plannerContext = plannerContext;
        this.funcConverter = funcConverter;
        this.catalogManager = catalogManager;
    }

    public CatalogSinkModifyOperation createInsertOperation(RelNode queryRelNode, Table destTable, Map<String, String> staticPartSpec, List<String> destSchema, boolean overwrite) throws SemanticException {
        Preconditions.checkArgument((queryRelNode instanceof Project || queryRelNode instanceof Sort || queryRelNode instanceof LogicalDistribution ? 1 : 0) != 0, (Object)("Expect top RelNode to be Project, Sort, or LogicalDistribution, actually got " + queryRelNode));
        if (!(queryRelNode instanceof Project)) {
            RelNode parent = ((SingleRel)queryRelNode).getInput();
            Preconditions.checkArgument((parent instanceof Project || parent instanceof LogicalDistribution ? 1 : 0) != 0, (Object)("Expect input to be a Project or LogicalDistribution, actually got " + parent));
            if (parent instanceof LogicalDistribution) {
                RelNode grandParent = ((LogicalDistribution)parent).getInput();
                Preconditions.checkArgument((boolean)(grandParent instanceof Project), (Object)("Expect input of LogicalDistribution to be a Project, actually got " + grandParent));
            }
        }
        queryRelNode = this.handleDestSchema((SingleRel)queryRelNode, destTable, destSchema, staticPartSpec.keySet());
        FlinkTypeFactory typeFactory = this.plannerContext.getTypeFactory();
        LinkedHashMap<String, RelDataType> targetColToCalcType = new LinkedHashMap<String, RelDataType>();
        ArrayList<TypeInfo> targetHiveTypes = new ArrayList<TypeInfo>();
        ArrayList<FieldSchema> allCols = new ArrayList<FieldSchema>(destTable.getCols());
        allCols.addAll(destTable.getPartCols());
        for (FieldSchema col : allCols) {
            TypeInfo hiveType = TypeInfoUtils.getTypeInfoFromTypeString(col.getType());
            targetHiveTypes.add(hiveType);
            targetColToCalcType.put(col.getName(), HiveParserTypeConverter.convert(hiveType, (RelDataTypeFactory)typeFactory));
        }
        if (!staticPartSpec.isEmpty()) {
            if (queryRelNode instanceof Project) {
                queryRelNode = this.replaceProjectForStaticPart((Project)queryRelNode, staticPartSpec, destTable, targetColToCalcType);
            } else if (queryRelNode instanceof Sort) {
                RelNode newInput;
                Sort sort = (Sort)queryRelNode;
                RelNode oldInput = sort.getInput();
                if (oldInput instanceof LogicalDistribution) {
                    newInput = this.replaceDistForStaticParts((LogicalDistribution)oldInput, destTable, staticPartSpec, targetColToCalcType);
                } else {
                    newInput = this.replaceProjectForStaticPart((Project)oldInput, staticPartSpec, destTable, targetColToCalcType);
                    int numDynmPart = destTable.getTTable().getPartitionKeys().size() - staticPartSpec.size();
                    if (!sort.getCollation().getFieldCollations().isEmpty() && numDynmPart > 0) {
                        sort.replaceInput(0, null);
                        sort = LogicalSort.create((RelNode)newInput, (RelCollation)this.shiftRelCollation(sort.getCollation(), (Project)oldInput, staticPartSpec.size(), numDynmPart), (RexNode)sort.offset, (RexNode)sort.fetch);
                    }
                }
                sort.replaceInput(0, newInput);
                queryRelNode = sort;
            } else {
                queryRelNode = this.replaceDistForStaticParts((LogicalDistribution)queryRelNode, destTable, staticPartSpec, targetColToCalcType);
            }
        }
        queryRelNode = HiveParserDMLHelper.addTypeConversions(this.plannerContext.getCluster().getRexBuilder(), queryRelNode, new ArrayList<RelDataType>(targetColToCalcType.values()), targetHiveTypes, this.funcConverter);
        List<String> targetTablePath = Arrays.asList(destTable.getDbName(), destTable.getTableName());
        UnresolvedIdentifier unresolvedIdentifier = UnresolvedIdentifier.of(targetTablePath);
        ObjectIdentifier identifier = this.catalogManager.qualifyIdentifier(unresolvedIdentifier);
        return new CatalogSinkModifyOperation(identifier, (QueryOperation)new PlannerQueryOperation(queryRelNode), staticPartSpec, overwrite, Collections.emptyMap());
    }

    public Operation createInsertOperation(HiveParserCalcitePlanner analyzer, RelNode queryRelNode) throws SemanticException {
        Table destTable;
        String insClauseName;
        HiveParserQB topQB = analyzer.getQB();
        QBMetaData qbMetaData = topQB.getMetaData();
        Map<String, Table> nameToDestTable = qbMetaData.getNameToDestTable();
        Map<String, Partition> nameToDestPart = qbMetaData.getNameToDestPartition();
        Preconditions.checkState((nameToDestTable.size() <= 1 && nameToDestPart.size() <= 1 ? 1 : 0) != 0, (Object)"Only support inserting to 1 table");
        if (!nameToDestTable.isEmpty()) {
            insClauseName = nameToDestTable.keySet().iterator().next();
            destTable = nameToDestTable.values().iterator().next();
        } else if (!nameToDestPart.isEmpty()) {
            insClauseName = nameToDestPart.keySet().iterator().next();
            destTable = nameToDestPart.values().iterator().next().getTable();
        } else {
            throw new SemanticException("INSERT DIRECTORY is not supported");
        }
        LinkedHashMap<String, String> staticPartSpec = new LinkedHashMap<String, String>();
        if (destTable.isPartitioned()) {
            List<String> partCols = HiveCatalog.getFieldNames(destTable.getTTable().getPartitionKeys());
            if (!nameToDestPart.isEmpty()) {
                Partition destPart = nameToDestPart.values().iterator().next();
                Preconditions.checkState((partCols.size() == destPart.getValues().size() ? 1 : 0) != 0, (Object)"Part cols and static spec doesn't match");
                for (int i = 0; i < partCols.size(); ++i) {
                    staticPartSpec.put(partCols.get(i), destPart.getValues().get(i));
                }
            } else {
                Map<String, String> spec = qbMetaData.getPartSpecForAlias(insClauseName);
                if (spec != null) {
                    for (String partCol : partCols) {
                        String val = spec.get(partCol);
                        if (val == null) continue;
                        staticPartSpec.put(partCol, val);
                    }
                }
            }
        }
        boolean overwrite = topQB.getParseInfo().getInsertOverwriteTables().keySet().stream().map(String::toLowerCase).collect(Collectors.toSet()).contains(destTable.getDbName() + "." + destTable.getTableName());
        return this.createInsertOperation(queryRelNode, destTable, staticPartSpec, analyzer.getDestSchemaForClause(insClauseName), overwrite);
    }

    private RelNode replaceDistForStaticParts(LogicalDistribution hiveDist, Table destTable, Map<String, String> staticPartSpec, Map<String, RelDataType> targetColToType) {
        Project project = (Project)hiveDist.getInput();
        RelNode expandedProject = this.replaceProjectForStaticPart(project, staticPartSpec, destTable, targetColToType);
        hiveDist.replaceInput(0, null);
        int toShift = staticPartSpec.size();
        int numDynmPart = destTable.getTTable().getPartitionKeys().size() - toShift;
        return LogicalDistribution.create((RelNode)expandedProject, (RelCollation)this.shiftRelCollation(hiveDist.getCollation(), project, toShift, numDynmPart), HiveParserDMLHelper.shiftDistKeys(hiveDist.getDistKeys(), project, toShift, numDynmPart));
    }

    private static List<Integer> shiftDistKeys(List<Integer> distKeys, Project origProject, int toShift, int numDynmPart) {
        ArrayList<Integer> shiftedDistKeys = new ArrayList<Integer>(distKeys.size());
        int insertIndex = origProject.getProjects().size() - numDynmPart;
        for (Integer key : distKeys) {
            if (key >= insertIndex) {
                key = key + toShift;
            }
            shiftedDistKeys.add(key);
        }
        return shiftedDistKeys;
    }

    private RelCollation shiftRelCollation(RelCollation collation, Project origProject, int toShift, int numDynmPart) {
        List fieldCollations = collation.getFieldCollations();
        int insertIndex = origProject.getProjects().size() - numDynmPart;
        ArrayList<RelFieldCollation> shiftedCollations = new ArrayList<RelFieldCollation>(fieldCollations.size());
        for (RelFieldCollation fieldCollation : fieldCollations) {
            if (fieldCollation.getFieldIndex() >= insertIndex) {
                fieldCollation = fieldCollation.withFieldIndex(fieldCollation.getFieldIndex() + toShift);
            }
            shiftedCollations.add(fieldCollation);
        }
        return (RelCollation)this.plannerContext.getCluster().traitSet().canonize((RelTrait)RelCollationImpl.of(shiftedCollations));
    }

    static RelNode addTypeConversions(RexBuilder rexBuilder, RelNode queryRelNode, List<RelDataType> targetCalcTypes, List<TypeInfo> targetHiveTypes, SqlFunctionConverter funcConverter) throws SemanticException {
        if (queryRelNode instanceof Project) {
            return HiveParserDMLHelper.replaceProjectForTypeConversion(rexBuilder, (Project)queryRelNode, targetCalcTypes, targetHiveTypes, funcConverter);
        }
        RelNode newInput = HiveParserDMLHelper.addTypeConversions(rexBuilder, queryRelNode.getInput(0), targetCalcTypes, targetHiveTypes, funcConverter);
        queryRelNode.replaceInput(0, newInput);
        return queryRelNode;
    }

    private static RexNode createConversionCast(RexBuilder rexBuilder, RexNode srcRex, TypeInfo targetHiveType, RelDataType targetCalType, SqlFunctionConverter funcConverter) throws SemanticException {
        FunctionInfo functionInfo;
        if (funcConverter == null) {
            return rexBuilder.makeCast(targetCalType, srcRex);
        }
        String udfName = TypeInfoUtils.getBaseName(targetHiveType.getTypeName());
        try {
            functionInfo = FunctionRegistry.getFunctionInfo(udfName);
        }
        catch (SemanticException e) {
            throw new SemanticException(String.format("Failed to get UDF %s for casting", udfName), e);
        }
        if (functionInfo == null || functionInfo.getGenericUDF() == null) {
            throw new SemanticException(String.format("Failed to get UDF %s for casting", udfName));
        }
        if (functionInfo.getGenericUDF() instanceof SettableUDF) {
            return rexBuilder.makeCast(targetCalType, srcRex);
        }
        RexCall cast = (RexCall)rexBuilder.makeCall(HiveParserSqlFunctionConverter.getCalciteOperator(udfName, functionInfo.getGenericUDF(), Collections.singletonList(srcRex.getType()), targetCalType), new RexNode[]{srcRex});
        if (!funcConverter.hasOverloadedOp(cast.getOperator(), SqlFunctionCategory.USER_DEFINED_FUNCTION)) {
            return rexBuilder.makeCast(targetCalType, srcRex);
        }
        return (RexNode)cast.accept((RexVisitor)funcConverter);
    }

    private static RelNode replaceProjectForTypeConversion(RexBuilder rexBuilder, Project project, List<RelDataType> targetCalcTypes, List<TypeInfo> targetHiveTypes, SqlFunctionConverter funcConverter) throws SemanticException {
        List exprs = project.getProjects();
        Preconditions.checkState((exprs.size() == targetCalcTypes.size() ? 1 : 0) != 0, (Object)"Expressions and target types size mismatch");
        ArrayList<RexNode> updatedExprs = new ArrayList<RexNode>(exprs.size());
        boolean updated = false;
        for (int i = 0; i < exprs.size(); ++i) {
            RexNode expr = (RexNode)exprs.get(i);
            if (expr.getType().getSqlTypeName() != targetCalcTypes.get(i).getSqlTypeName()) {
                TypeInfo hiveType = targetHiveTypes.get(i);
                RelDataType calcType = targetCalcTypes.get(i);
                if (hiveType.getCategory() == ObjectInspector.Category.PRIMITIVE) {
                    expr = HiveParserDMLHelper.createConversionCast(rexBuilder, expr, hiveType, calcType, funcConverter);
                    updated = true;
                }
            }
            updatedExprs.add(expr);
        }
        if (updated) {
            LogicalProject newProject = LogicalProject.create((RelNode)project.getInput(), Collections.emptyList(), updatedExprs, HiveParserDMLHelper.getProjectNames(project));
            project.replaceInput(0, null);
            return newProject;
        }
        return project;
    }

    private RelNode handleDestSchema(SingleRel queryRelNode, Table destTable, List<String> destSchema, Set<String> staticParts) throws SemanticException {
        if (destSchema == null || destSchema.isEmpty()) {
            return queryRelNode;
        }
        ArrayList<FieldSchema> naturalSchema = new ArrayList<FieldSchema>(destTable.getCols());
        if (destTable.isPartitioned()) {
            naturalSchema.addAll(destTable.getTTable().getPartitionKeys().stream().filter(f -> !staticParts.contains(f.getName())).collect(Collectors.toList()));
        }
        if (destSchema.equals(HiveCatalog.getFieldNames(naturalSchema))) {
            return queryRelNode;
        }
        ArrayList<Object> updatedIndices = new ArrayList<Object>(naturalSchema.size());
        for (FieldSchema col : naturalSchema) {
            int index = destSchema.indexOf(col.getName());
            if (index < 0) {
                updatedIndices.add(HiveParserTypeConverter.convert(TypeInfoUtils.getTypeInfoFromTypeString(col.getType()), (RelDataTypeFactory)this.plannerContext.getTypeFactory()));
                continue;
            }
            updatedIndices.add(index);
        }
        if (queryRelNode instanceof Project) {
            return this.addProjectForDestSchema((Project)queryRelNode, updatedIndices);
        }
        if (queryRelNode instanceof Sort) {
            Sort sort = (Sort)queryRelNode;
            RelNode sortInput = sort.getInput();
            if (sortInput instanceof LogicalDistribution) {
                RelNode newDist = this.handleDestSchemaForDist((LogicalDistribution)sortInput, updatedIndices);
                sort.replaceInput(0, newDist);
                return sort;
            }
            RelNode addedProject = this.addProjectForDestSchema((Project)sortInput, updatedIndices);
            List fieldCollations = sort.getCollation().getFieldCollations();
            if (!fieldCollations.isEmpty()) {
                sort.replaceInput(0, null);
                sort = LogicalSort.create((RelNode)addedProject, (RelCollation)this.updateRelCollation(sort.getCollation(), updatedIndices), (RexNode)sort.offset, (RexNode)sort.fetch);
            }
            sort.replaceInput(0, addedProject);
            return sort;
        }
        return this.handleDestSchemaForDist((LogicalDistribution)queryRelNode, updatedIndices);
    }

    private RelNode handleDestSchemaForDist(LogicalDistribution hiveDist, List<Object> updatedIndices) throws SemanticException {
        Project project = (Project)hiveDist.getInput();
        RelNode addedProject = this.addProjectForDestSchema(project, updatedIndices);
        hiveDist.replaceInput(0, null);
        return LogicalDistribution.create((RelNode)addedProject, (RelCollation)this.updateRelCollation(hiveDist.getCollation(), updatedIndices), this.updateDistKeys(hiveDist.getDistKeys(), updatedIndices));
    }

    private RelCollation updateRelCollation(RelCollation collation, List<Object> updatedIndices) {
        List fieldCollations = collation.getFieldCollations();
        if (fieldCollations.isEmpty()) {
            return collation;
        }
        ArrayList<RelFieldCollation> updatedCollations = new ArrayList<RelFieldCollation>(fieldCollations.size());
        for (RelFieldCollation fieldCollation : fieldCollations) {
            int newIndex = updatedIndices.indexOf(fieldCollation.getFieldIndex());
            Preconditions.checkState((newIndex >= 0 ? 1 : 0) != 0, (Object)"Sort/Order references a non-existing field");
            fieldCollation = fieldCollation.withFieldIndex(newIndex);
            updatedCollations.add(fieldCollation);
        }
        return (RelCollation)this.plannerContext.getCluster().traitSet().canonize((RelTrait)RelCollationImpl.of(updatedCollations));
    }

    private List<Integer> updateDistKeys(List<Integer> distKeys, List<Object> updatedIndices) {
        ArrayList<Integer> updatedDistKeys = new ArrayList<Integer>(distKeys.size());
        for (Integer key : distKeys) {
            int newKey = updatedIndices.indexOf(key);
            Preconditions.checkState((newKey >= 0 ? 1 : 0) != 0, (Object)"Cluster/Distribute references a non-existing field");
            updatedDistKeys.add(newKey);
        }
        return updatedDistKeys;
    }

    private RelNode replaceProjectForStaticPart(Project project, Map<String, String> staticPartSpec, Table destTable, Map<String, RelDataType> targetColToType) {
        List exprs = project.getProjects();
        ArrayList<RexLiteral> extendedExprs = new ArrayList<RexLiteral>(exprs);
        int numDynmPart = destTable.getTTable().getPartitionKeys().size() - staticPartSpec.size();
        int insertIndex = extendedExprs.size() - numDynmPart;
        RexBuilder rexBuilder = this.plannerContext.getCluster().getRexBuilder();
        for (Map.Entry<String, String> spec : staticPartSpec.entrySet()) {
            RexLiteral toAdd = rexBuilder.makeCharLiteral(HiveParserUtils.asUnicodeString(spec.getValue()));
            toAdd = rexBuilder.makeAbstractCast(targetColToType.get(spec.getKey()), (RexNode)toAdd);
            extendedExprs.add(insertIndex++, toAdd);
        }
        LogicalProject res = LogicalProject.create((RelNode)project.getInput(), Collections.emptyList(), extendedExprs, (List)null);
        project.replaceInput(0, null);
        return res;
    }

    private static List<String> getProjectNames(Project project) {
        return project.getNamedProjects().stream().map(p -> (String)p.right).collect(Collectors.toList());
    }

    private RelNode addProjectForDestSchema(Project input, List<Object> updatedIndices) throws SemanticException {
        int destSchemaSize = (int)updatedIndices.stream().filter(o -> o instanceof Integer).count();
        if (destSchemaSize != input.getProjects().size()) {
            throw new SemanticException(String.format("Expected %d columns, but SEL produces %d columns", destSchemaSize, input.getProjects().size()));
        }
        ArrayList<Object> exprs = new ArrayList<Object>(updatedIndices.size());
        RexBuilder rexBuilder = this.plannerContext.getCluster().getRexBuilder();
        for (Object object : updatedIndices) {
            if (object instanceof Integer) {
                exprs.add(rexBuilder.makeInputRef((RelNode)input, ((Integer)object).intValue()));
                continue;
            }
            RexLiteral rexNode = rexBuilder.makeNullLiteral((RelDataType)object);
            exprs.add(rexNode);
        }
        return LogicalProject.create((RelNode)input, Collections.emptyList(), exprs, (List)null);
    }
}

