package org.apache.drill.exec.physical.impl.aggregate;

import com.google.common.collect.Lists;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JVar;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import org.apache.drill.common.exceptions.ExecutionSetupException;
import org.apache.drill.common.exceptions.UserException;
import org.apache.drill.common.expression.ErrorCollectorImpl;
import org.apache.drill.common.expression.IfExpression;
import org.apache.drill.common.expression.LogicalExpression;
import org.apache.drill.common.logical.data.NamedExpression;
import org.apache.drill.exec.ExecConstants;
import org.apache.drill.exec.compile.sig.GeneratorMapping;
import org.apache.drill.exec.compile.sig.MappingSet;
import org.apache.drill.exec.exception.ClassTransformationException;
import org.apache.drill.exec.exception.SchemaChangeException;
import org.apache.drill.exec.expr.ClassGenerator;
import org.apache.drill.exec.expr.CodeGenerator;
import org.apache.drill.exec.expr.ExpressionTreeMaterializer;
import org.apache.drill.exec.expr.TypeHelper;
import org.apache.drill.exec.expr.ValueVectorWriteExpression;
import org.apache.drill.exec.ops.FragmentContext;
import org.apache.drill.exec.physical.config.HashAggregate;
import org.apache.drill.exec.physical.impl.aggregate.HashAggregator;
import org.apache.drill.exec.physical.impl.common.Comparator;
import org.apache.drill.exec.physical.impl.common.HashTableConfig;
import org.apache.drill.exec.record.AbstractRecordBatch;
import org.apache.drill.exec.record.BatchSchema;
import org.apache.drill.exec.record.MaterializedField;
import org.apache.drill.exec.record.RecordBatch;
import org.apache.drill.exec.record.TypedFieldId;
import org.apache.drill.exec.record.VectorWrapper;
import org.apache.drill.exec.record.selection.SelectionVector2;
import org.apache.drill.exec.record.selection.SelectionVector4;
import org.apache.drill.exec.vector.AllocationHelper;
import org.apache.drill.exec.vector.ValueVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/drill/exec/physical/impl/aggregate/HashAggBatch.class */
public class HashAggBatch extends AbstractRecordBatch<HashAggregate> {
    static final Logger logger = LoggerFactory.getLogger(HashAggBatch.class);
    private HashAggregator aggregator;
    private RecordBatch incoming;
    private LogicalExpression[] aggrExprs;
    private TypedFieldId[] groupByOutFieldIds;
    private TypedFieldId[] aggrOutFieldIds;
    private final List<Comparator> comparators;
    private BatchSchema incomingSchema;
    private boolean wasKilled;
    private final GeneratorMapping UPDATE_AGGR_INSIDE;
    private final GeneratorMapping UPDATE_AGGR_OUTSIDE;
    private final MappingSet UpdateAggrValuesMapping;

    public HashAggBatch(HashAggregate hashAggregate, RecordBatch recordBatch, FragmentContext fragmentContext) throws ExecutionSetupException {
        super(hashAggregate, fragmentContext);
        this.UPDATE_AGGR_INSIDE = GeneratorMapping.create("setupInterior", "updateAggrValuesInternal", "resetValues", "cleanup");
        this.UPDATE_AGGR_OUTSIDE = GeneratorMapping.create("setupInterior", "outputRecordValues", "resetValues", "cleanup");
        this.UpdateAggrValuesMapping = new MappingSet("incomingRowIdx", "outRowIdx", "htRowIdx", "incoming", "outgoing", "aggrValuesContainer", this.UPDATE_AGGR_INSIDE, this.UPDATE_AGGR_OUTSIDE, this.UPDATE_AGGR_INSIDE);
        this.incoming = recordBatch;
        this.wasKilled = false;
        int size = hashAggregate.getGroupByExprs().size();
        this.comparators = Lists.newArrayListWithExpectedSize(size);
        for (int i = 0; i < size; i++) {
            this.comparators.add(Comparator.IS_NOT_DISTINCT_FROM);
        }
    }

    @Override // org.apache.drill.exec.record.RecordBatch, org.apache.drill.exec.record.VectorAccessible
    public int getRecordCount() {
        if (this.state == AbstractRecordBatch.BatchState.DONE) {
            return 0;
        }
        return this.aggregator.getOutputCount();
    }

