list keys and refactor

This commit is contained in:
juancwu 2026-01-11 17:57:11 -05:00
commit 656b3f4a80
4 changed files with 154 additions and 82 deletions

View file

@ -4,7 +4,10 @@ import (
"fmt"
"net"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"
"syscall"
"golang.org/x/crypto/ssh"
@ -13,7 +16,7 @@ import (
)
func startEphemeralAgent(pemData []byte, target string) (string, func(), error) {
var key interface{}
var key any
var err error
key, err = ssh.ParseRawPrivateKey(pemData)
@ -75,3 +78,75 @@ func startEphemeralAgent(pemData []byte, target string) (string, func(), error)
}
return sockPath, cleanup, nil
}
func startSSH(dbPath string, args []string) {
user, host := parseDestination(args)
env := os.Environ()
if host != "" {
db, err := initDB(dbPath)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
pemData, err := findKey(db, user, host)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
db.Close()
targetName := host
if user != "" {
targetName = user + "@" + host
}
socketPath, cleanup, err := startEphemeralAgent(pemData, targetName)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
defer cleanup()
newEnv := []string{"SSH_AUTH_SOCK=" + socketPath}
for _, e := range env {
if !strings.HasPrefix(e, "SSH_AUTH_SOCK=") {
newEnv = append(newEnv, e)
}
}
env = newEnv
}
sshPath, err := exec.LookPath("ssh")
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
sshCmd := exec.Command(sshPath, args...)
sshCmd.Env = env
sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout
sshCmd.Stderr = os.Stderr
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
for sig := range c {
if sshCmd.Process != nil {
sshCmd.Process.Signal(sig)
}
}
}()
err = sshCmd.Run()
if err != nil {
fmt.Println("Error:", err)
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
os.Exit(1)
}
}

43
cmd.go Normal file
View file

@ -0,0 +1,43 @@
package main
import (
"fmt"
"os"
"text/tabwriter"
)
func handleAddKey(dbPath, userPattern, hostPattern, keyPath string) error {
db, err := initDB(dbPath)
if err != nil {
return err
}
defer db.Close()
err = addKey(db, userPattern, hostPattern, keyPath)
if err != nil {
return err
}
return nil
}
func handleListKey(dbPath string) error {
db, err := initDB(dbPath)
if err != nil {
return err
}
defer db.Close()
keys, err := listkeys(db)
if err != nil {
return err
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
fmt.Fprintln(w, "UD\tUser Pattern\tHost Pattern\tComment")
fmt.Fprintln(w, "--\t------------\t------------\t-------")
for _, k := range keys {
fmt.Fprintln(w, "%d\t%s\t%s\t%s\n", k.ID, k.UserPattern, k.HostPattern, k.Comment)
}
w.Flush()
return nil
}

26
db.go
View file

@ -9,6 +9,13 @@ import (
_ "modernc.org/sqlite"
)
type KeyRecord struct {
ID int
HostPattern string
UserPattern string
Comment string
}
func getDBPath() string {
home, _ := os.UserHomeDir()
localDataDir := filepath.Join(home, ".local", "share", "gosh")
@ -67,3 +74,22 @@ func addKey(db *sql.DB, userPattern, hostPattern, keyPath string) error {
return nil
}
func listkeys(db *sql.DB) ([]KeyRecord, error) {
rows, err := db.Query("SELECT id, host_pattern, user_pattern, comment FROM keys;")
if err != nil {
return nil, fmt.Errorf("failed to query keys: %w", err)
}
defer rows.Close()
var keys []KeyRecord
for rows.Next() {
var k KeyRecord
if err := rows.Scan(&k.ID, &k.HostPattern, &k.UserPattern, &k.Comment); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
keys = append(keys, k)
}
return keys, nil
}

90
main.go
View file

@ -4,16 +4,13 @@ import (
"flag"
"fmt"
"os"
"os/exec"
"os/signal"
"strings"
"syscall"
)
func main() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [gosh-flags] [user@host] [ssh-flags]\n", os.Args[0])
fmt.Fprintf(os.Stderr, "\nManagement Commands:\n")
fmt.Fprintf(os.Stderr, " %s [flags] list-keys\n", os.Args[0])
fmt.Fprintf(os.Stderr, " %s [flags] add-key <user_pattern> <host_pattern> <path_to_key>\n", os.Args[0])
fmt.Fprintf(os.Stderr, "\nFlags:\n")
flag.PrintDefaults()
@ -32,98 +29,29 @@ func main() {
cmd := args[0]
if cmd == "add-key" {
switch cmd {
case "add-key":
if argc != 4 {
fmt.Println("Error: Incorrect arguments for add-key.")
fmt.Println("Try: gosh [--db path] add-key <user_pattern> <host_pattern> <path_to_key>")
os.Exit(1)
}
db, err := initDB(*dbPath)
err := handleAddKey(*dbPath, args[1], args[2], args[3])
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
defer db.Close()
os.Exit(0)
userPattern := args[1]
hostPattern := args[2]
keyPath := args[3]
err = addKey(db, userPattern, hostPattern, keyPath)
case "list-keys":
err := handleListKey(*dbPath)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
return
os.Exit(0)
}
user, host := parseDestination(args)
env := os.Environ()
if host != "" {
db, err := initDB(*dbPath)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
pemData, err := findKey(db, user, host)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
db.Close()
targetName := host
if user != "" {
targetName = user + "@" + host
}
socketPath, cleanup, err := startEphemeralAgent(pemData, targetName)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
defer cleanup()
newEnv := []string{"SSH_AUTH_SOCK=" + socketPath}
for _, e := range env {
if !strings.HasPrefix(e, "SSH_AUTH_SOCK=") {
newEnv = append(newEnv, e)
}
}
env = newEnv
}
sshPath, err := exec.LookPath("ssh")
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
sshCmd := exec.Command(sshPath, args...)
sshCmd.Env = env
sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout
sshCmd.Stderr = os.Stderr
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
for sig := range c {
if sshCmd.Process != nil {
sshCmd.Process.Signal(sig)
}
}
}()
err = sshCmd.Run()
if err != nil {
fmt.Println("Error:", err)
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
os.Exit(1)
}
startSSH(*dbPath, args)
}