package org.apache.sysds.runtime.instructions.cp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.ParameterizedBuiltinFunctionExpression;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
import org.apache.sysds.runtime.transform.decode.Decoder;
import org.apache.sysds.runtime.transform.decode.DecoderFactory;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
import org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
import org.apache.sysds.runtime.util.DataConverter;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.class */
public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction {
    private static final Log LOG = LogFactory.getLog(ParameterizedBuiltinCPInstruction.class.getName());
    private static final int TOSTRING_MAXROWS = 100;
    private static final int TOSTRING_MAXCOLS = 100;
    private static final int TOSTRING_DECIMAL = 3;
    private static final boolean TOSTRING_SPARSE = false;
    private static final String TOSTRING_SEPARATOR = " ";
    private static final String TOSTRING_LINESEPARATOR = "\n";
    protected final LinkedHashMap<String, String> params;

    /* JADX INFO: Access modifiers changed from: protected */
    public ParameterizedBuiltinCPInstruction(Operator operator, LinkedHashMap<String, String> linkedHashMap, CPOperand cPOperand, String str, String str2) {
        super(CPInstruction.CPType.ParameterizedBuiltin, operator, null, null, cPOperand, str, str2);
        this.params = linkedHashMap;
    }

    public HashMap<String, String> getParameterMap() {
        return this.params;
    }

    public String getParam(String str) {
        return getParameterMap().get(str);
    }

    public static LinkedHashMap<String, String> constructParameterMap(String[] strArr) {
        LinkedHashMap<String, String> linkedHashMap = new LinkedHashMap<>();
        for (int i = 1; i <= strArr.length - 2; i++) {
            String[] split = strArr[i].split("=");
            linkedHashMap.put(split[0], split[1]);
        }
        return linkedHashMap;
    }

    public static ParameterizedBuiltinCPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[instructionPartsWithValueType.length - 1]);
        LinkedHashMap<String, String> constructParameterMap = constructParameterMap(instructionPartsWithValueType);
        if (str2.equalsIgnoreCase("cdf")) {
            if (constructParameterMap.get("dist") == null) {
                throw new DMLRuntimeException("Invalid distribution: " + str);
            }
            return new ParameterizedBuiltinCPInstruction(new SimpleOperator(ParameterizedBuiltin.getParameterizedBuiltinFnObject(str2, constructParameterMap.get("dist"))), constructParameterMap, cPOperand, str2, str);
        }
        if (str2.equalsIgnoreCase("invcdf")) {
            if (constructParameterMap.get("dist") == null) {
                throw new DMLRuntimeException("Invalid distribution: " + str);
            }
            return new ParameterizedBuiltinCPInstruction(new SimpleOperator(ParameterizedBuiltin.getParameterizedBuiltinFnObject(str2, constructParameterMap.get("dist"))), constructParameterMap, cPOperand, str2, str);
        }
        if (str2.equalsIgnoreCase("groupedagg")) {
            String str3 = constructParameterMap.get(Statement.GAGG_FN);
            if (str3 == null) {
                throw new DMLRuntimeException("Function parameter is missing in groupedAggregate.");
            }
            if (str3.equalsIgnoreCase(Statement.GAGG_FN_CM) && constructParameterMap.get(Statement.GAGG_FN_CM_ORDER) == null) {
                throw new DMLRuntimeException("Mandatory \"order\" must be specified when fn=\"centralmoment\" in groupedAggregate.");
            }
            return new ParameterizedBuiltinCPInstruction(InstructionUtils.parseGroupedAggOperator(str3, constructParameterMap.get(Statement.GAGG_FN_CM_ORDER)), constructParameterMap, cPOperand, str2, str);
        }
        if (str2.equalsIgnoreCase("rmempty") || str2.equalsIgnoreCase("replace") || str2.equalsIgnoreCase("rexpand") || str2.equalsIgnoreCase("lowertri") || str2.equalsIgnoreCase("uppertri")) {
            return new ParameterizedBuiltinCPInstruction(new SimpleOperator(ParameterizedBuiltin.getParameterizedBuiltinFnObject(str2)), constructParameterMap, cPOperand, str2, str);
        }
        if (str2.equals("transformapply") || str2.equals("transformdecode") || str2.equals("transformcolmap") || str2.equals("transformmeta") || str2.equals("tokenize") || str2.equals("toString") || str2.equals("nvlist")) {
            return new ParameterizedBuiltinCPInstruction(null, constructParameterMap, cPOperand, str2, str);
        }
        if ("paramserv".equals(str2)) {
            return new ParamservBuiltinCPInstruction(null, constructParameterMap, cPOperand, str2, str);
        }
        throw new DMLRuntimeException("Unknown opcode (" + str2 + ") for ParameterizedBuiltin Instruction.");
    }

    @Override // org.apache.sysds.runtime.instructions.cp.CPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        String dataConverter;
        String opcode = getOpcode();
        if (opcode.equalsIgnoreCase("cdf")) {
            executionContext.setScalarOutput(this.output.getName(), new DoubleObject(((SimpleOperator) this._optr).fn.execute(this.params)));
            return;
        }
        if (opcode.equalsIgnoreCase("invcdf")) {
            executionContext.setScalarOutput(this.output.getName(), new DoubleObject(((SimpleOperator) this._optr).fn.execute(this.params)));
            return;
        }
        if (opcode.equalsIgnoreCase("groupedagg")) {
            MatrixBlock matrixInput = executionContext.getMatrixInput(this.params.get("target"));
            MatrixBlock matrixInput2 = executionContext.getMatrixInput(this.params.get(Statement.GAGG_GROUPS));
            MatrixBlock matrixBlock = null;
            if (this.params.get(Statement.GAGG_WEIGHTS) != null) {
                matrixBlock = executionContext.getMatrixInput(this.params.get(Statement.GAGG_WEIGHTS));
            }
            int i = -1;
            if (this.params.get(Statement.GAGG_NUM_GROUPS) != null) {
                i = (int) Double.parseDouble(this.params.get(Statement.GAGG_NUM_GROUPS));
            }
            executionContext.setMatrixOutput(this.output.getName(), matrixInput2.groupedAggOperations(matrixInput, matrixBlock, new MatrixBlock(), i, this._optr, Integer.parseInt(this.params.get(Statement.PS_PARALLELISM))));
            executionContext.releaseMatrixInput(this.params.get("target"));
            executionContext.releaseMatrixInput(this.params.get(Statement.GAGG_GROUPS));
            if (this.params.get(Statement.GAGG_WEIGHTS) != null) {
                executionContext.releaseMatrixInput(this.params.get(Statement.GAGG_WEIGHTS));
                return;
            }
            return;
        }
        if (opcode.equalsIgnoreCase("rmempty")) {
            String str = this.params.get("margin");
            if (!str.equals("rows") && !str.equals("cols")) {
                throw new DMLRuntimeException("Unspupported margin identifier '" + str + "'.");
            }
            executionContext.setMatrixOutput(this.output.getName(), executionContext.getMatrixInput(this.params.get("target")).removeEmptyOperations(new MatrixBlock(), str.equals("rows"), Boolean.parseBoolean(this.params.get("empty.return").toLowerCase()), this.params.containsKey("select") ? executionContext.getMatrixInput(this.params.get("select")) : null));
            executionContext.releaseMatrixInput(this.params.get("target"));
            if (this.params.containsKey("select")) {
                executionContext.releaseMatrixInput(this.params.get("select"));
                return;
            }
            return;
        }
        if (opcode.equalsIgnoreCase("replace")) {
            executionContext.setMatrixOutput(this.output.getName(), executionContext.getMatrixInput(this.params.get("target")).replaceOperations((MatrixValue) new MatrixBlock(), Double.parseDouble(this.params.get("pattern")), Double.parseDouble(this.params.get("replacement"))));
            executionContext.releaseMatrixInput(this.params.get("target"));
            return;
        }
        if (opcode.equals("lowertri") || opcode.equals("uppertri")) {
            executionContext.setMatrixOutput(this.output.getName(), executionContext.getMatrixInput(this.params.get("target")).extractTriangular(new MatrixBlock(), opcode.equals("lowertri"), Boolean.parseBoolean(this.params.get("diag")), Boolean.parseBoolean(this.params.get("values"))));
            executionContext.releaseMatrixInput(this.params.get("target"));
            return;
        }
        if (opcode.equalsIgnoreCase("rexpand")) {
            executionContext.setMatrixOutput(this.output.getName(), executionContext.getMatrixInput(this.params.get("target")).rexpandOperations(new MatrixBlock(), Double.parseDouble(this.params.get("max")), this.params.get("dir").equals("rows"), Boolean.parseBoolean(this.params.get("cast")), Boolean.parseBoolean(this.params.get("ignore")), Integer.parseInt(this.params.get(Statement.PS_PARALLELISM))));
            executionContext.releaseMatrixInput(this.params.get("target"));
            return;
        }
        if (opcode.equalsIgnoreCase("tokenize")) {
            FrameBlock frameInput = executionContext.getFrameInput(this.params.get("target"));
            Tokenizer createTokenizer = TokenizerFactory.createTokenizer(getParameterMap().get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_SPEC), Integer.parseInt(getParameterMap().get("max_tokens")));
            executionContext.setFrameOutput(this.output.getName(), createTokenizer.tokenize(frameInput, new FrameBlock(createTokenizer.getSchema())));
            executionContext.releaseFrameInput(this.params.get("target"));
            return;
        }
        if (opcode.equalsIgnoreCase("transformapply")) {
            FrameBlock frameInput2 = executionContext.getFrameInput(this.params.get("target"));
            executionContext.setMatrixOutput(this.output.getName(), EncoderFactory.createEncoder(this.params.get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_SPEC), frameInput2.getColumnNames(), frameInput2.getNumColumns(), executionContext.getFrameInput(this.params.get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_MTD2))).apply(frameInput2));
            executionContext.releaseFrameInput(this.params.get("target"));
            executionContext.releaseFrameInput(this.params.get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_MTD2));
            return;
        }
        if (opcode.equalsIgnoreCase("transformdecode")) {
            MatrixBlock matrixInput3 = executionContext.getMatrixInput(this.params.get("target"));
            FrameBlock frameInput3 = executionContext.getFrameInput(this.params.get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_MTD2));
            String[] columnNames = frameInput3.getColumnNames();
            Decoder createDecoder = DecoderFactory.createDecoder(getParameterMap().get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_SPEC), columnNames, null, frameInput3, matrixInput3.getNumColumns());
            FrameBlock decode = createDecoder.decode(matrixInput3, new FrameBlock(createDecoder.getSchema()));
            decode.setColumnNames((String[]) Arrays.copyOfRange(columnNames, 0, decode.getNumColumns()));
            executionContext.setFrameOutput(this.output.getName(), decode);
            executionContext.releaseMatrixInput(this.params.get("target"));
            executionContext.releaseFrameInput(this.params.get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_MTD2));
            return;
        }
        if (opcode.equalsIgnoreCase("transformcolmap")) {
            FrameBlock frameInput4 = executionContext.getFrameInput(this.params.get("target"));
            executionContext.setMatrixOutput(this.output.getName(), EncoderFactory.createEncoder(this.params.get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_SPEC), frameInput4.getColumnNames(), frameInput4.getNumColumns(), (FrameBlock) null).getColMapping(frameInput4));
            executionContext.releaseFrameInput(this.params.get("target"));
            return;
        }
        if (opcode.equalsIgnoreCase("transformmeta")) {
            try {
                executionContext.setFrameOutput(this.output.getName(), TfMetaUtils.readTransformMetaDataFromFile(getParameterMap().get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_SPEC), getParameterMap().get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_MTD), getParameterMap().getOrDefault(DataExpression.DELIM_DELIMITER, ",")));
                return;
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
        if (!opcode.equalsIgnoreCase("toString")) {
            if (!opcode.equals("nvlist")) {
                throw new DMLRuntimeException("Unknown opcode : " + opcode);
            }
            List list = (List) this.params.values().stream().map(str2 -> {
                return executionContext.containsVariable(str2) ? executionContext.getVariable(str2) : new StringObject(str2);
            }).collect(Collectors.toList());
            ArrayList arrayList = new ArrayList(this.params.keySet());
            ListObject listObject = DMLScript.LINEAGE ? new ListObject(list, arrayList, Arrays.asList(LineageItemUtils.getLineage(executionContext, (CPOperand[]) arrayList.stream().map(str3 -> {
                return executionContext.containsVariable(this.params.get(str3)) ? new CPOperand(str3, executionContext.getVariable(this.params.get(str3))) : getStringLiteral(str3);
            }).toArray(i2 -> {
                return new CPOperand[i2];
            })))) : new ListObject((List<Data>) list, arrayList);
            listObject.deriveAndSetStatusFromData();
            executionContext.setVariable(this.output.getName(), listObject);
            return;
        }
        int parseInt = getParam("rows") != null ? Integer.parseInt(getParam("rows")) : 100;
        int parseInt2 = getParam("cols") != null ? Integer.parseInt(getParam("cols")) : 100;
        int parseInt3 = getParam("decimal") != null ? Integer.parseInt(getParam("decimal")) : 3;
        boolean parseBoolean = getParam(DataExpression.DELIM_SPARSE) != null ? Boolean.parseBoolean(getParam(DataExpression.DELIM_SPARSE)) : false;
        String param = getParam(DataExpression.DELIM_DELIMITER) != null ? getParam(DataExpression.DELIM_DELIMITER) : " ";
        String param2 = getParam("linesep") != null ? getParam("linesep") : "\n";
        Data variable = executionContext.getVariable(getParam("target"));
        if (variable instanceof MatrixObject) {
            MatrixBlock acquireRead = ((MatrixObject) variable).acquireRead();
            warnOnTrunction(acquireRead, parseInt, parseInt2);
            dataConverter = DataConverter.toString(acquireRead, parseBoolean, param, param2, parseInt, parseInt2, parseInt3);
        } else if (variable instanceof TensorObject) {
            TensorBlock acquireRead2 = ((TensorObject) variable).acquireRead();
            warnOnTrunction(acquireRead2, parseInt, parseInt2);
            dataConverter = DataConverter.toString(acquireRead2, parseBoolean, param, param2, "[", "]", parseInt, parseInt2, parseInt3);
        } else if (variable instanceof FrameObject) {
            FrameBlock acquireRead3 = ((FrameObject) variable).acquireRead();
            warnOnTrunction(acquireRead3, parseInt, parseInt2);
            dataConverter = DataConverter.toString(acquireRead3, parseBoolean, param, param2, parseInt, parseInt2, parseInt3);
        } else {
            if (!(variable instanceof ListObject)) {
                throw new DMLRuntimeException("toString only converts matrix, tensors, lists or frames to string");
            }
            dataConverter = DataConverter.toString((ListObject) variable, parseInt, parseInt2, parseBoolean, param, param2, parseInt, parseInt2, parseInt3);
        }
        if (!(variable instanceof ListObject)) {
            executionContext.releaseCacheableData(getParam("target"));
        }
        executionContext.setScalarOutput(this.output.getName(), new StringObject(dataConverter));
    }

    private void warnOnTrunction(CacheBlock cacheBlock, int i, int i2) {
        if ((getParam("rows") != null || cacheBlock.getNumRows() <= i) && (getParam("cols") != null || cacheBlock.getNumColumns() <= i2)) {
            return;
        }
        LOG.warn("Truncating " + cacheBlock.getClass().getSimpleName() + " of size " + cacheBlock.getNumRows() + "x" + cacheBlock.getNumColumns() + " to " + i + "x" + i2 + ". Use toString(X, rows=..., cols=...) if necessary.");
    }

    private void warnOnTrunction(TensorBlock tensorBlock, int i, int i2) {
        if ((getParam("rows") != null || tensorBlock.getDim(0) <= i) && (getParam("cols") != null || tensorBlock.getDim(1) <= i2)) {
            return;
        }
        StringBuilder sb = new StringBuilder();
        IntStream.range(0, tensorBlock.getNumDims()).forEach(i3 -> {
            if (i3 == tensorBlock.getNumDims() - 1) {
                sb.append(tensorBlock.getDim(i3));
            } else {
                sb.append(tensorBlock.getDim(i3)).append("x");
            }
        });
        LOG.warn("Truncating " + tensorBlock.getClass().getSimpleName() + " of size " + sb.toString() + " to " + i + "x" + i2 + ". Use toString(X, rows=..., cols=...) if necessary.");
    }

    @Override // org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction, org.apache.sysds.runtime.lineage.LineageTraceable
    public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
        String opcode = getOpcode();
        if (opcode.equalsIgnoreCase("groupedagg")) {
            return Pair.of(this.output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, getTargetOperand(), new CPOperand(this.params.get(Statement.GAGG_GROUPS), Types.ValueType.FP64, Types.DataType.MATRIX), new CPOperand(this.params.containsKey(Statement.GAGG_WEIGHTS) ? this.params.get(Statement.GAGG_WEIGHTS) : String.valueOf(-1), Types.ValueType.FP64, Types.DataType.MATRIX), getStringLiteral(Statement.GAGG_FN), new CPOperand(this.params.containsKey(Statement.GAGG_NUM_GROUPS) ? this.params.get(Statement.GAGG_NUM_GROUPS) : String.valueOf(-1), Types.ValueType.INT64, Types.DataType.SCALAR, true))));
        }
        if (opcode.equalsIgnoreCase("rmempty")) {
            return Pair.of(this.output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, getTargetOperand(), getStringLiteral("margin"), new CPOperand(this.params.containsKey("select") ? this.params.get("select") : String.valueOf(-1), Types.ValueType.FP64, Types.DataType.MATRIX))));
        }
        if (opcode.equalsIgnoreCase("replace")) {
            return Pair.of(this.output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, getTargetOperand(), getFP64Literal("pattern"), getFP64Literal("replacement"))));
        }
        if (opcode.equalsIgnoreCase("rexpand")) {
            return Pair.of(this.output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, getTargetOperand(), getFP64Literal("max"), getStringLiteral("dir"), getBoolLiteral("cast"), getBoolLiteral("ignore"))));
        }
        if (opcode.equalsIgnoreCase("lowertri") || opcode.equalsIgnoreCase("uppertri")) {
            return Pair.of(this.output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, getTargetOperand(), getBoolLiteral("lowertri"), getBoolLiteral("diag"), getBoolLiteral("values"))));
        }
        if (opcode.equalsIgnoreCase("transformdecode") || opcode.equalsIgnoreCase("transformapply")) {
            return Pair.of(this.output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, new CPOperand(this.params.get("target"), Types.ValueType.FP64, Types.DataType.FRAME), getLiteral(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_MTD2, Types.ValueType.UNKNOWN, Types.DataType.FRAME), getStringLiteral(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_SPEC))));
        }
        if (opcode.equalsIgnoreCase("nvlist")) {
            return Pair.of(this.output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, (CPOperand[]) new ArrayList(this.params.keySet()).stream().map(str -> {
                return executionContext.containsVariable(this.params.get(str)) ? new CPOperand(str, executionContext.getVariable(this.params.get(str))) : getStringLiteral(str);
            }).toArray(i -> {
                return new CPOperand[i];
            }))));
        }
        throw new DMLRuntimeException("Unsupported lineage tracing for: " + opcode);
    }

    public CacheableData<?> getTarget(ExecutionContext executionContext) {
        return executionContext.getCacheableData(this.params.get("target"));
    }

    private CPOperand getTargetOperand() {
        return new CPOperand(this.params.get("target"), Types.ValueType.FP64, Types.DataType.MATRIX);
    }

    private CPOperand getFP64Literal(String str) {
        return getLiteral(str, Types.ValueType.FP64);
    }

    private CPOperand getStringLiteral(String str) {
        return getLiteral(str, Types.ValueType.STRING);
    }

    private CPOperand getBoolLiteral(String str) {
        return getLiteral(str, Types.ValueType.BOOLEAN);
    }

    private CPOperand getLiteral(String str, Types.ValueType valueType) {
        return new CPOperand(this.params.get(str), valueType, Types.DataType.SCALAR, true);
    }

    private CPOperand getLiteral(String str, Types.ValueType valueType, Types.DataType dataType) {
        return new CPOperand(this.params.get(str), valueType, dataType);
    }
}
