package org.unitils.dbunit;

import java.io.InputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Properties;
import org.dbunit.dataset.datatype.IDataTypeFactory;
import org.unitils.core.Module;
import org.unitils.core.TestListener;
import org.unitils.core.Unitils;
import org.unitils.core.UnitilsException;
import org.unitils.core.dbsupport.DbSupport;
import org.unitils.core.dbsupport.DbSupportFactory;
import org.unitils.core.dbsupport.SQLHandler;
import org.unitils.database.DatabaseModule;
import org.unitils.database.transaction.TransactionalDataSource;
import org.unitils.dbunit.annotation.DataSet;
import org.unitils.dbunit.annotation.ExpectedDataSet;
import org.unitils.dbunit.datasetfactory.DataSetFactory;
import org.unitils.dbunit.datasetloadstrategy.DataSetLoadStrategy;
import org.unitils.dbunit.util.DbUnitAssert;
import org.unitils.dbunit.util.DbUnitDatabaseConnection;
import org.unitils.dbunit.util.MultiSchemaDataSet;
import org.unitils.dbunit.util.MultiSchemaXmlDataSetReader;
import org.unitils.thirdparty.org.apache.commons.io.IOUtils;
import org.unitils.util.AnnotationUtils;
import org.unitils.util.ConfigUtils;
import org.unitils.util.ModuleUtils;
import org.unitils.util.ReflectionUtils;

/* loaded from: input_file:org/unitils/dbunit/DbUnitModule.class */
public class DbUnitModule implements Module {
    private Map<Class<? extends Annotation>, Map<Method, String>> defaultAnnotationPropertyValues;
    private Map<String, DbUnitDatabaseConnection> dbUnitDatabaseConnections = new HashMap();
    private Properties configuration;

    /* loaded from: input_file:org/unitils/dbunit/DbUnitModule$DbUnitListener.class */
    protected class DbUnitListener extends TestListener {
        protected DbUnitListener() {
        }

        @Override // org.unitils.core.TestListener
        public void beforeAll() {
            if (DbUnitModule.this.getDatabaseModule() == null) {
                throw new UnitilsException("Invalid configuration: When the DbUnitModule is enabled, the DatabaseModule should also be enabled and the DbUnitModule should be configured to run after the DatabaseModule");
            }
        }

        @Override // org.unitils.core.TestListener
        public void beforeTestMethod(Object obj, Method method) {
            DbUnitModule.this.insertTestData(method);
        }

        @Override // org.unitils.core.TestListener
        public void afterTestMethod(Object obj, Method method, Throwable th) {
            if (th == null) {
                DbUnitModule.this.assertDbContentAsExpected(method);
            }
        }
    }

    @Override // org.unitils.core.Module
    public void init(Properties properties) {
        this.configuration = properties;
        this.defaultAnnotationPropertyValues = ModuleUtils.getAnnotationPropertyDefaults(DbUnitModule.class, properties, DataSet.class, ExpectedDataSet.class);
    }

    public DbUnitDatabaseConnection getDbUnitDatabaseConnection(String str) {
        DbUnitDatabaseConnection dbUnitDatabaseConnection = this.dbUnitDatabaseConnections.get(str);
        if (dbUnitDatabaseConnection == null) {
            dbUnitDatabaseConnection = createDbUnitConnection(str);
            this.dbUnitDatabaseConnections.put(str, dbUnitDatabaseConnection);
        }
        return dbUnitDatabaseConnection;
    }

    public void insertTestData(Method method) {
        try {
            try {
                MultiSchemaDataSet testDataSets = getTestDataSets(method);
                if (testDataSets == null) {
                    return;
                }
                DataSetLoadStrategy dataSetOperation = getDataSetOperation(method);
                for (String str : testDataSets.getSchemaNames()) {
                    dataSetOperation.execute(getDbUnitDatabaseConnection(str), testDataSets.getDataSetForSchema(str));
                }
                closeJdbcConnection();
            } catch (Exception e) {
                throw new UnitilsException("Error inserting test data from DbUnit dataset for method " + method, e);
            }
        } finally {
            closeJdbcConnection();
        }
    }

    public void insertTestData(InputStream inputStream, DataSetLoadStrategy dataSetLoadStrategy) {
        try {
            try {
                MultiSchemaDataSet dataSet = getDataSet(inputStream);
                for (String str : dataSet.getSchemaNames()) {
                    dataSetLoadStrategy.execute(getDbUnitDatabaseConnection(str), dataSet.getDataSetForSchema(str));
                }
            } catch (Exception e) {
                throw new UnitilsException("Error inserting test data from DbUnit dataset.", e);
            }
        } finally {
            closeJdbcConnection();
        }
    }

    public void assertDbContentAsExpected(Method method) {
        try {
            MultiSchemaDataSet expectedTestDataSet = getExpectedTestDataSet(method);
            if (expectedTestDataSet == null) {
                return;
            }
            getDatabaseModule().flushDatabaseUpdates();
            for (String str : expectedTestDataSet.getSchemaNames()) {
                DbUnitAssert.assertDbContentAsExpected(expectedTestDataSet.getDataSetForSchema(str), getDbUnitDatabaseConnection(str));
            }
            closeJdbcConnection();
        } finally {
            closeJdbcConnection();
        }
    }

