package org.apache.sysds.lops;

import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.instructions.InstructionUtils;

/* loaded from: input_file:org/apache/sysds/lops/PartialAggregate.class */
public class PartialAggregate extends Lop {
    private Types.AggOp operation;
    private Types.Direction direction;
    private int _numThreads;
    private AggBinaryOp.SparkAggType _aggtype;

    public PartialAggregate(Lop lop, Types.AggOp aggOp, Types.Direction direction, Types.DataType dataType, Types.ValueType valueType, Types.ExecType execType, int i) {
        super(Lop.Type.PartialAggregate, dataType, valueType);
        this._numThreads = -1;
        this._aggtype = AggBinaryOp.SparkAggType.MULTI_BLOCK;
        init(lop, aggOp, direction, dataType, valueType, execType);
        this._numThreads = i;
    }

    public PartialAggregate(Lop lop, Types.AggOp aggOp, Types.Direction direction, Types.DataType dataType, Types.ValueType valueType, AggBinaryOp.SparkAggType sparkAggType, Types.ExecType execType) {
        super(Lop.Type.PartialAggregate, dataType, valueType);
        this._numThreads = -1;
        this._aggtype = AggBinaryOp.SparkAggType.MULTI_BLOCK;
        init(lop, aggOp, direction, dataType, valueType, execType);
        this._aggtype = sparkAggType;
    }

    private void init(Lop lop, Types.AggOp aggOp, Types.Direction direction, Types.DataType dataType, Types.ValueType valueType, Types.ExecType execType) {
        this.operation = aggOp;
        this.direction = direction;
        addInput(lop);
        lop.addOutput(this);
        this.lps.setProperties(this.inputs, execType);
    }

    public Types.CorrectionLocationType getCorrectionLocation() {
        return getCorrectionLocation(this.operation, this.direction);
    }

    public static Types.CorrectionLocationType getCorrectionLocation(Types.AggOp aggOp, Types.Direction direction) {
        Types.CorrectionLocationType correctionLocationType;
        switch (aggOp) {
            case SUM:
            case SUM_SQ:
            case TRACE:
                switch (direction) {
                    case Col:
                        correctionLocationType = Types.CorrectionLocationType.LASTROW;
                        break;
                    case Row:
                    case RowCol:
                        correctionLocationType = Types.CorrectionLocationType.LASTCOLUMN;
                        break;
                    default:
                        throw new LopsException("PartialAggregate.getCorrectionLocation() - Unknown aggregate direction: " + direction);
                }
            case MEAN:
                switch (direction) {
                    case Col:
                        correctionLocationType = Types.CorrectionLocationType.LASTTWOROWS;
                        break;
                    case Row:
                    case RowCol:
                        correctionLocationType = Types.CorrectionLocationType.LASTTWOCOLUMNS;
                        break;
                    default:
                        throw new LopsException("PartialAggregate.getCorrectionLocation() - Unknown aggregate direction: " + direction);
                }
            case VAR:
                switch (direction) {
                    case Col:
                        correctionLocationType = Types.CorrectionLocationType.LASTFOURROWS;
                        break;
                    case Row:
                    case RowCol:
                        correctionLocationType = Types.CorrectionLocationType.LASTFOURCOLUMNS;
                        break;
                    default:
                        throw new LopsException("PartialAggregate.getCorrectionLocation() - Unknown aggregate direction: " + direction);
                }
            case MAXINDEX:
            case MININDEX:
                correctionLocationType = Types.CorrectionLocationType.LASTCOLUMN;
                break;
            default:
                correctionLocationType = Types.CorrectionLocationType.NONE;
                break;
        }
        return correctionLocationType;
    }

    @Override // org.apache.sysds.lops.Lop
    public AggBinaryOp.SparkAggType getAggType() {
        return this._aggtype;
    }

    public void setDimensionsBasedOnDirection(long j, long j2, long j3) {
        setDimensionsBasedOnDirection(this, j, j2, j3, this.direction);
    }

    public static void setDimensionsBasedOnDirection(Lop lop, long j, long j2, long j3, Types.Direction direction) {
        try {
            if (direction == Types.Direction.Row) {
                lop.outParams.setDimensions(j, 1L, j3, -1L);
            } else if (direction == Types.Direction.Col) {
                lop.outParams.setDimensions(1L, j2, j3, -1L);
            } else {
                if (direction != Types.Direction.RowCol) {
                    throw new LopsException("In PartialAggregate Lop, Unknown aggregate direction " + direction);
                }
                lop.outParams.setDimensions(1L, 1L, j3, -1L);
            }
        } catch (HopsException e) {
            throw new LopsException("In PartialAggregate Lop, error setting dimensions based on direction", e);
        }
    }

    @Override // org.apache.sysds.lops.Lop
    public String toString() {
        return "Partial Aggregate " + this.operation;
    }

    private String getOpcode() {
        return getOpcode(this.operation, this.direction);
    }

    @Override // org.apache.sysds.lops.Lop
    public String getInstructions(String str, String str2) {
        String concatOperands = InstructionUtils.concatOperands(getExecType().name(), getOpcode(), getInputs().get(0).prepInputOperand(str), prepOutputOperand(str2));
        if (getExecType() == Types.ExecType.SPARK) {
            concatOperands = InstructionUtils.concatOperands(concatOperands, this._aggtype.name());
        } else if (getExecType() == Types.ExecType.CP || getExecType() == Types.ExecType.FED) {
            concatOperands = InstructionUtils.concatOperands(concatOperands, Integer.toString(this._numThreads));
            if (getOpcode().equalsIgnoreCase("uarimin") || getOpcode().equalsIgnoreCase("uarimax")) {
                concatOperands = InstructionUtils.concatOperands(concatOperands, "1");
            }
            if (getExecType() == Types.ExecType.FED) {
                concatOperands = InstructionUtils.concatOperands(concatOperands, this._fedOutput.name());
            }
        }
        return concatOperands;
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Failed to find 'out' block for switch in B:2:0x0008. Please report as an issue. */
    /* JADX WARN: Failed to find 'out' block for switch in B:46:0x00f9. Please report as an issue. */
    /* JADX WARN: Removed duplicated region for block: B:45:0x00f1  */
    /* JADX WARN: Removed duplicated region for block: B:52:0x0124 A[RETURN] */
    /* JADX WARN: Removed duplicated region for block: B:53:0x0127  */
    /* JADX WARN: Removed duplicated region for block: B:85:0x01ad  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static java.lang.String getOpcode(org.apache.sysds.common.Types.AggOp r4, org.apache.sysds.common.Types.Direction r5) {
        /*
            Method dump skipped, instructions count: 443
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.apache.sysds.lops.PartialAggregate.getOpcode(org.apache.sysds.common.Types$AggOp, org.apache.sysds.common.Types$Direction):java.lang.String");
    }
}
