/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.cloud.sleuth.instrument.rsocket;

import io.rsocket.frame.FrameType;
import java.net.URI;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingDeque;
import org.assertj.core.api.BDDAssertions;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
import org.springframework.boot.WebApplicationType;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.cloud.sleuth.Span;
import org.springframework.cloud.sleuth.TraceContext;
import org.springframework.cloud.sleuth.Tracer;
import org.springframework.cloud.sleuth.exporter.FinishedSpan;
import org.springframework.cloud.sleuth.test.TestSpanHandler;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.env.Environment;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.messaging.rsocket.RSocketRequester;
import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.stereotype.Controller;
import org.springframework.util.MimeType;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.context.ContextView;

public abstract class TraceRSocketTests {
    public static final String EXPECTED_TRACE_ID = "b919095138aa4c6e";

    @Test
    public void should_instrument_responder() throws Exception {
        ConfigurableApplicationContext context = new SpringApplicationBuilder(new Class[]{MyConfig.class, this.testConfiguration()}).web(WebApplicationType.REACTIVE).properties(new String[]{"server.port=0", "spring.rsocket.server.transport=websocket", "spring.rsocket.server.mapping-path=/rsocket", "spring.jmx.enabled=false", "spring.application.name=TraceRSocketTests", "security.basic.enabled=false", "management.security.enabled=false"}).run(new String[0]);
        TestSpanHandler spans = (TestSpanHandler)context.getBean(TestSpanHandler.class);
        int port = (Integer)((Environment)context.getBean(Environment.class)).getProperty("local.server.port", Integer.class);
        TestController controller2 = (TestController)context.getBean(TestController.class);
        RSocketStrategies strategies = (RSocketStrategies)context.getBean(RSocketStrategies.class);
        RSocketRequester.Builder rsocketRequesterBuilder = RSocketRequester.builder().rsocketStrategies(strategies);
        RSocketRequester rSocketRequester = rsocketRequesterBuilder.websocket(URI.create("ws://localhost:" + port + "/rsocket"));
        this.whenRequestFnFIsSent(rSocketRequester, "api.c2.fnf").block();
        FrameType receivedFrame = controller2.getReceivedFrames().take();
        this.thenSpanWasReportedWithTags(spans, "api.c2.fnf", receivedFrame);
        spans.clear();
        controller2.reset();
        this.whenRequestResponseIsSent(rSocketRequester, "api.c2.rr").block();
        receivedFrame = controller2.getReceivedFrames().take();
        this.thenSpanWasReportedWithTags(spans, "api.c2.rr", receivedFrame);
        spans.clear();
        controller2.reset();
        this.whenRequestStreamIsSent(rSocketRequester, "api.c2.rs").blockLast();
        receivedFrame = controller2.getReceivedFrames().take();
        this.thenSpanWasReportedWithTags(spans, "api.c2.rs", receivedFrame);
        spans.clear();
        controller2.reset();
        this.whenRequestChannelIsSent(rSocketRequester, "api.c2.rc").blockLast();
        receivedFrame = controller2.getReceivedFrames().take();
        this.thenSpanWasReportedWithTags(spans, "api.c2.rc", receivedFrame);
        spans.clear();
        controller2.reset();
        this.whenNonSampledRequestFnfIsSent(rSocketRequester);
        controller2.getReceivedFrames().take();
        this.thenNoSpanWasReported(spans, controller2, this.expectedTraceId());
        spans.clear();
        controller2.reset();
        this.whenNonSampledRequestResponseIsSent(rSocketRequester);
        controller2.getReceivedFrames().take();
        this.thenNoSpanWasReported(spans, controller2, this.expectedTraceId());
        spans.clear();
        controller2.reset();
        this.whenNonSampledRequestStreamIsSent(rSocketRequester);
        controller2.getReceivedFrames().take();
        this.thenNoSpanWasReported(spans, controller2, this.expectedTraceId());
        spans.clear();
        controller2.reset();
        this.whenNonSampledRequestChannelIsSent(rSocketRequester);
        controller2.getReceivedFrames().take();
        this.thenNoSpanWasReported(spans, controller2, this.expectedTraceId());
        spans.clear();
        controller2.reset();
        context.close();
    }

    protected String expectedTraceId() {
        return EXPECTED_TRACE_ID;
    }

    protected String expectedSpanId() {
        return EXPECTED_TRACE_ID;
    }

