diff --git a/src/main/scala/ssh/GitCommand.scala b/src/main/scala/ssh/GitCommand.scala index f855168..4a13975 100644 --- a/src/main/scala/ssh/GitCommand.scala +++ b/src/main/scala/ssh/GitCommand.scala @@ -9,15 +9,15 @@ import org.eclipse.jgit.transport.{ReceivePack, UploadPack} import org.apache.sshd.server.command.UnknownCommand import servlet.{Database, CommitLogHook} -import service.SystemSettingsService.SystemSettings import service.SystemSettingsService +import org.eclipse.jgit.errors.RepositoryNotFoundException class GitCommandFactory extends CommandFactory { private val logger = LoggerFactory.getLogger(classOf[GitCommandFactory]) override def createCommand(command: String): Command = { - logger.info(s"command: String -> " + command) + logger.debug(s"command: $command") command match { // TODO MUST use regular expression and UnitTest case s if s.startsWith("git-upload-pack") => new GitUploadPack(command) @@ -28,24 +28,24 @@ } abstract class GitCommand(val command: String) extends Command { - private val logger = LoggerFactory.getLogger(classOf[GitCommand]) + protected val logger = LoggerFactory.getLogger(classOf[GitCommand]) protected val (gitCommand, owner, repositoryName) = parseCommand protected var err: OutputStream = null protected var in: InputStream = null protected var out: OutputStream = null protected var callback: ExitCallback = null - protected def runnable: Runnable + protected def runnable(user: String): Runnable override def start(env: Environment): Unit = { logger.info(s"start command : " + command) logger.info(s"parsed command : $gitCommand, $owner, $repositoryName") - val thread = new Thread(runnable) + val user = env.getEnv.get("USER") + val thread = new Thread(runnable(user)) thread.start() } - override def destroy(): Unit = { - } + override def destroy(): Unit = {} override def setExitCallback(callback: ExitCallback): Unit = { this.callback = callback @@ -64,47 +64,70 @@ } private def parseCommand: (String, String, String) = { - // command sample: git-upload-pack '/username/repository_name.git' - // command sample: git-receive-pack '/username/repository_name.git' - // TODO This is not correct.... + // command sample: git-upload-pack '/owner/repository_name.git' + // command sample: git-receive-pack '/owner/repository_name.git' + // TODO This is not correct.... but works val split = command.split(" ") val gitCommand = split(0) - val gitUser = split(1).substring(1, split(1).length - 5).split("/")(1) - val gitRepo = split(1).substring(1, split(1).length - 5).split("/")(2) - (gitCommand, gitUser, gitRepo) + val owner = split(1).substring(1, split(1).length - 5).split("/")(1) + val repositoryName = split(1).substring(1, split(1).length - 5).split("/")(2) + (gitCommand, owner, repositoryName) } } -class GitUploadPack(command: String) extends GitCommand(command: String) { - override def runnable = new Runnable { +class GitUploadPack(override val command: String) extends GitCommand(command: String) { + override def runnable(user: String) = new Runnable { override def run(): Unit = { - using(Git.open(getRepositoryDir(owner, repositoryName))) { git => - val repository = git.getRepository - val upload = new UploadPack(repository) - upload.upload(in, out, err) - callback.onExit(0) + try { + using(Git.open(getRepositoryDir(owner, repositoryName))) { + git => + val repository = git.getRepository + val upload = new UploadPack(repository) + try { + upload.upload(in, out, err) + callback.onExit(0) + } catch { + case e: Throwable => + logger.error(e.getMessage, e) + callback.onExit(1) + } + } + } catch { + case e: RepositoryNotFoundException => + logger.info(e.getMessage, e) + callback.onExit(1) } } } } -class GitReceivePack(command: String) extends GitCommand(command: String) with SystemSettingsService { - override def runnable = new Runnable { +class GitReceivePack(override val command: String) extends GitCommand(command: String) with SystemSettingsService { + // TODO Correct this info. where i get base url? + val BaseURL: String = loadSystemSettings().baseUrl.getOrElse("http://localhost:8080") - // TODO correct this info - val pusher: String = "user1" - val baseURL: String = loadSystemSettings().baseUrl.getOrElse("http://localhost:8080") - + override def runnable(user: String) = new Runnable { override def run(): Unit = { - using(Git.open(getRepositoryDir(owner, repositoryName))) { git => - val repository = git.getRepository - // TODO hook commit - val receive = new ReceivePack(repository) - receive.setPostReceiveHook(new CommitLogHook(owner, repositoryName, pusher, baseURL)) - Database(SshServer.getServletContext) withTransaction { - receive.receive(in, out, err) - callback.onExit(0) - } + try { + using(Git.open(getRepositoryDir(owner, repositoryName))) { + git => + val repository = git.getRepository + val receive = new ReceivePack(repository) + receive.setPostReceiveHook(new CommitLogHook(owner, repositoryName, user, BaseURL)) + Database(SshServer.getServletContext) withTransaction { + try { + receive.receive(in, out, err) + callback.onExit(0) + } catch { + case e: Throwable => + logger.error(e.getMessage, e) + callback.onExit(1) + } + } + } + } catch { + case e: RepositoryNotFoundException => + logger.info(e.getMessage, e) + callback.onExit(1) } } } diff --git a/src/main/scala/ssh/PublicKeyAuthenticator.scala b/src/main/scala/ssh/PublicKeyAuthenticator.scala index a3c8ef0..a59761a 100644 --- a/src/main/scala/ssh/PublicKeyAuthenticator.scala +++ b/src/main/scala/ssh/PublicKeyAuthenticator.scala @@ -1,37 +1,49 @@ package ssh -import org.apache.sshd.server.{PublickeyAuthenticator, PasswordAuthenticator} +import org.apache.sshd.server.PublickeyAuthenticator import org.slf4j.LoggerFactory import org.apache.sshd.server.session.ServerSession -import java.security.{KeyFactory, PublicKey} +import java.security.PublicKey import org.apache.commons.codec.binary.Base64 -import java.security.spec.X509EncodedKeySpec import org.apache.sshd.common.util.Buffer +import org.eclipse.jgit.lib.Constants + + +object DummyData { + val userPublicKeys = List( + "ssh-rsa AAB3NzaC1yc2EAAAADAQABAAABAQDRzuX0WtSLzCY45nEhfFDPXzYGmvQdqnOgOUY4yGL5io/2ztyUvJdhWowkyakeoPxVk/jIP7Tu8Are5TuSD+fJp7aUbZW2CYOEsxo8cwndh/ezIX6RFjlu+xvKvZ8G7BtFLlLCcnza9uB+uEAyPH5HvGQLdV7dXctLfFqXPTr1p1RjSI7Noubm+vN4n9108rILd32MlhQiToXjL4HKWWwmppaln6bEsonOQW4/GieRjQeyWDkbVekIofnedjWl4+W0kAA+WosNwRFShgsaJLfU964HT/cGjK5auqOG+nATY0suECnxAK+5Wb6jXXYNmKiIMHypeXG1Qy2wMyMB1Gq9 tanacasino-local", + "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDRzuX0WtSLzCY45nEhfFDPXzYGmvQdqnOgOUY4yGL5io/2ztyUvJdhWowkyakeoPxVk/jIP7Tu8Are5TuSD+fJp7aUbZW2CYOEsxo8cwndh/ezIX6RFjlu+xvKvZ8G7BtFLlLCcnza9uB+uEAyPH5HvGQLdV7dXctLfFqXPTr1p1RjSI7Noubm+vN4n9108rILd32MlhQiToXjL4HKWWwmppaln6bEsonOQW4/GieRjQeyWDkbVekIofnedjWl4+W0kAA+WosNwRFShgsaJLfU964HT/cGjK5auqOG+nATY0suECnxAK+5Wb6jXXYNmKiIMHypeXG1Qy2wMyMB1Gq9 tanacasino-local" + ) +} class PublicKeyAuthenticator extends PublickeyAuthenticator { private val logger = LoggerFactory.getLogger(classOf[PublicKeyAuthenticator]) override def authenticate(username: String, key: PublicKey, session: ServerSession): Boolean = { - // TODO this string is read from DB and Users register this public key string on Account Profile view - val testAuthkey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDRzuX0WtSLzCY45nEhfFDPXzYGmvQdqnOgOUY4yGL5io/2ztyUvJdhWowkyakeoPxVk/jIP7Tu8Are5TuSD+fJp7aUbZW2CYOEsxo8cwndh/ezIX6RFjlu+xvKvZ8G7BtFLlLCcnza9uB+uEAyPH5HvGQLdV7dXctLfFqXPTr1p1RjSI7Noubm+vN4n9108rILd32MlhQiToXjL4HKWWwmppaln6bEsonOQW4/GieRjQeyWDkbVekIofnedjWl4+W0kAA+WosNwRFShgsaJLfU964HT/cGjK5auqOG+nATY0suECnxAK+5Wb6jXXYNmKiIMHypeXG1Qy2wMyMB1Gq9 tanacasino-local" - toPublicKey(testAuthkey) match { + // TODO userPublicKeys is read from DB and Users register this public key string list on Account Profile view + DummyData.userPublicKeys.exists(str => str2PublicKey(str) match { case Some(publicKey) => key.equals(publicKey) case _ => false + }) + } + + private def str2PublicKey(key: String): Option[PublicKey] = { + // TODO RFC 4716 Public Key is not supported... + val parts = key.split(" ") + if (parts.size < 2) { + logger.debug(s"Invalid PublicKey Format: key") + return None + } + try { + val encodedKey = parts(1) + val decode = Base64.decodeBase64(Constants.encodeASCII(encodedKey)) + Some(new Buffer(decode).getRawPublicKey) + } catch { + case e: Throwable => + logger.debug(e.getMessage, e) + None } } - private def toPublicKey(key: String): Option[PublicKey] = { - try { - val parts = key.split(" ") - val encodedKey = key.split(" ")(1) - val decode = Base64.decodeBase64(encodedKey) - Some(new Buffer(decode).getRawPublicKey) - } catch { - case e: Throwable => { - logger.error(e.getMessage, e) - None - } - } - } } diff --git a/src/main/scala/ssh/SshServerListener.scala b/src/main/scala/ssh/SshServerListener.scala index dda7dcd..477f3a1 100644 --- a/src/main/scala/ssh/SshServerListener.scala +++ b/src/main/scala/ssh/SshServerListener.scala @@ -3,30 +3,30 @@ import javax.servlet.{ServletContext, ServletContextEvent, ServletContextListener} import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider import org.slf4j.LoggerFactory +import util.Directory object SshServer { private val logger = LoggerFactory.getLogger(SshServer.getClass) - val DEFAULT_PORT: Int = 29418 // TODO read from config - val SSH_SERVICE_ENABLE = true + val DEFAULT_PORT: Int = 29418 + // TODO read from config + val SSH_SERVICE_ENABLE = true // TODO read from config private val server = org.apache.sshd.SshServer.setUpDefaultServer() - // TODO think other way to create database session + // TODO think other way. this is for create database session private var context: ServletContext = null private def configure() = { - server.setPort(DEFAULT_PORT) - // TODO gitbucket.ser should be in GITBUCKET_HOME - server.setKeyPairProvider(new SimpleGeneratorHostKeyProvider("gitbucket.ser")) - + server.setPort(DEFAULT_PORT) // TODO read from config + server.setKeyPairProvider(new SimpleGeneratorHostKeyProvider(s"${Directory.GitBucketHome}/gitbucket.ser")) server.setPublickeyAuthenticator(new PublicKeyAuthenticator) server.setCommandFactory(new GitCommandFactory) } - def start(context: ServletContext) = { + def start(context: ServletContext) = this.synchronized { if (SSH_SERVICE_ENABLE) { this.context = context configure() @@ -39,7 +39,7 @@ server.stop(true) } - def getServletContext = this.context; + def getServletContext = this.context } /* @@ -52,7 +52,7 @@ class SshServerListener extends ServletContextListener { override def contextInitialized(sce: ServletContextEvent): Unit = { - SshServer.start(sce.getServletContext()) + SshServer.start(sce.getServletContext) } override def contextDestroyed(sce: ServletContextEvent): Unit = { @@ -60,5 +60,3 @@ } } - -