package consul import ( "context" "fmt" "net" "strings" "time" "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib" "github.com/hashicorp/go-hclog" "github.com/miekg/dns" ) const ( dummyTrustDomain = "dummy.trustdomain" retryJitterWindow = 30 * time.Second ) func (c *Client) autoEncryptCSR(extraDNSSANs []string, extraIPSANs []net.IP) (string, string, error) { // We don't provide the correct host here, because we don't know any // better at this point. Apart from the domain, we would need the // ClusterID, which we don't have. This is why we go with // dummyTrustDomain the first time. Subsequent CSRs will have the // correct TrustDomain. id := &connect.SpiffeIDAgent{ Host: dummyTrustDomain, Datacenter: c.config.Datacenter, Agent: c.config.NodeName, } conf, err := c.config.CAConfig.GetCommonConfig() if err != nil { return "", "", err } if conf.PrivateKeyType == "" { conf.PrivateKeyType = connect.DefaultPrivateKeyType } if conf.PrivateKeyBits == 0 { conf.PrivateKeyBits = connect.DefaultPrivateKeyBits } // Create a new private key pk, pkPEM, err := connect.GeneratePrivateKeyWithConfig(conf.PrivateKeyType, conf.PrivateKeyBits) if err != nil { return "", "", err } dnsNames := append([]string{"localhost"}, extraDNSSANs...) ipAddresses := append([]net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}, extraIPSANs...) // Create a CSR. // // The Common Name includes the dummy trust domain for now but Server will // override this when it is signed anyway so it's OK. cn := connect.AgentCN(c.config.NodeName, dummyTrustDomain) csr, err := connect.CreateCSR(id, cn, pk, dnsNames, ipAddresses) if err != nil { return "", "", err } return pkPEM, csr, nil } func (c *Client) RequestAutoEncryptCerts(ctx context.Context, servers []string, port int, token string, extraDNSSANs []string, extraIPSANs []net.IP) (*structs.SignedResponse, string, error) { errFn := func(err error) (*structs.SignedResponse, string, error) { return nil, "", err } // Check if we know about a server already through gossip. Depending on // how the agent joined, there might already be one. Also in case this // gets called because the cert expired. server := c.routers.FindServer() if server != nil { servers = []string{server.Addr.String()} } if len(servers) == 0 { return errFn(fmt.Errorf("No servers to request AutoEncrypt.Sign")) } pkPEM, csr, err := c.autoEncryptCSR(extraDNSSANs, extraIPSANs) if err != nil { return errFn(err) } // Prepare request and response so that it can be passed to // RPCInsecure. args := structs.CASignRequest{ WriteRequest: structs.WriteRequest{Token: token}, Datacenter: c.config.Datacenter, CSR: csr, } var reply structs.SignedResponse // Retry implementation modeled after https://github.com/hashicorp/consul/pull/5228. // TLDR; there is a 30s window from which a random time is picked. // Repeat until the call is successful. attempts := 0 for { select { case <-ctx.Done(): return errFn(fmt.Errorf("aborting AutoEncrypt because interrupted: %w", ctx.Err())) default: } // Translate host to net.TCPAddr to make life easier for // RPCInsecure. for _, s := range servers { ips, err := resolveAddr(s, c.logger) if err != nil { c.logger.Warn("AutoEncrypt resolveAddr failed", "error", err) continue } for _, ip := range ips { addr := net.TCPAddr{IP: ip, Port: port} if err = c.connPool.RPC(c.config.Datacenter, c.config.NodeName, &addr, "AutoEncrypt.Sign", &args, &reply); err == nil { return &reply, pkPEM, nil } else { c.logger.Warn("AutoEncrypt failed", "error", err) } } } attempts++ delay := lib.RandomStagger(retryJitterWindow) interval := (time.Duration(attempts) * delay) + delay c.logger.Warn("retrying AutoEncrypt", "retry_interval", interval) select { case <-time.After(interval): continue case <-ctx.Done(): return errFn(fmt.Errorf("aborting AutoEncrypt because interrupted: %w", ctx.Err())) case <-c.shutdownCh: return errFn(fmt.Errorf("aborting AutoEncrypt because shutting down")) } } } func missingPortError(host string, err error) bool { return err != nil && err.Error() == fmt.Sprintf("address %s: missing port in address", host) } // resolveAddr is used to resolve the host into IPs and error. func resolveAddr(rawHost string, logger hclog.Logger) ([]net.IP, error) { host, _, err := net.SplitHostPort(rawHost) if err != nil { // In case we encounter this error, we proceed with the // rawHost. This is fine since -start-join and -retry-join // take only hosts anyways and this is an expected case. if missingPortError(rawHost, err) { host = rawHost } else { return nil, err } } if ip := net.ParseIP(host); ip != nil { return []net.IP{ip}, nil } // First try TCP so we have the best chance for the largest list of // hosts to join. If this fails it's not fatal since this isn't a standard // way to query DNS, and we have a fallback below. if ips, err := tcpLookupIP(host, logger); err != nil { logger.Debug("TCP-first lookup failed for host, falling back to UDP", "host", host, "error", err) } else if len(ips) > 0 { return ips, nil } // If TCP didn't yield anything then use the normal Go resolver which // will try UDP, then might possibly try TCP again if the UDP response // indicates it was truncated. ips, err := net.LookupIP(host) if err != nil { return nil, err } return ips, nil } // tcpLookupIP is a helper to initiate a TCP-based DNS lookup for the given host. // The built-in Go resolver will do a UDP lookup first, and will only use TCP if // the response has the truncate bit set, which isn't common on DNS servers like // Consul's. By doing the TCP lookup directly, we get the best chance for the // largest list of hosts to join. Since joins are relatively rare events, it's ok // to do this rather expensive operation. func tcpLookupIP(host string, logger hclog.Logger) ([]net.IP, error) { // Don't attempt any TCP lookups against non-fully qualified domain // names, since those will likely come from the resolv.conf file. if !strings.Contains(host, ".") { return nil, nil } // Make sure the domain name is terminated with a dot (we know there's // at least one character at this point). dn := host if dn[len(dn)-1] != '.' { dn = dn + "." } // See if we can find a server to try. cc, err := dns.ClientConfigFromFile("/etc/resolv.conf") if err != nil { return nil, err } if len(cc.Servers) > 0 { // Do the lookup. c := new(dns.Client) c.Net = "tcp" msg := new(dns.Msg) msg.SetQuestion(dn, dns.TypeANY) in, _, err := c.Exchange(msg, cc.Servers[0]) if err != nil { return nil, err } // Handle any IPs we get back that we can attempt to join. var ips []net.IP for _, r := range in.Answer { switch rr := r.(type) { case (*dns.A): ips = append(ips, rr.A) case (*dns.AAAA): ips = append(ips, rr.AAAA) case (*dns.CNAME): logger.Debug("Ignoring CNAME RR in TCP-first answer for host", "host", host) } } return ips, nil } return nil, nil }