package org.apache.sysds.runtime.instructions.spark.functions;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/functions/ReplicateVectorFunction.class */
public class ReplicateVectorFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
    private static final long serialVersionUID = -1505557561471236851L;
    private boolean _byRow;
    private long _numReplicas;

    public ReplicateVectorFunction(boolean z, long j) {
        this._byRow = z;
        this._numReplicas = j;
    }

    public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
        MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
        MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
        if (this._byRow && (matrixIndexes.getRowIndex() != 1 || matrixBlock.getNumRows() > 1)) {
            throw new Exception("Expected a row vector in ReplicateVector");
        }
        if (!this._byRow && (matrixIndexes.getColumnIndex() != 1 || matrixBlock.getNumColumns() > 1)) {
            throw new Exception("Expected a column vector in ReplicateVector");
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 1; i <= this._numReplicas; i++) {
            if (this._byRow) {
                arrayList.add(new Tuple2(new MatrixIndexes(i, matrixIndexes.getColumnIndex()), matrixBlock));
            } else {
                arrayList.add(new Tuple2(new MatrixIndexes(matrixIndexes.getRowIndex(), i), matrixBlock));
            }
        }
        return arrayList.iterator();
    }
}
