package org.apache.sysds.hops;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.GroupedAggregate;
import org.apache.sysds.lops.GroupedAggregateM;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.ParameterizedBuiltin;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/hops/ParameterizedBuiltinOp.class */
public class ParameterizedBuiltinOp extends MultiThreadedHop {
    private static final Log LOG = LogFactory.getLog(ParameterizedBuiltinOp.class.getName());
    public static boolean FORCE_DIST_RM_EMPTY = false;
    private Types.ParamBuiltinOp _op;
    private boolean _outputPermutationMatrix;
    private boolean _bRmEmptyBC;
    private HashMap<String, Integer> _paramIndexMap;

    private ParameterizedBuiltinOp() {
        this._outputPermutationMatrix = false;
        this._bRmEmptyBC = false;
        this._paramIndexMap = new HashMap<>();
    }

    public ParameterizedBuiltinOp(String str, Types.DataType dataType, Types.ValueType valueType, Types.ParamBuiltinOp paramBuiltinOp, LinkedHashMap<String, Hop> linkedHashMap) {
        super(str, dataType, valueType);
        this._outputPermutationMatrix = false;
        this._bRmEmptyBC = false;
        this._paramIndexMap = new HashMap<>();
        this._op = paramBuiltinOp;
        int i = 0;
        for (Map.Entry<String, Hop> entry : linkedHashMap.entrySet()) {
            String key = entry.getKey();
            Hop value = entry.getValue();
            getInput().add(value);
            value.getParent().add(this);
            this._paramIndexMap.put(key, Integer.valueOf(i));
            i++;
        }
        refreshSizeInformation();
    }

    @Override // org.apache.sysds.hops.Hop
    public void checkArity() {
        int size = this._input.size();
        int size2 = this._paramIndexMap.size();
        HopsException.check(size == size2, this, "has %d inputs but %d parameters", Integer.valueOf(size), Integer.valueOf(size2));
    }

    public HashMap<String, Integer> getParamIndexMap() {
        return this._paramIndexMap;
    }

    @Override // org.apache.sysds.hops.Hop
    public String getOpString() {
        return "" + this._op;
    }

    public Types.ParamBuiltinOp getOp() {
        return this._op;
    }

    public void setOutputPermutationMatrix(boolean z) {
        this._outputPermutationMatrix = z;
    }

    public Hop getTargetHop() {
        return getParameterHop("target");
    }

