package org.wso2.am.choreo.extensions.oauth;

import com.nimbusds.jwt.SignedJWT;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.TreeMap;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.wso2.carbon.apimgt.api.APIManagementException;
import org.wso2.carbon.apimgt.api.APIMgtAuthorizationFailedException;
import org.wso2.carbon.apimgt.common.gateway.dto.JWTValidationInfo;
import org.wso2.carbon.apimgt.impl.caching.CacheProvider;
import org.wso2.carbon.apimgt.impl.jwt.SignedJWTInfo;
import org.wso2.carbon.apimgt.impl.utils.APIUtil;
import org.wso2.carbon.apimgt.rest.api.common.RestAPIAuthenticator;
import org.wso2.carbon.apimgt.rest.api.common.utils.JWTUtil;

/* loaded from: input_file:org/wso2/am/choreo/extensions/oauth/BackendJWTAuthenticationImpl.class */
public class BackendJWTAuthenticationImpl implements RestAPIAuthenticator {
    private static final int MAX_LEN = 36;
    private static final int MAX_VISIBLE_LEN = 8;
    private static final int MIN_VISIBLE_LEN_RATIO = 5;
    private static final String MASK_CHAR = "X";
    private static final Log logger = LogFactory.getLog(BackendJWTAuthenticationImpl.class);
    private boolean isRESTApiTokenCacheEnabled;

    public boolean authenticate(HashMap<String, Object> hashMap) throws APIMgtAuthorizationFailedException {
        this.isRESTApiTokenCacheEnabled = APIUtil.getRESTAPICacheConfig().isTokenCacheEnabled();
        String accessToken = getAccessToken(hashMap);
        hashMap.put("maskedToken", getMaskedToken(accessToken));
        if (accessToken == null) {
            return false;
        }
        try {
            String str = (String) hashMap.get("JWT_TOKEN");
            SignedJWTInfo signedJwt = getSignedJwt(accessToken);
            if (signedJwt == null) {
                logger.error("Invalid Signed JWT :" + signedJwt);
                return false;
            }
            String jWTTokenIdentifier = getJWTTokenIdentifier(signedJwt);
            JWTValidationInfo jWTValidationInfo = new JWTValidationInfo();
            jWTValidationInfo.setValid(true);
            if (this.isRESTApiTokenCacheEnabled) {
                CacheProvider.getRESTAPITokenCache().put(jWTTokenIdentifier, jWTValidationInfo);
            }
            return JWTUtil.handleScopeValidation(hashMap, signedJwt, str);
        } catch (APIManagementException | ParseException e) {
            logger.error("Not a JWT token. Failed to decode the token. Reason: " + e.getMessage());
            return false;
        }
    }

    public boolean canHandle(HashMap<String, Object> hashMap) {
        ArrayList arrayList = (ArrayList) ((TreeMap) hashMap.get("org.apache.cxf.message.Message.PROTOCOL_HEADERS")).get("X-JWT-Assertion");
        if (arrayList != null) {
            return arrayList.get(0).toString().contains(".");
        }
        return false;
    }

    public String getAuthenticationType() {
        return "jwt";
    }

    public int getPriority(HashMap<String, Object> hashMap) {
        return 0;
    }

    private SignedJWTInfo getSignedJwt(String str) throws ParseException {
        SignedJWT parse = SignedJWT.parse(str);
        return new SignedJWTInfo(str, parse, parse.getJWTClaimsSet());
    }

    private String getAccessToken(HashMap<String, Object> hashMap) {
        ArrayList arrayList = (ArrayList) ((TreeMap) hashMap.get("org.apache.cxf.message.Message.PROTOCOL_HEADERS")).get("X-JWT-Assertion");
        if (arrayList != null) {
            return arrayList.get(0).toString();
        }
        return null;
    }

    private String getJWTTokenIdentifier(SignedJWTInfo signedJWTInfo) {
        String jwtid = signedJWTInfo.getJwtClaimsSet().getJWTID();
        return StringUtils.isNotEmpty(jwtid) ? jwtid : signedJWTInfo.getSignedJWT().getSignature().toString();
    }

    private static String getMaskedToken(String str) {
        StringBuilder sb = new StringBuilder();
        if (str != null) {
            int min = Math.min(str.length() / MIN_VISIBLE_LEN_RATIO, MAX_VISIBLE_LEN);
            if (str.length() > MAX_LEN) {
                sb.append("...");
                sb.append(String.join("", Collections.nCopies(MAX_LEN, MASK_CHAR)));
            } else {
                sb.append(String.join("", Collections.nCopies(str.length() - min, MASK_CHAR)));
            }
            sb.append(str.substring(str.length() - min));
        }
        return sb.toString();
    }
}
