/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import org.apache.calcite.adapter.druid.DruidQuery;
import org.apache.calcite.adapter.druid.DruidRules;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;

public class HiveDruidRules {
    public static final DruidRules.DruidFilterRule FILTER = new DruidRules.DruidFilterRule(HiveRelFactories.HIVE_BUILDER);
    public static final DruidRules.DruidProjectRule PROJECT = new DruidRules.DruidProjectRule(HiveRelFactories.HIVE_BUILDER);
    public static final DruidRules.DruidAggregateRule AGGREGATE = new DruidRules.DruidAggregateRule(HiveRelFactories.HIVE_BUILDER);
    public static final DruidRules.DruidAggregateProjectRule AGGREGATE_PROJECT = new DruidRules.DruidAggregateProjectRule(HiveRelFactories.HIVE_BUILDER);
    public static final DruidRules.DruidSortRule SORT = new DruidRules.DruidSortRule(HiveRelFactories.HIVE_BUILDER);
    public static final DruidRules.DruidSortProjectTransposeRule SORT_PROJECT_TRANSPOSE = new DruidRules.DruidSortProjectTransposeRule(HiveRelFactories.HIVE_BUILDER);
    public static final DruidRules.DruidProjectSortTransposeRule PROJECT_SORT_TRANSPOSE = new DruidRules.DruidProjectSortTransposeRule(HiveRelFactories.HIVE_BUILDER);
    public static final DruidRules.DruidProjectFilterTransposeRule PROJECT_FILTER_TRANSPOSE = new DruidRules.DruidProjectFilterTransposeRule(HiveRelFactories.HIVE_BUILDER);
    public static final DruidRules.DruidFilterProjectTransposeRule FILTER_PROJECT_TRANSPOSE = new DruidRules.DruidFilterProjectTransposeRule(HiveRelFactories.HIVE_BUILDER);
    public static final DruidRules.DruidAggregateFilterTransposeRule AGGREGATE_FILTER_TRANSPOSE = new DruidRules.DruidAggregateFilterTransposeRule(HiveRelFactories.HIVE_BUILDER);
    public static final DruidRules.DruidFilterAggregateTransposeRule FILTER_AGGREGATE_TRANSPOSE = new DruidRules.DruidFilterAggregateTransposeRule(HiveRelFactories.HIVE_BUILDER);
    public static final DruidRules.DruidPostAggregationProjectRule POST_AGGREGATION_PROJECT = new DruidRules.DruidPostAggregationProjectRule(HiveRelFactories.HIVE_BUILDER);
    public static final DruidRules.DruidHavingFilterRule HAVING_FILTER_RULE = new DruidRules.DruidHavingFilterRule(HiveRelFactories.HIVE_BUILDER);
    public static final AggregateExpandDistinctAggregatesDruidRule EXPAND_SINGLE_DISTINCT_AGGREGATES_DRUID_RULE = new AggregateExpandDistinctAggregatesDruidRule(HiveRelFactories.HIVE_BUILDER);

