package org.apache.sysds.runtime.instructions.cp;

import org.antlr.v4.runtime.atn.PredictionContext;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.Ctable;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.CTableMap;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.LongLongDoubleHashMap;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.class */
public class CtableCPInstruction extends ComputationCPInstruction {
    private final CPOperand _outDim1;
    private final CPOperand _outDim2;
    private final boolean _isExpand;
    private final boolean _ignoreZeros;

    private CtableCPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, boolean z, String str2, boolean z2, boolean z3, boolean z4, String str3, String str4) {
        super(CPInstruction.CPType.Ctable, null, cPOperand, cPOperand2, cPOperand3, cPOperand4, str3, str4);
        this._outDim1 = new CPOperand(str, Types.ValueType.FP64, Types.DataType.SCALAR, z);
        this._outDim2 = new CPOperand(str2, Types.ValueType.FP64, Types.DataType.SCALAR, z2);
        this._isExpand = z3;
        this._ignoreZeros = z4;
    }

    public static CtableCPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 7);
        String str2 = instructionPartsWithValueType[0];
        if (!str2.equalsIgnoreCase("ctable") && !str2.equalsIgnoreCase("ctableexpand")) {
            throw new DMLRuntimeException("Unexpected opcode in TertiaryCPInstruction: " + str);
        }
        boolean equalsIgnoreCase = str2.equalsIgnoreCase("ctableexpand");
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        CPOperand cPOperand3 = new CPOperand(instructionPartsWithValueType[3]);
        String[] split = instructionPartsWithValueType[4].split("·");
        String[] split2 = instructionPartsWithValueType[5].split("·");
        return new CtableCPInstruction(cPOperand, cPOperand2, cPOperand3, new CPOperand(instructionPartsWithValueType[6]), split[0], Boolean.parseBoolean(split[1]), split2[0], Boolean.parseBoolean(split2[1]), equalsIgnoreCase, Boolean.parseBoolean(instructionPartsWithValueType[7]), str2, str);
    }

    private Ctable.OperationTypes findCtableOperation() {
        return Ctable.findCtableOperationByInputDataTypes(this.input1.getDataType(), this.input2.getDataType(), this.input3.getDataType());
    }

    @Override // org.apache.sysds.runtime.instructions.cp.CPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        MatrixBlock matrixInput = executionContext.getMatrixInput(this.input1.getName());
        MatrixBlock matrixBlock = null;
        CTableMap cTableMap = new CTableMap(LongLongDoubleHashMap.EntryType.INT);
        MatrixBlock matrixBlock2 = null;
        Ctable.OperationTypes findCtableOperation = this._isExpand ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : findCtableOperation();
        long longValue = executionContext.getScalarInput(this._outDim1).getLongValue();
        long longValue2 = executionContext.getScalarInput(this._outDim2).getLongValue();
        boolean z = (longValue == -1 || longValue2 == -1) ? false : true;
        if (z && !MatrixBlock.evalSparseFormatInMemory(longValue, longValue2, matrixInput.getNumRows() * matrixInput.getNumColumns())) {
            matrixBlock2 = new MatrixBlock((int) longValue, (int) longValue2, false);
        }
        if (this._isExpand) {
            matrixBlock2 = new MatrixBlock(matrixInput.getNumRows(), PredictionContext.EMPTY_RETURN_STATE, true);
        }
        switch (findCtableOperation) {
            case CTABLE_TRANSFORM:
                matrixBlock = executionContext.getMatrixInput(this.input2.getName());
                matrixInput.ctableOperations(this._optr, matrixBlock, executionContext.getMatrixInput(this.input3.getName()), cTableMap, matrixBlock2);
                break;
            case CTABLE_TRANSFORM_SCALAR_WEIGHT:
                matrixBlock = executionContext.getMatrixInput(this.input2.getName());
                matrixInput.ctableOperations(this._optr, matrixBlock, executionContext.getScalarInput(this.input3.getName(), this.input3.getValueType(), this.input3.isLiteral()).getDoubleValue(), this._ignoreZeros, cTableMap, matrixBlock2);
                break;
            case CTABLE_EXPAND_SCALAR_WEIGHT:
                matrixBlock = executionContext.getMatrixInput(this.input2.getName());
                matrixInput.ctableSeqOperations(matrixBlock, executionContext.getScalarInput(this.input3.getName(), this.input3.getValueType(), this.input3.isLiteral()).getDoubleValue(), matrixBlock2);
                break;
            case CTABLE_TRANSFORM_HISTOGRAM:
                matrixInput.ctableOperations(this._optr, executionContext.getScalarInput(this.input2.getName(), this.input2.getValueType(), this.input2.isLiteral()).getDoubleValue(), executionContext.getScalarInput(this.input3.getName(), this.input3.getValueType(), this.input3.isLiteral()).getDoubleValue(), cTableMap, matrixBlock2);
                break;
            case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM:
                matrixInput.ctableOperations(this._optr, executionContext.getScalarInput(this.input2.getName(), this.input2.getValueType(), this.input2.isLiteral()).getDoubleValue(), executionContext.getMatrixInput(this.input3.getName()), cTableMap, matrixBlock2);
                break;
            default:
                throw new DMLRuntimeException("Encountered an invalid ctable operation (" + findCtableOperation + ") while executing instruction: " + toString());
        }
        if (this.input1.getDataType() == Types.DataType.MATRIX) {
            executionContext.releaseMatrixInput(this.input1.getName());
        }
        if (this.input2.getDataType() == Types.DataType.MATRIX) {
            executionContext.releaseMatrixInput(this.input2.getName());
        }
        if (this.input3.getDataType() == Types.DataType.MATRIX) {
            executionContext.releaseMatrixInput(this.input3.getName());
        }
        if (matrixBlock2 == null) {
            matrixBlock2 = z ? DataConverter.convertToMatrixBlock(cTableMap, (int) longValue, (int) longValue2) : DataConverter.convertToMatrixBlock(cTableMap);
        } else {
            matrixBlock2.examSparsity();
        }
        if (checkGuardedRepresentationChange(matrixInput, matrixBlock, matrixBlock2)) {
            matrixBlock2.examSparsity();
        }
        executionContext.setMatrixOutput(this.output.getName(), matrixBlock2);
    }

    @Override // org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction, org.apache.sysds.runtime.lineage.LineageTraceable
    public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
        return Pair.of(this.output.getName(), new LineageItem(getOpcode(), (this._outDim1.getName().equals("-1") && this._outDim2.getName().equals("-1")) ? LineageItemUtils.getLineage(executionContext, this.input1, this.input2, this.input3) : LineageItemUtils.getLineage(executionContext, this.input1, this.input2, this.input3, this._outDim1, this._outDim2)));
    }
}
