/*
 * Decompiled with CFR 0.152.
 */
package org.apache.dolphinscheduler.server.worker.task.sql;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.apache.dolphinscheduler.alert.utils.MailUtils;
import org.apache.dolphinscheduler.common.enums.CommandType;
import org.apache.dolphinscheduler.common.enums.DataType;
import org.apache.dolphinscheduler.common.enums.DbType;
import org.apache.dolphinscheduler.common.enums.ShowType;
import org.apache.dolphinscheduler.common.enums.TaskTimeoutStrategy;
import org.apache.dolphinscheduler.common.process.Property;
import org.apache.dolphinscheduler.common.task.AbstractParameters;
import org.apache.dolphinscheduler.common.task.sql.SqlBinds;
import org.apache.dolphinscheduler.common.task.sql.SqlParameters;
import org.apache.dolphinscheduler.common.task.sql.SqlType;
import org.apache.dolphinscheduler.common.utils.CollectionUtils;
import org.apache.dolphinscheduler.common.utils.CommonUtils;
import org.apache.dolphinscheduler.common.utils.EnumUtils;
import org.apache.dolphinscheduler.common.utils.JSONUtils;
import org.apache.dolphinscheduler.common.utils.ParameterUtils;
import org.apache.dolphinscheduler.dao.AlertDao;
import org.apache.dolphinscheduler.dao.datasource.BaseDataSource;
import org.apache.dolphinscheduler.dao.datasource.DataSourceFactory;
import org.apache.dolphinscheduler.dao.entity.User;
import org.apache.dolphinscheduler.server.entity.SQLTaskExecutionContext;
import org.apache.dolphinscheduler.server.entity.TaskExecutionContext;
import org.apache.dolphinscheduler.server.utils.ParamUtils;
import org.apache.dolphinscheduler.server.utils.UDFUtils;
import org.apache.dolphinscheduler.server.worker.task.AbstractTask;
import org.apache.dolphinscheduler.service.bean.SpringApplicationContext;
import org.slf4j.Logger;

