package org.apache.sysds.hops;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.DnnTransform;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.matrix.data.DnnParameters;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.DnnUtils;

/* loaded from: input_file:org/apache/sysds/hops/DnnOp.class */
public class DnnOp extends MultiThreadedHop {
    private static final Log LOG = LogFactory.getLog(DnnOp.class.getName());
    private static final boolean INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP = true;
    private static final boolean THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH = true;
    private Types.OpOpDnn op;
    private DnnParameters _cachedParams;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/hops/DnnOp$IntermediateDimensions.class */
    public static class IntermediateDimensions {
        int dim1;
        int dim2;
        double sp;

        public IntermediateDimensions(DnnOp dnnOp, String str, String str2, double d) {
            this.dim1 = (int) dnnOp.getDim(str);
            this.dim2 = (int) dnnOp.getDim(str2);
            this.sp = d;
        }

        public IntermediateDimensions(DnnOp dnnOp, String str, String str2) {
            this.dim1 = (int) dnnOp.getDim(str);
            this.dim2 = (int) dnnOp.getDim(str2);
            this.sp = 1.0d;
        }

        public IntermediateDimensions(DnnOp dnnOp, int i, String str) {
            this.dim1 = i;
            this.dim2 = (int) dnnOp.getDim(str);
            this.sp = 1.0d;
        }

        static double guardedAdd(double d, double d2) {
            if (d < DataExpression.DEFAULT_DELIM_FILL_VALUE || d2 < DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                return OptimizerUtils.DEFAULT_SIZE;
            }
            double d3 = d + d2;
            return d3 >= OptimizerUtils.DEFAULT_SIZE ? OptimizerUtils.DEFAULT_SIZE : d3;
        }

        public static double addEstimateSizes(ArrayList<IntermediateDimensions> arrayList, int i) {
            double d = 0.0d;
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                d = guardedAdd(d, OptimizerUtils.estimateSizeExactSparsity(arrayList.get(i2).dim1, arrayList.get(i2).dim2, arrayList.get(i2).sp) * i);
            }
            return d;
        }