    public MultiSchemaDataSet getTestDataSets(Method method) {
        DataSet dataSet = (DataSet) AnnotationUtils.getMethodOrClassLevelAnnotation(DataSet.class, method);
        if (dataSet == null) {
            return null;
        }
        String value = dataSet.value();
        Class<?> declaringClass = method.getDeclaringClass();
        DataSetFactory dataSetFactory = getDataSetFactory(DataSet.class, method);
        if ("".equals(value)) {
            MultiSchemaDataSet dataSet2 = getDataSet(declaringClass, getMethodLevelDefaultTestDataSetFileName(method, dataSetFactory.getDataSetFileExtension()), dataSetFactory);
            if (dataSet2 != null) {
                return dataSet2;
            }
            value = getClassLevelDefaultTestDataSetFileName(declaringClass, dataSetFactory.getDataSetFileExtension());
        }
        MultiSchemaDataSet dataSet3 = getDataSet(declaringClass, value, dataSetFactory);
        if (dataSet3 == null) {
            throw new UnitilsException("Could not find DbUnit dataset with name " + value);
        }
        return dataSet3;
    }

    public MultiSchemaDataSet getExpectedTestDataSet(Method method) {
        ExpectedDataSet expectedDataSet = (ExpectedDataSet) AnnotationUtils.getMethodOrClassLevelAnnotation(ExpectedDataSet.class, method);
        if (expectedDataSet == null) {
            return null;
        }
        DataSetFactory dataSetFactory = getDataSetFactory(ExpectedDataSet.class, method);
        String value = expectedDataSet.value();
        if ("".equals(value)) {
            value = getDefaultExpectedDataSetFileName(method, dataSetFactory.getDataSetFileExtension());
        }
        MultiSchemaDataSet dataSet = getDataSet(method.getDeclaringClass(), value, dataSetFactory);
        if (dataSet == null) {
            throw new UnitilsException("Could not find expected DbUnit dataset with name " + value);
        }
        return dataSet;
    }

    protected MultiSchemaDataSet getDataSet(Class cls, String str, DataSetFactory dataSetFactory) {
        try {
            InputStream resourceAsStream = cls.getResourceAsStream(str);
            if (resourceAsStream == null) {
                return null;
            }
            return getDataSet(resourceAsStream);
        } catch (Exception e) {
            throw new UnitilsException("Unable to create DbUnit dataset for file " + str, e);
        }
    }

    public MultiSchemaDataSet getDataSet(InputStream inputStream) {
        try {
            try {
                MultiSchemaDataSet readDataSetXml = new MultiSchemaXmlDataSetReader(DbSupportFactory.getDefaultDbSupport(this.configuration, new SQLHandler(getDatabaseModule().getDataSource())).getSchemaName()).readDataSetXml(inputStream);
                IOUtils.closeQuietly(inputStream);
                return readDataSetXml;
            } catch (Exception e) {
                throw new UnitilsException("Unable to create DbUnit dataset for input stream.", e);
            }
        } catch (Throwable th) {
            IOUtils.closeQuietly(inputStream);
            throw th;
        }
    }

    protected DataSetLoadStrategy getDataSetOperation(Method method) {
        return (DataSetLoadStrategy) ReflectionUtils.createInstanceOfType((Class) ModuleUtils.getClassValueReplaceDefault(DataSet.class, "loadStrategy", (Class) AnnotationUtils.getMethodOrClassLevelAnnotationProperty(DataSet.class, "loadStrategy", DataSetLoadStrategy.class, method), this.defaultAnnotationPropertyValues, DataSetLoadStrategy.class), false);
    }

    protected DbUnitDatabaseConnection createDbUnitConnection(String str) {
        TransactionalDataSource dataSource = getDatabaseModule().getDataSource();
        DbSupport dbSupport = DbSupportFactory.getDbSupport(this.configuration, new SQLHandler(dataSource), str);
        DbUnitDatabaseConnection dbUnitDatabaseConnection = new DbUnitDatabaseConnection(dataSource, dbSupport.getSchemaName());
        dbUnitDatabaseConnection.getConfig().setProperty("http://www.dbunit.org/properties/datatypeFactory", (IDataTypeFactory) ConfigUtils.getConfiguredInstance(IDataTypeFactory.class, this.configuration, dbSupport.getDatabaseDialect()));
        return dbUnitDatabaseConnection;
    }

    protected void closeJdbcConnection() {
        try {
            Iterator<DbUnitDatabaseConnection> it = this.dbUnitDatabaseConnections.values().iterator();
            while (it.hasNext()) {
                it.next().closeJdbcConnection();
            }
        } catch (SQLException e) {
            throw new UnitilsException("Error while closing connection.", e);
        }
    }

    protected String getMethodLevelDefaultTestDataSetFileName(Method method, String str) {
        String name = method.getDeclaringClass().getName();
        return name.substring(name.lastIndexOf(".") + 1) + "." + method.getName() + '.' + str;
    }

    protected String getClassLevelDefaultTestDataSetFileName(Class<?> cls, String str) {
        String name = cls.getName();
        return name.substring(name.lastIndexOf(".") + 1) + '.' + str;
    }

    protected static String getDefaultExpectedDataSetFileName(Method method, String str) {
        String name = method.getDeclaringClass().getName();
        return name.substring(name.lastIndexOf(".") + 1) + "." + method.getName() + "-result." + str;
    }

    protected DataSetFactory getDataSetFactory(Class<? extends Annotation> cls, Method method) {
        return (DataSetFactory) ReflectionUtils.createInstanceOfType((Class) ModuleUtils.getClassValueReplaceDefault(cls, "factory", (Class) AnnotationUtils.getMethodOrClassLevelAnnotationProperty(cls, "factory", DataSetFactory.class, method), this.defaultAnnotationPropertyValues, DataSetFactory.class), false);
    }

    protected DatabaseModule getDatabaseModule() {
        return (DatabaseModule) Unitils.getInstance().getModulesRepository().getModuleOfType(DatabaseModule.class);
    }

    @Override // org.unitils.core.Module
    public TestListener createTestListener() {
        return new DbUnitListener();
    }
}
