package org.apache.sysds.runtime.instructions.cp;

import java.util.Arrays;
import java.util.Random;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.DataGenOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.DataGen;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.RandomMatrixGenerator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.class */
public class DataGenCPInstruction extends UnaryCPInstruction {
    private static final Log LOG = LogFactory.getLog(DataGenCPInstruction.class.getName());
    private Types.OpOpDG method;
    private final CPOperand rows;
    private final CPOperand cols;
    private final CPOperand dims;
    private final int blocksize;
    private boolean minMaxAreDoubles;
    private final String minValueStr;
    private final String maxValueStr;
    private final double minValue;
    private final double maxValue;
    private final double sparsity;
    private final String pdf;
    private final String pdfParams;
    private final String frame_data;
    private final String schema;
    private final long seed;
    private Long runtimeSeed;
    private final CPOperand seq_from;
    private final CPOperand seq_to;
    private final CPOperand seq_incr;
    private final boolean replace;
    private final int numThreads;
    private static final int SEED_POSITION_RAND = 8;
    private static final int SEED_POSITION_SAMPLE = 4;

    private DataGenCPInstruction(Operator operator, Types.OpOpDG opOpDG, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, int i, String str, String str2, double d, long j, String str3, String str4, int i2, CPOperand cPOperand6, CPOperand cPOperand7, CPOperand cPOperand8, boolean z, String str5, String str6, String str7, String str8) {
        super(CPInstruction.CPType.Rand, operator, cPOperand, cPOperand2, str7, str8);
        double d2;
        double d3;
        this.method = opOpDG;
        this.rows = cPOperand3;
        this.cols = cPOperand4;
        this.dims = cPOperand5;
        this.blocksize = i;
        this.minValueStr = str;
        this.maxValueStr = str2;
        try {
            d2 = !str.contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? Double.valueOf(str).doubleValue() : -1.0d;
            d3 = !str2.contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? Double.valueOf(str2).doubleValue() : -1.0d;
            this.minMaxAreDoubles = true;
        } catch (NumberFormatException e) {
            if (!this.minValueStr.equals(this.maxValueStr)) {
                throw new DMLRuntimeException("Rand instruction does not support non numeric Datatypes for range initializations.");
            }
            d2 = -1.0d;
            d3 = -1.0d;
            this.minMaxAreDoubles = false;
        }
        this.minValue = d2;
        this.maxValue = d3;
        this.sparsity = d;
        this.seed = j;
        this.pdf = str3;
        this.pdfParams = str4;
        this.numThreads = i2;
        this.seq_from = cPOperand6;
        this.seq_to = cPOperand7;
        this.seq_incr = cPOperand8;
        this.replace = z;
        this.frame_data = str5;
        this.schema = str6;
    }

    private DataGenCPInstruction(Operator operator, Types.OpOpDG opOpDG, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, int i, String str, String str2, double d, long j, String str3, String str4, int i2, String str5, String str6) {
        this(operator, opOpDG, cPOperand, cPOperand2, cPOperand3, cPOperand4, cPOperand5, i, str, str2, d, j, str3, str4, i2, null, null, null, false, null, null, str5, str6);
    }

    private DataGenCPInstruction(Operator operator, Types.OpOpDG opOpDG, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, int i, String str, boolean z, long j, String str2, String str3) {
        this(operator, opOpDG, cPOperand, cPOperand2, cPOperand3, cPOperand4, cPOperand5, i, "0", str, 1.0d, j, null, null, 1, null, null, null, z, null, null, str2, str3);
    }

    private DataGenCPInstruction(Operator operator, Types.OpOpDG opOpDG, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, int i, CPOperand cPOperand6, CPOperand cPOperand7, CPOperand cPOperand8, String str, String str2) {
        this(operator, opOpDG, cPOperand, cPOperand2, cPOperand3, cPOperand4, cPOperand5, i, "0", "1", 1.0d, -1L, null, null, 1, cPOperand6, cPOperand7, cPOperand8, false, null, null, str, str2);
    }

    private DataGenCPInstruction(Operator operator, Types.OpOpDG opOpDG, CPOperand cPOperand, String str, String str2) {
        this(operator, opOpDG, null, cPOperand, null, null, null, 0, "0", "0", DataExpression.DEFAULT_DELIM_FILL_VALUE, 0L, null, null, 1, null, null, null, false, null, null, str, str2);
    }

