add agent wrapper

This commit is contained in:
juancwu 2026-01-11 17:39:15 -05:00
commit b5d9c45492
6 changed files with 220 additions and 4 deletions

77
agent.go Normal file
View file

@ -0,0 +1,77 @@
package main
import (
"fmt"
"net"
"os"
"path/filepath"
"syscall"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/term"
)
func startEphemeralAgent(pemData []byte, target string) (string, func(), error) {
var key interface{}
var err error
key, err = ssh.ParseRawPrivateKey(pemData)
if err != nil {
if _, ok := err.(*ssh.PassphraseMissingError); ok {
fmt.Printf("\033[1;32m? Gosh:\033[0m Key for \033[1m%s\033[0m is encrypted. Enter passphrase: ", target)
pass, readErr := term.ReadPassword(int(syscall.Stdin))
fmt.Println()
if readErr != nil {
return "", nil, fmt.Errorf("failed to read password: %w", err)
}
key, err = ssh.ParseRawPrivateKeyWithPassphrase(pemData, pass)
if err != nil {
return "", nil, fmt.Errorf("incorrect passphrase or invalid key: %w", err)
}
} else {
return "", nil, fmt.Errorf("invalid key format: %w", err)
}
} else {
fmt.Printf("Using unencrypted key for %s\n", target)
}
keyring := agent.NewKeyring()
addedKey := agent.AddedKey{
PrivateKey: key,
Comment: "gosh-ephemeral",
LifetimeSecs: 60,
}
if err := keyring.Add(addedKey); err != nil {
return "", nil, fmt.Errorf("failed to add key to agent: %w", err)
}
tempDir, err := os.MkdirTemp("", "gosh-*")
if err != nil {
return "", nil, fmt.Errorf("failed to make temporary directory: %w", err)
}
sockPath := filepath.Join(tempDir, "agent.sock")
l, err := net.Listen("unix", sockPath)
if err != nil {
return "", nil, fmt.Errorf("failed to listen on socket: %w", err)
}
go func() {
for {
conn, err := l.Accept()
if err != nil {
return
}
go agent.ServeAgent(keyring, conn)
}
}()
cleanup := func() {
l.Close()
os.RemoveAll(tempDir)
}
return sockPath, cleanup, nil
}

16
db.go
View file

@ -34,6 +34,22 @@ func initDB(dbPath string) (*sql.DB, error) {
return db, nil return db, nil
} }
func findKey(db *sql.DB, user, host string) ([]byte, error) {
query := `
SELECT encrypted_pem FROM keys
WHERE (host_pattern = ? OR host_pattern = '*') AND (user_pattern = ? OR user_pattern = '*')
ORDER BY id DESC LIMIT 1;
`
var pemData []byte
err := db.QueryRow(query, host, user).Scan(&pemData)
if err != nil {
return nil, err
}
return pemData, nil
}
func addKey(db *sql.DB, userPattern, hostPattern, keyPath string) error { func addKey(db *sql.DB, userPattern, hostPattern, keyPath string) error {
pemData, err := os.ReadFile(keyPath) pemData, err := os.ReadFile(keyPath)
if err != nil { if err != nil {

8
go.mod
View file

@ -2,7 +2,11 @@ module git.juancwu.dev/juancwu/gosh
go 1.25.1 go 1.25.1
require modernc.org/sqlite v1.43.0 require (
golang.org/x/crypto v0.46.0
golang.org/x/term v0.39.0
modernc.org/sqlite v1.43.0
)
require ( require (
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
@ -11,7 +15,7 @@ require (
github.com/ncruces/go-strftime v0.1.9 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect
golang.org/x/sys v0.36.0 // indirect golang.org/x/sys v0.40.0 // indirect
modernc.org/libc v1.66.10 // indirect modernc.org/libc v1.66.10 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect modernc.org/memory v1.11.0 // indirect

8
go.sum
View file

@ -10,6 +10,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ=
@ -17,8 +19,10 @@ golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4= modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4=

75
main.go
View file

@ -4,6 +4,10 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"os/exec"
"os/signal"
"strings"
"syscall"
) )
func main() { func main() {
@ -40,6 +44,7 @@ func main() {
fmt.Println("Error:", err) fmt.Println("Error:", err)
os.Exit(1) os.Exit(1)
} }
defer db.Close()
userPattern := args[1] userPattern := args[1]
hostPattern := args[2] hostPattern := args[2]
@ -51,4 +56,74 @@ func main() {
} }
return return
} }
user, host := parseDestination(args)
env := os.Environ()
if host != "" {
db, err := initDB(*dbPath)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
pemData, err := findKey(db, user, host)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
db.Close()
targetName := host
if user != "" {
targetName = user + "@" + host
}
socketPath, cleanup, err := startEphemeralAgent(pemData, targetName)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
defer cleanup()
newEnv := []string{"SSH_AUTH_SOCK=" + socketPath}
for _, e := range env {
if !strings.HasPrefix(e, "SSH_AUTH_SOCK=") {
newEnv = append(newEnv, e)
}
}
env = newEnv
}
sshPath, err := exec.LookPath("ssh")
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
sshCmd := exec.Command(sshPath, args...)
sshCmd.Env = env
sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout
sshCmd.Stderr = os.Stderr
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
for sig := range c {
if sshCmd.Process != nil {
sshCmd.Process.Signal(sig)
}
}
}()
err = sshCmd.Run()
if err != nil {
fmt.Println("Error:", err)
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
os.Exit(1)
}
} }

40
utils.go Normal file
View file

@ -0,0 +1,40 @@
package main
import (
"strings"
)
func parseDestination(args []string) (user, host string) {
argsWithParams := map[string]bool{
"-b": true, "-c": true, "-D": true, "-E": true, "-e": true,
"-F": true, "-I": true, "-i": true, "-L": true, "-l": true,
"-m": true, "-O": true, "-o": true, "-p": true, "-R": true,
"-S": true, "-W": true, "-w": true,
}
skipNext := false
for _, arg := range args {
if skipNext {
skipNext = false
continue
}
if strings.HasPrefix(arg, "-") {
flag := arg
if len(arg) > 2 && !strings.Contains(arg, "=") {
flag = "-" + string(arg[len(arg)-1])
}
if argsWithParams[flag] {
skipNext = true
}
continue
}
parts := strings.SplitN(arg, "@", 2)
if len(parts) == 2 {
return parts[0], parts[1]
}
return "", parts[0] // No user specified, just host
}
return "", ""
}