add token framing, crypto and errors
This commit is contained in:
parent
71483982ca
commit
a6aad1a1d6
5 changed files with 448 additions and 0 deletions
56
crypto.go
Normal file
56
crypto.go
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
package ficha
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
keySize = chacha20poly1305.KeySize
|
||||||
|
nonceSize = chacha20poly1305.NonceSizeX
|
||||||
|
)
|
||||||
|
|
||||||
|
// encrypt seals plaintext under key. aad is authenticated but not encrypted.
|
||||||
|
// A fresh random nonce is generated per call; with a 24-byte nonce, random generation is collision-safe.
|
||||||
|
func encrypt(key, plaintext, aad []byte) (nonce, ciphertext []byte, err error) {
|
||||||
|
if len(key) != keySize {
|
||||||
|
return nil, nil, fmt.Errorf("ficha: key must be %d bytes, got %d", keySize, len(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
aead, err := chacha20poly1305.NewX(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("ficha: init aead: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce = make([]byte, nonceSize)
|
||||||
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("ficha: read random nonce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext = aead.Seal(nil, nonce, plaintext, aad)
|
||||||
|
return nonce, ciphertext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decrypt verifies and opens ciphertext. Any auth failure returns ErrInvalidToken
|
||||||
|
func decrypt(key, nonce, ciphertext, aad []byte) ([]byte, error) {
|
||||||
|
if len(key) != keySize {
|
||||||
|
return nil, fmt.Errorf("ficha: key must be %d bytes, got %d", keySize, len(key))
|
||||||
|
}
|
||||||
|
if len(nonce) != nonceSize {
|
||||||
|
return nil, errors.New("ficha: invalid nonce size")
|
||||||
|
}
|
||||||
|
|
||||||
|
aead, err := chacha20poly1305.NewX(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ficha: init aead: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext, err := aead.Open(nil, nonce, ciphertext, aad)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
155
crypto_test.go
Normal file
155
crypto_test.go
Normal file
|
|
@ -0,0 +1,155 @@
|
||||||
|
package ficha
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testKey(t *testing.T) []byte {
|
||||||
|
t.Helper()
|
||||||
|
k := make([]byte, keySize)
|
||||||
|
if _, err := rand.Read(k); err != nil {
|
||||||
|
t.Fatalf("generate key: %v", err)
|
||||||
|
}
|
||||||
|
return k
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptDecryptRoundtrip(t *testing.T) {
|
||||||
|
key := testKey(t)
|
||||||
|
plaintext := []byte("hello, ficha")
|
||||||
|
aad := []byte("v1.k1")
|
||||||
|
|
||||||
|
nonce, ciphertext, err := encrypt(key, plaintext, aad)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(nonce) != nonceSize {
|
||||||
|
t.Errorf("nonce size: got %d, want %d", len(nonce), nonceSize)
|
||||||
|
}
|
||||||
|
if len(ciphertext) != len(plaintext)+16 {
|
||||||
|
t.Errorf("ciphertext size: got %d, want %d", len(ciphertext), len(plaintext)+16)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := decrypt(key, nonce, ciphertext, aad)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, plaintext) {
|
||||||
|
t.Errorf("got %q, want %q", got, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptDecryptEmptyPlaintext(t *testing.T) {
|
||||||
|
key := testKey(t)
|
||||||
|
nonce, ciphertext, err := encrypt(key, []byte{}, []byte("aad"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt: %v", err)
|
||||||
|
}
|
||||||
|
got, err := decrypt(key, nonce, ciphertext, []byte("aad"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt: %v", err)
|
||||||
|
}
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Errorf("expected empty plaintext, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptWrongKey(t *testing.T) {
|
||||||
|
key1 := testKey(t)
|
||||||
|
key2 := testKey(t)
|
||||||
|
|
||||||
|
nonce, ciphertext, err := encrypt(key1, []byte("secret"), []byte("v1.k1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := decrypt(key2, nonce, ciphertext, []byte("v1.k1")); !errors.Is(err, ErrInvalidToken) {
|
||||||
|
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptTamperedCiphertext(t *testing.T) {
|
||||||
|
key := testKey(t)
|
||||||
|
nonce, ciphertext, err := encrypt(key, []byte("don't tamper"), []byte("v1.k1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tampered := append([]byte(nil), ciphertext...)
|
||||||
|
tampered[len(tampered)/2] ^= 0x01
|
||||||
|
|
||||||
|
if _, err := decrypt(key, nonce, tampered, []byte("v1.k1")); !errors.Is(err, ErrInvalidToken) {
|
||||||
|
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptTamperedNonce(t *testing.T) {
|
||||||
|
key := testKey(t)
|
||||||
|
nonce, ciphertext, err := encrypt(key, []byte("nonce matters"), []byte("v1.k1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tampered := append([]byte(nil), nonce...)
|
||||||
|
tampered[0] ^= 0x01
|
||||||
|
|
||||||
|
if _, err := decrypt(key, tampered, ciphertext, []byte("v1.k1")); !errors.Is(err, ErrInvalidToken) {
|
||||||
|
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptTamperedAAD(t *testing.T) {
|
||||||
|
key := testKey(t)
|
||||||
|
nonce, ciphertext, err := encrypt(key, []byte("aad authenticated"), []byte("v1.k1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := decrypt(key, nonce, ciphertext, []byte("v1.k2")); !errors.Is(err, ErrInvalidToken) {
|
||||||
|
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptInvalidSizes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
keyLen int
|
||||||
|
nonceLen int
|
||||||
|
}{
|
||||||
|
{"key too short", 16, nonceSize},
|
||||||
|
{"key too long", 64, nonceSize},
|
||||||
|
{"nonce too short", keySize, 12},
|
||||||
|
{"nonce too long", keySize, 32},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
key := make([]byte, tc.keyLen)
|
||||||
|
nonce := make([]byte, tc.nonceLen)
|
||||||
|
ct := make([]byte, 32)
|
||||||
|
if _, err := decrypt(key, nonce, ct, nil); err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptProducesUniqueNonces(t *testing.T) {
|
||||||
|
key := testKey(t)
|
||||||
|
const N = 1000
|
||||||
|
seen := make(map[string]bool, N)
|
||||||
|
|
||||||
|
for i := 0; i < N; i++ {
|
||||||
|
nonce, _, err := encrypt(key, []byte("same message"), []byte("v1.k1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iter %d: %v", i, err)
|
||||||
|
}
|
||||||
|
if seen[string(nonce)] {
|
||||||
|
t.Fatalf("duplicate nonce after %d iterations", i)
|
||||||
|
}
|
||||||
|
seen[string(nonce)] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
15
errors.go
Normal file
15
errors.go
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
package ficha
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrInvalidToken means the token is malformed, has an unsupported
|
||||||
|
// version, fails descryption, or has been tampered with.
|
||||||
|
ErrInvalidToken = errors.New("ficha: invalid token")
|
||||||
|
// ErrExpiredToken means the token's expiry time has passed.
|
||||||
|
ErrExpiredToken = errors.New("ficha: token expired")
|
||||||
|
// ErrRevokedToken means the token's ID is in the revocation store.
|
||||||
|
ErrRevokedToken = errors.New("ficha: token revoked")
|
||||||
|
// ErrUnknownKey means the token references a key ID not in the keyring.
|
||||||
|
ErrUnknownKey = errors.New("ficha: unknown key id")
|
||||||
|
)
|
||||||
58
framing.go
Normal file
58
framing.go
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
package ficha
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const tokenVersion = "v1"
|
||||||
|
|
||||||
|
// encodeToken builds the wire format: v1.<keyID>.<base64url(none||ciphertext)>
|
||||||
|
func encodeToken(keyID string, nonce, ciphertext []byte) string {
|
||||||
|
body := make([]byte, 0, len(nonce)+len(ciphertext))
|
||||||
|
body = append(body, nonce...)
|
||||||
|
body = append(body, ciphertext...)
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(tokenVersion) + 1 + len(keyID) + 1 + base64.RawURLEncoding.EncodedLen(len(body)))
|
||||||
|
b.WriteString(tokenVersion)
|
||||||
|
b.WriteByte('.')
|
||||||
|
b.WriteString(keyID)
|
||||||
|
b.WriteByte('.')
|
||||||
|
b.WriteString(base64.RawURLEncoding.EncodeToString(body))
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeToken parses the wire format and splits the body back into nonce and ciphertext.
|
||||||
|
// Returns ErrInvalidToken for any malformed input.
|
||||||
|
func decodeToken(token string) (keyID string, nonce, ciphertext []byte, err error) {
|
||||||
|
parts := strings.SplitN(token, ".", 3)
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return "", nil, nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
if parts[0] != tokenVersion {
|
||||||
|
return "", nil, nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
if parts[1] == "" {
|
||||||
|
return "", nil, nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// body must be at least nonce + auth tag (16 bytes)
|
||||||
|
if len(body) < nonceSize+16 {
|
||||||
|
return "", nil, nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return parts[1], body[:nonceSize], body[nonceSize:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// aadFor builds the AAD bytes for a given version+keyID. Both encrypt
|
||||||
|
// and decrypt must use the same construction.
|
||||||
|
func aadFor(keyID string) []byte {
|
||||||
|
return []byte(tokenVersion + "." + keyID)
|
||||||
|
}
|
||||||
164
framing_test.go
Normal file
164
framing_test.go
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
package ficha
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEncodeDecodeRoundtrip(t *testing.T) {
|
||||||
|
keyID := "k1"
|
||||||
|
nonce := bytes.Repeat([]byte{0xAB}, nonceSize)
|
||||||
|
ciphertext := []byte("not-real-ciphertext-but-long-enough-for-tag")
|
||||||
|
|
||||||
|
token := encodeToken(keyID, nonce, ciphertext)
|
||||||
|
|
||||||
|
if !strings.HasPrefix(token, "v1.k1.") {
|
||||||
|
t.Errorf("unexpected prefix: %s", token)
|
||||||
|
}
|
||||||
|
if strings.ContainsAny(token, "+/=") {
|
||||||
|
t.Errorf("token should be url-safe with no padding: %s", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotKeyID, gotNonce, gotCT, err := decodeToken(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decodeToken: %v", err)
|
||||||
|
}
|
||||||
|
if gotKeyID != keyID {
|
||||||
|
t.Errorf("keyID: got %q, want %q", gotKeyID, keyID)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(gotNonce, nonce) {
|
||||||
|
t.Errorf("nonce: got %x, want %x", gotNonce, nonce)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(gotCT, ciphertext) {
|
||||||
|
t.Errorf("ciphertext: got %q, want %q", gotCT, ciphertext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeTokenMalformed(t *testing.T) {
|
||||||
|
// Build a baseline valid token to mutate.
|
||||||
|
validNonce := bytes.Repeat([]byte{0x01}, nonceSize)
|
||||||
|
validCT := bytes.Repeat([]byte{0x02}, 32) // > 16 byte tag minimum
|
||||||
|
valid := encodeToken("k1", validNonce, validCT)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
}{
|
||||||
|
{"empty", ""},
|
||||||
|
{"only version", "v1"},
|
||||||
|
{"missing body", "v1.k1"},
|
||||||
|
{"too few parts", "v1.k1"},
|
||||||
|
{"unknown version", "v9.k1.YWFhYQ"},
|
||||||
|
{"empty key id", "v1..YWFhYQ"},
|
||||||
|
{"invalid base64", "v1.k1.!!!not_base64!!!"},
|
||||||
|
{"body too short", "v1.k1.YWFhYQ"}, // 4 bytes < nonceSize + 16
|
||||||
|
{"body too short after truncation", valid[:len("v1.k1.")+8]},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, _, _, err := decodeToken(tc.input)
|
||||||
|
if !errors.Is(err, ErrInvalidToken) {
|
||||||
|
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAadIsDeterministic(t *testing.T) {
|
||||||
|
a := aadFor("k1")
|
||||||
|
b := aadFor("k1")
|
||||||
|
if !bytes.Equal(a, b) {
|
||||||
|
t.Errorf("aadFor not deterministic: %q vs %q", a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
c := aadFor("k2")
|
||||||
|
if bytes.Equal(a, c) {
|
||||||
|
t.Error("different key IDs produced the same AAD")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFramingEndToEndWithCrypto(t *testing.T) {
|
||||||
|
// This is the first test that exercises crypto + framing together.
|
||||||
|
// It catches mismatches between encrypt's AAD and decrypt's AAD.
|
||||||
|
key := testKey(t)
|
||||||
|
keyID := "k1"
|
||||||
|
plaintext := []byte("framing meets crypto")
|
||||||
|
|
||||||
|
aad := aadFor(keyID)
|
||||||
|
nonce, ct, err := encrypt(key, plaintext, aad)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := encodeToken(keyID, nonce, ct)
|
||||||
|
|
||||||
|
gotKeyID, gotNonce, gotCT, err := decodeToken(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decodeToken: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := decrypt(key, gotNonce, gotCT, aadFor(gotKeyID))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, plaintext) {
|
||||||
|
t.Errorf("got %q, want %q", got, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFramingTamperKeyIDFailsAuth(t *testing.T) {
|
||||||
|
// If an attacker swaps the key ID in the wire format, AAD changes,
|
||||||
|
// and decrypt should fail. This is the whole point of putting
|
||||||
|
// keyID into the AAD.
|
||||||
|
key := testKey(t)
|
||||||
|
plaintext := []byte("bind keyID to ciphertext")
|
||||||
|
|
||||||
|
nonce, ct, err := encrypt(key, plaintext, aadFor("k1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a token claiming to be from k1, then rewrite to k2.
|
||||||
|
original := encodeToken("k1", nonce, ct)
|
||||||
|
tampered := strings.Replace(original, "v1.k1.", "v1.k2.", 1)
|
||||||
|
|
||||||
|
_, gotNonce, gotCT, err := decodeToken(tampered)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decodeToken: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to decrypt under the SAME key but with the tampered AAD.
|
||||||
|
if _, err := decrypt(key, gotNonce, gotCT, aadFor("k2")); !errors.Is(err, ErrInvalidToken) {
|
||||||
|
t.Errorf("expected ErrInvalidToken on key ID swap, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFramingTruncationFailsAuth(t *testing.T) {
|
||||||
|
// A token that's structurally valid but had bytes lopped off
|
||||||
|
// should pass framing and fail authentication.
|
||||||
|
key := testKey(t)
|
||||||
|
plaintext := []byte("hello, ficha")
|
||||||
|
|
||||||
|
nonce, ct, err := encrypt(key, plaintext, aadFor("k1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := encodeToken("k1", nonce, ct)
|
||||||
|
truncated := token[:len(token)-5]
|
||||||
|
|
||||||
|
// Framing may or may not accept it (depends on how much got chopped),
|
||||||
|
// but if it does, decryption MUST fail.
|
||||||
|
gotKeyID, gotNonce, gotCT, err := decodeToken(truncated)
|
||||||
|
if err != nil {
|
||||||
|
// Truncated past the minimum — framing rejected it. Also fine.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := decrypt(key, gotNonce, gotCT, aadFor(gotKeyID)); !errors.Is(err, ErrInvalidToken) {
|
||||||
|
t.Errorf("expected ErrInvalidToken on truncated ciphertext, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue