package com.amazon.randomcutforest.state;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.ComponentList;
import com.amazon.randomcutforest.IComponentModel;
import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.config.Config;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.executor.PointStoreCoordinator;
import com.amazon.randomcutforest.executor.SamplerPlusTree;
import com.amazon.randomcutforest.sampler.CompactSampler;
import com.amazon.randomcutforest.state.sampler.CompactSamplerMapper;
import com.amazon.randomcutforest.state.sampler.CompactSamplerState;
import com.amazon.randomcutforest.state.store.PointStoreMapper;
import com.amazon.randomcutforest.state.tree.CompactRandomCutTreeContext;
import com.amazon.randomcutforest.state.tree.CompactRandomCutTreeState;
import com.amazon.randomcutforest.state.tree.RandomCutTreeMapper;
import com.amazon.randomcutforest.store.PointStore;
import com.amazon.randomcutforest.tree.RandomCutTree;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import lombok.Generated;

/* loaded from: input_file:com/amazon/randomcutforest/state/RandomCutForestMapper.class */
public class RandomCutForestMapper implements IContextualStateMapper<RandomCutForest, RandomCutForestState, ExecutionContext> {
    private boolean saveTreeStateEnabled = false;
    private boolean saveCoordinatorStateEnabled = true;
    private boolean saveSamplerStateEnabled = true;
    private boolean saveExecutorContextEnabled = false;
    private boolean compressionEnabled = true;
    private boolean partialTreeStateEnabled = false;

    @Override // com.amazon.randomcutforest.state.IContextualStateMapper
    public RandomCutForestState toState(RandomCutForest randomCutForest) {
        RandomCutForestState randomCutForestState = new RandomCutForestState();
        randomCutForestState.setNumberOfTrees(randomCutForest.getNumberOfTrees());
        randomCutForestState.setDimensions(randomCutForest.getDimensions());
        randomCutForestState.setTimeDecay(randomCutForest.getTimeDecay());
        randomCutForestState.setSampleSize(randomCutForest.getSampleSize());
        randomCutForestState.setShingleSize(randomCutForest.getShingleSize());
        randomCutForestState.setCenterOfMassEnabled(randomCutForest.isCenterOfMassEnabled());
        randomCutForestState.setOutputAfter(randomCutForest.getOutputAfter());
        randomCutForestState.setStoreSequenceIndexesEnabled(randomCutForest.isStoreSequenceIndexesEnabled());
        randomCutForestState.setTotalUpdates(randomCutForest.getTotalUpdates());
        randomCutForestState.setCompact(true);
        randomCutForestState.setInternalShinglingEnabled(randomCutForest.isInternalShinglingEnabled());
        randomCutForestState.setBoundingBoxCacheFraction(randomCutForest.getBoundingBoxCacheFraction());
        randomCutForestState.setSaveSamplerStateEnabled(this.saveSamplerStateEnabled);
        randomCutForestState.setSaveTreeStateEnabled(this.saveTreeStateEnabled);
        randomCutForestState.setSaveCoordinatorStateEnabled(this.saveCoordinatorStateEnabled);
        randomCutForestState.setPrecision(randomCutForest.getPrecision().name());
        randomCutForestState.setCompressed(this.compressionEnabled);
        randomCutForestState.setPartialTreeState(this.partialTreeStateEnabled);
        randomCutForestState.setCurrentlySampling(randomCutForest.isCurrentlySampling());
        if (this.saveExecutorContextEnabled) {
            ExecutionContext executionContext = new ExecutionContext();
            executionContext.setParallelExecutionEnabled(randomCutForest.isParallelExecutionEnabled());
            executionContext.setThreadPoolSize(randomCutForest.getThreadPoolSize());
            randomCutForestState.setExecutionContext(executionContext);
        }
        if (this.saveCoordinatorStateEnabled) {
            PointStoreCoordinator pointStoreCoordinator = (PointStoreCoordinator) randomCutForest.getUpdateCoordinator();
            PointStoreMapper pointStoreMapper = new PointStoreMapper();
            pointStoreMapper.setCompressionEnabled(this.compressionEnabled);
            pointStoreMapper.setNumberOfTrees(randomCutForest.getNumberOfTrees());
            randomCutForestState.setPointStoreState(pointStoreMapper.toState((PointStore) pointStoreCoordinator.getStore()));
        }
        ArrayList arrayList = null;
        if (this.saveSamplerStateEnabled) {
            arrayList = new ArrayList();
        }
        ArrayList arrayList2 = null;
        if (this.saveTreeStateEnabled) {
            arrayList2 = new ArrayList();
        }
        CompactSamplerMapper compactSamplerMapper = new CompactSamplerMapper();
        compactSamplerMapper.setCompressionEnabled(this.compressionEnabled);
        Iterator<IComponentModel<PointReference, Point>> it = randomCutForest.getComponents().iterator();
        while (it.hasNext()) {
            SamplerPlusTree samplerPlusTree = (SamplerPlusTree) ((IComponentModel) it.next());
            CompactSampler compactSampler = (CompactSampler) samplerPlusTree.getSampler();
            if (arrayList != null) {
                arrayList.add(compactSamplerMapper.toState(compactSampler));
            }
            if (arrayList2 != null) {
                arrayList2.add(samplerPlusTree.getTree());
            }
        }
        randomCutForestState.setCompactSamplerStates(arrayList);
        if (arrayList2 != null) {
            RandomCutTreeMapper randomCutTreeMapper = new RandomCutTreeMapper();
            randomCutForestState.setCompactRandomCutTreeStates((List) arrayList2.stream().map(iTree -> {
                return randomCutTreeMapper.toState((RandomCutTree) iTree);
            }).collect(Collectors.toList()));
        }
        return randomCutForestState;
    }

