package org.apache.sysds.runtime.instructions.spark.functions;

import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/functions/TensorTensorBinaryOpPartitionFunction.class */
public class TensorTensorBinaryOpPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<TensorIndexes, TensorBlock>>, TensorIndexes, TensorBlock> {
    private static final long serialVersionUID = 8029096658247920867L;
    private BinaryOperator _op;
    private PartitionedBroadcast<TensorBlock> _ptV;
    private boolean[] _replicateDim;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/functions/TensorTensorBinaryOpPartitionFunction$MapBinaryPartitionIterator.class */
    public class MapBinaryPartitionIterator extends LazyIterableIterator<Tuple2<TensorIndexes, TensorBlock>> {
        public MapBinaryPartitionIterator(Iterator<Tuple2<TensorIndexes, TensorBlock>> it) {
            super(it);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator
        public Tuple2<TensorIndexes, TensorBlock> computeNext(Tuple2<TensorIndexes, TensorBlock> tuple2) {
            TensorIndexes tensorIndexes = (TensorIndexes) tuple2._1();
            TensorBlock tensorBlock = (TensorBlock) tuple2._2();
            int[] iArr = new int[TensorTensorBinaryOpPartitionFunction.this._ptV.getDataCharacteristics().getNumDims()];
            for (int i = 0; i < iArr.length; i++) {
                if (TensorTensorBinaryOpPartitionFunction.this._replicateDim[i]) {
                    iArr[i] = 1;
                } else {
                    iArr[i] = (int) tensorIndexes.getIndex(i);
                }
            }
            return new Tuple2<>(tensorIndexes, tensorBlock.binaryOperations(TensorTensorBinaryOpPartitionFunction.this._op, TensorTensorBinaryOpPartitionFunction.this._ptV.getBlock(iArr), new TensorBlock()));
        }
    }

    public TensorTensorBinaryOpPartitionFunction(BinaryOperator binaryOperator, PartitionedBroadcast<TensorBlock> partitionedBroadcast, boolean[] zArr) {
        this._op = binaryOperator;
        this._ptV = partitionedBroadcast;
        this._replicateDim = zArr;
    }

    public LazyIterableIterator<Tuple2<TensorIndexes, TensorBlock>> call(Iterator<Tuple2<TensorIndexes, TensorBlock>> it) throws Exception {
        return new MapBinaryPartitionIterator(it);
    }
}
