package org.apache.sysds.hops;

import java.util.ArrayList;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Compression;
import org.apache.sysds.lops.CumulativeOffsetBinary;
import org.apache.sysds.lops.CumulativePartialAggregate;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.DeCompression;
import org.apache.sysds.lops.Local;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.lops.SortKeys;
import org.apache.sysds.lops.Unary;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/hops/UnaryOp.class */
public class UnaryOp extends MultiThreadedHop {
    private static final boolean ALLOW_CUMAGG_BROADCAST = true;
    private static final boolean ALLOW_CUMAGG_CACHING = false;
    private Types.OpOp1 _op;

    private UnaryOp() {
        this._op = null;
    }

    public UnaryOp(String str, Types.DataType dataType, Types.ValueType valueType, Types.OpOp1 opOp1, Hop hop) {
        super(str, dataType, valueType);
        this._op = null;
        getInput().add(hop);
        hop.getParent().add(this);
        this._op = opOp1;
        refreshSizeInformation();
    }

    @Override // org.apache.sysds.hops.Hop
    public void checkArity() {
        HopsException.check(this._input.size() == 1, this, "should have arity 1 but has arity %d", Integer.valueOf(this._input.size()));
    }

    public Types.OpOp1 getOp() {
        return this._op;
    }

    @Override // org.apache.sysds.hops.Hop
    public String getOpString() {
        return new String("") + "u(" + this._op.toString().toLowerCase() + ")";
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean isGPUEnabled() {
        if (!DMLScript.USE_ACCELERATOR) {
            return false;
        }
        if (getDataType() == Types.DataType.SCALAR || (this._op == Types.OpOp1.CAST_AS_MATRIX && getInput().get(0).getDataType() == Types.DataType.SCALAR) || (this._op == Types.OpOp1.CAST_AS_FRAME && getInput().get(0).getDataType() == Types.DataType.SCALAR)) {
            return false;
        }
        switch (this._op) {
            case EXP:
            case SQRT:
            case LOG:
            case ABS:
            case ROUND:
            case FLOOR:
            case CEIL:
            case SIN:
            case COS:
            case TAN:
            case ASIN:
            case ACOS:
            case ATAN:
            case SINH:
            case COSH:
            case TANH:
            case SIGN:
            case SIGMOID:
            case CUMSUM:
            case CUMPROD:
            case CUMMIN:
            case CUMMAX:
            case CUMSUMPROD:
                return true;
            default:
                return false;
        }
    }

    @Override // org.apache.sysds.hops.MultiThreadedHop
    public boolean isMultiThreadedOpType() {
        return isCumulativeUnaryOperation() || isExpensiveUnaryOperation();
    }

    @Override // org.apache.sysds.hops.Hop
    public Lop constructLops() {
        Lop unaryCP;
        if (getLops() != null) {
            return getLops();
        }
        try {
            Hop hop = getInput().get(0);
            switch (this._op) {
                case COMPRESS:
                    OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
                    unaryCP = new Compression(hop.constructLops(), getDataType(), getValueType(), optFindExecType(), 0);
                    break;
                case DECOMPRESS:
                    OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
                    unaryCP = new DeCompression(hop.constructLops(), getDataType(), getValueType(), optFindExecType());
                    break;
                case LOCAL:
                    unaryCP = new Local(hop.constructLops(), getDataType(), getValueType());
                    break;
                default:
                    if (getDataType() != Types.DataType.SCALAR && ((this._op != Types.OpOp1.CAST_AS_MATRIX || getInput().get(0).getDataType() != Types.DataType.SCALAR) && (this._op != Types.OpOp1.CAST_AS_FRAME || getInput().get(0).getDataType() != Types.DataType.SCALAR))) {
                        Types.ExecType optFindExecType = optFindExecType();
                        if (isCumulativeUnaryOperation() && optFindExecType != Types.ExecType.CP && optFindExecType != Types.ExecType.GPU) {
                            unaryCP = constructLopsSparkCumulativeUnary();
                            break;
                        } else {
                            unaryCP = new Unary(hop.constructLops(), this._op, getDataType(), getValueType(), optFindExecType, (isCumulativeUnaryOperation() || isExpensiveUnaryOperation()) ? OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads) : 1, OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE && hop.getParent().size() == 1 && !((hop instanceof DataOp) && ((DataOp) hop).isRead()));
                            break;
                        }
                    } else if (this._op != Types.OpOp1.IQM) {
                        if (this._op != Types.OpOp1.MEDIAN) {
                            unaryCP = new UnaryCP(hop.constructLops(), this._op, getDataType(), getValueType());
                            break;
                        } else {
                            unaryCP = constructLopsMedian();
                            break;
                        }
                    } else {
                        unaryCP = constructLopsIQM();
                        break;
                    }
                    break;
            }
            setOutputDimensions(unaryCP);
            setLineNumbers(unaryCP);
            setLops(unaryCP);
            constructAndSetLopsDataFlowProperties();
            return getLops();
        } catch (Exception e) {
            throw new HopsException(printErrorLocation() + "error constructing Lops for UnaryOp Hop -- \n ", e);
        }
    }

