/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.rules.ImmutableAggregateMergeRule;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.immutables.value.Value;

@Value.Enclosing
public class AggregateMergeRule
extends RelRule<Config>
implements TransformationRule {
    protected AggregateMergeRule(Config config) {
        super(config);
    }

    @Deprecated
    public AggregateMergeRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).withOperandSupplier(b -> b.exactly(operand)).as(Config.class));
    }

    private static boolean isAggregateSupported(AggregateCall aggCall) {
        if (aggCall.isDistinct() || aggCall.hasFilter() || aggCall.isApproximate() || aggCall.getArgList().size() > 1) {
            return false;
        }
        return aggCall.getAggregation().maybeUnwrap(SqlSplittableAggFunction.class).isPresent();
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate topAgg = (Aggregate)call.rel(0);
        Aggregate bottomAgg = (Aggregate)call.rel(1);
        if (topAgg.getGroupCount() > bottomAgg.getGroupCount()) {
            return;
        }
        ImmutableBitSet bottomGroupSet = bottomAgg.getGroupSet();
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        bottomGroupSet.forEachInt(v -> map.put(map.size(), v));
        for (int k : topAgg.getGroupSet()) {
            if (map.containsKey(k)) continue;
            return;
        }
        ImmutableBitSet topGroupSet = topAgg.getGroupSet().permute(map);
        if (!bottomGroupSet.contains(topGroupSet)) {
            return;
        }
        boolean hasEmptyGroup = topAgg.getGroupSets().stream().anyMatch(ImmutableBitSet::isEmpty);
        ArrayList<AggregateCall> finalCalls = new ArrayList<AggregateCall>();
        for (AggregateCall topCall : topAgg.getAggCallList()) {
            if (!AggregateMergeRule.isAggregateSupported(topCall) || topCall.getArgList().isEmpty()) {
                return;
            }
            int bottomIndex = topCall.getArgList().get(0) - bottomGroupSet.cardinality();
            if (bottomIndex >= bottomAgg.getAggCallList().size() || bottomIndex < 0) {
                return;
            }
            AggregateCall bottomCall = bottomAgg.getAggCallList().get(bottomIndex);
            if (!AggregateMergeRule.isAggregateSupported(bottomCall) || bottomCall.getAggregation() == SqlStdOperatorTable.COUNT && topCall.getAggregation().getKind() != SqlKind.SUM0 && hasEmptyGroup) {
                return;
            }
            SqlSplittableAggFunction splitter = bottomCall.getAggregation().unwrapOrThrow(SqlSplittableAggFunction.class);
            AggregateCall finalCall = splitter.merge(topCall, bottomCall);
            if (finalCall == null) {
                return;
            }
            finalCalls.add(finalCall);
        }
        ImmutableList newGroupingSets = null;
        if (topAgg.getGroupType() != Aggregate.Group.SIMPLE) {
            newGroupingSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(topAgg.getGroupSets(), map));
        }
        Aggregate finalAgg = topAgg.copy(topAgg.getTraitSet(), bottomAgg.getInput(), topGroupSet, (List<ImmutableBitSet>)newGroupingSets, finalCalls);
        call.transformTo(finalAgg);
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableAggregateMergeRule.Config.of().withOperandSupplier(b0 -> b0.operand(Aggregate.class).oneInput(b1 -> b1.operand(Aggregate.class).predicate(Aggregate::isSimple).anyInputs())).as(Config.class);

        @Override
        default public AggregateMergeRule toRule() {
            return new AggregateMergeRule(this);
        }
    }
}

