Add reflective struct-tag config loader
Implements conf.Load to populate tagged structs from a chain of Sources (env, .env, YAML/JSON/TOML, custom). Supports default values, slice separators, nested structs, pointer fields, and a Validator hook. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
3c806e6803
commit
c4ebd80669
15 changed files with 941 additions and 0 deletions
150
README.md
150
README.md
|
|
@ -1,2 +1,152 @@
|
|||
# conf
|
||||
|
||||
Tiny, reflective config loader for Go. Define a struct, tag the fields, call
|
||||
`conf.Load`. Values come from one or more `Source`s — env vars, `.env` files,
|
||||
YAML/JSON/TOML, or anything you implement yourself.
|
||||
|
||||
## Install
|
||||
|
||||
```sh
|
||||
go get git.juancwu.dev/juancwu/conf
|
||||
```
|
||||
|
||||
Requires Go 1.26+.
|
||||
|
||||
## Usage
|
||||
|
||||
Tag your struct and call `Load`. The first source that returns a value wins;
|
||||
if no source has it, the `default:` tag is used.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/conf"
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
BindAddr string `env:"BIND_ADDR" default:":8080"`
|
||||
DatabaseURL string `env:"DATABASE_URL"`
|
||||
SessionCookieSecure bool `env:"SESSION_COOKIE_SECURE" default:"true"`
|
||||
SessionIdleTTL time.Duration `env:"SESSION_IDLE_TTL" default:"24h"`
|
||||
JWTSecret []byte `env:"JWT_SECRET"`
|
||||
WorkerConcurrency int `env:"WORKER_CONCURRENCY" default:"4"`
|
||||
Tags []string `env:"TAGS" sep:","`
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
const op = "config.Validate"
|
||||
if strings.TrimSpace(c.DatabaseURL) == "" {
|
||||
return errx.New(op, "DATABASE_URL is required")
|
||||
}
|
||||
if len(c.JWTSecret) == 0 {
|
||||
return errx.New(op, "JWT_SECRET is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
var cfg Config
|
||||
if err := conf.Load(&cfg, conf.EnvSource()); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
_ = cfg
|
||||
}
|
||||
```
|
||||
|
||||
### Multiple sources
|
||||
|
||||
Sources are tried in order — earliest wins. Put highest-precedence first:
|
||||
|
||||
```go
|
||||
dotenv, _ := conf.DotEnvFile(".env")
|
||||
yamlBase, _ := conf.YAMLFile("config.yaml")
|
||||
|
||||
err := conf.Load(&cfg,
|
||||
conf.EnvSource(), // process env wins
|
||||
dotenv, // .env overrides yaml
|
||||
yamlBase, // base defaults
|
||||
)
|
||||
```
|
||||
|
||||
### `.env` files
|
||||
|
||||
```go
|
||||
src, err := conf.DotEnvFile(".env")
|
||||
```
|
||||
|
||||
Supports `KEY=value`, `export KEY=value`, `#` comments, single/double quotes,
|
||||
and `\n \t \r \\ \"` escapes inside double quotes.
|
||||
|
||||
### YAML / JSON / TOML
|
||||
|
||||
File sources are flattened into env-style keys: nested maps join with `_` and
|
||||
keys are uppercased.
|
||||
|
||||
```yaml
|
||||
# config.yaml
|
||||
bind_addr: ":9000"
|
||||
session:
|
||||
idle_ttl: 24h
|
||||
cookie:
|
||||
secure: true
|
||||
```
|
||||
|
||||
becomes `BIND_ADDR`, `SESSION_IDLE_TTL`, `SESSION_COOKIE_SECURE` — matching the
|
||||
same `env:` tags on your struct. `JSONFile` / `TOMLFile` work the same way.
|
||||
Each loader has a `*Reader` variant for `io.Reader`.
|
||||
|
||||
### Custom sources
|
||||
|
||||
```go
|
||||
type Source interface {
|
||||
Lookup(key string) (string, bool)
|
||||
}
|
||||
```
|
||||
|
||||
Implement that to pull from Vault, SSM, a database, or wherever.
|
||||
|
||||
## Tags
|
||||
|
||||
| Tag | Purpose |
|
||||
|---------------|----------------------------------------------------------|
|
||||
| `env:"KEY"` | Key looked up in each `Source`. Required to bind. |
|
||||
| `env:"-"` | Skip the field. |
|
||||
| `default:"v"` | Used when no `Source` returns the key. |
|
||||
| `sep:","` | Slice separator (default `,`). |
|
||||
|
||||
Untagged struct fields are recursed into, so you can group related values:
|
||||
|
||||
```go
|
||||
type Config struct {
|
||||
HTTP HTTP
|
||||
}
|
||||
type HTTP struct {
|
||||
Addr string `env:"HTTP_ADDR" default:":8080"`
|
||||
}
|
||||
```
|
||||
|
||||
## Supported field types
|
||||
|
||||
`string`, `[]byte`, `bool`, all sized `int`/`uint`, `float32`/`float64`,
|
||||
`time.Duration`, `time.Time` (RFC3339), slices of any scalar above, pointers
|
||||
to any of the above (left `nil` if unset and no default), and nested structs.
|
||||
|
||||
## Validation
|
||||
|
||||
If your destination type implements `Validate() error`, it's called after
|
||||
fields are populated. Return an error to fail the load.
|
||||
|
||||
## Errors
|
||||
|
||||
All errors flow through [`errx`](https://git.juancwu.dev/juancwu/errx) with op
|
||||
codes like `conf.Load`, so `errors.Is` / `errors.As` work as usual.
|
||||
|
||||
## License
|
||||
|
||||
MIT — see `LICENSE`.
|
||||
|
|
|
|||
110
conf.go
Normal file
110
conf.go
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
// Package conf loads configuration values from one or more Sources into a
|
||||
// tagged Go struct.
|
||||
//
|
||||
// Field tags:
|
||||
//
|
||||
// env:"KEY" key looked up in each Source (required to bind a field)
|
||||
// default:"v" raw value used when no Source returns the key
|
||||
// sep:"," separator for slice fields (default ",")
|
||||
// env:"-" skip the field
|
||||
//
|
||||
// Sources are tried in order; the first one returning a value wins.
|
||||
// If the destination type implements Validator, Validate is called after the
|
||||
// fields are populated.
|
||||
package conf
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
// Validator is implemented by config types that want a post-load hook.
|
||||
type Validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// Load populates dst from sources. dst must be a non-nil pointer to a struct.
|
||||
func Load(dst any, sources ...Source) error {
|
||||
const op = "conf.Load"
|
||||
|
||||
if dst == nil {
|
||||
return errx.New(op, "dst is nil")
|
||||
}
|
||||
v := reflect.ValueOf(dst)
|
||||
if v.Kind() != reflect.Pointer || v.IsNil() {
|
||||
return errx.New(op, "dst must be a non-nil pointer to a struct")
|
||||
}
|
||||
v = v.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
return errx.New(op, "dst must point to a struct")
|
||||
}
|
||||
|
||||
if err := walk(v, sources); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
|
||||
if val, ok := dst.(Validator); ok {
|
||||
if err := val.Validate(); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func walk(v reflect.Value, sources []Source) error {
|
||||
const op = "conf.walk"
|
||||
|
||||
t := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
sf := t.Field(i)
|
||||
if !sf.IsExported() {
|
||||
continue
|
||||
}
|
||||
fv := v.Field(i)
|
||||
|
||||
// Recurse into nested structs (and pointers to structs) that have no env tag.
|
||||
key := sf.Tag.Get("env")
|
||||
if key == "" {
|
||||
switch {
|
||||
case fv.Kind() == reflect.Struct:
|
||||
if err := walk(fv, sources); err != nil {
|
||||
return err
|
||||
}
|
||||
case fv.Kind() == reflect.Pointer && fv.Type().Elem().Kind() == reflect.Struct:
|
||||
if fv.IsNil() {
|
||||
fv.Set(reflect.New(fv.Type().Elem()))
|
||||
}
|
||||
if err := walk(fv.Elem(), sources); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if key == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
raw, ok := lookup(sources, key)
|
||||
if !ok {
|
||||
raw, ok = sf.Tag.Lookup("default")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := assignString(fv, raw, sf.Tag.Get("sep")); err != nil {
|
||||
return errx.Wrapf(op, err, "field %s (%s)", sf.Name, key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookup(sources []Source, key string) (string, bool) {
|
||||
for _, s := range sources {
|
||||
if v, ok := s.Lookup(key); ok {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
134
conf_test.go
Normal file
134
conf_test.go
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type allTypes struct {
|
||||
S string `env:"S" default:"hi"`
|
||||
B bool `env:"B" default:"true"`
|
||||
I int `env:"I" default:"7"`
|
||||
I64 int64 `env:"I64" default:"100"`
|
||||
U uint32 `env:"U" default:"3"`
|
||||
F float64 `env:"F" default:"1.5"`
|
||||
D time.Duration `env:"D" default:"5s"`
|
||||
BS []byte `env:"BS" default:"abc"`
|
||||
List []string `env:"LIST" default:"a,b,c"`
|
||||
Skip string `env:"-"`
|
||||
None string
|
||||
}
|
||||
|
||||
func TestLoadDefaults(t *testing.T) {
|
||||
var c allTypes
|
||||
if err := Load(&c); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c.S != "hi" || !c.B || c.I != 7 || c.I64 != 100 || c.U != 3 || c.F != 1.5 ||
|
||||
c.D != 5*time.Second || string(c.BS) != "abc" ||
|
||||
strings.Join(c.List, ",") != "a,b,c" {
|
||||
t.Fatalf("defaults not applied: %+v", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecedence(t *testing.T) {
|
||||
high := MapSource(map[string]string{"S": "high"})
|
||||
low := MapSource(map[string]string{"S": "low", "I": "42"})
|
||||
|
||||
var c allTypes
|
||||
if err := Load(&c, high, low); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c.S != "high" {
|
||||
t.Fatalf("S=%q want high", c.S)
|
||||
}
|
||||
if c.I != 42 {
|
||||
t.Fatalf("I=%d want 42", c.I)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPointerField(t *testing.T) {
|
||||
type p struct {
|
||||
X *int `env:"X"`
|
||||
Y *int `env:"Y"`
|
||||
}
|
||||
var c p
|
||||
if err := Load(&c, MapSource(map[string]string{"X": "9"})); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c.X == nil || *c.X != 9 {
|
||||
t.Fatalf("X=%v", c.X)
|
||||
}
|
||||
if c.Y != nil {
|
||||
t.Fatalf("Y should be nil, got %v", *c.Y)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedStruct(t *testing.T) {
|
||||
type Inner struct {
|
||||
A string `env:"A" default:"x"`
|
||||
}
|
||||
type Outer struct {
|
||||
Inner Inner
|
||||
B string `env:"B" default:"y"`
|
||||
}
|
||||
var o Outer
|
||||
if err := Load(&o); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if o.Inner.A != "x" || o.B != "y" {
|
||||
t.Fatalf("%+v", o)
|
||||
}
|
||||
}
|
||||
|
||||
type withValidate struct {
|
||||
URL string `env:"URL"`
|
||||
}
|
||||
|
||||
func (w *withValidate) Validate() error {
|
||||
if w.URL == "" {
|
||||
return errors.New("URL required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestValidatorFires(t *testing.T) {
|
||||
var w withValidate
|
||||
err := Load(&w)
|
||||
if err == nil || !strings.Contains(err.Error(), "URL required") {
|
||||
t.Fatalf("want URL required error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseError(t *testing.T) {
|
||||
type bad struct {
|
||||
N int `env:"N"`
|
||||
}
|
||||
var b bad
|
||||
err := Load(&b, MapSource(map[string]string{"N": "notanumber"}))
|
||||
if err == nil || !strings.Contains(err.Error(), "N") {
|
||||
t.Fatalf("want field N error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRejectsNonPointer(t *testing.T) {
|
||||
var c allTypes
|
||||
if err := Load(c); err == nil {
|
||||
t.Fatal("expected error for non-pointer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSepOverride(t *testing.T) {
|
||||
type s struct {
|
||||
L []string `env:"L" sep:"|"`
|
||||
}
|
||||
var x s
|
||||
if err := Load(&x, MapSource(map[string]string{"L": "a|b|c"})); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Join(x.L, ",") != "a,b,c" {
|
||||
t.Fatalf("got %v", x.L)
|
||||
}
|
||||
}
|
||||
103
dotenv.go
Normal file
103
dotenv.go
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
// DotEnvFile loads KEY=VALUE pairs from path. Lines starting with # and blank
|
||||
// lines are ignored. Values may be wrapped in single or double quotes; double
|
||||
// quotes honor \n, \t, \r, \\, and \" escapes.
|
||||
func DotEnvFile(path string) (Source, error) {
|
||||
const op = "conf.DotEnvFile"
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, errx.Wrapf(op, err, "open %s", path)
|
||||
}
|
||||
defer f.Close()
|
||||
return DotEnvReader(f)
|
||||
}
|
||||
|
||||
// DotEnvReader is DotEnvFile for an arbitrary reader.
|
||||
func DotEnvReader(r io.Reader) (Source, error) {
|
||||
const op = "conf.DotEnvReader"
|
||||
|
||||
m := map[string]string{}
|
||||
sc := bufio.NewScanner(r)
|
||||
lineNo := 0
|
||||
for sc.Scan() {
|
||||
lineNo++
|
||||
line := strings.TrimSpace(sc.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
line = strings.TrimPrefix(line, "export ")
|
||||
eq := strings.IndexByte(line, '=')
|
||||
if eq < 0 {
|
||||
return nil, errx.Newf(op, "line %d: missing '='", lineNo)
|
||||
}
|
||||
key := strings.TrimSpace(line[:eq])
|
||||
val := strings.TrimSpace(line[eq+1:])
|
||||
// strip trailing inline comment for unquoted values
|
||||
if !isQuoted(val) {
|
||||
if i := strings.Index(val, " #"); i >= 0 {
|
||||
val = strings.TrimSpace(val[:i])
|
||||
}
|
||||
}
|
||||
v, err := unquote(val)
|
||||
if err != nil {
|
||||
return nil, errx.Wrapf(op, err, "line %d", lineNo)
|
||||
}
|
||||
m[key] = v
|
||||
}
|
||||
if err := sc.Err(); err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
return MapSource(m), nil
|
||||
}
|
||||
|
||||
func isQuoted(s string) bool {
|
||||
if len(s) < 2 {
|
||||
return false
|
||||
}
|
||||
return (s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')
|
||||
}
|
||||
|
||||
func unquote(s string) (string, error) {
|
||||
if len(s) >= 2 && s[0] == '\'' && s[len(s)-1] == '\'' {
|
||||
return s[1 : len(s)-1], nil
|
||||
}
|
||||
if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' {
|
||||
inner := s[1 : len(s)-1]
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(inner); i++ {
|
||||
c := inner[i]
|
||||
if c != '\\' || i+1 >= len(inner) {
|
||||
b.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
i++
|
||||
switch inner[i] {
|
||||
case 'n':
|
||||
b.WriteByte('\n')
|
||||
case 't':
|
||||
b.WriteByte('\t')
|
||||
case 'r':
|
||||
b.WriteByte('\r')
|
||||
case '\\':
|
||||
b.WriteByte('\\')
|
||||
case '"':
|
||||
b.WriteByte('"')
|
||||
default:
|
||||
b.WriteByte('\\')
|
||||
b.WriteByte(inner[i])
|
||||
}
|
||||
}
|
||||
return b.String(), nil
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
35
dotenv_test.go
Normal file
35
dotenv_test.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDotEnvReader(t *testing.T) {
|
||||
in := `# comment
|
||||
FOO=bar
|
||||
export QUOTED="hello\nworld"
|
||||
SINGLE='ab cd'
|
||||
INLINE=plain # trailing
|
||||
EMPTY=
|
||||
`
|
||||
src, err := DotEnvReader(strings.NewReader(in))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cases := map[string]string{
|
||||
"FOO": "bar",
|
||||
"QUOTED": "hello\nworld",
|
||||
"SINGLE": "ab cd",
|
||||
"INLINE": "plain",
|
||||
}
|
||||
for k, want := range cases {
|
||||
got, ok := src.Lookup(k)
|
||||
if !ok || got != want {
|
||||
t.Errorf("%s = %q ok=%v, want %q", k, got, ok, want)
|
||||
}
|
||||
}
|
||||
if _, ok := src.Lookup("EMPTY"); ok {
|
||||
t.Errorf("EMPTY should be absent")
|
||||
}
|
||||
}
|
||||
46
example_test.go
Normal file
46
example_test.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
package conf_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/conf"
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
BindAddr string `env:"BIND_ADDR" default:":8080"`
|
||||
DatabaseURL string `env:"DATABASE_URL"`
|
||||
SessionCookieSecure bool `env:"SESSION_COOKIE_SECURE" default:"true"`
|
||||
SessionIdleTTL time.Duration `env:"SESSION_IDLE_TTL" default:"24h"`
|
||||
JWTSecret []byte `env:"JWT_SECRET"`
|
||||
WorkerConcurrency int `env:"WORKER_CONCURRENCY" default:"4"`
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
const op = "config.Validate"
|
||||
if strings.TrimSpace(c.DatabaseURL) == "" {
|
||||
return errx.New(op, "DATABASE_URL is required")
|
||||
}
|
||||
if len(c.JWTSecret) == 0 {
|
||||
return errx.New(op, "JWT_SECRET is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExampleLoad() {
|
||||
src := conf.MapSource(map[string]string{
|
||||
"DATABASE_URL": "postgres://localhost/x",
|
||||
"JWT_SECRET": "shh",
|
||||
})
|
||||
|
||||
var cfg Config
|
||||
if err := conf.Load(&cfg, src); err != nil {
|
||||
fmt.Println("err:", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("bind=%s workers=%d idle=%s secure=%v",
|
||||
cfg.BindAddr, cfg.WorkerConcurrency, cfg.SessionIdleTTL, cfg.SessionCookieSecure)
|
||||
// Output: bind=:8080 workers=4 idle=24h0m0s secure=true
|
||||
}
|
||||
82
files_test.go
Normal file
82
files_test.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestYAMLFlatten(t *testing.T) {
|
||||
src, err := YAMLReader(strings.NewReader(`
|
||||
bind_addr: ":9000"
|
||||
session:
|
||||
idle_ttl: 1h
|
||||
cookie:
|
||||
secure: true
|
||||
list: [a, b, c]
|
||||
`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
checks := map[string]string{
|
||||
"BIND_ADDR": ":9000",
|
||||
"SESSION_IDLE_TTL": "1h0m0s", // yaml.v3 decodes as time.Duration string? actually it stays string
|
||||
"SESSION_COOKIE_SECURE": "true",
|
||||
"LIST": "a,b,c",
|
||||
}
|
||||
for k, want := range checks {
|
||||
got, ok := src.Lookup(k)
|
||||
if !ok {
|
||||
t.Errorf("%s missing", k)
|
||||
continue
|
||||
}
|
||||
if k == "SESSION_IDLE_TTL" {
|
||||
// yaml decodes "1h" as plain string
|
||||
if got != "1h" && got != want {
|
||||
t.Errorf("%s = %q", k, got)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("%s = %q want %q", k, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONFlatten(t *testing.T) {
|
||||
src, err := JSONReader(strings.NewReader(`{"a":{"b":"c"},"n":7,"arr":[1,2,3]}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if v, _ := src.Lookup("A_B"); v != "c" {
|
||||
t.Errorf("A_B=%q", v)
|
||||
}
|
||||
if v, _ := src.Lookup("N"); v != "7" {
|
||||
t.Errorf("N=%q", v)
|
||||
}
|
||||
if v, _ := src.Lookup("ARR"); v != "1,2,3" {
|
||||
t.Errorf("ARR=%q", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOMLFlatten(t *testing.T) {
|
||||
src, err := TOMLReader(strings.NewReader(`
|
||||
bind_addr = ":9000"
|
||||
[session]
|
||||
idle_ttl = "1h"
|
||||
[session.cookie]
|
||||
secure = true
|
||||
`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for k, want := range map[string]string{
|
||||
"BIND_ADDR": ":9000",
|
||||
"SESSION_IDLE_TTL": "1h",
|
||||
"SESSION_COOKIE_SECURE": "true",
|
||||
} {
|
||||
got, ok := src.Lookup(k)
|
||||
if !ok || got != want {
|
||||
t.Errorf("%s = %q ok=%v want %q", k, got, ok, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
44
flatten.go
Normal file
44
flatten.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// flatten walks a nested map[string]any (as produced by yaml/json/toml decoders)
|
||||
// and returns a flat key->string map. Keys are uppercased and joined with "_".
|
||||
// Scalar values are stringified via fmt.Sprint. Slices are joined with ",".
|
||||
func flatten(in map[string]any) map[string]string {
|
||||
out := map[string]string{}
|
||||
flattenInto(out, "", in)
|
||||
return out
|
||||
}
|
||||
|
||||
func flattenInto(out map[string]string, prefix string, in map[string]any) {
|
||||
for k, v := range in {
|
||||
key := strings.ToUpper(k)
|
||||
if prefix != "" {
|
||||
key = prefix + "_" + key
|
||||
}
|
||||
switch x := v.(type) {
|
||||
case map[string]any:
|
||||
flattenInto(out, key, x)
|
||||
case map[any]any: // yaml.v3 with non-string keys
|
||||
m := make(map[string]any, len(x))
|
||||
for kk, vv := range x {
|
||||
m[fmt.Sprint(kk)] = vv
|
||||
}
|
||||
flattenInto(out, key, m)
|
||||
case []any:
|
||||
parts := make([]string, len(x))
|
||||
for i, e := range x {
|
||||
parts[i] = fmt.Sprint(e)
|
||||
}
|
||||
out[key] = strings.Join(parts, ",")
|
||||
case nil:
|
||||
// skip
|
||||
default:
|
||||
out[key] = fmt.Sprint(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
9
go.mod
Normal file
9
go.mod
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
module git.juancwu.dev/juancwu/conf
|
||||
|
||||
go 1.26.2
|
||||
|
||||
require (
|
||||
git.juancwu.dev/juancwu/errx v0.1.0
|
||||
github.com/BurntSushi/toml v1.6.0
|
||||
github.com/goccy/go-yaml v1.19.2
|
||||
)
|
||||
6
go.sum
Normal file
6
go.sum
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
git.juancwu.dev/juancwu/errx v0.1.0 h1:92yA0O1BkKGXcoEiWtxwH/ztXCjoV1KSTMtKpm3gd2w=
|
||||
git.juancwu.dev/juancwu/errx v0.1.0/go.mod h1:7jNhBOwcZ/q7zDD6mln3QCJBYZ8T6h+dAdxVfykprTk=
|
||||
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
||||
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
33
json.go
Normal file
33
json.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
// JSONFile loads a JSON object and exposes it as a flat key->string Source.
|
||||
func JSONFile(path string) (Source, error) {
|
||||
const op = "conf.JSONFile"
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, errx.Wrapf(op, err, "open %s", path)
|
||||
}
|
||||
defer f.Close()
|
||||
return JSONReader(f)
|
||||
}
|
||||
|
||||
// JSONReader is JSONFile for an arbitrary reader.
|
||||
func JSONReader(r io.Reader) (Source, error) {
|
||||
const op = "conf.JSONReader"
|
||||
var raw map[string]any
|
||||
if err := json.NewDecoder(r).Decode(&raw); err != nil {
|
||||
if err == io.EOF {
|
||||
return MapSource(nil), nil
|
||||
}
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
return MapSource(flatten(raw)), nil
|
||||
}
|
||||
92
parse.go
Normal file
92
parse.go
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
var (
|
||||
durationType = reflect.TypeOf(time.Duration(0))
|
||||
timeType = reflect.TypeOf(time.Time{})
|
||||
byteSliceTyp = reflect.TypeOf([]byte(nil))
|
||||
)
|
||||
|
||||
// assignString parses raw into the value pointed to by v. v must be settable.
|
||||
func assignString(v reflect.Value, raw, sep string) error {
|
||||
const op = "conf.assignString"
|
||||
|
||||
if v.Kind() == reflect.Pointer {
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
return assignString(v.Elem(), raw, sep)
|
||||
}
|
||||
|
||||
switch v.Type() {
|
||||
case durationType:
|
||||
d, err := time.ParseDuration(raw)
|
||||
if err != nil {
|
||||
return errx.Wrapf(op, err, "parse duration %q", raw)
|
||||
}
|
||||
v.SetInt(int64(d))
|
||||
return nil
|
||||
case timeType:
|
||||
t, err := time.Parse(time.RFC3339, raw)
|
||||
if err != nil {
|
||||
return errx.Wrapf(op, err, "parse time %q", raw)
|
||||
}
|
||||
v.Set(reflect.ValueOf(t))
|
||||
return nil
|
||||
case byteSliceTyp:
|
||||
v.SetBytes([]byte(raw))
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.String:
|
||||
v.SetString(raw)
|
||||
case reflect.Bool:
|
||||
b, err := strconv.ParseBool(raw)
|
||||
if err != nil {
|
||||
return errx.Wrapf(op, err, "parse bool %q", raw)
|
||||
}
|
||||
v.SetBool(b)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
n, err := strconv.ParseInt(raw, 10, v.Type().Bits())
|
||||
if err != nil {
|
||||
return errx.Wrapf(op, err, "parse int %q", raw)
|
||||
}
|
||||
v.SetInt(n)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
n, err := strconv.ParseUint(raw, 10, v.Type().Bits())
|
||||
if err != nil {
|
||||
return errx.Wrapf(op, err, "parse uint %q", raw)
|
||||
}
|
||||
v.SetUint(n)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
f, err := strconv.ParseFloat(raw, v.Type().Bits())
|
||||
if err != nil {
|
||||
return errx.Wrapf(op, err, "parse float %q", raw)
|
||||
}
|
||||
v.SetFloat(f)
|
||||
case reflect.Slice:
|
||||
if sep == "" {
|
||||
sep = ","
|
||||
}
|
||||
parts := strings.Split(raw, sep)
|
||||
out := reflect.MakeSlice(v.Type(), len(parts), len(parts))
|
||||
for i, p := range parts {
|
||||
if err := assignString(out.Index(i), strings.TrimSpace(p), sep); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
v.Set(out)
|
||||
default:
|
||||
return errx.Newf(op, "unsupported field type %s", v.Type())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
34
source.go
Normal file
34
source.go
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
package conf
|
||||
|
||||
import "os"
|
||||
|
||||
// Source returns a string value for a key, or false if the key is not present.
|
||||
type Source interface {
|
||||
Lookup(key string) (string, bool)
|
||||
}
|
||||
|
||||
type envSource struct{}
|
||||
|
||||
func (envSource) Lookup(key string) (string, bool) {
|
||||
v, ok := os.LookupEnv(key)
|
||||
if !ok || v == "" {
|
||||
return "", false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
// EnvSource reads from the process environment. Empty values are treated as absent.
|
||||
func EnvSource() Source { return envSource{} }
|
||||
|
||||
type mapSource map[string]string
|
||||
|
||||
func (m mapSource) Lookup(key string) (string, bool) {
|
||||
v, ok := m[key]
|
||||
if !ok || v == "" {
|
||||
return "", false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
// MapSource wraps a key/value map as a Source. The map is not copied.
|
||||
func MapSource(m map[string]string) Source { return mapSource(m) }
|
||||
30
toml.go
Normal file
30
toml.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
// TOMLFile loads a TOML document and exposes it as a flat key->string Source.
|
||||
func TOMLFile(path string) (Source, error) {
|
||||
const op = "conf.TOMLFile"
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, errx.Wrapf(op, err, "open %s", path)
|
||||
}
|
||||
defer f.Close()
|
||||
return TOMLReader(f)
|
||||
}
|
||||
|
||||
// TOMLReader is TOMLFile for an arbitrary reader.
|
||||
func TOMLReader(r io.Reader) (Source, error) {
|
||||
const op = "conf.TOMLReader"
|
||||
var raw map[string]any
|
||||
if _, err := toml.NewDecoder(r).Decode(&raw); err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
return MapSource(flatten(raw)), nil
|
||||
}
|
||||
33
yaml.go
Normal file
33
yaml.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/goccy/go-yaml"
|
||||
)
|
||||
|
||||
// YAMLFile loads a YAML document and exposes it as a flat key->string Source.
|
||||
func YAMLFile(path string) (Source, error) {
|
||||
const op = "conf.YAMLFile"
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, errx.Wrapf(op, err, "open %s", path)
|
||||
}
|
||||
defer f.Close()
|
||||
return YAMLReader(f)
|
||||
}
|
||||
|
||||
// YAMLReader is YAMLFile for an arbitrary reader.
|
||||
func YAMLReader(r io.Reader) (Source, error) {
|
||||
const op = "conf.YAMLReader"
|
||||
var raw map[string]any
|
||||
if err := yaml.NewDecoder(r).Decode(&raw); err != nil {
|
||||
if err == io.EOF {
|
||||
return MapSource(nil), nil
|
||||
}
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
return MapSource(flatten(raw)), nil
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue