add token framing, crypto and errors

This commit is contained in:
juancwu 2026-04-29 01:28:05 +00:00
commit a6aad1a1d6
5 changed files with 448 additions and 0 deletions

56
crypto.go Normal file
View file

@ -0,0 +1,56 @@
package ficha
import (
"crypto/rand"
"errors"
"fmt"
"golang.org/x/crypto/chacha20poly1305"
)
const (
keySize = chacha20poly1305.KeySize
nonceSize = chacha20poly1305.NonceSizeX
)
// encrypt seals plaintext under key. aad is authenticated but not encrypted.
// A fresh random nonce is generated per call; with a 24-byte nonce, random generation is collision-safe.
func encrypt(key, plaintext, aad []byte) (nonce, ciphertext []byte, err error) {
if len(key) != keySize {
return nil, nil, fmt.Errorf("ficha: key must be %d bytes, got %d", keySize, len(key))
}
aead, err := chacha20poly1305.NewX(key)
if err != nil {
return nil, nil, fmt.Errorf("ficha: init aead: %w", err)
}
nonce = make([]byte, nonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, nil, fmt.Errorf("ficha: read random nonce: %w", err)
}
ciphertext = aead.Seal(nil, nonce, plaintext, aad)
return nonce, ciphertext, nil
}
// decrypt verifies and opens ciphertext. Any auth failure returns ErrInvalidToken
func decrypt(key, nonce, ciphertext, aad []byte) ([]byte, error) {
if len(key) != keySize {
return nil, fmt.Errorf("ficha: key must be %d bytes, got %d", keySize, len(key))
}
if len(nonce) != nonceSize {
return nil, errors.New("ficha: invalid nonce size")
}
aead, err := chacha20poly1305.NewX(key)
if err != nil {
return nil, fmt.Errorf("ficha: init aead: %w", err)
}
plaintext, err := aead.Open(nil, nonce, ciphertext, aad)
if err != nil {
return nil, ErrInvalidToken
}
return plaintext, nil
}

155
crypto_test.go Normal file
View file

@ -0,0 +1,155 @@
package ficha
import (
"bytes"
"crypto/rand"
"errors"
"testing"
)
func testKey(t *testing.T) []byte {
t.Helper()
k := make([]byte, keySize)
if _, err := rand.Read(k); err != nil {
t.Fatalf("generate key: %v", err)
}
return k
}
func TestEncryptDecryptRoundtrip(t *testing.T) {
key := testKey(t)
plaintext := []byte("hello, ficha")
aad := []byte("v1.k1")
nonce, ciphertext, err := encrypt(key, plaintext, aad)
if err != nil {
t.Fatalf("encrypt: %v", err)
}
if len(nonce) != nonceSize {
t.Errorf("nonce size: got %d, want %d", len(nonce), nonceSize)
}
if len(ciphertext) != len(plaintext)+16 {
t.Errorf("ciphertext size: got %d, want %d", len(ciphertext), len(plaintext)+16)
}
got, err := decrypt(key, nonce, ciphertext, aad)
if err != nil {
t.Fatalf("decrypt: %v", err)
}
if !bytes.Equal(got, plaintext) {
t.Errorf("got %q, want %q", got, plaintext)
}
}
func TestEncryptDecryptEmptyPlaintext(t *testing.T) {
key := testKey(t)
nonce, ciphertext, err := encrypt(key, []byte{}, []byte("aad"))
if err != nil {
t.Fatalf("encrypt: %v", err)
}
got, err := decrypt(key, nonce, ciphertext, []byte("aad"))
if err != nil {
t.Fatalf("decrypt: %v", err)
}
if len(got) != 0 {
t.Errorf("expected empty plaintext, got %q", got)
}
}
func TestDecryptWrongKey(t *testing.T) {
key1 := testKey(t)
key2 := testKey(t)
nonce, ciphertext, err := encrypt(key1, []byte("secret"), []byte("v1.k1"))
if err != nil {
t.Fatalf("encrypt: %v", err)
}
if _, err := decrypt(key2, nonce, ciphertext, []byte("v1.k1")); !errors.Is(err, ErrInvalidToken) {
t.Errorf("expected ErrInvalidToken, got %v", err)
}
}
func TestDecryptTamperedCiphertext(t *testing.T) {
key := testKey(t)
nonce, ciphertext, err := encrypt(key, []byte("don't tamper"), []byte("v1.k1"))
if err != nil {
t.Fatalf("encrypt: %v", err)
}
tampered := append([]byte(nil), ciphertext...)
tampered[len(tampered)/2] ^= 0x01
if _, err := decrypt(key, nonce, tampered, []byte("v1.k1")); !errors.Is(err, ErrInvalidToken) {
t.Errorf("expected ErrInvalidToken, got %v", err)
}
}
func TestDecryptTamperedNonce(t *testing.T) {
key := testKey(t)
nonce, ciphertext, err := encrypt(key, []byte("nonce matters"), []byte("v1.k1"))
if err != nil {
t.Fatalf("encrypt: %v", err)
}
tampered := append([]byte(nil), nonce...)
tampered[0] ^= 0x01
if _, err := decrypt(key, tampered, ciphertext, []byte("v1.k1")); !errors.Is(err, ErrInvalidToken) {
t.Errorf("expected ErrInvalidToken, got %v", err)
}
}
func TestDecryptTamperedAAD(t *testing.T) {
key := testKey(t)
nonce, ciphertext, err := encrypt(key, []byte("aad authenticated"), []byte("v1.k1"))
if err != nil {
t.Fatalf("encrypt: %v", err)
}
if _, err := decrypt(key, nonce, ciphertext, []byte("v1.k2")); !errors.Is(err, ErrInvalidToken) {
t.Errorf("expected ErrInvalidToken, got %v", err)
}
}
func TestDecryptInvalidSizes(t *testing.T) {
tests := []struct {
name string
keyLen int
nonceLen int
}{
{"key too short", 16, nonceSize},
{"key too long", 64, nonceSize},
{"nonce too short", keySize, 12},
{"nonce too long", keySize, 32},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
key := make([]byte, tc.keyLen)
nonce := make([]byte, tc.nonceLen)
ct := make([]byte, 32)
if _, err := decrypt(key, nonce, ct, nil); err == nil {
t.Error("expected error, got nil")
}
})
}
}
func TestEncryptProducesUniqueNonces(t *testing.T) {
key := testKey(t)
const N = 1000
seen := make(map[string]bool, N)
for i := 0; i < N; i++ {
nonce, _, err := encrypt(key, []byte("same message"), []byte("v1.k1"))
if err != nil {
t.Fatalf("iter %d: %v", i, err)
}
if seen[string(nonce)] {
t.Fatalf("duplicate nonce after %d iterations", i)
}
seen[string(nonce)] = true
}
}

15
errors.go Normal file
View file

@ -0,0 +1,15 @@
package ficha
import "errors"
var (
// ErrInvalidToken means the token is malformed, has an unsupported
// version, fails descryption, or has been tampered with.
ErrInvalidToken = errors.New("ficha: invalid token")
// ErrExpiredToken means the token's expiry time has passed.
ErrExpiredToken = errors.New("ficha: token expired")
// ErrRevokedToken means the token's ID is in the revocation store.
ErrRevokedToken = errors.New("ficha: token revoked")
// ErrUnknownKey means the token references a key ID not in the keyring.
ErrUnknownKey = errors.New("ficha: unknown key id")
)

58
framing.go Normal file
View file

@ -0,0 +1,58 @@
package ficha
import (
"encoding/base64"
"strings"
)
const tokenVersion = "v1"
// encodeToken builds the wire format: v1.<keyID>.<base64url(none||ciphertext)>
func encodeToken(keyID string, nonce, ciphertext []byte) string {
body := make([]byte, 0, len(nonce)+len(ciphertext))
body = append(body, nonce...)
body = append(body, ciphertext...)
var b strings.Builder
b.Grow(len(tokenVersion) + 1 + len(keyID) + 1 + base64.RawURLEncoding.EncodedLen(len(body)))
b.WriteString(tokenVersion)
b.WriteByte('.')
b.WriteString(keyID)
b.WriteByte('.')
b.WriteString(base64.RawURLEncoding.EncodeToString(body))
return b.String()
}
// decodeToken parses the wire format and splits the body back into nonce and ciphertext.
// Returns ErrInvalidToken for any malformed input.
func decodeToken(token string) (keyID string, nonce, ciphertext []byte, err error) {
parts := strings.SplitN(token, ".", 3)
if len(parts) != 3 {
return "", nil, nil, ErrInvalidToken
}
if parts[0] != tokenVersion {
return "", nil, nil, ErrInvalidToken
}
if parts[1] == "" {
return "", nil, nil, ErrInvalidToken
}
body, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return "", nil, nil, ErrInvalidToken
}
// body must be at least nonce + auth tag (16 bytes)
if len(body) < nonceSize+16 {
return "", nil, nil, ErrInvalidToken
}
return parts[1], body[:nonceSize], body[nonceSize:], nil
}
// aadFor builds the AAD bytes for a given version+keyID. Both encrypt
// and decrypt must use the same construction.
func aadFor(keyID string) []byte {
return []byte(tokenVersion + "." + keyID)
}

164
framing_test.go Normal file
View file

@ -0,0 +1,164 @@
package ficha
import (
"bytes"
"errors"
"strings"
"testing"
)
func TestEncodeDecodeRoundtrip(t *testing.T) {
keyID := "k1"
nonce := bytes.Repeat([]byte{0xAB}, nonceSize)
ciphertext := []byte("not-real-ciphertext-but-long-enough-for-tag")
token := encodeToken(keyID, nonce, ciphertext)
if !strings.HasPrefix(token, "v1.k1.") {
t.Errorf("unexpected prefix: %s", token)
}
if strings.ContainsAny(token, "+/=") {
t.Errorf("token should be url-safe with no padding: %s", token)
}
gotKeyID, gotNonce, gotCT, err := decodeToken(token)
if err != nil {
t.Fatalf("decodeToken: %v", err)
}
if gotKeyID != keyID {
t.Errorf("keyID: got %q, want %q", gotKeyID, keyID)
}
if !bytes.Equal(gotNonce, nonce) {
t.Errorf("nonce: got %x, want %x", gotNonce, nonce)
}
if !bytes.Equal(gotCT, ciphertext) {
t.Errorf("ciphertext: got %q, want %q", gotCT, ciphertext)
}
}
func TestDecodeTokenMalformed(t *testing.T) {
// Build a baseline valid token to mutate.
validNonce := bytes.Repeat([]byte{0x01}, nonceSize)
validCT := bytes.Repeat([]byte{0x02}, 32) // > 16 byte tag minimum
valid := encodeToken("k1", validNonce, validCT)
tests := []struct {
name string
input string
}{
{"empty", ""},
{"only version", "v1"},
{"missing body", "v1.k1"},
{"too few parts", "v1.k1"},
{"unknown version", "v9.k1.YWFhYQ"},
{"empty key id", "v1..YWFhYQ"},
{"invalid base64", "v1.k1.!!!not_base64!!!"},
{"body too short", "v1.k1.YWFhYQ"}, // 4 bytes < nonceSize + 16
{"body too short after truncation", valid[:len("v1.k1.")+8]},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, _, _, err := decodeToken(tc.input)
if !errors.Is(err, ErrInvalidToken) {
t.Errorf("expected ErrInvalidToken, got %v", err)
}
})
}
}
func TestAadIsDeterministic(t *testing.T) {
a := aadFor("k1")
b := aadFor("k1")
if !bytes.Equal(a, b) {
t.Errorf("aadFor not deterministic: %q vs %q", a, b)
}
c := aadFor("k2")
if bytes.Equal(a, c) {
t.Error("different key IDs produced the same AAD")
}
}
func TestFramingEndToEndWithCrypto(t *testing.T) {
// This is the first test that exercises crypto + framing together.
// It catches mismatches between encrypt's AAD and decrypt's AAD.
key := testKey(t)
keyID := "k1"
plaintext := []byte("framing meets crypto")
aad := aadFor(keyID)
nonce, ct, err := encrypt(key, plaintext, aad)
if err != nil {
t.Fatalf("encrypt: %v", err)
}
token := encodeToken(keyID, nonce, ct)
gotKeyID, gotNonce, gotCT, err := decodeToken(token)
if err != nil {
t.Fatalf("decodeToken: %v", err)
}
got, err := decrypt(key, gotNonce, gotCT, aadFor(gotKeyID))
if err != nil {
t.Fatalf("decrypt: %v", err)
}
if !bytes.Equal(got, plaintext) {
t.Errorf("got %q, want %q", got, plaintext)
}
}
func TestFramingTamperKeyIDFailsAuth(t *testing.T) {
// If an attacker swaps the key ID in the wire format, AAD changes,
// and decrypt should fail. This is the whole point of putting
// keyID into the AAD.
key := testKey(t)
plaintext := []byte("bind keyID to ciphertext")
nonce, ct, err := encrypt(key, plaintext, aadFor("k1"))
if err != nil {
t.Fatalf("encrypt: %v", err)
}
// Build a token claiming to be from k1, then rewrite to k2.
original := encodeToken("k1", nonce, ct)
tampered := strings.Replace(original, "v1.k1.", "v1.k2.", 1)
_, gotNonce, gotCT, err := decodeToken(tampered)
if err != nil {
t.Fatalf("decodeToken: %v", err)
}
// Try to decrypt under the SAME key but with the tampered AAD.
if _, err := decrypt(key, gotNonce, gotCT, aadFor("k2")); !errors.Is(err, ErrInvalidToken) {
t.Errorf("expected ErrInvalidToken on key ID swap, got %v", err)
}
}
func TestFramingTruncationFailsAuth(t *testing.T) {
// A token that's structurally valid but had bytes lopped off
// should pass framing and fail authentication.
key := testKey(t)
plaintext := []byte("hello, ficha")
nonce, ct, err := encrypt(key, plaintext, aadFor("k1"))
if err != nil {
t.Fatalf("encrypt: %v", err)
}
token := encodeToken("k1", nonce, ct)
truncated := token[:len(token)-5]
// Framing may or may not accept it (depends on how much got chopped),
// but if it does, decryption MUST fail.
gotKeyID, gotNonce, gotCT, err := decodeToken(truncated)
if err != nil {
// Truncated past the minimum — framing rejected it. Also fine.
return
}
if _, err := decrypt(key, gotNonce, gotCT, aadFor(gotKeyID)); !errors.Is(err, ErrInvalidToken) {
t.Errorf("expected ErrInvalidToken on truncated ciphertext, got %v", err)
}
}