package org.apache.sysds.runtime.util;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataGenOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.DataGen;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionParser;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.spark.RandSPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageParser;

/* loaded from: input_file:org/apache/sysds/runtime/util/AutoDiff.class */
public class AutoDiff {
    private static final String ADVARPREFIX = "adVar";
    private static final boolean DEBUG = false;

    public static ListObject getBackward(MatrixObject matrixObject, ArrayList<Data> arrayList, ExecutionContext executionContext) {
        ArrayList arrayList2 = new ArrayList();
        return new ListObject(parseNComputeAutoDiffFromLineage(matrixObject, arrayList.get(0).toString().replace("foo", ""), arrayList2, executionContext), arrayList2);
    }

    public static List<Data> parseNComputeAutoDiffFromLineage(MatrixObject matrixObject, String str, ArrayList<String> arrayList, ExecutionContext executionContext) {
        LineageItem parseLineageTrace = LineageParser.parseLineageTrace(str);
        parseLineageTrace.resetVisitStatusNR();
        HashMap hashMap = new HashMap();
        executionContext.setVariable("X", matrixObject);
        ArrayList<Hop> constructHopsNR = constructHopsNR(parseLineageTrace, hashMap, HopRewriteUtils.createTransientRead("X", matrixObject), arrayList);
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < constructHopsNR.size(); i++) {
            executeInst(Recompiler.recompileHopsDag(HopRewriteUtils.createTransientWrite("advar" + i, constructHopsNR.get(i)), executionContext.getVariables(), null, true, true, 0L), executionContext);
            arrayList2.add(executionContext.getVariable("advar" + i));
        }
        return arrayList2;
    }

    public static ArrayList<Hop> constructHopsNR(LineageItem lineageItem, Map<Long, Hop> map, Hop hop, ArrayList<String> arrayList) {
        ArrayList<Hop> arrayList2 = new ArrayList<>();
        Stack stack = new Stack();
        Stack stack2 = new Stack();
        stack.push(lineageItem);
        stack2.push(new MutableInt(0));
        while (!stack.empty()) {
            LineageItem lineageItem2 = (LineageItem) stack.peek();
            MutableInt mutableInt = (MutableInt) stack2.peek();
            if (lineageItem2.isVisited()) {
                stack.pop();
                stack2.pop();
            } else if (lineageItem2.getInputs() == null || lineageItem2.getInputs().length <= mutableInt.intValue()) {
                constructSingleHop(lineageItem2, map, hop, arrayList2, arrayList);
                stack.pop();
                stack2.pop();
                lineageItem2.setVisited();
            } else if (lineageItem2.getInputs() != null) {
                stack.push(lineageItem2.getInputs()[mutableInt.intValue()]);
                mutableInt.increment();
                stack2.push(new MutableInt(0));
            }
        }
        return arrayList2;
    }

    private static void constructSingleHop(LineageItem lineageItem, Map<Long, Hop> map, Hop hop, ArrayList<Hop> arrayList, ArrayList<String> arrayList2) {
        switch (lineageItem.getType()) {
            case Creation:
                if (lineageItem.getData().startsWith(ADVARPREFIX)) {
                    long parseLong = Long.parseLong(lineageItem.getData().substring(3));
                    Hop hop2 = map.get(Long.valueOf(parseLong));
                    map.remove(Long.valueOf(parseLong));
                    map.put(Long.valueOf(lineageItem.getId()), hop2);
                    return;
                }
                Instruction parseSingleInstruction = InstructionParser.parseSingleInstruction(lineageItem.getData());
                if (parseSingleInstruction instanceof DataGenCPInstruction) {
                    DataGenCPInstruction dataGenCPInstruction = (DataGenCPInstruction) parseSingleInstruction;
                    HashMap hashMap = new HashMap();
                    if (dataGenCPInstruction.getOpcode().equals(DataGen.RAND_OPCODE)) {
                        if (dataGenCPInstruction.output.getDataType() == Types.DataType.TENSOR) {
                            hashMap.put(DataExpression.RAND_DIMS, new LiteralOp(dataGenCPInstruction.getDims()));
                        } else {
                            hashMap.put("rows", new LiteralOp(dataGenCPInstruction.getRows()));
                            hashMap.put("cols", new LiteralOp(dataGenCPInstruction.getCols()));
                        }
                        hashMap.put("min", new LiteralOp(dataGenCPInstruction.getMinValue()));
                        hashMap.put("max", new LiteralOp(dataGenCPInstruction.getMaxValue()));
                        hashMap.put(DataExpression.RAND_PDF, new LiteralOp(dataGenCPInstruction.getPdf()));
                        hashMap.put(DataExpression.RAND_LAMBDA, new LiteralOp(dataGenCPInstruction.getPdfParams()));
                        hashMap.put(DataExpression.RAND_SPARSITY, new LiteralOp(dataGenCPInstruction.getSparsity()));
                        hashMap.put("seed", new LiteralOp(dataGenCPInstruction.getSeed()));
                    }
                    Hop dataGenOp = new DataGenOp(Types.OpOpDG.valueOf(dataGenCPInstruction.getOpcode().toUpperCase()), new DataIdentifier("tmp"), hashMap);
                    dataGenOp.setBlocksize(dataGenCPInstruction.getBlocksize());
                    map.put(Long.valueOf(lineageItem.getId()), dataGenOp);
                    return;
                }
                if ((parseSingleInstruction instanceof VariableCPInstruction) && ((VariableCPInstruction) parseSingleInstruction).isCreateVariable()) {
                    String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(parseSingleInstruction.toString());
                    Types.DataType valueOf = Types.DataType.valueOf(instructionPartsWithValueType[4]);
                    Types.ValueType valueType = valueOf == Types.DataType.MATRIX ? Types.ValueType.FP64 : Types.ValueType.STRING;
                    HashMap hashMap2 = new HashMap();
                    hashMap2.put(DataExpression.IO_FILENAME, new LiteralOp(instructionPartsWithValueType[2]));
                    hashMap2.put("rows", new LiteralOp(Long.parseLong(instructionPartsWithValueType[6])));
                    hashMap2.put("cols", new LiteralOp(Long.parseLong(instructionPartsWithValueType[7])));
                    hashMap2.put(DataExpression.READNNZPARAM, new LiteralOp(Long.parseLong(instructionPartsWithValueType[8])));
                    hashMap2.put(DataExpression.FORMAT_TYPE, new LiteralOp(instructionPartsWithValueType[5]));
                    DataOp dataOp = new DataOp(instructionPartsWithValueType[1].substring(5), valueOf, valueType, Types.OpOpData.PERSISTENTREAD, hashMap2);
                    dataOp.setFileName(instructionPartsWithValueType[2]);
                    map.put(Long.valueOf(lineageItem.getId()), dataOp);
                    return;
                }
                if (parseSingleInstruction instanceof RandSPInstruction) {
                    RandSPInstruction randSPInstruction = (RandSPInstruction) parseSingleInstruction;
                    HashMap hashMap3 = new HashMap();
                    if (randSPInstruction.output.getDataType() == Types.DataType.TENSOR) {
                        hashMap3.put(DataExpression.RAND_DIMS, new LiteralOp(randSPInstruction.getDims()));
                    } else {
                        hashMap3.put("rows", new LiteralOp(randSPInstruction.getRows()));
                        hashMap3.put("cols", new LiteralOp(randSPInstruction.getCols()));
                    }
                    hashMap3.put("min", new LiteralOp(randSPInstruction.getMinValue()));
                    hashMap3.put("max", new LiteralOp(randSPInstruction.getMaxValue()));
                    hashMap3.put(DataExpression.RAND_PDF, new LiteralOp(randSPInstruction.getPdf()));
                    hashMap3.put(DataExpression.RAND_LAMBDA, new LiteralOp(randSPInstruction.getPdfParams()));
                    hashMap3.put(DataExpression.RAND_SPARSITY, new LiteralOp(randSPInstruction.getSparsity()));
                    hashMap3.put("seed", new LiteralOp(randSPInstruction.getSeed()));
                    Hop dataGenOp2 = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("tmp"), hashMap3);
                    dataGenOp2.setBlocksize(randSPInstruction.getBlocksize());
                    map.put(Long.valueOf(lineageItem.getId()), dataGenOp2);
                    return;
                }
                return;
            case Instruction:
                CPInstruction.CPType cPTypeByOpcode = InstructionUtils.getCPTypeByOpcode(lineageItem.getOpcode());
                if (cPTypeByOpcode != null) {
                    switch (cPTypeByOpcode) {
                        case AggregateBinary:
                            Hop hop3 = map.get(Long.valueOf(lineageItem.getInputs()[0].getId()));
                            Hop hop4 = map.get(Long.valueOf(lineageItem.getInputs()[1].getId()));
                            ReorgOp createTranspose = HopRewriteUtils.createTranspose(hop3);
                            Hop createMatrixMultiply = HopRewriteUtils.createMatrixMultiply(hop, HopRewriteUtils.createTranspose(hop4));
                            Hop createMatrixMultiply2 = HopRewriteUtils.createMatrixMultiply(createTranspose, hop);
                            map.put(Long.valueOf(lineageItem.getId()), createMatrixMultiply);
                            map.put(Long.valueOf(lineageItem.getId() + 1), createMatrixMultiply2);
                            arrayList.add(createMatrixMultiply);
                            arrayList.add(createMatrixMultiply2);
                            arrayList2.add("dX");
                            arrayList2.add("dW");
                            return;
                        case Binary:
                            Hop hop5 = null;
                            if (lineageItem.getOpcode().equals("+")) {
                                hop5 = HopRewriteUtils.createAggUnaryOp(hop, Types.AggOp.SUM, Types.Direction.Col);
                            }
                            map.put(Long.valueOf(lineageItem.getId()), hop5);
                            arrayList.add(hop5);
                            arrayList2.add("dB");
                            return;
                        default:
                            throw new DMLRuntimeException("Unsupported autoDiff instruction type: " + cPTypeByOpcode.name() + " (" + lineageItem.getOpcode() + ").");
                    }
                }
                return;
            case Literal:
                CPOperand cPOperand = new CPOperand(lineageItem.getData());
                map.put(Long.valueOf(lineageItem.getId()), ScalarObjectFactory.createLiteralOp(cPOperand.getValueType(), cPOperand.getName()));
                return;
            default:
                throw new DMLRuntimeException("Lineage type " + lineageItem.getType() + " is not supported");
        }
    }

    private static void executeInst(ArrayList<Instruction> arrayList, ExecutionContext executionContext) {
        try {
            BasicProgramBlock basicProgramBlock = new BasicProgramBlock(new Program());
            basicProgramBlock.setInstructions(arrayList);
            basicProgramBlock.execute(executionContext);
        } catch (Exception e) {
            throw new DMLRuntimeException("Error executing autoDiff instruction", e);
        }
    }
}