    private Lop constructLopsMedian() {
        Types.ExecType optFindExecType = optFindExecType();
        SortKeys constructSortByValueLop = SortKeys.constructSortByValueLop(getInput().get(0).constructLops(), SortKeys.OperationTypes.WithoutWeights, Types.DataType.MATRIX, Types.ValueType.FP64, optFindExecType, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        constructSortByValueLop.getOutputParameters().setDimensions(getInput().get(0).getDim1(), getInput().get(0).getDim2(), getInput().get(0).getBlocksize(), getInput().get(0).getNnz());
        PickByCount pickByCount = new PickByCount(constructSortByValueLop, Data.createLiteralLop(Types.ValueType.FP64, Double.toString(0.5d)), getDataType(), getValueType(), PickByCount.OperationTypes.MEDIAN, optFindExecType, true);
        pickByCount.getOutputParameters().setDimensions(getDim1(), getDim2(), getBlocksize(), getNnz());
        setLineNumbers(pickByCount);
        setLops(pickByCount);
        return pickByCount;
    }

    private Lop constructLopsIQM() {
        Types.ExecType optFindExecType = optFindExecType();
        int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        Hop hop = getInput().get(0);
        SortKeys constructSortByValueLop = SortKeys.constructSortByValueLop(hop.constructLops(), SortKeys.OperationTypes.WithoutWeights, Types.DataType.MATRIX, Types.ValueType.FP64, optFindExecType, constrainedNumThreads);
        constructSortByValueLop.getOutputParameters().setDimensions(hop.getDim1(), hop.getDim2(), hop.getBlocksize(), hop.getNnz());
        PickByCount pickByCount = new PickByCount(constructSortByValueLop, null, getDataType(), getValueType(), PickByCount.OperationTypes.IQM, optFindExecType, true);
        pickByCount.getOutputParameters().setDimensions(getDim1(), getDim2(), getBlocksize(), getNnz());
        setLineNumbers(pickByCount);
        return pickByCount;
    }

    private Lop constructLopsSparkCumulativeUnary() {
        Hop hop = getInput().get(0);
        long dim1 = hop.getDim1();
        long dim2 = hop.getDim2();
        long blocksize = hop.getBlocksize();
        boolean z = !dimsKnown() || this._etypeForced == Types.ExecType.SPARK;
        Types.AggOp cumulativeAggType = getCumulativeAggType();
        Lop constructLops = hop.constructLops();
        if (dim1 > 0 && dim2 > 0 && dim1 <= blocksize) {
            return constructCumOffBinary(constructLops, HopRewriteUtils.createDataGenOpByVal(new LiteralOp(1L), new LiteralOp(dim2), null, Types.DataType.MATRIX, Types.ValueType.FP64, getCumulativeInitValue()).constructLops(), cumulativeAggType, dim1, dim2, blocksize);
        }
        Lop lop = constructLops;
        ArrayList arrayList = new ArrayList();
        int i = 0;
        while (true) {
            if (((2 * OptimizerUtils.estimateSize(lop.getOutputParameters().getNumRows(), dim2)) + OptimizerUtils.estimateSize(1L, dim2) <= OptimizerUtils.getLocalMemBudget() || lop.getOutputParameters().getNumRows() <= 1) && !z) {
                break;
            }
            arrayList.add(lop);
            long ceil = (long) Math.ceil(lop.getOutputParameters().getNumRows() / blocksize);
            CumulativePartialAggregate cumulativePartialAggregate = new CumulativePartialAggregate(lop, Types.DataType.MATRIX, Types.ValueType.FP64, cumulativeAggType, Types.ExecType.SPARK);
            cumulativePartialAggregate.getOutputParameters().setDimensions(ceil, dim2, blocksize, -1L);
            setLineNumbers(cumulativePartialAggregate);
            lop = cumulativePartialAggregate;
            i++;
            z = false;
        }
        if (lop.getOutputParameters().getNumRows() != 1) {
            Unary unary = new Unary(lop, this._op, Types.DataType.MATRIX, Types.ValueType.FP64, Types.ExecType.CP, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads), true);
            unary.getOutputParameters().setDimensions(lop.getOutputParameters().getNumRows(), dim2, blocksize, -1L);
            setLineNumbers(unary);
            lop = unary;
        }
        while (true) {
            int i2 = i;
            i--;
            if (i2 <= 0) {
                return lop;
            }
            lop = constructCumOffBinary((Lop) arrayList.get(i), lop, cumulativeAggType, dim1, dim2, blocksize);
        }
    }

