From c55203f43f401cef7f61e8d512212c9e46ee4207 Mon Sep 17 00:00:00 2001 From: juancwu <46619361+juancwu@users.noreply.github.com> Date: Sun, 11 Jan 2026 20:39:16 -0500 Subject: [PATCH] move away from sqlite into json file to store keys --- agent.go | 12 +-- cmd.go | 91 ----------------------- db.go | 138 ---------------------------------- main.go | 45 ++++++++--- store.go | 221 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 258 insertions(+), 249 deletions(-) delete mode 100644 cmd.go delete mode 100644 db.go create mode 100644 store.go diff --git a/agent.go b/agent.go index ceee523..03b0146 100644 --- a/agent.go +++ b/agent.go @@ -79,25 +79,17 @@ func startEphemeralAgent(pemData []byte, target string) (string, func(), error) return sockPath, cleanup, nil } -func startSSH(dbPath string, args []string) { +func startSSH(storePath string, args []string) { user, host := parseDestination(args) env := os.Environ() if host != "" { - db, err := initDB(dbPath) + pemData, err := findKey(storePath, user, host) if err != nil { fmt.Println("Error:", err) os.Exit(1) } - pemData, err := findKey(db, user, host) - if err != nil { - fmt.Println("Error:", err) - os.Exit(1) - } - - db.Close() - targetName := host if user != "" { targetName = user + "@" + host diff --git a/cmd.go b/cmd.go deleted file mode 100644 index 3a82545..0000000 --- a/cmd.go +++ /dev/null @@ -1,91 +0,0 @@ -package main - -import ( - "fmt" - "os" - "text/tabwriter" -) - -func handleAddKey(dbPath, userPattern, hostPattern, keyPath string) error { - db, err := initDB(dbPath) - if err != nil { - return err - } - defer db.Close() - - pemData, err := readKeyFile(keyPath) - if err != nil { - return fmt.Errorf("failed to read key: %w", err) - } - - finalPemData, err := checkAndEncryptKey(pemData) - if err != nil { - return fmt.Errorf("failed to check and encrypt key: %w", err) - } - - source := "Stdin" - if keyPath != "-" { - source = keyPath - } - - err = addKey(db, finalPemData, userPattern, hostPattern, source) - if err != nil { - return err - } - - fmt.Printf("Key for %s@%s imported successfully.\n", userPattern, hostPattern) - - return nil -} - -func handleListKey(dbPath string) error { - db, err := initDB(dbPath) - if err != nil { - return err - } - defer db.Close() - - keys, err := listkeys(db) - if err != nil { - return err - } - - w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) - fmt.Fprintln(w, "UD\tUser Pattern\tHost Pattern\tComment") - fmt.Fprintln(w, "--\t------------\t------------\t-------") - for _, k := range keys { - fmt.Fprintf(w, "%d\t%s\t%s\t%s\n", k.ID, k.UserPattern, k.HostPattern, k.Comment) - } - w.Flush() - return nil -} - -func handleUpdateKey(dbPath, userPattern, hostPattern, keyPath string) error { - db, err := initDB(dbPath) - if err != nil { - return err - } - defer db.Close() - - err = updateKey(db, userPattern, hostPattern, keyPath) - if err != nil { - return err - } - - return nil -} - -func handleDeleteKey(dbPath string, id int) error { - db, err := initDB(dbPath) - if err != nil { - return err - } - defer db.Close() - - err = deleteKey(db, id) - if err != nil { - return err - } - - return nil -} diff --git a/db.go b/db.go deleted file mode 100644 index 08e4f10..0000000 --- a/db.go +++ /dev/null @@ -1,138 +0,0 @@ -package main - -import ( - "database/sql" - "fmt" - "os" - "path/filepath" - - _ "modernc.org/sqlite" -) - -type KeyRecord struct { - ID int - HostPattern string - UserPattern string - Comment string -} - -func getDBPath() string { - home, _ := os.UserHomeDir() - localDataDir := filepath.Join(home, ".local", "share", "gosh") - err := os.MkdirAll(localDataDir, 0700) - if err != nil { - fmt.Println("[WARNING] Failed to create local data direction '", localDataDir, "': ", err) - fmt.Println("[WARNING] Putting database in current working directory.") - return "./gosh.db" - } - return filepath.Join(localDataDir, "gosh.db") -} - -func initDB(dbPath string) (*sql.DB, error) { - if dbPath == "" { - dbPath = getDBPath() - } - - db, err := sql.Open("sqlite", dbPath) - if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) - } - - return db, nil -} - -func findKey(db *sql.DB, user, host string) ([]byte, error) { - query := ` - SELECT encrypted_pem FROM keys - WHERE (host_pattern = ? OR host_pattern = '*') AND (user_pattern = ? OR user_pattern = '*') - ORDER BY id DESC LIMIT 1; - ` - - var pemData []byte - err := db.QueryRow(query, host, user).Scan(&pemData) - if err != nil { - return nil, err - } - - return pemData, nil -} - -func addKey(db *sql.DB, pemData []byte, userPattern, hostPattern, source string) error { - query := ` - INSERT INTO keys (host_pattern, user_pattern, encrypted_pem, comment) - VALUES (?, ?, ? ,?); - ` - _, err := db.Exec(query, hostPattern, userPattern, pemData, "Imported from "+source) - if err != nil { - return fmt.Errorf("failed to add key: %w", err) - } - - return nil -} - -func listkeys(db *sql.DB) ([]KeyRecord, error) { - rows, err := db.Query("SELECT id, host_pattern, user_pattern, comment FROM keys;") - if err != nil { - return nil, fmt.Errorf("failed to query keys: %w", err) - } - defer rows.Close() - - var keys []KeyRecord - for rows.Next() { - var k KeyRecord - if err := rows.Scan(&k.ID, &k.HostPattern, &k.UserPattern, &k.Comment); err != nil { - return nil, fmt.Errorf("failed to scan row: %w", err) - } - keys = append(keys, k) - } - - return keys, nil -} - -func updateKey(db *sql.DB, userPattern, hostPattern, keyPath string) error { - pemData, err := os.ReadFile(keyPath) - if err != nil { - return fmt.Errorf("failed to read key: %w", err) - } - - res, err := db.Exec( - "UPDATE keys SET encrypted_pem=?, comment=? WHERE user_pattern = ? AND host_pattern = ?;", - pemData, "Updated from "+keyPath, userPattern, hostPattern, - ) - if err != nil { - return fmt.Errorf("failed to update key: %w", err) - } - - rows, err := res.RowsAffected() - if err == nil { - if rows == 0 { - fmt.Printf("No key found with user '%s' and host '%s'.\n", userPattern, hostPattern) - } else { - fmt.Printf("Key for %s@%s updated successfully.\n", userPattern, hostPattern) - } - } else { - fmt.Println("Warning: could not verify update result.", err) - } - - return nil -} - -func deleteKey(db *sql.DB, id int) error { - res, err := db.Exec("DELETE FROM keys WHERE id = ?;", id) - if err != nil { - return fmt.Errorf("failed to delete key: %w", err) - } - - rows, err := res.RowsAffected() - if err == nil { - if rows == 0 { - fmt.Printf("No key found with ID %d.\n", id) - } else { - fmt.Printf("Key with ID %d deleted.\n", id) - } - } else { - fmt.Println("Warning: could not confirm key deletion.", err) - } - - return nil -} diff --git a/main.go b/main.go index 846b905..fe6b13f 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "strconv" + "text/tabwriter" ) func main() { @@ -13,13 +14,13 @@ func main() { fmt.Fprintf(os.Stderr, "\nManagement Commands:\n") fmt.Fprintf(os.Stderr, " %s [flags] list-keys\n", os.Args[0]) fmt.Fprintf(os.Stderr, " %s [flags] add-key \n", os.Args[0]) - fmt.Fprintf(os.Stderr, " %s [flags] update-key \n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s [flags] update-key \n", os.Args[0]) fmt.Fprintf(os.Stderr, " %s [flags] delete-key \n", os.Args[0]) fmt.Fprintf(os.Stderr, "\nFlags:\n") flag.PrintDefaults() } - dbPath := flag.String("db", "", "Path to the keys database (optional)") + storePath := flag.String("store", "", "Path to the keys store (optional)") flag.Parse() args := flag.Args() @@ -36,11 +37,15 @@ func main() { case "add-key": if argc != 4 { fmt.Println("Error: Incorrect arguments for add-key.") - fmt.Println("Try: gosh [--db path] add-key ") + fmt.Println("Try: gosh [-store path] add-key ") os.Exit(1) } - err := handleAddKey(*dbPath, args[1], args[2], args[3]) + user := args[1] + host := args[2] + keyPath := args[3] + + err := addKey(*storePath, user, host, keyPath) if err != nil { fmt.Println("Error:", err) os.Exit(1) @@ -48,20 +53,40 @@ func main() { os.Exit(0) case "list-keys": - err := handleListKey(*dbPath) + keys, err := listkeys(*storePath) if err != nil { fmt.Println("Error:", err) os.Exit(1) } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) + fmt.Fprintln(w, "UD\tUser Pattern\tHost Pattern\tComment") + fmt.Fprintln(w, "--\t------------\t------------\t-------") + for _, k := range keys { + fmt.Fprintf(w, "%d\t%s\t%s\t%s\n", k.ID, k.UserPattern, k.HostPattern, k.Comment) + } + w.Flush() + os.Exit(0) case "update-key": - if argc != 4 { - fmt.Println("Usage: gosh update-key ") + if argc != 5 { + fmt.Println("Usage: gosh update-key ") os.Exit(1) } - err := handleUpdateKey(*dbPath, args[1], args[2], args[3]) + idStr := args[1] + user := args[2] + host := args[3] + keyPath := args[4] + + id, err := strconv.ParseInt(idStr, 10, 32) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + + err = updateKey(*storePath, int(id), user, host, keyPath) if err != nil { fmt.Println("Error:", err) os.Exit(1) @@ -80,13 +105,13 @@ func main() { os.Exit(1) } - err = handleDeleteKey(*dbPath, int(id)) + err = deleteKey(*storePath, int(id)) if err != nil { fmt.Println("Error:", err) os.Exit(1) } os.Exit(0) default: - startSSH(*dbPath, args) + startSSH(*storePath, args) } } diff --git a/store.go b/store.go new file mode 100644 index 0000000..532dc19 --- /dev/null +++ b/store.go @@ -0,0 +1,221 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + + _ "modernc.org/sqlite" +) + +type KeyRecord struct { + ID int `json:"id"` + HostPattern string `json:"host_pattern"` + UserPattern string `json:"user_pattern"` + EncryptedPem []byte `json:"encrypted_pem"` // Storea as base64, default marshal encoding + Comment string `json:"comment"` +} + +type Storage struct { + Keys []KeyRecord `json:"keys"` +} + +func getStorePath(customPath string) string { + if customPath != "" { + return customPath + } + home, _ := os.UserHomeDir() + localDataDir := filepath.Join(home, ".local", "share", "gosh") + err := os.MkdirAll(localDataDir, 0700) + if err != nil { + fmt.Println("Warning: Failed to create local data direction '", localDataDir, "': ", err) + fmt.Println("Warning: Putting database in current working directory.") + return "./keys.json" + } + return filepath.Join(localDataDir, "keys.json") +} + +func loadStore(storePath string) (*Storage, error) { + data, err := os.ReadFile(storePath) + if os.IsNotExist(err) { + return &Storage{Keys: []KeyRecord{}}, nil + } + if err != nil { + return nil, err + } + var store Storage + if err := json.Unmarshal(data, &store); err != nil { + return nil, fmt.Errorf("corrupt store file: %w", err) + } + return &store, nil +} + +func saveStore(storePath string, store *Storage) error { + data, err := json.MarshalIndent(store, "", " ") + if err != nil { + return err + } + return os.WriteFile(storePath, data, 0600) +} + +func findKey(storePath, user, host string) ([]byte, error) { + store, err := loadStore(getStorePath(storePath)) + if err != nil { + return nil, err + } + + sort.Slice(store.Keys, func(i, j int) bool { + return len(store.Keys[i].HostPattern) > len(store.Keys[j].HostPattern) + }) + + for _, k := range store.Keys { + hostMatched, _ := filepath.Match(k.HostPattern, host) + userMatched, _ := filepath.Match(k.UserPattern, user) + if hostMatched && userMatched { + return k.EncryptedPem, nil + } + } + + return nil, fmt.Errorf("no matching key found") +} + +func addKey(storePath, userPattern, hostPattern, keyPath string) error { + realPath := getStorePath(storePath) + store, err := loadStore(realPath) + if err != nil { + return err + } + + pemData, err := readKeyFile(keyPath) + if err != nil { + return err + } + + finalPemData, err := checkAndEncryptKey(pemData) + if err != nil { + return err + } + + source := "Stdin" + if keyPath != "-" { + source = keyPath + } + + maxID := 0 + for _, k := range store.Keys { + if k.ID > maxID { + maxID = k.ID + } + } + + newKey := KeyRecord{ + ID: maxID + 1, + HostPattern: hostPattern, + UserPattern: userPattern, + EncryptedPem: finalPemData, + Comment: "Imported from " + source, + } + + store.Keys = append(store.Keys, newKey) + + if err := saveStore(realPath, store); err != nil { + return err + } + + fmt.Printf("Key for %s@%s imported successfully (ID: %d).\n", userPattern, hostPattern, newKey.ID) + + return nil +} + +func listkeys(storePath string) ([]KeyRecord, error) { + store, err := loadStore(getStorePath(storePath)) + if err != nil { + return nil, err + } + + return store.Keys, nil +} + +func updateKey(storePath string, keyID int, userPattern, hostPattern, keyPath string) error { + realPath := getStorePath(storePath) + store, err := loadStore(realPath) + if err != nil { + return err + } + + pemData, err := readKeyFile(keyPath) + if err != nil { + return err + } + + finalPemData, err := checkAndEncryptKey(pemData) + if err != nil { + return err + } + + source := "Stdin" + if keyPath != "-" { + source = keyPath + } + + found := false + for i, k := range store.Keys { + if k.ID == keyID { + store.Keys[i].HostPattern = hostPattern + store.Keys[i].UserPattern = userPattern + store.Keys[i].EncryptedPem = finalPemData + store.Keys[i].Comment = "Updated from " + source + found = true + break + } + } + + if !found { + fmt.Printf("no key found with ID %d.\n", keyID) + return nil + } + + err = saveStore(realPath, store) + if err != nil { + return fmt.Errorf("failed to save store: %w", err) + } + + fmt.Printf("key %d updated successfully.\n", keyID) + + return nil +} + +func deleteKey(storePath string, id int) error { + realPath := getStorePath(storePath) + store, err := loadStore(realPath) + if err != nil { + return err + } + + newKeys := []KeyRecord{} + found := false + for _, k := range store.Keys { + if k.ID == id { + found = true + continue + } + newKeys = append(newKeys, k) + } + + if !found { + fmt.Printf("no key found with ID %d.\n", id) + return nil + } + + store.Keys = newKeys + err = saveStore(realPath, store) + if err != nil { + return err + } + + fmt.Printf("key %d deleted.\n", id) + + return nil +}