    @Test
    public void should_instrument_requester_and_responder() throws Exception {
        ConfigurableApplicationContext context = new SpringApplicationBuilder(new Class[]{MyConfig.class, this.testConfiguration()}).web(WebApplicationType.REACTIVE).properties(new String[]{"server.port=0", "spring.rsocket.server.transport=websocket", "spring.rsocket.server.mapping-path=/rsocket", "spring.jmx.enabled=false", "spring.application.name=TraceRSocketTests", "security.basic.enabled=false", "management.security.enabled=false"}).run(new String[0]);
        Tracer tracer = (Tracer)context.getBean(Tracer.class);
        TestSpanHandler spans = (TestSpanHandler)context.getBean(TestSpanHandler.class);
        int port = (Integer)((Environment)context.getBean(Environment.class)).getProperty("local.server.port", Integer.class);
        TestController controller2 = (TestController)context.getBean(TestController.class);
        RSocketRequester.Builder rsocketRequesterBuilder = (RSocketRequester.Builder)context.getBean(RSocketRequester.Builder.class);
        RSocketRequester rSocketRequester = rsocketRequesterBuilder.websocket(URI.create("ws://localhost:" + port + "/rsocket"));
        Span nextSpanFnf = tracer.nextSpan().start();
        this.whenRequestFnFIsSent(rSocketRequester, "api.c2.fnf").contextWrite(ctx -> ctx.put(TraceContext.class, (Object)nextSpanFnf.context())).doFinally(signalType -> nextSpanFnf.end()).block();
        controller2.getReceivedFrames().take();
        this.thenNoSpanWasReported(spans, controller2, nextSpanFnf.context().traceId());
        spans.clear();
        controller2.reset();
        Span nextSpanRR = tracer.nextSpan().start();
        this.whenRequestResponseIsSent(rSocketRequester, "api.c2.rr").contextWrite(ctx -> ctx.put(TraceContext.class, (Object)nextSpanRR.context())).doFinally(signalType -> nextSpanRR.end()).block();
        controller2.getReceivedFrames().take();
        this.thenNoSpanWasReported(spans, controller2, nextSpanRR.context().traceId());
        spans.clear();
        controller2.reset();
        Span nextSpanRS = tracer.nextSpan().start();
        this.whenRequestStreamIsSent(rSocketRequester, "api.c2.rs").contextWrite(ctx -> ctx.put(TraceContext.class, (Object)nextSpanRS.context())).doFinally(signalType -> nextSpanRS.end()).blockLast();
        controller2.getReceivedFrames().take();
        this.thenNoSpanWasReported(spans, controller2, nextSpanRS.context().traceId());
        spans.clear();
        controller2.reset();
        Span nextSpanRC = tracer.nextSpan().start();
        this.whenRequestChannelIsSent(rSocketRequester, "api.c2.rc").contextWrite(ctx -> ctx.put(TraceContext.class, (Object)nextSpanRC.context())).doFinally(signalType -> nextSpanRC.end()).blockLast();
        controller2.getReceivedFrames().take();
        this.thenNoSpanWasReported(spans, controller2, nextSpanRC.context().traceId());
        spans.clear();
        controller2.reset();
        context.close();
    }

    protected abstract Class testConfiguration();

    private void thenSpanWasReportedWithTags(TestSpanHandler spans, String path, FrameType frameType) {
        String expectedName = frameType.name() + " " + path;
        Awaitility.await().untilAsserted(() -> {
            FinishedSpan span = spans.reportedSpans().stream().filter(finished -> expectedName.equals(finished.getName())).findFirst().orElseThrow(() -> new AssertionError((Object)("Span with name [" + expectedName + "] not found")));
            BDDAssertions.then((Map)span.getTags()).containsEntry((Object)"messaging.controller.class", (Object)"org.springframework.cloud.sleuth.instrument.rsocket.TraceRSocketTests$TestController");
            BDDAssertions.then((Map)span.getTags()).containsKey((Object)"messaging.controller.method");
        });
    }

    private Mono<Void> whenRequestFnFIsSent(RSocketRequester requester, String path) {
        return requester.route(path, new Object[0]).send();
    }

    private Mono<String> whenRequestResponseIsSent(RSocketRequester requester, String path) {
        return requester.route(path, new Object[0]).retrieveMono(String.class);
    }

    private Flux<String> whenRequestStreamIsSent(RSocketRequester requester, String path) {
        return requester.route(path, new Object[0]).retrieveFlux(String.class);
    }

