package org.apache.sysds.runtime.controlprogram.paramserv;

import java.io.IOException;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.MultiThreadedHop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.util.ProgramConverter;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.class */
public class ParamservUtils {
    public static final String PS_FUNC_PREFIX = "_ps_";
    protected static final Log LOG = LogFactory.getLog(ParamservUtils.class.getName());
    public static long SEED = -1;

    public static ListObject copyList(ListObject listObject, boolean z) {
        ListObject listObject2 = new ListObject((List<Data>) IntStream.range(0, listObject.getLength()).mapToObj(i -> {
            Data slice = listObject.slice(i);
            if (slice instanceof MatrixObject) {
                return createShallowCopy((MatrixObject) slice);
            }
            if ((slice instanceof ListObject) || (slice instanceof FrameObject)) {
                throw new DMLRuntimeException("Copy list: does not support list or frame.");
            }
            return slice;
        }).collect(Collectors.toList()), listObject.getNames());
        if (z) {
            cleanupListObject(listObject);
        }
        return listObject2;
    }

    public static void cleanupListObject(ExecutionContext executionContext, String str) {
        ListObject listObject = (ListObject) executionContext.removeVariable(str);
        cleanupListObject(executionContext, listObject, listObject.getStatus());
    }

    public static void cleanupListObject(ExecutionContext executionContext, String str, boolean[] zArr) {
        cleanupListObject(executionContext, (ListObject) executionContext.removeVariable(str), zArr);
    }

    public static void cleanupListObject(ExecutionContext executionContext, ListObject listObject) {
        cleanupListObject(executionContext, listObject, listObject.getStatus());
    }

    public static void cleanupListObject(ExecutionContext executionContext, ListObject listObject, boolean[] zArr) {
        for (int i = 0; i < listObject.getLength(); i++) {
            if (zArr == null || zArr[i]) {
                cleanupData(executionContext, listObject.getData().get(i));
            }
        }
    }

    public static void cleanupData(ExecutionContext executionContext, Data data) {
        if (data instanceof CacheableData) {
            CacheableData<?> cacheableData = (CacheableData) data;
            cacheableData.enableCleanup(true);
            executionContext.cleanupCacheableData(cacheableData);
        }
    }

    public static void cleanupData(ExecutionContext executionContext, String str) {
        cleanupData(executionContext, executionContext.removeVariable(str));
    }

    public static void cleanupListObject(ListObject listObject) {
        cleanupListObject(ExecutionContextFactory.createContext(), listObject);
    }

    public static MatrixObject newMatrixObject(MatrixBlock matrixBlock) {
        return newMatrixObject(matrixBlock, true);
    }

    public static MatrixObject newMatrixObject(MatrixBlock matrixBlock, boolean z) {
        MatrixObject matrixObject = new MatrixObject(Types.ValueType.FP64, OptimizerUtils.getUniqueTempFileName(), new MetaDataFormat(new MatrixCharacteristics(-1L, -1L, ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize()), Types.FileFormat.BINARY));
        matrixObject.acquireModify(matrixBlock);
        matrixObject.release();
        matrixObject.enableCleanup(z);
        return matrixObject;
    }

    public static MatrixObject createShallowCopy(MatrixObject matrixObject) {
        return newMatrixObject(matrixObject.acquireReadAndRelease(), false);
    }