    public static class AggregateExpandDistinctAggregatesDruidRule
    extends RelOptRule {
        public AggregateExpandDistinctAggregatesDruidRule(RelBuilderFactory relBuilderFactory) {
            super(AggregateExpandDistinctAggregatesDruidRule.operand(Aggregate.class, (RelOptRuleOperand)AggregateExpandDistinctAggregatesDruidRule.operand(DruidQuery.class, (RelOptRuleOperandChildren)AggregateExpandDistinctAggregatesDruidRule.none()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), relBuilderFactory, null);
        }

        public void onMatch(RelOptRuleCall call) {
            Aggregate aggregate = (Aggregate)call.rel(0);
            if (!aggregate.containsDistinctCall()) {
                return;
            }
            long numCountDistinct = aggregate.getAggCallList().stream().filter(aggregateCall -> aggregateCall.getAggregation().getKind().equals((Object)SqlKind.COUNT) && aggregateCall.isDistinct()).count();
            if (numCountDistinct != 1L) {
                return;
            }
            int nonDistinctAggCallCount = 0;
            int filterCount = 0;
            int unsupportedNonDistinctAggCallCount = 0;
            LinkedHashSet<Pair<List<Integer>, Integer>> argLists = new LinkedHashSet<Pair<List<Integer>, Integer>>();
            for (AggregateCall aggCall : aggregate.getAggCallList()) {
                if (aggCall.filterArg >= 0) {
                    ++filterCount;
                }
                if (!aggCall.isDistinct()) {
                    ++nonDistinctAggCallCount;
                    SqlKind aggCallKind = aggCall.getAggregation().getKind();
                    switch (aggCallKind) {
                        case COUNT: 
                        case SUM: 
                        case SUM0: 
                        case MIN: 
                        case MAX: {
                            break;
                        }
                        default: {
                            ++unsupportedNonDistinctAggCallCount;
                            break;
                        }
                    }
                    continue;
                }
                argLists.add(Pair.of((Object)aggCall.getArgList(), (Object)aggCall.filterArg));
            }
            if (numCountDistinct == 1L && filterCount == 0 && unsupportedNonDistinctAggCallCount == 0 && nonDistinctAggCallCount > 0) {
                RelBuilder relBuilder = call.builder();
                this.convertSingletonDistinct(relBuilder, aggregate, argLists);
                call.transformTo(relBuilder.build());
                return;
            }
        }

        private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
            Preconditions.checkArgument(argLists.size() == 1);
            relBuilder.push(aggregate.getInput());
            List originalAggCalls = aggregate.getAggCallList();
            ImmutableBitSet originalGroupSet = aggregate.getGroupSet();
            TreeSet<Integer> bottomGroupSet = new TreeSet<Integer>();
            bottomGroupSet.addAll(aggregate.getGroupSet().asList());
            for (Object aggCall : originalAggCalls) {
                if (!aggCall.isDistinct()) continue;
                bottomGroupSet.addAll(aggCall.getArgList());
                break;
            }
            ArrayList<AggregateCall> bottomAggregateCalls = new ArrayList<AggregateCall>();
            for (AggregateCall aggCall : originalAggCalls) {
                if (aggCall.isDistinct()) continue;
                AggregateCall newCall = AggregateCall.create((SqlAggFunction)aggCall.getAggregation(), (boolean)false, (boolean)aggCall.isApproximate(), (List)aggCall.getArgList(), (int)-1, (int)ImmutableBitSet.of(bottomGroupSet).cardinality(), (RelNode)relBuilder.peek(), null, (String)aggCall.name);
                bottomAggregateCalls.add(newCall);
            }
            relBuilder.push((RelNode)aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.of(bottomGroupSet), null, bottomAggregateCalls));
            ArrayList<AggregateCall> topAggregateCalls = Lists.newArrayList();
            int nonDistinctAggCallProcessedSoFar = 0;
            for (AggregateCall aggCall : originalAggCalls) {
                AggregateCall newCall;
                if (aggCall.isDistinct()) {
                    ArrayList<Integer> newArgList = new ArrayList<Integer>();
                    Iterator iterator = aggCall.getArgList().iterator();
                    while (iterator.hasNext()) {
                        int arg = (Integer)iterator.next();
                        newArgList.add(bottomGroupSet.headSet(arg).size());
                    }
                    newCall = AggregateCall.create((SqlAggFunction)aggCall.getAggregation(), (boolean)false, (boolean)aggCall.isApproximate(), newArgList, (int)-1, (int)originalGroupSet.cardinality(), (RelNode)relBuilder.peek(), (RelDataType)aggCall.getType(), (String)aggCall.name);
                } else {
                    ArrayList<Integer> newArgs = Lists.newArrayList(bottomGroupSet.size() + nonDistinctAggCallProcessedSoFar);
                    newCall = aggCall.getAggregation().getKind() == SqlKind.COUNT ? AggregateCall.create((SqlAggFunction)new SqlSumEmptyIsZeroAggFunction(), (boolean)false, (boolean)aggCall.isApproximate(), newArgs, (int)-1, (int)originalGroupSet.cardinality(), (RelNode)relBuilder.peek(), (RelDataType)aggCall.getType(), (String)aggCall.getName()) : AggregateCall.create((SqlAggFunction)aggCall.getAggregation(), (boolean)false, (boolean)aggCall.isApproximate(), newArgs, (int)-1, (int)originalGroupSet.cardinality(), (RelNode)relBuilder.peek(), (RelDataType)aggCall.getType(), (String)aggCall.name);
                    ++nonDistinctAggCallProcessedSoFar;
                }
                topAggregateCalls.add(newCall);
            }
            HashSet<Integer> topGroupSet = new HashSet<Integer>();
            int groupSetToAdd = 0;
            Iterator iterator = bottomGroupSet.iterator();
            while (iterator.hasNext()) {
                int bottomGroup = (Integer)iterator.next();
                if (originalGroupSet.get(bottomGroup)) {
                    topGroupSet.add(groupSetToAdd);
                }
                ++groupSetToAdd;
            }
            relBuilder.push((RelNode)aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
            return relBuilder;
        }
    }
}

