package org.apache.sysds.hops;

import java.util.HashMap;
import java.util.Map;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.DataGen;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/hops/DataGenOp.class */
public class DataGenOp extends MultiThreadedHop {
    public static final long UNSPECIFIED_SEED = -1;
    private Types.OpOpDG _op;
    private HashMap<String, Integer> _paramIndexMap;
    private DataIdentifier _id;
    private double _sparsity;
    private String _baseDir;
    private double _incr;

    private DataGenOp() {
        this._paramIndexMap = new HashMap<>();
        this._sparsity = -1.0d;
        this._incr = Double.MAX_VALUE;
    }

    public DataGenOp(Types.OpOpDG opOpDG, DataIdentifier dataIdentifier, HashMap<String, Hop> hashMap) {
        super(dataIdentifier.getName(), dataIdentifier.getDataType().isUnknown() ? Types.DataType.MATRIX : dataIdentifier.getDataType(), dataIdentifier.getValueType().isUnknown() ? Types.ValueType.FP64 : dataIdentifier.getValueType());
        this._paramIndexMap = new HashMap<>();
        this._sparsity = -1.0d;
        this._incr = Double.MAX_VALUE;
        this._id = dataIdentifier;
        this._op = opOpDG;
        if (hashMap.containsKey(DataExpression.RAND_DIMS)) {
            setDataType(Types.DataType.TENSOR);
        }
        int i = 0;
        for (Map.Entry<String, Hop> entry : hashMap.entrySet()) {
            String key = entry.getKey();
            Hop value = entry.getValue();
            getInput().add(value);
            value.getParent().add(this);
            this._paramIndexMap.put(key, Integer.valueOf(i));
            i++;
        }
        Hop hop = hashMap.get(DataExpression.RAND_SPARSITY);
        if (opOpDG == Types.OpOpDG.RAND && (hop instanceof LiteralOp)) {
            this._sparsity = HopRewriteUtils.getDoubleValue((LiteralOp) hop);
        }
        this._baseDir = ConfigurationManager.getScratchSpace() + "/_p" + DMLScript.getUUID() + "//_t0/";
        refreshSizeInformation();
    }

    public DataGenOp(Types.OpOpDG opOpDG, DataIdentifier dataIdentifier) {
        super(dataIdentifier.getName(), Types.DataType.SCALAR, Types.ValueType.INT64);
        this._paramIndexMap = new HashMap<>();
        this._sparsity = -1.0d;
        this._incr = Double.MAX_VALUE;
        this._id = dataIdentifier;
        this._op = opOpDG;
        this._baseDir = ConfigurationManager.getScratchSpace() + "/_p" + DMLScript.getUUID() + "//_t0/";
        refreshSizeInformation();
    }

    @Override // org.apache.sysds.hops.Hop
    public void checkArity() {
        int size = this._input.size();
        int size2 = this._paramIndexMap.size();
        HopsException.check(size == size2, this, "has %d inputs but %d parameters", Integer.valueOf(size), Integer.valueOf(size2));
    }

    @Override // org.apache.sysds.hops.Hop
    public String getOpString() {
        return "dg(" + this._op.toString().toLowerCase() + ")";
    }

