diff --git a/src/main/scala/gitbucket/core/controller/FileUploadController.scala b/src/main/scala/gitbucket/core/controller/FileUploadController.scala index 39bdb99..fe9b456 100644 --- a/src/main/scala/gitbucket/core/controller/FileUploadController.scala +++ b/src/main/scala/gitbucket/core/controller/FileUploadController.scala @@ -80,20 +80,9 @@ post("/import") { session.get(Keys.Session.LoginAccount).collect { case loginAccount: Account if loginAccount.isAdmin => execute({ (file, fileId) => - using(file.getInputStream){ in => - import JDBCUtil._ - val sql = IOUtils.toString(in, "UTF-8") - val conn = request2Session(request).conn - conn.setAutoCommit(false) - try { - conn.update(sql) - conn.commit() - } catch { - case e: Throwable => - conn.rollback() - throw e - } - } + import JDBCUtil._ + val conn = request2Session(request).conn + conn.importAsXML(file.getInputStream) }, _ => true) } redirect("/admin/data") diff --git a/src/main/scala/gitbucket/core/controller/SystemSettingsController.scala b/src/main/scala/gitbucket/core/controller/SystemSettingsController.scala index 9b897de..4f4ee99 100644 --- a/src/main/scala/gitbucket/core/controller/SystemSettingsController.scala +++ b/src/main/scala/gitbucket/core/controller/SystemSettingsController.scala @@ -280,7 +280,7 @@ post("/admin/export")(adminOnly { import gitbucket.core.util.JDBCUtil._ val session = request2Session(request) - val file = session.conn.export(request.getParameterValues("tableNames").toSeq) + val file = session.conn.exportAsXML(request.getParameterValues("tableNames").toSeq) contentType = "application/octet-stream" response.setHeader("Content-Disposition", "attachment; filename=" + file.getName) diff --git a/src/main/scala/gitbucket/core/util/JDBCUtil.scala b/src/main/scala/gitbucket/core/util/JDBCUtil.scala index f269da8..80e1469 100644 --- a/src/main/scala/gitbucket/core/util/JDBCUtil.scala +++ b/src/main/scala/gitbucket/core/util/JDBCUtil.scala @@ -3,7 +3,10 @@ import java.io._ import java.sql._ import java.text.SimpleDateFormat +import javax.xml.stream.{XMLStreamConstants, XMLInputFactory, XMLOutputFactory} import ControlUtil._ +import scala.StringBuilder +import scala.collection.mutable import scala.collection.mutable.ListBuffer /** @@ -60,61 +63,131 @@ } } - def export(targetTables: Seq[String]): File = { - val dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss") - val file = File.createTempFile("gitbucket-export-", ".sql") + def importAsXML(in: InputStream): Unit = { + conn.setAutoCommit(false) + try { + val factory = XMLInputFactory.newInstance() + using(factory.createXMLStreamReader(in)){ reader => + // stateful objects + var elementName = "" + var insertTable = "" + var insertColumns = Map.empty[String, (String, String)] - using(new FileOutputStream(file)) { out => + while(reader.hasNext){ + reader.next() + + reader.getEventType match { + case XMLStreamConstants.START_ELEMENT => + elementName = reader.getName.getLocalPart + if(elementName == "insert"){ + insertTable = reader.getAttributeValue(null, "table") + } else if(elementName == "delete"){ + val tableName = reader.getAttributeValue(null, "table") + conn.update(s"DELETE FROM ${tableName}") + } else if(elementName == "column"){ + val columnName = reader.getAttributeValue(null, "name") + val columnType = reader.getAttributeValue(null, "type") + val columnValue = reader.getElementText + insertColumns = insertColumns + (columnName -> (columnType, columnValue)) + } + case XMLStreamConstants.END_ELEMENT => + // Execute insert statement + reader.getName.getLocalPart match { + case "insert" => { + val sb = new StringBuilder() + sb.append(s"INSERT INTO ${insertTable} (") + sb.append(insertColumns.map { case (columnName, _) => columnName }.mkString(", ")) + sb.append(") VALUES (") + sb.append(insertColumns.map { case (_, (columnType, columnValue)) => + if(columnType == null || columnValue == null){ + "NULL" + } else if(columnType == "string"){ + "'" + columnValue.replace("'", "''") + "'" + } else if(columnType == "timestamp"){ + "'" + columnValue + "'" + } else { + columnValue.toString + } + }.mkString(", ")) + sb.append(")") + + conn.update(sb.toString) + + insertColumns = Map.empty[String, (String, String)] // Clear column information + } + case _ => // Nothing to do + } + case _ => // Nothing to do + } + } + } + + conn.commit() + + } catch { + case e: Exception => { + conn.rollback() + throw e + } + } + } + + def exportAsXML(targetTables: Seq[String]): File = { + val dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss") + val file = File.createTempFile("gitbucket-export-", ".xml") + + val factory = XMLOutputFactory.newInstance() + using(factory.createXMLStreamWriter(new FileOutputStream(file))){ writer => val dbMeta = conn.getMetaData val allTablesInDatabase = allTablesOrderByDependencies(dbMeta) + writer.writeStartDocument("UTF-8", "1.0") + writer.writeStartElement("tables") + allTablesInDatabase.reverse.foreach { tableName => if (targetTables.contains(tableName)) { - out.write(s"DELETE FROM ${tableName};\n".getBytes("UTF-8")) + writer.writeStartElement("delete") + writer.writeAttribute("table", tableName) + writer.writeEndElement() } } allTablesInDatabase.foreach { tableName => if (targetTables.contains(tableName)) { - val sb = new StringBuilder() select(s"SELECT * FROM ${tableName}") { rs => - sb.append(s"INSERT INTO ${tableName} (") - + writer.writeStartElement("insert") + writer.writeAttribute("table", tableName) val rsMeta = rs.getMetaData - val columns = (1 to rsMeta.getColumnCount).map { i => - (rsMeta.getColumnName(i), rsMeta.getColumnType(i)) - } - sb.append(columns.map(_._1).mkString(", ")) - sb.append(") VALUES (") - - val values = columns.map { case (columnName, columnType) => - if(rs.getObject(columnName) == null){ - null + (1 to rsMeta.getColumnCount).foreach { i => + val columnName = rsMeta.getColumnName(i) + val (columnType, columnValue) = if(rs.getObject(columnName) == null){ + (null, null) } else { - columnType match { - case Types.BOOLEAN | Types.BIT => rs.getBoolean(columnName) - case Types.VARCHAR | Types.CLOB | Types.CHAR => rs.getString(columnName) - case Types.INTEGER => rs.getInt(columnName) - case Types.TIMESTAMP => rs.getTimestamp(columnName) + rsMeta.getColumnType(i) match { + case Types.BOOLEAN | Types.BIT => ("boolean" , rs.getBoolean(columnName)) + case Types.VARCHAR | Types.CLOB | Types.CHAR | Types.LONGVARCHAR + => ("string" , rs.getString(columnName)) + case Types.INTEGER => ("int" , rs.getInt(columnName)) + case Types.TIMESTAMP => ("timestamp", dateFormat.format(rs.getTimestamp(columnName))) } } - } - - val columnValues = values.map { value => - value match { - case x: String => "'" + x.replace("'", "''") + "'" - case x: Timestamp => "'" + dateFormat.format(x) + "'" - case null => "NULL" - case x => x + writer.writeStartElement("column") + writer.writeAttribute("name", columnName) + if(columnType != null){ + writer.writeAttribute("type", columnType) } + if(columnValue != null){ + writer.writeCharacters(columnValue.toString) + } + writer.writeEndElement() } - sb.append(columnValues.mkString(", ")) - sb.append(");\n") + writer.writeEndElement() } - - out.write(sb.toString.getBytes("UTF-8")) } } + + writer.writeEndElement() + writer.writeEndDocument() } file @@ -133,27 +206,6 @@ } } -// private def parentTables(meta: DatabaseMetaData, tableName: String): Seq[String] = { -// val normalizedTableName = -// if(meta.getDatabaseProductName == "PostgreSQL"){ -// tableName.toLowerCase -// } else { -// tableName -// } -// -// using(meta.getImportedKeys(null, null, normalizedTableName)) { rs => -// val parents = new ListBuffer[String] -// while (rs.next) { -// val parentTableName = rs.getString("PKTABLE_NAME").toUpperCase -// if(!parents.contains(parentTableName)){ -// parents += parentTableName -// parents ++= parentTables(meta, parentTableName) -// } -// } -// parents.distinct.toSeq -// } -// } - private def childTables(meta: DatabaseMetaData, tableName: String): Seq[String] = { val normalizedTableName = if(meta.getDatabaseProductName == "PostgreSQL"){ @@ -179,7 +231,6 @@ private def allTablesOrderByDependencies(meta: DatabaseMetaData): Seq[String] = { val tables = allTableNames.map { tableName => val result = TableDependency(tableName, childTables(meta, tableName)) - println(result) result } tables.sortWith { (a, b) =>