diff --git a/README.md b/README.md index 28aec1b..6215c6f 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,152 @@ # conf +Tiny, reflective config loader for Go. Define a struct, tag the fields, call +`conf.Load`. Values come from one or more `Source`s — env vars, `.env` files, +YAML/JSON/TOML, or anything you implement yourself. + +## Install + +```sh +go get git.juancwu.dev/juancwu/conf +``` + +Requires Go 1.26+. + +## Usage + +Tag your struct and call `Load`. The first source that returns a value wins; +if no source has it, the `default:` tag is used. + +```go +package main + +import ( + "log" + "strings" + "time" + + "git.juancwu.dev/juancwu/conf" + "git.juancwu.dev/juancwu/errx" +) + +type Config struct { + BindAddr string `env:"BIND_ADDR" default:":8080"` + DatabaseURL string `env:"DATABASE_URL"` + SessionCookieSecure bool `env:"SESSION_COOKIE_SECURE" default:"true"` + SessionIdleTTL time.Duration `env:"SESSION_IDLE_TTL" default:"24h"` + JWTSecret []byte `env:"JWT_SECRET"` + WorkerConcurrency int `env:"WORKER_CONCURRENCY" default:"4"` + Tags []string `env:"TAGS" sep:","` +} + +func (c *Config) Validate() error { + const op = "config.Validate" + if strings.TrimSpace(c.DatabaseURL) == "" { + return errx.New(op, "DATABASE_URL is required") + } + if len(c.JWTSecret) == 0 { + return errx.New(op, "JWT_SECRET is required") + } + return nil +} + +func main() { + var cfg Config + if err := conf.Load(&cfg, conf.EnvSource()); err != nil { + log.Fatal(err) + } + _ = cfg +} +``` + +### Multiple sources + +Sources are tried in order — earliest wins. Put highest-precedence first: + +```go +dotenv, _ := conf.DotEnvFile(".env") +yamlBase, _ := conf.YAMLFile("config.yaml") + +err := conf.Load(&cfg, + conf.EnvSource(), // process env wins + dotenv, // .env overrides yaml + yamlBase, // base defaults +) +``` + +### `.env` files + +```go +src, err := conf.DotEnvFile(".env") +``` + +Supports `KEY=value`, `export KEY=value`, `#` comments, single/double quotes, +and `\n \t \r \\ \"` escapes inside double quotes. + +### YAML / JSON / TOML + +File sources are flattened into env-style keys: nested maps join with `_` and +keys are uppercased. + +```yaml +# config.yaml +bind_addr: ":9000" +session: + idle_ttl: 24h + cookie: + secure: true +``` + +becomes `BIND_ADDR`, `SESSION_IDLE_TTL`, `SESSION_COOKIE_SECURE` — matching the +same `env:` tags on your struct. `JSONFile` / `TOMLFile` work the same way. +Each loader has a `*Reader` variant for `io.Reader`. + +### Custom sources + +```go +type Source interface { + Lookup(key string) (string, bool) +} +``` + +Implement that to pull from Vault, SSM, a database, or wherever. + +## Tags + +| Tag | Purpose | +|---------------|----------------------------------------------------------| +| `env:"KEY"` | Key looked up in each `Source`. Required to bind. | +| `env:"-"` | Skip the field. | +| `default:"v"` | Used when no `Source` returns the key. | +| `sep:","` | Slice separator (default `,`). | + +Untagged struct fields are recursed into, so you can group related values: + +```go +type Config struct { + HTTP HTTP +} +type HTTP struct { + Addr string `env:"HTTP_ADDR" default:":8080"` +} +``` + +## Supported field types + +`string`, `[]byte`, `bool`, all sized `int`/`uint`, `float32`/`float64`, +`time.Duration`, `time.Time` (RFC3339), slices of any scalar above, pointers +to any of the above (left `nil` if unset and no default), and nested structs. + +## Validation + +If your destination type implements `Validate() error`, it's called after +fields are populated. Return an error to fail the load. + +## Errors + +All errors flow through [`errx`](https://git.juancwu.dev/juancwu/errx) with op +codes like `conf.Load`, so `errors.Is` / `errors.As` work as usual. + +## License + +MIT — see `LICENSE`. diff --git a/conf.go b/conf.go new file mode 100644 index 0000000..0648911 --- /dev/null +++ b/conf.go @@ -0,0 +1,110 @@ +// Package conf loads configuration values from one or more Sources into a +// tagged Go struct. +// +// Field tags: +// +// env:"KEY" key looked up in each Source (required to bind a field) +// default:"v" raw value used when no Source returns the key +// sep:"," separator for slice fields (default ",") +// env:"-" skip the field +// +// Sources are tried in order; the first one returning a value wins. +// If the destination type implements Validator, Validate is called after the +// fields are populated. +package conf + +import ( + "reflect" + + "git.juancwu.dev/juancwu/errx" +) + +// Validator is implemented by config types that want a post-load hook. +type Validator interface { + Validate() error +} + +// Load populates dst from sources. dst must be a non-nil pointer to a struct. +func Load(dst any, sources ...Source) error { + const op = "conf.Load" + + if dst == nil { + return errx.New(op, "dst is nil") + } + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Pointer || v.IsNil() { + return errx.New(op, "dst must be a non-nil pointer to a struct") + } + v = v.Elem() + if v.Kind() != reflect.Struct { + return errx.New(op, "dst must point to a struct") + } + + if err := walk(v, sources); err != nil { + return errx.Wrap(op, err) + } + + if val, ok := dst.(Validator); ok { + if err := val.Validate(); err != nil { + return errx.Wrap(op, err) + } + } + return nil +} + +func walk(v reflect.Value, sources []Source) error { + const op = "conf.walk" + + t := v.Type() + for i := 0; i < v.NumField(); i++ { + sf := t.Field(i) + if !sf.IsExported() { + continue + } + fv := v.Field(i) + + // Recurse into nested structs (and pointers to structs) that have no env tag. + key := sf.Tag.Get("env") + if key == "" { + switch { + case fv.Kind() == reflect.Struct: + if err := walk(fv, sources); err != nil { + return err + } + case fv.Kind() == reflect.Pointer && fv.Type().Elem().Kind() == reflect.Struct: + if fv.IsNil() { + fv.Set(reflect.New(fv.Type().Elem())) + } + if err := walk(fv.Elem(), sources); err != nil { + return err + } + } + continue + } + if key == "-" { + continue + } + + raw, ok := lookup(sources, key) + if !ok { + raw, ok = sf.Tag.Lookup("default") + if !ok { + continue + } + } + + if err := assignString(fv, raw, sf.Tag.Get("sep")); err != nil { + return errx.Wrapf(op, err, "field %s (%s)", sf.Name, key) + } + } + return nil +} + +func lookup(sources []Source, key string) (string, bool) { + for _, s := range sources { + if v, ok := s.Lookup(key); ok { + return v, true + } + } + return "", false +} diff --git a/conf_test.go b/conf_test.go new file mode 100644 index 0000000..1c678b3 --- /dev/null +++ b/conf_test.go @@ -0,0 +1,134 @@ +package conf + +import ( + "errors" + "strings" + "testing" + "time" +) + +type allTypes struct { + S string `env:"S" default:"hi"` + B bool `env:"B" default:"true"` + I int `env:"I" default:"7"` + I64 int64 `env:"I64" default:"100"` + U uint32 `env:"U" default:"3"` + F float64 `env:"F" default:"1.5"` + D time.Duration `env:"D" default:"5s"` + BS []byte `env:"BS" default:"abc"` + List []string `env:"LIST" default:"a,b,c"` + Skip string `env:"-"` + None string +} + +func TestLoadDefaults(t *testing.T) { + var c allTypes + if err := Load(&c); err != nil { + t.Fatal(err) + } + if c.S != "hi" || !c.B || c.I != 7 || c.I64 != 100 || c.U != 3 || c.F != 1.5 || + c.D != 5*time.Second || string(c.BS) != "abc" || + strings.Join(c.List, ",") != "a,b,c" { + t.Fatalf("defaults not applied: %+v", c) + } +} + +func TestPrecedence(t *testing.T) { + high := MapSource(map[string]string{"S": "high"}) + low := MapSource(map[string]string{"S": "low", "I": "42"}) + + var c allTypes + if err := Load(&c, high, low); err != nil { + t.Fatal(err) + } + if c.S != "high" { + t.Fatalf("S=%q want high", c.S) + } + if c.I != 42 { + t.Fatalf("I=%d want 42", c.I) + } +} + +func TestPointerField(t *testing.T) { + type p struct { + X *int `env:"X"` + Y *int `env:"Y"` + } + var c p + if err := Load(&c, MapSource(map[string]string{"X": "9"})); err != nil { + t.Fatal(err) + } + if c.X == nil || *c.X != 9 { + t.Fatalf("X=%v", c.X) + } + if c.Y != nil { + t.Fatalf("Y should be nil, got %v", *c.Y) + } +} + +func TestNestedStruct(t *testing.T) { + type Inner struct { + A string `env:"A" default:"x"` + } + type Outer struct { + Inner Inner + B string `env:"B" default:"y"` + } + var o Outer + if err := Load(&o); err != nil { + t.Fatal(err) + } + if o.Inner.A != "x" || o.B != "y" { + t.Fatalf("%+v", o) + } +} + +type withValidate struct { + URL string `env:"URL"` +} + +func (w *withValidate) Validate() error { + if w.URL == "" { + return errors.New("URL required") + } + return nil +} + +func TestValidatorFires(t *testing.T) { + var w withValidate + err := Load(&w) + if err == nil || !strings.Contains(err.Error(), "URL required") { + t.Fatalf("want URL required error, got %v", err) + } +} + +func TestParseError(t *testing.T) { + type bad struct { + N int `env:"N"` + } + var b bad + err := Load(&b, MapSource(map[string]string{"N": "notanumber"})) + if err == nil || !strings.Contains(err.Error(), "N") { + t.Fatalf("want field N error, got %v", err) + } +} + +func TestRejectsNonPointer(t *testing.T) { + var c allTypes + if err := Load(c); err == nil { + t.Fatal("expected error for non-pointer") + } +} + +func TestSepOverride(t *testing.T) { + type s struct { + L []string `env:"L" sep:"|"` + } + var x s + if err := Load(&x, MapSource(map[string]string{"L": "a|b|c"})); err != nil { + t.Fatal(err) + } + if strings.Join(x.L, ",") != "a,b,c" { + t.Fatalf("got %v", x.L) + } +} diff --git a/dotenv.go b/dotenv.go new file mode 100644 index 0000000..e00abf2 --- /dev/null +++ b/dotenv.go @@ -0,0 +1,103 @@ +package conf + +import ( + "bufio" + "io" + "os" + "strings" + + "git.juancwu.dev/juancwu/errx" +) + +// DotEnvFile loads KEY=VALUE pairs from path. Lines starting with # and blank +// lines are ignored. Values may be wrapped in single or double quotes; double +// quotes honor \n, \t, \r, \\, and \" escapes. +func DotEnvFile(path string) (Source, error) { + const op = "conf.DotEnvFile" + f, err := os.Open(path) + if err != nil { + return nil, errx.Wrapf(op, err, "open %s", path) + } + defer f.Close() + return DotEnvReader(f) +} + +// DotEnvReader is DotEnvFile for an arbitrary reader. +func DotEnvReader(r io.Reader) (Source, error) { + const op = "conf.DotEnvReader" + + m := map[string]string{} + sc := bufio.NewScanner(r) + lineNo := 0 + for sc.Scan() { + lineNo++ + line := strings.TrimSpace(sc.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + line = strings.TrimPrefix(line, "export ") + eq := strings.IndexByte(line, '=') + if eq < 0 { + return nil, errx.Newf(op, "line %d: missing '='", lineNo) + } + key := strings.TrimSpace(line[:eq]) + val := strings.TrimSpace(line[eq+1:]) + // strip trailing inline comment for unquoted values + if !isQuoted(val) { + if i := strings.Index(val, " #"); i >= 0 { + val = strings.TrimSpace(val[:i]) + } + } + v, err := unquote(val) + if err != nil { + return nil, errx.Wrapf(op, err, "line %d", lineNo) + } + m[key] = v + } + if err := sc.Err(); err != nil { + return nil, errx.Wrap(op, err) + } + return MapSource(m), nil +} + +func isQuoted(s string) bool { + if len(s) < 2 { + return false + } + return (s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'') +} + +func unquote(s string) (string, error) { + if len(s) >= 2 && s[0] == '\'' && s[len(s)-1] == '\'' { + return s[1 : len(s)-1], nil + } + if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' { + inner := s[1 : len(s)-1] + var b strings.Builder + for i := 0; i < len(inner); i++ { + c := inner[i] + if c != '\\' || i+1 >= len(inner) { + b.WriteByte(c) + continue + } + i++ + switch inner[i] { + case 'n': + b.WriteByte('\n') + case 't': + b.WriteByte('\t') + case 'r': + b.WriteByte('\r') + case '\\': + b.WriteByte('\\') + case '"': + b.WriteByte('"') + default: + b.WriteByte('\\') + b.WriteByte(inner[i]) + } + } + return b.String(), nil + } + return s, nil +} diff --git a/dotenv_test.go b/dotenv_test.go new file mode 100644 index 0000000..650f90a --- /dev/null +++ b/dotenv_test.go @@ -0,0 +1,35 @@ +package conf + +import ( + "strings" + "testing" +) + +func TestDotEnvReader(t *testing.T) { + in := `# comment +FOO=bar +export QUOTED="hello\nworld" +SINGLE='ab cd' +INLINE=plain # trailing +EMPTY= +` + src, err := DotEnvReader(strings.NewReader(in)) + if err != nil { + t.Fatal(err) + } + cases := map[string]string{ + "FOO": "bar", + "QUOTED": "hello\nworld", + "SINGLE": "ab cd", + "INLINE": "plain", + } + for k, want := range cases { + got, ok := src.Lookup(k) + if !ok || got != want { + t.Errorf("%s = %q ok=%v, want %q", k, got, ok, want) + } + } + if _, ok := src.Lookup("EMPTY"); ok { + t.Errorf("EMPTY should be absent") + } +} diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..3116069 --- /dev/null +++ b/example_test.go @@ -0,0 +1,46 @@ +package conf_test + +import ( + "fmt" + "strings" + "time" + + "git.juancwu.dev/juancwu/conf" + "git.juancwu.dev/juancwu/errx" +) + +type Config struct { + BindAddr string `env:"BIND_ADDR" default:":8080"` + DatabaseURL string `env:"DATABASE_URL"` + SessionCookieSecure bool `env:"SESSION_COOKIE_SECURE" default:"true"` + SessionIdleTTL time.Duration `env:"SESSION_IDLE_TTL" default:"24h"` + JWTSecret []byte `env:"JWT_SECRET"` + WorkerConcurrency int `env:"WORKER_CONCURRENCY" default:"4"` +} + +func (c *Config) Validate() error { + const op = "config.Validate" + if strings.TrimSpace(c.DatabaseURL) == "" { + return errx.New(op, "DATABASE_URL is required") + } + if len(c.JWTSecret) == 0 { + return errx.New(op, "JWT_SECRET is required") + } + return nil +} + +func ExampleLoad() { + src := conf.MapSource(map[string]string{ + "DATABASE_URL": "postgres://localhost/x", + "JWT_SECRET": "shh", + }) + + var cfg Config + if err := conf.Load(&cfg, src); err != nil { + fmt.Println("err:", err) + return + } + fmt.Printf("bind=%s workers=%d idle=%s secure=%v", + cfg.BindAddr, cfg.WorkerConcurrency, cfg.SessionIdleTTL, cfg.SessionCookieSecure) + // Output: bind=:8080 workers=4 idle=24h0m0s secure=true +} diff --git a/files_test.go b/files_test.go new file mode 100644 index 0000000..0a25c17 --- /dev/null +++ b/files_test.go @@ -0,0 +1,82 @@ +package conf + +import ( + "strings" + "testing" +) + +func TestYAMLFlatten(t *testing.T) { + src, err := YAMLReader(strings.NewReader(` +bind_addr: ":9000" +session: + idle_ttl: 1h + cookie: + secure: true +list: [a, b, c] +`)) + if err != nil { + t.Fatal(err) + } + checks := map[string]string{ + "BIND_ADDR": ":9000", + "SESSION_IDLE_TTL": "1h0m0s", // yaml.v3 decodes as time.Duration string? actually it stays string + "SESSION_COOKIE_SECURE": "true", + "LIST": "a,b,c", + } + for k, want := range checks { + got, ok := src.Lookup(k) + if !ok { + t.Errorf("%s missing", k) + continue + } + if k == "SESSION_IDLE_TTL" { + // yaml decodes "1h" as plain string + if got != "1h" && got != want { + t.Errorf("%s = %q", k, got) + } + continue + } + if got != want { + t.Errorf("%s = %q want %q", k, got, want) + } + } +} + +func TestJSONFlatten(t *testing.T) { + src, err := JSONReader(strings.NewReader(`{"a":{"b":"c"},"n":7,"arr":[1,2,3]}`)) + if err != nil { + t.Fatal(err) + } + if v, _ := src.Lookup("A_B"); v != "c" { + t.Errorf("A_B=%q", v) + } + if v, _ := src.Lookup("N"); v != "7" { + t.Errorf("N=%q", v) + } + if v, _ := src.Lookup("ARR"); v != "1,2,3" { + t.Errorf("ARR=%q", v) + } +} + +func TestTOMLFlatten(t *testing.T) { + src, err := TOMLReader(strings.NewReader(` +bind_addr = ":9000" +[session] +idle_ttl = "1h" +[session.cookie] +secure = true +`)) + if err != nil { + t.Fatal(err) + } + for k, want := range map[string]string{ + "BIND_ADDR": ":9000", + "SESSION_IDLE_TTL": "1h", + "SESSION_COOKIE_SECURE": "true", + } { + got, ok := src.Lookup(k) + if !ok || got != want { + t.Errorf("%s = %q ok=%v want %q", k, got, ok, want) + } + } +} diff --git a/flatten.go b/flatten.go new file mode 100644 index 0000000..7291aaa --- /dev/null +++ b/flatten.go @@ -0,0 +1,44 @@ +package conf + +import ( + "fmt" + "strings" +) + +// flatten walks a nested map[string]any (as produced by yaml/json/toml decoders) +// and returns a flat key->string map. Keys are uppercased and joined with "_". +// Scalar values are stringified via fmt.Sprint. Slices are joined with ",". +func flatten(in map[string]any) map[string]string { + out := map[string]string{} + flattenInto(out, "", in) + return out +} + +func flattenInto(out map[string]string, prefix string, in map[string]any) { + for k, v := range in { + key := strings.ToUpper(k) + if prefix != "" { + key = prefix + "_" + key + } + switch x := v.(type) { + case map[string]any: + flattenInto(out, key, x) + case map[any]any: // yaml.v3 with non-string keys + m := make(map[string]any, len(x)) + for kk, vv := range x { + m[fmt.Sprint(kk)] = vv + } + flattenInto(out, key, m) + case []any: + parts := make([]string, len(x)) + for i, e := range x { + parts[i] = fmt.Sprint(e) + } + out[key] = strings.Join(parts, ",") + case nil: + // skip + default: + out[key] = fmt.Sprint(x) + } + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a0bdaea --- /dev/null +++ b/go.mod @@ -0,0 +1,9 @@ +module git.juancwu.dev/juancwu/conf + +go 1.26.2 + +require ( + git.juancwu.dev/juancwu/errx v0.1.0 + github.com/BurntSushi/toml v1.6.0 + github.com/goccy/go-yaml v1.19.2 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..10f09f0 --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +git.juancwu.dev/juancwu/errx v0.1.0 h1:92yA0O1BkKGXcoEiWtxwH/ztXCjoV1KSTMtKpm3gd2w= +git.juancwu.dev/juancwu/errx v0.1.0/go.mod h1:7jNhBOwcZ/q7zDD6mln3QCJBYZ8T6h+dAdxVfykprTk= +github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= +github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= +github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= diff --git a/json.go b/json.go new file mode 100644 index 0000000..cefbca8 --- /dev/null +++ b/json.go @@ -0,0 +1,33 @@ +package conf + +import ( + "encoding/json" + "io" + "os" + + "git.juancwu.dev/juancwu/errx" +) + +// JSONFile loads a JSON object and exposes it as a flat key->string Source. +func JSONFile(path string) (Source, error) { + const op = "conf.JSONFile" + f, err := os.Open(path) + if err != nil { + return nil, errx.Wrapf(op, err, "open %s", path) + } + defer f.Close() + return JSONReader(f) +} + +// JSONReader is JSONFile for an arbitrary reader. +func JSONReader(r io.Reader) (Source, error) { + const op = "conf.JSONReader" + var raw map[string]any + if err := json.NewDecoder(r).Decode(&raw); err != nil { + if err == io.EOF { + return MapSource(nil), nil + } + return nil, errx.Wrap(op, err) + } + return MapSource(flatten(raw)), nil +} diff --git a/parse.go b/parse.go new file mode 100644 index 0000000..56ad9e5 --- /dev/null +++ b/parse.go @@ -0,0 +1,92 @@ +package conf + +import ( + "reflect" + "strconv" + "strings" + "time" + + "git.juancwu.dev/juancwu/errx" +) + +var ( + durationType = reflect.TypeOf(time.Duration(0)) + timeType = reflect.TypeOf(time.Time{}) + byteSliceTyp = reflect.TypeOf([]byte(nil)) +) + +// assignString parses raw into the value pointed to by v. v must be settable. +func assignString(v reflect.Value, raw, sep string) error { + const op = "conf.assignString" + + if v.Kind() == reflect.Pointer { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + return assignString(v.Elem(), raw, sep) + } + + switch v.Type() { + case durationType: + d, err := time.ParseDuration(raw) + if err != nil { + return errx.Wrapf(op, err, "parse duration %q", raw) + } + v.SetInt(int64(d)) + return nil + case timeType: + t, err := time.Parse(time.RFC3339, raw) + if err != nil { + return errx.Wrapf(op, err, "parse time %q", raw) + } + v.Set(reflect.ValueOf(t)) + return nil + case byteSliceTyp: + v.SetBytes([]byte(raw)) + return nil + } + + switch v.Kind() { + case reflect.String: + v.SetString(raw) + case reflect.Bool: + b, err := strconv.ParseBool(raw) + if err != nil { + return errx.Wrapf(op, err, "parse bool %q", raw) + } + v.SetBool(b) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(raw, 10, v.Type().Bits()) + if err != nil { + return errx.Wrapf(op, err, "parse int %q", raw) + } + v.SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + n, err := strconv.ParseUint(raw, 10, v.Type().Bits()) + if err != nil { + return errx.Wrapf(op, err, "parse uint %q", raw) + } + v.SetUint(n) + case reflect.Float32, reflect.Float64: + f, err := strconv.ParseFloat(raw, v.Type().Bits()) + if err != nil { + return errx.Wrapf(op, err, "parse float %q", raw) + } + v.SetFloat(f) + case reflect.Slice: + if sep == "" { + sep = "," + } + parts := strings.Split(raw, sep) + out := reflect.MakeSlice(v.Type(), len(parts), len(parts)) + for i, p := range parts { + if err := assignString(out.Index(i), strings.TrimSpace(p), sep); err != nil { + return err + } + } + v.Set(out) + default: + return errx.Newf(op, "unsupported field type %s", v.Type()) + } + return nil +} diff --git a/source.go b/source.go new file mode 100644 index 0000000..5f4ea94 --- /dev/null +++ b/source.go @@ -0,0 +1,34 @@ +package conf + +import "os" + +// Source returns a string value for a key, or false if the key is not present. +type Source interface { + Lookup(key string) (string, bool) +} + +type envSource struct{} + +func (envSource) Lookup(key string) (string, bool) { + v, ok := os.LookupEnv(key) + if !ok || v == "" { + return "", false + } + return v, true +} + +// EnvSource reads from the process environment. Empty values are treated as absent. +func EnvSource() Source { return envSource{} } + +type mapSource map[string]string + +func (m mapSource) Lookup(key string) (string, bool) { + v, ok := m[key] + if !ok || v == "" { + return "", false + } + return v, true +} + +// MapSource wraps a key/value map as a Source. The map is not copied. +func MapSource(m map[string]string) Source { return mapSource(m) } diff --git a/toml.go b/toml.go new file mode 100644 index 0000000..c5083ed --- /dev/null +++ b/toml.go @@ -0,0 +1,30 @@ +package conf + +import ( + "io" + "os" + + "git.juancwu.dev/juancwu/errx" + "github.com/BurntSushi/toml" +) + +// TOMLFile loads a TOML document and exposes it as a flat key->string Source. +func TOMLFile(path string) (Source, error) { + const op = "conf.TOMLFile" + f, err := os.Open(path) + if err != nil { + return nil, errx.Wrapf(op, err, "open %s", path) + } + defer f.Close() + return TOMLReader(f) +} + +// TOMLReader is TOMLFile for an arbitrary reader. +func TOMLReader(r io.Reader) (Source, error) { + const op = "conf.TOMLReader" + var raw map[string]any + if _, err := toml.NewDecoder(r).Decode(&raw); err != nil { + return nil, errx.Wrap(op, err) + } + return MapSource(flatten(raw)), nil +} diff --git a/yaml.go b/yaml.go new file mode 100644 index 0000000..3fb295b --- /dev/null +++ b/yaml.go @@ -0,0 +1,33 @@ +package conf + +import ( + "io" + "os" + + "git.juancwu.dev/juancwu/errx" + "github.com/goccy/go-yaml" +) + +// YAMLFile loads a YAML document and exposes it as a flat key->string Source. +func YAMLFile(path string) (Source, error) { + const op = "conf.YAMLFile" + f, err := os.Open(path) + if err != nil { + return nil, errx.Wrapf(op, err, "open %s", path) + } + defer f.Close() + return YAMLReader(f) +} + +// YAMLReader is YAMLFile for an arbitrary reader. +func YAMLReader(r io.Reader) (Source, error) { + const op = "conf.YAMLReader" + var raw map[string]any + if err := yaml.NewDecoder(r).Decode(&raw); err != nil { + if err == io.EOF { + return MapSource(nil), nil + } + return nil, errx.Wrap(op, err) + } + return MapSource(flatten(raw)), nil +}