package org.apache.sysds.runtime.functionobjects;

import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.meta.DataCharacteristics;

/* loaded from: input_file:org/apache/sysds/runtime/functionobjects/DiagIndex.class */
public class DiagIndex extends IndexFunction {
    private static final long serialVersionUID = -5294771266108903886L;
    private final boolean diagV2M;

    private DiagIndex(boolean z) {
        this.diagV2M = z;
    }

    public static DiagIndex getDiagIndexFnObject() {
        return getDiagIndexFnObject(true);
    }

    public static DiagIndex getDiagIndexFnObject(boolean z) {
        return new DiagIndex(z);
    }

    @Override // org.apache.sysds.runtime.functionobjects.IndexFunction
    public void execute(MatrixIndexes matrixIndexes, MatrixIndexes matrixIndexes2) {
        matrixIndexes2.setIndexes(matrixIndexes.getRowIndex(), this.diagV2M ? matrixIndexes.getRowIndex() : 1L);
    }

    @Override // org.apache.sysds.runtime.functionobjects.IndexFunction
    public void execute(MatrixValue.CellIndex cellIndex, MatrixValue.CellIndex cellIndex2) {
        cellIndex2.set(cellIndex.row, cellIndex.row);
    }

    @Override // org.apache.sysds.runtime.functionobjects.IndexFunction
    public boolean computeDimension(int i, int i2, MatrixValue.CellIndex cellIndex) {
        if (i2 == 1) {
            cellIndex.set(i, i);
            return false;
        }
        cellIndex.set(i, 1);
        return false;
    }

    @Override // org.apache.sysds.runtime.functionobjects.IndexFunction
    public boolean computeDimension(DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2) {
        if (dataCharacteristics.getCols() == 1) {
            dataCharacteristics2.set(dataCharacteristics.getRows(), dataCharacteristics.getRows(), dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
            return false;
        }
        dataCharacteristics2.set(dataCharacteristics.getRows(), 1L, dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
        return false;
    }
}