    public Types.OpOpDG getOp() {
        return this._op;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean isGPUEnabled() {
        return false;
    }

    @Override // org.apache.sysds.hops.MultiThreadedHop
    public boolean isMultiThreadedOpType() {
        return this._op == Types.OpOpDG.RAND;
    }

    @Override // org.apache.sysds.hops.Hop
    public Lop constructLops() {
        if (getLops() != null) {
            return getLops();
        }
        Types.ExecType optFindExecType = optFindExecType();
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, Integer> entry : this._paramIndexMap.entrySet()) {
            if (entry.getKey().equals("rows") && rowsKnown()) {
                hashMap.put(entry.getKey(), new LiteralOp(getDim1()).constructLops());
            } else if (entry.getKey().equals("cols") && colsKnown()) {
                hashMap.put(entry.getKey(), new LiteralOp(getDim2()).constructLops());
            } else {
                hashMap.put(entry.getKey(), getInput().get(entry.getValue().intValue()).constructLops());
            }
        }
        DataGen dataGen = new DataGen(this._op, this._id, hashMap, this._baseDir, getDataType(), getValueType(), optFindExecType);
        dataGen.setNumThreads(OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        dataGen.getOutputParameters().setDimensions(getDim1(), getDim2(), getBlocksize() > 0 ? getBlocksize() : ConfigurationManager.getBlocksize(), (this._op == Types.OpOpDG.RAND && optFindExecType == Types.ExecType.SPARK && getNnz() != 0) ? -1L : getNnz(), getUpdateType());
        setLineNumbers(dataGen);
        setLops(dataGen);
        constructAndSetLopsDataFlowProperties();
        return getLops();
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        return (this._op != Types.OpOpDG.RAND || this._sparsity == -1.0d) ? OptimizerUtils.estimateSizeExactSparsity(j, j2, 1.0d) : hasConstantValue(DataExpression.DEFAULT_DELIM_FILL_VALUE) ? OptimizerUtils.estimateSizeEmptyBlock(j, j2) : OptimizerUtils.estimateSizeExactSparsity(j, j2, this._sparsity);
    }

    @Override // org.apache.sysds.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        return (this._op == Types.OpOpDG.RAND && dimsKnown()) ? 32.0d + (((long) (Math.ceil(j / ConfigurationManager.getBlocksize()) * Math.ceil(j2 / ConfigurationManager.getBlocksize()))) * 8.0d) : DataExpression.DEFAULT_DELIM_FILL_VALUE;
    }