    public static MatrixObject sliceMatrix(MatrixObject matrixObject, long j, long j2) {
        return newMatrixObject(sliceMatrixBlock(matrixObject.acquireReadAndRelease(), j, j2), false);
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [org.apache.sysds.runtime.matrix.data.MatrixBlock] */
    public static MatrixBlock sliceMatrixBlock(MatrixBlock matrixBlock, long j, long j2) {
        return matrixBlock.slice2(((int) j) - 1, ((int) j2) - 1);
    }

    public static MatrixBlock generatePermutation(int i, long j) {
        return new MatrixBlock(i, 1, false).ctableSeqOperations(MatrixBlock.sampleOperations(i, i, false, j), 1.0d, new MatrixBlock(i, i, true));
    }

    public static MatrixBlock generateSubsampleMatrix(int i, int i2, long j) {
        return new MatrixBlock(i, i2, false).ctableSeqOperations(MatrixBlock.sampleOperations(i2, i, false, j), 1.0d, new MatrixBlock(i, i2, true), false);
    }

    public static MatrixBlock generateReplicationMatrix(int i, int i2, long j) {
        return new MatrixBlock(i, i2, false).ctableSeqOperations(MatrixBlock.sampleOperations(i2, i, true, j), 1.0d, new MatrixBlock(i, i2, true), false);
    }

    public static ExecutionContext createExecutionContext(ExecutionContext executionContext, LocalVariableMap localVariableMap, String str, String str2, int i) {
        return createExecutionContext(executionContext, localVariableMap, str, str2, i, false);
    }

    public static ExecutionContext createExecutionContext(ExecutionContext executionContext, LocalVariableMap localVariableMap, String str, String str2, int i, boolean z) {
        Program program = executionContext.getProgram();
        recompileProgramBlocks(i, program.getProgramBlocks(), z);
        program.getFunctionProgramBlocks(program.getFunctionProgramBlocks(false).isEmpty()).forEach((str3, functionProgramBlock) -> {
            recompileProgramBlocks(i, functionProgramBlock.getChildBlocks(), z);
        });
        return ExecutionContextFactory.createContext(new LocalVariableMap(localVariableMap), copyProgramFunctions(program));
    }

    public static List<ExecutionContext> copyExecutionContext(ExecutionContext executionContext, int i) {
        return (List) IntStream.range(0, i).mapToObj(i2 -> {
            return ExecutionContextFactory.createContext(new LocalVariableMap(executionContext.getVariables()), copyProgramFunctions(executionContext.getProgram()));
        }).collect(Collectors.toList());
    }

    private static Program copyProgramFunctions(Program program) {
        Program program2 = new Program(program.getDMLProg());
        boolean isEmpty = program.getFunctionProgramBlocks(false).isEmpty();
        for (Map.Entry<String, FunctionProgramBlock> entry : program.getFunctionProgramBlocks(isEmpty).entrySet()) {
            String[] splitFunctionKey = DMLProgram.splitFunctionKey(entry.getKey());
            FunctionProgramBlock createDeepCopyFunctionProgramBlock = ProgramConverter.createDeepCopyFunctionProgramBlock(entry.getValue(), new HashSet(), new HashSet());
            createDeepCopyFunctionProgramBlock._namespace = splitFunctionKey[0];
            createDeepCopyFunctionProgramBlock._functionName = splitFunctionKey[1];
            program2.addFunctionProgramBlock(splitFunctionKey[0], splitFunctionKey[1], createDeepCopyFunctionProgramBlock, isEmpty);
            program2.addProgramBlock(createDeepCopyFunctionProgramBlock);
        }
        return program2;
    }

    public static void recompileProgramBlocks(int i, List<ProgramBlock> list) {
        recompileProgramBlocks(i, list, false);
    }

    public static void recompileProgramBlocks(int i, List<ProgramBlock> list, boolean z) {
        Iterator<ProgramBlock> it = list.iterator();
        while (it.hasNext()) {
            DMLTranslator.resetHopsDAGVisitStatus(it.next().getStatementBlock());
        }
        try {
            if (z) {
                rAssignParallelismAndRecompile(list, i, true, z);
            } else {
                rAssignParallelismAndRecompile(list, i, false, z);
            }
        } catch (IOException e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static boolean rAssignParallelismAndRecompile(List<ProgramBlock> list, int i, boolean z, boolean z2) throws IOException {
        for (ProgramBlock programBlock : list) {
            if (programBlock instanceof ParForProgramBlock) {
                ParForProgramBlock parForProgramBlock = (ParForProgramBlock) programBlock;
                if (!parForProgramBlock.isDegreeOfParallelismFixed()) {
                    parForProgramBlock.setDegreeOfParallelism(i);
                    if (i == 1) {
                        parForProgramBlock.setOptimizationMode(ParForProgramBlock.POptMode.NONE);
                    }
                    z |= rAssignParallelismAndRecompile(parForProgramBlock.getChildBlocks(), 1, z, z2);
                }
            } else if (programBlock instanceof ForProgramBlock) {
                z |= rAssignParallelismAndRecompile(((ForProgramBlock) programBlock).getChildBlocks(), i, z, z2);
            } else if (programBlock instanceof WhileProgramBlock) {
                z |= rAssignParallelismAndRecompile(((WhileProgramBlock) programBlock).getChildBlocks(), i, z, z2);
            } else if (programBlock instanceof FunctionProgramBlock) {
                z |= rAssignParallelismAndRecompile(((FunctionProgramBlock) programBlock).getChildBlocks(), i, z, z2);
            } else if (programBlock instanceof IfProgramBlock) {
                IfProgramBlock ifProgramBlock = (IfProgramBlock) programBlock;
                z |= rAssignParallelismAndRecompile(ifProgramBlock.getChildBlocksIfBody(), i, z, z2);
                if (ifProgramBlock.getChildBlocksElseBody() != null) {
                    z |= rAssignParallelismAndRecompile(ifProgramBlock.getChildBlocksElseBody(), i, z, z2);
                }
            } else {
                Iterator<Hop> it = programBlock.getStatementBlock().getHops().iterator();
                while (it.hasNext()) {
                    z |= rAssignParallelismAndRecompile(it.next(), i, z);
                }
            }
            if (z) {
                if (z2) {
                    Recompiler.rRecompileProgramBlock2Forced(programBlock, programBlock.getThreadID(), new HashSet(), Types.ExecType.CP);
                } else {
                    Recompiler.recompileProgramBlockInstructions(programBlock);
                }
            }
        }
        return z;
    }

    private static boolean rAssignParallelismAndRecompile(Hop hop, int i, boolean z) {
        if (hop.isVisited()) {
            return z;
        }
        if (hop instanceof MultiThreadedHop) {
            ((MultiThreadedHop) hop).setMaxNumThreads(i);
            z = true;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            z |= rAssignParallelismAndRecompile(it.next(), i, z);
        }
        hop.setVisited();
        return z;
    }

    private static FunctionProgramBlock getFunctionBlock(ExecutionContext executionContext, String str) {
        String[] splitFunctionKey = DMLProgram.splitFunctionKey(str);
        return executionContext.getProgram().getFunctionProgramBlock(splitFunctionKey[0], splitFunctionKey[1]);
    }

    public static MatrixBlock cbindMatrix(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        return matrixBlock.append(matrixBlock2, new MatrixBlock());
    }

    public static ListObject accrueGradients(ListObject listObject, ListObject listObject2, boolean z) {
        return accrueGradients(listObject, listObject2, false, z);
    }

    public static ListObject accrueGradients(ListObject listObject, ListObject listObject2, boolean z, boolean z2) {
        if (listObject == null) {
            return copyList(listObject2, z2);
        }
        IntStream range = IntStream.range(0, listObject.getLength());
        (z ? range.parallel() : range).forEach(i -> {
            ((MatrixObject) listObject.getData().get(i)).acquireReadAndRelease().binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), (MatrixValue) ((MatrixObject) listObject2.getData().get(i)).acquireReadAndRelease());
        });
        if (z2) {
            cleanupListObject(listObject2);
        }
        return listObject;
    }

    public static ListObject accrueModels(ListObject listObject, ListObject listObject2, boolean z) {
        return accrueModels(listObject, listObject2, false, z);
    }

    public static ListObject accrueModels(ListObject listObject, ListObject listObject2, boolean z, boolean z2) {
        if (listObject == null) {
            return copyList(listObject2, z2);
        }
        IntStream range = IntStream.range(0, listObject.getLength());
        (z ? range.parallel() : range).forEach(i -> {
            ((MatrixObject) listObject.getData().get(i)).acquireReadAndRelease().binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), (MatrixValue) ((MatrixObject) listObject2.getData().get(i)).acquireReadAndRelease());
        });
        if (z2) {
            cleanupListObject(listObject2);
        }
        return listObject;
    }
}