    @Override // com.amazon.randomcutforest.state.IContextualStateMapper
    public RandomCutForest toModel(RandomCutForestState randomCutForestState, ExecutionContext executionContext, long j) {
        ExecutionContext executionContext2;
        RandomCutTree build;
        if (executionContext != null) {
            executionContext2 = executionContext;
        } else {
            CommonUtils.checkNotNull(randomCutForestState.getExecutionContext(), "The executor context in the state object is null, an executor context must be passed explicitly to toModel()");
            executionContext2 = randomCutForestState.getExecutionContext();
        }
        RandomCutForest.Builder randomSeed = RandomCutForest.builder().numberOfTrees(randomCutForestState.getNumberOfTrees()).dimensions(randomCutForestState.getDimensions()).timeDecay(randomCutForestState.getTimeDecay()).sampleSize(randomCutForestState.getSampleSize()).centerOfMassEnabled(randomCutForestState.isCenterOfMassEnabled()).outputAfter(randomCutForestState.getOutputAfter()).parallelExecutionEnabled(executionContext2.isParallelExecutionEnabled()).threadPoolSize(executionContext2.getThreadPoolSize()).storeSequenceIndexesEnabled(randomCutForestState.isStoreSequenceIndexesEnabled()).shingleSize(randomCutForestState.getShingleSize()).boundingBoxCacheFraction(randomCutForestState.getBoundingBoxCacheFraction()).internalShinglingEnabled(randomCutForestState.isInternalShinglingEnabled()).randomSeed(j);
        CommonUtils.validateInternalState(Precision.valueOf(randomCutForestState.getPrecision()) == Precision.FLOAT_32, " use randomcutforest-serialization package");
        Random random = randomSeed.getRandom();
        ComponentList componentList = new ComponentList();
        CompactRandomCutTreeContext compactRandomCutTreeContext = new CompactRandomCutTreeContext();
        PointStore model = new PointStoreMapper().toModel(randomCutForestState.getPointStoreState());
        PointStoreCoordinator pointStoreCoordinator = new PointStoreCoordinator(model);
        pointStoreCoordinator.setTotalUpdates(randomCutForestState.getTotalUpdates());
        compactRandomCutTreeContext.setPointStore(model);
        compactRandomCutTreeContext.setMaxSize(randomCutForestState.getSampleSize());
        RandomCutTreeMapper randomCutTreeMapper = new RandomCutTreeMapper();
        List<CompactRandomCutTreeState> compactRandomCutTreeStates = randomCutForestState.isSaveTreeStateEnabled() ? randomCutForestState.getCompactRandomCutTreeStates() : null;
        CompactSamplerMapper compactSamplerMapper = new CompactSamplerMapper();
        CommonUtils.checkArgument(randomCutForestState.isSaveSamplerStateEnabled(), "samplers are not saved; no forest to reconstruct");
        List<CompactSamplerState> compactSamplerStates = randomCutForestState.getCompactSamplerStates();
        for (int i = 0; i < randomCutForestState.getNumberOfTrees(); i++) {
            CompactSampler model2 = compactSamplerMapper.toModel(compactSamplerStates.get(i), random.nextLong());
            if (compactRandomCutTreeStates != null) {
                build = randomCutTreeMapper.toModel(compactRandomCutTreeStates.get(i), compactRandomCutTreeContext, random.nextLong());
                model2.getSample().forEach(iSampled -> {
                    build.addPointToPartialTree((Integer) iSampled.getValue(), iSampled.getSequenceIndex());
                });
                build.setConfig(Config.BOUNDING_BOX_CACHE_FRACTION, compactRandomCutTreeStates.get(i).getBoundingBoxCacheFraction());
                build.validateAndReconstruct();
            } else {
                build = new RandomCutTree.Builder().capacity(randomCutForestState.getSampleSize()).randomSeed(random.nextLong()).pointStoreView(model).boundingBoxCacheFraction(randomCutForestState.getBoundingBoxCacheFraction()).centerOfMassEnabled(randomCutForestState.isCenterOfMassEnabled()).storeSequenceIndexesEnabled(randomCutForestState.isStoreSequenceIndexesEnabled()).build();
                model2.getSample().forEach(iSampled2 -> {
                    build.addPoint((Integer) iSampled2.getValue(), iSampled2.getSequenceIndex());
                });
            }
            componentList.add(new SamplerPlusTree(model2, build));
        }
        RandomCutForest randomCutForest = new RandomCutForest(randomSeed, pointStoreCoordinator, componentList, random);
        if (!randomCutForestState.isCurrentlySampling()) {
            randomCutForest.pauseSampling();
        }
        return randomCutForest;
    }