    private Lop constructCumOffBinary(Lop lop, Lop lop2, Types.AggOp aggOp, long j, long j2, long j3) {
        CumulativeOffsetBinary cumulativeOffsetBinary = new CumulativeOffsetBinary(lop, lop2, Types.DataType.MATRIX, Types.ValueType.FP64, getCumulativeInitValue(), OptimizerUtils.checkSparkBroadcastMemoryBudget(OptimizerUtils.estimateSize(lop2.getOutputParameters().getNumRows(), lop2.getOutputParameters().getNumCols())), aggOp, Types.ExecType.SPARK);
        cumulativeOffsetBinary.getOutputParameters().setDimensions(j, j2, j3, -1L);
        setLineNumbers(cumulativeOffsetBinary);
        return cumulativeOffsetBinary;
    }

    private Types.AggOp getCumulativeAggType() {
        switch (this._op) {
            case CUMSUM:
                return Types.AggOp.SUM;
            case CUMPROD:
                return Types.AggOp.PROD;
            case CUMMIN:
                return Types.AggOp.MIN;
            case CUMMAX:
                return Types.AggOp.MAX;
            case CUMSUMPROD:
                return Types.AggOp.SUM_PROD;
            default:
                return null;
        }
    }

    private double getCumulativeInitValue() {
        switch (this._op) {
            case CUMSUM:
            case CUMSUMPROD:
                return DataExpression.DEFAULT_DELIM_FILL_VALUE;
            case CUMPROD:
                return 1.0d;
            case CUMMIN:
                return Double.POSITIVE_INFINITY;
            case CUMMAX:
                return Double.NEGATIVE_INFINITY;
            default:
                return Double.NaN;
        }
    }

    @Override // org.apache.sysds.hops.Hop
    public void computeMemEstimate(MemoTable memoTable) {
        super.computeMemEstimate(memoTable);
        if (isMetadataOperation()) {
            this._memEstimate = 4.0d;
        }
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        return OptimizerUtils.estimateSizeExactSparsity(j, j2, isGPUEnabled() ? 1.0d : OptimizerUtils.getSparsity(j, j2, j3));
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        double d = 0.0d;
        if (this._op == Types.OpOp1.IQM || this._op == Types.OpOp1.MEDIAN) {
            d = getInput().get(0).getMemEstimate() * 3.0d;
        } else if (isCumulativeUnaryOperation()) {
            d = DataExpression.DEFAULT_DELIM_FILL_VALUE + MatrixBlock.estimateSizeSparseInMemory(j, j2, 0.4d - UtilFunctions.DOUBLE_EPS);
        }
        if (isGPUEnabled()) {
            d += OptimizerUtils.estimateSize(j, j2);
        }
        return d;
    }

