001package com.avaje.ebean.dbmigration;
002
003import com.avaje.ebean.Transaction;
004import com.avaje.ebeaninternal.api.SpiEbeanServer;
005import org.slf4j.Logger;
006import org.slf4j.LoggerFactory;
007
008import javax.persistence.PersistenceException;
009import java.io.StringReader;
010import java.sql.Connection;
011import java.sql.PreparedStatement;
012import java.sql.SQLException;
013import java.util.ArrayList;
014import java.util.List;
015
016/**
017 * Runs DDL scripts.
018 */
019public class DdlRunner {
020
021  protected static final Logger logger = LoggerFactory.getLogger(DdlRunner.class);
022
023  protected DdlParser ddlParser = new DdlParser();
024
025  protected final String scriptName;
026
027  protected final boolean expectErrors;
028
029  /**
030   * Construct with a script name (for logging) and flag indicating if errors are expected.
031   */
032  public DdlRunner(boolean expectErrors, String scriptName) {
033    this.expectErrors = expectErrors;
034    this.scriptName = scriptName;
035  }
036
037  /**
038   * Parse the content into sql statements and execute them in a transaction.
039   */
040  public int runAll(String content, SpiEbeanServer server) {
041
042    List<String> statements = ddlParser.parse(new StringReader(content));
043    return runStatements(statements, server);
044  }
045
046  /**
047   * Execute all the statements in a single transaction.
048   */
049  public int runStatements(List<String> statements, SpiEbeanServer server) {
050
051    Transaction t = server.createTransaction();
052    try {
053      int statementCount = runStatements(expectErrors, statements, t.getConnection());
054      t.commit();
055
056      return statementCount;
057
058    } catch (Exception e) {
059      throw new PersistenceException("Error: " + e.getMessage(), e);
060
061    } finally {
062      t.end();
063    }
064  }
065
066  /**
067   * Execute the list of statements.
068   */
069  private int runStatements(boolean expectErrors, List<String> statements, Connection c) {
070
071    List<String> noDuplicates = new ArrayList<String>();
072
073    for (String statement : statements) {
074      if (!noDuplicates.contains(statement)) {
075        noDuplicates.add(statement);
076      }
077    }
078
079    logger.info("Executing {} - {} statements", scriptName, noDuplicates.size());
080
081    for (int i = 0; i < noDuplicates.size(); i++) {
082      String xOfy = (i + 1) + " of " + noDuplicates.size();
083      runStatement(expectErrors, xOfy, noDuplicates.get(i), c);
084    }
085
086    return noDuplicates.size();
087  }
088
089  /**
090   * Execute the statement.
091   */
092  private void runStatement(boolean expectErrors, String oneOf, String stmt, Connection c) {
093
094    PreparedStatement pstmt = null;
095    try {
096
097      // trim and remove trailing ; or /
098      stmt = stmt.trim();
099      if (stmt.endsWith(";")) {
100        stmt = stmt.substring(0, stmt.length() - 1);
101      } else if (stmt.endsWith("/")) {
102        stmt = stmt.substring(0, stmt.length() - 1);
103      }
104
105      if (logger.isDebugEnabled()) {
106        logger.debug("executing " + oneOf + " " + getSummary(stmt));
107      }
108
109      pstmt = c.prepareStatement(stmt);
110      pstmt.execute();
111
112    } catch (Exception e) {
113      if (expectErrors) {
114        logger.debug(" ... ignoring error executing " + getSummary(stmt) + "  error: " + e.getMessage());
115      } else {
116        String msg = "Error executing stmt[" + stmt + "] error[" + e.getMessage() + "]";
117        throw new RuntimeException(msg, e);
118      }
119
120    } finally {
121      if (pstmt != null) {
122        try {
123          pstmt.close();
124        } catch (SQLException e) {
125          logger.error("Error closing pstmt", e);
126        }
127      }
128    }
129  }
130
131  private String getSummary(String s) {
132    if (s.length() > 80) {
133      return s.substring(0, 80).trim() + "...";
134    }
135    return s;
136  }
137
138}