    public RandomCutForest toModel(RandomCutForestState randomCutForestState, long j) {
        return toModel(randomCutForestState, (ExecutionContext) null, j);
    }

    public RandomCutForest toModel(RandomCutForestState randomCutForestState) {
        return toModel((RandomCutForestMapper) randomCutForestState, (RandomCutForestState) null);
    }

    @Generated
    public boolean isSaveTreeStateEnabled() {
        return this.saveTreeStateEnabled;
    }

    @Generated
    public boolean isSaveCoordinatorStateEnabled() {
        return this.saveCoordinatorStateEnabled;
    }

    @Generated
    public boolean isSaveSamplerStateEnabled() {
        return this.saveSamplerStateEnabled;
    }

    @Generated
    public boolean isSaveExecutorContextEnabled() {
        return this.saveExecutorContextEnabled;
    }

    @Generated
    public boolean isCompressionEnabled() {
        return this.compressionEnabled;
    }

    @Generated
    public boolean isPartialTreeStateEnabled() {
        return this.partialTreeStateEnabled;
    }

    @Generated
    public void setSaveTreeStateEnabled(boolean z) {
        this.saveTreeStateEnabled = z;
    }

    @Generated
    public void setSaveCoordinatorStateEnabled(boolean z) {
        this.saveCoordinatorStateEnabled = z;
    }

    @Generated
    public void setSaveSamplerStateEnabled(boolean z) {
        this.saveSamplerStateEnabled = z;
    }

    @Generated
    public void setSaveExecutorContextEnabled(boolean z) {
        this.saveExecutorContextEnabled = z;
    }

    @Generated
    public void setCompressionEnabled(boolean z) {
        this.compressionEnabled = z;
    }

    @Generated
    public void setPartialTreeStateEnabled(boolean z) {
        this.partialTreeStateEnabled = z;
    }
}
