diff --git a/src/main/scala/ScalatraBootstrap.scala b/src/main/scala/ScalatraBootstrap.scala index cfa37e6..ac8f57e 100644 --- a/src/main/scala/ScalatraBootstrap.scala +++ b/src/main/scala/ScalatraBootstrap.scala @@ -1,4 +1,4 @@ -import _root_.servlet.{BasicAuthenticationFilter, TransactionFilter} +import _root_.servlet.{BasicAuthenticationFilter, TransactionFilter, AccessTokenAuthenticationFilter} import app._ import plugin.PluginRegistry @@ -14,7 +14,8 @@ context.getFilterRegistration("transactionFilter").addMappingForUrlPatterns(EnumSet.allOf(classOf[DispatcherType]), true, "/*") context.addFilter("basicAuthenticationFilter", new BasicAuthenticationFilter) context.getFilterRegistration("basicAuthenticationFilter").addMappingForUrlPatterns(EnumSet.allOf(classOf[DispatcherType]), true, "/git/*") - + context.addFilter("accessTokenAuthenticationFilter", new AccessTokenAuthenticationFilter) + context.getFilterRegistration("accessTokenAuthenticationFilter").addMappingForUrlPatterns(EnumSet.allOf(classOf[DispatcherType]), true, "/api/v3/*") // Register controllers context.mount(new AnonymousAccessController, "/*") diff --git a/src/main/scala/app/ControllerBase.scala b/src/main/scala/app/ControllerBase.scala index 5bb1ca2..ed9691b 100644 --- a/src/main/scala/app/ControllerBase.scala +++ b/src/main/scala/app/ControllerBase.scala @@ -74,12 +74,7 @@ } } - private def LoginAccount: Option[Account] = { - Option(request.getHeader("Authorization")) match { - case Some(auth) if auth.startsWith("token ") => AccessTokenService.getAccountByAccessToken(auth.substring(6).trim) - case _ => session.getAs[Account](Keys.Session.LoginAccount) - } - } + private def LoginAccount: Option[Account] = request.getAs[Account](Keys.Session.LoginAccount).orElse(session.getAs[Account](Keys.Session.LoginAccount)) def ajaxGet(path : String)(action : => Any) : Route = super.get(path){ diff --git a/src/main/scala/servlet/AccessTokenAuthenticationFilter.scala b/src/main/scala/servlet/AccessTokenAuthenticationFilter.scala new file mode 100644 index 0000000..0dc3949 --- /dev/null +++ b/src/main/scala/servlet/AccessTokenAuthenticationFilter.scala @@ -0,0 +1,41 @@ +package servlet + +import javax.servlet._ +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} + +import service.AccessTokenService +import util.Keys +import org.scalatra.servlet.ServletApiImplicits._ +import model.Account +import org.scalatra._ + +class AccessTokenAuthenticationFilter extends Filter with AccessTokenService { + private val tokenHeaderPrefix = "token " + + override def init(filterConfig: FilterConfig): Unit = {} + + override def destroy(): Unit = {} + + override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = { + implicit val request = req.asInstanceOf[HttpServletRequest] + implicit val session = req.getAttribute(Keys.Request.DBSession).asInstanceOf[slick.jdbc.JdbcBackend#Session] + val response = res.asInstanceOf[HttpServletResponse] + Option(request.getHeader("Authorization")).map{ + case auth if auth.startsWith("token ") => AccessTokenService.getAccountByAccessToken(auth.substring(6).trim).toRight(Unit) + // TODO Basic Authentication Support + case _ => Left(Unit) + }.orElse{ + Option(request.getSession.getAttribute(Keys.Session.LoginAccount).asInstanceOf[Account]).map(Right(_)) + } match { + case Some(Right(account)) => request.setAttribute(Keys.Session.LoginAccount, account); chain.doFilter(req, res) + case None => chain.doFilter(req, res) + case Some(Left(_)) => { + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setContentType("Content-Type: application/json; charset=utf-8") + val w = response.getWriter() + w.print("""{ "message": "Bad credentials" }""") + w.close() + } + } + } +}