/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.shaded.org.apache.calcite.rel.rules;

import com.hazelcast.shaded.com.google.common.collect.ImmutableList;
import com.hazelcast.shaded.org.apache.calcite.plan.RelOptCluster;
import com.hazelcast.shaded.org.apache.calcite.plan.RelOptRuleCall;
import com.hazelcast.shaded.org.apache.calcite.plan.RelRule;
import com.hazelcast.shaded.org.apache.calcite.rel.RelCollations;
import com.hazelcast.shaded.org.apache.calcite.rel.core.Aggregate;
import com.hazelcast.shaded.org.apache.calcite.rel.core.AggregateCall;
import com.hazelcast.shaded.org.apache.calcite.rel.core.Project;
import com.hazelcast.shaded.org.apache.calcite.rel.rules.ImmutableAggregateCaseToFilterRule;
import com.hazelcast.shaded.org.apache.calcite.rel.rules.TransformationRule;
import com.hazelcast.shaded.org.apache.calcite.rel.type.RelDataType;
import com.hazelcast.shaded.org.apache.calcite.rel.type.RelDataTypeFactory;
import com.hazelcast.shaded.org.apache.calcite.rex.RexBuilder;
import com.hazelcast.shaded.org.apache.calcite.rex.RexCall;
import com.hazelcast.shaded.org.apache.calcite.rex.RexLiteral;
import com.hazelcast.shaded.org.apache.calcite.rex.RexNode;
import com.hazelcast.shaded.org.apache.calcite.sql.SqlKind;
import com.hazelcast.shaded.org.apache.calcite.sql.SqlOperator;
import com.hazelcast.shaded.org.apache.calcite.sql.SqlPostfixOperator;
import com.hazelcast.shaded.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import com.hazelcast.shaded.org.apache.calcite.sql.type.SqlTypeName;
import com.hazelcast.shaded.org.apache.calcite.tools.RelBuilder;
import com.hazelcast.shaded.org.apache.calcite.tools.RelBuilderFactory;
import com.hazelcast.shaded.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.shaded.org.checkerframework.checker.nullness.qual.Nullable;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.immutables.value.Value;

@Value.Enclosing
public class AggregateCaseToFilterRule
extends RelRule<Config>
implements TransformationRule {
    protected AggregateCaseToFilterRule(Config config) {
        super(config);
    }

    @Deprecated
    protected AggregateCaseToFilterRule(RelBuilderFactory relBuilderFactory, String description) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).withDescription(description).as(Config.class));
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Project project = (Project)call.rel(1);
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            int singleArg = AggregateCaseToFilterRule.soleArgument(aggregateCall);
            if (singleArg < 0 || !AggregateCaseToFilterRule.isThreeArgCase(project.getProjects().get(singleArg))) continue;
            return true;
        }
        return false;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Project project = (Project)call.rel(1);
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>(aggregate.getAggCallList().size());
        ArrayList<RexNode> newProjects = new ArrayList<RexNode>(project.getProjects());
        ArrayList<RexNode> newCasts = new ArrayList<RexNode>();
        Iterator<Object> iterator = aggregate.getGroupSet().iterator();
        while (iterator.hasNext()) {
            int fieldNumber = iterator.next();
            newCasts.add(rexBuilder.makeInputRef(project.getProjects().get(fieldNumber).getType(), fieldNumber));
        }
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            AggregateCall newCall = AggregateCaseToFilterRule.transform(aggregateCall, project, newProjects);
            int i = newCasts.size();
            RelDataType oldType = aggregate.getRowType().getFieldList().get(i).getType();
            if (newCall == null) {
                newCalls.add(aggregateCall);
                newCasts.add(rexBuilder.makeInputRef(oldType, i));
                continue;
            }
            newCalls.add(newCall);
            newCasts.add(rexBuilder.makeCast(oldType, rexBuilder.makeInputRef(newCall.getType(), i)));
        }
        if (newCalls.equals(aggregate.getAggCallList())) {
            return;
        }
        RelBuilder relBuilder = call.builder().push(project.getInput()).project(newProjects);
        RelBuilder.GroupKey groupKey = relBuilder.groupKey(aggregate.getGroupSet(), (Iterable<? extends ImmutableBitSet>)aggregate.getGroupSets());
        relBuilder.aggregate(groupKey, (List<AggregateCall>)newCalls).convert(aggregate.getRowType(), false);
        call.transformTo(relBuilder.build());
        call.getPlanner().prune(aggregate);
    }

    private static @Nullable AggregateCall transform(AggregateCall aggregateCall, Project project, List<RexNode> newProjects) {
        int singleArg = AggregateCaseToFilterRule.soleArgument(aggregateCall);
        if (singleArg < 0) {
            return null;
        }
        RexNode rexNode = project.getProjects().get(singleArg);
        if (!AggregateCaseToFilterRule.isThreeArgCase(rexNode)) {
            return null;
        }
        RelOptCluster cluster = project.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RexCall caseCall = (RexCall)rexNode;
        boolean flip = RexLiteral.isNullLiteral((RexNode)caseCall.operands.get(1)) && !RexLiteral.isNullLiteral((RexNode)caseCall.operands.get(2));
        RexNode arg1 = (RexNode)caseCall.operands.get(flip ? 2 : 1);
        RexNode arg2 = (RexNode)caseCall.operands.get(flip ? 1 : 2);
        SqlPostfixOperator op = flip ? SqlStdOperatorTable.IS_NOT_TRUE : SqlStdOperatorTable.IS_TRUE;
        RexNode filterFromCase = rexBuilder.makeCall((SqlOperator)op, (RexNode)caseCall.operands.get(0));
        RexNode filter = aggregateCall.filterArg >= 0 ? rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, project.getProjects().get(aggregateCall.filterArg), filterFromCase) : filterFromCase;
        SqlKind kind = aggregateCall.getAggregation().getKind();
        if (aggregateCall.isDistinct()) {
            if (kind == SqlKind.COUNT && RexLiteral.isNullLiteral(arg2)) {
                newProjects.add(arg1);
                newProjects.add(filter);
                return AggregateCall.create(SqlStdOperatorTable.COUNT, true, false, false, ImmutableList.of(Integer.valueOf(newProjects.size() - 2)), newProjects.size() - 1, null, RelCollations.EMPTY, aggregateCall.getType(), aggregateCall.getName());
            }
            return null;
        }
        if (kind == SqlKind.COUNT && arg1.isA(SqlKind.LITERAL) && !RexLiteral.isNullLiteral(arg1) && RexLiteral.isNullLiteral(arg2)) {
            newProjects.add(filter);
            return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false, false, ImmutableList.of(), newProjects.size() - 1, null, RelCollations.EMPTY, aggregateCall.getType(), aggregateCall.getName());
        }
        if (kind == SqlKind.SUM && AggregateCaseToFilterRule.isIntLiteral(arg1, BigDecimal.ONE) && AggregateCaseToFilterRule.isIntLiteral(arg2, BigDecimal.ZERO)) {
            newProjects.add(filter);
            RelDataTypeFactory typeFactory = cluster.getTypeFactory();
            RelDataType dataType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), false);
            return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false, false, ImmutableList.of(), newProjects.size() - 1, null, RelCollations.EMPTY, dataType, aggregateCall.getName());
        }
        if (RexLiteral.isNullLiteral(arg2) && aggregateCall.getAggregation().allowsFilter() || kind == SqlKind.SUM && AggregateCaseToFilterRule.isIntLiteral(arg2, BigDecimal.ZERO)) {
            newProjects.add(arg1);
            newProjects.add(filter);
            return AggregateCall.create(aggregateCall.getAggregation(), false, false, false, ImmutableList.of(Integer.valueOf(newProjects.size() - 2)), newProjects.size() - 1, null, RelCollations.EMPTY, aggregateCall.getType(), aggregateCall.getName());
        }
        return null;
    }

    private static int soleArgument(AggregateCall aggregateCall) {
        return aggregateCall.getArgList().size() == 1 ? aggregateCall.getArgList().get(0) : -1;
    }

    private static boolean isThreeArgCase(RexNode rexNode) {
        return rexNode.getKind() == SqlKind.CASE && ((RexCall)rexNode).operands.size() == 3;
    }

    private static boolean isIntLiteral(RexNode rexNode, BigDecimal value) {
        return rexNode instanceof RexLiteral && SqlTypeName.INT_TYPES.contains((Object)rexNode.getType().getSqlTypeName()) && value.equals(((RexLiteral)rexNode).getValueAs(BigDecimal.class));
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableAggregateCaseToFilterRule.Config.of().withOperandSupplier(b0 -> b0.operand(Aggregate.class).oneInput(b1 -> b1.operand(Project.class).anyInputs()));

        @Override
        default public AggregateCaseToFilterRule toRule() {
            return new AggregateCaseToFilterRule(this);
        }
    }
}

