package org.apache.sysds.runtime.instructions.spark.functions;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/functions/ReplicateTensorFunction.class */
public class ReplicateTensorFunction implements PairFlatMapFunction<Tuple2<TensorIndexes, TensorBlock>, TensorIndexes, TensorBlock> {
    private static final long serialVersionUID = 7181347334827684965L;
    private int _byDim;
    private long _numReplicas;

    public ReplicateTensorFunction(int i, long j) {
        this._byDim = i;
        this._numReplicas = j;
    }

    public Iterator<Tuple2<TensorIndexes, TensorBlock>> call(Tuple2<TensorIndexes, TensorBlock> tuple2) throws Exception {
        TensorIndexes tensorIndexes = (TensorIndexes) tuple2._1();
        TensorBlock tensorBlock = (TensorBlock) tuple2._2();
        if (tensorIndexes.getIndex(this._byDim) != 1 || (tensorBlock.getNumDims() > this._byDim && tensorBlock.getDim(this._byDim) > 1)) {
            throw new Exception("Expected dimension " + this._byDim + " to be 1 in ReplicateTensor");
        }
        ArrayList arrayList = new ArrayList();
        long[] indexes = tensorIndexes.getIndexes();
        for (int i = 1; i <= this._numReplicas; i++) {
            indexes[this._byDim] = i;
            arrayList.add(new Tuple2(new TensorIndexes(indexes), tensorBlock));
        }
        return arrayList.iterator();
    }
}