    public Hop getParameterHop(String str) {
        if (this._paramIndexMap.containsKey(str)) {
            return getInput().get(this._paramIndexMap.get(str).intValue());
        }
        return null;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean isGPUEnabled() {
        return false;
    }

    @Override // org.apache.sysds.hops.MultiThreadedHop
    public boolean isMultiThreadedOpType() {
        return HopRewriteUtils.isValidOp(this._op, Types.ParamBuiltinOp.GROUPEDAGG, Types.ParamBuiltinOp.REXPAND, Types.ParamBuiltinOp.PARAMSERV);
    }

    @Override // org.apache.sysds.hops.Hop
    public Lop constructLops() {
        if (getLops() != null) {
            return getLops();
        }
        HashMap<String, Lop> hashMap = new HashMap<>();
        for (Map.Entry<String, Integer> entry : this._paramIndexMap.entrySet()) {
            hashMap.put(entry.getKey(), getInput().get(entry.getValue().intValue()).constructLops());
        }
        switch (this._op) {
            case GROUPEDAGG:
                constructLopsGroupedAggregate(hashMap, optFindExecType());
                break;
            case RMEMPTY:
                constructLopsRemoveEmpty(hashMap, optFindExecType());
                break;
            case REXPAND:
                constructLopsRExpand(hashMap, optFindExecType());
                break;
            case CDF:
            case INVCDF:
            case REPLACE:
            case LOWER_TRI:
            case UPPER_TRI:
            case TOKENIZE:
            case TRANSFORMAPPLY:
            case TRANSFORMDECODE:
            case TRANSFORMCOLMAP:
            case TRANSFORMMETA:
            case TOSTRING:
            case PARAMSERV:
            case LIST:
                Lop parameterizedBuiltin = new ParameterizedBuiltin(hashMap, this._op, getDataType(), getValueType(), optFindExecType());
                setOutputDimensions(parameterizedBuiltin);
                setLineNumbers(parameterizedBuiltin);
                setLops(parameterizedBuiltin);
                break;
            default:
                throw new HopsException("Unknown ParamBuiltinOp: " + this._op);
        }
        constructAndSetLopsDataFlowProperties();
        return getLops();
    }

    private void constructLopsGroupedAggregate(HashMap<String, Lop> hashMap, Types.ExecType execType) {
        setRequiresReblock(false);
        long j = -1;
        long j2 = -1;
        Lop lop = hashMap.get(Statement.GAGG_NUM_GROUPS);
        if (!dimsKnown() && lop != null && (lop instanceof Data) && ((Data) lop).isLiteral()) {
            long longValue = ((Data) lop).getLongValue();
            Lop lop2 = hashMap.get(GroupedAggregate.COMBINEDINPUT);
            long numRows = lop2.getOutputParameters().getNumRows();
            long numCols = lop2.getOutputParameters().getNumCols();
            if (numRows == 1 && numCols > 1) {
                j = longValue;
                j2 = 1;
            } else {
                j = numCols;
                j2 = longValue;
            }
        }
        Lop lop3 = null;
        if (execType == Types.ExecType.CP) {
            lop3 = new GroupedAggregate(hashMap, getDataType(), getValueType(), execType, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
            lop3.getOutputParameters().setDimensions(j, j2, getBlocksize(), -1L);
        } else if (execType == Types.ExecType.SPARK) {
            Hop parameterHop = getParameterHop(Statement.GAGG_GROUPS);
            boolean z = this._paramIndexMap.get(Statement.GAGG_WEIGHTS) == null && OptimizerUtils.checkSparkBroadcastMemoryBudget(parameterHop.getDim1(), parameterHop.getDim2(), (long) parameterHop.getBlocksize(), parameterHop.getNnz());
            if (z && (getParameterHop(Statement.GAGG_FN) instanceof LiteralOp) && ((LiteralOp) getParameterHop(Statement.GAGG_FN)).getStringValue().equals(Statement.GAGG_FN_SUM) && hashMap.get(Statement.GAGG_NUM_GROUPS) != null) {
                Hop targetHop = getTargetHop();
                lop3 = new GroupedAggregateM(hashMap, getDataType(), getValueType(), true, Types.ExecType.SPARK);
                lop3.getOutputParameters().setDimensions(j, j2, targetHop.getBlocksize(), -1L);
            } else {
                lop3 = new GroupedAggregate(hashMap, getDataType(), getValueType(), execType, z);
                lop3.getOutputParameters().setDimensions(j, j2, -1L, -1L);
                setRequiresReblock(true);
            }
        }
        setLineNumbers(lop3);
        setLops(lop3);
    }

    private void constructLopsRemoveEmpty(HashMap<String, Lop> hashMap, Types.ExecType execType) {
        Hop hop;
        Hop targetHop = getTargetHop();
        Hop parameterHop = getParameterHop("margin");
        Hop parameterHop2 = getParameterHop("select");
        if (execType == Types.ExecType.CP) {
            ParameterizedBuiltin parameterizedBuiltin = new ParameterizedBuiltin(hashMap, this._op, getDataType(), getValueType(), execType);
            setOutputDimensions(parameterizedBuiltin);
            setLineNumbers(parameterizedBuiltin);
            setLops(parameterizedBuiltin);
            return;
        }
        if (execType == Types.ExecType.SPARK) {
            if (!(parameterHop instanceof LiteralOp)) {
                throw new HopsException("Parameter 'margin' must be a literal argument.");
            }
            long dim1 = targetHop.getDim1();
            long dim2 = targetHop.getDim2();
            int blocksize = targetHop.getBlocksize();
            boolean equals = ((LiteralOp) parameterHop).getStringValue().equals("rows");
            BinaryOp binaryOp = null;
            if (parameterHop2 == null) {
                binaryOp = HopRewriteUtils.createBinary(targetHop, new LiteralOp(0L), Types.OpOp2.NOTEQUAL);
                binaryOp.setForcedExecType(Types.ExecType.SPARK);
                hop = binaryOp;
                if ((!equals || dim2 != 1) && (equals || dim1 != 1)) {
                    hop = HopRewriteUtils.createAggUnaryOp(binaryOp, Types.AggOp.MAX, equals ? Types.Direction.Row : Types.Direction.Col);
                    hop.setForcedExecType(Types.ExecType.SPARK);
                }
            } else {
                hop = parameterHop2;
            }
            Hop hop2 = hop;
            if (!equals) {
                hop2 = HopRewriteUtils.createTranspose(hop);
                HopRewriteUtils.updateHopCharacteristics(hop2, blocksize, this);
            }
            MultiThreadedHop createUnary = HopRewriteUtils.createUnary(hop2, Types.OpOp1.CUMSUM);
            HopRewriteUtils.updateHopCharacteristics(createUnary, blocksize, this);
            MultiThreadedHop multiThreadedHop = createUnary;
            if (!equals) {
                multiThreadedHop = HopRewriteUtils.createTranspose(createUnary);
                HopRewriteUtils.updateHopCharacteristics(multiThreadedHop, blocksize, this);
            }
            AggUnaryOp createAggUnaryOp = HopRewriteUtils.createAggUnaryOp(multiThreadedHop, Types.AggOp.MAX, Types.Direction.RowCol);
            HopRewriteUtils.updateHopCharacteristics(createAggUnaryOp, blocksize, this);
            BinaryOp createBinary = HopRewriteUtils.createBinary(multiThreadedHop, hop, Types.OpOp2.MULT);
            HopRewriteUtils.updateHopCharacteristics(createBinary, blocksize, this);
            Lop constructLops = targetHop.constructLops();
            Lop constructLops2 = createBinary.constructLops();
            Lop constructLops3 = createAggUnaryOp.constructLops();
            HashMap hashMap2 = new HashMap();
            hashMap2.put("target", constructLops);
            hashMap2.put("offset", constructLops2);
            hashMap2.put("maxdim", constructLops3);
            hashMap2.put("margin", hashMap.get("margin"));
            hashMap2.put("empty.return", hashMap.get("empty.return"));
            if (!FORCE_DIST_RM_EMPTY && isRemoveEmptyBcSP()) {
                this._bRmEmptyBC = true;
            }
            ParameterizedBuiltin parameterizedBuiltin2 = new ParameterizedBuiltin((HashMap<String, Lop>) hashMap2, this._op, getDataType(), getValueType(), execType, this._bRmEmptyBC);
            setOutputDimensions(parameterizedBuiltin2);
            setLineNumbers(parameterizedBuiltin2);
            if (parameterHop2 == null) {
                HopRewriteUtils.removeChildReference(binaryOp, targetHop);
            }
            setLops(parameterizedBuiltin2);
        }
    }

    private void constructLopsRExpand(HashMap<String, Lop> hashMap, Types.ExecType execType) {
        ParameterizedBuiltin parameterizedBuiltin = new ParameterizedBuiltin(hashMap, this._op, getDataType(), getValueType(), execType, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        setOutputDimensions(parameterizedBuiltin);
        setLineNumbers(parameterizedBuiltin);
        setLops(parameterizedBuiltin);
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        if (getOp() != Types.ParamBuiltinOp.TOSTRING) {
            return OptimizerUtils.estimateSizeExactSparsity(j, j2, OptimizerUtils.getSparsity(j, j2, j3));
        }
        long j4 = 100;
        long j5 = 100;
        boolean z = false;
        String str = " ";
        String str2 = ProgramConverter.NEWLINE;
        Hop parameterHop = getParameterHop("rows");
        Hop parameterHop2 = getParameterHop("cols");
        Hop parameterHop3 = getParameterHop(DataExpression.DELIM_SPARSE);
        Hop parameterHop4 = getParameterHop(DataExpression.DELIM_DELIMITER);
        Hop parameterHop5 = getParameterHop("linesep");
        long nnz = getInput().get(0).getNnz();
        if (nnz < 0) {
            nnz = 100 * 100;
        }
        long dim1 = getInput().get(0).getDim1();
        if (dim1 < 0) {
            dim1 = 100;
        }
        long dim2 = getInput().get(0).getDim2();
        if (dim2 < 0) {
            dim2 = 100;
        }
        if (parameterHop != null) {
            try {
                if (parameterHop instanceof LiteralOp) {
                    j4 = ((LiteralOp) parameterHop).getLongValue();
                }
            } catch (HopsException e) {
                LOG.warn("Invalid values when trying to compute dims1, dims2 & nnz", e);
                return 160036.0d;
            }
        }
        long j6 = dim1 < j4 ? dim1 : j4;
        if (parameterHop2 != null && (parameterHop2 instanceof LiteralOp)) {
            j5 = ((LiteralOp) parameterHop2).getLongValue();
        }
        long j7 = dim2 < j5 ? dim2 : j5;
        if (parameterHop3 != null && (parameterHop3 instanceof LiteralOp)) {
            z = ((LiteralOp) parameterHop3).getBooleanValue();
        }
        if (parameterHop4 != null && (parameterHop4 instanceof LiteralOp)) {
            str = ((LiteralOp) parameterHop4).getStringValue();
        }
        if (parameterHop5 != null && (parameterHop5 instanceof LiteralOp)) {
            str2 = ((LiteralOp) parameterHop5).getStringValue();
        }
        return 36 + ((z ? (7 * nnz) + (8 * nnz) + (str.length() * 2 * nnz) + (str2.length() * nnz) : (7 * j6 * j7) + (str.length() * j6 * (j7 - 1)) + (str2.length() * j6)) * 2);
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        double d = 0.0d;
        if (this._op == Types.ParamBuiltinOp.RMEMPTY) {
            Hop parameterHop = getParameterHop("margin");
            d = (parameterHop instanceof LiteralOp) && "cols".equals(((LiteralOp) parameterHop).getStringValue()) ? DataExpression.DEFAULT_DELIM_FILL_VALUE + (1 * j2) + (4 * j2) : DataExpression.DEFAULT_DELIM_FILL_VALUE + (1 * j);
        } else if (this._op == Types.ParamBuiltinOp.REXPAND && "rows".equals(((LiteralOp) getParameterHop("dir")).getStringValue())) {
            d = 12 * Math.min(j, LibMatrixReorg.PAR_NUMCELL_THRESHOLD);
        }
        return d;
    }

    @Override // org.apache.sysds.hops.Hop
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memoTable) {
        long rows;
        long cols;
        Hop parameterHop;
        MatrixCharacteristics matrixCharacteristics = null;
        DataCharacteristics allInputStats = memoTable.getAllInputStats(getTargetHop());
        if (this._op == Types.ParamBuiltinOp.GROUPEDAGG) {
            if (this._paramIndexMap.get(Statement.GAGG_NUM_GROUPS) != null && (parameterHop = getParameterHop(Statement.GAGG_NUM_GROUPS)) != null && (parameterHop instanceof LiteralOp)) {
                long intValueSafe = HopRewriteUtils.getIntValueSafe((LiteralOp) parameterHop);
                return new MatrixCharacteristics(intValueSafe, allInputStats.getRows() == 1 ? 1L : allInputStats.getCols(), -1, intValueSafe);
            }
            long rows2 = allInputStats.getRows();
            long cols2 = allInputStats.getRows() == 1 ? 1L : allInputStats.getCols();
            if (rows2 >= 1) {
                matrixCharacteristics = new MatrixCharacteristics(rows2, cols2, -1, rows2);
            }
        } else if (this._op == Types.ParamBuiltinOp.RMEMPTY) {
            if (allInputStats.dimsKnown()) {
                String str = "rows";
                Hop parameterHop2 = getParameterHop("margin");
                if ((parameterHop2 instanceof LiteralOp) && "cols".equals(((LiteralOp) parameterHop2).getStringValue())) {
                    str = new String("cols");
                }
                DataCharacteristics dataCharacteristics = null;
                if (this._paramIndexMap.get("select") != null) {
                    dataCharacteristics = memoTable.getAllInputStats(getParameterHop("select"));
                }
                if (str.equals("rows")) {
                    rows = (dataCharacteristics == null || !dataCharacteristics.nnzKnown()) ? allInputStats.getRows() : dataCharacteristics.getNonZeros();
                    cols = allInputStats.getCols();
                } else {
                    rows = allInputStats.getRows();
                    cols = (dataCharacteristics == null || !dataCharacteristics.nnzKnown()) ? allInputStats.getCols() : dataCharacteristics.getNonZeros();
                }
                matrixCharacteristics = new MatrixCharacteristics(rows, cols, -1, allInputStats.getNonZeros());
            }
        } else if (this._op == Types.ParamBuiltinOp.REPLACE) {
            if (allInputStats.dimsKnown()) {
                matrixCharacteristics = new MatrixCharacteristics(allInputStats.getRows(), allInputStats.getCols(), -1, isNonZeroReplaceArguments() ? allInputStats.getNonZeros() : -1L);
            }
        } else if (this._op == Types.ParamBuiltinOp.REXPAND) {
            Hop parameterHop3 = getParameterHop("max");
            Hop parameterHop4 = getParameterHop("dir");
            long computeDimParameterInformation = computeDimParameterInformation(parameterHop3, memoTable);
            String stringValue = ((LiteralOp) parameterHop4).getStringValue();
            if (allInputStats.dimsKnown()) {
                long nonZeros = allInputStats.nnzKnown() ? allInputStats.getNonZeros() : allInputStats.getRows();
                if ("cols".equals(stringValue)) {
                    matrixCharacteristics = new MatrixCharacteristics(allInputStats.getRows(), computeDimParameterInformation, -1, nonZeros);
                } else if ("rows".equals(stringValue)) {
                    matrixCharacteristics = new MatrixCharacteristics(computeDimParameterInformation, allInputStats.getRows(), -1, nonZeros);
                }
            }
        } else if (this._op == Types.ParamBuiltinOp.TRANSFORMDECODE) {
            if (allInputStats.dimsKnown()) {
                return new MatrixCharacteristics(allInputStats.getRows(), allInputStats.getCols(), -1, allInputStats.getLength());
            }
        } else if (this._op == Types.ParamBuiltinOp.TRANSFORMAPPLY && allInputStats.dimsKnown()) {
            return new MatrixCharacteristics(allInputStats.getRows(), allInputStats.getCols(), -1, allInputStats.getLength());
        }
        return matrixCharacteristics;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean allowsAllExecTypes() {
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.hops.Hop
    public Types.ExecType optFindExecType() {
        checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            if (OptimizerUtils.isMemoryBasedOptLevel()) {
                this._etype = findExecTypeByMemEstimate();
            } else if (this._op == Types.ParamBuiltinOp.GROUPEDAGG && getTargetHop().areDimsBelowThreshold()) {
                this._etype = Types.ExecType.CP;
            } else {
                this._etype = Types.ExecType.SPARK;
            }
            checkAndSetInvalidCPDimsAndSize();
        }
        if (this._op == Types.ParamBuiltinOp.TRANSFORMCOLMAP || this._op == Types.ParamBuiltinOp.TRANSFORMMETA || this._op == Types.ParamBuiltinOp.TOSTRING || this._op == Types.ParamBuiltinOp.LIST || this._op == Types.ParamBuiltinOp.CDF || this._op == Types.ParamBuiltinOp.INVCDF || this._op == Types.ParamBuiltinOp.PARAMSERV) {
            this._etype = Types.ExecType.CP;
        }
        setRequiresRecompileIfNecessary();
        return this._etype;
    }

    @Override // org.apache.sysds.hops.Hop
    public void refreshSizeInformation() {
        Hop parameterHop;
        switch (this._op) {
            case GROUPEDAGG:
                long j = -1;
                if (this._paramIndexMap.get(Statement.GAGG_NUM_GROUPS) != null && (parameterHop = getParameterHop(Statement.GAGG_NUM_GROUPS)) != null && (parameterHop instanceof LiteralOp)) {
                    j = HopRewriteUtils.getIntValueSafe((LiteralOp) parameterHop);
                }
                Hop targetHop = getTargetHop();
                long dim2 = targetHop.getDim1() == 1 ? 1L : targetHop.getDim2();
                setDim1(j);
                setDim2(dim2);
                return;
            case RMEMPTY:
                Hop targetHop2 = getTargetHop();
                Hop parameterHop2 = getParameterHop("margin");
                Hop parameterHop3 = getParameterHop("select");
                if (parameterHop2 instanceof LiteralOp) {
                    LiteralOp literalOp = (LiteralOp) parameterHop2;
                    if ("rows".equals(literalOp.getStringValue())) {
                        setDim2(targetHop2.getDim2());
                        if (parameterHop3 != null) {
                            setDim1(parameterHop3.getNnz());
                        }
                    } else if ("cols".equals(literalOp.getStringValue())) {
                        setDim1(targetHop2.getDim1());
                        if (parameterHop3 != null) {
                            setDim2(parameterHop3.getNnz());
                        }
                    }
                }
                setNnz(targetHop2.getNnz());
                return;
            case REXPAND:
                Hop targetHop3 = getTargetHop();
                Hop parameterHop4 = getParameterHop("max");
                Hop parameterHop5 = getParameterHop("dir");
                double computeSizeInformation = computeSizeInformation(parameterHop4);
                String stringValue = ((LiteralOp) parameterHop5).getStringValue();
                if ("cols".equals(stringValue)) {
                    setDim1(targetHop3.getDim1());
                    setDim2(UtilFunctions.toLong(computeSizeInformation));
                    return;
                } else {
                    if ("rows".equals(stringValue)) {
                        setDim1(UtilFunctions.toLong(computeSizeInformation));
                        setDim2(targetHop3.getDim1());
                        return;
                    }
                    return;
                }
            case CDF:
            case INVCDF:
            case TOKENIZE:
            case TRANSFORMAPPLY:
            case TRANSFORMMETA:
            case TOSTRING:
            case PARAMSERV:
            default:
                return;
            case REPLACE:
                Hop targetHop4 = getTargetHop();
                setDim1(targetHop4.getDim1());
                setDim2(targetHop4.getDim2());
                if (isNonZeroReplaceArguments()) {
                    setNnz(targetHop4.getNnz());
                    return;
                }
                return;
            case LOWER_TRI:
            case UPPER_TRI:
                Hop targetHop5 = getTargetHop();
                setDim1(targetHop5.getDim1());
                setDim2(targetHop5.getDim2());
                return;
            case TRANSFORMDECODE:
                setDim1(getTargetHop().getDim1());
                return;
            case TRANSFORMCOLMAP:
                setDim1(getTargetHop().getDim2());
                setDim2(3L);
                return;
            case LIST:
                setDim1(getInput().size());
                setDim2(1L);
                return;
        }
    }

    @Override // org.apache.sysds.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        ParameterizedBuiltinOp parameterizedBuiltinOp = new ParameterizedBuiltinOp();
        parameterizedBuiltinOp.clone(this, false);
        parameterizedBuiltinOp._op = this._op;
        parameterizedBuiltinOp._outputEmptyBlocks = this._outputEmptyBlocks;
        parameterizedBuiltinOp._outputPermutationMatrix = this._outputPermutationMatrix;
        parameterizedBuiltinOp._paramIndexMap = (HashMap) this._paramIndexMap.clone();
        return parameterizedBuiltinOp;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean compare(Hop hop) {
        if (!(hop instanceof ParameterizedBuiltinOp)) {
            return false;
        }
        ParameterizedBuiltinOp parameterizedBuiltinOp = (ParameterizedBuiltinOp) hop;
        boolean z = this._op == parameterizedBuiltinOp._op && this._paramIndexMap != null && parameterizedBuiltinOp._paramIndexMap != null && this._paramIndexMap.size() == parameterizedBuiltinOp._paramIndexMap.size() && this._outputEmptyBlocks == parameterizedBuiltinOp._outputEmptyBlocks && this._outputPermutationMatrix == parameterizedBuiltinOp._outputPermutationMatrix;
        if (z) {
            for (Map.Entry<String, Integer> entry : this._paramIndexMap.entrySet()) {
                String key = entry.getKey();
                int intValue = entry.getValue().intValue();
                int intValue2 = parameterizedBuiltinOp._paramIndexMap.get(key).intValue();
                z &= parameterizedBuiltinOp.getInput().get(intValue2) != null && getInput().get(intValue) == parameterizedBuiltinOp.getInput().get(intValue2);
            }
        }
        return z;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean isTransposeSafe() {
        boolean z;
        boolean z2 = false;
        try {
            if (this._op == Types.ParamBuiltinOp.GROUPEDAGG) {
                Hop hop = getInput().get(this._paramIndexMap.get(Statement.GAGG_FN).intValue());
                if (hop instanceof LiteralOp) {
                    if (Statement.GAGG_FN_SUM.equals(((LiteralOp) hop).getStringValue())) {
                        z = true;
                        z2 = z;
                    }
                }
                z = false;
                z2 = z;
            }
        } catch (Exception e) {
            LOG.warn("Check for transpose-safeness failed, continue assuming false.", e);
        }
        return z2;
    }

    public boolean isCountFunction() {
        boolean z;
        boolean z2 = false;
        try {
            if (this._op == Types.ParamBuiltinOp.GROUPEDAGG) {
                Hop parameterHop = getParameterHop(Statement.GAGG_FN);
                if (parameterHop instanceof LiteralOp) {
                    if (Statement.GAGG_FN_COUNT.equals(((LiteralOp) parameterHop).getStringValue())) {
                        z = true;
                        z2 = z;
                    }
                }
                z = false;
                z2 = z;
            }
        } catch (Exception e) {
            LOG.warn("Check for count function failed, continue assuming false.", e);
        }
        return z2;
    }

    private boolean isNonZeroReplaceArguments() {
        boolean z = false;
        try {
            Hop parameterHop = getParameterHop("pattern");
            Hop parameterHop2 = getParameterHop("replacement");
            if ((parameterHop instanceof LiteralOp) && ((LiteralOp) parameterHop).getDoubleValue() != DataExpression.DEFAULT_DELIM_FILL_VALUE && (parameterHop2 instanceof LiteralOp)) {
                if (((LiteralOp) parameterHop2).getDoubleValue() != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    z = true;
                }
            }
        } catch (Exception e) {
            LOG.warn("Non Zero Replace Arguments exception: " + e.getMessage());
        }
        return z;
    }

    public boolean isTargetDiagInput() {
        Hop targetHop = getTargetHop();
        return (targetHop instanceof ReorgOp) && ((ReorgOp) targetHop).getOp() == Types.ReOrgOp.DIAG && targetHop.getInput().get(0).getDim2() == 1;
    }

    public List<FunctionOp> getParamservPseudoFunctionCalls() {
        try {
            String[] splitFunctionKey = DMLProgram.splitFunctionKey(((LiteralOp) getParameterHop(Statement.PS_UPDATE_FUN)).getStringValue());
            String[] splitFunctionKey2 = DMLProgram.splitFunctionKey(((LiteralOp) getParameterHop(Statement.PS_AGGREGATION_FUN)).getStringValue());
            String[] splitFunctionKey3 = getParameterHop(Statement.PS_VAL_FUN) == null ? null : DMLProgram.splitFunctionKey(((LiteralOp) getParameterHop(Statement.PS_VAL_FUN)).getStringValue());
            Hop parameterHop = getParameterHop(Statement.PS_MODEL);
            Hop parameterHop2 = getParameterHop(Statement.PS_HYPER_PARAMS);
            Hop hop = (Hop) ObjectUtils.defaultIfNull(getParameterHop(Statement.PS_BATCH_SIZE), new LiteralOp(64L));
            FunctionOp functionOp = new FunctionOp(FunctionOp.FunctionType.DML, splitFunctionKey[0], splitFunctionKey[1], new String[]{Statement.PS_MODEL, Statement.PS_HYPER_PARAMS, Statement.PS_FEATURES, Statement.PS_LABELS}, Arrays.asList(parameterHop, parameterHop2, HopRewriteUtils.createIndexingOp(getParameterHop(Statement.PS_FEATURES), hop), HopRewriteUtils.createIndexingOp(getParameterHop(Statement.PS_LABELS), hop)), new String[]{Statement.PS_GRADIENTS}, false, true);
            FunctionOp functionOp2 = new FunctionOp(FunctionOp.FunctionType.DML, splitFunctionKey2[0], splitFunctionKey2[1], new String[]{Statement.PS_MODEL, Statement.PS_HYPER_PARAMS, Statement.PS_GRADIENTS}, Arrays.asList(parameterHop, parameterHop2, functionOp), new String[]{Statement.PS_MODEL}, false, true);
            return splitFunctionKey3 == null ? Arrays.asList(functionOp, functionOp2) : Arrays.asList(functionOp, functionOp2, splitFunctionKey3 == null ? null : new FunctionOp(FunctionOp.FunctionType.DML, splitFunctionKey3[0], splitFunctionKey3[1], new String[]{Statement.PS_MODEL, Statement.PS_HYPER_PARAMS, "valfeatures", "vallabels"}, Arrays.asList(parameterHop, parameterHop2, getParameterHop(Statement.PS_VAL_FEATURES), getParameterHop(Statement.PS_VAL_LABELS)), new String[]{"loss", "accuracy"}, false, true));
        } catch (Exception e) {
            return Collections.emptyList();
        }
    }

    private boolean isRemoveEmptyBcSP() {
        double outputMemEstimate;
        Hop hop = getInput().get(0);
        Hop parameterHop = getParameterHop("margin");
        boolean equals = parameterHop instanceof LiteralOp ? ((LiteralOp) parameterHop).getStringValue().equals("cols") : false;
        if (hop.dimsKnown()) {
            outputMemEstimate = OptimizerUtils.estimateSize(equals ? hop.getDim2() : hop.getDim1(), 1L);
        } else {
            outputMemEstimate = hop.getOutputMemEstimate();
        }
        return OptimizerUtils.checkSparkBroadcastMemoryBudget(outputMemEstimate);
    }
}
