Add custom DNS resolver to ACME configuration (#20400)
* Handle caching of ACME config Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com> * Add DNS resolvers to ACME configuration Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com> * Add custom DNS resolver to challenge verification This required plumbing through the config, reloading it when necessary, and creating a custom net.Resolver instance. Not immediately clear is how we'd go about building a custom DNS validation mechanism that supported multiple resolvers. Likely we'd need to rely on meikg/dns and handle the resolution separately for each container and use a custom Dialer that assumes the address is already pre-resolved. Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com> * Improvements to Docker harness - Expose additional service information, allowing callers to figure out both the local address and the network-specific address of the service container, and - Allow modifying permissions on uploaded container files. Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com> * Add infrastructure to run Bind9 in a container for tests Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com> * Validate DNS-01 challenge works Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com> --------- Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>
This commit is contained in:
parent
155a32fc77
commit
e42fd09b47
|
@ -85,11 +85,11 @@ func (ace *ACMEChallengeEngine) LoadFromStorage(b *backend, sc *storageContext)
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ace *ACMEChallengeEngine) Run(b *backend) {
|
||||
func (ace *ACMEChallengeEngine) Run(b *backend, state *acmeState) {
|
||||
for true {
|
||||
// err == nil on shutdown.
|
||||
b.Logger().Debug("Starting ACME challenge validation engine")
|
||||
err := ace._run(b)
|
||||
err := ace._run(b, state)
|
||||
if err != nil {
|
||||
b.Logger().Error("Got unexpected error from ACME challenge validation engine", "err", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
|
@ -99,7 +99,7 @@ func (ace *ACMEChallengeEngine) Run(b *backend) {
|
|||
}
|
||||
}
|
||||
|
||||
func (ace *ACMEChallengeEngine) _run(b *backend) error {
|
||||
func (ace *ACMEChallengeEngine) _run(b *backend, state *acmeState) error {
|
||||
// This runner uses a background context for storage operations: we don't
|
||||
// want to tie it to a inbound request and we don't want to set a time
|
||||
// limit, so create a fresh background context.
|
||||
|
@ -177,6 +177,11 @@ func (ace *ACMEChallengeEngine) _run(b *backend) error {
|
|||
continue
|
||||
}
|
||||
|
||||
config, err := state.getConfigWithUpdate(runnerSC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed fetching ACME configuration: %w", err)
|
||||
}
|
||||
|
||||
// Since this work item was valid, we won't expect to see it in
|
||||
// the validation queue again until it is executed. Here, we
|
||||
// want to avoid infinite looping above (if we removed the one
|
||||
|
@ -190,7 +195,7 @@ func (ace *ACMEChallengeEngine) _run(b *backend) error {
|
|||
// could have a RetryAfter date we're not aware of (e.g., if the
|
||||
// cluster restarted as we do not read the entries there).
|
||||
channel := make(chan bool, 1)
|
||||
go ace.VerifyChallenge(runnerSC, task.Identifier, channel)
|
||||
go ace.VerifyChallenge(runnerSC, task.Identifier, channel, config)
|
||||
finishedWorkersChannels = append(finishedWorkersChannels, channel)
|
||||
startedWork = true
|
||||
}
|
||||
|
@ -279,11 +284,11 @@ func (ace *ACMEChallengeEngine) AcceptChallenge(sc *storageContext, account stri
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ace *ACMEChallengeEngine) VerifyChallenge(runnerSc *storageContext, id string, finished chan bool) {
|
||||
func (ace *ACMEChallengeEngine) VerifyChallenge(runnerSc *storageContext, id string, finished chan bool, config *acmeConfigEntry) {
|
||||
sc, _ /* cancel func */ := runnerSc.WithFreshTimeout(MaxChallengeTimeout)
|
||||
runnerSc.Backend.Logger().Debug("Starting verification of challenge: %v", id)
|
||||
|
||||
if retry, retryAfter, err := ace._verifyChallenge(sc, id); err != nil {
|
||||
if retry, retryAfter, err := ace._verifyChallenge(sc, id, config); err != nil {
|
||||
// Because verification of this challenge failed, we need to retry
|
||||
// it in the future. Log the error and re-add the item to the queue
|
||||
// to try again later.
|
||||
|
@ -315,7 +320,7 @@ func (ace *ACMEChallengeEngine) VerifyChallenge(runnerSc *storageContext, id str
|
|||
finished <- false
|
||||
}
|
||||
|
||||
func (ace *ACMEChallengeEngine) _verifyChallenge(sc *storageContext, id string) (bool, time.Time, error) {
|
||||
func (ace *ACMEChallengeEngine) _verifyChallenge(sc *storageContext, id string, config *acmeConfigEntry) (bool, time.Time, error) {
|
||||
now := time.Now()
|
||||
path := acmeValidationPrefix + id
|
||||
challengeEntry, err := sc.Storage.Get(sc.Context, path)
|
||||
|
@ -384,7 +389,7 @@ func (ace *ACMEChallengeEngine) _verifyChallenge(sc *storageContext, id string)
|
|||
return ace._verifyChallengeCleanup(sc, err, id)
|
||||
}
|
||||
|
||||
valid, err = ValidateHTTP01Challenge(authz.Identifier.Value, cv.Token, cv.Thumbprint)
|
||||
valid, err = ValidateHTTP01Challenge(authz.Identifier.Value, cv.Token, cv.Thumbprint, config)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error validating http-01 challenge %v: %w", id, err)
|
||||
return ace._verifyChallengeRetry(sc, cv, authz, err, id)
|
||||
|
@ -395,7 +400,7 @@ func (ace *ACMEChallengeEngine) _verifyChallenge(sc *storageContext, id string)
|
|||
return ace._verifyChallengeCleanup(sc, err, id)
|
||||
}
|
||||
|
||||
valid, err = ValidateDNS01Challenge(authz.Identifier.Value, cv.Token, cv.Thumbprint)
|
||||
valid, err = ValidateDNS01Challenge(authz.Identifier.Value, cv.Token, cv.Thumbprint, config)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error validating dns-01 challenge %v: %w", id, err)
|
||||
return ace._verifyChallengeRetry(sc, cv, authz, err, id)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package pki
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
@ -11,6 +12,8 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
const DNSChallengePrefix = "_acme-challenge."
|
||||
|
||||
// ValidateKeyAuthorization validates that the given keyAuthz from a challenge
|
||||
// matches our expectation, returning (true, nil) if so, or (false, err) if
|
||||
// not.
|
||||
|
@ -47,12 +50,33 @@ func ValidateSHA256KeyAuthorization(keyAuthz string, token string, thumbprint st
|
|||
return true, nil
|
||||
}
|
||||
|
||||
func buildResolver(config *acmeConfigEntry) (*net.Resolver, error) {
|
||||
if len(config.DNSResolver) == 0 {
|
||||
return net.DefaultResolver, nil
|
||||
}
|
||||
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
StrictErrors: false,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
return d.DialContext(ctx, network, config.DNSResolver)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validates a given ACME http-01 challenge against the specified domain,
|
||||
// per RFC 8555.
|
||||
//
|
||||
// We attempt to be defensive here against timeouts, extra redirects, &c.
|
||||
func ValidateHTTP01Challenge(domain string, token string, thumbprint string) (bool, error) {
|
||||
func ValidateHTTP01Challenge(domain string, token string, thumbprint string, config *acmeConfigEntry) (bool, error) {
|
||||
path := "http://" + domain + "/.well-known/acme-challenge/" + token
|
||||
resolver, err := buildResolver(config)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to build resolver: %w", err)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
// Only a single request is sent to this server as we do not do any
|
||||
|
@ -69,6 +93,7 @@ func ValidateHTTP01Challenge(domain string, token string, thumbprint string) (bo
|
|||
DialContext: (&net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: -1 * time.Second,
|
||||
Resolver: resolver,
|
||||
}).DialContext,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
@ -129,7 +154,7 @@ func ValidateHTTP01Challenge(domain string, token string, thumbprint string) (bo
|
|||
return ValidateKeyAuthorization(keyAuthz, token, thumbprint)
|
||||
}
|
||||
|
||||
func ValidateDNS01Challenge(domain string, token string, thumbprint string) (bool, error) {
|
||||
func ValidateDNS01Challenge(domain string, token string, thumbprint string, config *acmeConfigEntry) (bool, error) {
|
||||
// Here, domain is the value from the post-wildcard-processed identifier.
|
||||
// Per RFC 8555, no difference in validation occurs if a wildcard entry
|
||||
// is requested or if a non-wildcard entry is requested.
|
||||
|
@ -140,10 +165,18 @@ func ValidateDNS01Challenge(domain string, token string, thumbprint string) (boo
|
|||
//
|
||||
// 1. To control the actual resolver via ACME configuration,
|
||||
// 2. To use a context to set stricter timeout limits.
|
||||
name := "_acme-challenge." + domain
|
||||
results, err := net.LookupTXT(name)
|
||||
resolver, err := buildResolver(config)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("dns-01: failed to lookup TXT records for domain (%v): %w", name, err)
|
||||
return false, fmt.Errorf("failed to build resolver: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
name := DNSChallengePrefix + domain
|
||||
results, err := resolver.LookupTXT(ctx, name)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("dns-01: failed to lookup TXT records for domain (%v) via resolver %v: %w", name, config.DNSResolver, err)
|
||||
}
|
||||
|
||||
for _, keyAuthz := range results {
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
package pki
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/pki/dnstest"
|
||||
)
|
||||
|
||||
type keyAuthorizationTestCase struct {
|
||||
|
@ -126,9 +130,9 @@ func TestAcmeValidateHTTP01Challenge(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
host := ts.URL[7:]
|
||||
isValid, err := ValidateHTTP01Challenge(host, tc.token, tc.thumbprint)
|
||||
isValid, err := ValidateHTTP01Challenge(host, tc.token, tc.thumbprint, &acmeConfigEntry{})
|
||||
if !isValid && err == nil {
|
||||
t.Fatalf("[tc=%d/handler=%d] expected failure to give reason via err (%v / %v)", handlerIndex, index, isValid, err)
|
||||
t.Fatalf("[tc=%d/handler=%d] expected failure to give reason via err (%v / %v)", index, handlerIndex, isValid, err)
|
||||
}
|
||||
|
||||
expectedValid := !tc.shouldFail
|
||||
|
@ -159,7 +163,7 @@ func TestAcmeValidateHTTP01Challenge(t *testing.T) {
|
|||
}
|
||||
tooLarge := func(w http.ResponseWriter, r *http.Request) {
|
||||
for i := 0; i < 512; i++ {
|
||||
w.Write([]byte("my-token.my-thumbprint"))
|
||||
w.Write([]byte("my-token.my-thumbprint\n"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -175,10 +179,43 @@ func TestAcmeValidateHTTP01Challenge(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
host := ts.URL[7:]
|
||||
isValid, err := ValidateHTTP01Challenge(host, "my-token", "my-thumbprint")
|
||||
isValid, err := ValidateHTTP01Challenge(host, "my-token", "my-thumbprint", &acmeConfigEntry{})
|
||||
if isValid || err == nil {
|
||||
t.Fatalf("[handler=%d] expected failure validating challenge (%v / %v)", handlerIndex, isValid, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcmeValidateDNS01Challenge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
host := "alsdkjfasldkj.com"
|
||||
resolver := dnstest.SetupResolver(t, host)
|
||||
defer resolver.Cleanup()
|
||||
|
||||
t.Logf("DNS Server Address: %v", resolver.GetLocalAddr())
|
||||
|
||||
config := &acmeConfigEntry{
|
||||
DNSResolver: resolver.GetLocalAddr(),
|
||||
}
|
||||
|
||||
for index, tc := range keyAuthorizationTestCases {
|
||||
checksum := sha256.Sum256([]byte(tc.keyAuthz))
|
||||
authz := base64.RawURLEncoding.EncodeToString(checksum[:])
|
||||
resolver.AddRecord(DNSChallengePrefix+host, "TXT", authz)
|
||||
resolver.PushConfig()
|
||||
|
||||
isValid, err := ValidateDNS01Challenge(host, tc.token, tc.thumbprint, config)
|
||||
if !isValid && err == nil {
|
||||
t.Fatalf("[tc=%d] expected failure to give reason via err (%v / %v)", index, isValid, err)
|
||||
}
|
||||
|
||||
expectedValid := !tc.shouldFail
|
||||
if expectedValid != isValid {
|
||||
t.Fatalf("[tc=%d] got ret=%v (err=%v), expected ret=%v (shouldFail=%v)", index, isValid, err, expectedValid, tc.shouldFail)
|
||||
}
|
||||
|
||||
resolver.RemoveAllRecords()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,6 +41,10 @@ type acmeState struct {
|
|||
nextExpiry *atomic.Int64
|
||||
nonces *sync.Map // map[string]time.Time
|
||||
validator *ACMEChallengeEngine
|
||||
|
||||
configDirty *atomic.Bool
|
||||
_config sync.RWMutex
|
||||
config acmeConfigEntry
|
||||
}
|
||||
|
||||
type acmeThumbprint struct {
|
||||
|
@ -49,23 +53,75 @@ type acmeThumbprint struct {
|
|||
}
|
||||
|
||||
func NewACMEState() *acmeState {
|
||||
return &acmeState{
|
||||
nextExpiry: new(atomic.Int64),
|
||||
nonces: new(sync.Map),
|
||||
validator: NewACMEChallengeEngine(),
|
||||
state := &acmeState{
|
||||
nextExpiry: new(atomic.Int64),
|
||||
nonces: new(sync.Map),
|
||||
validator: NewACMEChallengeEngine(),
|
||||
configDirty: new(atomic.Bool),
|
||||
}
|
||||
// Config hasn't been loaded yet; mark dirty.
|
||||
state.configDirty.Store(true)
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (a *acmeState) Initialize(b *backend, sc *storageContext) error {
|
||||
if err := a.validator.Initialize(b, sc); err != nil {
|
||||
// Load the ACME config.
|
||||
_, err := a.getConfigWithUpdate(sc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error initializing ACME engine: %w", err)
|
||||
}
|
||||
|
||||
go a.validator.Run(b)
|
||||
// Kick off our ACME challenge validation engine.
|
||||
if err := a.validator.Initialize(b, sc); err != nil {
|
||||
return fmt.Errorf("error initializing ACME engine: %w", err)
|
||||
}
|
||||
go a.validator.Run(b, a)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *acmeState) markConfigDirty() {
|
||||
a.configDirty.Store(true)
|
||||
}
|
||||
|
||||
func (a *acmeState) reloadConfigIfRequired(sc *storageContext) error {
|
||||
if !a.configDirty.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
a._config.Lock()
|
||||
defer a._config.Unlock()
|
||||
|
||||
if !a.configDirty.Load() {
|
||||
// Someone beat us to grabbing the above write lock and already
|
||||
// updated the config.
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := sc.getAcmeConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading config: %w", err)
|
||||
}
|
||||
|
||||
a.config = *config
|
||||
a.configDirty.Store(false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *acmeState) getConfigWithUpdate(sc *storageContext) (*acmeConfigEntry, error) {
|
||||
if err := a.reloadConfigIfRequired(sc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a._config.RLock()
|
||||
defer a._config.RUnlock()
|
||||
|
||||
configCopy := a.config
|
||||
return &configCopy, nil
|
||||
}
|
||||
|
||||
func generateNonce() (string, error) {
|
||||
return generateRandomBase64(21)
|
||||
}
|
||||
|
|
|
@ -542,6 +542,8 @@ func (b *backend) invalidate(ctx context.Context, key string) {
|
|||
case key == "config/crl":
|
||||
// We may need to reload our OCSP status flag
|
||||
b.crlBuilder.markConfigDirty()
|
||||
case key == storageAcmeConfig:
|
||||
b.acmeState.markConfigDirty()
|
||||
case key == storageIssuerConfig:
|
||||
b.crlBuilder.invalidateCRLBuildTime()
|
||||
case strings.HasPrefix(key, crossRevocationPrefix):
|
||||
|
|
|
@ -0,0 +1,280 @@
|
|||
package dnstest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
|
||||
"github.com/hashicorp/vault/sdk/helper/docker"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type TestServer struct {
|
||||
t *testing.T
|
||||
ctx context.Context
|
||||
|
||||
runner *docker.Runner
|
||||
startup *docker.Service
|
||||
|
||||
serial int
|
||||
forwarders []string
|
||||
domains []string
|
||||
records map[string]map[string][]string // domain -> record -> value(s).
|
||||
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
func SetupResolver(t *testing.T, domain string) *TestServer {
|
||||
return SetupResolverOnNetwork(t, domain, "")
|
||||
}
|
||||
|
||||
func SetupResolverOnNetwork(t *testing.T, domain string, network string) *TestServer {
|
||||
var ts TestServer
|
||||
ts.t = t
|
||||
ts.ctx = context.Background()
|
||||
ts.domains = []string{domain}
|
||||
ts.records = map[string]map[string][]string{}
|
||||
|
||||
ts.setupRunner(domain, network)
|
||||
ts.startContainer()
|
||||
ts.PushConfig()
|
||||
|
||||
return &ts
|
||||
}
|
||||
|
||||
func (ts *TestServer) setupRunner(domain string, network string) {
|
||||
var err error
|
||||
ts.runner, err = docker.NewServiceRunner(docker.RunOptions{
|
||||
ImageRepo: "ubuntu/bind9",
|
||||
ImageTag: "latest",
|
||||
ContainerName: "bind9-dns-" + strings.ReplaceAll(domain, ".", "-"),
|
||||
NetworkName: network,
|
||||
Ports: []string{"53/udp"},
|
||||
LogConsumer: func(s string) {
|
||||
ts.t.Logf(s)
|
||||
},
|
||||
})
|
||||
require.NoError(ts.t, err)
|
||||
}
|
||||
|
||||
func (ts *TestServer) startContainer() {
|
||||
connUpFunc := func(ctx context.Context, host string, port int) (docker.ServiceConfig, error) {
|
||||
// Perform a simple connection to this resolver, even though the
|
||||
// default configuration doesn't do anything useful.
|
||||
peer, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", host, port))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve peer: %v / %v: %w", host, port, err)
|
||||
}
|
||||
|
||||
conn, err := net.DialUDP("udp", nil, peer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial peer: %v / %v / %v: %w", host, port, peer, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write([]byte("garbage-in"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write to peer: %v / %v / %v: %w", host, port, peer, err)
|
||||
}
|
||||
|
||||
// Connection worked.
|
||||
return docker.NewServiceHostPort(host, port), nil
|
||||
}
|
||||
|
||||
result, err := ts.runner.StartService(ts.ctx, connUpFunc)
|
||||
require.NoError(ts.t, err, "failed to start dns resolver for "+ts.domains[0])
|
||||
ts.startup = result
|
||||
}
|
||||
|
||||
func (ts *TestServer) buildNamedConf() string {
|
||||
forwarders := "\n"
|
||||
if len(ts.forwarders) > 0 {
|
||||
forwarders = "\tforwarders {\n"
|
||||
for _, forwarder := range ts.forwarders {
|
||||
forwarders += "\t\t" + forwarder + ";\n"
|
||||
}
|
||||
forwarders += "\t};\n"
|
||||
}
|
||||
|
||||
zones := "\n"
|
||||
for _, domain := range ts.domains {
|
||||
zones += fmt.Sprintf("zone \"%s\" {\n", domain)
|
||||
zones += "\ttype primary;\n"
|
||||
zones += fmt.Sprintf("\tfile \"%s.zone\";\n", domain)
|
||||
zones += "\tallow-update {\n\t\tnone;\n\t};\n"
|
||||
zones += "\tnotify no;\n"
|
||||
zones += "};\n\n"
|
||||
}
|
||||
|
||||
// Reverse lookups are not handles as they're not presently necessary.
|
||||
|
||||
cfg := `options {
|
||||
directory "/var/cache/bind";
|
||||
|
||||
dnssec-validation no;
|
||||
|
||||
` + forwarders + `
|
||||
};
|
||||
|
||||
` + zones
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (ts *TestServer) buildZoneFile(target string) string {
|
||||
// One second TTL by default to allow quick refreshes.
|
||||
zone := "$TTL 1;\n"
|
||||
|
||||
ts.serial += 1
|
||||
zone += fmt.Sprintf("@\tIN\tSOA\tns.%v.\troot.%v.\t(\n", target, target)
|
||||
zone += fmt.Sprintf("\t\t\t%d;\n\t\t\t1;\n\t\t\t1;\n\t\t\t2;\n\t\t\t1;\n\t\t\t)\n\n", ts.serial)
|
||||
zone += fmt.Sprintf("@\tIN\tNS\tns%d.%v.\n", ts.serial, target)
|
||||
zone += fmt.Sprintf("ns%d.%v.\tIN\tA\t%v\n", ts.serial, target, "127.0.0.1")
|
||||
|
||||
for domain, records := range ts.records {
|
||||
if !strings.HasSuffix(domain, target) {
|
||||
continue
|
||||
}
|
||||
|
||||
for recordType, values := range records {
|
||||
for _, value := range values {
|
||||
zone += fmt.Sprintf("%s.\tIN\t%s\t%s\n", domain, recordType, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return zone
|
||||
}
|
||||
|
||||
func (ts *TestServer) PushConfig() {
|
||||
contents := docker.NewBuildContext()
|
||||
cfgPath := "/etc/bind/named.conf.options"
|
||||
namedCfg := ts.buildNamedConf()
|
||||
contents[cfgPath] = docker.PathContentsFromString(namedCfg)
|
||||
contents[cfgPath].SetOwners(0, 142) // root, bind
|
||||
|
||||
ts.t.Logf("Generated bind9 config (%s):\n%v\n", cfgPath, namedCfg)
|
||||
|
||||
for _, domain := range ts.domains {
|
||||
path := "/var/cache/bind/" + domain + ".zone"
|
||||
zoneFile := ts.buildZoneFile(domain)
|
||||
contents[path] = docker.PathContentsFromString(zoneFile)
|
||||
contents[path].SetOwners(0, 142) // root, bind
|
||||
|
||||
ts.t.Logf("Generated bind9 zone file for %v (%s):\n%v\n", domain, path, zoneFile)
|
||||
}
|
||||
|
||||
err := ts.runner.CopyTo(ts.startup.Container.ID, "/", contents)
|
||||
require.NoError(ts.t, err, "failed pushing updated configuration to container")
|
||||
|
||||
// Wait until our config has taken.
|
||||
corehelpers.RetryUntil(ts.t, 3*time.Second, func() error {
|
||||
// bind reloads based on file mtime, touch files before starting
|
||||
// to make sure it has been updated more recently than when the
|
||||
// last update was written. Then issue a new SIGHUP.
|
||||
for _, domain := range ts.domains {
|
||||
path := "/var/cache/bind/" + domain + ".zone"
|
||||
touchCmd := []string{"touch", path}
|
||||
|
||||
_, _, _, err := ts.runner.RunCmdWithOutput(ts.ctx, ts.startup.Container.ID, touchCmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update zone mtime: %w", err)
|
||||
}
|
||||
}
|
||||
ts.runner.DockerAPI.ContainerKill(ts.ctx, ts.startup.Container.ID, "SIGHUP")
|
||||
|
||||
// Connect to our bind resolver.
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
StrictErrors: false,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
return d.DialContext(ctx, network, ts.GetLocalAddr())
|
||||
},
|
||||
}
|
||||
|
||||
// last domain has the given serial number, which also appears in the
|
||||
// NS record so we can fetch it via Go.
|
||||
lastDomain := ts.domains[len(ts.domains)-1]
|
||||
records, err := resolver.LookupNS(ts.ctx, lastDomain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to lookup NS record for %v: %w", lastDomain, err)
|
||||
}
|
||||
|
||||
if len(records) != 1 {
|
||||
return fmt.Errorf("expected only 1 NS record for %v, got %v/%v", lastDomain, len(records), records)
|
||||
}
|
||||
|
||||
expectedNS := fmt.Sprintf("ns%d.%v.", ts.serial, lastDomain)
|
||||
if records[0].Host != expectedNS {
|
||||
return fmt.Errorf("expected to find NS %v, got %v indicating reload hadn't completed", expectedNS, records[0])
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (ts *TestServer) GetLocalAddr() string {
|
||||
return ts.startup.Config.Address()
|
||||
}
|
||||
|
||||
func (ts *TestServer) GetRemoteAddr() string {
|
||||
return fmt.Sprintf("%s:%d", ts.startup.StartResult.RealIP, 53)
|
||||
}
|
||||
|
||||
func (ts *TestServer) AddDomain(domain string) {
|
||||
for _, existing := range ts.domains {
|
||||
if existing == domain {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ts.domains = append(ts.domains, domain)
|
||||
}
|
||||
|
||||
func (ts *TestServer) AddRecord(domain string, record string, value string) {
|
||||
foundDomain := false
|
||||
for _, existing := range ts.domains {
|
||||
if strings.HasSuffix(domain, existing) {
|
||||
foundDomain = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundDomain {
|
||||
ts.t.Fatalf("cannot add record %v/%v :: [%v] -- no domain zone matching (%v)", record, domain, value, ts.domains)
|
||||
}
|
||||
|
||||
value = strings.TrimSpace(value)
|
||||
if _, present := ts.records[domain]; !present {
|
||||
ts.records[domain] = map[string][]string{}
|
||||
}
|
||||
|
||||
if values, present := ts.records[domain][record]; present {
|
||||
for _, candidate := range values {
|
||||
if candidate == value {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ts.records[domain][record] = append(ts.records[domain][record], value)
|
||||
}
|
||||
|
||||
func (ts *TestServer) RemoveAllRecords() {
|
||||
ts.records = map[string]map[string][]string{}
|
||||
}
|
||||
|
||||
func (ts *TestServer) Cleanup() {
|
||||
if ts.cleanup != nil {
|
||||
ts.cleanup()
|
||||
}
|
||||
if ts.startup != nil && ts.startup.Cleanup != nil {
|
||||
ts.startup.Cleanup()
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@ package pki
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/errutil"
|
||||
|
@ -22,6 +23,7 @@ type acmeConfigEntry struct {
|
|||
DefaultRole string `json:"default_role"`
|
||||
AllowNoAllowedDomains bool `json:"allow_no_allowed_domains"`
|
||||
AllowAnyDomain bool `json:"allow_any_domain"`
|
||||
DNSResolver string `json:"dns_resolver"`
|
||||
}
|
||||
|
||||
func (sc *storageContext) getAcmeConfig() (*acmeConfigEntry, error) {
|
||||
|
@ -87,6 +89,10 @@ func pathAcmeConfig(b *backend) *framework.Path {
|
|||
Type: framework.TypeBool,
|
||||
Description: `whether ACME will allow the use of roles with allow_any_name=true set.`,
|
||||
},
|
||||
"dns_resolver": {
|
||||
Type: framework.TypeString,
|
||||
Description: `DNS resolver to use for domain resolution on this mount. Defaults to using the default system resolver. Must be in the format <host>:<port>, with both parts mandatory.`,
|
||||
},
|
||||
},
|
||||
|
||||
Operations: map[logical.Operation]framework.OperationHandler{
|
||||
|
@ -132,6 +138,7 @@ func genResponseFromAcmeConfig(config *acmeConfigEntry) *logical.Response {
|
|||
"allowed_issuers": config.AllowedIssuers,
|
||||
"default_role": config.DefaultRole,
|
||||
"enabled": config.Enabled,
|
||||
"dns_resolver": config.DNSResolver,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -142,54 +149,56 @@ func genResponseFromAcmeConfig(config *acmeConfigEntry) *logical.Response {
|
|||
|
||||
func (b *backend) pathAcmeWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
sc := b.makeStorageContext(ctx, req.Storage)
|
||||
|
||||
config, err := sc.getAcmeConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
enabled := config.Enabled
|
||||
if enabledRaw, ok := d.GetOk("enabled"); ok {
|
||||
enabled = enabledRaw.(bool)
|
||||
config.Enabled = enabledRaw.(bool)
|
||||
}
|
||||
|
||||
allowAnyDomain := config.AllowAnyDomain
|
||||
if allowAnyDomainRaw, ok := d.GetOk("allow_any_domain"); ok {
|
||||
allowAnyDomain = allowAnyDomainRaw.(bool)
|
||||
config.AllowAnyDomain = allowAnyDomainRaw.(bool)
|
||||
}
|
||||
|
||||
allowedRoles := config.AllowedRoles
|
||||
if allowedRolesRaw, ok := d.GetOk("allowed_roles"); ok {
|
||||
allowedRoles = allowedRolesRaw.([]string)
|
||||
config.AllowedRoles = allowedRolesRaw.([]string)
|
||||
}
|
||||
|
||||
defaultRole := config.DefaultRole
|
||||
if defaultRoleRaw, ok := d.GetOk("default_role"); ok {
|
||||
defaultRole = defaultRoleRaw.(string)
|
||||
config.DefaultRole = defaultRoleRaw.(string)
|
||||
}
|
||||
|
||||
allowNoAllowedDomains := config.AllowNoAllowedDomains
|
||||
if allowNoAllowedDomainsRaw, ok := d.GetOk("allow_no_allowed_domains"); ok {
|
||||
allowNoAllowedDomains = allowNoAllowedDomainsRaw.(bool)
|
||||
config.AllowNoAllowedDomains = allowNoAllowedDomainsRaw.(bool)
|
||||
}
|
||||
|
||||
allowedIssuers := config.AllowedIssuers
|
||||
if allowedIssuersRaw, ok := d.GetOk("allowed_issuers"); ok {
|
||||
allowedIssuers = allowedIssuersRaw.([]string)
|
||||
config.AllowedIssuers = allowedIssuersRaw.([]string)
|
||||
}
|
||||
|
||||
newConfig := &acmeConfigEntry{
|
||||
Enabled: enabled,
|
||||
AllowAnyDomain: allowAnyDomain,
|
||||
AllowedRoles: allowedRoles,
|
||||
DefaultRole: defaultRole,
|
||||
AllowNoAllowedDomains: allowNoAllowedDomains,
|
||||
AllowedIssuers: allowedIssuers,
|
||||
if dnsResolverRaw, ok := d.GetOk("dns_resolver"); ok {
|
||||
config.DNSResolver = dnsResolverRaw.(string)
|
||||
if config.DNSResolver != "" {
|
||||
addr, _, err := net.SplitHostPort(config.DNSResolver)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse DNS resolver address: %w", err)
|
||||
}
|
||||
if addr != "" {
|
||||
return nil, fmt.Errorf("failed to parse DNS resolver address: got empty address")
|
||||
}
|
||||
if net.ParseIP(addr) != nil {
|
||||
return nil, fmt.Errorf("failed to parse DNS resolver address: expected IPv4/IPv6 address, likely got hostname")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = sc.setAcmeConfig(newConfig)
|
||||
err = sc.setAcmeConfig(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return genResponseFromAcmeConfig(newConfig), nil
|
||||
return genResponseFromAcmeConfig(config), nil
|
||||
}
|
||||
|
|
|
@ -276,7 +276,7 @@ func (d *Runner) StartNewService(ctx context.Context, addSuffix, forceLocalAddr
|
|||
wg.Wait()
|
||||
|
||||
if d.RunOptions.PostStart != nil {
|
||||
if err := d.RunOptions.PostStart(result.Container.ID, result.realIP); err != nil {
|
||||
if err := d.RunOptions.PostStart(result.Container.ID, result.RealIP); err != nil {
|
||||
return nil, "", fmt.Errorf("poststart failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
@ -295,7 +295,7 @@ func (d *Runner) StartNewService(ctx context.Context, addSuffix, forceLocalAddr
|
|||
bo.MaxInterval = time.Second * 5
|
||||
bo.MaxElapsedTime = 2 * time.Minute
|
||||
|
||||
pieces := strings.Split(result.addrs[0], ":")
|
||||
pieces := strings.Split(result.Addrs[0], ":")
|
||||
portInt, err := strconv.Atoi(pieces[1])
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
|
@ -327,25 +327,27 @@ func (d *Runner) StartNewService(ctx context.Context, addSuffix, forceLocalAddr
|
|||
}
|
||||
|
||||
return &Service{
|
||||
Config: config,
|
||||
Cleanup: cleanup,
|
||||
Container: result.Container,
|
||||
Config: config,
|
||||
Cleanup: cleanup,
|
||||
Container: result.Container,
|
||||
StartResult: result,
|
||||
}, result.Container.ID, nil
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
Config ServiceConfig
|
||||
Cleanup func()
|
||||
Container *types.ContainerJSON
|
||||
Config ServiceConfig
|
||||
Cleanup func()
|
||||
Container *types.ContainerJSON
|
||||
StartResult *StartResult
|
||||
}
|
||||
|
||||
type startResult struct {
|
||||
type StartResult struct {
|
||||
Container *types.ContainerJSON
|
||||
addrs []string
|
||||
realIP string
|
||||
Addrs []string
|
||||
RealIP string
|
||||
}
|
||||
|
||||
func (d *Runner) Start(ctx context.Context, addSuffix, forceLocalAddr bool) (*startResult, error) {
|
||||
func (d *Runner) Start(ctx context.Context, addSuffix, forceLocalAddr bool) (*StartResult, error) {
|
||||
name := d.RunOptions.ContainerName
|
||||
if addSuffix {
|
||||
suffix, err := uuid.GenerateUUID()
|
||||
|
@ -458,10 +460,10 @@ func (d *Runner) Start(ctx context.Context, addSuffix, forceLocalAddr bool) (*st
|
|||
realIP = inspect.NetworkSettings.Networks[d.RunOptions.NetworkName].IPAddress
|
||||
}
|
||||
|
||||
return &startResult{
|
||||
return &StartResult{
|
||||
Container: &inspect,
|
||||
addrs: addrs,
|
||||
realIP: realIP,
|
||||
Addrs: addrs,
|
||||
RealIP: realIP,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -619,6 +621,8 @@ func (d *Runner) RunCmdInBackground(ctx context.Context, container string, cmd [
|
|||
type PathContents interface {
|
||||
UpdateHeader(header *tar.Header) error
|
||||
Get() ([]byte, error)
|
||||
SetMode(mode int64)
|
||||
SetOwners(uid int, gid int)
|
||||
}
|
||||
|
||||
type FileContents struct {
|
||||
|
@ -639,8 +643,17 @@ func (b FileContents) Get() ([]byte, error) {
|
|||
return b.Data, nil
|
||||
}
|
||||
|
||||
func (b *FileContents) SetMode(mode int64) {
|
||||
b.Mode = mode
|
||||
}
|
||||
|
||||
func (b *FileContents) SetOwners(uid int, gid int) {
|
||||
b.UID = uid
|
||||
b.GID = gid
|
||||
}
|
||||
|
||||
func PathContentsFromBytes(data []byte) PathContents {
|
||||
return FileContents{
|
||||
return &FileContents{
|
||||
Data: data,
|
||||
Mode: 0o644,
|
||||
}
|
||||
|
@ -680,7 +693,7 @@ func BuildContextFromTarball(reader io.Reader) (BuildContext, error) {
|
|||
return nil, fmt.Errorf("unexpectedly short read on tarball: %v of %v", read, header.Size)
|
||||
}
|
||||
|
||||
bCtx[header.Name] = FileContents{
|
||||
bCtx[header.Name] = &FileContents{
|
||||
Data: data,
|
||||
Mode: header.Mode,
|
||||
UID: header.Uid,
|
||||
|
@ -697,8 +710,14 @@ func (bCtx *BuildContext) ToTarball() (io.Reader, error) {
|
|||
tarBuilder := tar.NewWriter(buffer)
|
||||
defer tarBuilder.Close()
|
||||
|
||||
now := time.Now()
|
||||
for filepath, contents := range *bCtx {
|
||||
fileHeader := &tar.Header{Name: filepath}
|
||||
fileHeader := &tar.Header{
|
||||
Name: filepath,
|
||||
ModTime: now,
|
||||
AccessTime: now,
|
||||
ChangeTime: now,
|
||||
}
|
||||
if contents == nil && !strings.HasSuffix(filepath, "/") {
|
||||
return nil, fmt.Errorf("expected file path (%v) to have trailing / due to nil contents, indicating directory", filepath)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue