/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.interpreter;

import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.calcite.DataContext;
import org.apache.calcite.adapter.enumerable.AggImpState;
import org.apache.calcite.adapter.enumerable.JavaRowFormat;
import org.apache.calcite.adapter.enumerable.PhysType;
import org.apache.calcite.adapter.enumerable.PhysTypeImpl;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.adapter.enumerable.impl.AggAddContextImpl;
import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.interpreter.AbstractSingleNode;
import org.apache.calcite.interpreter.Context;
import org.apache.calcite.interpreter.Interpreter;
import org.apache.calcite.interpreter.JaninoRexCompiler;
import org.apache.calcite.interpreter.Row;
import org.apache.calcite.interpreter.Scalar;
import org.apache.calcite.interpreter.Sink;
import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.flink.shaded.calcite.com.google.common.base.Supplier;
import org.apache.flink.shaded.calcite.com.google.common.base.Throwables;
import org.apache.flink.shaded.calcite.com.google.common.collect.ImmutableList;
import org.apache.flink.shaded.calcite.com.google.common.collect.Lists;
import org.apache.flink.shaded.calcite.com.google.common.collect.Maps;

public class AggregateNode
extends AbstractSingleNode<Aggregate> {
    private final List<Grouping> groups = Lists.newArrayList();
    private final ImmutableBitSet unionGroups;
    private final int outputRowLength;
    private final ImmutableList<AccumulatorFactory> accumulatorFactories;
    private final DataContext dataContext;

    public AggregateNode(Interpreter interpreter, Aggregate rel) {
        super(interpreter, rel);
        this.dataContext = interpreter.getDataContext();
        ImmutableBitSet union = ImmutableBitSet.of();
        if (rel.getGroupSets() != null) {
            for (ImmutableBitSet group : rel.getGroupSets()) {
                union = union.union(group);
                this.groups.add(new Grouping(group));
            }
        }
        this.unionGroups = union;
        this.outputRowLength = this.unionGroups.cardinality() + (rel.indicator ? this.unionGroups.cardinality() : 0) + rel.getAggCallList().size();
        ImmutableList.Builder builder = ImmutableList.builder();
        for (AggregateCall aggregateCall : rel.getAggCallList()) {
            builder.add(this.getAccumulator(aggregateCall));
        }
        this.accumulatorFactories = builder.build();
    }

    @Override
    public void run() throws InterruptedException {
        Row r;
        while ((r = this.source.receive()) != null) {
            for (Grouping group : this.groups) {
                group.send(r);
            }
        }
        for (Grouping group : this.groups) {
            group.end(this.sink);
        }
    }

    private AccumulatorFactory getAccumulator(final AggregateCall call) {
        if (call.getAggregation() == SqlStdOperatorTable.COUNT) {
            return new AccumulatorFactory(){

                @Override
                public Accumulator get() {
                    return new CountAccumulator(call);
                }
            };
        }
        if (call.getAggregation() == SqlStdOperatorTable.SUM) {
            return new UdaAccumulatorFactory(AggregateFunctionImpl.create(IntSum.class), call);
        }
        final JavaTypeFactory typeFactory = (JavaTypeFactory)((Aggregate)this.rel).getCluster().getTypeFactory();
        int stateOffset = 0;
        final AggImpState agg = new AggImpState(0, call, false);
        int stateSize = agg.state.size();
        BlockBuilder builder2 = new BlockBuilder();
        final PhysType inputPhysType = PhysTypeImpl.of(typeFactory, ((Aggregate)this.rel).getInput().getRowType(), JavaRowFormat.ARRAY);
        RelDataTypeFactory.FieldInfoBuilder builder = typeFactory.builder();
        for (Expression expression : agg.state) {
            builder.add("a", typeFactory.createJavaType((Class)expression.getType()));
        }
        PhysType accPhysType = PhysTypeImpl.of(typeFactory, builder.build(), JavaRowFormat.ARRAY);
        final ParameterExpression inParameter = Expressions.parameter(inputPhysType.getJavaRowType(), "in");
        ParameterExpression acc_ = Expressions.parameter(accPhysType.getJavaRowType(), "acc");
        ArrayList<Expression> accumulator = new ArrayList<Expression>(stateSize);
        for (int j2 = 0; j2 < stateSize; ++j2) {
            accumulator.add(accPhysType.fieldReference(acc_, j2 + stateOffset));
        }
        agg.state = accumulator;
        AggAddContextImpl addContext = new AggAddContextImpl(builder2, accumulator){

            @Override
            public List<RexNode> rexArguments() {
                ArrayList<RexNode> args = new ArrayList<RexNode>();
                for (int index : agg.call.getArgList()) {
                    args.add(RexInputRef.of(index, inputPhysType.getRowType()));
                }
                return args;
            }

            @Override
            public RexNode rexFilterArgument() {
                return agg.call.filterArg < 0 ? null : RexInputRef.of(agg.call.filterArg, inputPhysType.getRowType());
            }

            @Override
            public RexToLixTranslator rowTranslator() {
                return RexToLixTranslator.forAggregation(typeFactory, this.currentBlock(), new RexToLixTranslator.InputGetterImpl(Collections.singletonList(Pair.of(inParameter, inputPhysType)))).setNullable(this.currentNullables());
            }
        };
        agg.implementor.implementAdd(agg.context, addContext);
        ParameterExpression context_ = Expressions.parameter(Context.class, "context");
        ParameterExpression outputValues_ = Expressions.parameter(Object[].class, "outputValues");
        Scalar addScalar = JaninoRexCompiler.baz(context_, outputValues_, builder2.toBlock());
        return new ScalarAccumulatorDef(null, addScalar, null, ((Aggregate)this.rel).getInput().getRowType().getFieldCount(), stateSize, this.dataContext);
    }

    private static class UdaAccumulator
    implements Accumulator {
        private final UdaAccumulatorFactory factory;
        private Object value;

        public UdaAccumulator(UdaAccumulatorFactory factory) {
            this.factory = factory;
            try {
                this.value = factory.aggFunction.initMethod.invoke(factory.instance, new Object[0]);
            }
            catch (IllegalAccessException e) {
                throw Throwables.propagate(e);
            }
            catch (InvocationTargetException e) {
                throw Throwables.propagate(e);
            }
        }

        @Override
        public void send(Row row) {
            Object[] args = new Object[]{this.value, row.getValues()[this.factory.argOrdinal]};
            for (int i = 1; i < args.length; ++i) {
                if (args[i] != null) continue;
                return;
            }
            try {
                this.value = this.factory.aggFunction.addMethod.invoke(this.factory.instance, args);
            }
            catch (IllegalAccessException e) {
                throw Throwables.propagate(e);
            }
            catch (InvocationTargetException e) {
                throw Throwables.propagate(e);
            }
        }

        @Override
        public Object end() {
            Object[] args = new Object[]{this.value};
            try {
                return this.factory.aggFunction.resultMethod.invoke(this.factory.instance, args);
            }
            catch (IllegalAccessException e) {
                throw Throwables.propagate(e);
            }
            catch (InvocationTargetException e) {
                throw Throwables.propagate(e);
            }
        }
    }

    private static class UdaAccumulatorFactory
    implements AccumulatorFactory {
        public final AggregateFunctionImpl aggFunction;
        public final int argOrdinal;
        public final Object instance;

        public UdaAccumulatorFactory(AggregateFunctionImpl aggFunction, AggregateCall call) {
            this.aggFunction = aggFunction;
            if (call.getArgList().size() != 1) {
                throw new UnsupportedOperationException("in current implementation, aggregate must have precisely one argument");
            }
            this.argOrdinal = call.getArgList().get(0);
            if (aggFunction.isStatic) {
                this.instance = null;
            } else {
                try {
                    this.instance = aggFunction.declaringClass.newInstance();
                }
                catch (InstantiationException e) {
                    throw Throwables.propagate(e);
                }
                catch (IllegalAccessException e) {
                    throw Throwables.propagate(e);
                }
            }
        }

        @Override
        public Accumulator get() {
            return new UdaAccumulator(this);
        }
    }

    public static class LongSum {
        public long init() {
            return 0L;
        }

        public long add(long accumulator, int v) {
            return accumulator + (long)v;
        }

        public long merge(long accumulator0, long accumulator1) {
            return accumulator0 + accumulator1;
        }

        public long result(long accumulator) {
            return accumulator;
        }
    }

    public static class IntSum {
        public int init() {
            return 0;
        }

        public int add(int accumulator, int v) {
            return accumulator + v;
        }

        public int merge(int accumulator0, int accumulator1) {
            return accumulator0 + accumulator1;
        }

        public int result(int accumulator) {
            return accumulator;
        }
    }

    private static interface Accumulator {
        public void send(Row var1);

        public Object end();
    }

    private class AccumulatorList
    extends ArrayList<Accumulator> {
        private AccumulatorList() {
        }

        public void send(Row row) {
            for (Accumulator a : this) {
                a.send(row);
            }
        }

        public void end(Row.RowBuilder r) {
            int accIndex = 0;
            int rowIndex = r.size() - this.size();
            while (rowIndex < r.size()) {
                r.set(rowIndex, ((Accumulator)this.get(accIndex)).end());
                ++rowIndex;
                ++accIndex;
            }
        }
    }

    private class Grouping {
        private final ImmutableBitSet grouping;
        private final Map<Row, AccumulatorList> accumulators = Maps.newHashMap();

        private Grouping(ImmutableBitSet grouping) {
            this.grouping = grouping;
        }

        public void send(Row row) {
            Row.RowBuilder builder = Row.newBuilder(this.grouping.cardinality());
            for (Integer i : this.grouping) {
                builder.set(i, row.getObject(i));
            }
            Row key = builder.build();
            if (!this.accumulators.containsKey(key)) {
                AccumulatorList list = new AccumulatorList();
                for (AccumulatorFactory factory : AggregateNode.this.accumulatorFactories) {
                    list.add(factory.get());
                }
                this.accumulators.put(key, list);
            }
            this.accumulators.get(key).send(row);
        }

        public void end(Sink sink) throws InterruptedException {
            for (Map.Entry<Row, AccumulatorList> e : this.accumulators.entrySet()) {
                Row key = e.getKey();
                AccumulatorList list = e.getValue();
                Row.RowBuilder rb = Row.newBuilder(AggregateNode.this.outputRowLength);
                int index = 0;
                for (Integer groupPos : AggregateNode.this.unionGroups) {
                    if (this.grouping.get(groupPos)) {
                        rb.set(index, key.getObject(groupPos));
                        if (((Aggregate)AggregateNode.this.rel).indicator) {
                            rb.set(AggregateNode.this.unionGroups.cardinality() + index, true);
                        }
                    }
                    ++index;
                }
                list.end(rb);
                sink.send(rb.build());
            }
        }
    }

    private static class ScalarAccumulator
    implements Accumulator {
        final ScalarAccumulatorDef def;
        final Object[] values;

        private ScalarAccumulator(ScalarAccumulatorDef def, Object[] values) {
            this.def = def;
            this.values = values;
        }

        @Override
        public void send(Row row) {
            System.arraycopy(row.getValues(), 0, this.def.sendContext.values, 0, this.def.rowLength);
            System.arraycopy(this.values, 0, this.def.sendContext.values, this.def.rowLength, this.values.length);
            this.def.addScalar.execute(this.def.sendContext, this.values);
        }

        @Override
        public Object end() {
            System.arraycopy(this.values, 0, this.def.endContext.values, 0, this.values.length);
            return this.def.endScalar.execute(this.def.endContext);
        }
    }

    private static class ScalarAccumulatorDef
    implements AccumulatorFactory {
        final Scalar initScalar;
        final Scalar addScalar;
        final Scalar endScalar;
        final Context sendContext;
        final Context endContext;
        final int rowLength;
        final int accumulatorLength;

        private ScalarAccumulatorDef(Scalar initScalar, Scalar addScalar, Scalar endScalar, int rowLength, int accumulatorLength, DataContext root) {
            this.initScalar = initScalar;
            this.addScalar = addScalar;
            this.endScalar = endScalar;
            this.accumulatorLength = accumulatorLength;
            this.rowLength = rowLength;
            this.sendContext = new Context(root);
            this.sendContext.values = new Object[rowLength + accumulatorLength];
            this.endContext = new Context(root);
            this.endContext.values = new Object[accumulatorLength];
        }

        @Override
        public Accumulator get() {
            return new ScalarAccumulator(this, new Object[this.accumulatorLength]);
        }
    }

    private static interface AccumulatorFactory
    extends Supplier<Accumulator> {
    }

    private static class CountAccumulator
    implements Accumulator {
        private final AggregateCall call;
        long cnt;

        public CountAccumulator(AggregateCall call) {
            this.call = call;
            this.cnt = 0L;
        }

        @Override
        public void send(Row row) {
            boolean notNull = true;
            for (Integer i : this.call.getArgList()) {
                if (row.getObject(i) != null) continue;
                notNull = false;
                break;
            }
            if (notNull) {
                ++this.cnt;
            }
        }

        @Override
        public Object end() {
            return this.cnt;
        }
    }
}

