diff --git a/src/main/scala/gitbucket/core/util/JDBCUtil.scala b/src/main/scala/gitbucket/core/util/JDBCUtil.scala index 41a2b1a..78829a7 100644 --- a/src/main/scala/gitbucket/core/util/JDBCUtil.scala +++ b/src/main/scala/gitbucket/core/util/JDBCUtil.scala @@ -1,6 +1,8 @@ package gitbucket.core.util +import java.io.FileOutputStream import java.sql._ +import java.text.SimpleDateFormat import ControlUtil._ import scala.collection.mutable.ListBuffer @@ -58,6 +60,74 @@ } } + def export(): Unit = { + val dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss") + + using(new FileOutputStream("export.sql")) { out => + val dbMeta = conn.getMetaData + + allTables(dbMeta).foreach { tableName => + val sb = new StringBuilder() + + select(s"SELECT * FROM ${tableName}") { rs => + sb.append(s"INSERT INTO ${tableName} (") + val rsMeta = rs.getMetaData + val columns = (1 to rsMeta.getColumnCount).map { i => + (rsMeta.getColumnName(i), rsMeta.getColumnType(i)) + } + sb.append(columns.map(_._1).mkString(", ")) + sb.append(") VALUES (") + val values = columns.map { case (columnName, columnType) => + columnType match { + case Types.BOOLEAN => rs.getBoolean(columnName) + case Types.VARCHAR | Types.CLOB | Types.CHAR => rs.getString(columnName) + case Types.INTEGER => rs.getInt(columnName) + case Types.TIMESTAMP => rs.getTimestamp(columnName) + } + } + + val columnValues = values.map { value => + value match { + case x: String => "'" + x.replace("'", "''") + "'" + case x: Timestamp => "'" + dateFormat.format(x) + "'" + case null => "NULL" + case x => x + } + } + sb.append(columnValues.mkString(", ")) + sb.append(");\n") + } + + out.write(sb.toString.getBytes("UTF-8")) + } + } + } + + private def parentTables(meta: DatabaseMetaData, tableName: String): Seq[String] = { + using(meta.getImportedKeys(null, null, tableName)) { rs => + val parents = new ListBuffer[String] + while (rs.next) { + val tableName = rs.getString("PKTABLE_NAME") + parents += tableName + parents ++= parentTables(meta, tableName) + } + parents.toSeq + } + } + + private def allTables(meta: DatabaseMetaData): Seq[String] = { + using(meta.getTables(null, null, "%", Seq("TABLE").toArray)) { rs => + val tables = new ListBuffer[(String, Seq[String])] + while (rs.next) { + val name = rs.getString("TABLE_NAME") + if(name != "VERSIONS") { + tables += ((name, parentTables(meta, name))) + } + } + tables.sortWith { (a, b) => b._2.contains(a._1) }.map(_._1).toSeq + } + } + } }