fix: leaky test db

This commit is contained in:
juancwu 2026-05-06 15:18:17 +00:00
commit d6d968d209

View file

@ -6,6 +6,9 @@ import (
"net" "net"
"os" "os"
"os/exec" "os/exec"
"os/signal"
"sync"
"syscall"
"testing" "testing"
"time" "time"
@ -59,12 +62,28 @@ func PostgresMain(m *testing.M) {
os.Exit(1) os.Exit(1)
} }
var stopOnce sync.Once
stop := func() { stop := func() {
// `docker rm -f` because --rm only fires on a clean exit; force-stop the stopOnce.Do(func() {
// container regardless of state so leftover containers don't accumulate. // `docker rm -f` because --rm only fires on a clean exit; force-stop the
_ = exec.Command("docker", "rm", "-f", containerName).Run() // container regardless of state so leftover containers don't accumulate.
_ = exec.Command("docker", "rm", "-f", containerName).Run()
})
} }
// Defers don't run when the test binary is killed by a signal — Ctrl+C from
// `task test`, `go test` timeout (SIGQUIT then SIGKILL), or SIGPIPE when a
// piped consumer like tparse exits. Trap the common ones so we still clean up.
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
sig := <-sigCh
stop()
// Re-raise with default handler so the parent sees the real exit cause.
signal.Reset(sig.(syscall.Signal))
_ = syscall.Kill(os.Getpid(), sig.(syscall.Signal))
}()
url := fmt.Sprintf("postgres://budgit_test:testpass@127.0.0.1:%d/budgit_test?sslmode=disable", port) url := fmt.Sprintf("postgres://budgit_test:testpass@127.0.0.1:%d/budgit_test?sslmode=disable", port)
if err := waitForPostgres(url, 60*time.Second); err != nil { if err := waitForPostgres(url, 60*time.Second); err != nil {
stop() stop()
@ -79,11 +98,10 @@ func PostgresMain(m *testing.M) {
} }
// Run tests, then ALWAYS stop the container — including on panic. // Run tests, then ALWAYS stop the container — including on panic.
code := func() int { os.Exit(func() int {
defer stop() defer stop()
return m.Run() return m.Run()
}() }())
os.Exit(code)
} }
func freePort() (int, error) { func freePort() (int, error) {