From a6aad1a1d6641f5e57202da37baacead42ad81b9 Mon Sep 17 00:00:00 2001 From: juancwu Date: Wed, 29 Apr 2026 01:28:05 +0000 Subject: [PATCH] add token framing, crypto and errors --- crypto.go | 56 +++++++++++++++++ crypto_test.go | 155 +++++++++++++++++++++++++++++++++++++++++++++ errors.go | 15 +++++ framing.go | 58 +++++++++++++++++ framing_test.go | 164 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 448 insertions(+) create mode 100644 crypto.go create mode 100644 crypto_test.go create mode 100644 errors.go create mode 100644 framing.go create mode 100644 framing_test.go diff --git a/crypto.go b/crypto.go new file mode 100644 index 0000000..cb48946 --- /dev/null +++ b/crypto.go @@ -0,0 +1,56 @@ +package ficha + +import ( + "crypto/rand" + "errors" + "fmt" + + "golang.org/x/crypto/chacha20poly1305" +) + +const ( + keySize = chacha20poly1305.KeySize + nonceSize = chacha20poly1305.NonceSizeX +) + +// encrypt seals plaintext under key. aad is authenticated but not encrypted. +// A fresh random nonce is generated per call; with a 24-byte nonce, random generation is collision-safe. +func encrypt(key, plaintext, aad []byte) (nonce, ciphertext []byte, err error) { + if len(key) != keySize { + return nil, nil, fmt.Errorf("ficha: key must be %d bytes, got %d", keySize, len(key)) + } + + aead, err := chacha20poly1305.NewX(key) + if err != nil { + return nil, nil, fmt.Errorf("ficha: init aead: %w", err) + } + + nonce = make([]byte, nonceSize) + if _, err := rand.Read(nonce); err != nil { + return nil, nil, fmt.Errorf("ficha: read random nonce: %w", err) + } + + ciphertext = aead.Seal(nil, nonce, plaintext, aad) + return nonce, ciphertext, nil +} + +// decrypt verifies and opens ciphertext. Any auth failure returns ErrInvalidToken +func decrypt(key, nonce, ciphertext, aad []byte) ([]byte, error) { + if len(key) != keySize { + return nil, fmt.Errorf("ficha: key must be %d bytes, got %d", keySize, len(key)) + } + if len(nonce) != nonceSize { + return nil, errors.New("ficha: invalid nonce size") + } + + aead, err := chacha20poly1305.NewX(key) + if err != nil { + return nil, fmt.Errorf("ficha: init aead: %w", err) + } + + plaintext, err := aead.Open(nil, nonce, ciphertext, aad) + if err != nil { + return nil, ErrInvalidToken + } + return plaintext, nil +} diff --git a/crypto_test.go b/crypto_test.go new file mode 100644 index 0000000..5f3c90a --- /dev/null +++ b/crypto_test.go @@ -0,0 +1,155 @@ +package ficha + +import ( + "bytes" + "crypto/rand" + "errors" + "testing" +) + +func testKey(t *testing.T) []byte { + t.Helper() + k := make([]byte, keySize) + if _, err := rand.Read(k); err != nil { + t.Fatalf("generate key: %v", err) + } + return k +} + +func TestEncryptDecryptRoundtrip(t *testing.T) { + key := testKey(t) + plaintext := []byte("hello, ficha") + aad := []byte("v1.k1") + + nonce, ciphertext, err := encrypt(key, plaintext, aad) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + + if len(nonce) != nonceSize { + t.Errorf("nonce size: got %d, want %d", len(nonce), nonceSize) + } + if len(ciphertext) != len(plaintext)+16 { + t.Errorf("ciphertext size: got %d, want %d", len(ciphertext), len(plaintext)+16) + } + + got, err := decrypt(key, nonce, ciphertext, aad) + if err != nil { + t.Fatalf("decrypt: %v", err) + } + if !bytes.Equal(got, plaintext) { + t.Errorf("got %q, want %q", got, plaintext) + } +} + +func TestEncryptDecryptEmptyPlaintext(t *testing.T) { + key := testKey(t) + nonce, ciphertext, err := encrypt(key, []byte{}, []byte("aad")) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + got, err := decrypt(key, nonce, ciphertext, []byte("aad")) + if err != nil { + t.Fatalf("decrypt: %v", err) + } + if len(got) != 0 { + t.Errorf("expected empty plaintext, got %q", got) + } +} + +func TestDecryptWrongKey(t *testing.T) { + key1 := testKey(t) + key2 := testKey(t) + + nonce, ciphertext, err := encrypt(key1, []byte("secret"), []byte("v1.k1")) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + + if _, err := decrypt(key2, nonce, ciphertext, []byte("v1.k1")); !errors.Is(err, ErrInvalidToken) { + t.Errorf("expected ErrInvalidToken, got %v", err) + } +} + +func TestDecryptTamperedCiphertext(t *testing.T) { + key := testKey(t) + nonce, ciphertext, err := encrypt(key, []byte("don't tamper"), []byte("v1.k1")) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + + tampered := append([]byte(nil), ciphertext...) + tampered[len(tampered)/2] ^= 0x01 + + if _, err := decrypt(key, nonce, tampered, []byte("v1.k1")); !errors.Is(err, ErrInvalidToken) { + t.Errorf("expected ErrInvalidToken, got %v", err) + } +} + +func TestDecryptTamperedNonce(t *testing.T) { + key := testKey(t) + nonce, ciphertext, err := encrypt(key, []byte("nonce matters"), []byte("v1.k1")) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + + tampered := append([]byte(nil), nonce...) + tampered[0] ^= 0x01 + + if _, err := decrypt(key, tampered, ciphertext, []byte("v1.k1")); !errors.Is(err, ErrInvalidToken) { + t.Errorf("expected ErrInvalidToken, got %v", err) + } +} + +func TestDecryptTamperedAAD(t *testing.T) { + key := testKey(t) + nonce, ciphertext, err := encrypt(key, []byte("aad authenticated"), []byte("v1.k1")) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + + if _, err := decrypt(key, nonce, ciphertext, []byte("v1.k2")); !errors.Is(err, ErrInvalidToken) { + t.Errorf("expected ErrInvalidToken, got %v", err) + } +} + +func TestDecryptInvalidSizes(t *testing.T) { + tests := []struct { + name string + keyLen int + nonceLen int + }{ + {"key too short", 16, nonceSize}, + {"key too long", 64, nonceSize}, + {"nonce too short", keySize, 12}, + {"nonce too long", keySize, 32}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + key := make([]byte, tc.keyLen) + nonce := make([]byte, tc.nonceLen) + ct := make([]byte, 32) + if _, err := decrypt(key, nonce, ct, nil); err == nil { + t.Error("expected error, got nil") + } + }) + } +} + +func TestEncryptProducesUniqueNonces(t *testing.T) { + key := testKey(t) + const N = 1000 + seen := make(map[string]bool, N) + + for i := 0; i < N; i++ { + nonce, _, err := encrypt(key, []byte("same message"), []byte("v1.k1")) + if err != nil { + t.Fatalf("iter %d: %v", i, err) + } + if seen[string(nonce)] { + t.Fatalf("duplicate nonce after %d iterations", i) + } + seen[string(nonce)] = true + } +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..6403d36 --- /dev/null +++ b/errors.go @@ -0,0 +1,15 @@ +package ficha + +import "errors" + +var ( + // ErrInvalidToken means the token is malformed, has an unsupported + // version, fails descryption, or has been tampered with. + ErrInvalidToken = errors.New("ficha: invalid token") + // ErrExpiredToken means the token's expiry time has passed. + ErrExpiredToken = errors.New("ficha: token expired") + // ErrRevokedToken means the token's ID is in the revocation store. + ErrRevokedToken = errors.New("ficha: token revoked") + // ErrUnknownKey means the token references a key ID not in the keyring. + ErrUnknownKey = errors.New("ficha: unknown key id") +) diff --git a/framing.go b/framing.go new file mode 100644 index 0000000..6a8426d --- /dev/null +++ b/framing.go @@ -0,0 +1,58 @@ +package ficha + +import ( + "encoding/base64" + "strings" +) + +const tokenVersion = "v1" + +// encodeToken builds the wire format: v1.. +func encodeToken(keyID string, nonce, ciphertext []byte) string { + body := make([]byte, 0, len(nonce)+len(ciphertext)) + body = append(body, nonce...) + body = append(body, ciphertext...) + + var b strings.Builder + b.Grow(len(tokenVersion) + 1 + len(keyID) + 1 + base64.RawURLEncoding.EncodedLen(len(body))) + b.WriteString(tokenVersion) + b.WriteByte('.') + b.WriteString(keyID) + b.WriteByte('.') + b.WriteString(base64.RawURLEncoding.EncodeToString(body)) + + return b.String() +} + +// decodeToken parses the wire format and splits the body back into nonce and ciphertext. +// Returns ErrInvalidToken for any malformed input. +func decodeToken(token string) (keyID string, nonce, ciphertext []byte, err error) { + parts := strings.SplitN(token, ".", 3) + if len(parts) != 3 { + return "", nil, nil, ErrInvalidToken + } + if parts[0] != tokenVersion { + return "", nil, nil, ErrInvalidToken + } + if parts[1] == "" { + return "", nil, nil, ErrInvalidToken + } + + body, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return "", nil, nil, ErrInvalidToken + } + + // body must be at least nonce + auth tag (16 bytes) + if len(body) < nonceSize+16 { + return "", nil, nil, ErrInvalidToken + } + + return parts[1], body[:nonceSize], body[nonceSize:], nil +} + +// aadFor builds the AAD bytes for a given version+keyID. Both encrypt +// and decrypt must use the same construction. +func aadFor(keyID string) []byte { + return []byte(tokenVersion + "." + keyID) +} diff --git a/framing_test.go b/framing_test.go new file mode 100644 index 0000000..62a7707 --- /dev/null +++ b/framing_test.go @@ -0,0 +1,164 @@ +package ficha + +import ( + "bytes" + "errors" + "strings" + "testing" +) + +func TestEncodeDecodeRoundtrip(t *testing.T) { + keyID := "k1" + nonce := bytes.Repeat([]byte{0xAB}, nonceSize) + ciphertext := []byte("not-real-ciphertext-but-long-enough-for-tag") + + token := encodeToken(keyID, nonce, ciphertext) + + if !strings.HasPrefix(token, "v1.k1.") { + t.Errorf("unexpected prefix: %s", token) + } + if strings.ContainsAny(token, "+/=") { + t.Errorf("token should be url-safe with no padding: %s", token) + } + + gotKeyID, gotNonce, gotCT, err := decodeToken(token) + if err != nil { + t.Fatalf("decodeToken: %v", err) + } + if gotKeyID != keyID { + t.Errorf("keyID: got %q, want %q", gotKeyID, keyID) + } + if !bytes.Equal(gotNonce, nonce) { + t.Errorf("nonce: got %x, want %x", gotNonce, nonce) + } + if !bytes.Equal(gotCT, ciphertext) { + t.Errorf("ciphertext: got %q, want %q", gotCT, ciphertext) + } +} + +func TestDecodeTokenMalformed(t *testing.T) { + // Build a baseline valid token to mutate. + validNonce := bytes.Repeat([]byte{0x01}, nonceSize) + validCT := bytes.Repeat([]byte{0x02}, 32) // > 16 byte tag minimum + valid := encodeToken("k1", validNonce, validCT) + + tests := []struct { + name string + input string + }{ + {"empty", ""}, + {"only version", "v1"}, + {"missing body", "v1.k1"}, + {"too few parts", "v1.k1"}, + {"unknown version", "v9.k1.YWFhYQ"}, + {"empty key id", "v1..YWFhYQ"}, + {"invalid base64", "v1.k1.!!!not_base64!!!"}, + {"body too short", "v1.k1.YWFhYQ"}, // 4 bytes < nonceSize + 16 + {"body too short after truncation", valid[:len("v1.k1.")+8]}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, _, _, err := decodeToken(tc.input) + if !errors.Is(err, ErrInvalidToken) { + t.Errorf("expected ErrInvalidToken, got %v", err) + } + }) + } +} + +func TestAadIsDeterministic(t *testing.T) { + a := aadFor("k1") + b := aadFor("k1") + if !bytes.Equal(a, b) { + t.Errorf("aadFor not deterministic: %q vs %q", a, b) + } + + c := aadFor("k2") + if bytes.Equal(a, c) { + t.Error("different key IDs produced the same AAD") + } +} + +func TestFramingEndToEndWithCrypto(t *testing.T) { + // This is the first test that exercises crypto + framing together. + // It catches mismatches between encrypt's AAD and decrypt's AAD. + key := testKey(t) + keyID := "k1" + plaintext := []byte("framing meets crypto") + + aad := aadFor(keyID) + nonce, ct, err := encrypt(key, plaintext, aad) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + + token := encodeToken(keyID, nonce, ct) + + gotKeyID, gotNonce, gotCT, err := decodeToken(token) + if err != nil { + t.Fatalf("decodeToken: %v", err) + } + + got, err := decrypt(key, gotNonce, gotCT, aadFor(gotKeyID)) + if err != nil { + t.Fatalf("decrypt: %v", err) + } + if !bytes.Equal(got, plaintext) { + t.Errorf("got %q, want %q", got, plaintext) + } +} + +func TestFramingTamperKeyIDFailsAuth(t *testing.T) { + // If an attacker swaps the key ID in the wire format, AAD changes, + // and decrypt should fail. This is the whole point of putting + // keyID into the AAD. + key := testKey(t) + plaintext := []byte("bind keyID to ciphertext") + + nonce, ct, err := encrypt(key, plaintext, aadFor("k1")) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + + // Build a token claiming to be from k1, then rewrite to k2. + original := encodeToken("k1", nonce, ct) + tampered := strings.Replace(original, "v1.k1.", "v1.k2.", 1) + + _, gotNonce, gotCT, err := decodeToken(tampered) + if err != nil { + t.Fatalf("decodeToken: %v", err) + } + + // Try to decrypt under the SAME key but with the tampered AAD. + if _, err := decrypt(key, gotNonce, gotCT, aadFor("k2")); !errors.Is(err, ErrInvalidToken) { + t.Errorf("expected ErrInvalidToken on key ID swap, got %v", err) + } +} + +func TestFramingTruncationFailsAuth(t *testing.T) { + // A token that's structurally valid but had bytes lopped off + // should pass framing and fail authentication. + key := testKey(t) + plaintext := []byte("hello, ficha") + + nonce, ct, err := encrypt(key, plaintext, aadFor("k1")) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + + token := encodeToken("k1", nonce, ct) + truncated := token[:len(token)-5] + + // Framing may or may not accept it (depends on how much got chopped), + // but if it does, decryption MUST fail. + gotKeyID, gotNonce, gotCT, err := decodeToken(truncated) + if err != nil { + // Truncated past the minimum — framing rejected it. Also fine. + return + } + + if _, err := decrypt(key, gotNonce, gotCT, aadFor(gotKeyID)); !errors.Is(err, ErrInvalidToken) { + t.Errorf("expected ErrInvalidToken on truncated ciphertext, got %v", err) + } +}