add agent wrapper
This commit is contained in:
parent
1208542bcb
commit
b5d9c45492
6 changed files with 220 additions and 4 deletions
77
agent.go
Normal file
77
agent.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
func startEphemeralAgent(pemData []byte, target string) (string, func(), error) {
|
||||
var key interface{}
|
||||
var err error
|
||||
|
||||
key, err = ssh.ParseRawPrivateKey(pemData)
|
||||
|
||||
if err != nil {
|
||||
if _, ok := err.(*ssh.PassphraseMissingError); ok {
|
||||
fmt.Printf("\033[1;32m? Gosh:\033[0m Key for \033[1m%s\033[0m is encrypted. Enter passphrase: ", target)
|
||||
pass, readErr := term.ReadPassword(int(syscall.Stdin))
|
||||
fmt.Println()
|
||||
if readErr != nil {
|
||||
return "", nil, fmt.Errorf("failed to read password: %w", err)
|
||||
}
|
||||
|
||||
key, err = ssh.ParseRawPrivateKeyWithPassphrase(pemData, pass)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("incorrect passphrase or invalid key: %w", err)
|
||||
}
|
||||
} else {
|
||||
return "", nil, fmt.Errorf("invalid key format: %w", err)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("Using unencrypted key for %s\n", target)
|
||||
}
|
||||
|
||||
keyring := agent.NewKeyring()
|
||||
addedKey := agent.AddedKey{
|
||||
PrivateKey: key,
|
||||
Comment: "gosh-ephemeral",
|
||||
LifetimeSecs: 60,
|
||||
}
|
||||
if err := keyring.Add(addedKey); err != nil {
|
||||
return "", nil, fmt.Errorf("failed to add key to agent: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp("", "gosh-*")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to make temporary directory: %w", err)
|
||||
}
|
||||
sockPath := filepath.Join(tempDir, "agent.sock")
|
||||
|
||||
l, err := net.Listen("unix", sockPath)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to listen on socket: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go agent.ServeAgent(keyring, conn)
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() {
|
||||
l.Close()
|
||||
os.RemoveAll(tempDir)
|
||||
}
|
||||
return sockPath, cleanup, nil
|
||||
}
|
||||
16
db.go
16
db.go
|
|
@ -34,6 +34,22 @@ func initDB(dbPath string) (*sql.DB, error) {
|
|||
return db, nil
|
||||
}
|
||||
|
||||
func findKey(db *sql.DB, user, host string) ([]byte, error) {
|
||||
query := `
|
||||
SELECT encrypted_pem FROM keys
|
||||
WHERE (host_pattern = ? OR host_pattern = '*') AND (user_pattern = ? OR user_pattern = '*')
|
||||
ORDER BY id DESC LIMIT 1;
|
||||
`
|
||||
|
||||
var pemData []byte
|
||||
err := db.QueryRow(query, host, user).Scan(&pemData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pemData, nil
|
||||
}
|
||||
|
||||
func addKey(db *sql.DB, userPattern, hostPattern, keyPath string) error {
|
||||
pemData, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
|
|
|
|||
8
go.mod
8
go.mod
|
|
@ -2,7 +2,11 @@ module git.juancwu.dev/juancwu/gosh
|
|||
|
||||
go 1.25.1
|
||||
|
||||
require modernc.org/sqlite v1.43.0
|
||||
require (
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/term v0.39.0
|
||||
modernc.org/sqlite v1.43.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
|
|
@ -11,7 +15,7 @@ require (
|
|||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect
|
||||
golang.org/x/sys v0.36.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
modernc.org/libc v1.66.10 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
|
|
|
|||
8
go.sum
8
go.sum
|
|
@ -10,6 +10,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh
|
|||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
|
||||
golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ=
|
||||
|
|
@ -17,8 +19,10 @@ golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc=
|
|||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
|
||||
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
|
||||
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
|
||||
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
|
||||
modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4=
|
||||
|
|
|
|||
75
main.go
75
main.go
|
|
@ -4,6 +4,10 @@ import (
|
|||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
|
@ -40,6 +44,7 @@ func main() {
|
|||
fmt.Println("Error:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
userPattern := args[1]
|
||||
hostPattern := args[2]
|
||||
|
|
@ -51,4 +56,74 @@ func main() {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
user, host := parseDestination(args)
|
||||
env := os.Environ()
|
||||
|
||||
if host != "" {
|
||||
db, err := initDB(*dbPath)
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
pemData, err := findKey(db, user, host)
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
db.Close()
|
||||
|
||||
targetName := host
|
||||
if user != "" {
|
||||
targetName = user + "@" + host
|
||||
}
|
||||
|
||||
socketPath, cleanup, err := startEphemeralAgent(pemData, targetName)
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
newEnv := []string{"SSH_AUTH_SOCK=" + socketPath}
|
||||
for _, e := range env {
|
||||
if !strings.HasPrefix(e, "SSH_AUTH_SOCK=") {
|
||||
newEnv = append(newEnv, e)
|
||||
}
|
||||
}
|
||||
env = newEnv
|
||||
}
|
||||
|
||||
sshPath, err := exec.LookPath("ssh")
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
sshCmd := exec.Command(sshPath, args...)
|
||||
sshCmd.Env = env
|
||||
sshCmd.Stdin = os.Stdin
|
||||
sshCmd.Stdout = os.Stdout
|
||||
sshCmd.Stderr = os.Stderr
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
for sig := range c {
|
||||
if sshCmd.Process != nil {
|
||||
sshCmd.Process.Signal(sig)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = sshCmd.Run()
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
os.Exit(exitErr.ExitCode())
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
40
utils.go
Normal file
40
utils.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func parseDestination(args []string) (user, host string) {
|
||||
argsWithParams := map[string]bool{
|
||||
"-b": true, "-c": true, "-D": true, "-E": true, "-e": true,
|
||||
"-F": true, "-I": true, "-i": true, "-L": true, "-l": true,
|
||||
"-m": true, "-O": true, "-o": true, "-p": true, "-R": true,
|
||||
"-S": true, "-W": true, "-w": true,
|
||||
}
|
||||
|
||||
skipNext := false
|
||||
for _, arg := range args {
|
||||
if skipNext {
|
||||
skipNext = false
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
flag := arg
|
||||
if len(arg) > 2 && !strings.Contains(arg, "=") {
|
||||
flag = "-" + string(arg[len(arg)-1])
|
||||
}
|
||||
|
||||
if argsWithParams[flag] {
|
||||
skipNext = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(arg, "@", 2)
|
||||
if len(parts) == 2 {
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
return "", parts[0] // No user specified, just host
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue