Add cached OCSP client support to Cert Auth (#17093)

* wip

* Add cached OCSP client support to Cert Auth

* ->pointer

* Code cleanup

* Fix unit tests

* Use an LRU cache, and only persist up to 1000 of the most recently used values to stay under the storage entry limit

* Fix caching, add fail open mode parameter to cert auth roles

* reduce logging

* Add the retry client and GET then POST logic

* Drop persisted cache, make cache size configurable, allow for parallel testing of multiple servers

* dead code

* Update builtin/credential/cert/path_certs.go

Co-authored-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Hook invalidate to reinit the ocsp cache size

* locking

* Conditionally init the ocsp client

* Remove cache size config from cert configs, it's a backend global

* Add field

* Remove strangely complex validity logic

* Address more feedback

* Rework error returning logic

* More edge cases

* MORE edge cases

* Add a test matrix with a builtin responder

* changelog

* Use an atomic for configUpdated

* Actually use ocsp_enabled, and bind to a random port for testing

* Update builtin/credential/cert/path_login.go

Co-authored-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Refactor unit tests

* Add status to cache

* Make some functions private

* Rename for testing, and attribute

* Up to date gofumpt

* remove hash from key, and disable the vault dependent unit test

* Comment out TestMultiOCSP

* imports

* more imports

* Address semgrep results

* Attempt to pass some sort of logging to test_responder

* fix overzealous search&replace

Co-authored-by: Alexander Scheel <alex.scheel@hashicorp.com>
This commit is contained in:
Scott Miller 2022-11-21 10:39:24 -06:00 committed by GitHub
parent f58990677f
commit b51b2a7027
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 2287 additions and 60 deletions

View File

@ -8,10 +8,13 @@ import (
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/ocsp"
"github.com/hashicorp/vault/sdk/logical"
)
@ -20,6 +23,13 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend,
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
bConf, err := b.Config(ctx, conf.StorageView)
if err != nil {
return nil, err
}
if bConf != nil {
b.updatedConfig(bConf)
}
if err := b.lockThenpopulateCRLs(ctx, conf.StorageView); err != nil {
return nil, err
}
@ -50,7 +60,6 @@ func Backend() *backend {
}
b.crlUpdateMutex = &sync.RWMutex{}
return &b
}
@ -58,8 +67,11 @@ type backend struct {
*framework.Backend
MapCertId *framework.PathMap
crls map[string]CRLInfo
crlUpdateMutex *sync.RWMutex
crls map[string]CRLInfo
crlUpdateMutex *sync.RWMutex
ocspClientMutex sync.RWMutex
ocspClient *ocsp.Client
configUpdated atomic.Bool
}
func (b *backend) invalidate(_ context.Context, key string) {
@ -68,9 +80,25 @@ func (b *backend) invalidate(_ context.Context, key string) {
b.crlUpdateMutex.Lock()
defer b.crlUpdateMutex.Unlock()
b.crls = nil
case key == "config":
b.configUpdated.Store(true)
}
}
func (b *backend) initOCSPClient(cacheSize int) {
b.ocspClient = ocsp.New(func() hclog.Logger {
return b.Logger()
}, cacheSize)
}
func (b *backend) updatedConfig(config *config) error {
b.ocspClientMutex.Lock()
defer b.ocspClientMutex.Unlock()
b.initOCSPClient(config.OcspCacheSize)
b.configUpdated.Store(false)
return nil
}
func (b *backend) fetchCRL(ctx context.Context, storage logical.Storage, name string, crl *CRLInfo) error {
response, err := http.Get(crl.CDP.Url)
if err != nil {
@ -105,6 +133,19 @@ func (b *backend) updateCRLs(ctx context.Context, req *logical.Request) error {
return errs.ErrorOrNil()
}
func (b *backend) storeConfig(ctx context.Context, storage logical.Storage, config *config) error {
entry, err := logical.StorageEntryJSON("config", config)
if err != nil {
return err
}
if err := storage.Put(ctx, entry); err != nil {
return err
}
b.updatedConfig(config)
return nil
}
const backendHelp = `
The "cert" credential provider allows authentication using
TLS client certificates. A client connects to Vault and uses

View File

@ -1092,12 +1092,13 @@ func TestBackend_CRLs(t *testing.T) {
}
func testFactory(t *testing.T) logical.Backend {
storage := &logical.InmemStorage{}
b, err := Factory(context.Background(), &logical.BackendConfig{
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: 1000 * time.Second,
MaxLeaseTTLVal: 1800 * time.Second,
},
StorageView: &logical.InmemStorage{},
StorageView: storage,
})
if err != nil {
t.Fatalf("error: %s", err)
@ -1923,27 +1924,33 @@ type allowed struct {
metadata_ext string // allowed metadata extensions to add to identity alias
}
func testAccStepCert(
t *testing.T, name string, cert []byte, policies string, testData allowed, expectError bool,
) logicaltest.TestStep {
func testAccStepCert(t *testing.T, name string, cert []byte, policies string, testData allowed, expectError bool) logicaltest.TestStep {
return testAccStepCertWithExtraParams(t, name, cert, policies, testData, expectError, nil)
}
func testAccStepCertWithExtraParams(t *testing.T, name string, cert []byte, policies string, testData allowed, expectError bool, extraParams map[string]interface{}) logicaltest.TestStep {
data := map[string]interface{}{
"certificate": string(cert),
"policies": policies,
"display_name": name,
"allowed_names": testData.names,
"allowed_common_names": testData.common_names,
"allowed_dns_sans": testData.dns,
"allowed_email_sans": testData.emails,
"allowed_uri_sans": testData.uris,
"allowed_organizational_units": testData.organizational_units,
"required_extensions": testData.ext,
"allowed_metadata_extensions": testData.metadata_ext,
"lease": 1000,
}
for k, v := range extraParams {
data[k] = v
}
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "certs/" + name,
ErrorOk: expectError,
Data: map[string]interface{}{
"certificate": string(cert),
"policies": policies,
"display_name": name,
"allowed_names": testData.names,
"allowed_common_names": testData.common_names,
"allowed_dns_sans": testData.dns,
"allowed_email_sans": testData.emails,
"allowed_uri_sans": testData.uris,
"allowed_organizational_units": testData.organizational_units,
"required_extensions": testData.ext,
"allowed_metadata_extensions": testData.metadata_ext,
"lease": 1000,
},
Data: data,
Check: func(resp *logical.Response) error {
if resp == nil && expectError {
return fmt.Errorf("expected error but received nil")

View File

@ -7,7 +7,8 @@ import (
"strings"
"time"
sockaddr "github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/tokenutil"
"github.com/hashicorp/vault/sdk/logical"
@ -47,7 +48,32 @@ Must be x509 PEM encoded.`,
EditType: "file",
},
},
"ocsp_enabled": {
Type: framework.TypeBool,
Description: `Whether to attempt OCSP verification of certificates at login`,
},
"ocsp_ca_certificates": {
Type: framework.TypeString,
Description: `Any additional CA certificates needed to communicate with OCSP servers`,
DisplayAttrs: &framework.DisplayAttributes{
EditType: "file",
},
},
"ocsp_servers_override": {
Type: framework.TypeCommaStringSlice,
Description: `A comma-separated list of OCSP server addresses. If unset, the OCSP server is determined
from the AuthorityInformationAccess extension on the certificate being inspected.`,
},
"ocsp_fail_open": {
Type: framework.TypeBool,
Default: false,
Description: "If set to true, if an OCSP revocation cannot be made successfully, login will proceed rather than failing. If false, failing to get an OCSP status fails the request.",
},
"ocsp_query_all_servers": {
Type: framework.TypeBool,
Default: false,
Description: "If set to true, rather than accepting the first successful OCSP response, query all servers and consider the certificate valid only if all servers agree.",
},
"allowed_names": {
Type: framework.TypeCommaStringSlice,
Description: `A comma-separated list of names.
@ -294,6 +320,21 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr
if certificateRaw, ok := d.GetOk("certificate"); ok {
cert.Certificate = certificateRaw.(string)
}
if ocspCertificatesRaw, ok := d.GetOk("ocsp_ca_certificates"); ok {
cert.OcspCaCertificates = ocspCertificatesRaw.(string)
}
if ocspEnabledRaw, ok := d.GetOk("ocsp_enabled"); ok {
cert.OcspEnabled = ocspEnabledRaw.(bool)
}
if ocspServerOverrides, ok := d.GetOk("ocsp_servers_override"); ok {
cert.OcspServersOverride = ocspServerOverrides.([]string)
}
if ocspFailOpen, ok := d.GetOk("ocsp_fail_open"); ok {
cert.OcspFailOpen = ocspFailOpen.(bool)
}
if ocspQueryAll, ok := d.GetOk("ocsp_query_all_servers"); ok {
cert.OcspQueryAllServers = ocspQueryAll.(bool)
}
if displayNameRaw, ok := d.GetOk("display_name"); ok {
cert.DisplayName = displayNameRaw.(string)
}
@ -399,7 +440,7 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr
}
}
if !clientAuth {
return logical.ErrorResponse("non-CA certificates should have TLS client authentication set as an extended key usage"), nil
return logical.ErrorResponse("nonCA certificates should have TLS client authentication set as an extended key usage"), nil
}
}
@ -438,6 +479,12 @@ type CertEntry struct {
RequiredExtensions []string
AllowedMetadataExtensions []string
BoundCIDRs []*sockaddr.SockAddrMarshaler
OcspCaCertificates string
OcspEnabled bool
OcspServersOverride []string
OcspFailOpen bool
OcspQueryAllServers bool
}
const pathCertHelpSyn = `
@ -449,6 +496,7 @@ This endpoint allows you to create, read, update, and delete trusted certificate
that are allowed to authenticate.
Deleting a certificate will not revoke auth for prior authenticated connections.
To do this, do a revoke on "login". If you don't need to revoke login immediately,
To do this, do a revoke on "login". If you don'log need to revoke login immediately,
then the next renew will cause the lease to expire.
`

View File

@ -8,6 +8,8 @@ import (
"github.com/hashicorp/vault/sdk/logical"
)
const maxCacheSize = 100000
func pathConfig(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config",
@ -22,6 +24,11 @@ func pathConfig(b *backend) *framework.Path {
Default: false,
Description: `If set, metadata of the certificate including the metadata corresponding to allowed_metadata_extensions will be stored in the alias. Defaults to false.`,
},
"ocsp_cache_size": {
Type: framework.TypeInt,
Default: 100,
Description: `The size of the in memory OCSP response cache, shared by all configured certs`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -32,18 +39,25 @@ func pathConfig(b *backend) *framework.Path {
}
func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
disableBinding := data.Get("disable_binding").(bool)
enableIdentityAliasMetadata := data.Get("enable_identity_alias_metadata").(bool)
entry, err := logical.StorageEntryJSON("config", config{
DisableBinding: disableBinding,
EnableIdentityAliasMetadata: enableIdentityAliasMetadata,
})
config, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
if disableBindingRaw, ok := data.GetOk("disable_binding"); ok {
config.DisableBinding = disableBindingRaw.(bool)
}
if enableIdentityAliasMetadataRaw, ok := data.GetOk("enable_identity_alias_metadata"); ok {
config.EnableIdentityAliasMetadata = enableIdentityAliasMetadataRaw.(bool)
}
if cacheSizeRaw, ok := data.GetOk("ocsp_cache_size"); ok {
cacheSize := cacheSizeRaw.(int)
if cacheSize < 2 || cacheSize > maxCacheSize {
return logical.ErrorResponse("invalid cache size, must be >= 2 and <= %d", maxCacheSize), nil
}
config.OcspCacheSize = cacheSize
}
if err := b.storeConfig(ctx, req.Storage, config); err != nil {
return nil, err
}
return nil, nil
@ -58,6 +72,7 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *f
data := map[string]interface{}{
"disable_binding": cfg.DisableBinding,
"enable_identity_alias_metadata": cfg.EnableIdentityAliasMetadata,
"ocsp_cache_size": cfg.OcspCacheSize,
}
return &logical.Response{
@ -85,4 +100,5 @@ func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error
type config struct {
DisableBinding bool `json:"disable_binding"`
EnableIdentityAliasMetadata bool `json:"enable_identity_alias_metadata"`
OcspCacheSize int `json:"ocsp_cache_size"`
}

View File

@ -12,6 +12,8 @@ import (
"fmt"
"strings"
"github.com/hashicorp/vault/sdk/helper/ocsp"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/policyutil"
@ -84,6 +86,9 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
if err != nil {
return nil, err
}
if b.configUpdated.Load() {
b.updatedConfig(config)
}
if b.crls == nil {
// Probably invalidated due to replication, but we need these to proceed
@ -164,6 +169,9 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
if err != nil {
return nil, err
}
if b.configUpdated.Load() {
b.updatedConfig(config)
}
if b.crls == nil {
if err := b.populateCRLs(ctx, req.Storage); err != nil {
@ -240,8 +248,8 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d
certName = d.Get("name").(string)
}
// Load the trusted certificates
roots, trusted, trustedNonCAs := b.loadTrustedCerts(ctx, req.Storage, certName)
// Load the trusted certificates and other details
roots, trusted, trustedNonCAs, verifyConf := b.loadTrustedCerts(ctx, req.Storage, certName)
// Get the list of full chains matching the connection and validates the
// certificate itself
@ -250,6 +258,11 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d
return nil, nil, err
}
var extraCas []*x509.Certificate
for _, t := range trusted {
extraCas = append(extraCas, t.Certificates...)
}
// If trustedNonCAs is not empty it means that client had registered a non-CA cert
// with the backend.
if len(trustedNonCAs) != 0 {
@ -257,9 +270,14 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d
tCert := trustedNonCA.Certificates[0]
// Check for client cert being explicitly listed in the config (and matching other constraints)
if tCert.SerialNumber.Cmp(clientCert.SerialNumber) == 0 &&
bytes.Equal(tCert.AuthorityKeyId, clientCert.AuthorityKeyId) &&
b.matchesConstraints(clientCert, trustedNonCA.Certificates, trustedNonCA) {
return trustedNonCA, nil, nil
bytes.Equal(tCert.AuthorityKeyId, clientCert.AuthorityKeyId) {
matches, err := b.matchesConstraints(ctx, clientCert, trustedNonCA.Certificates, trustedNonCA, verifyConf)
if err != nil {
return nil, nil, err
}
if matches {
return trustedNonCA, nil, nil
}
}
}
}
@ -276,10 +294,15 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d
for _, tCert := range trust.Certificates { // For each certificate in the entry
for _, chain := range trustedChains { // For each root chain that we matched
for _, cCert := range chain { // For each cert in the matched chain
if tCert.Equal(cCert) && // ParsedCert intersects with matched chain
b.matchesConstraints(clientCert, chain, trust) { // validate client cert + matched chain against the config
// Add the match to the list
matches = append(matches, trust)
if tCert.Equal(cCert) { // ParsedCert intersects with matched chain
match, err := b.matchesConstraints(ctx, clientCert, chain, trust, verifyConf) // validate client cert + matched chain against the config
if err != nil {
return nil, nil, err
}
if match {
// Add the match to the list
matches = append(matches, trust)
}
}
}
}
@ -295,8 +318,10 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d
return matches[0], nil, nil
}
func (b *backend) matchesConstraints(clientCert *x509.Certificate, trustedChain []*x509.Certificate, config *ParsedCert) bool {
return !b.checkForChainInCRLs(trustedChain) &&
func (b *backend) matchesConstraints(ctx context.Context, clientCert *x509.Certificate, trustedChain []*x509.Certificate,
config *ParsedCert, conf *ocsp.VerifyConfig,
) (bool, error) {
soFar := !b.checkForChainInCRLs(trustedChain) &&
b.matchesNames(clientCert, config) &&
b.matchesCommonName(clientCert, config) &&
b.matchesDNSSANs(clientCert, config) &&
@ -304,6 +329,14 @@ func (b *backend) matchesConstraints(clientCert *x509.Certificate, trustedChain
b.matchesURISANs(clientCert, config) &&
b.matchesOrganizationalUnits(clientCert, config) &&
b.matchesCertificateExtensions(clientCert, config)
if config.Entry.OcspEnabled {
ocspGood, err := b.checkForCertInOCSP(ctx, clientCert, trustedChain, conf)
if err != nil {
return false, err
}
soFar = soFar && ocspGood
}
return soFar, nil
}
// matchesNames verifies that the certificate matches at least one configured
@ -450,7 +483,7 @@ func (b *backend) matchesCertificateExtensions(clientCert *x509.Certificate, con
asn1.Unmarshal(ext.Value, &parsedValue)
clientExtMap[ext.Id.String()] = parsedValue
}
// If any of the required extensions don't match the constraint fails
// If any of the required extensions don'log match the constraint fails
for _, requiredExt := range config.Entry.RequiredExtensions {
reqExt := strings.SplitN(requiredExt, ":", 2)
clientExtValue, clientExtValueOk := clientExtMap[reqExt[0]]
@ -494,7 +527,7 @@ func (b *backend) certificateExtensionsMetadata(clientCert *x509.Certificate, co
}
// loadTrustedCerts is used to load all the trusted certificates from the backend
func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert) {
func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert, conf *ocsp.VerifyConfig) {
pool = x509.NewCertPool()
trusted = make([]*ParsedCert, 0)
trustedNonCAs = make([]*ParsedCert, 0)
@ -511,6 +544,7 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage,
}
}
conf = &ocsp.VerifyConfig{}
for _, name := range names {
entry, err := b.Cert(ctx, storage, strings.TrimPrefix(name, "cert/"))
if err != nil {
@ -518,7 +552,7 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage,
continue
}
if entry == nil {
// This could happen when the certName was provided and the cert doesn't exist,
// This could happen when the certName was provided and the cert doesn'log exist,
// or just if between the LIST and the GET the cert was deleted.
continue
}
@ -528,6 +562,8 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage,
b.Logger().Error("failed to parse certificate", "name", name)
continue
}
parsed = append(parsed, parsePEM([]byte(entry.OcspCaCertificates))...)
if !parsed[0].IsCA {
trustedNonCAs = append(trustedNonCAs, &ParsedCert{
Entry: entry,
@ -544,10 +580,33 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage,
Certificates: parsed,
})
}
if entry.OcspEnabled {
conf.OcspEnabled = true
conf.OcspServersOverride = append(conf.OcspServersOverride, entry.OcspServersOverride...)
if entry.OcspFailOpen {
conf.OcspFailureMode = ocsp.FailOpenTrue
} else {
conf.OcspFailureMode = ocsp.FailOpenFalse
}
conf.QueryAllServers = conf.QueryAllServers || entry.OcspQueryAllServers
}
}
return
}
func (b *backend) checkForCertInOCSP(ctx context.Context, clientCert *x509.Certificate, chain []*x509.Certificate, conf *ocsp.VerifyConfig) (bool, error) {
if !conf.OcspEnabled || len(chain) < 2 {
return true, nil
}
b.ocspClientMutex.RLock()
defer b.ocspClientMutex.RUnlock()
err := b.ocspClient.VerifyLeafCertificate(ctx, clientCert, chain[1], conf)
if err != nil {
return false, nil
}
return true, nil
}
func (b *backend) checkForChainInCRLs(chain []*x509.Certificate) bool {
badChain := false
for _, cert := range chain {

View File

@ -1,24 +1,61 @@
package cert
import (
"context"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"io/ioutil"
"math/big"
mathrand "math/rand"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/hashicorp/vault/sdk/helper/certutil"
"golang.org/x/crypto/ocsp"
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
"github.com/hashicorp/vault/sdk/logical"
)
var ocspPort int
var source InMemorySource
type testLogger struct{}
func (t *testLogger) Log(args ...any) {
fmt.Printf("%v", args)
}
func TestMain(m *testing.M) {
source = make(InMemorySource)
listener, err := net.Listen("tcp", ":0")
if err != nil {
return
}
ocspPort = listener.Addr().(*net.TCPAddr).Port
srv := &http.Server{
Addr: "localhost:0",
Handler: NewResponder(&testLogger{}, source, nil),
}
go func() {
srv.Serve(listener)
}()
defer srv.Shutdown(context.Background())
m.Run()
}
func TestCert_RoleResolve(t *testing.T) {
certTemplate := &x509.Certificate{
Subject: pkix.Name{
@ -159,6 +196,34 @@ func testAccStepResolveRoleExpectRoleResolutionToFail(t *testing.T, connState tl
}
}
func testAccStepResolveRoleOCSPFail(t *testing.T, connState tls.ConnectionState, certName string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ResolveRoleOperation,
Path: "login",
Unauthenticated: true,
ConnState: &connState,
ErrorOk: true,
Check: func(resp *logical.Response) error {
if resp == nil || !resp.IsError() {
t.Fatalf("Response was not an error: resp:%#v", resp)
}
errString, ok := resp.Data["error"].(string)
if !ok {
t.Fatal("Error not part of response.")
}
if !strings.Contains(errString, "no chain matching") {
t.Fatalf("Error was not due to OCSP failure. Error: %s", errString)
}
return nil
},
Data: map[string]interface{}{
"name": certName,
},
}
}
func TestCert_RoleResolve_RoleDoesNotExist(t *testing.T) {
certTemplate := &x509.Certificate{
Subject: pkix.Name{
@ -197,3 +262,97 @@ func TestCert_RoleResolve_RoleDoesNotExist(t *testing.T) {
},
})
}
func TestCert_RoleResolveOCSP(t *testing.T) {
cases := []struct {
name string
failOpen bool
certStatus int
errExpected bool
}{
{"failFalseGoodCert", false, ocsp.Good, false},
{"failFalseRevokedCert", false, ocsp.Revoked, true},
{"failFalseUnknownCert", false, ocsp.Unknown, true},
{"failTrueGoodCert", true, ocsp.Good, false},
{"failTrueRevokedCert", true, ocsp.Revoked, true},
{"failTrueUnknownCert", true, ocsp.Unknown, false},
}
certTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "example.com",
},
DNSNames: []string{"example.com"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
OCSPServer: []string{fmt.Sprintf("http://localhost:%d", ocspPort)},
}
tempDir, connState, err := generateTestCertAndConnState(t, certTemplate)
if tempDir != "" {
defer os.RemoveAll(tempDir)
}
if err != nil {
t.Fatalf("error testing connection state: %v", err)
}
ca, err := ioutil.ReadFile(filepath.Join(tempDir, "ca_cert.pem"))
if err != nil {
t.Fatalf("err: %v", err)
}
issuer := parsePEM(ca)
pkf, err := ioutil.ReadFile(filepath.Join(tempDir, "ca_key.pem"))
if err != nil {
t.Fatalf("err: %v", err)
}
pk, err := certutil.ParsePEMBundle(string(pkf))
if err != nil {
t.Fatalf("err: %v", err)
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
resp, err := ocsp.CreateResponse(issuer[0], issuer[0], ocsp.Response{
Status: c.certStatus,
SerialNumber: certTemplate.SerialNumber,
ProducedAt: time.Now(),
ThisUpdate: time.Now(),
NextUpdate: time.Now().Add(time.Hour),
}, pk.PrivateKey)
if err != nil {
t.Fatal(err)
}
source[certTemplate.SerialNumber.String()] = resp
b := testFactory(t)
b.(*backend).ocspClient.ClearCache()
var resolveStep logicaltest.TestStep
var loginStep logicaltest.TestStep
if c.errExpected {
loginStep = testAccStepLoginWithNameInvalid(t, connState, "web")
resolveStep = testAccStepResolveRoleOCSPFail(t, connState, "web")
} else {
loginStep = testAccStepLoginWithName(t, connState, "web")
resolveStep = testAccStepResolveRoleWithName(t, connState, "web")
}
logicaltest.Test(t, logicaltest.TestCase{
CredentialBackend: b,
Steps: []logicaltest.TestStep{
testAccStepCertWithExtraParams(t, "web", ca, "foo", allowed{dns: "example.com"}, false,
map[string]interface{}{"ocsp_enabled": true, "ocsp_fail_open": c.failOpen}),
loginStep,
resolveStep,
},
})
})
}
}
func serialFromBigInt(serial *big.Int) string {
return strings.TrimSpace(certutil.GetHexFormatted(serial.Bytes(), ":"))
}

View File

@ -0,0 +1,301 @@
// Package ocsp implements an OCSP responder based on a generic storage backend.
// It provides a couple of sample implementations.
// Because OCSP responders handle high query volumes, we have to be careful
// about how much logging we do. Error-level logs are reserved for problems
// internal to the server, that can be fixed by an administrator. Any type of
// incorrect input from a user should be logged and Info or below. For things
// that are logged on every request, Debug is the appropriate level.
//
// From https://github.com/cloudflare/cfssl/blob/master/ocsp/responder.go
package cert
import (
"crypto"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
"golang.org/x/crypto/ocsp"
)
var (
malformedRequestErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x01}
internalErrorErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x02}
tryLaterErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x03}
sigRequredErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x05}
unauthorizedErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x06}
// ErrNotFound indicates the request OCSP response was not found. It is used to
// indicate that the responder should reply with unauthorizedErrorResponse.
ErrNotFound = errors.New("Request OCSP Response not found")
)
// Source represents the logical source of OCSP responses, i.e.,
// the logic that actually chooses a response based on a request. In
// order to create an actual responder, wrap one of these in a Responder
// object and pass it to http.Handle. By default the Responder will set
// the headers Cache-Control to "max-age=(response.NextUpdate-now), public, no-transform, must-revalidate",
// Last-Modified to response.ThisUpdate, Expires to response.NextUpdate,
// ETag to the SHA256 hash of the response, and Content-Type to
// application/ocsp-response. If you want to override these headers,
// or set extra headers, your source should return a http.Header
// with the headers you wish to set. If you don'log want to set any
// extra headers you may return nil instead.
type Source interface {
Response(*ocsp.Request) ([]byte, http.Header, error)
}
// An InMemorySource is a map from serialNumber -> der(response)
type InMemorySource map[string][]byte
// Response looks up an OCSP response to provide for a given request.
// InMemorySource looks up a response purely based on serial number,
// without regard to what issuer the request is asking for.
func (src InMemorySource) Response(request *ocsp.Request) ([]byte, http.Header, error) {
response, present := src[request.SerialNumber.String()]
if !present {
return nil, nil, ErrNotFound
}
return response, nil, nil
}
// Stats is a basic interface that allows users to record information
// about returned responses
type Stats interface {
ResponseStatus(ocsp.ResponseStatus)
}
type logger interface {
Log(args ...any)
}
// A Responder object provides the HTTP logic to expose a
// Source of OCSP responses.
type Responder struct {
log logger
Source Source
stats Stats
}
// NewResponder instantiates a Responder with the give Source.
func NewResponder(t logger, source Source, stats Stats) *Responder {
return &Responder{
Source: source,
stats: stats,
log: t,
}
}
func overrideHeaders(response http.ResponseWriter, headers http.Header) {
for k, v := range headers {
if len(v) == 1 {
response.Header().Set(k, v[0])
} else if len(v) > 1 {
response.Header().Del(k)
for _, e := range v {
response.Header().Add(k, e)
}
}
}
}
// hashToString contains mappings for the only hash functions
// x/crypto/ocsp supports
var hashToString = map[crypto.Hash]string{
crypto.SHA1: "SHA1",
crypto.SHA256: "SHA256",
crypto.SHA384: "SHA384",
crypto.SHA512: "SHA512",
}
// A Responder can process both GET and POST requests. The mapping
// from an OCSP request to an OCSP response is done by the Source;
// the Responder simply decodes the request, and passes back whatever
// response is provided by the source.
// Note: The caller must use http.StripPrefix to strip any path components
// (including '/') on GET requests.
// Do not use this responder in conjunction with http.NewServeMux, because the
// default handler will try to canonicalize path components by changing any
// strings of repeated '/' into a single '/', which will break the base64
// encoding.
func (rs *Responder) ServeHTTP(response http.ResponseWriter, request *http.Request) {
// By default we set a 'max-age=0, no-cache' Cache-Control header, this
// is only returned to the client if a valid authorized OCSP response
// is not found or an error is returned. If a response if found the header
// will be altered to contain the proper max-age and modifiers.
response.Header().Add("Cache-Control", "max-age=0, no-cache")
// Read response from request
var requestBody []byte
var err error
switch request.Method {
case "GET":
base64Request, err := url.QueryUnescape(request.URL.Path)
if err != nil {
rs.log.Log("Error decoding URL:", request.URL.Path)
response.WriteHeader(http.StatusBadRequest)
return
}
// url.QueryUnescape not only unescapes %2B escaping, but it additionally
// turns the resulting '+' into a space, which makes base64 decoding fail.
// So we go back afterwards and turn ' ' back into '+'. This means we
// accept some malformed input that includes ' ' or %20, but that's fine.
base64RequestBytes := []byte(base64Request)
for i := range base64RequestBytes {
if base64RequestBytes[i] == ' ' {
base64RequestBytes[i] = '+'
}
}
// In certain situations a UA may construct a request that has a double
// slash between the host name and the base64 request body due to naively
// constructing the request URL. In that case strip the leading slash
// so that we can still decode the request.
if len(base64RequestBytes) > 0 && base64RequestBytes[0] == '/' {
base64RequestBytes = base64RequestBytes[1:]
}
requestBody, err = base64.StdEncoding.DecodeString(string(base64RequestBytes))
if err != nil {
rs.log.Log("Error decoding base64 from URL", string(base64RequestBytes))
response.WriteHeader(http.StatusBadRequest)
return
}
case "POST":
requestBody, err = ioutil.ReadAll(request.Body)
if err != nil {
rs.log.Log("Problem reading body of POST", err)
response.WriteHeader(http.StatusBadRequest)
return
}
default:
response.WriteHeader(http.StatusMethodNotAllowed)
return
}
b64Body := base64.StdEncoding.EncodeToString(requestBody)
rs.log.Log("Received OCSP request", b64Body)
// All responses after this point will be OCSP.
// We could check for the content type of the request, but that
// seems unnecessariliy restrictive.
response.Header().Add("Content-Type", "application/ocsp-response")
// Parse response as an OCSP request
// XXX: This fails if the request contains the nonce extension.
// We don'log intend to support nonces anyway, but maybe we
// should return unauthorizedRequest instead of malformed.
ocspRequest, err := ocsp.ParseRequest(requestBody)
if err != nil {
rs.log.Log("Error decoding request body", b64Body)
response.WriteHeader(http.StatusBadRequest)
response.Write(malformedRequestErrorResponse)
if rs.stats != nil {
rs.stats.ResponseStatus(ocsp.Malformed)
}
return
}
// Look up OCSP response from source
ocspResponse, headers, err := rs.Source.Response(ocspRequest)
if err != nil {
if err == ErrNotFound {
rs.log.Log("No response found for request: serial %x, request body %s",
ocspRequest.SerialNumber, b64Body)
response.Write(unauthorizedErrorResponse)
if rs.stats != nil {
rs.stats.ResponseStatus(ocsp.Unauthorized)
}
return
}
rs.log.Log("Error retrieving response for request: serial %x, request body %s, error",
ocspRequest.SerialNumber, b64Body, err)
response.WriteHeader(http.StatusInternalServerError)
response.Write(internalErrorErrorResponse)
if rs.stats != nil {
rs.stats.ResponseStatus(ocsp.InternalError)
}
return
}
parsedResponse, err := ocsp.ParseResponse(ocspResponse, nil)
if err != nil {
rs.log.Log("Error parsing response for serial %x",
ocspRequest.SerialNumber, err)
response.Write(internalErrorErrorResponse)
if rs.stats != nil {
rs.stats.ResponseStatus(ocsp.InternalError)
}
return
}
// Write OCSP response to response
response.Header().Add("Last-Modified", parsedResponse.ThisUpdate.Format(time.RFC1123))
response.Header().Add("Expires", parsedResponse.NextUpdate.Format(time.RFC1123))
now := time.Now()
maxAge := 0
if now.Before(parsedResponse.NextUpdate) {
maxAge = int(parsedResponse.NextUpdate.Sub(now) / time.Second)
} else {
// TODO(#530): we want max-age=0 but this is technically an authorized OCSP response
// (despite being stale) and 5019 forbids attaching no-cache
maxAge = 0
}
response.Header().Set(
"Cache-Control",
fmt.Sprintf(
"max-age=%d, public, no-transform, must-revalidate",
maxAge,
),
)
responseHash := sha256.Sum256(ocspResponse)
response.Header().Add("ETag", fmt.Sprintf("\"%X\"", responseHash))
if headers != nil {
overrideHeaders(response, headers)
}
// RFC 7232 says that a 304 response must contain the above
// headers if they would also be sent for a 200 for the same
// request, so we have to wait until here to do this
if etag := request.Header.Get("If-None-Match"); etag != "" {
if etag == fmt.Sprintf("\"%X\"", responseHash) {
response.WriteHeader(http.StatusNotModified)
return
}
}
response.WriteHeader(http.StatusOK)
response.Write(ocspResponse)
if rs.stats != nil {
rs.stats.ResponseStatus(ocsp.Success)
}
}
/*
Copyright (c) 2014 CloudFlare Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

View File

@ -41,7 +41,7 @@ func TestOcsp_Disabled(t *testing.T) {
"ocsp_disable": "true",
})
requireSuccessNilResponse(t, resp, err)
resp, err = sendOcspRequest(t, b, s, localTT.reqType, testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
resp, err = SendOcspRequest(t, b, s, localTT.reqType, testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
require.NoError(t, err)
requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body")
require.Equal(t, 401, resp.Data["http_status_code"])
@ -63,7 +63,7 @@ func TestOcsp_UnknownIssuerWithNoDefault(t *testing.T) {
// Create another completely empty mount so the created issuer/certificate above is unknown
b, s := CreateBackendWithStorage(t)
resp, err := sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
resp, err := SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
require.NoError(t, err)
requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body")
require.Equal(t, 401, resp.Data["http_status_code"])
@ -85,7 +85,7 @@ func TestOcsp_WrongIssuerInRequest(t *testing.T) {
})
requireSuccessNonNilResponse(t, resp, err, "revoke")
resp, err = sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer2, crypto.SHA1)
resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer2, crypto.SHA1)
require.NoError(t, err)
requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body")
require.Equal(t, 200, resp.Data["http_status_code"])
@ -167,7 +167,7 @@ func TestOcsp_InvalidIssuerIdInRevocationEntry(t *testing.T) {
require.NoError(t, err, "failed writing out new revocation entry: %v", revEntry)
// Send the request
resp, err = sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
require.NoError(t, err)
requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body")
require.Equal(t, 200, resp.Data["http_status_code"])
@ -220,7 +220,7 @@ func TestOcsp_UnknownIssuerIdWithDefaultHavingOcspUsageRemoved(t *testing.T) {
requireSuccessNonNilResponse(t, resp, err, "failed resetting usage flags on issuer2")
// Send the request
resp, err = sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
require.NoError(t, err)
requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body")
require.Equal(t, 401, resp.Data["http_status_code"])
@ -257,7 +257,7 @@ func TestOcsp_RevokedCertHasIssuerWithoutOcspUsage(t *testing.T) {
require.False(t, usages.HasUsage(OCSPSigningUsage))
// Request an OCSP request from it, we should get an Unauthorized response back
resp, err = sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
requireSuccessNonNilResponse(t, resp, err, "ocsp get request")
requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body")
require.Equal(t, 401, resp.Data["http_status_code"])
@ -296,7 +296,7 @@ func TestOcsp_RevokedCertHasIssuerWithoutAKey(t *testing.T) {
requireSuccessNonNilResponse(t, resp, err, "failed deleting key")
// Request an OCSP request from it, we should get an Unauthorized response back
resp, err = sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
requireSuccessNonNilResponse(t, resp, err, "ocsp get request")
requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body")
require.Equal(t, 401, resp.Data["http_status_code"])
@ -342,7 +342,7 @@ func TestOcsp_MultipleMatchingIssuersOneWithoutSigningUsage(t *testing.T) {
require.False(t, usages.HasUsage(OCSPSigningUsage))
// Request an OCSP request from it, we should get a Good response back, from the rotated cert
resp, err = sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
requireSuccessNonNilResponse(t, resp, err, "ocsp get request")
requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body")
require.Equal(t, 200, resp.Data["http_status_code"])
@ -410,7 +410,7 @@ func runOcspRequestTest(t *testing.T, requestType string, caKeyType string, caKe
b, s, testEnv := setupOcspEnvWithCaKeyConfig(t, caKeyType, caKeyBits, caKeySigBits)
// Non-revoked cert
resp, err := sendOcspRequest(t, b, s, requestType, testEnv.leafCertIssuer1, testEnv.issuer1, requestHash)
resp, err := SendOcspRequest(t, b, s, requestType, testEnv.leafCertIssuer1, testEnv.issuer1, requestHash)
requireSuccessNonNilResponse(t, resp, err, "ocsp get request")
requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body")
require.Equal(t, 200, resp.Data["http_status_code"])
@ -435,7 +435,7 @@ func runOcspRequestTest(t *testing.T, requestType string, caKeyType string, caKe
})
requireSuccessNonNilResponse(t, resp, err, "revoke")
resp, err = sendOcspRequest(t, b, s, requestType, testEnv.leafCertIssuer1, testEnv.issuer1, requestHash)
resp, err = SendOcspRequest(t, b, s, requestType, testEnv.leafCertIssuer1, testEnv.issuer1, requestHash)
requireSuccessNonNilResponse(t, resp, err, "ocsp get request with revoked")
requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body")
require.Equal(t, 200, resp.Data["http_status_code"])
@ -455,7 +455,7 @@ func runOcspRequestTest(t *testing.T, requestType string, caKeyType string, caKe
requireOcspResponseSignedBy(t, ocspResp, testEnv.issuer1)
// Request status for our second issuer
resp, err = sendOcspRequest(t, b, s, requestType, testEnv.leafCertIssuer2, testEnv.issuer2, requestHash)
resp, err = SendOcspRequest(t, b, s, requestType, testEnv.leafCertIssuer2, testEnv.issuer2, requestHash)
requireSuccessNonNilResponse(t, resp, err, "ocsp get request")
requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body")
require.Equal(t, 200, resp.Data["http_status_code"])
@ -569,7 +569,7 @@ func setupOcspEnvWithCaKeyConfig(t *testing.T, keyType string, caKeyBits int, ca
return b, s, testEnv
}
func sendOcspRequest(t *testing.T, b *backend, s logical.Storage, getOrPost string, cert, issuer *x509.Certificate, requestHash crypto.Hash) (*logical.Response, error) {
func SendOcspRequest(t *testing.T, b *backend, s logical.Storage, getOrPost string, cert, issuer *x509.Certificate, requestHash crypto.Hash) (*logical.Response, error) {
ocspRequest := generateRequest(t, requestHash, cert, issuer)
switch strings.ToLower(getOrPost) {
@ -578,7 +578,7 @@ func sendOcspRequest(t *testing.T, b *backend, s logical.Storage, getOrPost stri
case "post":
return sendOcspPostRequest(b, s, ocspRequest)
default:
t.Fatalf("unsupported value for sendOcspRequest getOrPost arg: %s", getOrPost)
t.Fatalf("unsupported value for SendOcspRequest getOrPost arg: %s", getOrPost)
}
return nil, nil
}

3
changelog/17093.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
auth/cert: Add configurable support for validating client certs with OCSP.
```

View File

@ -17,6 +17,7 @@ require (
github.com/hashicorp/go-kms-wrapping/entropy/v2 v2.0.0
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-plugin v1.4.5
github.com/hashicorp/go-retryablehttp v0.5.3
github.com/hashicorp/go-secure-stdlib/base62 v0.1.1
github.com/hashicorp/go-secure-stdlib/mlock v0.1.1
github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6
@ -45,6 +46,7 @@ require (
github.com/fatih/color v1.7.0 // indirect
github.com/frankban/quicktest v1.10.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.3.1 // indirect
github.com/hashicorp/go-cleanhttp v0.5.0 // indirect
github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-colorable v0.1.6 // indirect

View File

@ -87,6 +87,7 @@ github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFb
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-cleanhttp v0.5.0 h1:wvCrVc9TjDls6+YGAF2hAifE1E5U1+b4tH6KdvN3Gig=
github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80=
github.com/hashicorp/go-hclog v0.16.2 h1:K4ev2ib4LdQETX5cSZBG0DVLk1jwGqSPXBjdah3veNs=
github.com/hashicorp/go-hclog v0.16.2/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
@ -100,6 +101,7 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hashicorp/go-plugin v1.4.5 h1:oTE/oQR4eghggRg8VY7PAz3dr++VwDNBGCcOfIvHpBo=
github.com/hashicorp/go-plugin v1.4.5/go.mod h1:viDMjcLJuDui6pXb8U4HVfb8AamCWhHGUjr2IrTF67s=
github.com/hashicorp/go-retryablehttp v0.5.3 h1:QlWt0KvWT0lq8MFppF9tsJGF+ynG7ztc2KIPhzRGk7s=
github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.1 h1:6KMBnfEv0/kLAz0O76sliN5mXbCDcLfs2kP7ssP7+DQ=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.1/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=

1059
sdk/helper/ocsp/client.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,530 @@
// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
package ocsp
import (
"bytes"
"context"
"crypto"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-retryablehttp"
lru "github.com/hashicorp/golang-lru"
"golang.org/x/crypto/ocsp"
)
func TestOCSP(t *testing.T) {
targetURL := []string{
"https://sfcdev1.blob.core.windows.net/",
"https://sfctest0.snowflakecomputing.com/",
"https://s3-us-west-2.amazonaws.com/sfc-snowsql-updates/?prefix=1.1/windows_x86_64",
}
conf := VerifyConfig{
OcspFailureMode: FailOpenFalse,
}
c := New(testLogFactory, 10)
transports := []*http.Transport{
newInsecureOcspTransport(nil),
c.NewTransport(&conf),
}
for _, tgt := range targetURL {
c.ocspResponseCache, _ = lru.New2Q(10)
for _, tr := range transports {
c := &http.Client{
Transport: tr,
Timeout: 30 * time.Second,
}
req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil))
if err != nil {
t.Fatalf("fail to create a request. err: %v", err)
}
res, err := c.Do(req)
if err != nil {
t.Fatalf("failed to GET contents. err: %v", err)
}
defer res.Body.Close()
_, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("failed to read content body for %v", tgt)
}
}
}
}
/**
// Used for development, requires an active Vault with PKI setup
func TestMultiOCSP(t *testing.T) {
targetURL := []string{
"https://localhost:8200/v1/pki/ocsp",
"https://localhost:8200/v1/pki/ocsp",
"https://localhost:8200/v1/pki/ocsp",
}
b, _ := pem.Decode([]byte(vaultCert))
caCert, _ := x509.ParseCertificate(b.Bytes)
conf := VerifyConfig{
OcspFailureMode: FailOpenFalse,
QueryAllServers: true,
OcspServersOverride: targetURL,
ExtraCas: []*x509.Certificate{caCert},
}
c := New(testLogFactory, 10)
transports := []*http.Transport{
newInsecureOcspTransport(conf.ExtraCas),
c.NewTransport(&conf),
}
tgt := "https://localhost:8200/v1/pki/ca/pem"
c.ocspResponseCache, _ = lru.New2Q(10)
for _, tr := range transports {
c := &http.Client{
Transport: tr,
Timeout: 30 * time.Second,
}
req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil))
if err != nil {
t.Fatalf("fail to create a request. err: %v", err)
}
res, err := c.Do(req)
if err != nil {
t.Fatalf("failed to GET contents. err: %v", err)
}
defer res.Body.Close()
_, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("failed to read content body for %v", tgt)
}
}
}
*/
func TestUnitEncodeCertIDGood(t *testing.T) {
targetURLs := []string{
"faketestaccount.snowflakecomputing.com:443",
"s3-us-west-2.amazonaws.com:443",
"sfcdev1.blob.core.windows.net:443",
}
for _, tt := range targetURLs {
chainedCerts := getCert(tt)
for i := 0; i < len(chainedCerts)-1; i++ {
subject := chainedCerts[i]
issuer := chainedCerts[i+1]
ocspServers := subject.OCSPServer
if len(ocspServers) == 0 {
t.Fatalf("no OCSP server is found. cert: %v", subject.Subject)
}
ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{})
if err != nil {
t.Fatalf("failed to create OCSP request. err: %v", err)
}
var ost *ocspStatus
_, ost = extractCertIDKeyFromRequest(ocspReq)
if ost.err != nil {
t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err)
}
// better hash. Not sure if the actual OCSP server accepts this, though.
ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512})
if err != nil {
t.Fatalf("failed to create OCSP request. err: %v", err)
}
_, ost = extractCertIDKeyFromRequest(ocspReq)
if ost.err != nil {
t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err)
}
// tweaked request binary
ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512})
if err != nil {
t.Fatalf("failed to create OCSP request. err: %v", err)
}
ocspReq[10] = 0 // random change
_, ost = extractCertIDKeyFromRequest(ocspReq)
if ost.err == nil {
t.Fatal("should have failed")
}
}
}
}
func TestUnitCheckOCSPResponseCache(t *testing.T) {
c := New(testLogFactory, 10)
dummyKey0 := certIDKey{
NameHash: "dummy0",
IssuerKeyHash: "dummy0",
SerialNumber: "dummy0",
}
dummyKey := certIDKey{
NameHash: "dummy1",
IssuerKeyHash: "dummy1",
SerialNumber: "dummy1",
}
currentTime := float64(time.Now().UTC().Unix())
c.ocspResponseCache.Add(dummyKey0, &ocspCachedResponse{time: currentTime})
subject := &x509.Certificate{}
issuer := &x509.Certificate{}
ost, err := c.checkOCSPResponseCache(&dummyKey, subject, issuer)
if err != nil {
t.Fatal(err)
}
if ost.code != ocspMissedCache {
t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code)
}
// old timestamp
c.ocspResponseCache.Add(dummyKey, &ocspCachedResponse{time: float64(1395054952)})
ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer)
if err != nil {
t.Fatal(err)
}
if ost.code != ocspCacheExpired {
t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code)
}
// invalid validity
c.ocspResponseCache.Add(dummyKey, &ocspCachedResponse{time: float64(currentTime - 1000)})
ost, err = c.checkOCSPResponseCache(&dummyKey, subject, nil)
if err == nil && isValidOCSPStatus(ost.code) {
t.Fatalf("should have failed.")
}
}
func TestUnitValidateOCSP(t *testing.T) {
ocspRes := &ocsp.Response{}
ost, err := validateOCSP(ocspRes)
if err == nil && isValidOCSPStatus(ost.code) {
t.Fatalf("should have failed.")
}
currentTime := time.Now()
ocspRes.ThisUpdate = currentTime.Add(-2 * time.Hour)
ocspRes.NextUpdate = currentTime.Add(2 * time.Hour)
ocspRes.Status = ocsp.Revoked
ost, err = validateOCSP(ocspRes)
if err != nil {
t.Fatal(err)
}
if ost.code != ocspStatusRevoked {
t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusRevoked, ost.code)
}
ocspRes.Status = ocsp.Good
ost, err = validateOCSP(ocspRes)
if err != nil {
t.Fatal(err)
}
if ost.code != ocspStatusGood {
t.Fatalf("should have success. expected: %v, got: %v", ocspStatusGood, ost.code)
}
ocspRes.Status = ocsp.Unknown
ost, err = validateOCSP(ocspRes)
if err != nil {
t.Fatal(err)
}
if ost.code != ocspStatusUnknown {
t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusUnknown, ost.code)
}
ocspRes.Status = ocsp.ServerFailed
ost, err = validateOCSP(ocspRes)
if err != nil {
t.Fatal(err)
}
if ost.code != ocspStatusOthers {
t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusOthers, ost.code)
}
}
func TestUnitEncodeCertID(t *testing.T) {
var st *ocspStatus
_, st = extractCertIDKeyFromRequest([]byte{0x1, 0x2})
if st.code != ocspFailedDecomposeRequest {
t.Fatalf("failed to get OCSP status. expected: %v, got: %v", ocspFailedDecomposeRequest, st.code)
}
}
func getCert(addr string) []*x509.Certificate {
tcpConn, err := net.DialTimeout("tcp", addr, 40*time.Second)
if err != nil {
panic(err)
}
defer tcpConn.Close()
err = tcpConn.SetDeadline(time.Now().Add(10 * time.Second))
if err != nil {
panic(err)
}
config := tls.Config{InsecureSkipVerify: true, ServerName: addr}
conn := tls.Client(tcpConn, &config)
defer conn.Close()
err = conn.Handshake()
if err != nil {
panic(err)
}
state := conn.ConnectionState()
return state.PeerCertificates
}
func TestOCSPRetry(t *testing.T) {
c := New(testLogFactory, 10)
certs := getCert("s3-us-west-2.amazonaws.com:443")
dummyOCSPHost := &url.URL{
Scheme: "https",
Host: "dummyOCSPHost",
}
client := &fakeHTTPClient{
cnt: 3,
success: true,
body: []byte{1, 2, 3},
logger: hclog.New(hclog.DefaultOptions),
t: t,
}
res, b, st, err := c.retryOCSP(
context.TODO(),
client, fakeRequestFunc,
dummyOCSPHost,
make(map[string]string), []byte{0}, certs[len(certs)-1])
if err == nil {
fmt.Printf("should fail: %v, %v, %v\n", res, b, st)
}
client = &fakeHTTPClient{
cnt: 30,
success: true,
body: []byte{1, 2, 3},
logger: hclog.New(hclog.DefaultOptions),
t: t,
}
res, b, st, err = c.retryOCSP(
context.TODO(),
client, fakeRequestFunc,
dummyOCSPHost,
make(map[string]string), []byte{0}, certs[len(certs)-1])
if err == nil {
fmt.Printf("should fail: %v, %v, %v\n", res, b, st)
}
}
type tcCanEarlyExit struct {
results []*ocspStatus
resultLen int
retFailOpen *ocspStatus
retFailClosed *ocspStatus
}
func TestCanEarlyExitForOCSP(t *testing.T) {
testcases := []tcCanEarlyExit{
{ // 0
results: []*ocspStatus{
{
code: ocspStatusGood,
},
{
code: ocspStatusGood,
},
{
code: ocspStatusGood,
},
},
retFailOpen: nil,
retFailClosed: nil,
},
{ // 1
results: []*ocspStatus{
{
code: ocspStatusRevoked,
err: errors.New("revoked"),
},
{
code: ocspStatusGood,
},
{
code: ocspStatusGood,
},
},
retFailOpen: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
},
{ // 2
results: []*ocspStatus{
{
code: ocspStatusUnknown,
err: errors.New("unknown"),
},
{
code: ocspStatusGood,
},
{
code: ocspStatusGood,
},
},
retFailOpen: nil,
retFailClosed: &ocspStatus{ocspStatusUnknown, errors.New("unknown")},
},
{ // 3: not taken as revoked if any invalid OCSP response (ocspInvalidValidity) is included.
results: []*ocspStatus{
{
code: ocspStatusRevoked,
err: errors.New("revoked"),
},
{
code: ocspInvalidValidity,
},
{
code: ocspStatusGood,
},
},
retFailOpen: nil,
retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
},
{ // 4: not taken as revoked if the number of results don't match the expected results.
results: []*ocspStatus{
{
code: ocspStatusRevoked,
err: errors.New("revoked"),
},
{
code: ocspStatusGood,
},
},
resultLen: 3,
retFailOpen: nil,
retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
},
}
c := New(testLogFactory, 10)
for idx, tt := range testcases {
expectedLen := len(tt.results)
if tt.resultLen > 0 {
expectedLen = tt.resultLen
}
r := c.canEarlyExitForOCSP(tt.results, expectedLen, &VerifyConfig{OcspFailureMode: FailOpenTrue})
if !(tt.retFailOpen == nil && r == nil) && !(tt.retFailOpen != nil && r != nil && tt.retFailOpen.code == r.code) {
t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailOpen, r)
}
r = c.canEarlyExitForOCSP(tt.results, expectedLen, &VerifyConfig{OcspFailureMode: FailOpenFalse})
if !(tt.retFailClosed == nil && r == nil) && !(tt.retFailClosed != nil && r != nil && tt.retFailClosed.code == r.code) {
t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailClosed, r)
}
}
}
var testLogger = hclog.New(hclog.DefaultOptions)
func testLogFactory() hclog.Logger {
return testLogger
}
type fakeHTTPClient struct {
cnt int // number of retry
success bool // return success after retry in cnt times
timeout bool // timeout
body []byte // return body
t *testing.T
logger hclog.Logger
redirected bool
}
func (c *fakeHTTPClient) Do(_ *retryablehttp.Request) (*http.Response, error) {
c.cnt--
if c.cnt < 0 {
c.cnt = 0
}
c.t.Log("fakeHTTPClient.cnt", c.cnt)
var retcode int
if !c.redirected {
c.redirected = true
c.cnt++
retcode = 405
} else if c.success && c.cnt == 1 {
retcode = 200
} else {
if c.timeout {
// simulate timeout
time.Sleep(time.Second * 1)
return nil, &fakeHTTPError{
err: "Whatever reason (Client.Timeout exceeded while awaiting headers)",
timeout: true,
}
}
retcode = 0
}
ret := &http.Response{
StatusCode: retcode,
Body: &fakeResponseBody{body: c.body},
}
return ret, nil
}
type fakeHTTPError struct {
err string
timeout bool
}
func (e *fakeHTTPError) Error() string { return e.err }
func (e *fakeHTTPError) Timeout() bool { return e.timeout }
func (e *fakeHTTPError) Temporary() bool { return true }
type fakeResponseBody struct {
body []byte
cnt int
}
func (b *fakeResponseBody) Read(p []byte) (n int, err error) {
if b.cnt == 0 {
copy(p, b.body)
b.cnt = 1
return len(b.body), nil
}
b.cnt = 0
return 0, io.EOF
}
func (b *fakeResponseBody) Close() error {
return nil
}
func fakeRequestFunc(_, _ string, _ interface{}) (*retryablehttp.Request, error) {
return nil, nil
}
const vaultCert = `-----BEGIN CERTIFICATE-----
MIIDuTCCAqGgAwIBAgIUA6VeVD1IB5rXcCZRAqPO4zr/GAMwDQYJKoZIhvcNAQEL
BQAwcjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAlZBMREwDwYDVQQHDAhTb21lQ2l0
eTESMBAGA1UECgwJTXlDb21wYW55MRMwEQYDVQQLDApNeURpdmlzaW9uMRowGAYD
VQQDDBF3d3cuY29uaHVnZWNvLmNvbTAeFw0yMjA5MDcxOTA1MzdaFw0yNDA5MDYx
OTA1MzdaMHIxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJWQTERMA8GA1UEBwwIU29t
ZUNpdHkxEjAQBgNVBAoMCU15Q29tcGFueTETMBEGA1UECwwKTXlEaXZpc2lvbjEa
MBgGA1UEAwwRd3d3LmNvbmh1Z2Vjby5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IB
DwAwggEKAoIBAQDL9qzEXi4PIafSAqfcwcmjujFvbG1QZbI8swxnD+w8i4ufAQU5
LDmvMrGo3ZbhJ0mCihYmFxpjhRdP2raJQ9TysHlPXHtDRpr9ckWTKBz2oIfqVtJ2
qzteQkWCkDAO7kPqzgCFsMeoMZeONRkeGib0lEzQAbW/Rqnphg8zVVkyQ71DZ7Pc
d5WkC2E28kKcSramhWfVFpxG3hSIrLOX2esEXteLRzKxFPf+gi413JZFKYIWrebP
u5t0++MLNpuX322geoki4BWMjQsd47XILmxZ4aj33ScZvdrZESCnwP76hKIxg9mO
lMxrqSWKVV5jHZrElSEj9LYJgDO1Y6eItn7hAgMBAAGjRzBFMAsGA1UdDwQEAwIE
MDATBgNVHSUEDDAKBggrBgEFBQcDATAhBgNVHREEGjAYggtleGFtcGxlLmNvbYIJ
bG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQA5dPdf5SdtMwe2uSspO/EuWqbM
497vMQBW1Ey8KRKasJjhvOVYMbe7De5YsnW4bn8u5pl0zQGF4hEtpmifAtVvziH/
K+ritQj9VVNbLLCbFcg+b0kfjt4yrDZ64vWvIeCgPjG1Kme8gdUUWgu9dOud5gdx
qg/tIFv4TRS/eIIymMlfd9owOD3Ig6S5fy4NaAJFAwXf8+3Rzuc+e7JSAPgAufjh
tOTWinxvoiOLuYwo9CyGgq4qKBFsrY0aE0gdA7oTQkpbEbo2EbqiWUl/PTCl1Y4Z
nSZ0n+4q9QC9RLrWwYTwh838d5RVLUst2mBKSA+vn7YkqmBJbdBC6nkd7n7H
-----END CERTIFICATE-----
`