package org.apache.sysds.hops;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.recompile.RecompileStatus;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

/* loaded from: input_file:org/apache/sysds/hops/MemoTable.class */
public class MemoTable {
    private HashMap<Long, DataCharacteristics> _memo;

    public MemoTable() {
        this._memo = null;
        this._memo = new HashMap<>();
    }

    public void init(ArrayList<Hop> arrayList, RecompileStatus recompileStatus) {
        if (arrayList == null || arrayList.isEmpty() || recompileStatus == null || recompileStatus.getTWriteStats().isEmpty()) {
            return;
        }
        Hop.resetVisitStatus(arrayList);
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            rinit(it.next(), recompileStatus);
        }
    }

    public void init(Hop hop, RecompileStatus recompileStatus) {
        if (hop == null || recompileStatus == null || recompileStatus.getTWriteStats().isEmpty()) {
            return;
        }
        hop.resetVisitStatus();
        rinit(hop, recompileStatus);
    }

    public void extract(ArrayList<Hop> arrayList, RecompileStatus recompileStatus) {
        if (recompileStatus == null) {
            return;
        }
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if ((next instanceof DataOp) && ((DataOp) next).getOp() == Types.OpOpData.TRANSIENTWRITE) {
                String name = next.getName();
                DataCharacteristics allInputStats = getAllInputStats(next.getInput().get(0));
                if (allInputStats != null) {
                    recompileStatus.getTWriteStats().put(name, allInputStats);
                } else {
                    recompileStatus.getTWriteStats().remove(name);
                }
            }
        }
    }

    public void memoizeStatistics(long j, long j2, long j3, long j4) {
        this._memo.put(Long.valueOf(j), new MatrixCharacteristics(j2, j3, -1, j4));
    }

    public void memoizeStatistics(long j, DataCharacteristics dataCharacteristics) {
        this._memo.put(Long.valueOf(j), dataCharacteristics);
    }

    public DataCharacteristics[] getAllInputStats(ArrayList<Hop> arrayList) {
        if (arrayList == null) {
            return null;
        }
        MatrixCharacteristics[] matrixCharacteristicsArr = new MatrixCharacteristics[arrayList.size()];
        for (int i = 0; i < arrayList.size(); i++) {
            Hop hop = arrayList.get(i);
            long dim1 = hop.getDim1();
            long dim2 = hop.getDim2();
            long nnz = hop.getNnz();
            if (hop.dimsKnown()) {
                matrixCharacteristicsArr[i] = new MatrixCharacteristics(dim1, dim2, -1, nnz);
            } else {
                DataCharacteristics dataCharacteristics = this._memo.get(Long.valueOf(hop.getHopID()));
                if (dataCharacteristics != null) {
                    dim1 = dim1 <= 0 ? dataCharacteristics.getRows() : dim1;
                    dim2 = dim2 <= 0 ? dataCharacteristics.getCols() : dim2;
                    nnz = nnz <= 0 ? dataCharacteristics.getNonZeros() : nnz;
                }
                matrixCharacteristicsArr[i] = new MatrixCharacteristics(dim1, dim2, -1, nnz);
            }
        }
        return matrixCharacteristicsArr;
    }

    public DataCharacteristics getAllInputStats(Hop hop) {
        MatrixCharacteristics matrixCharacteristics;
        if (hop == null) {
            return null;
        }
        long dim1 = hop.getDim1();
        long dim2 = hop.getDim2();
        long nnz = hop.getNnz();
        if (hop.dimsKnown(true)) {
            matrixCharacteristics = new MatrixCharacteristics(dim1, dim2, -1, nnz);
        } else {
            DataCharacteristics dataCharacteristics = this._memo.get(Long.valueOf(hop.getHopID()));
            if (dataCharacteristics != null) {
                dim1 = dim1 <= 0 ? dataCharacteristics.getRows() : dim1;
                dim2 = dim2 <= 0 ? dataCharacteristics.getCols() : dim2;
                nnz = nnz < 0 ? dataCharacteristics.getNonZeros() : nnz;
            }
            matrixCharacteristics = new MatrixCharacteristics(dim1, dim2, -1, nnz);
        }
        return matrixCharacteristics;
    }

    public boolean hasInputStatistics(Hop hop) {
        boolean z = false;
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if (next.dimsKnownAny() || this._memo.containsKey(Long.valueOf(next.getHopID()))) {
                z = true;
                break;
            }
        }
        if (((hop instanceof DataOp) && ((DataOp) hop).getOp() == Types.OpOpData.TRANSIENTREAD) || (hop instanceof DataGenOp)) {
            z = true;
        }
        return z;
    }

    private void rinit(Hop hop, RecompileStatus recompileStatus) {
        if (hop.isVisited()) {
            return;
        }
        if ((hop instanceof DataOp) && hop.getDataType() == Types.DataType.MATRIX && ((DataOp) hop).getOp() == Types.OpOpData.TRANSIENTREAD) {
            DataCharacteristics dataCharacteristics = recompileStatus.getTWriteStats().get(hop.getName());
            if (dataCharacteristics != null) {
                this._memo.put(Long.valueOf(hop.getHopID()), dataCharacteristics);
            }
        }
        if (hop.getInput() != null && !hop.getInput().isEmpty()) {
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                rinit(it.next(), recompileStatus);
            }
        }
        hop.setVisited();
    }
}
