diff --git a/src/main/scala/gitbucket/core/servlet/CompositeScalatraFilter.scala b/src/main/scala/gitbucket/core/servlet/CompositeScalatraFilter.scala index 56bf848..9222f6f 100644 --- a/src/main/scala/gitbucket/core/servlet/CompositeScalatraFilter.scala +++ b/src/main/scala/gitbucket/core/servlet/CompositeScalatraFilter.scala @@ -7,7 +7,32 @@ import scala.collection.mutable.ListBuffer -class CompositeScalatraFilter extends Filter { +abstract class ControllerFilter extends Filter { + + def process(request: ServletRequest, response: ServletResponse, checkPath: String): Boolean + + override def doFilter(request: ServletRequest, response: ServletResponse, chain: FilterChain): Unit = { + val contextPath = request.getServletContext.getContextPath + val requestPath = request.asInstanceOf[HttpServletRequest].getRequestURI.substring(contextPath.length) + val checkPath = if (requestPath.endsWith("/")) { + requestPath + } else { + requestPath + "/" + } + + if (!checkPath.startsWith("/upload/") && !checkPath.startsWith("/git/") && !checkPath.startsWith("/git-lfs/") && + !checkPath.startsWith("/assets/") && !checkPath.startsWith("/plugin-assets/")) { + val continue = process(request, response, checkPath) + if (!continue) { + return () + } + } + + chain.doFilter(request, response) + } +} + +class CompositeScalatraFilter extends ControllerFilter { private val filters = new ListBuffer[(ScalatraFilter, String)]() @@ -29,34 +54,23 @@ } } - override def doFilter(request: ServletRequest, response: ServletResponse, chain: FilterChain): Unit = { - val contextPath = request.getServletContext.getContextPath - val requestPath = request.asInstanceOf[HttpServletRequest].getRequestURI.substring(contextPath.length) - val checkPath = if (requestPath.endsWith("/")) { - requestPath - } else { - requestPath + "/" - } + override def process(request: ServletRequest, response: ServletResponse, checkPath: String): Boolean = { + filters + .filter { + case (_, path) => + val start = path.replaceFirst("/\\*$", "/") + checkPath.startsWith(start) + } + .foreach { + case (filter, _) => + val mockChain = new MockFilterChain() + filter.doFilter(request, response, mockChain) + if (mockChain.continue == false) { + return false + } + } - if (!checkPath.startsWith("/upload/") && !checkPath.startsWith("/git/") && !checkPath.startsWith("/git-lfs/") && - !checkPath.startsWith("/plugin-assets/")) { - filters - .filter { - case (_, path) => - val start = path.replaceFirst("/\\*$", "/") - checkPath.startsWith(start) - } - .foreach { - case (filter, _) => - val mockChain = new MockFilterChain() - filter.doFilter(request, response, mockChain) - if (mockChain.continue == false) { - return () - } - } - } - - chain.doFilter(request, response) + true } } diff --git a/src/main/scala/gitbucket/core/servlet/PluginControllerFilter.scala b/src/main/scala/gitbucket/core/servlet/PluginControllerFilter.scala index 9137c15..c515b64 100644 --- a/src/main/scala/gitbucket/core/servlet/PluginControllerFilter.scala +++ b/src/main/scala/gitbucket/core/servlet/PluginControllerFilter.scala @@ -6,7 +6,7 @@ import gitbucket.core.controller.ControllerBase import gitbucket.core.plugin.PluginRegistry -class PluginControllerFilter extends Filter { +class PluginControllerFilter extends ControllerFilter { private var filterConfig: FilterConfig = null @@ -21,16 +21,13 @@ } } - override def doFilter(request: ServletRequest, response: ServletResponse, chain: FilterChain): Unit = { - val contextPath = request.getServletContext.getContextPath - val requestUri = request.asInstanceOf[HttpServletRequest].getRequestURI.substring(contextPath.length) - + override def process(request: ServletRequest, response: ServletResponse, checkPath: String): Boolean = { PluginRegistry() .getControllers() .filter { case (_, path) => val start = path.replaceFirst("/\\*$", "/") - (requestUri + "/").startsWith(start) + checkPath.startsWith(start) } .foreach { case (controller, _) => @@ -42,11 +39,11 @@ controller.doFilter(request, response, mockChain) if (mockChain.continue == false) { - return () + return false } } - chain.doFilter(request, response) + true } }