/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.text.functions;

import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.Accumulator;
import org.apache.spark.AccumulatorParam;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.spark.text.accumulators.MaxPerPartitionAccumulator;
import org.deeplearning4j.spark.text.functions.FoldBetweenPartitionFunction;
import org.deeplearning4j.spark.text.functions.FoldWithinPartitionFunction;
import org.deeplearning4j.spark.text.functions.MapPerPartitionVoidFunction;

public class CountCumSum {
    private JavaSparkContext sc;
    private JavaRDD<AtomicLong> sentenceCountRDD;
    private JavaRDD<AtomicLong> foldWithinPartitionRDD;
    private Broadcast<Counter<Integer>> broadcastedMaxPerPartitionCounter;
    private JavaRDD<Long> cumSumRDD;

    public CountCumSum(JavaRDD<AtomicLong> sentenceCountRDD) {
        this.sentenceCountRDD = sentenceCountRDD;
        this.sc = new JavaSparkContext(sentenceCountRDD.context());
    }

    public JavaRDD<Long> getCumSumRDD() {
        if (this.cumSumRDD != null) {
            return this.cumSumRDD;
        }
        throw new IllegalAccessError("Cumulative Sum list not defined. Call buildCumSum() first.");
    }

    public void actionForMapPartition(JavaRDD rdd) {
        rdd.foreachPartition((VoidFunction)new MapPerPartitionVoidFunction());
    }

    public void cumSumWithinPartition() {
        Accumulator maxPerPartitionAcc = this.sc.accumulator((Object)new Counter(), (AccumulatorParam)new MaxPerPartitionAccumulator());
        this.foldWithinPartitionRDD = this.sentenceCountRDD.mapPartitionsWithIndex((Function2)new FoldWithinPartitionFunction((Accumulator<Counter<Integer>>)maxPerPartitionAcc), true).cache();
        this.actionForMapPartition(this.foldWithinPartitionRDD);
        this.broadcastedMaxPerPartitionCounter = this.sc.broadcast(maxPerPartitionAcc.value());
    }

    public void cumSumBetweenPartition() {
        this.cumSumRDD = this.foldWithinPartitionRDD.mapPartitionsWithIndex((Function2)new FoldBetweenPartitionFunction(this.broadcastedMaxPerPartitionCounter), true).setName("cumSumRDD").cache();
        this.foldWithinPartitionRDD.unpersist();
    }

    public JavaRDD<Long> buildCumSum() {
        this.cumSumWithinPartition();
        this.cumSumBetweenPartition();
        return this.getCumSumRDD();
    }
}