    public DataGenCPInstruction(Operator operator, Types.OpOpDG opOpDG, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, String str3, String str4) {
        this(operator, opOpDG, null, cPOperand, cPOperand2, cPOperand3, null, 0, "0", "0", DataExpression.DEFAULT_DELIM_FILL_VALUE, 0L, null, null, 1, null, null, null, false, str, str2, str3, str4);
    }

    public long getRows() {
        if (this.rows.isLiteral()) {
            return UtilFunctions.parseToLong(this.rows.getName());
        }
        return -1L;
    }

    public long getCols() {
        if (this.cols.isLiteral()) {
            return UtilFunctions.parseToLong(this.cols.getName());
        }
        return -1L;
    }

    public String getDims() {
        return this.dims.getName();
    }

    public int getBlocksize() {
        return this.blocksize;
    }

    public double getMinValue() {
        return this.minValue;
    }

    public double getMaxValue() {
        return this.maxValue;
    }

    public double getSparsity() {
        return this.sparsity;
    }

    public String getPdf() {
        return this.pdf;
    }

    public String getPdfParams() {
        return this.pdfParams;
    }

    public long getSeed() {
        return this.seed;
    }

    public boolean isOnesCol() {
        return this.minValue == this.maxValue && this.minValue == 1.0d && this.sparsity == 1.0d && getCols() == 1;
    }

    public boolean isMatrixCall() {
        return this.minValue == this.maxValue && this.sparsity == 1.0d;
    }

    public long getFrom() {
        if (this.seq_from.isLiteral()) {
            return UtilFunctions.parseToLong(this.seq_from.getName());
        }
        return -1L;
    }

    public long getTo() {
        if (this.seq_to.isLiteral()) {
            return UtilFunctions.parseToLong(this.seq_to.getName());
        }
        return -1L;
    }

    public long getIncr() {
        if (this.seq_incr.isLiteral()) {
            return UtilFunctions.parseToLong(this.seq_incr.getName());
        }
        return -1L;
    }

