package org.apache.sysds.hops;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.PartialAggregate;
import org.apache.sysds.lops.TernaryAggregate;
import org.apache.sysds.lops.UAggOuterChain;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

/* loaded from: input_file:org/apache/sysds/hops/AggUnaryOp.class */
public class AggUnaryOp extends MultiThreadedHop {
    private static final boolean ALLOW_UNARYAGG_WO_FINAL_AGG = true;
    private Types.AggOp _op;
    private Types.Direction _direction;
    static final /* synthetic */ boolean $assertionsDisabled;

    private AggUnaryOp() {
    }

    public AggUnaryOp(String str, Types.DataType dataType, Types.ValueType valueType, Types.AggOp aggOp, Types.Direction direction, Hop hop) {
        super(str, dataType, valueType);
        this._op = aggOp;
        this._direction = direction;
        getInput().add(0, hop);
        hop.getParent().add(this);
    }

    @Override // org.apache.sysds.hops.Hop
    public void checkArity() {
        HopsException.check(this._input.size() == 1, this, "should have arity 1 but has arity %d", Integer.valueOf(this._input.size()));
    }

    public Types.AggOp getOp() {
        return this._op;
    }

    public void setOp(Types.AggOp aggOp) {
        this._op = aggOp;
    }

    public Types.Direction getDirection() {
        return this._direction;
    }

