From f5539b6a95e76727ac6fd10c4386b4aac859b14a Mon Sep 17 00:00:00 2001 From: juancwu <46619361+juancwu@users.noreply.github.com> Date: Sun, 11 Jan 2026 18:36:35 -0500 Subject: [PATCH] check for unencrypted key before adding to database --- cmd.go | 12 +++++++++- db.go | 9 ++----- utils.go | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 8 deletions(-) diff --git a/cmd.go b/cmd.go index 56f95ae..390190b 100644 --- a/cmd.go +++ b/cmd.go @@ -13,7 +13,17 @@ func handleAddKey(dbPath, userPattern, hostPattern, keyPath string) error { } defer db.Close() - err = addKey(db, userPattern, hostPattern, keyPath) + pemData, err := os.ReadFile(keyPath) + if err != nil { + return fmt.Errorf("failed to read key: %w", err) + } + + finalPemData, err := checkAndEncryptKey(pemData) + if err != nil { + return fmt.Errorf("failed to check and encrypt key: %w", err) + } + + err = addKey(db, finalPemData, userPattern, hostPattern, keyPath) if err != nil { return err } diff --git a/db.go b/db.go index 953922e..a0cf2ad 100644 --- a/db.go +++ b/db.go @@ -57,17 +57,12 @@ func findKey(db *sql.DB, user, host string) ([]byte, error) { return pemData, nil } -func addKey(db *sql.DB, userPattern, hostPattern, keyPath string) error { - pemData, err := os.ReadFile(keyPath) - if err != nil { - return fmt.Errorf("failed to read key file: %w", err) - } - +func addKey(db *sql.DB, pemData []byte, userPattern, hostPattern, keyPath string) error { query := ` INSERT INTO keys (host_pattern, user_pattern, encrypted_pem, comment) VALUES (?, ?, ? ,?); ` - _, err = db.Exec(query, hostPattern, userPattern, pemData, "Imported from "+keyPath) + _, err := db.Exec(query, hostPattern, userPattern, pemData, "Imported from "+keyPath) if err != nil { return fmt.Errorf("failed to add key: %w", err) } diff --git a/utils.go b/utils.go index 7ee4f41..a32eaff 100644 --- a/utils.go +++ b/utils.go @@ -1,7 +1,15 @@ package main import ( + "fmt" + "os" + "os/exec" + "path/filepath" "strings" + "syscall" + + "golang.org/x/crypto/ssh" + "golang.org/x/term" ) func parseDestination(args []string) (user, host string) { @@ -38,3 +46,67 @@ func parseDestination(args []string) (user, host string) { } return "", "" } + +func checkAndEncryptKey(pemData []byte) ([]byte, error) { + _, err := ssh.ParseRawPrivateKey(pemData) + if err != nil { + // Key is encrypted, so there is no more further action needed + if _, ok := err.(*ssh.PassphraseMissingError); ok { + return pemData, nil + } + return nil, err + } + + fmt.Println("Warning: This key is unencrypted.") + fmt.Print("Would you like to encrypt it before storing? (y/N): ") + + var response string + fmt.Scanln(&response) + if strings.ToLower(response) != "y" { + return pemData, nil + } + + fmt.Print("Enter new passphrase: ") + bytePass, err := term.ReadPassword(int(syscall.Stdin)) + fmt.Println() + if err != nil { + return nil, fmt.Errorf("failed to read passphrase: %w", err) + } + passphrase := string(bytePass) + + fmt.Print("Confirm passphrase: ") + bytePassConfirm, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + return nil, fmt.Errorf("fialed to read passphrase: %w", err) + } + fmt.Println() + + if passphrase != string(bytePassConfirm) { + return nil, fmt.Errorf("passphrases do not match") + } + + tempDir, err := os.MkdirTemp("", "gosh-encrypt") + if err != nil { + return nil, err + } + defer os.RemoveAll(tempDir) + + tempKeyPath := filepath.Join(tempDir, "gosh_temp_key") + if err := os.WriteFile(tempKeyPath, pemData, 0600); err != nil { + return nil, err + } + + cmd := exec.Command("ssh-keygen", "-p", "-f", tempKeyPath, "-P", "", "-N", passphrase, "-Z", "aes256-ctr") + if output, err := cmd.CombinedOutput(); err != nil { + return nil, fmt.Errorf("ssh-keygen failed: %s: %s", err, string(output)) + } + + encryptedData, err := os.ReadFile(tempKeyPath) + if err != nil { + return nil, err + } + + fmt.Println("Key encrypted successfully (AES-256-CTR).") + + return encryptedData, nil +}