package org.apache.sysds.hops;

import java.util.ArrayList;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties;
import org.apache.sysds.lops.Transform;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

/* loaded from: input_file:org/apache/sysds/hops/ReorgOp.class */
public class ReorgOp extends MultiThreadedHop {
    public static boolean FORCE_DIST_SORT_INDEXES = false;
    private Types.ReOrgOp _op;

    private ReorgOp() {
    }

    public ReorgOp(String str, Types.DataType dataType, Types.ValueType valueType, Types.ReOrgOp reOrgOp, Hop hop) {
        super(str, dataType, valueType);
        this._op = reOrgOp;
        getInput().add(0, hop);
        hop.getParent().add(this);
        refreshSizeInformation();
    }

    public ReorgOp(String str, Types.DataType dataType, Types.ValueType valueType, Types.ReOrgOp reOrgOp, ArrayList<Hop> arrayList) {
        super(str, dataType, valueType);
        this._op = reOrgOp;
        for (int i = 0; i < arrayList.size(); i++) {
            Hop hop = arrayList.get(i);
            getInput().add(i, hop);
            hop.getParent().add(this);
        }
        refreshSizeInformation();
    }

    @Override // org.apache.sysds.hops.Hop
    public void checkArity() {
        int size = this._input.size();
        switch (this._op) {
            case TRANS:
            case DIAG:
            case REV:
                HopsException.check(size == 1, this, "should have arity 1 for op %s but has arity %d", this._op, Integer.valueOf(size));
                return;
            case RESHAPE:
            case SORT:
                HopsException.check(size == 5, this, "should have arity 5 for op %s but has arity %d", this._op, Integer.valueOf(size));
                return;
            default:
                throw new HopsException("Unsupported lops construction for operation type '" + this._op + "'.");
        }
    }

    public Types.ReOrgOp getOp() {
        return this._op;
    }

    @Override // org.apache.sysds.hops.Hop
    public String getOpString() {
        return "r(" + this._op.toString() + ")";
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean isGPUEnabled() {
        if (!DMLScript.USE_ACCELERATOR) {
            return false;
        }
        switch (this._op) {
            case TRANS:
                if (getDim1() == 1 && getDim2() == 1) {
                    return false;
                }
                return ((getInput().get(0) instanceof ReorgOp) && ((ReorgOp) getInput().get(0)).getOp() == Types.ReOrgOp.TRANS) ? false : true;
            case DIAG:
            case REV:
            case SORT:
                return false;
            case RESHAPE:
                return true;
            default:
                throw new RuntimeException("Unsupported operator:" + this._op.name());
        }
    }

    @Override // org.apache.sysds.hops.MultiThreadedHop
    public boolean isMultiThreadedOpType() {
        return this._op == Types.ReOrgOp.TRANS || this._op == Types.ReOrgOp.SORT;
    }

    @Override // org.apache.sysds.hops.Hop
    public Lop constructLops() {
        Transform transform;
        if (getLops() != null) {
            return getLops();
        }
        LopProperties.ExecType optFindExecType = optFindExecType();
        switch (this._op) {
            case TRANS:
                Lop constructLops = getInput().get(0).constructLops();
                if (!(constructLops instanceof Transform) || ((Transform) constructLops).getOp() != Types.ReOrgOp.TRANS) {
                    if (getDim1() != 1 || getDim2() != 1) {
                        Transform transform2 = new Transform(constructLops, this._op, getDataType(), getValueType(), optFindExecType, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
                        setOutputDimensions(transform2);
                        setLineNumbers(transform2);
                        setLops(transform2);
                        break;
                    } else {
                        setLops(constructLops);
                        break;
                    }
                } else {
                    setLops(constructLops.getInputs().get(0));
                    break;
                }
                break;
            case DIAG:
            case REV:
                Transform transform3 = new Transform(getInput().get(0).constructLops(), this._op, getDataType(), getValueType(), optFindExecType);
                setOutputDimensions(transform3);
                setLineNumbers(transform3);
                setLops(transform3);
                break;
            case RESHAPE:
                Lop[] lopArr = new Lop[5];
                for (int i = 0; i < 5; i++) {
                    lopArr[i] = getInput().get(i).constructLops();
                }
                this._outputEmptyBlocks = optFindExecType == LopProperties.ExecType.SPARK && !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
                Transform transform4 = new Transform(lopArr, this._op, getDataType(), getValueType(), this._outputEmptyBlocks, optFindExecType);
                setOutputDimensions(transform4);
                setLineNumbers(transform4);
                setLops(transform4);
                break;
            case SORT:
                Lop[] lopArr2 = new Lop[4];
                for (int i2 = 0; i2 < 4; i2++) {
                    lopArr2[i2] = getInput().get(i2).constructLops();
                }
                Hop hop = getInput().get(2);
                if (optFindExecType == LopProperties.ExecType.SPARK) {
                    transform = new Transform(lopArr2, Types.ReOrgOp.SORT, getDataType(), getValueType(), optFindExecType, !FORCE_DIST_SORT_INDEXES && isSortSPRewriteApplicable() && hop.getDataType().isScalar(), 1);
                } else {
                    transform = new Transform(lopArr2, Types.ReOrgOp.SORT, getDataType(), getValueType(), optFindExecType, false, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
                }
                setOutputDimensions(transform);
                setLineNumbers(transform);
                setLops(transform);
                break;
            default:
                throw new HopsException("Unsupported lops construction for operation type '" + this._op + "'.");
        }
        constructAndSetLopsDataFlowProperties();
        return getLops();
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        return OptimizerUtils.estimateSizeExactSparsity(j, j2, OptimizerUtils.getSparsity(j, j2, j3));
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        if (this._op != Types.ReOrgOp.SORT) {
            return DataExpression.DEFAULT_DELIM_FILL_VALUE;
        }
        Hop hop = getInput().get(3);
        return ((hop instanceof LiteralOp) && !HopRewriteUtils.getBooleanValueSafe((LiteralOp) hop) && (j2 == 1 || j3 == 0)) ? DataExpression.DEFAULT_DELIM_FILL_VALUE : j * 4;
    }

    @Override // org.apache.sysds.hops.Hop
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memoTable) {
        MatrixCharacteristics matrixCharacteristics = null;
        DataCharacteristics allInputStats = memoTable.getAllInputStats(getInput().get(0));
        switch (this._op) {
            case TRANS:
                if (allInputStats.dimsKnown()) {
                    matrixCharacteristics = new MatrixCharacteristics(allInputStats.getCols(), allInputStats.getRows(), -1, allInputStats.getNonZeros());
                    break;
                }
                break;
            case DIAG:
                long rows = allInputStats.getRows();
                if (rows == 1) {
                    matrixCharacteristics = new MatrixCharacteristics(rows, rows, -1, allInputStats.getNonZeros() >= 0 ? allInputStats.getNonZeros() : rows);
                }
                if (rows > 1) {
                    matrixCharacteristics = new MatrixCharacteristics(rows, 1L, -1, allInputStats.getNonZeros() >= 0 ? Math.min(rows, allInputStats.getNonZeros()) : rows);
                    break;
                }
                break;
            case REV:
                if (allInputStats.dimsKnown()) {
                    matrixCharacteristics = new MatrixCharacteristics(allInputStats.getRows(), allInputStats.getCols(), -1, allInputStats.getNonZeros());
                    break;
                }
                break;
            case RESHAPE:
                if (allInputStats.dimsKnown()) {
                    if (rowsKnown() && getDim1() != 0) {
                        matrixCharacteristics = new MatrixCharacteristics(getDim1(), (allInputStats.getRows() * allInputStats.getCols()) / getDim1(), -1, allInputStats.getNonZeros());
                        break;
                    } else if (colsKnown() && getDim2() != 0) {
                        matrixCharacteristics = new MatrixCharacteristics((allInputStats.getRows() * allInputStats.getCols()) / getDim2(), getDim2(), -1, allInputStats.getNonZeros());
                        break;
                    } else if (dimsKnown()) {
                        matrixCharacteristics = new MatrixCharacteristics(getDim1(), getDim2(), -1, -1L);
                        break;
                    }
                }
                break;
            case SORT:
                Hop hop = getInput().get(3);
                if (!(!(hop instanceof LiteralOp))) {
                    boolean booleanValueSafe = HopRewriteUtils.getBooleanValueSafe((LiteralOp) hop);
                    matrixCharacteristics = new MatrixCharacteristics(allInputStats.getRows(), booleanValueSafe ? 1L : allInputStats.getCols(), -1, booleanValueSafe ? allInputStats.getRows() : allInputStats.getNonZeros());
                    break;
                } else {
                    matrixCharacteristics = new MatrixCharacteristics(allInputStats.getRows(), -1L, -1, -1L);
                    break;
                }
        }
        return matrixCharacteristics;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean allowsAllExecTypes() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.hops.Hop
    public LopProperties.ExecType optFindExecType() {
        checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            if (OptimizerUtils.isMemoryBasedOptLevel()) {
                this._etype = findExecTypeByMemEstimate();
            } else if (getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVector()) {
                this._etype = LopProperties.ExecType.CP;
            } else {
                this._etype = LopProperties.ExecType.SPARK;
            }
            checkAndSetInvalidCPDimsAndSize();
        }
        setRequiresRecompileIfNecessary();
        return this._etype;
    }

    @Override // org.apache.sysds.hops.Hop
    public void refreshSizeInformation() {
        Hop hop = getInput().get(0);
        switch (this._op) {
            case TRANS:
                setDim1(hop.getDim2());
                setDim2(hop.getDim1());
                setNnz(hop.getNnz());
                return;
            case DIAG:
                long dim1 = hop.getDim1();
                setDim1(dim1);
                if (hop.getDim2() == 1) {
                    setDim2(dim1);
                    setNnz(hop.getNnz() >= 0 ? hop.getNnz() : dim1);
                }
                if (hop.getDim2() > 1) {
                    setDim2(1L);
                    setNnz(hop.getNnz() >= 0 ? Math.min(dim1, hop.getNnz()) : dim1);
                    return;
                }
                return;
            case REV:
                setDim1(hop.getDim1());
                setDim2(hop.getDim2());
                setNnz(hop.getNnz());
                return;
            case RESHAPE:
                if (this._dataType == Types.DataType.TENSOR) {
                    setNnz(hop.getNnz());
                    return;
                }
                Hop hop2 = getInput().get(1);
                Hop hop3 = getInput().get(2);
                refreshRowsParameterInformation(hop2);
                refreshColsParameterInformation(hop3);
                setNnz(hop.getNnz());
                if (dimsKnown() || !hop.dimsKnown()) {
                    return;
                }
                if (rowsKnown() && getDim1() != 0) {
                    setDim2(hop.getLength() / getDim1());
                    return;
                } else {
                    if (!colsKnown() || getDim2() == 0) {
                        return;
                    }
                    setDim1(hop.getLength() / getDim2());
                    return;
                }
            case SORT:
                Hop hop4 = getInput().get(3);
                boolean z = !(hop4 instanceof LiteralOp);
                setDim1(hop.getDim1());
                if (z) {
                    setDim2(-1L);
                    setNnz(-1L);
                    return;
                } else {
                    boolean booleanValueSafe = HopRewriteUtils.getBooleanValueSafe((LiteralOp) hop4);
                    setDim2(booleanValueSafe ? 1L : hop.getDim2());
                    setNnz(booleanValueSafe ? hop.getDim1() : hop.getNnz());
                    return;
                }
            default:
                return;
        }
    }

    @Override // org.apache.sysds.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        ReorgOp reorgOp = new ReorgOp();
        reorgOp.clone(this, false);
        reorgOp._op = this._op;
        reorgOp._maxNumThreads = this._maxNumThreads;
        return reorgOp;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean compare(Hop hop) {
        if (!(hop instanceof ReorgOp)) {
            return false;
        }
        ReorgOp reorgOp = (ReorgOp) hop;
        boolean z = this._op == reorgOp._op && this._maxNumThreads == reorgOp._maxNumThreads && getInput().size() == hop.getInput().size();
        if (z) {
            for (int i = 0; i < this._input.size(); i++) {
                z &= getInput().get(i) == reorgOp.getInput().get(i);
            }
        }
        return z;
    }

    private boolean isSortSPRewriteApplicable() {
        boolean z = false;
        Hop hop = getInput().get(0);
        if (OptimizerUtils.checkSparkBroadcastMemoryBudget(hop.dimsKnown() ? OptimizerUtils.estimateSize(hop.getDim1(), 1L) : hop.getOutputMemEstimate())) {
            z = true;
        }
        return z;
    }
}