        public static double guardedMax(double d, double d2) {
            if (d < DataExpression.DEFAULT_DELIM_FILL_VALUE || d2 < DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                return OptimizerUtils.DEFAULT_SIZE;
            }
            double max = Math.max(d, d2);
            return max >= OptimizerUtils.DEFAULT_SIZE ? OptimizerUtils.DEFAULT_SIZE : max;
        }
    }

    private DnnOp() {
        this._cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, this._maxNumThreads);
    }

    public DnnOp(String str, Types.DataType dataType, Types.ValueType valueType, Types.OpOpDnn opOpDnn, ArrayList<Hop> arrayList) {
        super(str, dataType, valueType);
        this._cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, this._maxNumThreads);
        this.op = opOpDnn;
        for (int i = 0; i < arrayList.size(); i++) {
            Hop hop = arrayList.get(i);
            getInput().add(i, hop);
            hop.getParent().add(this);
        }
        refreshSizeInformation();
    }

    @Override // org.apache.sysds.hops.Hop
    public void checkArity() {
        HopsException.check(this._input.size() >= 1, this, "should have at least one input but has %d inputs", Integer.valueOf(this._input.size()));
    }

    public Types.OpOpDnn getOp() {
        return this.op;
    }

    @Override // org.apache.sysds.hops.Hop
    public String getOpString() {
        return this.op.toString();
    }

    private static boolean isEligibleForSpark() {
        return false;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean isGPUEnabled() {
        return DMLScript.USE_ACCELERATOR;
    }

    @Override // org.apache.sysds.hops.MultiThreadedHop
    public boolean isMultiThreadedOpType() {
        return true;
    }

    @Override // org.apache.sysds.hops.Hop
    public Lop constructLops() {
        if (getLops() != null) {
            return getLops();
        }
        LopProperties.ExecType optFindExecType = optFindExecType();
        ArrayList<Hop> input = getInput();
        switch (this.op) {
            case MAX_POOL:
            case MAX_POOL_BACKWARD:
            case AVG_POOL:
            case AVG_POOL_BACKWARD:
            case CONV2D:
            case CONV2D_BACKWARD_DATA:
            case CONV2D_BACKWARD_FILTER:
            case BIASADD:
            case BIASMULT:
                if (optFindExecType != LopProperties.ExecType.CP && optFindExecType != LopProperties.ExecType.GPU) {
                    throw new HopsException("Unimplemented DnnOp for execution type: " + optFindExecType.name());
                }
                setLops(constructDnnLops(optFindExecType, input));
                break;
                break;
            case BATCH_NORM2D_TEST:
            case CHANNEL_SUMS:
            case UPDATE_NESTEROV_X:
                if (optFindExecType != LopProperties.ExecType.GPU) {
                    throw new HopsException("Unimplemented DnnOp for execution type: " + optFindExecType.name());
                }
                setLops(constructDnnLops(optFindExecType, input));
                break;
            default:
                throw new HopsException("Unsupported lops construction for operation type '" + this.op + "'.");
        }
        constructAndSetLopsDataFlowProperties();
        return getLops();
    }

    public void setOp(Types.OpOpDnn opOpDnn) {
        this.op = opOpDnn;
    }

    private int getNumExpectedInputs() {
        switch (this.op) {
            case MAX_POOL_BACKWARD:
            case AVG_POOL_BACKWARD:
            case CONV2D:
            case CONV2D_BACKWARD_DATA:
            case CONV2D_BACKWARD_FILTER:
                return 14;
            case AVG_POOL:
            default:
                return 13;
            case BIASADD:
            case BIASMULT:
                return 2;
            case BATCH_NORM2D_TEST:
                return 6;
            case CHANNEL_SUMS:
                return 3;
            case UPDATE_NESTEROV_X:
                return 4;
        }
    }

    private static Hop isInputReLU(Hop hop) {
        if (!HopRewriteUtils.isBinary(hop, Types.OpOp2.MAX)) {
            return null;
        }
        if (HopRewriteUtils.isLiteralOfValue(hop.getInput().get(0), DataExpression.DEFAULT_DELIM_FILL_VALUE)) {
            return hop.getInput().get(1);
        }
        if (HopRewriteUtils.isLiteralOfValue(hop.getInput().get(1), DataExpression.DEFAULT_DELIM_FILL_VALUE)) {
            return hop.getInput().get(0);
        }
        return null;
    }

    private static boolean isInputConv2d(Hop hop) {
        return HopRewriteUtils.isDnn(hop, Types.OpOpDnn.CONV2D);
    }

    private static boolean isPoolingParametersEqualAndKnown(DnnParameters dnnParameters, DnnParameters dnnParameters2) {
        return isEqualAndKnown(dnnParameters.stride_h, dnnParameters2.stride_h) && isEqualAndKnown(dnnParameters.stride_w, dnnParameters2.stride_w) && isEqualAndKnown(dnnParameters.pad_h, dnnParameters2.pad_h) && isEqualAndKnown(dnnParameters.pad_w, dnnParameters2.pad_w) && isEqualAndKnown(dnnParameters.R, dnnParameters2.R) && isEqualAndKnown(dnnParameters.S, dnnParameters2.S) && isEqualAndKnown(dnnParameters.N, dnnParameters2.N) && isEqualAndKnown(dnnParameters.C, dnnParameters2.C) && isEqualAndKnown(dnnParameters.H, dnnParameters2.H) && isEqualAndKnown(dnnParameters.W, dnnParameters2.W);
    }

    public boolean isStride1Pad0() {
        DnnParameters parseInput = parseInput();
        return parseInput.stride_h == 1 && parseInput.stride_w == 1 && parseInput.pad_h == 0 && parseInput.pad_w == 0;
    }

    private static boolean isEqualAndKnown(int i, int i2) {
        return i >= 0 && i2 >= 0 && i == i2;
    }

    private Lop getMaxPoolOutputLop() {
        if (this.op != Types.OpOpDnn.MAX_POOL_BACKWARD && this.op != Types.OpOpDnn.AVG_POOL_BACKWARD) {
            return null;
        }
        Types.OpOpDnn opOpDnn = this.op == Types.OpOpDnn.MAX_POOL_BACKWARD ? Types.OpOpDnn.MAX_POOL : Types.OpOpDnn.AVG_POOL;
        Iterator<Hop> it = getInput().get(0).getParent().iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if (next instanceof DnnOp) {
                DnnOp dnnOp = (DnnOp) next;
                if (dnnOp.getOp() == opOpDnn && isPoolingParametersEqualAndKnown(dnnOp._cachedParams, this._cachedParams)) {
                    return dnnOp.constructLops();
                }
            }
        }
        return null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Lop constructDnnLops(LopProperties.ExecType execType, ArrayList<Hop> arrayList) {
        Lop constructLops;
        if (arrayList.size() != getNumExpectedInputs()) {
            throw new HopsException("Incorrect number of inputs for " + this.op.name());
        }
        Lop lop = null;
        ArrayList<Hop> arrayList2 = arrayList;
        Types.OpOpDnn opOpDnn = this.op;
        Hop isInputReLU = isInputReLU(arrayList.get(0));
        if (OptimizerUtils.ALLOW_OPERATOR_FUSION && execType == LopProperties.ExecType.CP && this.op == Types.OpOpDnn.MAX_POOL && isInputReLU != null) {
            constructLops = isInputReLU.constructLops();
            opOpDnn = Types.OpOpDnn.RELU_MAX_POOL;
        } else if (OptimizerUtils.ALLOW_OPERATOR_FUSION && execType == LopProperties.ExecType.CP && this.op == Types.OpOpDnn.MAX_POOL_BACKWARD && isInputReLU != null) {
            constructLops = isInputReLU.constructLops();
            opOpDnn = Types.OpOpDnn.RELU_MAX_POOL_BACKWARD;
        } else if (OptimizerUtils.ALLOW_OPERATOR_FUSION && this.op == Types.OpOpDnn.BIASADD && isInputConv2d(arrayList.get(0))) {
            opOpDnn = Types.OpOpDnn.CONV2D_BIAS_ADD;
            constructLops = arrayList.get(0).getInput().get(0).constructLops();
            lop = arrayList.get(1).constructLops();
            arrayList2 = arrayList.get(0).getInput();
        } else {
            constructLops = arrayList.get(0).constructLops();
        }
        double computeIntermediateMemEstimate = computeIntermediateMemEstimate(-1L, -1L, -1L);
        if (execType == LopProperties.ExecType.GPU && getDim1() >= 0 && getDim2() >= 0) {
            double initialGPUMemBudget = (GPUContextPool.initialGPUMemBudget() - getOutputMemEstimate()) - arrayList.get(0).getOutputMemEstimate();
            if (lop != null) {
                initialGPUMemBudget -= arrayList.get(1).getOutputMemEstimate();
            }
            computeIntermediateMemEstimate = Math.max(computeIntermediateMemEstimate, initialGPUMemBudget);
        }
        Lop maxPoolOutputLop = execType == LopProperties.ExecType.GPU ? getMaxPoolOutputLop() : null;
        Lop[] lopArr = new Lop[arrayList2.size() - 1];
        for (int i = 1; i < arrayList2.size(); i++) {
            lopArr[i - 1] = arrayList2.get(i).constructLops();
        }
        DnnTransform dnnTransform = new DnnTransform(constructLops, opOpDnn, getDataType(), getValueType(), execType, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads), computeIntermediateMemEstimate);
        setOutputDimensions(dnnTransform);
        setLineNumbers(dnnTransform);
        constructLops.addOutput(dnnTransform);
        if (lop != null) {
            dnnTransform.addInput(lop);
            lop.addOutput(dnnTransform);
        }
        for (int i2 = 0; i2 < lopArr.length; i2++) {
            dnnTransform.addInput(lopArr[i2]);
            lopArr[i2].addOutput(dnnTransform);
        }
        if (maxPoolOutputLop != null) {
            dnnTransform.addInput(maxPoolOutputLop);
            maxPoolOutputLop.addOutput(dnnTransform);
        }
        dnnTransform.updateLopProperties();
        return dnnTransform;
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        if (getOp() == Types.OpOpDnn.BIASMULT && !DMLScript.USE_ACCELERATOR) {
            return OptimizerUtils.estimateSizeExactSparsity(j, j2, getInput().get(0).getSparsity());
        }
        return OptimizerUtils.estimateSizeExactSparsity(j, j2, 1.0d);
    }

    private double computeIntermediateMemEstimateHelper(ArrayList<IntermediateDimensions> arrayList, ArrayList<IntermediateDimensions> arrayList2) {
        int min = (int) Math.min(OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads), Math.max(getDim("N"), 1L));
        if (!DMLScript.USE_ACCELERATOR) {
            return IntermediateDimensions.addEstimateSizes(arrayList2, min);
        }
        double addEstimateSizes = IntermediateDimensions.addEstimateSizes(arrayList, 1);
        double addEstimateSizes2 = IntermediateDimensions.addEstimateSizes(arrayList2, min);
        if (addEstimateSizes2 > addEstimateSizes) {
            double addEstimateSizes3 = IntermediateDimensions.addEstimateSizes(arrayList2, 1);
            if (addEstimateSizes3 <= addEstimateSizes) {
                addEstimateSizes2 = addEstimateSizes3;
            }
        }
        return IntermediateDimensions.guardedMax(addEstimateSizes2, addEstimateSizes);
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        ArrayList<IntermediateDimensions> arrayList = new ArrayList<>();
        ArrayList<IntermediateDimensions> arrayList2 = new ArrayList<>();
        if (getOp() == Types.OpOpDnn.CONV2D) {
            arrayList.add(new IntermediateDimensions(this, 1, "CHW"));
            arrayList.add(new IntermediateDimensions(this, "K", "CRS"));
            arrayList2.add(new IntermediateDimensions(this, "CRS", "PQ", getInput().get(0).getSparsity()));
        } else if (getOp() == Types.OpOpDnn.CONV2D_BACKWARD_DATA) {
            arrayList.add(new IntermediateDimensions(this, 1, "KPQ"));
            arrayList.add(new IntermediateDimensions(this, "K", "CRS"));
            arrayList2.add(new IntermediateDimensions(this, "PQ", "K", getInput().get(1).getSparsity()));
            arrayList2.add(new IntermediateDimensions(this, "PQ", "CRS"));
        } else if (getOp() == Types.OpOpDnn.CONV2D_BACKWARD_FILTER) {
            arrayList.add(new IntermediateDimensions(this, 1, "CHW"));
            arrayList.add(new IntermediateDimensions(this, 1, "KPQ"));
            arrayList2.add(new IntermediateDimensions(this, "PQ", "K", getInput().get(1).getSparsity()));
            arrayList2.add(new IntermediateDimensions(this, "CRS", "PQ", getInput().get(0).getSparsity()));
        } else if (getOp() == Types.OpOpDnn.MAX_POOL || getOp() == Types.OpOpDnn.AVG_POOL) {
            arrayList.add(new IntermediateDimensions(this, 1, "CHW"));
        } else if (getOp() == Types.OpOpDnn.MAX_POOL_BACKWARD || getOp() == Types.OpOpDnn.AVG_POOL_BACKWARD) {
            arrayList.add(new IntermediateDimensions(this, 1, "CHW"));
            arrayList.add(new IntermediateDimensions(this, 1, "CPQ"));
        }
        return (arrayList.size() > 0 || arrayList2.size() > 0) ? computeIntermediateMemEstimateHelper(arrayList, arrayList2) : DataExpression.DEFAULT_DELIM_FILL_VALUE;
    }

    @Override // org.apache.sysds.hops.Hop
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memoTable) {
        new MatrixCharacteristics();
        if (this.op == Types.OpOpDnn.BIASADD || this.op == Types.OpOpDnn.BIASMULT || this.op == Types.OpOpDnn.BATCH_NORM2D_TEST || this.op == Types.OpOpDnn.UPDATE_NESTEROV_X) {
            DataCharacteristics[] allInputStats = memoTable.getAllInputStats(getInput());
            MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics(allInputStats[0].rowsKnown() ? allInputStats[0].getRows() : -1L, allInputStats[0].colsKnown() ? allInputStats[0].getCols() : -1L, -1, -1L);
            if (matrixCharacteristics.dimsKnown()) {
                return matrixCharacteristics;
            }
            return null;
        }
        if (this.op == Types.OpOpDnn.CHANNEL_SUMS) {
            return new MatrixCharacteristics(Hop.computeSizeInformation(getInput().get(1)), 1L, -1, -1L);
        }
        refreshSizeInformation();
        DataCharacteristics dataCharacteristics = this._dc;
        if (dataCharacteristics.dimsKnown()) {
            return dataCharacteristics;
        }
        return null;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean allowsAllExecTypes() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.hops.Hop
    public LopProperties.ExecType optFindExecType() {
        checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            if (OptimizerUtils.isMemoryBasedOptLevel()) {
                this._etype = findExecTypeByMemEstimate();
            } else {
                this._etype = LopProperties.ExecType.SPARK;
            }
            checkAndSetInvalidCPDimsAndSize();
        }
        this._etype = (isEligibleForSpark() || this._etype != LopProperties.ExecType.SPARK) ? this._etype : LopProperties.ExecType.CP;
        setRequiresRecompileIfNecessary();
        return this._etype;
    }

    DnnParameters parseInput() {
        if (this.op == Types.OpOpDnn.MAX_POOL_BACKWARD || this.op == Types.OpOpDnn.AVG_POOL_BACKWARD || this.op == Types.OpOpDnn.CONV2D || this.op == Types.OpOpDnn.CONV2D_BACKWARD_FILTER || this.op == Types.OpOpDnn.CONV2D_BACKWARD_DATA) {
            this._cachedParams.setIfUnknown(getInput().get(6), getInput().get(7), getInput().get(8), getInput().get(9), getInput().get(10), getInput().get(12), getInput().get(13), getInput().get(2), getInput().get(3), getInput().get(4), getInput().get(5), this._maxNumThreads);
        } else {
            this._cachedParams.setIfUnknown(getInput().get(5), getInput().get(6), getInput().get(7), getInput().get(8), getInput().get(9), getInput().get(11), getInput().get(12), getInput().get(1), getInput().get(2), getInput().get(3), getInput().get(4), this._maxNumThreads);
        }
        boolean z = getOp() == Types.OpOpDnn.MAX_POOL || getOp() == Types.OpOpDnn.AVG_POOL;
        boolean z2 = getOp() == Types.OpOpDnn.CONV2D;
        boolean z3 = this._cachedParams.C < 0 || this._cachedParams.H < 0 || this._cachedParams.W < 0 || this._cachedParams.P < 0 || this._cachedParams.Q < 0;
        if ((z || z2) && z3) {
            inferCHWPQFromParentOp();
        }
        if (0 == 0 && this._cachedParams.R < 0 && this._cachedParams.H > 0) {
            this._cachedParams.R = this._cachedParams.H;
        }
        if (this._cachedParams.P < 0 && this._cachedParams.H >= 0 && this._cachedParams.R >= 0 && this._cachedParams.stride_h >= 0 && this._cachedParams.pad_h >= 0) {
            this._cachedParams.P = (int) DnnUtils.getP(this._cachedParams.H, this._cachedParams.R, this._cachedParams.stride_h, this._cachedParams.pad_h);
        }
        if (this._cachedParams.Q < 0 && this._cachedParams.W >= 0 && this._cachedParams.S >= 0 && this._cachedParams.stride_w >= 0 && this._cachedParams.pad_w >= 0) {
            this._cachedParams.Q = (int) DnnUtils.getQ(this._cachedParams.W, this._cachedParams.S, this._cachedParams.stride_w, this._cachedParams.pad_w);
        }
        return this._cachedParams;
    }

    private static boolean isInputBiasAdd(Hop hop) {
        return HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD);
    }

    private static void throwExceptionIfNotEqual(int i, int i2, String str) {
        if (i >= 0 && i2 >= 0 && i != i2) {
            throw new DMLRuntimeException("Inferred " + str + " from parent doesn't match with given " + str + IOUtilFunctions.LIBSVM_INDEX_DELIM + i + " != " + i2);
        }
    }

    private void inferCHWPQFromParentOp() {
        Hop hop = getInput().get(0);
        Hop hop2 = isInputBiasAdd(hop) ? hop.getInput().get(0) : hop;
        Hop isInputReLU = isInputReLU(hop2);
        Hop hop3 = isInputReLU != null ? isInputReLU : hop2;
        DnnOp dnnOp = hop3 instanceof DnnOp ? (DnnOp) hop3 : null;
        if (dnnOp == null) {
            return;
        }
        if (dnnOp.getOp() == Types.OpOpDnn.MAX_POOL || dnnOp.getOp() == Types.OpOpDnn.AVG_POOL) {
            DnnParameters parseInput = dnnOp.parseInput();
            int i = this._cachedParams.C;
            int i2 = this._cachedParams.H;
            int i3 = this._cachedParams.W;
            this._cachedParams.C = this._cachedParams.C < 0 ? parseInput.C : this._cachedParams.C;
            this._cachedParams.H = this._cachedParams.H < 0 ? parseInput.P : this._cachedParams.H;
            this._cachedParams.W = this._cachedParams.W < 0 ? parseInput.Q : this._cachedParams.W;
            if (LOG.isDebugEnabled()) {
                LOG.debug("Inferring [C,H,W] from maxpool parent: [" + i + "," + i2 + "," + i3 + "]-> [" + this._cachedParams.C + "," + this._cachedParams.H + "," + this._cachedParams.W + "]");
            }
            throwExceptionIfNotEqual(i, this._cachedParams.C, "C");
            throwExceptionIfNotEqual(i2, this._cachedParams.H, "H");
            throwExceptionIfNotEqual(i3, this._cachedParams.W, "W");
            return;
        }
        if (dnnOp.getOp() == Types.OpOpDnn.CONV2D) {
            DnnParameters parseInput2 = dnnOp.parseInput();
            int i4 = this._cachedParams.C;
            int i5 = this._cachedParams.H;
            int i6 = this._cachedParams.W;
            this._cachedParams.C = this._cachedParams.C < 0 ? parseInput2.K : this._cachedParams.C;
            this._cachedParams.H = this._cachedParams.H < 0 ? parseInput2.P : this._cachedParams.H;
            this._cachedParams.W = this._cachedParams.W < 0 ? parseInput2.Q : this._cachedParams.W;
            if (LOG.isDebugEnabled()) {
                LOG.debug("Inferring [C,H,W] from maxpool parent: [" + i4 + "," + i5 + "," + i6 + "]-> [" + this._cachedParams.C + "," + this._cachedParams.H + "," + this._cachedParams.W + "]");
            }
            throwExceptionIfNotEqual(i4, this._cachedParams.C, "C");
            throwExceptionIfNotEqual(i5, this._cachedParams.H, "H");
            throwExceptionIfNotEqual(i6, this._cachedParams.W, "W");
        }
    }

    @Override // org.apache.sysds.hops.Hop
    public void refreshSizeInformation() {
        if (this.op == Types.OpOpDnn.BIASADD || this.op == Types.OpOpDnn.BIASMULT || this.op == Types.OpOpDnn.BATCH_NORM2D_TEST || this.op == Types.OpOpDnn.UPDATE_NESTEROV_X) {
            Hop hop = getInput().get(0);
            setDim1(hop.getDim1());
            setDim2(hop.getDim2());
            setNnz(-1L);
            return;
        }
        if (this.op == Types.OpOpDnn.CHANNEL_SUMS) {
            setDim1(Hop.computeSizeInformation(getInput().get(1)));
            setDim2(1L);
            setNnz(-1L);
            return;
        }
        this._cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, this._maxNumThreads);
        switch (this.op) {
            case MAX_POOL:
            case AVG_POOL:
                setDim1(getDim("N"));
                setDim2(getDim("CPQ"));
                setNnz(-1L);
                return;
            case MAX_POOL_BACKWARD:
            case AVG_POOL_BACKWARD:
                setDim1(getDim("N"));
                setDim2(getDim("CHW"));
                setNnz(-1L);
                return;
            case CONV2D:
                setDim1(getDim("N"));
                setDim2(getDim("KPQ"));
                setNnz(-1L);
                return;
            case CONV2D_BACKWARD_DATA:
                setDim1(getDim("N"));
                setDim2(getDim("CHW"));
                setNnz(-1L);
                return;
            case CONV2D_BACKWARD_FILTER:
                setDim1(getDim("K"));
                setDim2(getDim("CRS"));
                setNnz(-1L);
                return;
            default:
                throw new RuntimeException("The sizes are not refreshed for " + this.op.name());
        }
    }

    @Override // org.apache.sysds.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        DnnOp dnnOp = new DnnOp();
        dnnOp.clone(this, false);
        dnnOp.op = this.op;
        dnnOp._maxNumThreads = this._maxNumThreads;
        return dnnOp;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean compare(Hop hop) {
        if (!(hop instanceof DnnOp)) {
            return false;
        }
        DnnOp dnnOp = (DnnOp) hop;
        boolean z = this.op == dnnOp.op && getInput().size() == hop.getInput().size() && this._maxNumThreads == dnnOp._maxNumThreads;
        if (z) {
            for (int i = 0; i < this._input.size(); i++) {
                z &= getInput().get(i) == dnnOp.getInput().get(i);
            }
        }
        return z;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public long getDim(String str) {
        long nonNegative;
        if (this.op == Types.OpOpDnn.BIASADD || this.op == Types.OpOpDnn.BIASMULT || this.op == Types.OpOpDnn.BATCH_NORM2D_TEST || this.op == Types.OpOpDnn.CHANNEL_SUMS || this.op == Types.OpOpDnn.UPDATE_NESTEROV_X) {
            throw new RuntimeException("getDim method should not be invoked for " + this.op.name());
        }
        try {
            parseInput();
            Hop hop = null;
            Hop hop2 = null;
            Hop hop3 = null;
            Hop hop4 = null;
            if (getOp() == Types.OpOpDnn.CONV2D) {
                hop2 = getInput().get(0);
                hop = getInput().get(1);
            } else if (getOp() == Types.OpOpDnn.CONV2D_BACKWARD_DATA) {
                hop = getInput().get(0);
                hop3 = getInput().get(1);
            } else if (getOp() == Types.OpOpDnn.CONV2D_BACKWARD_FILTER) {
                hop2 = getInput().get(0);
                hop3 = getInput().get(1);
            } else if (getOp() == Types.OpOpDnn.MAX_POOL || getOp() == Types.OpOpDnn.AVG_POOL) {
                hop2 = getInput().get(0);
            } else if (getOp() == Types.OpOpDnn.MAX_POOL_BACKWARD || getOp() == Types.OpOpDnn.AVG_POOL_BACKWARD) {
                hop2 = getInput().get(0);
                hop4 = getInput().get(1);
            }
            if (str.equals("K") && hop != null) {
                nonNegative = getNonNegative(-1L, getNonNegative(this._cachedParams.K, hop.getDim1()));
            } else if (str.equals("CRS") && hop != null) {
                nonNegative = getNonNegative(-1L, getNonNegative(nonNegativeMultiply(this._cachedParams.C, this._cachedParams.R, this._cachedParams.S), hop.getDim2()));
            } else if (str.equals("N") && hop2 != null) {
                nonNegative = getNonNegative(-1L, getNonNegative(this._cachedParams.N, hop2.getDim1()));
            } else if (str.equals("CHW") && hop2 != null) {
                nonNegative = getNonNegative(-1L, getNonNegative(nonNegativeMultiply(this._cachedParams.C, this._cachedParams.H, this._cachedParams.W), hop2.getDim2()));
            } else if (str.equals("N") && hop3 != null) {
                nonNegative = getNonNegative(-1L, getNonNegative(this._cachedParams.N, hop3.getDim1()));
            } else if (str.equals("KPQ") && hop3 != null) {
                nonNegative = getNonNegative(-1L, getNonNegative(nonNegativeMultiply(this._cachedParams.K, this._cachedParams.P, this._cachedParams.Q), hop3.getDim2()));
            } else if (str.equals("N") && hop4 != null) {
                nonNegative = getNonNegative(-1L, getNonNegative(this._cachedParams.N, hop4.getDim1()));
            } else if (str.equals("CPQ") && hop4 != null) {
                nonNegative = getNonNegative(-1L, getNonNegative(nonNegativeMultiply(this._cachedParams.C, this._cachedParams.P, this._cachedParams.Q), hop4.getDim2()));
            } else if (str.equals("K")) {
                nonNegative = getNonNegative(-1L, this._cachedParams.K >= 0 ? this._cachedParams.K : -1L);
            } else if (str.equals("CRS")) {
                nonNegative = getNonNegative(-1L, nonNegativeMultiply(this._cachedParams.C, this._cachedParams.R, this._cachedParams.S));
            } else if (str.equals("N")) {
                nonNegative = getNonNegative(-1L, this._cachedParams.N >= 0 ? this._cachedParams.N : -1L);
            } else if (str.equals("CHW")) {
                nonNegative = getNonNegative(-1L, nonNegativeMultiply(this._cachedParams.C, this._cachedParams.H, this._cachedParams.W));
            } else if (str.equals("KPQ")) {
                nonNegative = getNonNegative(-1L, nonNegativeMultiply(this._cachedParams.K, this._cachedParams.P, this._cachedParams.Q));
            } else if (str.equals("PQ")) {
                nonNegative = getNonNegative(-1L, nonNegativeMultiply(this._cachedParams.P, this._cachedParams.Q));
            } else {
                if (!str.equals("CPQ")) {
                    throw new RuntimeException("Unsupported dimension:" + str + " for operator " + getOp().name());
                }
                nonNegative = getNonNegative(-1L, nonNegativeMultiply(this._cachedParams.C, this._cachedParams.P, this._cachedParams.Q));
            }
            if (LOG.isDebugEnabled() && nonNegative < 0) {
                LOG.debug("Unknown dimension " + str + " for DnnOp:" + this.op.name() + " img_dim=[" + this._cachedParams.N + " " + this._cachedParams.C + " " + this._cachedParams.H + " " + this._cachedParams.W + "] filter_dim=[" + this._cachedParams.K + " " + this._cachedParams.C + " " + this._cachedParams.R + " " + this._cachedParams.S + "] output_feature_map=[" + this._cachedParams.P + " " + this._cachedParams.Q + "] stride=[" + this._cachedParams.stride_h + " " + this._cachedParams.stride_w + "] pad=[" + this._cachedParams.pad_h + " " + this._cachedParams.pad_w + "]");
            }
            return nonNegative;
        } catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
    }

    private static long nonNegativeMultiply(long j, long j2, long j3) {
        if (j < 0 || j2 < 0 || j3 < 0) {
            return -1L;
        }
        return j * j2 * j3;
    }

    private static long nonNegativeMultiply(long j, long j2) {
        if (j < 0 || j2 < 0) {
            return -1L;
        }
        return j * j2;
    }

    private static long getNonNegative(long j, long j2) {
        if (j >= 0 && j2 >= 0) {
            if (j == j2) {
                return j;
            }
            throw new RuntimeException("Incorrect dimensions in DnnOp: " + j + " != " + j2);
        }
        if (j >= 0) {
            return j;
        }
        if (j2 >= 0) {
            return j2;
        }
        return -1L;
    }
}