    public static DataGenCPInstruction parseInstruction(String str) {
        int i;
        Types.OpOpDG opOpDG = null;
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (str2.equalsIgnoreCase(DataGen.RAND_OPCODE)) {
            opOpDG = Types.OpOpDG.RAND;
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 10, 11);
        } else if (str2.equalsIgnoreCase(DataGen.SEQ_OPCODE)) {
            opOpDG = Types.OpOpDG.SEQ;
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 7);
        } else if (str2.equalsIgnoreCase(DataGen.SAMPLE_OPCODE)) {
            opOpDG = Types.OpOpDG.SAMPLE;
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 6);
        } else if (str2.equalsIgnoreCase(DataGen.TIME_OPCODE)) {
            opOpDG = Types.OpOpDG.TIME;
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 1);
        } else if (str2.equalsIgnoreCase("frame")) {
            opOpDG = Types.OpOpDG.FRAMEINIT;
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 5);
        }
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[instructionPartsWithValueType.length - 1]);
        if (opOpDG != Types.OpOpDG.RAND) {
            if (opOpDG == Types.OpOpDG.SEQ) {
                return new DataGenCPInstruction((Operator) null, opOpDG, (CPOperand) null, cPOperand, (CPOperand) null, (CPOperand) null, (CPOperand) null, Integer.parseInt(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), new CPOperand(instructionPartsWithValueType[5]), new CPOperand(instructionPartsWithValueType[6]), str2, str);
            }
            if (opOpDG == Types.OpOpDG.FRAMEINIT) {
                return new DataGenCPInstruction(null, opOpDG, cPOperand, new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), instructionPartsWithValueType[1], instructionPartsWithValueType[4], str2, str);
            }
            if (opOpDG == Types.OpOpDG.SAMPLE) {
                return new DataGenCPInstruction((Operator) null, opOpDG, (CPOperand) null, cPOperand, new CPOperand(instructionPartsWithValueType[2]), new CPOperand("1", Types.ValueType.INT64, Types.DataType.SCALAR), (CPOperand) null, Integer.parseInt(instructionPartsWithValueType[5]), instructionPartsWithValueType[1], !instructionPartsWithValueType[3].contains(Lop.VARIABLE_NAME_PLACEHOLDER) && Boolean.valueOf(instructionPartsWithValueType[3]).booleanValue(), Long.parseLong(instructionPartsWithValueType[4]), str2, str);
            }
            if (opOpDG == Types.OpOpDG.TIME) {
                return new DataGenCPInstruction(null, opOpDG, cPOperand, str2, str);
            }
            throw new DMLRuntimeException("Unrecognized data generation method: " + opOpDG);
        }
        CPOperand cPOperand2 = null;
        CPOperand cPOperand3 = null;
        CPOperand cPOperand4 = null;
        if (instructionPartsWithValueType.length == 12) {
            i = 1;
            cPOperand2 = new CPOperand(instructionPartsWithValueType[1]);
            cPOperand3 = new CPOperand(instructionPartsWithValueType[2]);
        } else {
            i = 2;
            cPOperand4 = new CPOperand(instructionPartsWithValueType[1]);
        }
        return new DataGenCPInstruction(null, opOpDG, null, cPOperand, cPOperand2, cPOperand3, cPOperand4, Integer.parseInt(instructionPartsWithValueType[4 - i]), instructionPartsWithValueType[5 - i], instructionPartsWithValueType[6 - i], !instructionPartsWithValueType[7 - i].contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? Double.parseDouble(instructionPartsWithValueType[7 - i]) : -1.0d, !instructionPartsWithValueType[8 - i].contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? Long.parseLong(instructionPartsWithValueType[8 - i]) : -1L, instructionPartsWithValueType[9 - i], !instructionPartsWithValueType[10 - i].contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? instructionPartsWithValueType[10 - i] : null, Integer.parseInt(instructionPartsWithValueType[11 - i]), str2, str);
    }

    @Override // org.apache.sysds.runtime.instructions.cp.CPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        CacheBlock cacheBlock = null;
        IntObject intObject = null;
        if (this.method == Types.OpOpDG.RAND) {
            cacheBlock = processRandInstruction(executionContext);
        } else if (this.method == Types.OpOpDG.SEQ) {
            double doubleValue = executionContext.getScalarInput(this.seq_from).getDoubleValue();
            double doubleValue2 = executionContext.getScalarInput(this.seq_to).getDoubleValue();
            double updateSeqIncr = LibMatrixDatagen.updateSeqIncr(doubleValue, doubleValue2, executionContext.getScalarInput(this.seq_incr).getDoubleValue());
            if (LOG.isTraceEnabled()) {
                LOG.trace("Process DataGenCPInstruction seq with seqFrom=" + doubleValue + ", seqTo=" + doubleValue2 + ", seqIncr" + updateSeqIncr);
            }
            cacheBlock = MatrixBlock.seqOperations(doubleValue, doubleValue2, updateSeqIncr);
        } else if (this.method == Types.OpOpDG.SAMPLE) {
            long longValue = executionContext.getScalarInput(this.rows).getLongValue();
            long j = UtilFunctions.toLong(this.maxValue);
            checkValidDimensions(longValue, 1L);
            if (LOG.isTraceEnabled()) {
                LOG.trace("Process DataGenCPInstruction sample with range=" + j + ", size=" + longValue + ", replace" + this.replace + ", seed=" + this.seed);
            }
            if (j < longValue && !this.replace) {
                throw new DMLRuntimeException("Sample (size=" + longValue + ") larger than population (size=" + j + ") can only be generated with replacement.");
            }
            cacheBlock = MatrixBlock.sampleOperations(j, (int) longValue, this.replace, this.seed);
        } else if (this.method == Types.OpOpDG.TIME) {
            intObject = new IntObject(System.nanoTime());
        } else if (this.method == Types.OpOpDG.FRAMEINIT) {
            int longValue2 = (int) executionContext.getScalarInput(this.rows).getLongValue();
            int longValue3 = (int) executionContext.getScalarInput(this.cols).getLongValue();
            String[] split = this.schema.split("·");
            Types.ValueType[] nCopies = split[0].equals(DataExpression.DEFAULT_SCHEMAPARAM) ? UtilFunctions.nCopies(longValue3, Types.ValueType.STRING) : (split.length != 1 || longValue3 <= 1) ? UtilFunctions.stringToValueType(split) : UtilFunctions.nCopies(longValue3, Types.ValueType.valueOf(split[0]));
            int length = nCopies.length;
            if (length != longValue3) {
                throw new DMLRuntimeException("schema-dimension mismatch");
            }
            if (this.frame_data.equals("")) {
                cacheBlock = UtilFunctions.generateRandomFrameBlock(longValue2, longValue3, nCopies, new Random(10L));
            } else {
                String[] split2 = this.frame_data.split("·");
                int length2 = split2.length / longValue2;
                if (split2.length != length && split2.length > 1 && length2 != length) {
                    throw new DMLRuntimeException("data values should be equal to number of columns, or a single values for all columns");
                }
                cacheBlock = new FrameBlock(nCopies);
                FrameBlock frameBlock = (FrameBlock) cacheBlock;
                if (split2.length > 1 && length2 != length) {
                    for (int i = 0; i < longValue2; i++) {
                        frameBlock.appendRow(split2);
                    }
                } else if (split2.length <= 1 || length2 != length) {
                    String[] strArr = new String[longValue3];
                    Arrays.fill(strArr, this.frame_data);
                    for (int i2 = 0; i2 < longValue2; i2++) {
                        frameBlock.appendRow(strArr);
                    }
                } else {
                    int i3 = 0;
                    for (int i4 = 1; i4 <= longValue2; i4++) {
                        int i5 = longValue3 * i4;
                        String[] strArr2 = (String[]) ArrayUtils.subarray(split2, i3, i5);
                        i3 = i5;
                        frameBlock.appendRow(strArr2);
                    }
                }
            }
        }
        if (this.output.isScalar()) {
            executionContext.setScalarOutput(this.output.getName(), intObject);
        } else {
            setCacheBlockOutput(executionContext, cacheBlock);
        }
    }

    private CacheBlock processRandInstruction(ExecutionContext executionContext) {
        long generateSeed = generateSeed();
        CacheBlock processRandInstructionTensor = this.output.isTensor() ? processRandInstructionTensor(executionContext, generateSeed) : processRandInstructionMatrix(executionContext, generateSeed);
        this.runtimeSeed = null;
        return processRandInstructionTensor;
    }

    private CacheBlock processRandInstructionMatrix(ExecutionContext executionContext, long j) {
        long longValue = executionContext.getScalarInput(this.rows).getLongValue();
        long longValue2 = executionContext.getScalarInput(this.cols).getLongValue();
        checkValidDimensions(longValue, longValue2);
        return (ConfigurationManager.isCompressionEnabled() && this.minValue == this.maxValue && this.sparsity == 1.0d) ? (longValue <= 1000 || longValue2 <= 0 || longValue / longValue2 <= 1) ? MatrixBlock.randOperations(getGenerator(longValue, longValue2), j, this.numThreads) : CompressedMatrixBlockFactory.createConstant((int) longValue, (int) longValue2, this.minValue) : MatrixBlock.randOperations(getGenerator(longValue, longValue2), j, this.numThreads);
    }

    private CacheBlock processRandInstructionTensor(ExecutionContext executionContext, long j) {
        TensorBlock allocateBlock = new TensorBlock(this.output.getValueType(), DataConverter.getTensorDimensions(executionContext, this.dims)).allocateBlock();
        TensorBlock tensorBlock = allocateBlock;
        if (!this.minValueStr.equals(this.maxValueStr)) {
            long dim = tensorBlock.getDim(0);
            long j2 = 1;
            for (int i = 1; i < tensorBlock.getNumDims(); i++) {
                j2 *= tensorBlock.getDim(i);
            }
            tensorBlock.set(MatrixBlock.randOperations(getGenerator(dim, j2), j, this.numThreads));
        } else if (this.minMaxAreDoubles) {
            tensorBlock.set(Double.valueOf(this.minValue));
        } else {
            if (this.output.getValueType() != Types.ValueType.STRING && this.output.getValueType() != Types.ValueType.BOOLEAN) {
                throw new DMLRuntimeException("Rand instruction cannot fill numeric tensor with non numeric elements.");
            }
            tensorBlock.set(this.minValueStr);
        }
        return allocateBlock;
    }

    private long generateSeed() {
        long j = this.seed;
        if (j == -1) {
            if (this.runtimeSeed == null) {
                this.runtimeSeed = Long.valueOf(DataGenOp.generateRandomSeed());
            }
            j = this.runtimeSeed.longValue();
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Process DataGenCPInstruction rand with seed = " + j + ".");
        }
        return j;
    }

    private RandomMatrixGenerator getGenerator(long j, long j2) {
        return LibMatrixDatagen.createRandomMatrixGenerator(this.pdf, (int) j, (int) j2, this.blocksize, this.sparsity, this.minValue, this.maxValue, this.pdfParams);
    }

    private void setCacheBlockOutput(ExecutionContext executionContext, CacheBlock cacheBlock) {
        if (this.output.isMatrix()) {
            MatrixBlock matrixBlock = (MatrixBlock) cacheBlock;
            if (cacheBlock.getInMemorySize() < OptimizerUtils.SAFE_REP_CHANGE_THRES) {
                matrixBlock.examSparsity();
            }
            executionContext.setMatrixOutputAndLineage(this.output.getName(), matrixBlock, CacheableData.isBelowCachingThreshold(cacheBlock) ? null : (LineageItem) getLineageItem(executionContext).getValue());
            return;
        }
        if (this.output.isTensor()) {
            executionContext.setTensorOutput(this.output.getName(), (TensorBlock) cacheBlock);
        } else if (this.output.isFrame()) {
            executionContext.setFrameOutput(this.output.getName(), (FrameBlock) cacheBlock);
        }
    }

    private static void checkValidDimensions(long j, long j2) {
        if (j > OptimizerUtils.MAX_NUMCELLS_CP_DENSE || j2 > OptimizerUtils.MAX_NUMCELLS_CP_DENSE) {
            throw new DMLRuntimeException("DataGenCPInstruction does not support dimensions larger than integer: rows=" + j + ", cols=" + j2 + ".");
        }
    }

    @Override // org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction, org.apache.sysds.runtime.lineage.LineageTraceable
    public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
        String str = this.instString;
        switch (this.method) {
            case RAND:
            case SAMPLE:
                if (getSeed() == -1) {
                    if (this.runtimeSeed == null) {
                        this.runtimeSeed = Long.valueOf((this.minValue == this.maxValue && this.sparsity == 1.0d) ? -1L : DataGenOp.generateRandomSeed());
                    }
                    int i = this.method == Types.OpOpDG.RAND ? 8 : this.method == Types.OpOpDG.SAMPLE ? 4 : 0;
                    str = i != 0 ? InstructionUtils.replaceOperand(str, i, String.valueOf(this.runtimeSeed)) : str;
                }
                String replaceOperandName = InstructionUtils.replaceOperandName(str);
                String replaceNonLiteral = this.method.name().equalsIgnoreCase(DataGen.RAND_OPCODE) ? replaceNonLiteral(replaceOperandName, this.rows, 2, executionContext) : replaceNonLiteral(replaceOperandName, this.rows, 3, executionContext);
                str = this.method.name().equalsIgnoreCase(DataGen.RAND_OPCODE) ? replaceNonLiteral(replaceNonLiteral, this.cols, 3, executionContext) : replaceNonLiteral;
                break;
            case SEQ:
                str = replaceNonLiteral(replaceNonLiteral(replaceNonLiteral(InstructionUtils.replaceOperandName(str), this.seq_from, 5, executionContext), this.seq_to, 6, executionContext), this.seq_incr, 7, executionContext);
                break;
            case FRAMEINIT:
                String replaceNonLiteral2 = replaceNonLiteral(replaceNonLiteral(InstructionUtils.replaceOperand(InstructionUtils.replaceOperandName(str), 2, new CPOperand(this.frame_data, Types.ValueType.STRING, Types.DataType.SCALAR, true).getLineageLiteral()), this.rows, 3, executionContext), this.cols, 4, executionContext);
                str = !this.schema.equalsIgnoreCase(DataExpression.DEFAULT_SCHEMAPARAM) ? InstructionUtils.replaceOperand(replaceNonLiteral2, 5, new CPOperand(this.schema, Types.ValueType.STRING, Types.DataType.SCALAR, true).getLineageLiteral()) : replaceNonLiteral2;
                break;
            case TIME:
                break;
            default:
                throw new DMLRuntimeException("Unsupported datagen op: " + this.method);
        }
        return Pair.of(this.output.getName(), new LineageItem(str, getOpcode()));
    }

    private static String replaceNonLiteral(String str, CPOperand cPOperand, int i, ExecutionContext executionContext) {
        if (!cPOperand.isLiteral()) {
            str = InstructionUtils.replaceOperand(str, i, new CPOperand(executionContext.getScalarInput(cPOperand)).getLineageLiteral());
        }
        return str;
    }
}