public class SqlTask
extends AbstractTask {
    private SqlParameters sqlParameters;
    private AlertDao alertDao;
    private BaseDataSource baseDataSource;
    private TaskExecutionContext taskExecutionContext;

    public SqlTask(TaskExecutionContext taskExecutionContext, Logger logger) {
        super(taskExecutionContext, logger);
        this.taskExecutionContext = taskExecutionContext;
        logger.info("sql task params {}", (Object)taskExecutionContext.getTaskParams());
        this.sqlParameters = (SqlParameters)JSONUtils.parseObject((String)taskExecutionContext.getTaskParams(), SqlParameters.class);
        if (!this.sqlParameters.checkParameters()) {
            throw new RuntimeException("sql task params is not valid");
        }
        this.alertDao = (AlertDao)SpringApplicationContext.getBean(AlertDao.class);
    }

    @Override
    public void handle() throws Exception {
        String threadLoggerInfoName = String.format("TaskLogInfo-%s", this.taskExecutionContext.getTaskAppId());
        Thread.currentThread().setName(threadLoggerInfoName);
        this.logger.info("Full sql parameters: {}", (Object)this.sqlParameters);
        this.logger.info("sql type : {}, datasource : {}, sql : {} , localParams : {},udfs : {},showType : {},connParams : {}, query max result limit : {}", new Object[]{this.sqlParameters.getType(), this.sqlParameters.getDatasource(), this.sqlParameters.getSql(), this.sqlParameters.getLocalParams(), this.sqlParameters.getUdfs(), this.sqlParameters.getShowType(), this.sqlParameters.getConnParams(), this.sqlParameters.getLimit()});
        try {
            SQLTaskExecutionContext sqlTaskExecutionContext = this.taskExecutionContext.getSqlTaskExecutionContext();
            DataSourceFactory.loadClass((DbType)DbType.valueOf((String)this.sqlParameters.getType()));
            this.baseDataSource = DataSourceFactory.getDatasource((DbType)DbType.valueOf((String)this.sqlParameters.getType()), (String)sqlTaskExecutionContext.getConnectionParams());
            SqlBinds mainSqlBinds = this.getSqlAndSqlParamsMap(this.sqlParameters.getSql());
            List<SqlBinds> preStatementSqlBinds = ((List)Optional.ofNullable(this.sqlParameters.getPreStatements()).orElse(new ArrayList())).stream().map(this::getSqlAndSqlParamsMap).collect(Collectors.toList());
            List<SqlBinds> postStatementSqlBinds = ((List)Optional.ofNullable(this.sqlParameters.getPostStatements()).orElse(new ArrayList())).stream().map(this::getSqlAndSqlParamsMap).collect(Collectors.toList());
            List<String> createFuncs = UDFUtils.createFuncs(sqlTaskExecutionContext.getUdfFuncTenantCodeMap(), this.logger);
            this.executeFuncAndSql(mainSqlBinds, preStatementSqlBinds, postStatementSqlBinds, createFuncs);
            this.setExitStatusCode(0);
        }
        catch (Exception e) {
            this.setExitStatusCode(-1);
            this.logger.error("sql task error: {}", (Object)e.toString());
            throw e;
        }
    }

    private SqlBinds getSqlAndSqlParamsMap(String sql) {
        HashMap<Integer, Property> sqlParamsMap = new HashMap<Integer, Property>();
        StringBuilder sqlBuilder = new StringBuilder();
        Map<String, Property> paramsMap = ParamUtils.convert(ParamUtils.getUserDefParamsMap(this.taskExecutionContext.getDefinedParams()), this.taskExecutionContext.getDefinedParams(), this.sqlParameters.getLocalParametersMap(), CommandType.of((Integer)this.taskExecutionContext.getCmdTypeIfComplement()), this.taskExecutionContext.getScheduleTime());
        if (paramsMap == null) {
            sqlBuilder.append(sql);
            return new SqlBinds(sqlBuilder.toString(), sqlParamsMap);
        }
        if (StringUtils.isNotEmpty((String)this.sqlParameters.getTitle())) {
            String title = ParameterUtils.convertParameterPlaceholders((String)this.sqlParameters.getTitle(), ParamUtils.convert(paramsMap));
            this.logger.info("SQL title : {}", (Object)title);
            this.sqlParameters.setTitle(title);
        }
        sql = ParameterUtils.replaceScheduleTime((String)sql, (Date)this.taskExecutionContext.getScheduleTime());
        String rgex = "['\"]*\\$\\{(.*?)\\}['\"]*";
        this.setSqlParamsMap(sql, rgex, sqlParamsMap, paramsMap);
        String formatSql = sql.replaceAll(rgex, "?");
        sqlBuilder.append(formatSql);
        this.printReplacedSql(sql, formatSql, rgex, sqlParamsMap);
        return new SqlBinds(sqlBuilder.toString(), sqlParamsMap);
    }

    @Override
    public AbstractParameters getParameters() {
        return this.sqlParameters;
    }

    public void executeFuncAndSql(SqlBinds mainSqlBinds, List<SqlBinds> preStatementsBinds, List<SqlBinds> postStatementsBinds, List<String> createFuncs) throws Exception {
        Connection connection = null;
        PreparedStatement stmt = null;
        ResultSet resultSet = null;
        try {
            CommonUtils.loadKerberosConf();
            connection = this.createConnection();
            if (CollectionUtils.isNotEmpty(createFuncs)) {
                this.createTempFunction(connection, createFuncs);
            }
            this.preSql(connection, preStatementsBinds);
            stmt = this.prepareStatementAndBind(connection, mainSqlBinds);
            if (this.sqlParameters.getSqlType() == SqlType.QUERY.ordinal()) {
                resultSet = stmt.executeQuery();
                this.resultProcess(resultSet);
            } else if (this.sqlParameters.getSqlType() == SqlType.NON_QUERY.ordinal()) {
                stmt.executeUpdate();
            }
            this.postSql(connection, postStatementsBinds);
            this.close(resultSet, stmt, connection);
        }
        catch (Exception e) {
            try {
                this.logger.error("execute sql error: {}", (Object)e.getMessage());
                throw e;
            }
            catch (Throwable throwable) {
                this.close(resultSet, stmt, connection);
                throw throwable;
            }
        }
    }

    private void resultProcess(ResultSet resultSet) throws Exception {
        JSONArray resultJSONArray = new JSONArray();
        ResultSetMetaData md = resultSet.getMetaData();
        int num = md.getColumnCount();
        for (int rowCount = 0; rowCount < this.sqlParameters.getLimit() && resultSet.next(); ++rowCount) {
            JSONObject mapOfColValues = new JSONObject(true);
            for (int i = 1; i <= num; ++i) {
                mapOfColValues.put(md.getColumnLabel(i), resultSet.getObject(i));
            }
            resultJSONArray.add((Object)mapOfColValues);
        }
        String result = JSONUtils.toJsonString((Object)resultJSONArray);
        this.logger.debug("execute sql result : {}", (Object)result);
        int displayRows = this.sqlParameters.getDisplayRows() > 0 ? this.sqlParameters.getDisplayRows() : 10;
        displayRows = Math.min(displayRows, resultJSONArray.size());
        this.logger.info("display sql result {} rows as follows:", (Object)displayRows);
        for (int i = 0; i < displayRows; ++i) {
            String row = JSONUtils.toJsonString((Object)resultJSONArray.get(i));
            this.logger.info("row {} : {}", (Object)(i + 1), (Object)row);
        }
        if (this.sqlParameters.getSendEmail() == null || this.sqlParameters.getSendEmail().booleanValue()) {
            this.sendAttachment(StringUtils.isNotEmpty((String)this.sqlParameters.getTitle()) ? this.sqlParameters.getTitle() : this.taskExecutionContext.getTaskName() + " query result sets", JSONUtils.toJsonString((Object)resultJSONArray));
        }
    }

    private void preSql(Connection connection, List<SqlBinds> preStatementsBinds) throws Exception {
        for (SqlBinds sqlBind : preStatementsBinds) {
            PreparedStatement pstmt = this.prepareStatementAndBind(connection, sqlBind);
            Throwable throwable = null;
            try {
                int result = pstmt.executeUpdate();
                this.logger.info("pre statement execute result: {}, for sql: {}", (Object)result, (Object)sqlBind.getSql());
            }
            catch (Throwable throwable2) {
                throwable = throwable2;
                throw throwable2;
            }
            finally {
                if (pstmt == null) continue;
                if (throwable != null) {
                    try {
                        pstmt.close();
                    }
                    catch (Throwable throwable3) {
                        throwable.addSuppressed(throwable3);
                    }
                    continue;
                }
                pstmt.close();
            }
        }
    }

    private void postSql(Connection connection, List<SqlBinds> postStatementsBinds) throws Exception {
        for (SqlBinds sqlBind : postStatementsBinds) {
            PreparedStatement pstmt = this.prepareStatementAndBind(connection, sqlBind);
            Throwable throwable = null;
            try {
                int result = pstmt.executeUpdate();
                this.logger.info("post statement execute result: {},for sql: {}", (Object)result, (Object)sqlBind.getSql());
            }
            catch (Throwable throwable2) {
                throwable = throwable2;
                throw throwable2;
            }
            finally {
                if (pstmt == null) continue;
                if (throwable != null) {
                    try {
                        pstmt.close();
                    }
                    catch (Throwable throwable3) {
                        throwable.addSuppressed(throwable3);
                    }
                    continue;
                }
                pstmt.close();
            }
        }
    }

    private void createTempFunction(Connection connection, List<String> createFuncs) throws Exception {
        try (Statement funcStmt = connection.createStatement();){
            for (String createFunc : createFuncs) {
                this.logger.info("hive create function sql: {}", (Object)createFunc);
                funcStmt.execute(createFunc);
            }
        }
    }

    private Connection createConnection() throws Exception {
        Connection connection = null;
        if (DbType.HIVE == DbType.valueOf((String)this.sqlParameters.getType())) {
            Properties paramProp = new Properties();
            paramProp.setProperty("user", this.baseDataSource.getUser());
            paramProp.setProperty("password", this.baseDataSource.getPassword());
            Map connParamMap = CollectionUtils.stringToMap((String)this.sqlParameters.getConnParams(), (String)";", (String)"hiveconf:");
            paramProp.putAll((Map<?, ?>)connParamMap);
            connection = DriverManager.getConnection(this.baseDataSource.getJdbcUrl(), paramProp);
        } else {
            connection = DriverManager.getConnection(this.baseDataSource.getJdbcUrl(), this.baseDataSource.getUser(), this.baseDataSource.getPassword());
        }
        return connection;
    }

    private void close(ResultSet resultSet, PreparedStatement pstmt, Connection connection) {
        if (resultSet != null) {
            try {
                resultSet.close();
            }
            catch (SQLException e) {
                this.logger.error("close result set error : {}", (Object)e.getMessage(), (Object)e);
            }
        }
        if (pstmt != null) {
            try {
                pstmt.close();
            }
            catch (SQLException e) {
                this.logger.error("close prepared statement error : {}", (Object)e.getMessage(), (Object)e);
            }
        }
        if (connection != null) {
            try {
                connection.close();
            }
            catch (SQLException e) {
                this.logger.error("close connection error : {}", (Object)e.getMessage(), (Object)e);
            }
        }
    }

    private PreparedStatement prepareStatementAndBind(Connection connection, SqlBinds sqlBinds) throws Exception {
        Map params;
        boolean timeoutFlag = TaskTimeoutStrategy.of((int)this.taskExecutionContext.getTaskTimeoutStrategy()) == TaskTimeoutStrategy.FAILED || TaskTimeoutStrategy.of((int)this.taskExecutionContext.getTaskTimeoutStrategy()) == TaskTimeoutStrategy.WARNFAILED;
        PreparedStatement stmt = connection.prepareStatement(sqlBinds.getSql());
        if (timeoutFlag) {
            stmt.setQueryTimeout(this.taskExecutionContext.getTaskTimeout());
        }
        if ((params = sqlBinds.getParamsMap()) != null) {
            for (Map.Entry entry : params.entrySet()) {
                Property prop = (Property)entry.getValue();
                ParameterUtils.setInParameter((int)((Integer)entry.getKey()), (PreparedStatement)stmt, (DataType)prop.getType(), (String)prop.getValue());
            }
        }
        this.logger.info("prepare statement replace sql : {} ", (Object)stmt);
        return stmt;
    }

    public void sendAttachment(String title, String content) {
        String showTypeName;
        List users = this.alertDao.queryUserByAlertGroupId(this.taskExecutionContext.getSqlTaskExecutionContext().getWarningGroupId());
        ArrayList<String> receiversList = new ArrayList<String>();
        for (User user : users) {
            receiversList.add(user.getEmail().trim());
        }
        String receivers = this.sqlParameters.getReceivers();
        if (StringUtils.isNotEmpty((String)receivers)) {
            String[] splits;
            for (String receiver : splits = receivers.split(",")) {
                receiversList.add(receiver.trim());
            }
        }
        ArrayList<String> receiversCcList = new ArrayList<String>();
        String receiversCc = this.sqlParameters.getReceiversCc();
        if (StringUtils.isNotEmpty((String)receiversCc)) {
            String[] splits;
            for (String receiverCc : splits = receiversCc.split(",")) {
                receiversCcList.add(receiverCc.trim());
            }
        }
        if (EnumUtils.isValidEnum(ShowType.class, (String)(showTypeName = this.sqlParameters.getShowType().replace(",", "").trim()))) {
            Map mailResult = MailUtils.sendMails(receiversList, receiversCcList, (String)title, (String)content, (String)ShowType.valueOf((String)showTypeName).getDescp());
            if (!((Boolean)mailResult.get("status")).booleanValue()) {
                throw new RuntimeException("send mail failed!");
            }
        } else {
            this.logger.error("showType: {} is not valid ", (Object)showTypeName);
            throw new RuntimeException(String.format("showType: %s is not valid ", showTypeName));
        }
    }

    public void setSqlParamsMap(String content, String rgex, Map<Integer, Property> sqlParamsMap, Map<String, Property> paramsPropsMap) {
        Pattern pattern = Pattern.compile(rgex);
        Matcher m = pattern.matcher(content);
        int index = 1;
        while (m.find()) {
            String paramName = m.group(1);
            Property prop = paramsPropsMap.get(paramName);
            if (prop == null) {
                this.logger.error("setSqlParamsMap: No Property with paramName: {} is found in paramsPropsMap of task instance with id: {}. So couldn't put Property in sqlParamsMap.", (Object)paramName, (Object)this.taskExecutionContext.getTaskInstanceId());
                continue;
            }
            sqlParamsMap.put(index, prop);
            ++index;
            this.logger.info("setSqlParamsMap: Property with paramName: {} put in sqlParamsMap of content {} successfully.", (Object)paramName, (Object)content);
        }
    }

    public void printReplacedSql(String content, String formatSql, String rgex, Map<Integer, Property> sqlParamsMap) {
        this.logger.info("after replace sql , preparing : {}", (Object)formatSql);
        StringBuilder logPrint = new StringBuilder("replaced sql , parameters:");
        if (sqlParamsMap == null) {
            this.logger.info("printReplacedSql: sqlParamsMap is null.");
        } else {
            for (int i = 1; i <= sqlParamsMap.size(); ++i) {
                logPrint.append(sqlParamsMap.get(i).getValue() + "(" + sqlParamsMap.get(i).getType() + ")");
            }
        }
        this.logger.info("Sql Params are {}", (Object)logPrint);
    }
}

