package org.apache.sysds.hops.recompile;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.utils.Statistics;

/* loaded from: input_file:org/apache/sysds/hops/recompile/LiteralReplacement.class */
public class LiteralReplacement {
    private static final long REPLACE_LITERALS_MAX_MATRIX_SIZE = 1000000;
    private static final boolean REPORT_LITERAL_REPLACE_OPS_STATS = true;

    /* JADX INFO: Access modifiers changed from: protected */
    public static void rReplaceLiterals(Hop hop, ExecutionContext executionContext, boolean z) {
        if (hop.isVisited()) {
            return;
        }
        if (hop.getInput() != null) {
            LocalVariableMap variables = executionContext.getVariables();
            for (int i = 0; i < hop.getInput().size(); i++) {
                Hop hop2 = hop.getInput().get(i);
                Hop replaceLiteralScalarRead = replaceLiteralScalarRead(hop2, variables);
                Hop replaceLiteralValueTypeCastScalarRead = replaceLiteralScalarRead == null ? replaceLiteralValueTypeCastScalarRead(hop2, variables) : replaceLiteralScalarRead;
                Hop replaceLiteralValueTypeCastLiteral = replaceLiteralValueTypeCastScalarRead == null ? replaceLiteralValueTypeCastLiteral(hop2, variables) : replaceLiteralValueTypeCastScalarRead;
                if (!z) {
                    Hop replaceLiteralDataTypeCastMatrixRead = replaceLiteralValueTypeCastLiteral == null ? replaceLiteralDataTypeCastMatrixRead(hop2, variables) : replaceLiteralValueTypeCastLiteral;
                    Hop replaceLiteralValueTypeCastRightIndexing = replaceLiteralDataTypeCastMatrixRead == null ? replaceLiteralValueTypeCastRightIndexing(hop2, variables) : replaceLiteralDataTypeCastMatrixRead;
                    Hop replaceLiteralFullUnaryAggregate = replaceLiteralValueTypeCastRightIndexing == null ? replaceLiteralFullUnaryAggregate(hop2, variables) : replaceLiteralValueTypeCastRightIndexing;
                    Hop replaceLiteralFullUnaryAggregateRightIndexing = replaceLiteralFullUnaryAggregate == null ? replaceLiteralFullUnaryAggregateRightIndexing(hop2, variables) : replaceLiteralFullUnaryAggregate;
                    Hop replaceTReadMatrixFromList = replaceLiteralFullUnaryAggregateRightIndexing == null ? replaceTReadMatrixFromList(hop2, executionContext) : replaceLiteralFullUnaryAggregateRightIndexing;
                    Hop replaceTReadMatrixFromListAppend = replaceTReadMatrixFromList == null ? replaceTReadMatrixFromListAppend(hop2, executionContext) : replaceTReadMatrixFromList;
                    Hop replaceTReadMatrixLookupFromList = replaceTReadMatrixFromListAppend == null ? replaceTReadMatrixLookupFromList(hop2, executionContext) : replaceTReadMatrixFromListAppend;
                    replaceLiteralValueTypeCastLiteral = replaceTReadMatrixLookupFromList == null ? replaceTReadScalarLookupFromList(hop2, variables) : replaceTReadMatrixLookupFromList;
                }
                if (replaceLiteralValueTypeCastLiteral == null) {
                    rReplaceLiterals(hop2, executionContext, z);
                } else if (hop2.getParent().size() > 1) {
                    Iterator it = new ArrayList(hop2.getParent()).iterator();
                    while (it.hasNext()) {
                        Hop hop3 = (Hop) it.next();
                        int childReferencePos = HopRewriteUtils.getChildReferencePos(hop3, hop2);
                        HopRewriteUtils.removeChildReferenceByPos(hop3, hop2, childReferencePos);
                        HopRewriteUtils.addChildReference(hop3, replaceLiteralValueTypeCastLiteral, childReferencePos);
                    }
                } else {
                    HopRewriteUtils.replaceChildReference(hop, hop2, replaceLiteralValueTypeCastLiteral, i);
                }
            }
        }
        hop.setVisited();
    }

    private static LiteralOp replaceLiteralScalarRead(Hop hop, LocalVariableMap localVariableMap) {
        Data data;
        LiteralOp literalOp = null;
        if ((hop instanceof DataOp) && ((DataOp) hop).getOp() != Types.OpOpData.PERSISTENTREAD && hop.getDataType() == Types.DataType.SCALAR && (data = localVariableMap.get(hop.getName())) != null) {
            literalOp = ScalarObjectFactory.createLiteralOp((ScalarObject) data);
        }
        return literalOp;
    }

    private static LiteralOp replaceLiteralValueTypeCastScalarRead(Hop hop, LocalVariableMap localVariableMap) {
        Data data;
        LiteralOp literalOp = null;
        if ((hop instanceof UnaryOp) && ((((UnaryOp) hop).getOp() == Types.OpOp1.CAST_AS_DOUBLE || ((UnaryOp) hop).getOp() == Types.OpOp1.CAST_AS_INT || ((UnaryOp) hop).getOp() == Types.OpOp1.CAST_AS_BOOLEAN) && (hop.getInput().get(0) instanceof DataOp) && hop.getDataType() == Types.DataType.SCALAR && (data = localVariableMap.get(hop.getInput().get(0).getName())) != null)) {
            literalOp = ScalarObjectFactory.createLiteralOp((ScalarObject) data, (UnaryOp) hop);
        }
        return literalOp;
    }

    private static LiteralOp replaceLiteralValueTypeCastLiteral(Hop hop, LocalVariableMap localVariableMap) {
        LiteralOp literalOp = null;
        if ((hop instanceof UnaryOp) && ((((UnaryOp) hop).getOp() == Types.OpOp1.CAST_AS_DOUBLE || ((UnaryOp) hop).getOp() == Types.OpOp1.CAST_AS_INT || ((UnaryOp) hop).getOp() == Types.OpOp1.CAST_AS_BOOLEAN) && (hop.getInput().get(0) instanceof LiteralOp))) {
            LiteralOp literalOp2 = (LiteralOp) hop.getInput().get(0);
            try {
                switch (((UnaryOp) hop).getOp()) {
                    case CAST_AS_INT:
                        literalOp = new LiteralOp(HopRewriteUtils.getIntValue(literalOp2));
                        break;
                    case CAST_AS_DOUBLE:
                        literalOp = new LiteralOp(HopRewriteUtils.getDoubleValue(literalOp2));
                        break;
                    case CAST_AS_BOOLEAN:
                        literalOp = new LiteralOp(HopRewriteUtils.getBooleanValue(literalOp2));
                        break;
                }
            } catch (HopsException e) {
                throw new DMLRuntimeException(e);
            }
        }
        return literalOp;
    }

    private static LiteralOp replaceLiteralDataTypeCastMatrixRead(Hop hop, LocalVariableMap localVariableMap) {
        Data data;
        LiteralOp literalOp = null;
        if ((hop instanceof UnaryOp) && ((UnaryOp) hop).getOp() == Types.OpOp1.CAST_AS_SCALAR && (hop.getInput().get(0) instanceof DataOp) && hop.getInput().get(0).getDataType() == Types.DataType.MATRIX && (data = localVariableMap.get(hop.getInput().get(0).getName())) != null) {
            MatrixObject matrixObject = (MatrixObject) data;
            MatrixBlock acquireRead = matrixObject.acquireRead();
            if (acquireRead.getNumRows() != 1 || acquireRead.getNumColumns() != 1) {
                throw new DMLRuntimeException("Dimension mismatch - unable to cast matrix of dimension (" + acquireRead.getNumRows() + " x " + acquireRead.getNumColumns() + ") to scalar.");
            }
            double value = acquireRead.getValue(0, 0);
            matrixObject.release();
            literalOp = new LiteralOp(value);
        }
        return literalOp;
    }

    private static LiteralOp replaceLiteralValueTypeCastRightIndexing(Hop hop, LocalVariableMap localVariableMap) {
        LiteralOp literalOp = null;
        if ((hop instanceof UnaryOp) && ((UnaryOp) hop).getOp() == Types.OpOp1.CAST_AS_SCALAR && (hop.getInput().get(0) instanceof IndexingOp) && hop.getInput().get(0).getDataType() == Types.DataType.MATRIX) {
            IndexingOp indexingOp = (IndexingOp) hop.getInput().get(0);
            Hop hop2 = indexingOp.getInput().get(0);
            Hop hop3 = indexingOp.getInput().get(1);
            Hop hop4 = indexingOp.getInput().get(2);
            Hop hop5 = indexingOp.getInput().get(3);
            Hop hop6 = indexingOp.getInput().get(4);
            if (indexingOp.dimsKnown() && indexingOp.getDim1() == 1 && indexingOp.getDim2() == 1 && (hop2 instanceof DataOp) && localVariableMap.keySet().contains(hop2.getName()) && isIntValueDataLiteral(hop3, localVariableMap) && isIntValueDataLiteral(hop4, localVariableMap) && isIntValueDataLiteral(hop5, localVariableMap) && isIntValueDataLiteral(hop6, localVariableMap)) {
                long intValueDataLiteral = getIntValueDataLiteral(hop3, localVariableMap);
                long intValueDataLiteral2 = getIntValueDataLiteral(hop5, localVariableMap);
                MatrixObject matrixObject = (MatrixObject) localVariableMap.get(hop2.getName());
                if (matrixObject.getNumRows() * matrixObject.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE) {
                    double value = matrixObject.acquireRead().getValue(((int) intValueDataLiteral) - 1, ((int) intValueDataLiteral2) - 1);
                    matrixObject.release();
                    literalOp = new LiteralOp(value);
                }
            }
        }
        return literalOp;
    }

    private static LiteralOp replaceLiteralFullUnaryAggregate(Hop hop, LocalVariableMap localVariableMap) {
        LiteralOp literalOp = null;
        if ((hop instanceof AggUnaryOp) && isReplaceableUnaryAggregate((AggUnaryOp) hop) && (hop.getInput().get(0) instanceof DataOp) && localVariableMap.keySet().contains(hop.getInput().get(0).getName())) {
            MatrixObject matrixObject = (MatrixObject) localVariableMap.get(hop.getInput().get(0).getName());
            if (matrixObject.getNumRows() * matrixObject.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE) {
                double replaceUnaryAggregate = replaceUnaryAggregate((AggUnaryOp) hop, matrixObject.acquireRead());
                matrixObject.release();
                literalOp = new LiteralOp(replaceUnaryAggregate);
            }
        }
        return literalOp;
    }

    private static LiteralOp replaceLiteralFullUnaryAggregateRightIndexing(Hop hop, LocalVariableMap localVariableMap) {
        LiteralOp literalOp = null;
        if ((hop instanceof AggUnaryOp) && isReplaceableUnaryAggregate((AggUnaryOp) hop) && (hop.getInput().get(0) instanceof IndexingOp) && (hop.getInput().get(0).getInput().get(0) instanceof DataOp)) {
            IndexingOp indexingOp = (IndexingOp) hop.getInput().get(0);
            Hop hop2 = indexingOp.getInput().get(0);
            Hop hop3 = indexingOp.getInput().get(1);
            Hop hop4 = indexingOp.getInput().get(2);
            Hop hop5 = indexingOp.getInput().get(3);
            Hop hop6 = indexingOp.getInput().get(4);
            if ((hop2 instanceof DataOp) && localVariableMap.keySet().contains(hop2.getName()) && isIntValueDataLiteral(hop3, localVariableMap) && isIntValueDataLiteral(hop4, localVariableMap) && isIntValueDataLiteral(hop5, localVariableMap) && isIntValueDataLiteral(hop6, localVariableMap)) {
                long intValueDataLiteral = getIntValueDataLiteral(hop3, localVariableMap);
                long intValueDataLiteral2 = getIntValueDataLiteral(hop4, localVariableMap);
                long intValueDataLiteral3 = getIntValueDataLiteral(hop5, localVariableMap);
                long intValueDataLiteral4 = getIntValueDataLiteral(hop6, localVariableMap);
                MatrixObject matrixObject = (MatrixObject) localVariableMap.get(hop2.getName());
                if (matrixObject.getNumRows() * matrixObject.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE) {
                    double replaceUnaryAggregate = replaceUnaryAggregate((AggUnaryOp) hop, matrixObject.acquireRead().slice((int) (intValueDataLiteral - 1), (int) (intValueDataLiteral2 - 1), (int) (intValueDataLiteral3 - 1), (int) (intValueDataLiteral4 - 1), (CacheBlock) new MatrixBlock()));
                    matrixObject.release();
                    literalOp = new LiteralOp(replaceUnaryAggregate);
                }
            }
        }
        return literalOp;
    }

    private static DataOp replaceTReadMatrixFromList(Hop hop, ExecutionContext executionContext) {
        DataOp dataOp = null;
        if (HopRewriteUtils.isUnary(hop, Types.OpOp1.CAST_AS_MATRIX)) {
            Hop hop2 = hop.getInput().get(0);
            if (hop2.getDataType() == Types.DataType.LIST && HopRewriteUtils.isData(hop2, Types.OpOpData.TRANSIENTREAD)) {
                ListObject listObject = (ListObject) executionContext.getVariables().get(hop2.getName());
                if (listObject.getLength() == 1) {
                    String nextUniqueVarname = Dag.getNextUniqueVarname(Types.DataType.MATRIX);
                    executionContext.getVariables().put(nextUniqueVarname, (MatrixObject) listObject.slice(0));
                    dataOp = HopRewriteUtils.createTransientRead(nextUniqueVarname, hop);
                    if (DMLScript.LINEAGE) {
                        executionContext.getLineage().set(nextUniqueVarname, listObject.getLineageItem(0));
                    }
                }
            }
        }
        return dataOp;
    }

    private static NaryOp replaceTReadMatrixFromListAppend(Hop hop, ExecutionContext executionContext) {
        NaryOp naryOp = null;
        if (HopRewriteUtils.isNary(hop, Types.OpOpN.CBIND, Types.OpOpN.RBIND)) {
            Hop hop2 = hop.getInput().get(0);
            if (hop2.getDataType() == Types.DataType.LIST && HopRewriteUtils.isData(hop2, Types.OpOpData.TRANSIENTREAD)) {
                ListObject listObject = (ListObject) executionContext.getVariables().get(hop2.getName());
                if (listObject.getLength() <= 128) {
                    ArrayList arrayList = new ArrayList();
                    for (int i = 0; i < listObject.getLength(); i++) {
                        String nextUniqueVarname = Dag.getNextUniqueVarname(Types.DataType.MATRIX);
                        MatrixObject matrixObject = (MatrixObject) listObject.slice(i);
                        executionContext.getVariables().put(nextUniqueVarname, matrixObject);
                        arrayList.add(HopRewriteUtils.createTransientRead(nextUniqueVarname, matrixObject));
                        if (DMLScript.LINEAGE) {
                            executionContext.getLineage().set(nextUniqueVarname, listObject.getLineageItem(i));
                        }
                    }
                    naryOp = HopRewriteUtils.createNary(((NaryOp) hop).getOp(), (Hop[]) arrayList.toArray(new Hop[0]));
                }
            }
        }
        return naryOp;
    }

    private static DataOp replaceTReadMatrixLookupFromList(Hop hop, ExecutionContext executionContext) {
        LocalVariableMap variables = executionContext.getVariables();
        DataOp dataOp = null;
        if (HopRewriteUtils.isUnary(hop, Types.OpOp1.CAST_AS_MATRIX) && (hop.getInput().get(0) instanceof IndexingOp)) {
            Hop hop2 = hop.getInput().get(0);
            Hop hop3 = hop.getInput().get(0).getInput().get(0);
            if (hop3.getDataType() == Types.DataType.LIST && HopRewriteUtils.isData(hop3, Types.OpOpData.TRANSIENTREAD) && (hop2.getInput().get(1) instanceof LiteralOp) && (hop2.getInput().get(2) instanceof LiteralOp) && hop2.getInput().get(1) == hop2.getInput().get(2)) {
                ListObject listObject = (ListObject) variables.get(hop3.getName());
                String nextUniqueVarname = Dag.getNextUniqueVarname(Types.DataType.MATRIX);
                LiteralOp literalOp = (LiteralOp) hop2.getInput().get(1);
                MatrixObject matrixObject = (MatrixObject) (!literalOp.getValueType().isNumeric() ? listObject.slice(literalOp.getName()) : listObject.slice(((int) literalOp.getLongValue()) - 1));
                LineageItem lineageItem = !literalOp.getValueType().isNumeric() ? listObject.getLineageItem(literalOp.getName()) : listObject.getLineageItem(((int) literalOp.getLongValue()) - 1);
                variables.put(nextUniqueVarname, matrixObject);
                if (DMLScript.LINEAGE) {
                    executionContext.getLineage().set(nextUniqueVarname, lineageItem);
                }
                dataOp = HopRewriteUtils.createTransientRead(nextUniqueVarname, hop);
            }
        }
        return dataOp;
    }

    private static LiteralOp replaceTReadScalarLookupFromList(Hop hop, LocalVariableMap localVariableMap) {
        if (!HopRewriteUtils.isUnary(hop, Types.OpOp1.CAST_AS_SCALAR) || !(hop.getInput().get(0) instanceof IndexingOp)) {
            return null;
        }
        Hop hop2 = hop.getInput().get(0);
        Hop hop3 = hop.getInput().get(0).getInput().get(0);
        if (hop3.getDataType() != Types.DataType.LIST || !HopRewriteUtils.isData(hop3, Types.OpOpData.TRANSIENTREAD) || !(hop2.getInput().get(1) instanceof LiteralOp) || !(hop2.getInput().get(2) instanceof LiteralOp) || hop2.getInput().get(1) != hop2.getInput().get(2)) {
            return null;
        }
        ListObject listObject = (ListObject) localVariableMap.get(hop3.getName());
        LiteralOp literalOp = (LiteralOp) hop2.getInput().get(1);
        return ScalarObjectFactory.createLiteralOp((ScalarObject) (!literalOp.getValueType().isNumeric() ? listObject.slice(literalOp.getName()) : listObject.slice(((int) literalOp.getLongValue()) - 1)));
    }

    private static boolean isIntValueDataLiteral(Hop hop, LocalVariableMap localVariableMap) {
        return ((hop instanceof DataOp) && localVariableMap.keySet().contains(hop.getName())) || (hop instanceof LiteralOp) || ((hop instanceof UnaryOp) && ((((UnaryOp) hop).getOp() == Types.OpOp1.NROW || ((UnaryOp) hop).getOp() == Types.OpOp1.NCOL) && (hop.getInput().get(0) instanceof DataOp) && localVariableMap.keySet().contains(hop.getInput().get(0).getName())));
    }

    private static long getIntValueDataLiteral(Hop hop, LocalVariableMap localVariableMap) {
        try {
            return hop instanceof LiteralOp ? HopRewriteUtils.getIntValue((LiteralOp) hop) : ((hop instanceof UnaryOp) && ((UnaryOp) hop).getOp() == Types.OpOp1.NROW) ? ((CacheableData) localVariableMap.get(hop.getInput().get(0).getName())).getNumRows() : ((hop instanceof UnaryOp) && ((UnaryOp) hop).getOp() == Types.OpOp1.NCOL) ? ((CacheableData) localVariableMap.get(hop.getInput().get(0).getName())).getNumColumns() : ((ScalarObject) localVariableMap.get(hop.getName())).getLongValue();
        } catch (HopsException e) {
            throw new DMLRuntimeException("Failed to get int value for literal replacement", e);
        }
    }

    private static boolean isReplaceableUnaryAggregate(AggUnaryOp aggUnaryOp) {
        return (aggUnaryOp.getDirection() == Types.Direction.RowCol) && (aggUnaryOp.getOp() == Types.AggOp.SUM || aggUnaryOp.getOp() == Types.AggOp.SUM_SQ || aggUnaryOp.getOp() == Types.AggOp.MIN || aggUnaryOp.getOp() == Types.AggOp.MAX) && aggUnaryOp.getInput().get(0).getDataType().isMatrix();
    }

    private static double replaceUnaryAggregate(AggUnaryOp aggUnaryOp, MatrixBlock matrixBlock) {
        double max;
        boolean z = DMLScript.STATISTICS;
        long nanoTime = z ? System.nanoTime() : 0L;
        switch (aggUnaryOp.getOp()) {
            case SUM:
                max = matrixBlock.sum();
                break;
            case SUM_SQ:
                max = matrixBlock.sumSq();
                break;
            case MIN:
                max = matrixBlock.min();
                break;
            case MAX:
                max = matrixBlock.max();
                break;
            default:
                throw new DMLRuntimeException("Unsupported unary aggregate replacement: " + aggUnaryOp.getOp());
        }
        if (z) {
            Statistics.maintainCPHeavyHitters("rlit", System.nanoTime() - nanoTime);
        }
        return max;
    }
}
