package org.apache.sysds.runtime.controlprogram.federated;

import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.class */
public class FederatedWorkloadAnalyzer {
    protected static final Log LOG = LogFactory.getLog(FederatedWorkerHandler.class.getName());
    private static int compressRunFrequency = 10;
    private final ConcurrentHashMap<Long, ConcurrentHashMap<Long, InstructionTypeCounter>> m = new ConcurrentHashMap<>();
    private int counter = 0;

    public void incrementWorkload(ExecutionContext executionContext, long j, Instruction instruction) {
        if (instruction instanceof ComputationCPInstruction) {
            incrementWorkload(executionContext, j, (ComputationCPInstruction) instruction);
        }
    }

    public void compressRun(ExecutionContext executionContext, long j) {
        if (this.counter % compressRunFrequency == compressRunFrequency - 1) {
            get(j).forEach((l, instructionTypeCounter) -> {
                CompressedMatrixBlockFactory.compressAsync(executionContext, Long.toString(l.longValue()), instructionTypeCounter);
            });
        }
    }

    private void incrementWorkload(ExecutionContext executionContext, long j, ComputationCPInstruction computationCPInstruction) {
        incrementWorkload(executionContext, get(j), computationCPInstruction);
    }

    public void incrementWorkload(ExecutionContext executionContext, ConcurrentHashMap<Long, InstructionTypeCounter> concurrentHashMap, ComputationCPInstruction computationCPInstruction) {
        if (computationCPInstruction instanceof AggregateBinaryCPInstruction) {
            String name = computationCPInstruction.input1.getName();
            MatrixObject matrixObject = (MatrixObject) executionContext.getCacheableData(name);
            String name2 = computationCPInstruction.input2.getName();
            MatrixObject matrixObject2 = (MatrixObject) executionContext.getCacheableData(name2);
            int dim = (int) matrixObject.getDim(0);
            int dim2 = (int) matrixObject.getDim(1);
            int dim3 = (int) matrixObject2.getDim(0);
            int dim4 = (int) matrixObject2.getDim(1);
            if (validSize(dim, dim2)) {
                getOrMakeCounter(concurrentHashMap, Long.parseLong(name)).incRMM(dim);
                this.counter++;
            }
            if (validSize(dim3, dim4)) {
                getOrMakeCounter(concurrentHashMap, Long.parseLong(name2)).incLMM(dim4);
                this.counter++;
            }
        }
    }

    private static InstructionTypeCounter getOrMakeCounter(ConcurrentHashMap<Long, InstructionTypeCounter> concurrentHashMap, long j) {
        if (concurrentHashMap.containsKey(Long.valueOf(j))) {
            return concurrentHashMap.get(Long.valueOf(j));
        }
        InstructionTypeCounter instructionTypeCounter = new InstructionTypeCounter();
        concurrentHashMap.put(Long.valueOf(j), instructionTypeCounter);
        return instructionTypeCounter;
    }

    private ConcurrentHashMap<Long, InstructionTypeCounter> get(long j) {
        if (this.m.containsKey(Long.valueOf(j))) {
            return this.m.get(Long.valueOf(j));
        }
        ConcurrentHashMap<Long, InstructionTypeCounter> concurrentHashMap = new ConcurrentHashMap<>();
        this.m.put(Long.valueOf(j), concurrentHashMap);
        return concurrentHashMap;
    }

    private static boolean validSize(int i, int i2) {
        return i > 90 && i >= i2;
    }
}
