diff --git a/src/main/scala/ssh/GitCommand.scala b/src/main/scala/ssh/GitCommand.scala index 95583fa..2a8d874 100644 --- a/src/main/scala/ssh/GitCommand.scala +++ b/src/main/scala/ssh/GitCommand.scala @@ -9,7 +9,7 @@ import org.eclipse.jgit.transport.{ReceivePack, UploadPack} import org.apache.sshd.server.command.UnknownCommand import servlet.{Database, CommitLogHook} -import service.SystemSettingsService +import service.{AccountService, RepositoryService, SystemSettingsService} import org.eclipse.jgit.errors.RepositoryNotFoundException import javax.servlet.ServletContext @@ -18,8 +18,10 @@ val CommandRegex = """\Agit-(upload|receive)-pack '/([a-zA-Z0-9\-_.]+)/([a-zA-Z0-9\-_.]+).git'\Z""".r } -abstract class GitCommand(val command: String) extends Command { - protected val logger = LoggerFactory.getLogger(classOf[GitCommand]) +abstract class GitCommand(val context: ServletContext, val command: String) extends Command { + self: RepositoryService with AccountService => + + private val logger = LoggerFactory.getLogger(classOf[GitCommand]) protected val (gitCommand, owner, repositoryName) = parseCommand protected var err: OutputStream = null protected var in: InputStream = null @@ -30,16 +32,18 @@ private def newTask(user: String): Runnable = new Runnable { override def run(): Unit = { - try { - runTask(user) - callback.onExit(0) - } catch { - case e: RepositoryNotFoundException => - logger.info(e.getMessage) - callback.onExit(1, "Repository Not Found") - case e: Throwable => - logger.error(e.getMessage, e) - callback.onExit(1) + Database(context) withTransaction { + try { + runTask(user) + callback.onExit(0) + } catch { + case e: RepositoryNotFoundException => + logger.info(e.getMessage) + callback.onExit(1, "Repository Not Found") + case e: Throwable => + logger.error(e.getMessage, e) + callback.onExit(1) + } } } } @@ -80,34 +84,47 @@ val repositoryName = split(1).substring(1, split(1).length - 5).split("/")(2) (gitCommand, owner, repositoryName) } + + protected def isWritableUser(username: String, repositoryInfo: RepositoryService.RepositoryInfo): Boolean = + getAccountByUserName(username) match { + case Some(account) => hasWritePermission(repositoryInfo.owner, repositoryInfo.name, Some(account)) + case None => false + } + } -class GitUploadPack(context: ServletContext, override val command: String) extends GitCommand(command: String) { +class GitUploadPack(context: ServletContext, command: String) extends GitCommand(context, command) + with RepositoryService with AccountService { override protected def runTask(user: String): Unit = { - using(Git.open(getRepositoryDir(owner, repositoryName))) { - git => - val repository = git.getRepository - val upload = new UploadPack(repository) - upload.upload(in, out, err) + getRepository(owner, repositoryName, null).foreach { repositoryInfo => + if(!repositoryInfo.repository.isPrivate || isWritableUser(user, repositoryInfo)){ + using(Git.open(getRepositoryDir(owner, repositoryName))) { git => + val repository = git.getRepository + val upload = new UploadPack(repository) + upload.upload(in, out, err) + } + } } } } -class GitReceivePack(context: ServletContext, override val command: String) extends GitCommand(command: String) with SystemSettingsService { +class GitReceivePack(context: ServletContext, command: String) extends GitCommand(context, command) + with SystemSettingsService with RepositoryService with AccountService { // TODO Correct this info. where i get base url? val BaseURL: String = loadSystemSettings().baseUrl.getOrElse("http://localhost:8080") override protected def runTask(user: String): Unit = { - using(Git.open(getRepositoryDir(owner, repositoryName))) { - git => - val repository = git.getRepository - val receive = new ReceivePack(repository) - receive.setPostReceiveHook(new CommitLogHook(owner, repositoryName, user, BaseURL)) - Database(context) withTransaction { + getRepository(owner, repositoryName, null).foreach { repositoryInfo => + if(isWritableUser(user, repositoryInfo)){ + using(Git.open(getRepositoryDir(owner, repositoryName))) { git => + val repository = git.getRepository + val receive = new ReceivePack(repository) + receive.setPostReceiveHook(new CommitLogHook(owner, repositoryName, user, BaseURL)) receive.receive(in, out, err) } + } } } diff --git a/src/main/scala/ssh/PublicKeyAuthenticator.scala b/src/main/scala/ssh/PublicKeyAuthenticator.scala index 1fb472f..2ba0db9 100644 --- a/src/main/scala/ssh/PublicKeyAuthenticator.scala +++ b/src/main/scala/ssh/PublicKeyAuthenticator.scala @@ -16,7 +16,6 @@ override def authenticate(username: String, key: PublicKey, session: ServerSession): Boolean = { Database(context) withTransaction { - // TODO Check permission to the repository here? getPublicKeys(username).exists { sshKey => str2PublicKey(sshKey.publicKey) match { case Some(publicKey) => key.equals(publicKey)