cli: improve the file safety of 'consul tls' subcommands (#7186)

- also fixing the signature of file.WriteAtomicWithPerms
This commit is contained in:
R.B. Boyer 2020-01-31 10:12:36 -06:00 committed by GitHub
parent 0fd7b4e969
commit 1d7e4f7de5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 359 additions and 127 deletions

View File

@ -1256,7 +1256,7 @@ func (s *HTTPServer) AgentToken(resp http.ResponseWriter, req *http.Request) (in
return nil, fmt.Errorf("Failed to marshal tokens for persistence: %v", err)
}
if err := file.WriteAtomicWithPerms(filepath.Join(s.agent.config.DataDir, tokensPath), data, 0600); err != nil {
if err := file.WriteAtomicWithPerms(filepath.Join(s.agent.config.DataDir, tokensPath), data, 0700, 0600); err != nil {
s.agent.logger.Warn("failed to persist tokens", "error", err)
return nil, fmt.Errorf("Failed to persist tokens - %v", err)
}

View File

@ -126,7 +126,7 @@ func (c *cmd) Run(args []string) int {
func (c *cmd) writeToSink(tok *api.ACLToken) error {
payload := []byte(tok.SecretID)
return file.WriteAtomicWithPerms(c.tokenSinkFile, payload, 0600)
return file.WriteAtomicWithPerms(c.tokenSinkFile, payload, 0755, 0600)
}
func (c *cmd) Synopsis() string {

View File

@ -3,10 +3,10 @@ package create
import (
"flag"
"fmt"
"os"
"github.com/hashicorp/consul/command/flags"
"github.com/hashicorp/consul/command/tls"
"github.com/hashicorp/consul/lib/file"
"github.com/hashicorp/consul/tlsutil"
"github.com/mitchellh/cli"
)
@ -70,26 +70,30 @@ func (c *cmd) Run(args []string) int {
s, pk, err := tlsutil.GeneratePrivateKey()
if err != nil {
c.UI.Error(err.Error())
return 1
}
constraints := []string{}
if c.constraint {
constraints = append(c.additionalConstraints, []string{c.domain, "localhost"}...)
}
ca, err := tlsutil.GenerateCA(s, sn, c.days, constraints)
if err != nil {
c.UI.Error(err.Error())
return 1
}
caFile, err := os.Create(certFileName)
if err != nil {
if err := file.WriteAtomicWithPerms(certFileName, []byte(ca), 0755, 0666); err != nil {
c.UI.Error(err.Error())
return 1
}
caFile.WriteString(ca)
c.UI.Output("==> Saved " + certFileName)
pkFile, err := os.Create(pkFileName)
if err != nil {
if err := file.WriteAtomicWithPerms(pkFileName, []byte(pk), 0755, 0666); err != nil {
c.UI.Error(err.Error())
return 1
}
pkFile.WriteString(pk)
c.UI.Output("==> Saved " + pkFileName)
return 0

View File

@ -1,9 +1,10 @@
package create
import (
"crypto"
"crypto/x509"
"io/ioutil"
"os"
"path"
"strings"
"testing"
"time"
@ -22,81 +23,104 @@ func TestValidateCommand_noTabs(t *testing.T) {
}
func TestCACreateCommand(t *testing.T) {
require := require.New(t)
previousDirectory, err := os.Getwd()
require.NoError(err)
testDir := testutil.TempDir(t, "ca-create")
defer os.RemoveAll(testDir)
defer os.Chdir(previousDirectory)
os.Chdir(testDir)
defer switchToTempDir(t, testDir)()
ui := cli.NewMockUi()
cmd := New(ui)
require.Equal(0, cmd.Run(nil), "ca create should exit 0")
errOutput := ui.ErrorWriter.String()
require.Equal("", errOutput)
caPem := path.Join(testDir, "consul-agent-ca.pem")
require.FileExists(caPem)
certData, err := ioutil.ReadFile(caPem)
require.NoError(err)
cert, err := connect.ParseCert(string(certData))
require.NoError(err)
require.NotNil(cert)
require.Equal(1825*24*time.Hour, time.Until(cert.NotAfter).Round(24*time.Hour))
require.False(cert.PermittedDNSDomainsCritical)
require.Len(cert.PermittedDNSDomains, 0)
}
func TestCACreateCommandWithOptions(t *testing.T) {
require := require.New(t)
previousDirectory, err := os.Getwd()
require.NoError(err)
testDir := testutil.TempDir(t, "ca-create")
defer os.RemoveAll(testDir)
defer os.Chdir(previousDirectory)
os.Chdir(testDir)
ui := cli.NewMockUi()
cmd := New(ui)
args := []string{
type testcase struct {
name string
args []string
caPath string
keyPath string
extraCheck func(t *testing.T, cert *x509.Certificate)
}
// The following subtests must run serially.
cases := []testcase{
{"ca defaults",
nil,
"consul-agent-ca.pem",
"consul-agent-ca-key.pem",
func(t *testing.T, cert *x509.Certificate) {
require.Equal(t, 1825*24*time.Hour, time.Until(cert.NotAfter).Round(24*time.Hour))
require.False(t, cert.PermittedDNSDomainsCritical)
require.Len(t, cert.PermittedDNSDomains, 0)
},
},
{"ca options",
[]string{
"-days=365",
"-name-constraint=true",
"-domain=foo",
"-additional-name-constraint=bar",
},
"foo-agent-ca.pem",
"foo-agent-ca-key.pem",
func(t *testing.T, cert *x509.Certificate) {
require.Equal(t, 365*24*time.Hour, time.Until(cert.NotAfter).Round(24*time.Hour))
require.True(t, cert.PermittedDNSDomainsCritical)
require.Len(t, cert.PermittedDNSDomains, 3)
require.ElementsMatch(t, cert.PermittedDNSDomains, []string{"foo", "localhost", "bar"})
},
},
}
for _, tc := range cases {
tc := tc
require.True(t, t.Run(tc.name, func(t *testing.T) {
ui := cli.NewMockUi()
cmd := New(ui)
require.Equal(t, 0, cmd.Run(tc.args))
require.Equal(t, "", ui.ErrorWriter.String())
cert, _ := expectFiles(t, tc.caPath, tc.keyPath)
require.Contains(t, cert.Subject.CommonName, "Consul Agent CA")
require.True(t, cert.BasicConstraintsValid)
require.Equal(t, x509.KeyUsageCertSign|x509.KeyUsageCRLSign|x509.KeyUsageDigitalSignature, cert.KeyUsage)
require.True(t, cert.IsCA)
require.Equal(t, cert.AuthorityKeyId, cert.SubjectKeyId)
tc.extraCheck(t, cert)
}))
}
require.Equal(0, cmd.Run(args), "ca create should exit 0")
errOutput := ui.ErrorWriter.String()
require.Equal("", errOutput)
caPem := path.Join(testDir, "foo-agent-ca.pem")
require.FileExists(caPem)
certData, err := ioutil.ReadFile(caPem)
require.NoError(err)
cert, err := connect.ParseCert(string(certData))
require.NoError(err)
require.NotNil(cert)
require.Equal(365*24*time.Hour, time.Until(cert.NotAfter).Round(24*time.Hour))
require.True(cert.PermittedDNSDomainsCritical)
require.Len(cert.PermittedDNSDomains, 3)
require.ElementsMatch(cert.PermittedDNSDomains, []string{"foo", "localhost", "bar"})
}
func expectFiles(t *testing.T, caPath, keyPath string) (*x509.Certificate, crypto.Signer) {
t.Helper()
require.FileExists(t, caPath)
require.FileExists(t, keyPath)
caData, err := ioutil.ReadFile(caPath)
require.NoError(t, err)
keyData, err := ioutil.ReadFile(keyPath)
require.NoError(t, err)
ca, err := connect.ParseCert(string(caData))
require.NoError(t, err)
require.NotNil(t, ca)
signer, err := connect.ParseSigner(string(keyData))
require.NoError(t, err)
require.NotNil(t, signer)
return ca, signer
}
// switchToTempDir is meant to be used in a defer statement like:
//
// defer switchToTempDir(t, testDir)()
//
// This exploits the fact that the body of a defer is evaluated
// EXCEPT for the final function call invocation inline with the code
// where it is found. Only the final evaluation happens in the defer
// at a later time. In this case it means we switch to the temp
// directory immediately and defer switching back in one line of test
// code.
func switchToTempDir(t *testing.T, testDir string) func() {
previousDirectory, err := os.Getwd()
require.NoError(t, err)
require.NoError(t, os.Chdir(testDir))
return func() {
os.Chdir(previousDirectory)
}
}

View File

@ -6,11 +6,11 @@ import (
"fmt"
"io/ioutil"
"net"
"os"
"strings"
"github.com/hashicorp/consul/command/flags"
"github.com/hashicorp/consul/command/tls"
"github.com/hashicorp/consul/lib/file"
"github.com/hashicorp/consul/tlsutil"
"github.com/mitchellh/cli"
)
@ -98,10 +98,14 @@ func (c *cmd) Run(args []string) int {
if c.server {
name = fmt.Sprintf("server.%s.%s", c.dc, c.domain)
DNSNames = append(DNSNames, []string{name, "localhost"}...)
DNSNames = append(DNSNames, name)
DNSNames = append(DNSNames, "localhost")
IPAddresses = append(IPAddresses, net.ParseIP("127.0.0.1"))
extKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}
prefix = fmt.Sprintf("%s-server-%s", c.dc, c.domain)
} else if c.client {
name = fmt.Sprintf("client.%s.%s", c.dc, c.domain)
DNSNames = append(DNSNames, []string{name, "localhost"}...)
@ -174,24 +178,20 @@ func (c *cmd) Run(args []string) int {
}
if err = tlsutil.Verify(string(cert), pub, name); err != nil {
c.UI.Error("==> " + err.Error())
return 1
}
certFile, err := os.Create(certFileName)
if err != nil {
c.UI.Error(err.Error())
return 1
}
certFile.WriteString(pub)
if err := file.WriteAtomicWithPerms(certFileName, []byte(pub), 0755, 0666); err != nil {
c.UI.Error(err.Error())
return 1
}
c.UI.Output("==> Saved " + certFileName)
pkFile, err := os.Create(pkFileName)
if err != nil {
if err := file.WriteAtomicWithPerms(pkFileName, []byte(priv), 0755, 0666); err != nil {
c.UI.Error(err.Error())
return 1
}
pkFile.WriteString(priv)
c.UI.Output("==> Saved " + pkFileName)
return 0

View File

@ -1,9 +1,11 @@
package create
import (
"crypto"
"crypto/x509"
"io/ioutil"
"net"
"os"
"path"
"strings"
"testing"
@ -12,7 +14,7 @@ import (
"github.com/mitchellh/cli"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/command/tls/ca/create"
caCreate "github.com/hashicorp/consul/command/tls/ca/create"
)
func TestValidateCommand_noTabs(t *testing.T) {
@ -22,59 +24,240 @@ func TestValidateCommand_noTabs(t *testing.T) {
}
}
func TestTlsCertCreateCommand_fileCreate(t *testing.T) {
require := require.New(t)
func TestTlsCertCreateCommand_InvalidArgs(t *testing.T) {
t.Parallel()
previousDirectory, err := os.Getwd()
require.NoError(err)
type testcase struct {
args []string
expectErr string
}
testDir := testutil.TempDir(t, "tls")
defer os.RemoveAll(testDir)
defer os.Chdir(previousDirectory)
cases := map[string]testcase{
"no args (ca/key inferred)": {[]string{},
"Please provide either -server, -client, or -cli"},
"no ca": {[]string{"-ca", "", "-key", ""},
"Please provide the ca"},
"no key": {[]string{"-ca", "foo.pem", "-key", ""},
"Please provide the key"},
os.Chdir(testDir)
"server+client+cli": {[]string{"-server", "-client", "-cli"},
"Please provide either -server, -client, or -cli"},
"server+client": {[]string{"-server", "-client"},
"Please provide either -server, -client, or -cli"},
"server+cli": {[]string{"-server", "-cli"},
"Please provide either -server, -client, or -cli"},
"client+cli": {[]string{"-client", "-cli"},
"Please provide either -server, -client, or -cli"},
}
for name, tc := range cases {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
ui := cli.NewMockUi()
cmd := New(ui)
require.NotEqual(t, 0, cmd.Run(tc.args))
got := ui.ErrorWriter.String()
if tc.expectErr == "" {
require.NotEmpty(t, got) // don't care
} else {
require.Contains(t, got, tc.expectErr)
}
})
}
}
func TestTlsCertCreateCommand_fileCreate(t *testing.T) {
testDir := testutil.TempDir(t, "tls")
defer os.RemoveAll(testDir)
defer switchToTempDir(t, testDir)()
// Setup CA keys
createCA(t, "consul")
createCA(t, "nomad")
caPath := path.Join(testDir, "consul-agent-ca.pem")
require.FileExists(caPath)
args := []string{
"-server",
type testcase struct {
name string
typ string
args []string
certPath string
keyPath string
expectCN string
expectDNS []string
expectIP []net.IP
}
require.Equal(0, cmd.Run(args))
require.Equal("", ui.ErrorWriter.String())
// The following subtests must run serially.
cases := []testcase{
{"server0",
"server",
[]string{"-server"},
"dc1-server-consul-0.pem",
"dc1-server-consul-0-key.pem",
"server.dc1.consul",
[]string{
"server.dc1.consul",
"localhost",
},
[]net.IP{{127, 0, 0, 1}},
},
{"server1",
"server",
[]string{"-server"},
"dc1-server-consul-1.pem",
"dc1-server-consul-1-key.pem",
"server.dc1.consul",
[]string{
"server.dc1.consul",
"localhost",
},
[]net.IP{{127, 0, 0, 1}},
},
{"server0-dc2-altdomain",
"server",
[]string{"-server", "-dc", "dc2", "-domain", "nomad"},
"dc2-server-nomad-0.pem",
"dc2-server-nomad-0-key.pem",
"server.dc2.nomad",
[]string{
"server.dc2.nomad",
"localhost",
},
[]net.IP{{127, 0, 0, 1}},
},
{"client0",
"client",
[]string{"-client"},
"dc1-client-consul-0.pem",
"dc1-client-consul-0-key.pem",
"client.dc1.consul",
[]string{
"client.dc1.consul",
"localhost",
},
[]net.IP{{127, 0, 0, 1}},
},
{"client1",
"client",
[]string{"-client"},
"dc1-client-consul-1.pem",
"dc1-client-consul-1-key.pem",
"client.dc1.consul",
[]string{
"client.dc1.consul",
"localhost",
},
[]net.IP{{127, 0, 0, 1}},
},
{"client0-dc2-altdomain",
"client",
[]string{"-client", "-dc", "dc2", "-domain", "nomad"},
"dc2-client-nomad-0.pem",
"dc2-client-nomad-0-key.pem",
"client.dc2.nomad",
[]string{
"client.dc2.nomad",
"localhost",
},
[]net.IP{{127, 0, 0, 1}},
},
{"cli0",
"cli",
[]string{"-cli"},
"dc1-cli-consul-0.pem",
"dc1-cli-consul-0-key.pem",
"cli.dc1.consul",
[]string{
"cli.dc1.consul",
"localhost",
},
nil,
},
{"cli1",
"cli",
[]string{"-cli"},
"dc1-cli-consul-1.pem",
"dc1-cli-consul-1-key.pem",
"cli.dc1.consul",
[]string{
"cli.dc1.consul",
"localhost",
},
nil,
},
{"cli0-dc2-altdomain",
"cli",
[]string{"-cli", "-dc", "dc2", "-domain", "nomad"},
"dc2-cli-nomad-0.pem",
"dc2-cli-nomad-0-key.pem",
"cli.dc2.nomad",
[]string{
"cli.dc2.nomad",
"localhost",
},
nil,
},
}
certPath := path.Join(testDir, "dc1-server-consul-0.pem")
keyPath := path.Join(testDir, "dc1-server-consul-0-key.pem")
for _, tc := range cases {
tc := tc
require.True(t, t.Run(tc.name, func(t *testing.T) {
ui := cli.NewMockUi()
cmd := New(ui)
require.Equal(t, 0, cmd.Run(tc.args))
require.Equal(t, "", ui.ErrorWriter.String())
require.FileExists(certPath)
require.FileExists(keyPath)
cert, _ := expectFiles(t, tc.certPath, tc.keyPath)
require.Equal(t, tc.expectCN, cert.Subject.CommonName)
require.True(t, cert.BasicConstraintsValid)
require.Equal(t, x509.KeyUsageDigitalSignature|x509.KeyUsageKeyEncipherment, cert.KeyUsage)
switch tc.typ {
case "server":
require.Equal(t,
[]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
cert.ExtKeyUsage)
case "client":
require.Equal(t,
[]x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
cert.ExtKeyUsage)
case "cli":
require.Len(t, cert.ExtKeyUsage, 0)
}
require.False(t, cert.IsCA)
require.Equal(t, tc.expectDNS, cert.DNSNames)
require.Equal(t, tc.expectIP, cert.IPAddresses)
}))
}
}
func expectFiles(t *testing.T, certPath, keyPath string) (*x509.Certificate, crypto.Signer) {
t.Helper()
require.FileExists(t, certPath)
require.FileExists(t, keyPath)
certData, err := ioutil.ReadFile(certPath)
require.NoError(err)
require.NoError(t, err)
keyData, err := ioutil.ReadFile(keyPath)
require.NoError(err)
require.NoError(t, err)
cert, err := connect.ParseCert(string(certData))
require.NoError(err)
require.NotNil(cert)
require.NoError(t, err)
require.NotNil(t, cert)
signer, err := connect.ParseSigner(string(keyData))
require.NoError(err)
require.NotNil(signer)
require.NoError(t, err)
require.NotNil(t, signer)
// TODO - maybe we should validate some certs here.
return cert, signer
}
func createCA(t *testing.T, domain string) {
t.Helper()
ui := cli.NewMockUi()
caCmd := create.New(ui)
caCmd := caCreate.New(ui)
args := []string{
"-domain=" + domain,
@ -82,4 +265,25 @@ func createCA(t *testing.T, domain string) {
require.Equal(t, 0, caCmd.Run(args))
require.Equal(t, "", ui.ErrorWriter.String())
require.FileExists(t, "consul-agent-ca.pem")
}
// switchToTempDir is meant to be used in a defer statement like:
//
// defer switchToTempDir(t, testDir)()
//
// This exploits the fact that the body of a defer is evaluated
// EXCEPT for the final function call invocation inline with the code
// where it is found. Only the final evaluation happens in the defer
// at a later time. In this case it means we switch to the temp
// directory immediately and defer switching back in one line of test
// code.
func switchToTempDir(t *testing.T, testDir string) func() {
previousDirectory, err := os.Getwd()
require.NoError(t, err)
require.NoError(t, os.Chdir(testDir))
return func() {
os.Chdir(previousDirectory)
}
}

View File

@ -11,10 +11,10 @@ import (
// WriteAtomic writes the given contents to a temporary file in the same
// directory, does an fsync and then renames the file to its real path
func WriteAtomic(path string, contents []byte) error {
return WriteAtomicWithPerms(path, contents, 0700)
return WriteAtomicWithPerms(path, contents, 0700, 0600)
}
func WriteAtomicWithPerms(path string, contents []byte, permissions os.FileMode) error {
func WriteAtomicWithPerms(path string, contents []byte, dirPerms, filePerms os.FileMode) error {
uuid, err := uuid.GenerateUUID()
if err != nil {
@ -22,10 +22,10 @@ func WriteAtomicWithPerms(path string, contents []byte, permissions os.FileMode)
}
tempPath := fmt.Sprintf("%s-%s.tmp", path, uuid)
if err := os.MkdirAll(filepath.Dir(path), permissions); err != nil {
if err := os.MkdirAll(filepath.Dir(path), dirPerms); err != nil {
return err
}
fh, err := os.OpenFile(tempPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
fh, err := os.OpenFile(tempPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, filePerms)
if err != nil {
return err
}