package com.wing321.test;

import com.wing321.test.spring.listener.PersistenceTestExecutionListener;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.Resource;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator;
import org.springframework.test.context.TestExecutionListeners;
import org.springframework.test.context.jdbc.SqlScriptsTestExecutionListener;
import org.springframework.test.context.transaction.TransactionConfiguration;
import org.springframework.test.context.transaction.TransactionalTestExecutionListener;
import org.springframework.test.jdbc.JdbcTestUtils;
import org.springframework.transaction.annotation.Transactional;

import javax.sql.DataSource;
import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.net.URL;

/**
 * 用于进行持久化测试
 * @author woate
 */
@Transactional
@TransactionConfiguration(transactionManager = "transactionManager", defaultRollback = true)
@TestExecutionListeners({PersistenceTestExecutionListener.class,
        TransactionalTestExecutionListener.class,
        SqlScriptsTestExecutionListener.class})
public abstract class PersistenceBaseTest extends BaseTest {
    protected JdbcTemplate jdbcTemplate;
    private String sqlScriptEncoding = "UTF-8";
    @Autowired
    public void setDataSource(DataSource dataSource) {
        this.jdbcTemplate = new JdbcTemplate(dataSource);
    }

    /**
     * 统计指定表的记录条数
     */
    protected int countRowsInTable(String tableName) {
        return JdbcTestUtils.countRowsInTable(this.jdbcTemplate, tableName);
    }

    /**
     * 统计指定表指定条件的记录条数
     */
    protected int countRowsInTableWhere(String tableName, String whereClause) {
        return JdbcTestUtils.countRowsInTableWhere(this.jdbcTemplate, tableName, whereClause);
    }

    /**
     * 删除指定表的数据
     *
     * @param names
     * @return
     */
    protected int deleteFromTables(String... names) {
        return JdbcTestUtils.deleteFromTables(this.jdbcTemplate, names);
    }

    /**
     * 删除指定表指定条件的数据
     *
     * @param tableName
     * @param whereClause
     * @param args
     * @return
     */
    protected int deleteFromTableWhere(String tableName, String whereClause, Object... args) {
        return JdbcTestUtils.deleteFromTableWhere(jdbcTemplate, tableName, whereClause, args);
    }

    /**
     * 删除指定表定义
     *
     * @param names
     */
    protected void dropTables(String... names) {
        JdbcTestUtils.dropTables(this.jdbcTemplate, names);
    }

    public void setSqlScriptEncoding(String sqlScriptEncoding) {
        this.sqlScriptEncoding = sqlScriptEncoding;
    }

    protected void executeSqlScript(String sqlResourcePath, boolean continueOnError) throws DataAccessException {
        URL url = Thread.currentThread().getContextClassLoader().getResource(sqlResourcePath);
        try {
            System.out.println(new File(url.toURI().getRawPath()).getCanonicalPath());
        } catch (IOException e) {
            e.printStackTrace();
        } catch (URISyntaxException e) {
            e.printStackTrace();
        }
        Resource resource = this.applicationContext.getResource(sqlResourcePath);
        new ResourceDatabasePopulator(continueOnError, false, this.sqlScriptEncoding, resource).execute(jdbcTemplate.getDataSource());
    }
}