    /* JADX WARN: Type inference failed for: r0v16, types: [org.apache.drill.exec.vector.ValueVector] */
    @Override // org.apache.drill.exec.record.AbstractRecordBatch
    public void buildSchema() throws SchemaChangeException {
        switch (next(this.incoming)) {
            case NONE:
                this.state = AbstractRecordBatch.BatchState.DONE;
                this.container.buildSchema(BatchSchema.SelectionVectorMode.NONE);
                return;
            case OUT_OF_MEMORY:
                this.state = AbstractRecordBatch.BatchState.OUT_OF_MEMORY;
                return;
            case STOP:
                this.state = AbstractRecordBatch.BatchState.STOP;
                return;
            default:
                this.incomingSchema = this.incoming.getSchema();
                if (!createAggregator()) {
                    this.state = AbstractRecordBatch.BatchState.DONE;
                }
                Iterator<VectorWrapper<?>> it = this.container.iterator();
                while (it.hasNext()) {
                    AllocationHelper.allocatePrecomputedChildCount((ValueVector) it.next().getValueVector(), 0, 0, 0);
                }
                return;
        }
    }

    @Override // org.apache.drill.exec.record.AbstractRecordBatch
    public RecordBatch.IterOutcome innerNext() {
        HashAggregator.AggOutcome doWork;
        if (this.aggregator.allFlushed()) {
            return RecordBatch.IterOutcome.NONE;
        }
        if (this.aggregator.buildComplete() || this.aggregator.earlyOutput()) {
            HashAggregator.AggIterOutcome outputCurrentBatch = this.aggregator.outputCurrentBatch();
            if (outputCurrentBatch == HashAggregator.AggIterOutcome.AGG_NONE) {
                return RecordBatch.IterOutcome.NONE;
            }
            if (outputCurrentBatch == HashAggregator.AggIterOutcome.AGG_OK) {
                return RecordBatch.IterOutcome.OK;
            }
            this.incoming = this.aggregator.getNewIncoming();
        }
        if (this.wasKilled) {
            this.aggregator.cleanup();
            this.incoming.kill(false);
            return RecordBatch.IterOutcome.NONE;
        }
        do {
            doWork = this.aggregator.doWork();
        } while (doWork == HashAggregator.AggOutcome.CALL_WORK_AGAIN);
        switch (doWork) {
            case CLEANUP_AND_RETURN:
                this.container.zeroVectors();
                this.aggregator.cleanup();
                this.state = AbstractRecordBatch.BatchState.DONE;
                break;
            case RETURN_OUTCOME:
                break;
            case UPDATE_AGGREGATOR:
                this.context.fail(UserException.unsupportedError().message(SchemaChangeException.schemaChanged("Hash aggregate does not support schema change", this.incomingSchema, this.incoming.getSchema()).getMessage(), new Object[0]).build(logger));
                close();
                killIncoming(false);
                return RecordBatch.IterOutcome.STOP;
            default:
                throw new IllegalStateException(String.format("Unknown state %s.", doWork));
        }
        return this.aggregator.getOutcome();
    }

    private boolean createAggregator() {
        try {
            try {
                this.stats.startSetup();
                this.aggregator = createAggregatorInternal();
                this.stats.stopSetup();
                return true;
            } catch (IOException | ClassTransformationException | SchemaChangeException e) {
                this.context.fail(e);
                this.container.clear();
                this.incoming.kill(false);
                this.stats.stopSetup();
                return false;
            }
        } catch (Throwable th) {
            this.stats.stopSetup();
            throw th;
        }
    }

