diff --git a/src/main/scala/gitbucket/core/util/JDBCUtil.scala b/src/main/scala/gitbucket/core/util/JDBCUtil.scala index 631ad88..60116a5 100644 --- a/src/main/scala/gitbucket/core/util/JDBCUtil.scala +++ b/src/main/scala/gitbucket/core/util/JDBCUtil.scala @@ -6,6 +6,7 @@ import javax.xml.stream.{XMLStreamConstants, XMLInputFactory, XMLOutputFactory} import ControlUtil._ import scala.StringBuilder +import scala.annotation.tailrec import scala.collection.mutable import scala.collection.mutable.ListBuffer @@ -144,6 +145,8 @@ writer.writeStartDocument("UTF-8", "1.0") writer.writeStartElement("tables") + println(allTablesInDatabase.mkString(", ")) + allTablesInDatabase.reverse.foreach { tableName => if (targetTables.contains(tableName)) { writer.writeStartElement("delete") @@ -292,12 +295,34 @@ val result = TableDependency(tableName, childTables(meta, tableName)) result } - tables.sortWith { (a, b) => - a.children.contains(b.tableName) - }.map(_.tableName) + + val edges = tables.flatMap { table => + table.children.map { child => (table.tableName, child) } + } + + tsort(edges).toSeq } case class TableDependency(tableName: String, children: Seq[String]) + + + def tsort[A](edges: Traversable[(A, A)]): Iterable[A] = { + @tailrec + def tsort(toPreds: Map[A, Set[A]], done: Iterable[A]): Iterable[A] = { + val (noPreds, hasPreds) = toPreds.partition { _._2.isEmpty } + if (noPreds.isEmpty) { + if (hasPreds.isEmpty) done else sys.error(hasPreds.toString) + } else { + val found = noPreds.map { _._1 } + tsort(hasPreds.mapValues { _ -- found }, done ++ found) + } + } + + val toPred = edges.foldLeft(Map[A, Set[A]]()) { (acc, e) => + acc + (e._1 -> acc.getOrElse(e._1, Set())) + (e._2 -> (acc.getOrElse(e._2, Set()) + e._1)) + } + tsort(toPred, Seq()) + } } }