    @Override // org.apache.sysds.hops.Hop
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memoTable) {
        DataCharacteristics allInputStats = memoTable.getAllInputStats(getInput().get(0));
        MatrixCharacteristics matrixCharacteristics = null;
        if (allInputStats.dimsKnown()) {
            matrixCharacteristics = (this._op == Types.OpOp1.ABS || this._op == Types.OpOp1.COS || this._op == Types.OpOp1.SIN || this._op == Types.OpOp1.TAN || this._op == Types.OpOp1.ACOS || this._op == Types.OpOp1.ASIN || this._op == Types.OpOp1.ATAN || this._op == Types.OpOp1.COSH || this._op == Types.OpOp1.SINH || this._op == Types.OpOp1.TANH || this._op == Types.OpOp1.SQRT || this._op == Types.OpOp1.ROUND || this._op == Types.OpOp1.SPROP || this._op == Types.OpOp1.COMPRESS || this._op == Types.OpOp1.DECOMPRESS || this._op == Types.OpOp1.LOCAL) ? new MatrixCharacteristics(allInputStats.getRows(), allInputStats.getCols(), -1, allInputStats.getNonZeros()) : this._op == Types.OpOp1.CUMSUMPROD ? new MatrixCharacteristics(allInputStats.getRows(), 1L, -1, -1L) : new MatrixCharacteristics(allInputStats.getRows(), allInputStats.getCols(), -1, -1L);
        }
        return matrixCharacteristics;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean allowsAllExecTypes() {
        return true;
    }

    private boolean isInMemoryOperation() {
        return this._op == Types.OpOp1.INVERSE;
    }

    public boolean isCumulativeUnaryOperation() {
        return this._op == Types.OpOp1.CUMSUM || this._op == Types.OpOp1.CUMPROD || this._op == Types.OpOp1.CUMMIN || this._op == Types.OpOp1.CUMMAX || this._op == Types.OpOp1.CUMSUMPROD;
    }

    public boolean isCastUnaryOperation() {
        return this._op == Types.OpOp1.CAST_AS_MATRIX || this._op == Types.OpOp1.CAST_AS_SCALAR || this._op == Types.OpOp1.CAST_AS_FRAME || this._op == Types.OpOp1.CAST_AS_LIST || this._op == Types.OpOp1.CAST_AS_BOOLEAN || this._op == Types.OpOp1.CAST_AS_DOUBLE || this._op == Types.OpOp1.CAST_AS_INT;
    }

    public boolean isExpensiveUnaryOperation() {
        return this._op == Types.OpOp1.EXP || this._op == Types.OpOp1.LOG || this._op == Types.OpOp1.SIGMOID || this._op == Types.OpOp1.COMPRESS || this._op == Types.OpOp1.DECOMPRESS || this._op == Types.OpOp1.MEDIAN || this._op == Types.OpOp1.IQM;
    }

    public boolean isMetadataOperation() {
        return this._op == Types.OpOp1.NROW || this._op == Types.OpOp1.NCOL || this._op == Types.OpOp1.LENGTH || this._op == Types.OpOp1.EXISTS || this._op == Types.OpOp1.LINEAGE || this._op == Types.OpOp1.CAST_AS_LIST;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.hops.Hop
    public Types.ExecType optFindExecType(boolean z) {
        checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            if (OptimizerUtils.isMemoryBasedOptLevel()) {
                this._etype = findExecTypeByMemEstimate();
            } else if (getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVector() || isInMemoryOperation()) {
                this._etype = Types.ExecType.CP;
            } else {
                this._etype = Types.ExecType.SPARK;
            }
            checkAndSetInvalidCPDimsAndSize();
        }
        if (this._etype == Types.ExecType.CP && this._etypeForced != Types.ExecType.CP && getInput().get(0).optFindExecType() == Types.ExecType.SPARK && getDataType().isMatrix() && !isCumulativeUnaryOperation() && !isCastUnaryOperation() && this._op != Types.OpOp1.MEDIAN && this._op != Types.OpOp1.IQM && !(getInput().get(0) instanceof DataOp) && getInput().get(0).getParent().size() == 1) {
            this._etype = Types.ExecType.SPARK;
        }
        setRequiresRecompileIfNecessary();
        if (this._op == Types.OpOp1.PRINT || this._op == Types.OpOp1.ASSERT || this._op == Types.OpOp1.STOP || this._op == Types.OpOp1.TYPEOF || this._op == Types.OpOp1.INVERSE || this._op == Types.OpOp1.EIGEN || this._op == Types.OpOp1.CHOLESKY || this._op == Types.OpOp1.SVD || getInput().get(0).getDataType() == Types.DataType.LIST || isMetadataOperation()) {
            this._etype = Types.ExecType.CP;
        } else {
            setRequiresRecompileIfNecessary();
        }
        return this._etype;
    }

    @Override // org.apache.sysds.hops.Hop
    public void refreshSizeInformation() {
        Hop hop = getInput().get(0);
        if (getDataType() == Types.DataType.SCALAR) {
            return;
        }
        if ((this._op == Types.OpOp1.CAST_AS_MATRIX || this._op == Types.OpOp1.CAST_AS_FRAME || this._op == Types.OpOp1.CAST_AS_SCALAR) && hop.getDataType() == Types.DataType.LIST) {
            setDim1(hop.getLength() > 1 ? hop.getLength() : -1L);
            setDim2(hop.getLength() > 1 ? 1L : -1L);
            return;
        }
        if ((this._op == Types.OpOp1.CAST_AS_MATRIX || this._op == Types.OpOp1.CAST_AS_FRAME) && hop.getDataType() == Types.DataType.SCALAR) {
            setDim1(1L);
            setDim2(1L);
            return;
        }
        if (this._op == Types.OpOp1.CAST_AS_LIST) {
            setDim1(-1L);
            setDim2(1L);
            return;
        }
        if (this._op == Types.OpOp1.CUMSUMPROD) {
            setDim1(hop.getDim1());
            setDim2(1L);
            return;
        }
        if (this._op == Types.OpOp1.TYPEOF || this._op == Types.OpOp1.DETECTSCHEMA || this._op == Types.OpOp1.COLNAMES) {
            setDim1(1L);
            setDim2(hop.getDim2());
            return;
        }
        setDim1(hop.getDim1());
        setDim2(hop.getDim2());
        if (this._op == Types.OpOp1.ABS || this._op == Types.OpOp1.SIN || this._op == Types.OpOp1.TAN || this._op == Types.OpOp1.SINH || this._op == Types.OpOp1.TANH || this._op == Types.OpOp1.ASIN || this._op == Types.OpOp1.ATAN || this._op == Types.OpOp1.SQRT || this._op == Types.OpOp1.ROUND || this._op == Types.OpOp1.SPROP || this._op == Types.OpOp1.COMPRESS || this._op == Types.OpOp1.DECOMPRESS || this._op == Types.OpOp1.LOCAL) {
            setNnz(hop.getNnz());
        }
        if (!hop._compressedOutput || this._op == Types.OpOp1.DECOMPRESS) {
            return;
        }
        setCompressedOutput(true);
        setCompressedSize(hop.compressedSize() * 2);
    }

    @Override // org.apache.sysds.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        UnaryOp unaryOp = new UnaryOp();
        unaryOp.clone(this, false);
        unaryOp._op = this._op;
        return unaryOp;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean compare(Hop hop) {
        if (!(hop instanceof UnaryOp) || this._op == Types.OpOp1.PRINT) {
            return false;
        }
        UnaryOp unaryOp = (UnaryOp) hop;
        return this._op == unaryOp._op && getInput().get(0) == unaryOp.getInput().get(0);
    }
}