    private Flux<String> whenRequestChannelIsSent(RSocketRequester requester, String path) {
        return requester.route(path, new Object[0]).data((Object)Flux.fromArray((Object[])new String[]{"test1", "test2"})).retrieveFlux(String.class);
    }

    private void whenNonSampledRequestFnfIsSent(RSocketRequester requester) {
        ((RSocketRequester.RequestSpec)requester.route("api.c2.fnf", new Object[0]).metadata((Object)(this.expectedTraceId() + "-" + this.expectedSpanId() + "-0"), new MimeType("b3"){

            public String toString() {
                return "b3";
            }
        })).send().block();
    }

    private void whenNonSampledRequestResponseIsSent(RSocketRequester requester) {
        ((RSocketRequester.RequestSpec)requester.route("api.c2.rr", new Object[0]).metadata((Object)(this.expectedTraceId() + "-" + this.expectedSpanId() + "-0"), new MimeType("b3"){

            public String toString() {
                return "b3";
            }
        })).retrieveMono(String.class).block();
    }

    private void whenNonSampledRequestStreamIsSent(RSocketRequester requester) {
        ((RSocketRequester.RequestSpec)requester.route("api.c2.rs", new Object[0]).metadata((Object)(this.expectedTraceId() + "-" + this.expectedSpanId() + "-0"), new MimeType("b3"){

            public String toString() {
                return "b3";
            }
        })).retrieveFlux(String.class).blockLast();
    }

    private void whenNonSampledRequestChannelIsSent(RSocketRequester requester) {
        ((RSocketRequester.RequestSpec)requester.route("api.c2.rc", new Object[0]).metadata((Object)(this.expectedTraceId() + "-" + this.expectedSpanId() + "-0"), new MimeType("b3"){

            public String toString() {
                return "b3";
            }
        })).data((Object)Flux.fromArray((Object[])new String[]{"test1", "test2"})).retrieveFlux(String.class).blockLast();
    }

    private void thenNoSpanWasReported(TestSpanHandler spans, TestController controller2, String expectedTraceId) {
        Awaitility.await().untilAsserted(() -> {
            BDDAssertions.then((Object)controller2.getSpan()).isNotNull();
            BDDAssertions.then((String)controller2.getSpan().context().traceId()).isEqualTo(expectedTraceId);
        });
    }

    @Configuration(proxyBeanMethods=false)
    @EnableAutoConfiguration
    static class MyConfig {
        MyConfig() {
        }

        @Bean
        TestController controller(Tracer tracer) {
            return new TestController(tracer);
        }
    }

    @Controller
    @MessageMapping(value={"api.c2"})
    static class TestController {
        final Tracer tracer;
        Span span;
        ContextView interceptedContext;
        BlockingQueue<FrameType> receivedFrames = new LinkedBlockingDeque<FrameType>();

        TestController(Tracer tracer) {
            this.tracer = tracer;
        }

        BlockingQueue<FrameType> getReceivedFrames() {
            return this.receivedFrames;
        }

        Span getSpan() {
            return this.span;
        }

        void reset() {
            this.span = null;
        }

        @MessageMapping(value={"fnf"})
        Mono<Void> testFnf() {
            this.span = this.tracer.currentSpan();
            return Mono.deferContextual(c -> {
                this.interceptedContext = c;
                this.receivedFrames.offer(FrameType.REQUEST_FNF);
                return Mono.empty();
            });
        }

        @MessageMapping(value={"rr"})
        Mono<String> testRR() {
            this.span = this.tracer.currentSpan();
            return Mono.deferContextual(c -> {
                this.interceptedContext = c;
                this.receivedFrames.offer(FrameType.REQUEST_RESPONSE);
                return Mono.just((Object)"response");
            });
        }

        @MessageMapping(value={"rs"})
        Flux<String> testRS() {
            this.span = this.tracer.currentSpan();
            return Flux.deferContextual(c -> {
                this.interceptedContext = c;
                this.receivedFrames.offer(FrameType.REQUEST_STREAM);
                return Flux.just((Object)"stream");
            });
        }

        @MessageMapping(value={"rc"})
        Flux<String> testRC(@Payload Flux<String> inbound) {
            this.span = this.tracer.currentSpan();
            return Flux.deferContextual(c -> {
                this.interceptedContext = c;
                this.receivedFrames.offer(FrameType.REQUEST_CHANNEL);
                return inbound;
            });
        }
    }
}

