package org.apache.sysds.runtime.compress.cost;

import java.io.Serializable;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.runtime.compress.workload.Op;
import org.apache.sysds.runtime.compress.workload.OpMetadata;
import org.apache.sysds.runtime.compress.workload.OpSided;
import org.apache.sysds.runtime.compress.workload.WTreeNode;
import org.apache.sysds.runtime.compress.workload.WTreeRoot;

/* loaded from: input_file:org/apache/sysds/runtime/compress/cost/CostEstimatorBuilder.class */
public final class CostEstimatorBuilder implements Serializable {
    private static final long serialVersionUID = 14;
    protected static final Log LOG = LogFactory.getLog(CostEstimatorBuilder.class.getName());
    protected final InstructionTypeCounter counter;

    public CostEstimatorBuilder(WTreeRoot wTreeRoot) {
        this.counter = new InstructionTypeCounter();
        if (wTreeRoot.isDecompressing()) {
            this.counter.decompressions++;
        }
        Iterator<Op> it = wTreeRoot.getOps().iterator();
        while (it.hasNext()) {
            addOp(1, it.next(), this.counter);
        }
        Iterator<WTreeNode> it2 = wTreeRoot.getChildNodes().iterator();
        while (it2.hasNext()) {
            addNode(1, it2.next(), this.counter);
        }
    }

    public CostEstimatorBuilder(InstructionTypeCounter instructionTypeCounter) {
        this.counter = instructionTypeCounter;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ACostEstimate create(boolean z) {
        return new ComputationCostEstimator(this.counter);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ACostEstimate createHybrid() {
        return new HybridCostEstimator(this.counter);
    }

    public InstructionTypeCounter getCounter() {
        return this.counter;
    }

    private static void addNode(int i, WTreeNode wTreeNode, InstructionTypeCounter instructionTypeCounter) {
        int reps = wTreeNode.getReps();
        Iterator<Op> it = wTreeNode.getOps().iterator();
        while (it.hasNext()) {
            addOp(i * reps, it.next(), instructionTypeCounter);
        }
        Iterator<WTreeNode> it2 = wTreeNode.getChildNodes().iterator();
        while (it2.hasNext()) {
            addNode(i * reps, it2.next(), instructionTypeCounter);
        }
    }

    private static void addOp(int i, Op op, InstructionTypeCounter instructionTypeCounter) {
        if (op.isDecompressing()) {
            if (op.isOverlapping()) {
                instructionTypeCounter.overlappingDecompressions += i * op.dim();
            } else {
                instructionTypeCounter.decompressions += i;
            }
        }
        if (op.isDensifying()) {
            instructionTypeCounter.isDensifying = true;
        }
        if (op instanceof OpSided) {
            OpSided opSided = (OpSided) op;
            int dim = op.dim();
            if (opSided.isLeftMM()) {
                instructionTypeCounter.leftMultiplications += i * dim;
                return;
            } else if (opSided.isRightMM()) {
                instructionTypeCounter.rightMultiplications += i * dim;
                return;
            } else {
                instructionTypeCounter.compressedMultiplications += i * dim;
                return;
            }
        }
        if (op instanceof OpMetadata) {
            return;
        }
        Hop hop = op.getHop();
        if (hop instanceof AggUnaryOp) {
            switch (((AggUnaryOp) op.getHop()).getDirection()) {
                case Row:
                    instructionTypeCounter.scans += i;
                    return;
                default:
                    instructionTypeCounter.dictionaryOps += i;
                    return;
            }
        }
        if (!(hop instanceof IndexingOp)) {
            instructionTypeCounter.dictionaryOps += i;
            return;
        }
        IndexingOp indexingOp = (IndexingOp) hop;
        if (indexingOp.isRowLowerEqualsUpper() && indexingOp.isColLowerEqualsUpper()) {
            instructionTypeCounter.indexing++;
        } else if (indexingOp.isAllRows()) {
            instructionTypeCounter.dictionaryOps += i;
        }
    }

    public boolean shouldTryToCompress() {
        return (0 + ((((this.counter.scans + this.counter.leftMultiplications) + this.counter.rightMultiplications) + this.counter.compressedMultiplications) + this.counter.dictionaryOps)) - (this.counter.decompressions + this.counter.overlappingDecompressions) > 4;
    }

    public String toString() {
        return "CostVector: " + this.counter;
    }
}
