diff --git a/agent.go b/agent.go new file mode 100644 index 0000000..368e94f --- /dev/null +++ b/agent.go @@ -0,0 +1,77 @@ +package main + +import ( + "fmt" + "net" + "os" + "path/filepath" + "syscall" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "golang.org/x/term" +) + +func startEphemeralAgent(pemData []byte, target string) (string, func(), error) { + var key interface{} + var err error + + key, err = ssh.ParseRawPrivateKey(pemData) + + if err != nil { + if _, ok := err.(*ssh.PassphraseMissingError); ok { + fmt.Printf("\033[1;32m? Gosh:\033[0m Key for \033[1m%s\033[0m is encrypted. Enter passphrase: ", target) + pass, readErr := term.ReadPassword(int(syscall.Stdin)) + fmt.Println() + if readErr != nil { + return "", nil, fmt.Errorf("failed to read password: %w", err) + } + + key, err = ssh.ParseRawPrivateKeyWithPassphrase(pemData, pass) + if err != nil { + return "", nil, fmt.Errorf("incorrect passphrase or invalid key: %w", err) + } + } else { + return "", nil, fmt.Errorf("invalid key format: %w", err) + } + } else { + fmt.Printf("Using unencrypted key for %s\n", target) + } + + keyring := agent.NewKeyring() + addedKey := agent.AddedKey{ + PrivateKey: key, + Comment: "gosh-ephemeral", + LifetimeSecs: 60, + } + if err := keyring.Add(addedKey); err != nil { + return "", nil, fmt.Errorf("failed to add key to agent: %w", err) + } + + tempDir, err := os.MkdirTemp("", "gosh-*") + if err != nil { + return "", nil, fmt.Errorf("failed to make temporary directory: %w", err) + } + sockPath := filepath.Join(tempDir, "agent.sock") + + l, err := net.Listen("unix", sockPath) + if err != nil { + return "", nil, fmt.Errorf("failed to listen on socket: %w", err) + } + + go func() { + for { + conn, err := l.Accept() + if err != nil { + return + } + go agent.ServeAgent(keyring, conn) + } + }() + + cleanup := func() { + l.Close() + os.RemoveAll(tempDir) + } + return sockPath, cleanup, nil +} diff --git a/db.go b/db.go index 6563db5..8fdd417 100644 --- a/db.go +++ b/db.go @@ -34,6 +34,22 @@ func initDB(dbPath string) (*sql.DB, error) { return db, nil } +func findKey(db *sql.DB, user, host string) ([]byte, error) { + query := ` + SELECT encrypted_pem FROM keys + WHERE (host_pattern = ? OR host_pattern = '*') AND (user_pattern = ? OR user_pattern = '*') + ORDER BY id DESC LIMIT 1; + ` + + var pemData []byte + err := db.QueryRow(query, host, user).Scan(&pemData) + if err != nil { + return nil, err + } + + return pemData, nil +} + func addKey(db *sql.DB, userPattern, hostPattern, keyPath string) error { pemData, err := os.ReadFile(keyPath) if err != nil { diff --git a/go.mod b/go.mod index 9fbc220..1ba9698 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,11 @@ module git.juancwu.dev/juancwu/gosh go 1.25.1 -require modernc.org/sqlite v1.43.0 +require ( + golang.org/x/crypto v0.46.0 + golang.org/x/term v0.39.0 + modernc.org/sqlite v1.43.0 +) require ( github.com/dustin/go-humanize v1.0.1 // indirect @@ -11,7 +15,7 @@ require ( github.com/ncruces/go-strftime v0.1.9 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect - golang.org/x/sys v0.36.0 // indirect + golang.org/x/sys v0.40.0 // indirect modernc.org/libc v1.66.10 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect diff --git a/go.sum b/go.sum index 4059e7f..a6ddfc3 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= @@ -17,8 +19,10 @@ golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= -golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= +golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4= diff --git a/main.go b/main.go index 32e60a4..427b5f1 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,10 @@ import ( "flag" "fmt" "os" + "os/exec" + "os/signal" + "strings" + "syscall" ) func main() { @@ -40,6 +44,7 @@ func main() { fmt.Println("Error:", err) os.Exit(1) } + defer db.Close() userPattern := args[1] hostPattern := args[2] @@ -51,4 +56,74 @@ func main() { } return } + + user, host := parseDestination(args) + env := os.Environ() + + if host != "" { + db, err := initDB(*dbPath) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + + pemData, err := findKey(db, user, host) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + + db.Close() + + targetName := host + if user != "" { + targetName = user + "@" + host + } + + socketPath, cleanup, err := startEphemeralAgent(pemData, targetName) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + defer cleanup() + + newEnv := []string{"SSH_AUTH_SOCK=" + socketPath} + for _, e := range env { + if !strings.HasPrefix(e, "SSH_AUTH_SOCK=") { + newEnv = append(newEnv, e) + } + } + env = newEnv + } + + sshPath, err := exec.LookPath("ssh") + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + + sshCmd := exec.Command(sshPath, args...) + sshCmd.Env = env + sshCmd.Stdin = os.Stdin + sshCmd.Stdout = os.Stdout + sshCmd.Stderr = os.Stderr + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + for sig := range c { + if sshCmd.Process != nil { + sshCmd.Process.Signal(sig) + } + } + }() + + err = sshCmd.Run() + if err != nil { + fmt.Println("Error:", err) + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + os.Exit(1) + } } diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..7ee4f41 --- /dev/null +++ b/utils.go @@ -0,0 +1,40 @@ +package main + +import ( + "strings" +) + +func parseDestination(args []string) (user, host string) { + argsWithParams := map[string]bool{ + "-b": true, "-c": true, "-D": true, "-E": true, "-e": true, + "-F": true, "-I": true, "-i": true, "-L": true, "-l": true, + "-m": true, "-O": true, "-o": true, "-p": true, "-R": true, + "-S": true, "-W": true, "-w": true, + } + + skipNext := false + for _, arg := range args { + if skipNext { + skipNext = false + continue + } + if strings.HasPrefix(arg, "-") { + flag := arg + if len(arg) > 2 && !strings.Contains(arg, "=") { + flag = "-" + string(arg[len(arg)-1]) + } + + if argsWithParams[flag] { + skipNext = true + } + continue + } + + parts := strings.SplitN(arg, "@", 2) + if len(parts) == 2 { + return parts[0], parts[1] + } + return "", parts[0] // No user specified, just host + } + return "", "" +}