package org.apache.sysds.hops.codegen;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.MemoTable;
import org.apache.sysds.hops.MultiThreadedHop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties;
import org.apache.sysds.lops.SpoofFused;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

/* loaded from: input_file:org/apache/sysds/hops/codegen/SpoofFusedOp.class */
public class SpoofFusedOp extends MultiThreadedHop {
    private Class<?> _class;
    private boolean _distSupported;
    private long _constDim2;
    private SpoofOutputDimsType _dimsType;

    /* loaded from: input_file:org/apache/sysds/hops/codegen/SpoofFusedOp$SpoofOutputDimsType.class */
    public enum SpoofOutputDimsType {
        INPUT_DIMS,
        INPUT_DIMS_CONST2,
        ROW_DIMS,
        COLUMN_DIMS_ROWS,
        COLUMN_DIMS_COLS,
        RANK_DIMS_COLS,
        SCALAR,
        MULTI_SCALAR,
        ROW_RANK_DIMS,
        COLUMN_RANK_DIMS,
        COLUMN_RANK_DIMS_T,
        VECT_CONST2
    }

    public SpoofFusedOp() {
        this._class = null;
        this._distSupported = false;
        this._constDim2 = -1L;
    }

    public SpoofFusedOp(String str, Types.DataType dataType, Types.ValueType valueType, Class<?> cls, boolean z, SpoofOutputDimsType spoofOutputDimsType) {
        super(str, dataType, valueType);
        this._class = null;
        this._distSupported = false;
        this._constDim2 = -1L;
        this._class = cls;
        this._distSupported = z;
        this._dimsType = spoofOutputDimsType;
    }

