/*
 * Decompiled with CFR 0.152.
 */
package com.sourceclear.methods;

import com.google.common.base.Stopwatch;
import com.sourceclear.analysis.latte.frameworks.Frameworks;
import com.sourceclear.analysis.latte.genids.Id;
import com.sourceclear.methods.CallChain;
import com.sourceclear.methods.CallChainsInspector;
import com.sourceclear.methods.CallGraph;
import com.sourceclear.methods.CallGraphBuilder;
import com.sourceclear.methods.CallSite;
import com.sourceclear.methods.Cleaner;
import com.sourceclear.methods.EntryPointResolver;
import com.sourceclear.methods.JGCallGraph;
import com.sourceclear.methods.JSMethodInfo;
import com.sourceclear.methods.MethodInfo;
import com.sourceclear.methods.MethodScanner;
import com.sourceclear.methods.ReachabilityInspector;
import com.sourceclear.methods.SomePairsShortestPaths;
import com.sourceclear.methods.VulnMethodsInput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.jgrapht.Graph;
import org.jgrapht.graph.EdgeReversedGraph;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VulnerablePartsDetector
implements MethodScanner {
    private static final Logger LOGGER = LoggerFactory.getLogger(VulnerablePartsDetector.class);
    private final Collection<MethodInfo> vulnerableMethods;
    private final EntryPointResolver entryPointResolver;
    private final CallGraphBuilder callGraphBuilder;
    private final Cleaner cleaner;
    private final VulnMethodsInput input;

    public VulnerablePartsDetector(Collection<MethodInfo> vulnerableMethods, EntryPointResolver entryPointResolver, CallGraphBuilder callGraphBuilder, VulnMethodsInput input, Cleaner cleaner) {
        this.vulnerableMethods = vulnerableMethods;
        this.entryPointResolver = entryPointResolver;
        this.callGraphBuilder = callGraphBuilder;
        this.cleaner = cleaner;
        this.input = input;
    }

    @Override
    public MethodScanner.VMReport scan() throws IOException {
        this.callGraphBuilder.build(this.input);
        return this.findVulnerableMethodsCallChainsTimed();
    }

    @Override
    public Set<MethodInfo> reachableMethods(Set<MethodInfo> reachableFrom) throws IOException {
        this.callGraphBuilder.build(this.input);
        return new ReachabilityInspector(this.getCallGraph()).findPossibleCallers(new HashSet<MethodInfo>(this.vulnerableMethods), reachableFrom);
    }

    @Override
    public boolean vulnerableMethodsAreDefined() throws IOException {
        Set<MethodInfo> methodsDefined = this.callGraphBuilder.getMethodsDefined();
        if (methodsDefined.isEmpty()) {
            this.callGraphBuilder.build(this.input);
        }
        return methodsDefined.containsAll(this.vulnerableMethods);
    }

    private MethodScanner.VMReport findVulnerableMethodsCallChainsTimed() {
        Stopwatch stopwatch = Stopwatch.createStarted();
        MethodScanner.VMReport result = this.findVulnerableMethodsCallChains();
        LOGGER.debug("call chain traversal finished in {}s", (Object)stopwatch.elapsed(TimeUnit.SECONDS));
        return result;
    }

    private MethodScanner.VMReport findVulnerableMethodsCallChains() {
        JGCallGraph callGraph = CallChainsInspector.cast(this.callGraphBuilder.getCallGraph());
        double vertexCount = callGraph.vertices().count();
        double edgeCount = callGraph.edges().count();
        Set thirdParty = this.input.staticCallChains().stream().flatMap(css -> css.stream().flatMap(cs -> Stream.of(cs.getCaller(), cs.getCallee()))).collect(Collectors.toSet());
        Set spanningCandidates = callGraph.edges().filter(cs -> this.callGraphBuilder.getAppVertices().contains(cs.getCaller()) && !this.callGraphBuilder.getAppEdges().contains(cs)).flatMap(cs -> callGraph.getGraph().incomingEdgesOf((Object)cs.getCaller()).stream()).filter(cs -> !thirdParty.contains(cs.getCaller()) && thirdParty.contains(cs.getCallee())).collect(Collectors.toCollection(HashSet::new));
        LOGGER.debug("indirect vulnerable method calls: {} {}", (Object)spanningCandidates.size(), (Object)String.format("%.02f", (double)spanningCandidates.size() / edgeCount));
        callGraph.edges().filter(cs -> {
            boolean directVulnMethodCall = !thirdParty.contains(cs.getCaller()) && this.vulnerableMethods.contains(cs.getCallee());
            boolean testMethodEntryPoint = callGraph.getGraph().incomingEdgesOf((Object)cs.getCaller()).isEmpty() && this.callGraphBuilder.isTestMethod(cs.getCaller());
            return testMethodEntryPoint || directVulnMethodCall;
        }).forEach(spanningCandidates::add);
        LOGGER.debug("total spanning call sites: {} {}", (Object)spanningCandidates.size(), (Object)String.format("%.02f", (double)spanningCandidates.size() / edgeCount));
        SomePairsShortestPaths<MethodInfo, CallSite> johnsonToVulnMethods = new SomePairsShortestPaths<MethodInfo, CallSite>((Graph<MethodInfo, CallSite>)callGraph.getGraph(), spanningCandidates.stream().map(CallSite::getCallee).collect(Collectors.toSet()));
        SomePairsShortestPaths johnsonToEntryPoints = new SomePairsShortestPaths(new EdgeReversedGraph(callGraph.getGraph()), spanningCandidates.stream().map(CallSite::getCaller).collect(Collectors.toSet()));
        Set frameworkCallbacks = this.input.frameworkMethods().stream().map(Frameworks.Module::getFunctions).flatMap(Collection::stream).map(Frameworks.Function::getName).map(JSMethodInfo::getId).map(Id::toString).collect(Collectors.toSet());
        List allEntryPoints = callGraph.vertices().filter(v -> this.entryPointResolver.isEntryPoint((MethodInfo)v) && (callGraph.getGraph().incomingEdgesOf(v).isEmpty() || frameworkCallbacks.contains(v.getMethodName()))).collect(Collectors.toList());
        LOGGER.debug("application entry points: {} {}", (Object)allEntryPoints.size(), (Object)String.format("%.02f", (double)allEntryPoints.size() / vertexCount));
        LOGGER.debug("vulnerable methods: {} {}", (Object)this.vulnerableMethods.size(), (Object)String.format("%.02f", (double)this.vulnerableMethods.size() / vertexCount));
        HashMap<CallChain, CallSite> spanningMethods = new HashMap<CallChain, CallSite>();
        Map<MethodInfo, Collection<CallChain>> result = spanningCandidates.stream().flatMap(callSite -> allEntryPoints.stream().filter(e -> johnsonToEntryPoints.hasDirectedPath(callSite.getCaller(), e)).min(Comparator.comparing(Object::toString)).map(Stream::of).orElseGet(Stream::empty).flatMap(ep -> this.vulnerableMethods.stream().filter(callGraph::containsVertex).filter(vm -> johnsonToVulnMethods.hasDirectedPath(callSite.getCallee(), (MethodInfo)vm)).max(Comparator.comparingDouble(vm -> johnsonToVulnMethods.getPathWeight(callSite.getCallee(), (MethodInfo)vm))).map(Stream::of).orElseGet(Stream::empty).flatMap(vm -> {
            ArrayList<CallSite> callSites2 = new ArrayList<CallSite>(johnsonToEntryPoints.getPath(callSite.getCaller(), ep).getEdgeList());
            Collections.reverse(callSites2);
            callSites2.add((CallSite)callSite);
            callSites2.addAll(johnsonToVulnMethods.getPath(callSite.getCallee(), (MethodInfo)vm).getEdgeList());
            CallChain chain = new CallChain(callSites2);
            spanningMethods.put(chain, (CallSite)callSite);
            return Stream.of(new VMCall((MethodInfo)vm, (MethodInfo)ep, (CallSite)callSite, chain));
        }))).collect(Collectors.groupingBy(r -> r.vulnMethod, Collectors.mapping(r -> r.chain, Collectors.toCollection(HashSet::new))));
        return new MethodScanner.VMReport(result, spanningMethods);
    }

    @Override
    public void close() {
        this.cleaner.clean();
    }

    @Override
    public CallGraph getCallGraph() {
        return this.callGraphBuilder.getCallGraph();
    }

    private static class VMCall {
        MethodInfo vulnMethod;
        MethodInfo entryPoint;
        CallSite spanning;
        CallChain chain;

        VMCall(MethodInfo vulnMethod, MethodInfo entryPoint, CallSite spanning, CallChain chain) {
            this.vulnMethod = vulnMethod;
            this.entryPoint = entryPoint;
            this.spanning = spanning;
            this.chain = chain;
        }
    }
}

