add token framing, crypto and errors
This commit is contained in:
parent
71483982ca
commit
a6aad1a1d6
5 changed files with 448 additions and 0 deletions
56
crypto.go
Normal file
56
crypto.go
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
package ficha
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
const (
|
||||
keySize = chacha20poly1305.KeySize
|
||||
nonceSize = chacha20poly1305.NonceSizeX
|
||||
)
|
||||
|
||||
// encrypt seals plaintext under key. aad is authenticated but not encrypted.
|
||||
// A fresh random nonce is generated per call; with a 24-byte nonce, random generation is collision-safe.
|
||||
func encrypt(key, plaintext, aad []byte) (nonce, ciphertext []byte, err error) {
|
||||
if len(key) != keySize {
|
||||
return nil, nil, fmt.Errorf("ficha: key must be %d bytes, got %d", keySize, len(key))
|
||||
}
|
||||
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("ficha: init aead: %w", err)
|
||||
}
|
||||
|
||||
nonce = make([]byte, nonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, nil, fmt.Errorf("ficha: read random nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext = aead.Seal(nil, nonce, plaintext, aad)
|
||||
return nonce, ciphertext, nil
|
||||
}
|
||||
|
||||
// decrypt verifies and opens ciphertext. Any auth failure returns ErrInvalidToken
|
||||
func decrypt(key, nonce, ciphertext, aad []byte) ([]byte, error) {
|
||||
if len(key) != keySize {
|
||||
return nil, fmt.Errorf("ficha: key must be %d bytes, got %d", keySize, len(key))
|
||||
}
|
||||
if len(nonce) != nonceSize {
|
||||
return nil, errors.New("ficha: invalid nonce size")
|
||||
}
|
||||
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ficha: init aead: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := aead.Open(nil, nonce, ciphertext, aad)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
155
crypto_test.go
Normal file
155
crypto_test.go
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
package ficha
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testKey(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
k := make([]byte, keySize)
|
||||
if _, err := rand.Read(k); err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
return k
|
||||
}
|
||||
|
||||
func TestEncryptDecryptRoundtrip(t *testing.T) {
|
||||
key := testKey(t)
|
||||
plaintext := []byte("hello, ficha")
|
||||
aad := []byte("v1.k1")
|
||||
|
||||
nonce, ciphertext, err := encrypt(key, plaintext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt: %v", err)
|
||||
}
|
||||
|
||||
if len(nonce) != nonceSize {
|
||||
t.Errorf("nonce size: got %d, want %d", len(nonce), nonceSize)
|
||||
}
|
||||
if len(ciphertext) != len(plaintext)+16 {
|
||||
t.Errorf("ciphertext size: got %d, want %d", len(ciphertext), len(plaintext)+16)
|
||||
}
|
||||
|
||||
got, err := decrypt(key, nonce, ciphertext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, plaintext) {
|
||||
t.Errorf("got %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptEmptyPlaintext(t *testing.T) {
|
||||
key := testKey(t)
|
||||
nonce, ciphertext, err := encrypt(key, []byte{}, []byte("aad"))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt: %v", err)
|
||||
}
|
||||
got, err := decrypt(key, nonce, ciphertext, []byte("aad"))
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt: %v", err)
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Errorf("expected empty plaintext, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWrongKey(t *testing.T) {
|
||||
key1 := testKey(t)
|
||||
key2 := testKey(t)
|
||||
|
||||
nonce, ciphertext, err := encrypt(key1, []byte("secret"), []byte("v1.k1"))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt: %v", err)
|
||||
}
|
||||
|
||||
if _, err := decrypt(key2, nonce, ciphertext, []byte("v1.k1")); !errors.Is(err, ErrInvalidToken) {
|
||||
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptTamperedCiphertext(t *testing.T) {
|
||||
key := testKey(t)
|
||||
nonce, ciphertext, err := encrypt(key, []byte("don't tamper"), []byte("v1.k1"))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt: %v", err)
|
||||
}
|
||||
|
||||
tampered := append([]byte(nil), ciphertext...)
|
||||
tampered[len(tampered)/2] ^= 0x01
|
||||
|
||||
if _, err := decrypt(key, nonce, tampered, []byte("v1.k1")); !errors.Is(err, ErrInvalidToken) {
|
||||
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptTamperedNonce(t *testing.T) {
|
||||
key := testKey(t)
|
||||
nonce, ciphertext, err := encrypt(key, []byte("nonce matters"), []byte("v1.k1"))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt: %v", err)
|
||||
}
|
||||
|
||||
tampered := append([]byte(nil), nonce...)
|
||||
tampered[0] ^= 0x01
|
||||
|
||||
if _, err := decrypt(key, tampered, ciphertext, []byte("v1.k1")); !errors.Is(err, ErrInvalidToken) {
|
||||
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptTamperedAAD(t *testing.T) {
|
||||
key := testKey(t)
|
||||
nonce, ciphertext, err := encrypt(key, []byte("aad authenticated"), []byte("v1.k1"))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt: %v", err)
|
||||
}
|
||||
|
||||
if _, err := decrypt(key, nonce, ciphertext, []byte("v1.k2")); !errors.Is(err, ErrInvalidToken) {
|
||||
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptInvalidSizes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyLen int
|
||||
nonceLen int
|
||||
}{
|
||||
{"key too short", 16, nonceSize},
|
||||
{"key too long", 64, nonceSize},
|
||||
{"nonce too short", keySize, 12},
|
||||
{"nonce too long", keySize, 32},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
key := make([]byte, tc.keyLen)
|
||||
nonce := make([]byte, tc.nonceLen)
|
||||
ct := make([]byte, 32)
|
||||
if _, err := decrypt(key, nonce, ct, nil); err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptProducesUniqueNonces(t *testing.T) {
|
||||
key := testKey(t)
|
||||
const N = 1000
|
||||
seen := make(map[string]bool, N)
|
||||
|
||||
for i := 0; i < N; i++ {
|
||||
nonce, _, err := encrypt(key, []byte("same message"), []byte("v1.k1"))
|
||||
if err != nil {
|
||||
t.Fatalf("iter %d: %v", i, err)
|
||||
}
|
||||
if seen[string(nonce)] {
|
||||
t.Fatalf("duplicate nonce after %d iterations", i)
|
||||
}
|
||||
seen[string(nonce)] = true
|
||||
}
|
||||
}
|
||||
15
errors.go
Normal file
15
errors.go
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
package ficha
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrInvalidToken means the token is malformed, has an unsupported
|
||||
// version, fails descryption, or has been tampered with.
|
||||
ErrInvalidToken = errors.New("ficha: invalid token")
|
||||
// ErrExpiredToken means the token's expiry time has passed.
|
||||
ErrExpiredToken = errors.New("ficha: token expired")
|
||||
// ErrRevokedToken means the token's ID is in the revocation store.
|
||||
ErrRevokedToken = errors.New("ficha: token revoked")
|
||||
// ErrUnknownKey means the token references a key ID not in the keyring.
|
||||
ErrUnknownKey = errors.New("ficha: unknown key id")
|
||||
)
|
||||
58
framing.go
Normal file
58
framing.go
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
package ficha
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const tokenVersion = "v1"
|
||||
|
||||
// encodeToken builds the wire format: v1.<keyID>.<base64url(none||ciphertext)>
|
||||
func encodeToken(keyID string, nonce, ciphertext []byte) string {
|
||||
body := make([]byte, 0, len(nonce)+len(ciphertext))
|
||||
body = append(body, nonce...)
|
||||
body = append(body, ciphertext...)
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(tokenVersion) + 1 + len(keyID) + 1 + base64.RawURLEncoding.EncodedLen(len(body)))
|
||||
b.WriteString(tokenVersion)
|
||||
b.WriteByte('.')
|
||||
b.WriteString(keyID)
|
||||
b.WriteByte('.')
|
||||
b.WriteString(base64.RawURLEncoding.EncodeToString(body))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// decodeToken parses the wire format and splits the body back into nonce and ciphertext.
|
||||
// Returns ErrInvalidToken for any malformed input.
|
||||
func decodeToken(token string) (keyID string, nonce, ciphertext []byte, err error) {
|
||||
parts := strings.SplitN(token, ".", 3)
|
||||
if len(parts) != 3 {
|
||||
return "", nil, nil, ErrInvalidToken
|
||||
}
|
||||
if parts[0] != tokenVersion {
|
||||
return "", nil, nil, ErrInvalidToken
|
||||
}
|
||||
if parts[1] == "" {
|
||||
return "", nil, nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
body, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return "", nil, nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
// body must be at least nonce + auth tag (16 bytes)
|
||||
if len(body) < nonceSize+16 {
|
||||
return "", nil, nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
return parts[1], body[:nonceSize], body[nonceSize:], nil
|
||||
}
|
||||
|
||||
// aadFor builds the AAD bytes for a given version+keyID. Both encrypt
|
||||
// and decrypt must use the same construction.
|
||||
func aadFor(keyID string) []byte {
|
||||
return []byte(tokenVersion + "." + keyID)
|
||||
}
|
||||
164
framing_test.go
Normal file
164
framing_test.go
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
package ficha
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncodeDecodeRoundtrip(t *testing.T) {
|
||||
keyID := "k1"
|
||||
nonce := bytes.Repeat([]byte{0xAB}, nonceSize)
|
||||
ciphertext := []byte("not-real-ciphertext-but-long-enough-for-tag")
|
||||
|
||||
token := encodeToken(keyID, nonce, ciphertext)
|
||||
|
||||
if !strings.HasPrefix(token, "v1.k1.") {
|
||||
t.Errorf("unexpected prefix: %s", token)
|
||||
}
|
||||
if strings.ContainsAny(token, "+/=") {
|
||||
t.Errorf("token should be url-safe with no padding: %s", token)
|
||||
}
|
||||
|
||||
gotKeyID, gotNonce, gotCT, err := decodeToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeToken: %v", err)
|
||||
}
|
||||
if gotKeyID != keyID {
|
||||
t.Errorf("keyID: got %q, want %q", gotKeyID, keyID)
|
||||
}
|
||||
if !bytes.Equal(gotNonce, nonce) {
|
||||
t.Errorf("nonce: got %x, want %x", gotNonce, nonce)
|
||||
}
|
||||
if !bytes.Equal(gotCT, ciphertext) {
|
||||
t.Errorf("ciphertext: got %q, want %q", gotCT, ciphertext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeTokenMalformed(t *testing.T) {
|
||||
// Build a baseline valid token to mutate.
|
||||
validNonce := bytes.Repeat([]byte{0x01}, nonceSize)
|
||||
validCT := bytes.Repeat([]byte{0x02}, 32) // > 16 byte tag minimum
|
||||
valid := encodeToken("k1", validNonce, validCT)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"empty", ""},
|
||||
{"only version", "v1"},
|
||||
{"missing body", "v1.k1"},
|
||||
{"too few parts", "v1.k1"},
|
||||
{"unknown version", "v9.k1.YWFhYQ"},
|
||||
{"empty key id", "v1..YWFhYQ"},
|
||||
{"invalid base64", "v1.k1.!!!not_base64!!!"},
|
||||
{"body too short", "v1.k1.YWFhYQ"}, // 4 bytes < nonceSize + 16
|
||||
{"body too short after truncation", valid[:len("v1.k1.")+8]},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, _, _, err := decodeToken(tc.input)
|
||||
if !errors.Is(err, ErrInvalidToken) {
|
||||
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAadIsDeterministic(t *testing.T) {
|
||||
a := aadFor("k1")
|
||||
b := aadFor("k1")
|
||||
if !bytes.Equal(a, b) {
|
||||
t.Errorf("aadFor not deterministic: %q vs %q", a, b)
|
||||
}
|
||||
|
||||
c := aadFor("k2")
|
||||
if bytes.Equal(a, c) {
|
||||
t.Error("different key IDs produced the same AAD")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFramingEndToEndWithCrypto(t *testing.T) {
|
||||
// This is the first test that exercises crypto + framing together.
|
||||
// It catches mismatches between encrypt's AAD and decrypt's AAD.
|
||||
key := testKey(t)
|
||||
keyID := "k1"
|
||||
plaintext := []byte("framing meets crypto")
|
||||
|
||||
aad := aadFor(keyID)
|
||||
nonce, ct, err := encrypt(key, plaintext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt: %v", err)
|
||||
}
|
||||
|
||||
token := encodeToken(keyID, nonce, ct)
|
||||
|
||||
gotKeyID, gotNonce, gotCT, err := decodeToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeToken: %v", err)
|
||||
}
|
||||
|
||||
got, err := decrypt(key, gotNonce, gotCT, aadFor(gotKeyID))
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, plaintext) {
|
||||
t.Errorf("got %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFramingTamperKeyIDFailsAuth(t *testing.T) {
|
||||
// If an attacker swaps the key ID in the wire format, AAD changes,
|
||||
// and decrypt should fail. This is the whole point of putting
|
||||
// keyID into the AAD.
|
||||
key := testKey(t)
|
||||
plaintext := []byte("bind keyID to ciphertext")
|
||||
|
||||
nonce, ct, err := encrypt(key, plaintext, aadFor("k1"))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt: %v", err)
|
||||
}
|
||||
|
||||
// Build a token claiming to be from k1, then rewrite to k2.
|
||||
original := encodeToken("k1", nonce, ct)
|
||||
tampered := strings.Replace(original, "v1.k1.", "v1.k2.", 1)
|
||||
|
||||
_, gotNonce, gotCT, err := decodeToken(tampered)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeToken: %v", err)
|
||||
}
|
||||
|
||||
// Try to decrypt under the SAME key but with the tampered AAD.
|
||||
if _, err := decrypt(key, gotNonce, gotCT, aadFor("k2")); !errors.Is(err, ErrInvalidToken) {
|
||||
t.Errorf("expected ErrInvalidToken on key ID swap, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFramingTruncationFailsAuth(t *testing.T) {
|
||||
// A token that's structurally valid but had bytes lopped off
|
||||
// should pass framing and fail authentication.
|
||||
key := testKey(t)
|
||||
plaintext := []byte("hello, ficha")
|
||||
|
||||
nonce, ct, err := encrypt(key, plaintext, aadFor("k1"))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt: %v", err)
|
||||
}
|
||||
|
||||
token := encodeToken("k1", nonce, ct)
|
||||
truncated := token[:len(token)-5]
|
||||
|
||||
// Framing may or may not accept it (depends on how much got chopped),
|
||||
// but if it does, decryption MUST fail.
|
||||
gotKeyID, gotNonce, gotCT, err := decodeToken(truncated)
|
||||
if err != nil {
|
||||
// Truncated past the minimum — framing rejected it. Also fine.
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := decrypt(key, gotNonce, gotCT, aadFor(gotKeyID)); !errors.Is(err, ErrInvalidToken) {
|
||||
t.Errorf("expected ErrInvalidToken on truncated ciphertext, got %v", err)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue