package org.apache.sysds.api.jmlc;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;

/* loaded from: input_file:org/apache/sysds/api/jmlc/JMLCUtils.class */
public class JMLCUtils {
    public static void cleanupRuntimeProgram(Program program, String[] strArr) {
        HashMap<String, FunctionProgramBlock> functionProgramBlocks = program.getFunctionProgramBlocks();
        HashSet hashSet = new HashSet(Arrays.asList(strArr));
        if (functionProgramBlocks != null && !functionProgramBlocks.isEmpty()) {
            Iterator<Map.Entry<String, FunctionProgramBlock>> it = functionProgramBlocks.entrySet().iterator();
            while (it.hasNext()) {
                Iterator<ProgramBlock> it2 = it.next().getValue().getChildBlocks().iterator();
                while (it2.hasNext()) {
                    rCleanupRuntimeProgram(it2.next(), hashSet);
                }
            }
        }
        Iterator<ProgramBlock> it3 = program.getProgramBlocks().iterator();
        while (it3.hasNext()) {
            rCleanupRuntimeProgram(it3.next(), hashSet);
        }
    }

    public static void rCleanupRuntimeProgram(ProgramBlock programBlock, HashSet<String> hashSet) {
        if (programBlock instanceof WhileProgramBlock) {
            Iterator<ProgramBlock> it = ((WhileProgramBlock) programBlock).getChildBlocks().iterator();
            while (it.hasNext()) {
                rCleanupRuntimeProgram(it.next(), hashSet);
            }
            return;
        }
        if (programBlock instanceof IfProgramBlock) {
            IfProgramBlock ifProgramBlock = (IfProgramBlock) programBlock;
            Iterator<ProgramBlock> it2 = ifProgramBlock.getChildBlocksIfBody().iterator();
            while (it2.hasNext()) {
                rCleanupRuntimeProgram(it2.next(), hashSet);
            }
            Iterator<ProgramBlock> it3 = ifProgramBlock.getChildBlocksElseBody().iterator();
            while (it3.hasNext()) {
                rCleanupRuntimeProgram(it3.next(), hashSet);
            }
            return;
        }
        if (programBlock instanceof ForProgramBlock) {
            Iterator<ProgramBlock> it4 = ((ForProgramBlock) programBlock).getChildBlocks().iterator();
            while (it4.hasNext()) {
                rCleanupRuntimeProgram(it4.next(), hashSet);
            }
        } else if (programBlock instanceof BasicProgramBlock) {
            BasicProgramBlock basicProgramBlock = (BasicProgramBlock) programBlock;
            basicProgramBlock.setInstructions(cleanupRuntimeInstructions(basicProgramBlock.getInstructions(), hashSet));
        }
    }

    public static ArrayList<Instruction> cleanupRuntimeInstructions(ArrayList<Instruction> arrayList, String... strArr) {
        return cleanupRuntimeInstructions(arrayList, (HashSet<String>) new HashSet(Arrays.asList(strArr)));
    }

    public static ArrayList<Instruction> cleanupRuntimeInstructions(ArrayList<Instruction> arrayList, HashSet<String> hashSet) {
        ArrayList<Instruction> arrayList2 = new ArrayList<>();
        Iterator<Instruction> it = arrayList.iterator();
        while (it.hasNext()) {
            Instruction next = it.next();
            if ((next instanceof VariableCPInstruction) && ((VariableCPInstruction) next).isRemoveVariable()) {
                ArrayList arrayList3 = new ArrayList();
                for (CPOperand cPOperand : ((VariableCPInstruction) next).getInputs()) {
                    if (!hashSet.contains(cPOperand.getName())) {
                        arrayList3.add(cPOperand.getName());
                    }
                }
                if (!arrayList3.isEmpty()) {
                    arrayList2.add(VariableCPInstruction.prepareRemoveInstruction((String[]) arrayList3.toArray(new String[0])));
                }
            } else {
                arrayList2.add(next);
            }
        }
        return arrayList2;
    }
}
