package org.apache.flink.runtime.state.heap;

import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.StateTransformationFunction;
import org.apache.flink.runtime.state.internal.InternalAggregatingState;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/runtime/state/heap/HeapAggregatingState.class */
class HeapAggregatingState<K, N, IN, ACC, OUT> extends AbstractHeapMergingState<K, N, IN, ACC, OUT> implements InternalAggregatingState<K, N, IN, ACC, OUT> {
    private AggregateTransformation<IN, ACC, OUT> aggregateTransformation;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/flink/runtime/state/heap/HeapAggregatingState$AggregateTransformation.class */
    public static final class AggregateTransformation<IN, ACC, OUT> implements StateTransformationFunction<ACC, IN> {
        private final AggregateFunction<IN, ACC, OUT> aggFunction;

        AggregateTransformation(AggregateFunction<IN, ACC, OUT> aggregateFunction) {
            this.aggFunction = (AggregateFunction) Preconditions.checkNotNull(aggregateFunction);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.apache.flink.runtime.state.StateTransformationFunction
        public ACC apply(ACC acc, IN in) {
            if (acc == null) {
                acc = this.aggFunction.createAccumulator();
            }
            return (ACC) this.aggFunction.add(in, acc);
        }
    }

    private HeapAggregatingState(StateTable<K, N, ACC> stateTable, TypeSerializer<K> typeSerializer, TypeSerializer<ACC> typeSerializer2, TypeSerializer<N> typeSerializer3, ACC acc, AggregateFunction<IN, ACC, OUT> aggregateFunction) {
        super(stateTable, typeSerializer, typeSerializer2, typeSerializer3, acc);
        this.aggregateTransformation = new AggregateTransformation<>(aggregateFunction);
    }

    @Override // org.apache.flink.runtime.state.internal.InternalKvState
    public TypeSerializer<K> getKeySerializer() {
        return this.keySerializer;
    }

    @Override // org.apache.flink.runtime.state.internal.InternalKvState
    public TypeSerializer<N> getNamespaceSerializer() {
        return this.namespaceSerializer;
    }

    @Override // org.apache.flink.runtime.state.internal.InternalKvState
    public TypeSerializer<ACC> getValueSerializer() {
        return (TypeSerializer<ACC>) this.valueSerializer;
    }

    public OUT get() {
        ACC internal = getInternal();
        if (internal != null) {
            return (OUT) ((AggregateTransformation) this.aggregateTransformation).aggFunction.getResult(internal);
        }
        return null;
    }

    public void add(IN in) throws Exception {
        N n = this.currentNamespace;
        if (in == null) {
            clear();
        } else {
            this.stateTable.transform(n, in, this.aggregateTransformation);
        }
    }

    @Override // org.apache.flink.runtime.state.heap.AbstractHeapMergingState
    protected ACC mergeState(ACC acc, ACC acc2) {
        return (ACC) ((AggregateTransformation) this.aggregateTransformation).aggFunction.merge(acc, acc2);
    }

    HeapAggregatingState<K, N, IN, ACC, OUT> setAggregateFunction(AggregateFunction<IN, ACC, OUT> aggregateFunction) {
        this.aggregateTransformation = new AggregateTransformation<>(aggregateFunction);
        return this;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Incorrect return type in method signature: <T:Ljava/lang/Object;K:Ljava/lang/Object;N:Ljava/lang/Object;SV:Ljava/lang/Object;S::Lorg/apache/flink/api/common/state/State;IS:TS;>(Lorg/apache/flink/api/common/state/StateDescriptor<TS;TSV;>;Lorg/apache/flink/runtime/state/heap/StateTable<TK;TN;TSV;>;Lorg/apache/flink/api/common/typeutils/TypeSerializer<TK;>;)TIS; */
    public static State create(StateDescriptor stateDescriptor, StateTable stateTable, TypeSerializer typeSerializer) {
        return new HeapAggregatingState(stateTable, typeSerializer, stateTable.getStateSerializer(), stateTable.getNamespaceSerializer(), stateDescriptor.getDefaultValue(), ((AggregatingStateDescriptor) stateDescriptor).getAggregateFunction());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Incorrect return type in method signature: <T:Ljava/lang/Object;K:Ljava/lang/Object;N:Ljava/lang/Object;SV:Ljava/lang/Object;S::Lorg/apache/flink/api/common/state/State;IS:TS;>(Lorg/apache/flink/api/common/state/StateDescriptor<TS;TSV;>;Lorg/apache/flink/runtime/state/heap/StateTable<TK;TN;TSV;>;TIS;)TIS; */
    /* JADX WARN: Multi-variable type inference failed */
    public static State update(StateDescriptor stateDescriptor, StateTable stateTable, State state) {
        return ((HeapAggregatingState) state).setAggregateFunction(((AggregatingStateDescriptor) stateDescriptor).getAggregateFunction()).setNamespaceSerializer(stateTable.getNamespaceSerializer()).setValueSerializer(stateTable.getStateSerializer()).setDefaultValue(stateDescriptor.getDefaultValue());
    }
}
