package org.apache.sysds.hops;

import org.apache.sysds.common.Types;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties;
import org.apache.sysds.lops.LopsException;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedCrossEntropyR;
import org.apache.sysds.lops.WeightedDivMM;
import org.apache.sysds.lops.WeightedDivMMR;
import org.apache.sysds.lops.WeightedSigmoid;
import org.apache.sysds.lops.WeightedSigmoidR;
import org.apache.sysds.lops.WeightedSquaredLoss;
import org.apache.sysds.lops.WeightedSquaredLossR;
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.lops.WeightedUnaryMMR;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

/* loaded from: input_file:org/apache/sysds/hops/QuaternaryOp.class */
public class QuaternaryOp extends MultiThreadedHop {
    public static boolean FORCE_REPLICATION = false;
    private Types.OpOp4 _op;
    private boolean _postWeights;
    private boolean _logout;
    private boolean _minusin;
    private int _baseType;
    private boolean _mult;
    private boolean _minus;
    private boolean _umult;
    private Types.OpOp1 _uop;
    private Types.OpOp2 _sop;

    private QuaternaryOp() {
        this._op = null;
        this._postWeights = false;
        this._logout = false;
        this._minusin = false;
        this._baseType = -1;
        this._mult = false;
        this._minus = false;
        this._umult = false;
        this._uop = null;
        this._sop = null;
    }

    public QuaternaryOp(String str, Types.DataType dataType, Types.ValueType valueType, Types.OpOp4 opOp4, Hop hop, Hop hop2, Hop hop3, Hop hop4, boolean z) {
        this(str, dataType, valueType, opOp4, hop, hop2, hop3);
        getInput().add(3, hop4);
        hop4.getParent().add(this);
        this._postWeights = z;
    }

    public QuaternaryOp(String str, Types.DataType dataType, Types.ValueType valueType, Types.OpOp4 opOp4, Hop hop, Hop hop2, Hop hop3, boolean z, boolean z2) {
        this(str, dataType, valueType, opOp4, hop, hop2, hop3);
        this._logout = z;
        this._minusin = z2;
    }

    public QuaternaryOp(String str, Types.DataType dataType, Types.ValueType valueType, Types.OpOp4 opOp4, Hop hop, Hop hop2, Hop hop3, Hop hop4, int i, boolean z, boolean z2) {
        this(str, dataType, valueType, opOp4, hop, hop2, hop3);
        if (hop4 != null) {
            getInput().add(3, hop4);
            hop4.getParent().add(this);
        }
        this._baseType = i;
        this._mult = z;
        this._minus = z2;
    }

    public QuaternaryOp(String str, Types.DataType dataType, Types.ValueType valueType, Types.OpOp4 opOp4, Hop hop, Hop hop2, Hop hop3, boolean z, Types.OpOp1 opOp1, Types.OpOp2 opOp2) {
        this(str, dataType, valueType, opOp4, hop, hop2, hop3);
        this._umult = z;
        this._uop = opOp1;
        this._sop = opOp2;
    }

    public QuaternaryOp(String str, Types.DataType dataType, Types.ValueType valueType, Types.OpOp4 opOp4, Hop hop, Hop hop2, Hop hop3) {
        super(str, dataType, valueType);
        this._op = null;
        this._postWeights = false;
        this._logout = false;
        this._minusin = false;
        this._baseType = -1;
        this._mult = false;
        this._minus = false;
        this._umult = false;
        this._uop = null;
        this._sop = null;
        this._op = opOp4;
        getInput().add(0, hop);
        getInput().add(1, hop2);
        getInput().add(2, hop3);
        hop.getParent().add(this);
        hop2.getParent().add(this);
        hop3.getParent().add(this);
    }

    @Override // org.apache.sysds.hops.Hop
    public void checkArity() {
        HopsException.check(this._input.size() == 3 || this._input.size() == 4, this, "should have arity 3 or 4 but has arity %d", Integer.valueOf(this._input.size()));
    }

