package org.apache.drill.exec.planner.physical.visitor;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.drill.exec.ExecConstants;
import org.apache.drill.exec.ops.QueryContext;
import org.apache.drill.exec.planner.physical.BroadcastExchangePrel;
import org.apache.drill.exec.planner.physical.ExchangePrel;
import org.apache.drill.exec.planner.physical.HashAggPrel;
import org.apache.drill.exec.planner.physical.HashJoinPrel;
import org.apache.drill.exec.planner.physical.JoinPrel;
import org.apache.drill.exec.planner.physical.Prel;
import org.apache.drill.exec.planner.physical.RuntimeFilterPrel;
import org.apache.drill.exec.planner.physical.ScanPrel;
import org.apache.drill.exec.planner.physical.SortPrel;
import org.apache.drill.exec.planner.physical.StreamAggPrel;
import org.apache.drill.exec.planner.physical.TopNPrel;
import org.apache.drill.exec.work.filter.BloomFilter;
import org.apache.drill.exec.work.filter.BloomFilterDef;
import org.apache.drill.exec.work.filter.RuntimeFilterDef;
import org.apache.drill.shaded.guava.com.google.common.collect.HashMultimap;
import org.apache.drill.shaded.guava.com.google.common.collect.Multimap;

/* loaded from: input_file:org/apache/drill/exec/planner/physical/visitor/RuntimeFilterVisitor.class */
public class RuntimeFilterVisitor extends BasePrelVisitor<Prel, Void, RuntimeException> {
    private final Set<ScanPrel> toAddRuntimeFilter = new HashSet();
    private final Multimap<ScanPrel, HashJoinPrel> probeSideScan2hj = HashMultimap.create();
    private final double fpp;
    private final int bloomFilterMaxSizeInBytesDef;
    private static final AtomicLong rfIdCounter = new AtomicLong();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/drill/exec/planner/physical/visitor/RuntimeFilterVisitor$BlockNodeVisitor.class */
    public static class BlockNodeVisitor extends BasePrelVisitor<Void, Prel, RuntimeException> {
        private boolean encounteredBlockNode;

        private BlockNodeVisitor() {
        }

        @Override // org.apache.drill.exec.planner.physical.visitor.BasePrelVisitor, org.apache.drill.exec.planner.physical.visitor.PrelVisitor
        public Void visitPrel(Prel prel, Prel prel2) throws RuntimeException {
            if (prel == prel2) {
                return null;
            }
            Prel prel3 = prel instanceof RelSubset ? (Prel) ((RelSubset) prel).getBest() : prel;
            if (prel3 == null) {
                return null;
            }
            if (prel3 instanceof StreamAggPrel) {
                this.encounteredBlockNode = true;
                return null;
            }
            if (prel3 instanceof HashAggPrel) {
                this.encounteredBlockNode = true;
                return null;
            }
            if (prel3 instanceof SortPrel) {
                this.encounteredBlockNode = true;
                return null;
            }
            if (prel3 instanceof TopNPrel) {
                this.encounteredBlockNode = true;
                return null;
            }
            if (prel3 instanceof HashJoinPrel) {
                this.encounteredBlockNode = true;
                return null;
            }
            Iterator<Prel> it = prel3.iterator();
            while (it.hasNext()) {
                visitPrel(it.next(), prel2);
            }
            return null;
        }

        public boolean isEncounteredBlockNode() {
            return this.encounteredBlockNode;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/drill/exec/planner/physical/visitor/RuntimeFilterVisitor$RFHelperHolder.class */
    public static class RFHelperHolder {
        private boolean fromBuildSide;
        private ExchangePrel exchangePrel;

        private RFHelperHolder() {
        }

        public void setBuildSideExchange(ExchangePrel exchangePrel) {
            this.exchangePrel = exchangePrel;
        }

        public boolean needToRouteToForeman() {
            return (this.exchangePrel == null || (this.exchangePrel instanceof BroadcastExchangePrel)) ? false : true;
        }

        public boolean isFromBuildSide() {
            return this.fromBuildSide;
        }

        public void setFromBuildSide(boolean z) {
            this.fromBuildSide = z;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/drill/exec/planner/physical/visitor/RuntimeFilterVisitor$RuntimeFilterInfoPaddingHelper.class */
    public static class RuntimeFilterInfoPaddingHelper extends BasePrelVisitor<Void, RFHelperHolder, RuntimeException> {
        @Override // org.apache.drill.exec.planner.physical.visitor.BasePrelVisitor, org.apache.drill.exec.planner.physical.visitor.PrelVisitor
        public Void visitPrel(Prel prel, RFHelperHolder rFHelperHolder) throws RuntimeException {
            Iterator<Prel> it = prel.iterator();
            while (it.hasNext()) {
                it.next().accept(this, rFHelperHolder);
            }
            return null;
        }

        @Override // org.apache.drill.exec.planner.physical.visitor.BasePrelVisitor, org.apache.drill.exec.planner.physical.visitor.PrelVisitor
        public Void visitExchange(ExchangePrel exchangePrel, RFHelperHolder rFHelperHolder) throws RuntimeException {
            if (rFHelperHolder != null && rFHelperHolder.isFromBuildSide()) {
                rFHelperHolder.setBuildSideExchange(exchangePrel);
            }
            return visitPrel((Prel) exchangePrel, rFHelperHolder);
        }

        @Override // org.apache.drill.exec.planner.physical.visitor.BasePrelVisitor, org.apache.drill.exec.planner.physical.visitor.PrelVisitor
        public Void visitJoin(JoinPrel joinPrel, RFHelperHolder rFHelperHolder) throws RuntimeException {
            HashJoinPrel hashJoinPrel;
            RuntimeFilterDef runtimeFilterDef;
            if ((joinPrel instanceof HashJoinPrel) && (runtimeFilterDef = (hashJoinPrel = (HashJoinPrel) joinPrel).getRuntimeFilterDef()) != null) {
                runtimeFilterDef.setGenerateBloomFilter(true);
                if (rFHelperHolder == null) {
                    rFHelperHolder = new RFHelperHolder();
                }
                ((Prel) hashJoinPrel.getLeft()).accept(this, rFHelperHolder);
                Prel prel = (Prel) hashJoinPrel.getRight();
                rFHelperHolder.setFromBuildSide(true);
                prel.accept(this, rFHelperHolder);
                boolean needToRouteToForeman = rFHelperHolder.needToRouteToForeman();
                runtimeFilterDef.setSendToForeman(needToRouteToForeman);
                Iterator<BloomFilterDef> it = runtimeFilterDef.getBloomFilterDefs().iterator();
                while (it.hasNext()) {
                    it.next().setLocal(!needToRouteToForeman);
                }
            }
            return visitPrel((Prel) joinPrel, rFHelperHolder);
        }
    }

    private RuntimeFilterVisitor(QueryContext queryContext) {
        this.bloomFilterMaxSizeInBytesDef = queryContext.getOption(ExecConstants.HASHJOIN_BLOOM_FILTER_MAX_SIZE_KEY).num_val.intValue();
        this.fpp = queryContext.getOption(ExecConstants.HASHJOIN_BLOOM_FILTER_FPP_KEY).float_val.doubleValue();
    }

    public static Prel addRuntimeFilter(Prel prel, QueryContext queryContext) {
        Prel prel2 = (Prel) prel.accept(new RuntimeFilterVisitor(queryContext), null);
        new RuntimeFilterInfoPaddingHelper().visitPrel(prel2, (RFHelperHolder) null);
        return prel2;
    }

    @Override // org.apache.drill.exec.planner.physical.visitor.BasePrelVisitor, org.apache.drill.exec.planner.physical.visitor.PrelVisitor
    public Prel visitPrel(Prel prel, Void r6) throws RuntimeException {
        ArrayList arrayList = new ArrayList();
        Iterator<Prel> it = prel.iterator();
        while (it.hasNext()) {
            arrayList.add((Prel) it.next().accept(this, r6));
        }
        return arrayList.equals(prel.getInputs()) ? prel : (Prel) prel.copy(prel.getTraitSet(), arrayList);
    }

    @Override // org.apache.drill.exec.planner.physical.visitor.BasePrelVisitor, org.apache.drill.exec.planner.physical.visitor.PrelVisitor
    public Prel visitJoin(JoinPrel joinPrel, Void r6) throws RuntimeException {
        if (joinPrel instanceof HashJoinPrel) {
            HashJoinPrel hashJoinPrel = (HashJoinPrel) joinPrel;
            hashJoinPrel.setRuntimeFilterDef(generateRuntimeFilter(hashJoinPrel));
        }
        return visitPrel((Prel) joinPrel, r6);
    }

    @Override // org.apache.drill.exec.planner.physical.visitor.BasePrelVisitor, org.apache.drill.exec.planner.physical.visitor.PrelVisitor
    public Prel visitScan(ScanPrel scanPrel, Void r8) throws RuntimeException {
        if (!this.toAddRuntimeFilter.contains(scanPrel)) {
            return scanPrel;
        }
        RuntimeFilterPrel runtimeFilterPrel = null;
        for (HashJoinPrel hashJoinPrel : this.probeSideScan2hj.get(scanPrel)) {
            long incrementAndGet = rfIdCounter.incrementAndGet();
            hashJoinPrel.getRuntimeFilterDef().setRuntimeFilterIdentifier(incrementAndGet);
            runtimeFilterPrel = runtimeFilterPrel == null ? new RuntimeFilterPrel(scanPrel, incrementAndGet) : new RuntimeFilterPrel(runtimeFilterPrel, incrementAndGet);
        }
        return runtimeFilterPrel;
    }

    private RuntimeFilterDef generateRuntimeFilter(HashJoinPrel hashJoinPrel) {
        JoinRelType joinType = hashJoinPrel.getJoinType();
        if (!(hashJoinPrel.analyzeCondition().isEqui() && (joinType == JoinRelType.INNER || joinType == JoinRelType.RIGHT))) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        ScanPrel scanPrel = null;
        RelNode left = hashJoinPrel.getLeft();
        RelNode right = hashJoinPrel.getRight();
        if (findRightExchangePrel(right) == null) {
            return null;
        }
        List fieldNames = left.getRowType().getFieldNames();
        List fieldNames2 = right.getRowType().getFieldNames();
        List<Integer> leftKeys = hashJoinPrel.getLeftKeys();
        List<Integer> rightKeys = hashJoinPrel.getRightKeys();
        RelMetadataQuery metadataQuery = left.getCluster().getMetadataQuery();
        int i = 0;
        Iterator<Integer> it = leftKeys.iterator();
        while (it.hasNext()) {
            String str = (String) fieldNames.get(it.next().intValue());
            int i2 = i;
            i++;
            String str2 = (String) fieldNames2.get(rightKeys.get(i2).intValue());
            ScanPrel findLeftScanPrel = findLeftScanPrel(str, left);
            if (findLeftScanPrel != null && !containBlockNode((Prel) left, findLeftScanPrel)) {
                Double distinctRowCount = metadataQuery.getDistinctRowCount(findLeftScanPrel, ImmutableBitSet.of(new int[]{findLeftScanPrel.getRowType().getField(str, true, true).getIndex()}), (RexNode) null);
                if (distinctRowCount == null) {
                    distinctRowCount = Double.valueOf(left.estimateRowCount(metadataQuery) * 0.1d);
                }
                int optimalNumOfBytes = BloomFilter.optimalNumOfBytes(distinctRowCount.longValue(), this.fpp);
                BloomFilterDef bloomFilterDef = new BloomFilterDef(optimalNumOfBytes > this.bloomFilterMaxSizeInBytesDef ? this.bloomFilterMaxSizeInBytesDef : optimalNumOfBytes, false, str, str2);
                bloomFilterDef.setLeftNDV(distinctRowCount);
                arrayList.add(bloomFilterDef);
                this.toAddRuntimeFilter.add(findLeftScanPrel);
                scanPrel = findLeftScanPrel;
            }
        }
        if (arrayList.size() <= 0) {
            return null;
        }
        RuntimeFilterDef runtimeFilterDef = new RuntimeFilterDef(true, false, arrayList, false, -1L);
        this.probeSideScan2hj.put(scanPrel, hashJoinPrel);
        return runtimeFilterDef;
    }

    private ScanPrel findLeftScanPrel(String str, RelNode relNode) {
        if (relNode instanceof ScanPrel) {
            if (relNode.getRowType().getField(str, true, true) != null) {
                return (ScanPrel) relNode;
            }
            return null;
        }
        if (!(relNode instanceof RelSubset)) {
            return findLeftScanPrel(str, (RelNode) relNode.getInputs().get(0));
        }
        RelNode best = ((RelSubset) relNode).getBest();
        if (best != null) {
            return findLeftScanPrel(str, best);
        }
        return null;
    }

    private ExchangePrel findRightExchangePrel(RelNode relNode) {
        if (relNode instanceof ExchangePrel) {
            return (ExchangePrel) relNode;
        }
        if (relNode instanceof ScanPrel) {
            return null;
        }
        if (relNode instanceof RelSubset) {
            RelNode best = ((RelSubset) relNode).getBest();
            if (best != null) {
                return findRightExchangePrel(best);
            }
            return null;
        }
        List inputs = relNode.getInputs();
        if (inputs.size() == 1) {
            return findRightExchangePrel((RelNode) inputs.get(0));
        }
        return null;
    }

    private boolean containBlockNode(Prel prel, Prel prel2) {
        BlockNodeVisitor blockNodeVisitor = new BlockNodeVisitor();
        prel.accept(blockNodeVisitor, prel2);
        return blockNodeVisitor.isEncounteredBlockNode();
    }
}
