From 656b3f4a805510bbe427bfd408f21979ad37e26d Mon Sep 17 00:00:00 2001 From: juancwu <46619361+juancwu@users.noreply.github.com> Date: Sun, 11 Jan 2026 17:57:11 -0500 Subject: [PATCH] list keys and refactor --- agent.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++- cmd.go | 43 +++++++++++++++++++++++++++ db.go | 26 ++++++++++++++++ main.go | 90 ++++++-------------------------------------------------- 4 files changed, 154 insertions(+), 82 deletions(-) create mode 100644 cmd.go diff --git a/agent.go b/agent.go index 368e94f..e716493 100644 --- a/agent.go +++ b/agent.go @@ -4,7 +4,10 @@ import ( "fmt" "net" "os" + "os/exec" + "os/signal" "path/filepath" + "strings" "syscall" "golang.org/x/crypto/ssh" @@ -13,7 +16,7 @@ import ( ) func startEphemeralAgent(pemData []byte, target string) (string, func(), error) { - var key interface{} + var key any var err error key, err = ssh.ParseRawPrivateKey(pemData) @@ -75,3 +78,75 @@ func startEphemeralAgent(pemData []byte, target string) (string, func(), error) } return sockPath, cleanup, nil } + +func startSSH(dbPath string, args []string) { + user, host := parseDestination(args) + env := os.Environ() + + if host != "" { + db, err := initDB(dbPath) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + + pemData, err := findKey(db, user, host) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + + db.Close() + + targetName := host + if user != "" { + targetName = user + "@" + host + } + + socketPath, cleanup, err := startEphemeralAgent(pemData, targetName) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + defer cleanup() + + newEnv := []string{"SSH_AUTH_SOCK=" + socketPath} + for _, e := range env { + if !strings.HasPrefix(e, "SSH_AUTH_SOCK=") { + newEnv = append(newEnv, e) + } + } + env = newEnv + } + + sshPath, err := exec.LookPath("ssh") + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + + sshCmd := exec.Command(sshPath, args...) + sshCmd.Env = env + sshCmd.Stdin = os.Stdin + sshCmd.Stdout = os.Stdout + sshCmd.Stderr = os.Stderr + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + for sig := range c { + if sshCmd.Process != nil { + sshCmd.Process.Signal(sig) + } + } + }() + + err = sshCmd.Run() + if err != nil { + fmt.Println("Error:", err) + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + os.Exit(1) + } +} diff --git a/cmd.go b/cmd.go new file mode 100644 index 0000000..907f069 --- /dev/null +++ b/cmd.go @@ -0,0 +1,43 @@ +package main + +import ( + "fmt" + "os" + "text/tabwriter" +) + +func handleAddKey(dbPath, userPattern, hostPattern, keyPath string) error { + db, err := initDB(dbPath) + if err != nil { + return err + } + defer db.Close() + + err = addKey(db, userPattern, hostPattern, keyPath) + if err != nil { + return err + } + return nil +} + +func handleListKey(dbPath string) error { + db, err := initDB(dbPath) + if err != nil { + return err + } + defer db.Close() + + keys, err := listkeys(db) + if err != nil { + return err + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) + fmt.Fprintln(w, "UD\tUser Pattern\tHost Pattern\tComment") + fmt.Fprintln(w, "--\t------------\t------------\t-------") + for _, k := range keys { + fmt.Fprintln(w, "%d\t%s\t%s\t%s\n", k.ID, k.UserPattern, k.HostPattern, k.Comment) + } + w.Flush() + return nil +} diff --git a/db.go b/db.go index 8fdd417..b29bf04 100644 --- a/db.go +++ b/db.go @@ -9,6 +9,13 @@ import ( _ "modernc.org/sqlite" ) +type KeyRecord struct { + ID int + HostPattern string + UserPattern string + Comment string +} + func getDBPath() string { home, _ := os.UserHomeDir() localDataDir := filepath.Join(home, ".local", "share", "gosh") @@ -67,3 +74,22 @@ func addKey(db *sql.DB, userPattern, hostPattern, keyPath string) error { return nil } + +func listkeys(db *sql.DB) ([]KeyRecord, error) { + rows, err := db.Query("SELECT id, host_pattern, user_pattern, comment FROM keys;") + if err != nil { + return nil, fmt.Errorf("failed to query keys: %w", err) + } + defer rows.Close() + + var keys []KeyRecord + for rows.Next() { + var k KeyRecord + if err := rows.Scan(&k.ID, &k.HostPattern, &k.UserPattern, &k.Comment); err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + keys = append(keys, k) + } + + return keys, nil +} diff --git a/main.go b/main.go index 427b5f1..8fb79dd 100644 --- a/main.go +++ b/main.go @@ -4,16 +4,13 @@ import ( "flag" "fmt" "os" - "os/exec" - "os/signal" - "strings" - "syscall" ) func main() { flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s [gosh-flags] [user@host] [ssh-flags]\n", os.Args[0]) fmt.Fprintf(os.Stderr, "\nManagement Commands:\n") + fmt.Fprintf(os.Stderr, " %s [flags] list-keys\n", os.Args[0]) fmt.Fprintf(os.Stderr, " %s [flags] add-key \n", os.Args[0]) fmt.Fprintf(os.Stderr, "\nFlags:\n") flag.PrintDefaults() @@ -32,98 +29,29 @@ func main() { cmd := args[0] - if cmd == "add-key" { + switch cmd { + case "add-key": if argc != 4 { fmt.Println("Error: Incorrect arguments for add-key.") fmt.Println("Try: gosh [--db path] add-key ") os.Exit(1) } - db, err := initDB(*dbPath) + err := handleAddKey(*dbPath, args[1], args[2], args[3]) if err != nil { fmt.Println("Error:", err) os.Exit(1) } - defer db.Close() + os.Exit(0) - userPattern := args[1] - hostPattern := args[2] - keyPath := args[3] - err = addKey(db, userPattern, hostPattern, keyPath) + case "list-keys": + err := handleListKey(*dbPath) if err != nil { fmt.Println("Error:", err) os.Exit(1) } - return + os.Exit(0) } - user, host := parseDestination(args) - env := os.Environ() - - if host != "" { - db, err := initDB(*dbPath) - if err != nil { - fmt.Println("Error:", err) - os.Exit(1) - } - - pemData, err := findKey(db, user, host) - if err != nil { - fmt.Println("Error:", err) - os.Exit(1) - } - - db.Close() - - targetName := host - if user != "" { - targetName = user + "@" + host - } - - socketPath, cleanup, err := startEphemeralAgent(pemData, targetName) - if err != nil { - fmt.Println("Error:", err) - os.Exit(1) - } - defer cleanup() - - newEnv := []string{"SSH_AUTH_SOCK=" + socketPath} - for _, e := range env { - if !strings.HasPrefix(e, "SSH_AUTH_SOCK=") { - newEnv = append(newEnv, e) - } - } - env = newEnv - } - - sshPath, err := exec.LookPath("ssh") - if err != nil { - fmt.Println("Error:", err) - os.Exit(1) - } - - sshCmd := exec.Command(sshPath, args...) - sshCmd.Env = env - sshCmd.Stdin = os.Stdin - sshCmd.Stdout = os.Stdout - sshCmd.Stderr = os.Stderr - - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - go func() { - for sig := range c { - if sshCmd.Process != nil { - sshCmd.Process.Signal(sig) - } - } - }() - - err = sshCmd.Run() - if err != nil { - fmt.Println("Error:", err) - if exitErr, ok := err.(*exec.ExitError); ok { - os.Exit(exitErr.ExitCode()) - } - os.Exit(1) - } + startSSH(*dbPath, args) }