/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexTableInputRef;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveRelNode;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan;
import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveRelFieldTrimmer;
import org.apache.hadoop.hive.ql.optimizer.calcite.rules.RelFieldTrimmer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HiveCardinalityPreservingJoinOptimization
extends HiveRelFieldTrimmer {
    private static final Logger LOG = LoggerFactory.getLogger(HiveCardinalityPreservingJoinOptimization.class);

    public HiveCardinalityPreservingJoinOptimization() {
        super(false);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public RelNode trim(RelBuilder relBuilder, RelNode root) {
        try {
            Object tableToJoinBack;
            if (root.getInputs().size() != 1) {
                LOG.debug("Only plans where root has one input are supported. Root: {}", (Object)root);
                RelNode relNode = root;
                return relNode;
            }
            REL_BUILDER.set(relBuilder);
            RexBuilder rexBuilder = relBuilder.getRexBuilder();
            RelNode rootInput = root.getInput(0);
            ArrayList<RexInputRef> rootFieldList = new ArrayList<RexInputRef>(rootInput.getRowType().getFieldCount());
            ArrayList<String> newColumnNames = new ArrayList<String>();
            for (int i = 0; i < rootInput.getRowType().getFieldList().size(); ++i) {
                RelDataTypeField relDataTypeField = (RelDataTypeField)rootInput.getRowType().getFieldList().get(i);
                rootFieldList.add(rexBuilder.makeInputRef(relDataTypeField.getType(), i));
                newColumnNames.add(relDataTypeField.getName());
            }
            BitSet constants = new BitSet();
            List<JoinedBackFields> lineages = this.getExpressionLineageOf(rootFieldList, rootInput, constants);
            if (lineages == null) {
                LOG.debug("Some projected field lineage can not be determined");
                RelNode relNode = root;
                return relNode;
            }
            ImmutableBitSet fieldsUsed = ImmutableBitSet.of((int[])constants.stream().toArray());
            ArrayList<RelNode> tableToJoinBackList = new ArrayList<RelNode>(lineages.size());
            HashMap<Integer, RexNode> rexNodesToShuttle = new HashMap<Integer, RexNode>(rootInput.getRowType().getFieldCount());
            for (JoinedBackFields joinedBackFields : lineages) {
                Optional<ImmutableBitSet> projectedKeys = joinedBackFields.relOptHiveTable.getNonNullableKeys().stream().filter(arg_0 -> ((ImmutableBitSet)joinedBackFields.fieldsInSourceTable).contains(arg_0)).findFirst();
                if (projectedKeys.isPresent() && !projectedKeys.get().equals((Object)joinedBackFields.fieldsInSourceTable)) {
                    tableToJoinBack = new TableToJoinBack(projectedKeys.get(), joinedBackFields);
                    tableToJoinBackList.add((RelNode)tableToJoinBack);
                    fieldsUsed = fieldsUsed.union(joinedBackFields.getSource(projectedKeys.get()));
                    for (TableInputRefHolder tableInputRefHolder : joinedBackFields.mapping) {
                        if (fieldsUsed.get(tableInputRefHolder.indexInOriginalRowType)) continue;
                        rexNodesToShuttle.put(tableInputRefHolder.indexInOriginalRowType, tableInputRefHolder.rexNode);
                    }
                    continue;
                }
                fieldsUsed = fieldsUsed.union(joinedBackFields.fieldsInOriginalRowType);
            }
            if (tableToJoinBackList.isEmpty()) {
                LOG.debug("None of the tables has keys projected, unable to join back");
                RelNode relNode = root;
                return relNode;
            }
            Set<RelDataTypeField> extraFields = Collections.emptySet();
            RelFieldTrimmer.TrimResult trimResult = this.dispatchTrimFields(rootInput, fieldsUsed, extraFields);
            RelNode newInput = (RelNode)trimResult.left;
            if (newInput.getRowType().equals(rootInput.getRowType())) {
                LOG.debug("Nothing was trimmed out.");
                tableToJoinBack = root;
                return tableToJoinBack;
            }
            Mapping newInputMapping = (Mapping)trimResult.right;
            HashMap<RexTableInputRef, Integer> tableInputRefMapping = new HashMap<RexTableInputRef, Integer>();
            for (TableToJoinBack tableToJoinBack2 : tableToJoinBackList) {
                LOG.debug("Joining back table {}", (Object)tableToJoinBack2.joinedBackFields.relOptHiveTable.getName());
                RelOptHiveTable relOptTable = tableToJoinBack2.joinedBackFields.relOptHiveTable;
                RelOptCluster cluster = relBuilder.getCluster();
                HiveTableScan tableScan = new HiveTableScan(cluster, cluster.traitSetOf((RelTrait)HiveRelNode.CONVENTION), relOptTable, relOptTable.getHiveTableMD().getTableName(), null, false, false);
                RelNode projectTableAccessRel = tableScan.project(tableToJoinBack2.joinedBackFields.fieldsInSourceTable, new HashSet<RelDataTypeField>(0), (RelBuilder)REL_BUILDER.get());
                Mapping projectMapping = Mappings.create((MappingType)MappingType.INVERSE_SURJECTION, (int)tableScan.getRowType().getFieldCount(), (int)tableToJoinBack2.joinedBackFields.fieldsInSourceTable.cardinality());
                int projectIndex = 0;
                Iterator iterator = tableToJoinBack2.joinedBackFields.fieldsInSourceTable.iterator();
                while (iterator.hasNext()) {
                    int i = (Integer)iterator.next();
                    projectMapping.set(i, projectIndex);
                    ++projectIndex;
                }
                int offset = newInput.getRowType().getFieldCount();
                for (TableInputRefHolder mapping : tableToJoinBack2.joinedBackFields.mapping) {
                    int indexInSourceTable = mapping.tableInputRef.getIndex();
                    if (tableToJoinBack2.keys.get(indexInSourceTable)) continue;
                    tableInputRefMapping.put(mapping.tableInputRef, offset + projectMapping.getTarget(indexInSourceTable));
                }
                relBuilder.push(newInput);
                relBuilder.push(projectTableAccessRel);
                RexNode joinCondition = this.joinCondition(newInput, newInputMapping, tableToJoinBack2, projectTableAccessRel, projectMapping, rexBuilder);
                newInput = relBuilder.join(JoinRelType.INNER, joinCondition).build();
            }
            TableInputRefMapper tableInputRefMapper = new TableInputRefMapper(tableInputRefMapping, rexBuilder, newInput);
            ArrayList<Object> arrayList = new ArrayList<Object>(rootInput.getRowType().getFieldCount());
            for (int i = 0; i < rootInput.getRowType().getFieldCount(); ++i) {
                RexNode rexNode = (RexNode)rexNodesToShuttle.get(i);
                if (rexNode != null) {
                    arrayList.add(tableInputRefMapper.apply(rexNode));
                    continue;
                }
                int target = newInputMapping.getTarget(i);
                arrayList.add(rexBuilder.makeInputRef(((RelDataTypeField)newInput.getRowType().getFieldList().get(target)).getType(), target));
            }
            relBuilder.push(newInput);
            relBuilder.project(arrayList, newColumnNames);
            RelNode relNode = root.copy(root.getTraitSet(), Collections.singletonList(relBuilder.build()));
            return relNode;
        }
        finally {
            REL_BUILDER.remove();
        }
    }

    private List<JoinedBackFields> getExpressionLineageOf(List<RexInputRef> projectExpressions, RelNode projectInput, BitSet constants) {
        RelMetadataQuery relMetadataQuery = RelMetadataQuery.instance();
        HashMap<RexTableInputRef.RelTableRef, JoinedBackFieldsBuilder> fieldMappingBuilders = new HashMap<RexTableInputRef.RelTableRef, JoinedBackFieldsBuilder>();
        ArrayList tablesOrdered = new ArrayList();
        for (RexInputRef expr : projectExpressions) {
            Set expressionLineage = relMetadataQuery.getExpressionLineage(projectInput, (RexNode)expr);
            if (expressionLineage == null || expressionLineage.size() != 1) {
                LOG.debug("Lineage of expression in node {} can not be determined: {}", (Object)projectInput, (Object)expr);
                return null;
            }
            RexNode rexNode = (RexNode)expressionLineage.iterator().next();
            Set<RexTableInputRef> refs = HiveCalciteUtil.findRexTableInputRefs(rexNode);
            if (refs.isEmpty()) {
                if (!RexUtil.isConstant((RexNode)rexNode)) {
                    LOG.debug("Unknown expression that should be a constant: {}", (Object)rexNode);
                    return null;
                }
                constants.set(expr.getIndex());
                continue;
            }
            for (RexTableInputRef rexTableInputRef : refs) {
                RexTableInputRef.RelTableRef tableRef = rexTableInputRef.getTableRef();
                JoinedBackFieldsBuilder joinedBackFieldsBuilder = fieldMappingBuilders.computeIfAbsent(tableRef, k -> {
                    tablesOrdered.add(tableRef);
                    return new JoinedBackFieldsBuilder(tableRef);
                });
                joinedBackFieldsBuilder.add(expr, rexNode, rexTableInputRef);
            }
        }
        return tablesOrdered.stream().map(relOptHiveTable -> ((JoinedBackFieldsBuilder)fieldMappingBuilders.get(relOptHiveTable)).build()).collect(Collectors.toList());
    }

    private RexNode joinCondition(RelNode leftInput, Mapping leftInputMapping, TableToJoinBack tableToJoinBack, RelNode rightInput, Mapping rightInputKeyMapping, RexBuilder rexBuilder) {
        ArrayList<RexNode> equalsConditions = new ArrayList<RexNode>(tableToJoinBack.keys.size());
        BitSet usedKeys = new BitSet(0);
        for (TableInputRefHolder tableInputRefHolder : tableToJoinBack.joinedBackFields.mapping) {
            if (usedKeys.get(tableInputRefHolder.tableInputRef.getIndex()) || !tableToJoinBack.keys.get(tableInputRefHolder.tableInputRef.getIndex())) continue;
            usedKeys.set(tableInputRefHolder.tableInputRef.getIndex());
            int leftKeyIndex = leftInputMapping.getTarget(tableInputRefHolder.indexInOriginalRowType);
            RelDataTypeField leftKeyField = (RelDataTypeField)leftInput.getRowType().getFieldList().get(leftKeyIndex);
            int rightKeyIndex = rightInputKeyMapping.getTarget(tableInputRefHolder.tableInputRef.getIndex());
            RelDataTypeField rightKeyField = (RelDataTypeField)rightInput.getRowType().getFieldList().get(rightKeyIndex);
            equalsConditions.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, new RexNode[]{rexBuilder.makeInputRef((RelDataType)leftKeyField.getValue(), leftKeyField.getIndex()), rexBuilder.makeInputRef((RelDataType)rightKeyField.getValue(), leftInput.getRowType().getFieldCount() + rightKeyIndex)}));
        }
        return RexUtil.composeConjunction((RexBuilder)rexBuilder, equalsConditions);
    }

    private static final class JoinedBackFields {
        private final RelOptHiveTable relOptHiveTable;
        private final ImmutableBitSet fieldsInOriginalRowType;
        private final ImmutableBitSet fieldsInSourceTable;
        private final List<TableInputRefHolder> mapping;

        private JoinedBackFields(RexTableInputRef.RelTableRef relTableRef, ImmutableBitSet fieldsInOriginalRowType, ImmutableBitSet fieldsInSourceTable, List<TableInputRefHolder> mapping) {
            this.relOptHiveTable = (RelOptHiveTable)relTableRef.getTable();
            this.fieldsInOriginalRowType = fieldsInOriginalRowType;
            this.fieldsInSourceTable = fieldsInSourceTable;
            this.mapping = mapping;
        }

        public ImmutableBitSet getSource(ImmutableBitSet fields) {
            ImmutableBitSet.Builder targetFieldsBuilder = ImmutableBitSet.builder();
            for (TableInputRefHolder fieldMapping : this.mapping) {
                if (!fields.get(fieldMapping.tableInputRef.getIndex())) continue;
                targetFieldsBuilder.set(fieldMapping.indexInOriginalRowType);
            }
            return targetFieldsBuilder.build();
        }
    }

    private static final class TableToJoinBack {
        private final JoinedBackFields joinedBackFields;
        private final ImmutableBitSet keys;

        private TableToJoinBack(ImmutableBitSet keys, JoinedBackFields joinedBackFields) {
            this.joinedBackFields = joinedBackFields;
            this.keys = keys;
        }
    }

    private static final class TableInputRefHolder {
        private final RexTableInputRef tableInputRef;
        private final RexNode rexNode;
        private final int indexInOriginalRowType;

        private TableInputRefHolder(RexInputRef inputRef, RexNode rexNode, RexTableInputRef sourceTableRef) {
            this.indexInOriginalRowType = inputRef.getIndex();
            this.rexNode = rexNode;
            this.tableInputRef = sourceTableRef;
        }
    }

    private static final class TableInputRefMapper
    extends RexShuttle {
        private final Map<RexTableInputRef, Integer> tableInputRefMapping;
        private final RexBuilder rexBuilder;
        private final RelNode newInput;

        private TableInputRefMapper(Map<RexTableInputRef, Integer> tableInputRefMapping, RexBuilder rexBuilder, RelNode newInput) {
            this.tableInputRefMapping = tableInputRefMapping;
            this.rexBuilder = rexBuilder;
            this.newInput = newInput;
        }

        public RexNode visitTableInputRef(RexTableInputRef ref) {
            int source = this.tableInputRefMapping.get(ref);
            return this.rexBuilder.makeInputRef(((RelDataTypeField)this.newInput.getRowType().getFieldList().get(source)).getType(), source);
        }
    }

    private static class JoinedBackFieldsBuilder {
        private final RexTableInputRef.RelTableRef relTableRef;
        private final ImmutableBitSet.Builder fieldsInOriginalRowTypeBuilder = ImmutableBitSet.builder();
        private final ImmutableBitSet.Builder fieldsInSourceTableBuilder = ImmutableBitSet.builder();
        private final List<TableInputRefHolder> mapping = new ArrayList<TableInputRefHolder>();

        private JoinedBackFieldsBuilder(RexTableInputRef.RelTableRef relTableRef) {
            this.relTableRef = relTableRef;
        }

        public void add(RexInputRef rexInputRef, RexNode rexNode, RexTableInputRef sourceTableInputRef) {
            this.fieldsInOriginalRowTypeBuilder.set(rexInputRef.getIndex());
            this.fieldsInSourceTableBuilder.set(sourceTableInputRef.getIndex());
            this.mapping.add(new TableInputRefHolder(rexInputRef, rexNode, sourceTableInputRef));
        }

        public JoinedBackFields build() {
            return new JoinedBackFields(this.relTableRef, this.fieldsInOriginalRowTypeBuilder.build(), this.fieldsInSourceTableBuilder.build(), this.mapping);
        }
    }
}

