/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeSet;
import java.util.function.IntPredicate;
import java.util.stream.Collectors;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlInternalOperators;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBeans;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;
import org.checkerframework.checker.nullness.qual.Nullable;

public class AggregateExpandWithinDistinctRule
extends RelRule<Config> {
    protected AggregateExpandWithinDistinctRule(Config config) {
        super(config);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private static boolean hasWithinDistinct(Aggregate aggregate) {
        if (!aggregate.getAggCallList().stream().anyMatch(c -> c.distinctKeys != null)) return false;
        if (!aggregate.getAggCallList().stream().noneMatch(CoreRules.AGGREGATE_REDUCE_FUNCTIONS::canReduce)) return false;
        if (!aggregate.getAggCallList().stream().noneMatch(c -> c.filterArg >= 0)) return false;
        if (aggregate.getGroupType() != Aggregate.Group.SIMPLE) return false;
        return true;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        List aggCallList = (List)aggregate.getAggCallList().stream().map(c -> AggregateExpandWithinDistinctRule.unDistinct(c, aggregate.getInput()::fieldIsNullable)).collect(Util.toImmutableList());
        ArrayListMultimap argLists = ArrayListMultimap.create();
        ImmutableBitSet notDistinct = ImmutableBitSet.of(aggregate.getInput().getRowType().getFieldCount());
        for (Object aggCall : aggCallList) {
            ImmutableBitSet distinctKeys = ((AggregateCall)aggCall).distinctKeys;
            if (distinctKeys == null) {
                distinctKeys = notDistinct;
            } else if (distinctKeys.intersects(aggregate.getGroupSet())) {
                distinctKeys = distinctKeys.rebuild().removeAll(aggregate.getGroupSet()).build();
            }
            argLists.put((Object)distinctKeys, aggCall);
            assert (((AggregateCall)aggCall).filterArg < 0);
        }
        TreeSet<ImmutableBitSet> groupSetTreeSet = new TreeSet<ImmutableBitSet>((Comparator<ImmutableBitSet>)ImmutableBitSet.ORDERING);
        groupSetTreeSet.add(aggregate.getGroupSet());
        for (ImmutableBitSet key : argLists.keySet()) {
            if (key == notDistinct) continue;
            groupSetTreeSet.add(ImmutableBitSet.of(key).union(aggregate.getGroupSet()));
        }
        ImmutableList groupSets = ImmutableList.copyOf(groupSetTreeSet);
        final ImmutableBitSet fullGroupSet = ImmutableBitSet.union((Iterable<? extends ImmutableBitSet>)groupSets);
        LinkedHashSet<Integer> fullGroupOrderedSet = new LinkedHashSet<Integer>();
        fullGroupOrderedSet.addAll(aggregate.getGroupSet().asSet());
        fullGroupOrderedSet.addAll(fullGroupSet.asSet());
        ImmutableIntList fullGroupList = ImmutableIntList.copyOf(fullGroupOrderedSet);
        final RelBuilder b = call.builder();
        b.push(aggregate.getInput());
        final ArrayList<RelBuilder.AggCall> aggCalls = new ArrayList<RelBuilder.AggCall>();
        class Registrar {
            final int g;
            final Map<Integer, Integer> args;
            final Map<Integer, Integer> aggs;

            Registrar() {
                this.g = fullGroupSet.cardinality();
                this.args = new HashMap<Integer, Integer>();
                this.aggs = new HashMap<Integer, Integer>();
            }

            List<Integer> fields(List<Integer> fields) {
                return Util.transform(fields, this::field);
            }

            int field(int field) {
                return Objects.requireNonNull(this.args.get(field));
            }

            int register(int field) {
                return this.args.computeIfAbsent(field, j -> {
                    int ordinal = this.g + aggCalls.size();
                    aggCalls.add(b.aggregateCall(SqlStdOperatorTable.MIN, b.field((int)j)));
                    if (((Config)AggregateExpandWithinDistinctRule.this.config).throwIfNotUnique()) {
                        aggCalls.add(b.aggregateCall(SqlStdOperatorTable.MAX, b.field((int)j)));
                    }
                    return ordinal;
                });
            }

            int registerAgg(int i, RelBuilder.AggCall aggregateCall) {
                int ordinal = this.g + aggCalls.size();
                this.aggs.put(i, ordinal);
                aggCalls.add(aggregateCall);
                return ordinal;
            }

            int getAgg(int i) {
                return Objects.requireNonNull(this.aggs.get(i));
            }
        }
        Registrar registrar = new Registrar();
        Ord.forEach((Iterable)aggCallList, (c, i) -> {
            if (c.distinctKeys == null) {
                registrar.registerAgg(i, b.aggregateCall(c.getAggregation(), (Iterable<? extends RexNode>)b.fields(c.getArgList())));
            } else {
                c.getArgList().forEach(registrar::register);
            }
        });
        int grouping = registrar.registerAgg(-1, b.aggregateCall(SqlStdOperatorTable.GROUPING, (Iterable<? extends RexNode>)b.fields(fullGroupList)));
        b.aggregate(b.groupKey(fullGroupSet, (Iterable<? extends ImmutableBitSet>)groupSets), (Iterable<RelBuilder.AggCall>)aggCalls);
        aggCalls.clear();
        Ord.forEach((Iterable)aggCallList, (c, i) -> {
            RelBuilder.AggCall aggCall;
            ArrayList<RexNode> filters = new ArrayList<RexNode>();
            RexNode groupFilter = b.equals(b.field(grouping), b.literal(AggregateExpandDistinctAggregatesRule.groupValue(fullGroupList, AggregateExpandWithinDistinctRule.union(aggregate.getGroupSet(), c.distinctKeys))));
            filters.add(groupFilter);
            if (c.distinctKeys == null) {
                aggCall = b.aggregateCall(SqlStdOperatorTable.MIN, b.field(registrar.getAgg(i)));
            } else {
                aggCall = b.aggregateCall(c.getAggregation(), (Iterable<? extends RexNode>)b.fields(registrar.fields(c.getArgList())));
                if (((Config)this.config).throwIfNotUnique()) {
                    for (int j : c.getArgList()) {
                        String message = "more than one distinct value in agg UNIQUE_VALUE";
                        filters.add(b.call((SqlOperator)SqlInternalOperators.THROW_UNLESS, b.or(b.not(groupFilter), b.isNotDistinctFrom(b.field(registrar.field(j)), b.field(registrar.field(j) + 1))), b.literal(message)));
                    }
                }
            }
            aggCalls.add(aggCall.filter(b.and(filters)));
        });
        b.aggregate(b.groupKey(AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, aggregate.getGroupSet()), (Iterable<? extends ImmutableBitSet>)AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, aggregate.getGroupSets())), (Iterable<RelBuilder.AggCall>)aggCalls);
        b.convert(aggregate.getRowType(), false);
        call.transformTo(b.build());
    }

    private static AggregateCall unDistinct(AggregateCall aggregateCall, IntPredicate isNullable) {
        if (aggregateCall.isDistinct()) {
            List<Integer> newArgList = aggregateCall.getArgList().stream().filter(i -> aggregateCall.getAggregation().getKind() != SqlKind.COUNT || isNullable.test((int)i)).collect(Collectors.toList());
            return aggregateCall.withDistinct(false).withDistinctKeys(ImmutableBitSet.of(aggregateCall.getArgList())).withArgList(newArgList);
        }
        return aggregateCall;
    }

    private static ImmutableBitSet union(ImmutableBitSet s0, @Nullable ImmutableBitSet s1) {
        return s1 == null ? s0 : s0.union(s1);
    }

    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = EMPTY.withOperandSupplier(b -> b.operand(LogicalAggregate.class).predicate(x$0 -> AggregateExpandWithinDistinctRule.hasWithinDistinct(x$0)).anyInputs()).as(Config.class);

        @Override
        default public AggregateExpandWithinDistinctRule toRule() {
            return new AggregateExpandWithinDistinctRule(this);
        }

        @ImmutableBeans.Property
        @ImmutableBeans.BooleanDefault(value=true)
        public boolean throwIfNotUnique();

        public Config withThrowIfNotUnique(boolean var1);
    }
}

