From e97929cd0861415baf9ec31d091f78321406c600 Mon Sep 17 00:00:00 2001 From: juancwu <46619361+juancwu@users.noreply.github.com> Date: Sun, 11 Jan 2026 18:47:54 -0500 Subject: [PATCH] read key from stdin --- cmd.go | 12 ++++++++++-- db.go | 4 ++-- utils.go | 39 +++++++++++++++++++++++++++++---------- 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/cmd.go b/cmd.go index 390190b..3a82545 100644 --- a/cmd.go +++ b/cmd.go @@ -13,7 +13,7 @@ func handleAddKey(dbPath, userPattern, hostPattern, keyPath string) error { } defer db.Close() - pemData, err := os.ReadFile(keyPath) + pemData, err := readKeyFile(keyPath) if err != nil { return fmt.Errorf("failed to read key: %w", err) } @@ -23,10 +23,18 @@ func handleAddKey(dbPath, userPattern, hostPattern, keyPath string) error { return fmt.Errorf("failed to check and encrypt key: %w", err) } - err = addKey(db, finalPemData, userPattern, hostPattern, keyPath) + source := "Stdin" + if keyPath != "-" { + source = keyPath + } + + err = addKey(db, finalPemData, userPattern, hostPattern, source) if err != nil { return err } + + fmt.Printf("Key for %s@%s imported successfully.\n", userPattern, hostPattern) + return nil } diff --git a/db.go b/db.go index a0cf2ad..08e4f10 100644 --- a/db.go +++ b/db.go @@ -57,12 +57,12 @@ func findKey(db *sql.DB, user, host string) ([]byte, error) { return pemData, nil } -func addKey(db *sql.DB, pemData []byte, userPattern, hostPattern, keyPath string) error { +func addKey(db *sql.DB, pemData []byte, userPattern, hostPattern, source string) error { query := ` INSERT INTO keys (host_pattern, user_pattern, encrypted_pem, comment) VALUES (?, ?, ? ,?); ` - _, err := db.Exec(query, hostPattern, userPattern, pemData, "Imported from "+keyPath) + _, err := db.Exec(query, hostPattern, userPattern, pemData, "Imported from "+source) if err != nil { return fmt.Errorf("failed to add key: %w", err) } diff --git a/utils.go b/utils.go index a32eaff..7e07f1f 100644 --- a/utils.go +++ b/utils.go @@ -2,11 +2,11 @@ package main import ( "fmt" + "io" "os" "os/exec" "path/filepath" "strings" - "syscall" "golang.org/x/crypto/ssh" "golang.org/x/term" @@ -57,29 +57,37 @@ func checkAndEncryptKey(pemData []byte) ([]byte, error) { return nil, err } - fmt.Println("Warning: This key is unencrypted.") - fmt.Print("Would you like to encrypt it before storing? (y/N): ") + tty, err := getTTY() + if err != nil { + fmt.Println("Warning: cannot open TTY for encryption prompt. Storing unencrypted.") + return pemData, nil + } + defer tty.Close() + ttyFd := int(tty.Fd()) + + fmt.Fprintln(tty, "Warning: This key is unencrypted.") + fmt.Fprint(tty, "Would you like to encrypt it before storing? (y/N): ") var response string - fmt.Scanln(&response) + fmt.Fscanln(tty, &response) if strings.ToLower(response) != "y" { return pemData, nil } - fmt.Print("Enter new passphrase: ") - bytePass, err := term.ReadPassword(int(syscall.Stdin)) - fmt.Println() + fmt.Fprint(tty, "Enter new passphrase: ") + bytePass, err := term.ReadPassword(ttyFd) + fmt.Fprintln(tty) if err != nil { return nil, fmt.Errorf("failed to read passphrase: %w", err) } passphrase := string(bytePass) - fmt.Print("Confirm passphrase: ") - bytePassConfirm, err := term.ReadPassword(int(syscall.Stdin)) + fmt.Fprint(tty, "Confirm passphrase: ") + bytePassConfirm, err := term.ReadPassword(ttyFd) if err != nil { return nil, fmt.Errorf("fialed to read passphrase: %w", err) } - fmt.Println() + fmt.Fprintln(tty) if passphrase != string(bytePassConfirm) { return nil, fmt.Errorf("passphrases do not match") @@ -110,3 +118,14 @@ func checkAndEncryptKey(pemData []byte) ([]byte, error) { return encryptedData, nil } + +func getTTY() (*os.File, error) { + return os.OpenFile("/dev/tty", os.O_RDWR, 0) +} + +func readKeyFile(path string) ([]byte, error) { + if path == "-" { + return io.ReadAll(os.Stdin) + } + return os.ReadFile(path) +}