    @Override // org.apache.sysds.hops.Hop
    public void checkArity() {
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean allowsAllExecTypes() {
        return this._distSupported;
    }

    public void setConstDim2(long j) {
        this._constDim2 = j;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean isGPUEnabled() {
        return false;
    }

    @Override // org.apache.sysds.hops.MultiThreadedHop
    public boolean isMultiThreadedOpType() {
        return true;
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        return this._class.getGenericSuperclass().equals(SpoofRowwise.class) ? OptimizerUtils.estimateSize(j, j2) : OptimizerUtils.estimatePartitionedSizeExactSparsity(j, j2, getBlocksize(), j3);
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        return DataExpression.DEFAULT_DELIM_FILL_VALUE;
    }

    @Override // org.apache.sysds.hops.Hop
    public Lop constructLops() {
        if (getLops() != null) {
            return getLops();
        }
        LopProperties.ExecType optFindExecType = optFindExecType();
        ArrayList arrayList = new ArrayList();
        Iterator<Hop> it = getInput().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().constructLops());
        }
        SpoofFused spoofFused = new SpoofFused(arrayList, getDataType(), getValueType(), this._class, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads), optFindExecType);
        setOutputDimensions(spoofFused);
        setLineNumbers(spoofFused);
        setLops(spoofFused);
        return spoofFused;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.hops.Hop
    public LopProperties.ExecType optFindExecType() {
        checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            this._etype = findExecTypeByMemEstimate();
            checkAndSetInvalidCPDimsAndSize();
        }
        return this._etype;
    }

    @Override // org.apache.sysds.hops.Hop
    public String getOpString() {
        return "spoof(" + this._class.getSimpleName() + ")";
    }

    public String getClassName() {
        return this._class.getName();
    }

    @Override // org.apache.sysds.hops.Hop
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memoTable) {
        DataCharacteristics allInputStats = memoTable.getAllInputStats(getInput().get(0));
        MatrixCharacteristics matrixCharacteristics = null;
        if (allInputStats.dimsKnown()) {
            switch (this._dimsType) {
                case ROW_DIMS:
                    matrixCharacteristics = new MatrixCharacteristics(allInputStats.getRows(), 1L, -1, -1L);
                    break;
                case COLUMN_DIMS_ROWS:
                    matrixCharacteristics = new MatrixCharacteristics(allInputStats.getCols(), 1L, -1, -1L);
                    break;
                case COLUMN_DIMS_COLS:
                    matrixCharacteristics = new MatrixCharacteristics(1L, allInputStats.getCols(), -1, -1L);
                    break;
                case RANK_DIMS_COLS:
                    DataCharacteristics allInputStats2 = memoTable.getAllInputStats(getInput().get(1));
                    if (allInputStats2.dimsKnown()) {
                        matrixCharacteristics = new MatrixCharacteristics(1L, allInputStats2.getCols(), -1, -1L);
                        break;
                    }
                    break;
                case INPUT_DIMS:
                    matrixCharacteristics = new MatrixCharacteristics(allInputStats.getRows(), allInputStats.getCols(), -1, -1L);
                    break;
                case INPUT_DIMS_CONST2:
                    matrixCharacteristics = new MatrixCharacteristics(allInputStats.getRows(), this._constDim2, -1, -1L);
                    break;
                case VECT_CONST2:
                    matrixCharacteristics = new MatrixCharacteristics(1L, this._constDim2, -1, -1L);
                    break;
                case SCALAR:
                    matrixCharacteristics = new MatrixCharacteristics(0L, 0L, -1, -1L);
                    break;
                case MULTI_SCALAR:
                    matrixCharacteristics = new MatrixCharacteristics(1L, this._dc.getCols(), -1, -1L);
                    break;
                case ROW_RANK_DIMS:
                    DataCharacteristics allInputStats3 = memoTable.getAllInputStats(getInput().get(1));
                    if (allInputStats3.dimsKnown()) {
                        matrixCharacteristics = new MatrixCharacteristics(allInputStats.getRows(), allInputStats3.getCols(), -1, -1L);
                        break;
                    }
                    break;
                case COLUMN_RANK_DIMS:
                    DataCharacteristics allInputStats4 = memoTable.getAllInputStats(getInput().get(1));
                    if (allInputStats4.dimsKnown()) {
                        matrixCharacteristics = new MatrixCharacteristics(allInputStats.getCols(), allInputStats4.getCols(), -1, -1L);
                        break;
                    }
                    break;
                case COLUMN_RANK_DIMS_T:
                    DataCharacteristics allInputStats5 = memoTable.getAllInputStats(getInput().get(1));
                    if (allInputStats5.dimsKnown()) {
                        matrixCharacteristics = new MatrixCharacteristics(allInputStats5.getCols(), allInputStats.getCols(), -1, -1L);
                        break;
                    }
                    break;
                default:
                    throw new RuntimeException("Failed to infer worst-case size information for type: " + this._dimsType.toString());
            }
        }
        return matrixCharacteristics;
    }

    @Override // org.apache.sysds.hops.Hop
    public void refreshSizeInformation() {
        switch (this._dimsType) {
            case ROW_DIMS:
                setDim1(getInput().get(0).getDim1());
                setDim2(1L);
                return;
            case COLUMN_DIMS_ROWS:
                setDim1(getInput().get(0).getDim2());
                setDim2(1L);
                return;
            case COLUMN_DIMS_COLS:
                setDim1(1L);
                setDim2(getInput().get(0).getDim2());
                return;
            case RANK_DIMS_COLS:
                setDim1(1L);
                setDim2(getInput().get(1).getDim2());
                return;
            case INPUT_DIMS:
                setDim1(getInput().get(0).getDim1());
                setDim2(getInput().get(0).getDim2());
                return;
            case INPUT_DIMS_CONST2:
                setDim1(getInput().get(0).getDim1());
                setDim2(this._constDim2);
                return;
            case VECT_CONST2:
                setDim1(1L);
                setDim2(this._constDim2);
                return;
            case SCALAR:
                setDim1(0L);
                setDim2(0L);
                return;
            case MULTI_SCALAR:
                setDim1(1L);
                return;
            case ROW_RANK_DIMS:
                setDim1(getInput().get(0).getDim1());
                setDim2(getInput().get(1).getDim2());
                return;
            case COLUMN_RANK_DIMS:
                setDim1(getInput().get(0).getDim2());
                setDim2(getInput().get(1).getDim2());
                return;
            case COLUMN_RANK_DIMS_T:
                setDim1(getInput().get(1).getDim2());
                setDim2(getInput().get(0).getDim2());
                return;
            default:
                throw new RuntimeException("Failed to refresh size information for type: " + this._dimsType.toString());
        }
    }

    @Override // org.apache.sysds.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        SpoofFusedOp spoofFusedOp = new SpoofFusedOp();
        spoofFusedOp.clone(this, false);
        spoofFusedOp._class = this._class;
        spoofFusedOp._distSupported = this._distSupported;
        spoofFusedOp._maxNumThreads = this._maxNumThreads;
        spoofFusedOp._constDim2 = this._constDim2;
        spoofFusedOp._dimsType = this._dimsType;
        return spoofFusedOp;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean compare(Hop hop) {
        if (!(hop instanceof SpoofFusedOp)) {
            return false;
        }
        SpoofFusedOp spoofFusedOp = (SpoofFusedOp) hop;
        boolean z = this._class.equals(spoofFusedOp._class) && this._distSupported == spoofFusedOp._distSupported && this._maxNumThreads == spoofFusedOp._maxNumThreads && this._constDim2 == spoofFusedOp._constDim2 && getInput().size() == spoofFusedOp.getInput().size();
        if (z) {
            for (int i = 0; i < getInput().size(); i++) {
                z &= getInput().get(i) == spoofFusedOp.getInput().get(i);
            }
        }
        return z;
    }
}