    public void setDirection(Types.Direction direction) {
        this._direction = direction;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean isGPUEnabled() {
        if (!DMLScript.USE_ACCELERATOR) {
            return false;
        }
        try {
            if (isTernaryAggregateRewriteApplicable() || isUnaryAggregateOuterCPRewriteApplicable()) {
                return false;
            }
            if (this._op == Types.AggOp.SUM && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Row || this._direction == Types.Direction.Col)) {
                return true;
            }
            if (this._op == Types.AggOp.SUM_SQ && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Row || this._direction == Types.Direction.Col)) {
                return true;
            }
            if (this._op == Types.AggOp.MAX && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Row || this._direction == Types.Direction.Col)) {
                return true;
            }
            if (this._op == Types.AggOp.MIN && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Row || this._direction == Types.Direction.Col)) {
                return true;
            }
            if (this._op == Types.AggOp.MEAN && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Row || this._direction == Types.Direction.Col)) {
                return true;
            }
            if (this._op == Types.AggOp.VAR && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Row || this._direction == Types.Direction.Col)) {
                return true;
            }
            if (this._op == Types.AggOp.PROD) {
                return this._direction == Types.Direction.RowCol;
            }
            return false;
        } catch (HopsException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.apache.sysds.hops.Hop
    public Lop constructLops() {
        Lop partialAggregate;
        if (getLops() != null) {
            return getLops();
        }
        try {
            Types.ExecType optFindExecType = optFindExecType();
            Hop hop = getInput().get(0);
            if (optFindExecType == Types.ExecType.CP || optFindExecType == Types.ExecType.GPU || optFindExecType == Types.ExecType.FED) {
                if (isTernaryAggregateRewriteApplicable()) {
                    partialAggregate = constructLopsTernaryAggregateRewrite(optFindExecType);
                } else if (isUnaryAggregateOuterCPRewriteApplicable()) {
                    BinaryOp binaryOp = (BinaryOp) getInput().get(0);
                    partialAggregate = new UAggOuterChain(binaryOp.getInput().get(0).constructLops(), binaryOp.getInput().get(1).constructLops(), this._op, this._direction, binaryOp.getOp(), Types.DataType.MATRIX, getValueType(), Types.ExecType.CP);
                    PartialAggregate.setDimensionsBasedOnDirection(partialAggregate, getDim1(), getDim2(), hop.getBlocksize(), this._direction);
                    if (getDataType() == Types.DataType.SCALAR) {
                        Lop unaryCP = new UnaryCP(partialAggregate, Types.OpOp1.CAST_AS_SCALAR, getDataType(), getValueType());
                        unaryCP.getOutputParameters().setDimensions(0L, 0L, 0L, -1L);
                        setLineNumbers(unaryCP);
                        partialAggregate = unaryCP;
                    }
                } else {
                    partialAggregate = new PartialAggregate(hop.constructLops(), this._op, this._direction, getDataType(), getValueType(), optFindExecType, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
                }
                setOutputDimensions(partialAggregate);
                setLineNumbers(partialAggregate);
                setLops(partialAggregate);
                if (getDataType() == Types.DataType.SCALAR) {
                    partialAggregate.getOutputParameters().setDimensions(1L, 1L, getBlocksize(), getNnz());
                }
            } else {
                if (optFindExecType != Types.ExecType.SPARK) {
                    throw new HopsException("ExecType " + optFindExecType + " not recognized in " + toString());
                }
                if (isTernaryAggregateRewriteApplicable()) {
                    Lop constructLopsTernaryAggregateRewrite = constructLopsTernaryAggregateRewrite(optFindExecType);
                    setOutputDimensions(constructLopsTernaryAggregateRewrite);
                    setLineNumbers(constructLopsTernaryAggregateRewrite);
                    setLops(constructLopsTernaryAggregateRewrite);
                } else if (isUnaryAggregateOuterSPRewriteApplicable()) {
                    BinaryOp binaryOp2 = (BinaryOp) getInput().get(0);
                    Lop uAggOuterChain = new UAggOuterChain(binaryOp2.getInput().get(0).constructLops(), binaryOp2.getInput().get(1).constructLops(), this._op, this._direction, binaryOp2.getOp(), Types.DataType.MATRIX, getValueType(), Types.ExecType.SPARK);
                    PartialAggregate.setDimensionsBasedOnDirection(uAggOuterChain, getDim1(), getDim2(), hop.getBlocksize(), this._direction);
                    setLineNumbers(uAggOuterChain);
                    setLops(uAggOuterChain);
                    if (getDataType() == Types.DataType.SCALAR) {
                        Lop unaryCP2 = new UnaryCP(uAggOuterChain, Types.OpOp1.CAST_AS_SCALAR, getDataType(), getValueType());
                        unaryCP2.getOutputParameters().setDimensions(0L, 0L, 0L, -1L);
                        setLineNumbers(unaryCP2);
                        setLops(unaryCP2);
                    }
                } else {
                    PartialAggregate partialAggregate2 = new PartialAggregate(hop.constructLops(), this._op, this._direction, hop._dataType, getValueType(), getSparkUnaryAggregationType(requiresAggregation(hop, this._direction)), optFindExecType);
                    partialAggregate2.setDimensionsBasedOnDirection(getDim1(), getDim2(), hop.getBlocksize());
                    setLineNumbers(partialAggregate2);
                    setLops(partialAggregate2);
                    if (getDataType() == Types.DataType.SCALAR) {
                        Lop unaryCP3 = new UnaryCP(partialAggregate2, Types.OpOp1.CAST_AS_SCALAR, getDataType(), getValueType());
                        unaryCP3.getOutputParameters().setDimensions(0L, 0L, 0L, -1L);
                        setLineNumbers(unaryCP3);
                        setLops(unaryCP3);
                    }
                }
            }
            constructAndSetLopsDataFlowProperties();
            return getLops();
        } catch (Exception e) {
            throw new HopsException(printErrorLocation() + "In AggUnary Hop, error constructing Lops ", e);
        }
    }

    @Override // org.apache.sysds.hops.Hop
    public String getOpString() {
        return "ua(" + this._op.toString() + this._direction.toString() + ")";
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        return OptimizerUtils.estimateSizeExactSparsity(j, j2, isGPUEnabled() ? 1.0d : OptimizerUtils.getSparsity(j, j2, j3));
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        double d = 0.0d;
        double sparsity = OptimizerUtils.getSparsity(j, j2, j3);
        switch (this._op) {
            case MAX:
            case MIN:
                if (this._direction == Types.Direction.Col) {
                    d = j2 * 4;
                    break;
                }
                break;
            case SUM:
            case SUM_SQ:
                if (this._direction != Types.Direction.Col) {
                    if (this._direction == Types.Direction.Row) {
                        d = OptimizerUtils.estimateSizeExactSparsity(j, 2L, 1.0d);
                        break;
                    }
                } else {
                    d = OptimizerUtils.estimateSizeExactSparsity(2L, j2, sparsity);
                    break;
                }
                break;
            case MEAN:
                if (this._direction != Types.Direction.Col) {
                    if (this._direction == Types.Direction.Row) {
                        d = OptimizerUtils.estimateSizeExactSparsity(j, 3L, 1.0d);
                        break;
                    }
                } else {
                    d = OptimizerUtils.estimateSizeExactSparsity(3L, j2, sparsity);
                    break;
                }
                break;
            case VAR:
                if (!isGPUEnabled()) {
                    if (this._direction != Types.Direction.Col) {
                        if (this._direction == Types.Direction.Row) {
                            d = OptimizerUtils.estimateSizeExactSparsity(j, 5L, 1.0d);
                            break;
                        }
                    } else {
                        d = OptimizerUtils.estimateSizeExactSparsity(5L, j2, sparsity);
                        break;
                    }
                } else {
                    d = 2 * OptimizerUtils.estimateSize(getInput().get(0).getDim1(), getInput().get(0).getDim2());
                    if (this._direction != Types.Direction.Col) {
                        if (this._direction == Types.Direction.Row) {
                            d += OptimizerUtils.estimateSize(1L, r0);
                            break;
                        }
                    } else {
                        d += OptimizerUtils.estimateSize(r0, 1L);
                        break;
                    }
                }
                break;
            case MAXINDEX:
            case MININDEX:
                Hop hop = getInput().get(0);
                if (!isUnaryAggregateOuterCPRewriteApplicable()) {
                    d = OptimizerUtils.estimateSizeExactSparsity(j, 2L, 1.0d);
                    break;
                } else {
                    d = 3 * OptimizerUtils.estimateSizeExactSparsity(1L, hop.getDim2(), 1.0d);
                    break;
                }
            default:
                d = 0.0d;
                break;
        }
        return d;
    }

    @Override // org.apache.sysds.hops.Hop
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memoTable) {
        MatrixCharacteristics matrixCharacteristics = null;
        DataCharacteristics allInputStats = memoTable.getAllInputStats(getInput().get(0));
        if (this._direction == Types.Direction.Col && allInputStats.colsKnown()) {
            matrixCharacteristics = new MatrixCharacteristics(1L, allInputStats.getCols(), -1, -1L);
        } else if (this._direction == Types.Direction.Row && allInputStats.rowsKnown()) {
            matrixCharacteristics = new MatrixCharacteristics(allInputStats.getRows(), 1L, -2);
        }
        return matrixCharacteristics;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.hops.Hop
    public Types.ExecType optFindExecType(boolean z) {
        checkAndSetForcedPlatform();
        Types.ExecType execType = Types.ExecType.SPARK;
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            if (OptimizerUtils.isMemoryBasedOptLevel()) {
                this._etype = findExecTypeByMemEstimate();
            } else if (getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVector()) {
                this._etype = Types.ExecType.CP;
            } else {
                this._etype = execType;
            }
            checkAndSetInvalidCPDimsAndSize();
        }
        if (z && this._etype == Types.ExecType.CP && this._etypeForced != Types.ExecType.CP && (((!(getInput(0) instanceof DataOp) && getInput(0).optFindExecType() == Types.ExecType.SPARK) || ((getInput(0) instanceof DataOp) && ((DataOp) getInput(0)).hasOnlyRDD())) && (getInput(0).getParent().size() == 1 || getInput(0).getParent().stream().filter(hop -> {
            return hop != this;
        }).allMatch(hop2 -> {
            return hop2.optFindExecType(false) == Types.ExecType.SPARK;
        }) || !requiresAggregation(getInput(0), this._direction)))) {
            this._etype = Types.ExecType.SPARK;
        }
        updateETFed();
        setRequiresRecompileIfNecessary();
        return this._etype;
    }

    private static boolean requiresAggregation(Hop hop, Types.Direction direction) {
        return !(((hop.getDim1() > 1L ? 1 : (hop.getDim1() == 1L ? 0 : -1)) > 0 && (hop.getDim1() > ((long) hop.getBlocksize()) ? 1 : (hop.getDim1() == ((long) hop.getBlocksize()) ? 0 : -1)) <= 0 && direction == Types.Direction.Col) || ((hop.getDim2() > 1L ? 1 : (hop.getDim2() == 1L ? 0 : -1)) > 0 && (hop.getDim2() > ((long) hop.getBlocksize()) ? 1 : (hop.getDim2() == ((long) hop.getBlocksize()) ? 0 : -1)) <= 0 && direction == Types.Direction.Row));
    }

    private AggBinaryOp.SparkAggType getSparkUnaryAggregationType(boolean z) {
        return !z ? AggBinaryOp.SparkAggType.NONE : (getDataType() == Types.DataType.SCALAR || (dimsKnown() && getDim1() <= ((long) getBlocksize()) && getDim2() <= ((long) getBlocksize()))) ? AggBinaryOp.SparkAggType.SINGLE_BLOCK : AggBinaryOp.SparkAggType.MULTI_BLOCK;
    }

    private boolean isTernaryAggregateRewriteApplicable() {
        boolean z = false;
        if (DMLScript.USE_ACCELERATOR) {
            return false;
        }
        if (OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && this._op == Types.AggOp.SUM && (this._direction == Types.Direction.RowCol || this._direction == Types.Direction.Col)) {
            Hop hop = getInput().get(0);
            if (hop.getParent().size() == 1 && (hop instanceof BinaryOp)) {
                BinaryOp binaryOp = (BinaryOp) hop;
                if (binaryOp.getOp() == Types.OpOp2.POW && (binaryOp.getInput().get(1) instanceof LiteralOp)) {
                    z = HopRewriteUtils.getIntValueSafe((LiteralOp) binaryOp.getInput().get(1)) == 3;
                } else if (binaryOp.getOp() == Types.OpOp2.MULT) {
                    Hop hop2 = hop.getInput().get(0);
                    Hop hop3 = hop.getInput().get(1);
                    if ((hop2 instanceof BinaryOp) && ((BinaryOp) hop2).getOp() == Types.OpOp2.MULT) {
                        z = HopRewriteUtils.isEqualSize(hop2.getInput().get(0), hop) && HopRewriteUtils.isEqualSize(hop2.getInput().get(1), hop) && HopRewriteUtils.isEqualSize(hop3, hop);
                    } else if ((hop3 instanceof BinaryOp) && ((BinaryOp) hop3).getOp() == Types.OpOp2.MULT) {
                        z = HopRewriteUtils.isEqualSize(hop3.getInput().get(0), hop) && HopRewriteUtils.isEqualSize(hop3.getInput().get(1), hop) && HopRewriteUtils.isEqualSize(hop2, hop);
                    } else {
                        z = HopRewriteUtils.isEqualSize(hop2, hop3);
                    }
                }
            }
        }
        return z;
    }

    private static boolean isCompareOperator(Types.OpOp2 opOp2) {
        return opOp2 == Types.OpOp2.LESS || opOp2 == Types.OpOp2.LESSEQUAL || opOp2 == Types.OpOp2.GREATER || opOp2 == Types.OpOp2.GREATEREQUAL || opOp2 == Types.OpOp2.EQUAL || opOp2 == Types.OpOp2.NOTEQUAL;
    }

    @Override // org.apache.sysds.hops.MultiThreadedHop
    public boolean isMultiThreadedOpType() {
        return true;
    }

    private boolean isUnaryAggregateOuterSPRewriteApplicable() {
        boolean z = false;
        Hop hop = getInput().get(0);
        if ((hop instanceof BinaryOp) && ((BinaryOp) hop).isOuter()) {
            Hop hop2 = hop.getInput().get(1);
            double estimateSize = hop2.dimsKnown() ? OptimizerUtils.estimateSize(hop2.getDim1(), hop2.getDim2()) : hop2.getOutputMemEstimate();
            if (this._op == Types.AggOp.MAXINDEX || this._op == Types.AggOp.MININDEX) {
                z = 2.0d * estimateSize < SparkExecutionContext.getBroadcastMemoryBudget() && 2.0d * estimateSize < OptimizerUtils.getLocalMemBudget();
            } else if (OptimizerUtils.checkSparkBroadcastMemoryBudget(estimateSize)) {
                z = true;
            }
        }
        return z;
    }

    private boolean isUnaryAggregateOuterCPRewriteApplicable() {
        boolean z = false;
        Hop hop = getInput().get(0);
        if ((hop instanceof BinaryOp) && ((BinaryOp) hop).isOuter() && ((this._op == Types.AggOp.MAXINDEX || this._op == Types.AggOp.MININDEX || this._op == Types.AggOp.SUM) && isCompareOperator(((BinaryOp) hop).getOp()))) {
            z = true;
        }
        return z;
    }

    private Lop constructLopsTernaryAggregateRewrite(Types.ExecType execType) {
        BinaryOp binaryOp = (BinaryOp) getInput().get(0);
        Hop hop = binaryOp.getInput().get(0);
        Hop hop2 = binaryOp.getInput().get(1);
        Lop lop = null;
        Lop lop2 = null;
        Lop lop3 = null;
        boolean z = false;
        if (binaryOp.getOp() != Types.OpOp2.POW) {
            if (!(hop instanceof BinaryOp)) {
                if (hop2 instanceof BinaryOp) {
                    BinaryOp binaryOp2 = (BinaryOp) hop2;
                    switch (binaryOp2.getOp()) {
                        case MULT:
                            lop = hop.constructLops();
                            lop2 = hop2.getInput().get(0).constructLops();
                            lop3 = hop2.getInput().get(1).constructLops();
                            z = true;
                            break;
                        case POW:
                            if (HopRewriteUtils.isLiteralOfValue(binaryOp2.getInput().get(1), 2.0d)) {
                                lop = binaryOp2.getInput().get(0).constructLops();
                                lop2 = lop;
                                lop3 = hop.constructLops();
                                z = true;
                                break;
                            }
                            break;
                    }
                }
            } else {
                BinaryOp binaryOp3 = (BinaryOp) hop;
                switch (binaryOp3.getOp()) {
                    case MULT:
                        lop = hop.getInput().get(0).constructLops();
                        lop2 = hop.getInput().get(1).constructLops();
                        lop3 = hop2.constructLops();
                        z = true;
                        break;
                    case POW:
                        Hop hop3 = binaryOp3.getInput().get(1);
                        if ((!(hop2 instanceof BinaryOp) || ((BinaryOp) hop2).getOp() != Types.OpOp2.MULT) && HopRewriteUtils.isLiteralOfValue(hop3, 2.0d)) {
                            lop = binaryOp3.getInput().get(0).constructLops();
                            lop2 = lop;
                            lop3 = hop2.constructLops();
                            z = true;
                            break;
                        }
                        break;
                }
            }
        } else {
            if (!$assertionsDisabled && !HopRewriteUtils.isLiteralOfValue(hop2, 3.0d)) {
                throw new AssertionError("this case can only occur with a power of 3");
            }
            lop = hop.constructLops();
            lop2 = lop;
            lop3 = lop;
            z = true;
        }
        if (!z) {
            lop = hop.constructLops();
            lop2 = hop2.constructLops();
            lop3 = new LiteralOp(1L).constructLops();
        }
        int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        Types.ExecType optFindExecType = binaryOp.optFindExecType();
        return new TernaryAggregate(lop, lop2, lop3, Types.AggOp.SUM, Types.OpOp2.MULT, this._direction, getDataType(), Types.ValueType.FP64, optFindExecType == Types.ExecType.GPU ? Types.ExecType.CP : optFindExecType, constrainedNumThreads);
    }

    @Override // org.apache.sysds.hops.Hop
    public void refreshSizeInformation() {
        if (getDataType() != Types.DataType.SCALAR) {
            Hop hop = getInput().get(0);
            if (this._direction == Types.Direction.Col) {
                setDim1(1L);
                setDim2(hop.getDim2());
            } else if (this._direction == Types.Direction.Row) {
                setDim1(hop.getDim1());
                setDim2(1L);
            }
        }
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean isTransposeSafe() {
        return this._direction == Types.Direction.RowCol && (this._op == Types.AggOp.SUM || this._op == Types.AggOp.SUM_SQ || this._op == Types.AggOp.MIN || this._op == Types.AggOp.MAX || this._op == Types.AggOp.PROD || this._op == Types.AggOp.MEAN || this._op == Types.AggOp.VAR);
    }

    @Override // org.apache.sysds.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        AggUnaryOp aggUnaryOp = new AggUnaryOp();
        aggUnaryOp.clone(this, false);
        aggUnaryOp._op = this._op;
        aggUnaryOp._direction = this._direction;
        aggUnaryOp._maxNumThreads = this._maxNumThreads;
        return aggUnaryOp;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean compare(Hop hop) {
        if (!(hop instanceof AggUnaryOp)) {
            return false;
        }
        AggUnaryOp aggUnaryOp = (AggUnaryOp) hop;
        return this._op == aggUnaryOp._op && this._direction == aggUnaryOp._direction && this._maxNumThreads == aggUnaryOp._maxNumThreads && getInput().get(0) == aggUnaryOp.getInput().get(0);
    }

    static {
        $assertionsDisabled = !AggUnaryOp.class.desiredAssertionStatus();
    }
}