    public Types.OpOp4 getOp() {
        return this._op;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean isGPUEnabled() {
        return false;
    }

    @Override // org.apache.sysds.hops.MultiThreadedHop
    public boolean isMultiThreadedOpType() {
        return true;
    }

    @Override // org.apache.sysds.hops.Hop
    public Lop constructLops() {
        if (getLops() != null) {
            return getLops();
        }
        try {
            LopProperties.ExecType optFindExecType = optFindExecType();
            switch (this._op) {
                case WSLOSS:
                    WeightedSquaredLoss.WeightsType checkWeightsType = checkWeightsType();
                    if (optFindExecType == LopProperties.ExecType.CP) {
                        constructCPLopsWeightedSquaredLoss(checkWeightsType);
                        break;
                    } else {
                        if (optFindExecType != LopProperties.ExecType.SPARK) {
                            throw new HopsException("Unsupported quaternaryop-wsloss exec type: " + optFindExecType);
                        }
                        constructSparkLopsWeightedSquaredLoss(checkWeightsType);
                        break;
                    }
                case WSIGMOID:
                    WeightedSigmoid.WSigmoidType checkWSigmoidType = checkWSigmoidType();
                    if (optFindExecType == LopProperties.ExecType.CP) {
                        constructCPLopsWeightedSigmoid(checkWSigmoidType);
                        break;
                    } else {
                        if (optFindExecType != LopProperties.ExecType.SPARK) {
                            throw new HopsException("Unsupported quaternaryop-wsigmoid exec type: " + optFindExecType);
                        }
                        constructSparkLopsWeightedSigmoid(checkWSigmoidType);
                        break;
                    }
                case WDIVMM:
                    WeightedDivMM.WDivMMType checkWDivMMType = checkWDivMMType();
                    if (optFindExecType == LopProperties.ExecType.CP) {
                        constructCPLopsWeightedDivMM(checkWDivMMType);
                        break;
                    } else {
                        if (optFindExecType != LopProperties.ExecType.SPARK) {
                            throw new HopsException("Unsupported quaternaryop-wdivmm exec type: " + optFindExecType);
                        }
                        constructSparkLopsWeightedDivMM(checkWDivMMType);
                        break;
                    }
                case WCEMM:
                    WeightedCrossEntropy.WCeMMType checkWCeMMType = checkWCeMMType();
                    if (optFindExecType == LopProperties.ExecType.CP) {
                        constructCPLopsWeightedCeMM(checkWCeMMType);
                        break;
                    } else {
                        if (optFindExecType != LopProperties.ExecType.SPARK) {
                            throw new HopsException("Unsupported quaternaryop-wcemm exec type: " + optFindExecType);
                        }
                        constructSparkLopsWeightedCeMM(checkWCeMMType);
                        break;
                    }
                case WUMM:
                    WeightedUnaryMM.WUMMType wUMMType = this._umult ? WeightedUnaryMM.WUMMType.MULT : WeightedUnaryMM.WUMMType.DIV;
                    if (optFindExecType == LopProperties.ExecType.CP) {
                        constructCPLopsWeightedUMM(wUMMType);
                        break;
                    } else {
                        if (optFindExecType != LopProperties.ExecType.SPARK) {
                            throw new HopsException("Unsupported quaternaryop-wumm exec type: " + optFindExecType);
                        }
                        constructSparkLopsWeightedUMM(wUMMType);
                        break;
                    }
                default:
                    throw new HopsException(printErrorLocation() + "Unknown QuaternaryOp (" + this._op + ") while constructing Lops");
            }
            constructAndSetLopsDataFlowProperties();
            return getLops();
        } catch (LopsException e) {
            throw new HopsException(printErrorLocation() + "error constructing lops for QuaternaryOp.", e);
        }
    }

    @Override // org.apache.sysds.hops.Hop
    public String getOpString() {
        return "q(" + this._op.toString() + ")";
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean allowsAllExecTypes() {
        return true;
    }

    private void constructCPLopsWeightedSquaredLoss(WeightedSquaredLoss.WeightsType weightsType) {
        WeightedSquaredLoss weightedSquaredLoss = new WeightedSquaredLoss(getInput().get(0).constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(), getInput().get(3).constructLops(), getDataType(), getValueType(), weightsType, LopProperties.ExecType.CP);
        weightedSquaredLoss.setNumThreads(OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        setOutputDimensions(weightedSquaredLoss);
        setLineNumbers(weightedSquaredLoss);
        setLops(weightedSquaredLoss);
    }

    private void constructSparkLopsWeightedSquaredLoss(WeightedSquaredLoss.WeightsType weightsType) {
        double broadcastMemoryBudget = SparkExecutionContext.getBroadcastMemoryBudget();
        double localMemBudget = OptimizerUtils.getLocalMemBudget();
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        Hop hop3 = getInput().get(2);
        Hop hop4 = getInput().get(3);
        double estimateSize = OptimizerUtils.estimateSize(hop2.getDim1(), hop2.getDim2());
        double estimateSize2 = OptimizerUtils.estimateSize(hop3.getDim1(), hop3.getDim2());
        boolean z = !weightsType.hasFourInputs() && estimateSize + estimateSize2 < broadcastMemoryBudget && 2.0d * estimateSize < localMemBudget && 2.0d * estimateSize2 < localMemBudget;
        if (!FORCE_REPLICATION && z) {
            WeightedSquaredLoss weightedSquaredLoss = new WeightedSquaredLoss(hop.constructLops(), hop2.constructLops(), hop3.constructLops(), hop4.constructLops(), Types.DataType.SCALAR, Types.ValueType.FP64, weightsType, LopProperties.ExecType.SPARK);
            setOutputDimensions(weightedSquaredLoss);
            setLineNumbers(weightedSquaredLoss);
            setLops(weightedSquaredLoss);
            return;
        }
        boolean z2 = !FORCE_REPLICATION && estimateSize < broadcastMemoryBudget && 2.0d * estimateSize < localMemBudget;
        WeightedSquaredLossR weightedSquaredLossR = new WeightedSquaredLossR(hop.constructLops(), hop2.constructLops(), hop3.constructLops(), hop4.constructLops(), Types.DataType.SCALAR, Types.ValueType.FP64, weightsType, z2, !FORCE_REPLICATION && ((!z2 && estimateSize2 < broadcastMemoryBudget) || (z2 && estimateSize + estimateSize2 < broadcastMemoryBudget)) && 2.0d * estimateSize2 < localMemBudget, LopProperties.ExecType.SPARK);
        setOutputDimensions(weightedSquaredLossR);
        setLineNumbers(weightedSquaredLossR);
        setLops(weightedSquaredLossR);
    }

    private void constructCPLopsWeightedSigmoid(WeightedSigmoid.WSigmoidType wSigmoidType) {
        WeightedSigmoid weightedSigmoid = new WeightedSigmoid(getInput().get(0).constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(), getDataType(), getValueType(), wSigmoidType, LopProperties.ExecType.CP);
        weightedSigmoid.setNumThreads(OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        setOutputDimensions(weightedSigmoid);
        setLineNumbers(weightedSigmoid);
        setLops(weightedSigmoid);
    }

    private void constructSparkLopsWeightedSigmoid(WeightedSigmoid.WSigmoidType wSigmoidType) {
        double broadcastMemoryBudget = SparkExecutionContext.getBroadcastMemoryBudget();
        double localMemBudget = OptimizerUtils.getLocalMemBudget();
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        Hop hop3 = getInput().get(2);
        double estimateSize = OptimizerUtils.estimateSize(hop2.getDim1(), hop2.getDim2());
        double estimateSize2 = OptimizerUtils.estimateSize(hop3.getDim1(), hop3.getDim2());
        boolean z = estimateSize + estimateSize2 < broadcastMemoryBudget && 2.0d * estimateSize < localMemBudget && 2.0d * estimateSize2 < localMemBudget;
        if (!FORCE_REPLICATION && z) {
            WeightedSigmoid weightedSigmoid = new WeightedSigmoid(hop.constructLops(), hop2.constructLops(), hop3.constructLops(), Types.DataType.MATRIX, Types.ValueType.FP64, wSigmoidType, LopProperties.ExecType.SPARK);
            setOutputDimensions(weightedSigmoid);
            setLineNumbers(weightedSigmoid);
            setLops(weightedSigmoid);
            return;
        }
        boolean z2 = !FORCE_REPLICATION && estimateSize < broadcastMemoryBudget && 2.0d * estimateSize < localMemBudget;
        WeightedSigmoidR weightedSigmoidR = new WeightedSigmoidR(hop.constructLops(), hop2.constructLops(), hop3.constructLops(), Types.DataType.MATRIX, Types.ValueType.FP64, wSigmoidType, z2, !FORCE_REPLICATION && ((!z2 && estimateSize2 < broadcastMemoryBudget) || (z2 && estimateSize + estimateSize2 < broadcastMemoryBudget)) && 2.0d * estimateSize2 < localMemBudget, LopProperties.ExecType.SPARK);
        setOutputDimensions(weightedSigmoidR);
        setLineNumbers(weightedSigmoidR);
        setLops(weightedSigmoidR);
    }

    private void constructCPLopsWeightedDivMM(WeightedDivMM.WDivMMType wDivMMType) {
        WeightedDivMM weightedDivMM = new WeightedDivMM(getInput().get(0).constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(), getInput().get(3).constructLops(), getDataType(), getValueType(), wDivMMType, LopProperties.ExecType.CP);
        weightedDivMM.setNumThreads(OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        setOutputDimensions(weightedDivMM);
        setLineNumbers(weightedDivMM);
        setLops(weightedDivMM);
    }

    private void constructSparkLopsWeightedDivMM(WeightedDivMM.WDivMMType wDivMMType) {
        double broadcastMemoryBudget = SparkExecutionContext.getBroadcastMemoryBudget();
        double localMemBudget = OptimizerUtils.getLocalMemBudget();
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        Hop hop3 = getInput().get(2);
        Hop hop4 = getInput().get(3);
        double estimateSize = OptimizerUtils.estimateSize(hop2.getDim1(), hop2.getDim2());
        double estimateSize2 = OptimizerUtils.estimateSize(hop3.getDim1(), hop3.getDim2());
        boolean z = (!wDivMMType.hasFourInputs() || wDivMMType.hasScalar()) && estimateSize + estimateSize2 < broadcastMemoryBudget && 2.0d * estimateSize < localMemBudget && 2.0d * estimateSize2 < localMemBudget;
        if (!FORCE_REPLICATION && z) {
            WeightedDivMM weightedDivMM = new WeightedDivMM(hop.constructLops(), hop2.constructLops(), hop3.constructLops(), hop4.constructLops(), Types.DataType.MATRIX, Types.ValueType.FP64, wDivMMType, LopProperties.ExecType.SPARK);
            setOutputDimensions(weightedDivMM);
            setLineNumbers(weightedDivMM);
            setLops(weightedDivMM);
            return;
        }
        boolean z2 = !FORCE_REPLICATION && estimateSize < broadcastMemoryBudget && 2.0d * estimateSize < localMemBudget;
        WeightedDivMMR weightedDivMMR = new WeightedDivMMR(hop.constructLops(), hop2.constructLops(), hop3.constructLops(), hop4.constructLops(), Types.DataType.MATRIX, Types.ValueType.FP64, wDivMMType, z2, !FORCE_REPLICATION && ((!z2 && estimateSize2 < broadcastMemoryBudget) || (z2 && estimateSize + estimateSize2 < broadcastMemoryBudget)) && 2.0d * estimateSize2 < localMemBudget, LopProperties.ExecType.SPARK);
        setOutputDimensions(weightedDivMMR);
        setLineNumbers(weightedDivMMR);
        setLops(weightedDivMMR);
    }

    private void constructCPLopsWeightedCeMM(WeightedCrossEntropy.WCeMMType wCeMMType) {
        WeightedCrossEntropy weightedCrossEntropy = new WeightedCrossEntropy(getInput().get(0).constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(), getInput().get(3).constructLops(), getDataType(), getValueType(), wCeMMType, LopProperties.ExecType.CP);
        weightedCrossEntropy.setNumThreads(OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        setOutputDimensions(weightedCrossEntropy);
        setLineNumbers(weightedCrossEntropy);
        setLops(weightedCrossEntropy);
    }

    private void constructSparkLopsWeightedCeMM(WeightedCrossEntropy.WCeMMType wCeMMType) {
        double broadcastMemoryBudget = SparkExecutionContext.getBroadcastMemoryBudget();
        double localMemBudget = OptimizerUtils.getLocalMemBudget();
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        Hop hop3 = getInput().get(2);
        Hop hop4 = getInput().get(3);
        double estimateSize = OptimizerUtils.estimateSize(hop2.getDim1(), hop2.getDim2());
        double estimateSize2 = OptimizerUtils.estimateSize(hop3.getDim1(), hop3.getDim2());
        boolean z = estimateSize + estimateSize2 < broadcastMemoryBudget && 2.0d * estimateSize < localMemBudget && 2.0d * estimateSize2 < localMemBudget;
        if (!FORCE_REPLICATION && z) {
            WeightedCrossEntropy weightedCrossEntropy = new WeightedCrossEntropy(hop.constructLops(), hop2.constructLops(), hop3.constructLops(), hop4.constructLops(), Types.DataType.SCALAR, Types.ValueType.FP64, wCeMMType, LopProperties.ExecType.SPARK);
            setOutputDimensions(weightedCrossEntropy);
            setLineNumbers(weightedCrossEntropy);
            setLops(weightedCrossEntropy);
            return;
        }
        boolean z2 = !FORCE_REPLICATION && estimateSize < broadcastMemoryBudget && 2.0d * estimateSize < localMemBudget;
        WeightedCrossEntropyR weightedCrossEntropyR = new WeightedCrossEntropyR(hop.constructLops(), hop2.constructLops(), hop3.constructLops(), hop4.constructLops(), Types.DataType.SCALAR, Types.ValueType.FP64, wCeMMType, z2, !FORCE_REPLICATION && ((!z2 && estimateSize2 < broadcastMemoryBudget) || (z2 && estimateSize + estimateSize2 < broadcastMemoryBudget)) && 2.0d * estimateSize2 < localMemBudget, LopProperties.ExecType.SPARK);
        setOutputDimensions(weightedCrossEntropyR);
        setLineNumbers(weightedCrossEntropyR);
        setLops(weightedCrossEntropyR);
    }

    private void constructCPLopsWeightedUMM(WeightedUnaryMM.WUMMType wUMMType) {
        WeightedUnaryMM weightedUnaryMM = new WeightedUnaryMM(getInput().get(0).constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(), getDataType(), getValueType(), wUMMType, this._uop != null ? this._uop : this._sop == Types.OpOp2.POW ? Types.OpOp1.POW2 : Types.OpOp1.MULT2, LopProperties.ExecType.CP);
        weightedUnaryMM.setNumThreads(OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        setOutputDimensions(weightedUnaryMM);
        setLineNumbers(weightedUnaryMM);
        setLops(weightedUnaryMM);
    }

    private void constructSparkLopsWeightedUMM(WeightedUnaryMM.WUMMType wUMMType) {
        Types.OpOp1 opOp1 = this._uop != null ? this._uop : this._sop == Types.OpOp2.POW ? Types.OpOp1.POW2 : Types.OpOp1.MULT2;
        double broadcastMemoryBudget = SparkExecutionContext.getBroadcastMemoryBudget();
        double localMemBudget = OptimizerUtils.getLocalMemBudget();
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        Hop hop3 = getInput().get(2);
        double estimateSize = OptimizerUtils.estimateSize(hop2.getDim1(), hop2.getDim2());
        double estimateSize2 = OptimizerUtils.estimateSize(hop3.getDim1(), hop3.getDim2());
        boolean z = estimateSize + estimateSize2 < broadcastMemoryBudget && 2.0d * estimateSize < localMemBudget && 2.0d * estimateSize2 < localMemBudget;
        if (!FORCE_REPLICATION && z) {
            WeightedUnaryMM weightedUnaryMM = new WeightedUnaryMM(hop.constructLops(), hop2.constructLops(), hop3.constructLops(), Types.DataType.MATRIX, Types.ValueType.FP64, wUMMType, opOp1, LopProperties.ExecType.SPARK);
            setOutputDimensions(weightedUnaryMM);
            setLineNumbers(weightedUnaryMM);
            setLops(weightedUnaryMM);
            return;
        }
        boolean z2 = !FORCE_REPLICATION && estimateSize < broadcastMemoryBudget && 2.0d * estimateSize < localMemBudget;
        WeightedUnaryMMR weightedUnaryMMR = new WeightedUnaryMMR(hop.constructLops(), hop2.constructLops(), hop3.constructLops(), Types.DataType.MATRIX, Types.ValueType.FP64, wUMMType, opOp1, z2, !FORCE_REPLICATION && ((!z2 && estimateSize2 < broadcastMemoryBudget) || (z2 && estimateSize + estimateSize2 < broadcastMemoryBudget)) && 2.0d * estimateSize2 < localMemBudget, LopProperties.ExecType.SPARK);
        setOutputDimensions(weightedUnaryMMR);
        setLineNumbers(weightedUnaryMMR);
        setLops(weightedUnaryMMR);
    }

    private WeightedSquaredLoss.WeightsType checkWeightsType() {
        WeightedSquaredLoss.WeightsType weightsType = WeightedSquaredLoss.WeightsType.NONE;
        if (!(getInput().get(3) instanceof LiteralOp)) {
            weightsType = this._postWeights ? WeightedSquaredLoss.WeightsType.POST : WeightedSquaredLoss.WeightsType.PRE;
        } else if (this._postWeights) {
            weightsType = WeightedSquaredLoss.WeightsType.POST_NZ;
        }
        return weightsType;
    }

    private WeightedSigmoid.WSigmoidType checkWSigmoidType() {
        return (this._logout && this._minusin) ? WeightedSigmoid.WSigmoidType.LOG_MINUS : this._logout ? WeightedSigmoid.WSigmoidType.LOG : this._minusin ? WeightedSigmoid.WSigmoidType.MINUS : WeightedSigmoid.WSigmoidType.BASIC;
    }

    private WeightedDivMM.WDivMMType checkWDivMMType() {
        switch (this._baseType) {
            case 0:
                return WeightedDivMM.WDivMMType.MULT_BASIC;
            case 1:
                return getInput().get(3).getDataType() == Types.DataType.MATRIX ? WeightedDivMM.WDivMMType.MULT_MINUS_4_LEFT : this._minus ? WeightedDivMM.WDivMMType.MULT_MINUS_LEFT : this._mult ? WeightedDivMM.WDivMMType.MULT_LEFT : WeightedDivMM.WDivMMType.DIV_LEFT;
            case 2:
                return getInput().get(3).getDataType() == Types.DataType.MATRIX ? WeightedDivMM.WDivMMType.MULT_MINUS_4_RIGHT : this._minus ? WeightedDivMM.WDivMMType.MULT_MINUS_RIGHT : this._mult ? WeightedDivMM.WDivMMType.MULT_RIGHT : WeightedDivMM.WDivMMType.DIV_RIGHT;
            case 3:
                return WeightedDivMM.WDivMMType.DIV_LEFT_EPS;
            case 4:
                return WeightedDivMM.WDivMMType.DIV_RIGHT_EPS;
            default:
                return null;
        }
    }

    private WeightedCrossEntropy.WCeMMType checkWCeMMType() {
        return this._baseType == 1 ? WeightedCrossEntropy.WCeMMType.BASIC_EPS : WeightedCrossEntropy.WCeMMType.BASIC;
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        switch (this._op) {
            case WSLOSS:
            case WCEMM:
                return 8.0d;
            case WSIGMOID:
            case WDIVMM:
            case WUMM:
                return OptimizerUtils.estimateSizeExactSparsity(j, j2, OptimizerUtils.getSparsity(j, j2, j3));
            default:
                return DataExpression.DEFAULT_DELIM_FILL_VALUE;
        }
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        return DataExpression.DEFAULT_DELIM_FILL_VALUE;
    }

    @Override // org.apache.sysds.hops.Hop
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memoTable) {
        DataCharacteristics dataCharacteristics = null;
        switch (this._op) {
            case WSLOSS:
                break;
            case WSIGMOID:
            case WUMM:
                DataCharacteristics allInputStats = memoTable.getAllInputStats(getInput().get(0));
                dataCharacteristics = new MatrixCharacteristics(allInputStats.getRows(), allInputStats.getCols(), -1, allInputStats.getNonZeros());
                break;
            case WDIVMM:
                if (this._baseType != 0) {
                    if (this._baseType != 1 && this._baseType != 3) {
                        dataCharacteristics = memoTable.getAllInputStats(getInput().get(1)).setNonZeros(-1L);
                        break;
                    } else {
                        dataCharacteristics = memoTable.getAllInputStats(getInput().get(2)).setNonZeros(-1L);
                        break;
                    }
                } else {
                    dataCharacteristics = memoTable.getAllInputStats(getInput().get(0));
                    break;
                }
            case WCEMM:
            default:
                throw new RuntimeException("Memory for operation (" + this._op + ") can not be estimated.");
        }
        return dataCharacteristics;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.hops.Hop
    public LopProperties.ExecType optFindExecType() {
        checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            if (OptimizerUtils.isMemoryBasedOptLevel()) {
                this._etype = findExecTypeByMemEstimate();
            } else if (getInput().get(0).areDimsBelowThreshold() && getInput().get(1).areDimsBelowThreshold() && getInput().get(2).areDimsBelowThreshold() && getInput().get(3).areDimsBelowThreshold()) {
                this._etype = LopProperties.ExecType.CP;
            } else {
                this._etype = LopProperties.ExecType.SPARK;
            }
            checkAndSetInvalidCPDimsAndSize();
        }
        setRequiresRecompileIfNecessary();
        return this._etype;
    }

    @Override // org.apache.sysds.hops.Hop
    public void refreshSizeInformation() {
        switch (this._op) {
            case WSLOSS:
            case WCEMM:
            default:
                return;
            case WSIGMOID:
            case WUMM:
                Hop hop = getInput().get(0);
                setDim1(hop.getDim1());
                setDim2(hop.getDim2());
                setNnz(hop.getNnz());
                return;
            case WDIVMM:
                if (this._baseType == 0) {
                    Hop hop2 = getInput().get(0);
                    setDim1(hop2.getDim1());
                    setDim2(hop2.getDim2());
                    setNnz(hop2.getNnz());
                    return;
                }
                if (this._baseType == 1 || this._baseType == 3) {
                    Hop hop3 = getInput().get(2);
                    setDim1(hop3.getDim1());
                    setDim2(hop3.getDim2());
                    setNnz(-1L);
                    return;
                }
                Hop hop4 = getInput().get(1);
                setDim1(hop4.getDim1());
                setDim2(hop4.getDim2());
                setNnz(-1L);
                return;
        }
    }

    @Override // org.apache.sysds.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        QuaternaryOp quaternaryOp = new QuaternaryOp();
        quaternaryOp.clone(this, false);
        quaternaryOp._op = this._op;
        quaternaryOp._postWeights = this._postWeights;
        quaternaryOp._logout = this._logout;
        quaternaryOp._minusin = this._minusin;
        quaternaryOp._baseType = this._baseType;
        quaternaryOp._mult = this._mult;
        quaternaryOp._minus = this._minus;
        quaternaryOp._umult = this._umult;
        quaternaryOp._uop = this._uop;
        quaternaryOp._sop = this._sop;
        quaternaryOp._maxNumThreads = this._maxNumThreads;
        return quaternaryOp;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean compare(Hop hop) {
        if (!(hop instanceof QuaternaryOp)) {
            return false;
        }
        QuaternaryOp quaternaryOp = (QuaternaryOp) hop;
        boolean z = this._op == quaternaryOp._op && getInput().size() == quaternaryOp.getInput().size() && getInput().get(0) == quaternaryOp.getInput().get(0) && getInput().get(1) == quaternaryOp.getInput().get(1) && getInput().get(2) == quaternaryOp.getInput().get(2);
        if (z && getInput().size() == 4) {
            z &= getInput().get(3) == quaternaryOp.getInput().get(3);
        }
        return z & (this._postWeights == quaternaryOp._postWeights) & (this._logout == quaternaryOp._logout) & (this._minusin == quaternaryOp._minusin) & (this._baseType == quaternaryOp._baseType) & (this._mult == quaternaryOp._mult) & (this._minus == quaternaryOp._minus) & (this._umult == quaternaryOp._umult) & (this._uop == quaternaryOp._uop) & (this._sop == quaternaryOp._sop) & (this._maxNumThreads == quaternaryOp._maxNumThreads);
    }
}