    private HashAggregator createAggregatorInternal() throws SchemaChangeException, ClassTransformationException, IOException {
        CodeGenerator codeGenerator = CodeGenerator.get(HashAggregator.TEMPLATE_DEFINITION, this.context.getFunctionRegistry(), this.context.getOptions());
        ClassGenerator<HashAggregator> root = codeGenerator.getRoot();
        ClassGenerator<HashAggregator> innerGenerator = root.getInnerGenerator("BatchHolder");
        codeGenerator.plainJavaCapable(true);
        this.container.clear();
        int size = ((HashAggregate) this.popConfig).getGroupByExprs() != null ? ((HashAggregate) this.popConfig).getGroupByExprs().size() : 0;
        int size2 = ((HashAggregate) this.popConfig).getAggrExprs() != null ? ((HashAggregate) this.popConfig).getAggrExprs().size() : 0;
        this.aggrExprs = new LogicalExpression[size2];
        this.groupByOutFieldIds = new TypedFieldId[size];
        this.aggrOutFieldIds = new TypedFieldId[size2];
        ErrorCollectorImpl errorCollectorImpl = new ErrorCollectorImpl();
        for (int i = 0; i < size; i++) {
            NamedExpression namedExpression = ((HashAggregate) this.popConfig).getGroupByExprs().get(i);
            LogicalExpression materialize = ExpressionTreeMaterializer.materialize(namedExpression.getExpr(), this.incoming, errorCollectorImpl, this.context.getFunctionRegistry());
            if (materialize != null) {
                this.groupByOutFieldIds[i] = this.container.add(TypeHelper.getNewVector(MaterializedField.create(namedExpression.getRef().getAsNamePart().getName(), materialize.getMajorType()), this.oContext.getAllocator()));
            }
        }
        for (int i2 = 0; i2 < size2; i2++) {
            NamedExpression namedExpression2 = ((HashAggregate) this.popConfig).getAggrExprs().get(i2);
            LogicalExpression materialize2 = ExpressionTreeMaterializer.materialize(namedExpression2.getExpr(), this.incoming, errorCollectorImpl, this.context.getFunctionRegistry());
            if (materialize2 instanceof IfExpression) {
                throw UserException.unsupportedError(new UnsupportedOperationException("Union type not supported in aggregate functions")).build(logger);
            }
            if (errorCollectorImpl.hasErrors()) {
                throw new SchemaChangeException("Failure while materializing expression. " + errorCollectorImpl.toErrorString());
            }
            if (materialize2 != null) {
                this.aggrOutFieldIds[i2] = this.container.add(TypeHelper.getNewVector(MaterializedField.create(namedExpression2.getRef().getAsNamePart().getName(), materialize2.getMajorType()), this.oContext.getAllocator()));
                this.aggrExprs[i2] = new ValueVectorWriteExpression(this.aggrOutFieldIds[i2], materialize2, true);
            }
        }
        setupUpdateAggrValues(innerGenerator);
        setupGetIndex(root);
        root.getBlock("resetValues")._return(JExpr.TRUE);
        this.container.buildSchema(BatchSchema.SelectionVectorMode.NONE);
        HashAggregator hashAggregator = (HashAggregator) this.context.getImplementationClass(codeGenerator);
        hashAggregator.setup((HashAggregate) this.popConfig, new HashTableConfig((int) this.context.getOptions().getOption(ExecConstants.MIN_HASH_TABLE_SIZE), 0.75f, ((HashAggregate) this.popConfig).getGroupByExprs(), null, this.comparators), this.context, this.stats, this.oContext, this.incoming, this, this.aggrExprs, innerGenerator.getWorkspaceTypes(), this.groupByOutFieldIds, this.container);
        return hashAggregator;
    }

    private void setupUpdateAggrValues(ClassGenerator<HashAggregator> classGenerator) {
        classGenerator.setMappingSet(this.UpdateAggrValuesMapping);
        for (LogicalExpression logicalExpression : this.aggrExprs) {
            classGenerator.addExpr(logicalExpression, ClassGenerator.BlkCreateMode.TRUE);
        }
    }

    private void setupGetIndex(ClassGenerator<HashAggregator> classGenerator) {
        switch (this.incoming.getSchema().getSelectionVectorMode()) {
            case FOUR_BYTE:
                JVar declareClassField = classGenerator.declareClassField("sv4_", classGenerator.getModel()._ref(SelectionVector4.class));
                classGenerator.getBlock("doSetup").assign(declareClassField, JExpr.direct("incoming").invoke("getSelectionVector4"));
                classGenerator.getBlock("getVectorIndex")._return(declareClassField.invoke("get").arg(JExpr.direct("recordIndex")));
                return;
            case NONE:
                classGenerator.getBlock("getVectorIndex")._return(JExpr.direct("recordIndex"));
                return;
            case TWO_BYTE:
                JVar declareClassField2 = classGenerator.declareClassField("sv2_", classGenerator.getModel()._ref(SelectionVector2.class));
                classGenerator.getBlock("doSetup").assign(declareClassField2, JExpr.direct("incoming").invoke("getSelectionVector2"));
                classGenerator.getBlock("getVectorIndex")._return(declareClassField2.invoke("getIndex").arg(JExpr.direct("recordIndex")));
                return;
            default:
                return;
        }
    }

    @Override // org.apache.drill.exec.record.AbstractRecordBatch, java.lang.AutoCloseable
    public void close() {
        if (this.aggregator != null) {
            this.aggregator.cleanup();
        }
        super.close();
    }

    @Override // org.apache.drill.exec.record.AbstractRecordBatch
    protected void killIncoming(boolean z) {
        this.wasKilled = true;
        this.incoming.kill(z);
    }
}
