read key from stdin

This commit is contained in:
juancwu 2026-01-11 18:47:54 -05:00
commit e97929cd08
3 changed files with 41 additions and 14 deletions

12
cmd.go
View file

@ -13,7 +13,7 @@ func handleAddKey(dbPath, userPattern, hostPattern, keyPath string) error {
} }
defer db.Close() defer db.Close()
pemData, err := os.ReadFile(keyPath) pemData, err := readKeyFile(keyPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to read key: %w", err) return fmt.Errorf("failed to read key: %w", err)
} }
@ -23,10 +23,18 @@ func handleAddKey(dbPath, userPattern, hostPattern, keyPath string) error {
return fmt.Errorf("failed to check and encrypt key: %w", err) return fmt.Errorf("failed to check and encrypt key: %w", err)
} }
err = addKey(db, finalPemData, userPattern, hostPattern, keyPath) source := "Stdin"
if keyPath != "-" {
source = keyPath
}
err = addKey(db, finalPemData, userPattern, hostPattern, source)
if err != nil { if err != nil {
return err return err
} }
fmt.Printf("Key for %s@%s imported successfully.\n", userPattern, hostPattern)
return nil return nil
} }

4
db.go
View file

@ -57,12 +57,12 @@ func findKey(db *sql.DB, user, host string) ([]byte, error) {
return pemData, nil return pemData, nil
} }
func addKey(db *sql.DB, pemData []byte, userPattern, hostPattern, keyPath string) error { func addKey(db *sql.DB, pemData []byte, userPattern, hostPattern, source string) error {
query := ` query := `
INSERT INTO keys (host_pattern, user_pattern, encrypted_pem, comment) INSERT INTO keys (host_pattern, user_pattern, encrypted_pem, comment)
VALUES (?, ?, ? ,?); VALUES (?, ?, ? ,?);
` `
_, err := db.Exec(query, hostPattern, userPattern, pemData, "Imported from "+keyPath) _, err := db.Exec(query, hostPattern, userPattern, pemData, "Imported from "+source)
if err != nil { if err != nil {
return fmt.Errorf("failed to add key: %w", err) return fmt.Errorf("failed to add key: %w", err)
} }

View file

@ -2,11 +2,11 @@ package main
import ( import (
"fmt" "fmt"
"io"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
"syscall"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/term" "golang.org/x/term"
@ -57,29 +57,37 @@ func checkAndEncryptKey(pemData []byte) ([]byte, error) {
return nil, err return nil, err
} }
fmt.Println("Warning: This key is unencrypted.") tty, err := getTTY()
fmt.Print("Would you like to encrypt it before storing? (y/N): ") if err != nil {
fmt.Println("Warning: cannot open TTY for encryption prompt. Storing unencrypted.")
return pemData, nil
}
defer tty.Close()
ttyFd := int(tty.Fd())
fmt.Fprintln(tty, "Warning: This key is unencrypted.")
fmt.Fprint(tty, "Would you like to encrypt it before storing? (y/N): ")
var response string var response string
fmt.Scanln(&response) fmt.Fscanln(tty, &response)
if strings.ToLower(response) != "y" { if strings.ToLower(response) != "y" {
return pemData, nil return pemData, nil
} }
fmt.Print("Enter new passphrase: ") fmt.Fprint(tty, "Enter new passphrase: ")
bytePass, err := term.ReadPassword(int(syscall.Stdin)) bytePass, err := term.ReadPassword(ttyFd)
fmt.Println() fmt.Fprintln(tty)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read passphrase: %w", err) return nil, fmt.Errorf("failed to read passphrase: %w", err)
} }
passphrase := string(bytePass) passphrase := string(bytePass)
fmt.Print("Confirm passphrase: ") fmt.Fprint(tty, "Confirm passphrase: ")
bytePassConfirm, err := term.ReadPassword(int(syscall.Stdin)) bytePassConfirm, err := term.ReadPassword(ttyFd)
if err != nil { if err != nil {
return nil, fmt.Errorf("fialed to read passphrase: %w", err) return nil, fmt.Errorf("fialed to read passphrase: %w", err)
} }
fmt.Println() fmt.Fprintln(tty)
if passphrase != string(bytePassConfirm) { if passphrase != string(bytePassConfirm) {
return nil, fmt.Errorf("passphrases do not match") return nil, fmt.Errorf("passphrases do not match")
@ -110,3 +118,14 @@ func checkAndEncryptKey(pemData []byte) ([]byte, error) {
return encryptedData, nil return encryptedData, nil
} }
func getTTY() (*os.File, error) {
return os.OpenFile("/dev/tty", os.O_RDWR, 0)
}
func readKeyFile(path string) ([]byte, error) {
if path == "-" {
return io.ReadAll(os.Stdin)
}
return os.ReadFile(path)
}