diff --git a/src/main/scala/ScalatraBootstrap.scala b/src/main/scala/ScalatraBootstrap.scala index 80e3a58..231da65 100644 --- a/src/main/scala/ScalatraBootstrap.scala +++ b/src/main/scala/ScalatraBootstrap.scala @@ -32,20 +32,25 @@ context.addFilter("pluginControllerFilter", new PluginControllerFilter) context.getFilterRegistration("pluginControllerFilter").addMappingForUrlPatterns(EnumSet.allOf(classOf[DispatcherType]), true, "/*") - context.mount(new IndexController, "/") - context.mount(new ApiController, "/api/v3") context.mount(new FileUploadController, "/upload") - context.mount(new SystemSettingsController, "/admin") - context.mount(new DashboardController, "/*") - context.mount(new AccountController, "/*") - context.mount(new RepositoryViewerController, "/*") - context.mount(new WikiController, "/*") - context.mount(new LabelsController, "/*") - context.mount(new PrioritiesController, "/*") - context.mount(new MilestonesController, "/*") - context.mount(new IssuesController, "/*") - context.mount(new PullRequestsController, "/*") - context.mount(new RepositorySettingsController, "/*") + + val filter = new CompositeScalatraFilter() + filter.mount(new IndexController, "/") + filter.mount(new ApiController, "/api/v3") + filter.mount(new SystemSettingsController, "/admin") + filter.mount(new DashboardController, "/*") + filter.mount(new AccountController, "/*") + filter.mount(new RepositoryViewerController, "/*") + filter.mount(new WikiController, "/*") + filter.mount(new LabelsController, "/*") + filter.mount(new PrioritiesController, "/*") + filter.mount(new MilestonesController, "/*") + filter.mount(new IssuesController, "/*") + filter.mount(new PullRequestsController, "/*") + filter.mount(new RepositorySettingsController, "/*") + + context.addFilter("compositeScalatraFilter", filter) + context.getFilterRegistration("compositeScalatraFilter").addMappingForUrlPatterns(EnumSet.allOf(classOf[DispatcherType]), true, "/*") // Create GITBUCKET_HOME directory if it does not exist val dir = new java.io.File(Directory.GitBucketHome) diff --git a/src/main/scala/gitbucket/core/servlet/CompositeScalatraFilter.scala b/src/main/scala/gitbucket/core/servlet/CompositeScalatraFilter.scala new file mode 100644 index 0000000..c180738 --- /dev/null +++ b/src/main/scala/gitbucket/core/servlet/CompositeScalatraFilter.scala @@ -0,0 +1,63 @@ +package gitbucket.core.servlet + +import javax.servlet._ +import javax.servlet.http.HttpServletRequest + +import org.scalatra.ScalatraFilter + +import scala.collection.mutable.ListBuffer + +class CompositeScalatraFilter extends Filter { + + private val filters = new ListBuffer[(ScalatraFilter, String)]() + + def mount(filter: ScalatraFilter, path: String): Unit = { + filters += ((filter, path)) + } + + override def init(filterConfig: FilterConfig): Unit = { + filters.foreach { case (filter, _) => + filter.init(filterConfig) + } + } + + override def destroy(): Unit = { + filters.foreach { case (filter, _) => + filter.destroy() + } + } + + override def doFilter(request: ServletRequest, response: ServletResponse, chain: FilterChain): Unit = { + val requestUri = request.asInstanceOf[HttpServletRequest].getRequestURI + + filters + .filter { case (_, path) => + val start = path.replaceFirst("/\\*$", "/") + (requestUri + "/").startsWith(start) + } + .foreach { case (filter, _) => + val mockChain = new MockFilterChain() + filter.doFilter(request, response, mockChain) + if(mockChain.continue == false){ + return () + } + } + + chain.doFilter(request, response) + } + +} + +class MockFilterChain extends FilterChain { + var continue: Boolean = false + + override def doFilter(request: ServletRequest, response: ServletResponse): Unit = { + continue = true + } +} + +class FilterChainFilter(chain: FilterChain) extends Filter { + override def init(filterConfig: FilterConfig): Unit = () + override def destroy(): Unit = () + override def doFilter(request: ServletRequest, response: ServletResponse, mockChain: FilterChain) = chain.doFilter(request, response) +} diff --git a/src/main/scala/gitbucket/core/servlet/PluginControllerFilter.scala b/src/main/scala/gitbucket/core/servlet/PluginControllerFilter.scala index f8be024..1e9713f 100644 --- a/src/main/scala/gitbucket/core/servlet/PluginControllerFilter.scala +++ b/src/main/scala/gitbucket/core/servlet/PluginControllerFilter.scala @@ -21,25 +21,27 @@ } override def doFilter(request: ServletRequest, response: ServletResponse, chain: FilterChain): Unit = { - val controller = PluginRegistry().getControllers().filter { case (_, path) => - val requestUri = request.asInstanceOf[HttpServletRequest].getRequestURI - val start = path.replaceFirst("/\\*$", "/") - (requestUri + "/").startsWith(start) - } + val requestUri = request.asInstanceOf[HttpServletRequest].getRequestURI - val filterChainWrapper = controller.foldLeft(chain){ case (chain, (controller, _)) => - new FilterChainWrapper(controller, chain) - } - filterChainWrapper.doFilter(request, response) - } - - class FilterChainWrapper(controller: ControllerBase, chain: FilterChain) extends FilterChain { - override def doFilter(request: ServletRequest, response: ServletResponse): Unit = { - if(controller.config == null){ - controller.init(filterConfig) + PluginRegistry().getControllers() + .filter { case (_, path) => + val start = path.replaceFirst("/\\*$", "/") + (requestUri + "/").startsWith(start) } - controller.doFilter(request, response, chain) - } + .foreach { case (controller, _) => + controller match { + case x: ControllerBase if(x.config == null) => x.init(filterConfig) + case _ => () + } + val mockChain = new MockFilterChain() + controller.doFilter(request, response, mockChain) + + if(mockChain.continue == false){ + return () + } + } + + chain.doFilter(request, response) } }