    @Override // org.apache.sysds.hops.Hop
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memoTable) {
        if ((this._op == Types.OpOpDG.RAND || this._op == Types.OpOpDG.SINIT) && OptimizerUtils.ALLOW_WORSTCASE_SIZE_EXPRESSION_EVALUATION) {
            if (this._paramIndexMap.containsKey(DataExpression.RAND_DIMS)) {
                return null;
            }
            long computeDimParameterInformation = computeDimParameterInformation(getInput().get(this._paramIndexMap.get("rows").intValue()), memoTable);
            long computeDimParameterInformation2 = computeDimParameterInformation(getInput().get(this._paramIndexMap.get("cols").intValue()), memoTable);
            long j = this._sparsity >= DataExpression.DEFAULT_DELIM_FILL_VALUE ? (long) (this._sparsity * computeDimParameterInformation * computeDimParameterInformation2) : -1L;
            if (computeDimParameterInformation < 0 || computeDimParameterInformation2 < 0) {
                return null;
            }
            return new MatrixCharacteristics(computeDimParameterInformation, computeDimParameterInformation2, -1, j);
        }
        if (this._op != Types.OpOpDG.SEQ) {
            return null;
        }
        Hop hop = getInput().get(this._paramIndexMap.get(Statement.SEQ_FROM).intValue());
        Hop hop2 = getInput().get(this._paramIndexMap.get(Statement.SEQ_TO).intValue());
        Hop hop3 = getInput().get(this._paramIndexMap.get(Statement.SEQ_INCR).intValue());
        if ((hop instanceof LiteralOp) && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop) == 1.0d && (hop3 instanceof LiteralOp) && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop3) == 1.0d) {
            long computeDimParameterInformation3 = computeDimParameterInformation(hop2, memoTable);
            if (computeDimParameterInformation3 > 0) {
                return new MatrixCharacteristics(computeDimParameterInformation3, 1L, -1, -1L);
            }
        }
        if (!(hop2 instanceof LiteralOp) || HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop2) != 1.0d || !(hop3 instanceof LiteralOp) || HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop3) != -1.0d) {
            return null;
        }
        long computeDimParameterInformation4 = computeDimParameterInformation(hop, memoTable);
        if (computeDimParameterInformation4 > 0) {
            return new MatrixCharacteristics(computeDimParameterInformation4, 1L, -1, -1L);
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.hops.Hop
    public Types.ExecType optFindExecType(boolean z) {
        checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            if (OptimizerUtils.isMemoryBasedOptLevel()) {
                this._etype = findExecTypeByMemEstimate();
            } else if (areDimsBelowThreshold() || isVector()) {
                this._etype = Types.ExecType.CP;
            } else {
                this._etype = Types.ExecType.SPARK;
            }
            checkAndSetInvalidCPDimsAndSize();
        }
        setRequiresRecompileIfNecessary();
        if (this._op == Types.OpOpDG.SINIT || this._op == Types.OpOpDG.TIME) {
            this._etype = Types.ExecType.CP;
        }
        return this._etype;
    }

    @Override // org.apache.sysds.hops.Hop
    public void refreshSizeInformation() {
        if (this._op == Types.OpOpDG.RAND || this._op == Types.OpOpDG.SINIT || this._op == Types.OpOpDG.FRAMEINIT) {
            if (this._dataType != Types.DataType.TENSOR) {
                Hop hop = getInput().get(this._paramIndexMap.get("rows").intValue());
                Hop hop2 = getInput().get(this._paramIndexMap.get("cols").intValue());
                refreshRowsParameterInformation(hop);
                refreshColsParameterInformation(hop2);
            }
        } else if (this._op == Types.OpOpDG.SEQ) {
            Hop hop3 = getInput().get(this._paramIndexMap.get(Statement.SEQ_FROM).intValue());
            Hop hop4 = getInput().get(this._paramIndexMap.get(Statement.SEQ_TO).intValue());
            Hop hop5 = getInput().get(this._paramIndexMap.get(Statement.SEQ_INCR).intValue());
            double computeBoundsInformation = computeBoundsInformation(hop3);
            boolean z = computeBoundsInformation != Double.MAX_VALUE;
            double computeBoundsInformation2 = computeBoundsInformation(hop4);
            boolean z2 = computeBoundsInformation2 != Double.MAX_VALUE;
            double computeBoundsInformation3 = computeBoundsInformation(hop5);
            boolean z3 = computeBoundsInformation3 != Double.MAX_VALUE;
            if (z && z2 && computeBoundsInformation3 == 1.0d) {
                computeBoundsInformation3 = computeBoundsInformation >= computeBoundsInformation2 ? -1.0d : 1.0d;
            }
            if (z && z2 && z3) {
                setDim1(UtilFunctions.getSeqLength(computeBoundsInformation, computeBoundsInformation2, computeBoundsInformation3, false));
                setDim2(1L);
                this._incr = computeBoundsInformation3;
            }
            if (getDim1() == -1 && getParent().size() == 1) {
                Hop hop6 = getParent().get(0);
                hop6.refreshSizeInformation();
                setDim1((!HopRewriteUtils.isTernary(hop6, Types.OpOp3.CTABLE) || hop6.getDim1() < 0) ? -1L : hop6.getDim1());
            }
        } else if (this._op == Types.OpOpDG.TIME) {
            setDim1(0L);
            setDim2(0L);
            this._dataType = Types.DataType.SCALAR;
            this._valueType = Types.ValueType.INT64;
        }
        if (this._op == Types.OpOpDG.RAND && hasConstantValue(DataExpression.DEFAULT_DELIM_FILL_VALUE)) {
            setNnz(0L);
        } else if (!dimsKnown() || this._sparsity < DataExpression.DEFAULT_DELIM_FILL_VALUE) {
            setNnz(-1L);
        } else {
            setNnz((long) (this._sparsity * getLength()));
        }
    }

    public HashMap<String, Integer> getParamIndexMap() {
        return this._paramIndexMap;
    }

    public Hop getParam(String str) {
        return getInput().get(getParamIndex(str));
    }

    public int getParamIndex(String str) {
        return this._paramIndexMap.get(str).intValue();
    }

    public Hop getInput(String str) {
        return getInput().get(getParamIndex(str));
    }

    public void setInput(String str, Hop hop, boolean z) {
        getInput().set(getParamIndex(str), hop);
        if (z) {
            hop.getParent().add(this);
        }
    }

    public boolean hasConstantValue() {
        if (this._op != Types.OpOpDG.RAND) {
            return false;
        }
        Hop hop = getInput().get(this._paramIndexMap.get("min").intValue());
        Hop hop2 = getInput().get(this._paramIndexMap.get("max").intValue());
        Hop hop3 = getInput().get(this._paramIndexMap.get(DataExpression.RAND_SPARSITY).intValue());
        if (!(hop instanceof LiteralOp) || !(hop2 instanceof LiteralOp) || !(hop3 instanceof LiteralOp)) {
            return hop == hop2 && (hop3 instanceof LiteralOp) && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop3) == 1.0d;
        }
        try {
            return HopRewriteUtils.getDoubleValue((LiteralOp) hop3) == 1.0d && HopRewriteUtils.getDoubleValue((LiteralOp) hop) == HopRewriteUtils.getDoubleValue((LiteralOp) hop2);
        } catch (Exception e) {
            return false;
        }
    }

    public boolean hasConstantValue(double d) {
        if (this._op != Types.OpOpDG.RAND) {
            return false;
        }
        boolean z = false;
        Hop hop = getInput().get(this._paramIndexMap.get("min").intValue());
        Hop hop2 = getInput().get(this._paramIndexMap.get("max").intValue());
        if ((hop instanceof LiteralOp) && (hop2 instanceof LiteralOp)) {
            z = HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop) == d && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop2) == d;
        }
        if (z && d != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
            Hop hop3 = getInput().get(this._paramIndexMap.get(DataExpression.RAND_SPARSITY).intValue());
            z = hop3 == null || ((hop3 instanceof LiteralOp) && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop3) == 1.0d);
        }
        return z;
    }

    public boolean hasUnspecifiedSeed() {
        if (this._op == Types.OpOpDG.RAND || this._op == Types.OpOpDG.SINIT) {
            return getInput().get(this._paramIndexMap.get("seed").intValue()).getName().equals(String.valueOf(-1L));
        }
        return false;
    }

    public Hop getConstantValue() {
        return getInput().get(this._paramIndexMap.get("min").intValue());
    }

    public void setIncrementValue(double d) {
        this._incr = d;
    }

    public double getIncrementValue() {
        return this._incr;
    }

    public static long generateRandomSeed() {
        return System.nanoTime();
    }

    @Override // org.apache.sysds.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        DataGenOp dataGenOp = new DataGenOp();
        dataGenOp.clone(this, false);
        dataGenOp._op = this._op;
        dataGenOp._id = this._id;
        dataGenOp._sparsity = this._sparsity;
        dataGenOp._baseDir = this._baseDir;
        dataGenOp._paramIndexMap = (HashMap) this._paramIndexMap.clone();
        dataGenOp._maxNumThreads = this._maxNumThreads;
        return dataGenOp;
    }

    @Override // org.apache.sysds.hops.Hop
    public boolean compare(Hop hop) {
        if (!(hop instanceof DataGenOp) || this._op == Types.OpOpDG.TIME) {
            return false;
        }
        DataGenOp dataGenOp = (DataGenOp) hop;
        boolean z = this._op == dataGenOp._op && this._sparsity == dataGenOp._sparsity && this._baseDir.equals(dataGenOp._baseDir) && this._paramIndexMap != null && dataGenOp._paramIndexMap != null && this._maxNumThreads == dataGenOp._maxNumThreads;
        if (z) {
            for (Map.Entry<String, Integer> entry : this._paramIndexMap.entrySet()) {
                String key = entry.getKey();
                int intValue = entry.getValue().intValue();
                int intValue2 = dataGenOp._paramIndexMap.getOrDefault(key, -1).intValue();
                z &= intValue2 >= 0 && dataGenOp.getInput().get(intValue2) != null && getInput().get(intValue) == dataGenOp.getInput().get(intValue2);
            }
            if (this._op == Types.OpOpDG.RAND || this._op == Types.OpOpDG.SINIT) {
                Hop hop2 = getInput().get(this._paramIndexMap.get("seed").intValue());
                Hop hop3 = getInput().get(this._paramIndexMap.get("min").intValue());
                Hop hop4 = getInput().get(this._paramIndexMap.get("max").intValue());
                if (hop2.getName().equals(String.valueOf(-1L)) && hop3 != hop4) {
                    z = false;
                }
            }
        }
        return z;
    }
}
