f0279b0d31
This prevents SDK from having a circular import on the main Vault package. Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>
531 lines
14 KiB
Go
531 lines
14 KiB
Go
// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
|
|
|
|
package ocsp
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-retryablehttp"
|
|
lru "github.com/hashicorp/golang-lru"
|
|
"golang.org/x/crypto/ocsp"
|
|
)
|
|
|
|
func TestOCSP(t *testing.T) {
|
|
targetURL := []string{
|
|
"https://sfcdev1.blob.core.windows.net/",
|
|
"https://sfctest0.snowflakecomputing.com/",
|
|
"https://s3-us-west-2.amazonaws.com/sfc-snowsql-updates/?prefix=1.1/windows_x86_64",
|
|
}
|
|
|
|
conf := VerifyConfig{
|
|
OcspFailureMode: FailOpenFalse,
|
|
}
|
|
c := New(testLogFactory, 10)
|
|
transports := []*http.Transport{
|
|
newInsecureOcspTransport(nil),
|
|
c.NewTransport(&conf),
|
|
}
|
|
|
|
for _, tgt := range targetURL {
|
|
c.ocspResponseCache, _ = lru.New2Q(10)
|
|
for _, tr := range transports {
|
|
c := &http.Client{
|
|
Transport: tr,
|
|
Timeout: 30 * time.Second,
|
|
}
|
|
req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil))
|
|
if err != nil {
|
|
t.Fatalf("fail to create a request. err: %v", err)
|
|
}
|
|
res, err := c.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("failed to GET contents. err: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
_, err = ioutil.ReadAll(res.Body)
|
|
if err != nil {
|
|
t.Fatalf("failed to read content body for %v", tgt)
|
|
}
|
|
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
// Used for development, requires an active Vault with PKI setup
|
|
func TestMultiOCSP(t *testing.T) {
|
|
|
|
targetURL := []string{
|
|
"https://localhost:8200/v1/pki/ocsp",
|
|
"https://localhost:8200/v1/pki/ocsp",
|
|
"https://localhost:8200/v1/pki/ocsp",
|
|
}
|
|
|
|
b, _ := pem.Decode([]byte(vaultCert))
|
|
caCert, _ := x509.ParseCertificate(b.Bytes)
|
|
conf := VerifyConfig{
|
|
OcspFailureMode: FailOpenFalse,
|
|
QueryAllServers: true,
|
|
OcspServersOverride: targetURL,
|
|
ExtraCas: []*x509.Certificate{caCert},
|
|
}
|
|
c := New(testLogFactory, 10)
|
|
transports := []*http.Transport{
|
|
newInsecureOcspTransport(conf.ExtraCas),
|
|
c.NewTransport(&conf),
|
|
}
|
|
|
|
tgt := "https://localhost:8200/v1/pki/ca/pem"
|
|
c.ocspResponseCache, _ = lru.New2Q(10)
|
|
for _, tr := range transports {
|
|
c := &http.Client{
|
|
Transport: tr,
|
|
Timeout: 30 * time.Second,
|
|
}
|
|
req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil))
|
|
if err != nil {
|
|
t.Fatalf("fail to create a request. err: %v", err)
|
|
}
|
|
res, err := c.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("failed to GET contents. err: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
_, err = ioutil.ReadAll(res.Body)
|
|
if err != nil {
|
|
t.Fatalf("failed to read content body for %v", tgt)
|
|
}
|
|
}
|
|
}
|
|
*/
|
|
|
|
func TestUnitEncodeCertIDGood(t *testing.T) {
|
|
targetURLs := []string{
|
|
"faketestaccount.snowflakecomputing.com:443",
|
|
"s3-us-west-2.amazonaws.com:443",
|
|
"sfcdev1.blob.core.windows.net:443",
|
|
}
|
|
for _, tt := range targetURLs {
|
|
chainedCerts := getCert(tt)
|
|
for i := 0; i < len(chainedCerts)-1; i++ {
|
|
subject := chainedCerts[i]
|
|
issuer := chainedCerts[i+1]
|
|
ocspServers := subject.OCSPServer
|
|
if len(ocspServers) == 0 {
|
|
t.Fatalf("no OCSP server is found. cert: %v", subject.Subject)
|
|
}
|
|
ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{})
|
|
if err != nil {
|
|
t.Fatalf("failed to create OCSP request. err: %v", err)
|
|
}
|
|
var ost *ocspStatus
|
|
_, ost = extractCertIDKeyFromRequest(ocspReq)
|
|
if ost.err != nil {
|
|
t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err)
|
|
}
|
|
// better hash. Not sure if the actual OCSP server accepts this, though.
|
|
ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512})
|
|
if err != nil {
|
|
t.Fatalf("failed to create OCSP request. err: %v", err)
|
|
}
|
|
_, ost = extractCertIDKeyFromRequest(ocspReq)
|
|
if ost.err != nil {
|
|
t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err)
|
|
}
|
|
// tweaked request binary
|
|
ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512})
|
|
if err != nil {
|
|
t.Fatalf("failed to create OCSP request. err: %v", err)
|
|
}
|
|
ocspReq[10] = 0 // random change
|
|
_, ost = extractCertIDKeyFromRequest(ocspReq)
|
|
if ost.err == nil {
|
|
t.Fatal("should have failed")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestUnitCheckOCSPResponseCache(t *testing.T) {
|
|
c := New(testLogFactory, 10)
|
|
dummyKey0 := certIDKey{
|
|
NameHash: "dummy0",
|
|
IssuerKeyHash: "dummy0",
|
|
SerialNumber: "dummy0",
|
|
}
|
|
dummyKey := certIDKey{
|
|
NameHash: "dummy1",
|
|
IssuerKeyHash: "dummy1",
|
|
SerialNumber: "dummy1",
|
|
}
|
|
currentTime := float64(time.Now().UTC().Unix())
|
|
c.ocspResponseCache.Add(dummyKey0, &ocspCachedResponse{time: currentTime})
|
|
subject := &x509.Certificate{}
|
|
issuer := &x509.Certificate{}
|
|
ost, err := c.checkOCSPResponseCache(&dummyKey, subject, issuer)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ost.code != ocspMissedCache {
|
|
t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code)
|
|
}
|
|
// old timestamp
|
|
c.ocspResponseCache.Add(dummyKey, &ocspCachedResponse{time: float64(1395054952)})
|
|
ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ost.code != ocspCacheExpired {
|
|
t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code)
|
|
}
|
|
|
|
// invalid validity
|
|
c.ocspResponseCache.Add(dummyKey, &ocspCachedResponse{time: float64(currentTime - 1000)})
|
|
ost, err = c.checkOCSPResponseCache(&dummyKey, subject, nil)
|
|
if err == nil && isValidOCSPStatus(ost.code) {
|
|
t.Fatalf("should have failed.")
|
|
}
|
|
}
|
|
|
|
func TestUnitValidateOCSP(t *testing.T) {
|
|
ocspRes := &ocsp.Response{}
|
|
ost, err := validateOCSP(ocspRes)
|
|
if err == nil && isValidOCSPStatus(ost.code) {
|
|
t.Fatalf("should have failed.")
|
|
}
|
|
|
|
currentTime := time.Now()
|
|
ocspRes.ThisUpdate = currentTime.Add(-2 * time.Hour)
|
|
ocspRes.NextUpdate = currentTime.Add(2 * time.Hour)
|
|
ocspRes.Status = ocsp.Revoked
|
|
ost, err = validateOCSP(ocspRes)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if ost.code != ocspStatusRevoked {
|
|
t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusRevoked, ost.code)
|
|
}
|
|
ocspRes.Status = ocsp.Good
|
|
ost, err = validateOCSP(ocspRes)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if ost.code != ocspStatusGood {
|
|
t.Fatalf("should have success. expected: %v, got: %v", ocspStatusGood, ost.code)
|
|
}
|
|
ocspRes.Status = ocsp.Unknown
|
|
ost, err = validateOCSP(ocspRes)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ost.code != ocspStatusUnknown {
|
|
t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusUnknown, ost.code)
|
|
}
|
|
ocspRes.Status = ocsp.ServerFailed
|
|
ost, err = validateOCSP(ocspRes)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ost.code != ocspStatusOthers {
|
|
t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusOthers, ost.code)
|
|
}
|
|
}
|
|
|
|
func TestUnitEncodeCertID(t *testing.T) {
|
|
var st *ocspStatus
|
|
_, st = extractCertIDKeyFromRequest([]byte{0x1, 0x2})
|
|
if st.code != ocspFailedDecomposeRequest {
|
|
t.Fatalf("failed to get OCSP status. expected: %v, got: %v", ocspFailedDecomposeRequest, st.code)
|
|
}
|
|
}
|
|
|
|
func getCert(addr string) []*x509.Certificate {
|
|
tcpConn, err := net.DialTimeout("tcp", addr, 40*time.Second)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
defer tcpConn.Close()
|
|
|
|
err = tcpConn.SetDeadline(time.Now().Add(10 * time.Second))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
config := tls.Config{InsecureSkipVerify: true, ServerName: addr}
|
|
|
|
conn := tls.Client(tcpConn, &config)
|
|
defer conn.Close()
|
|
|
|
err = conn.Handshake()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
state := conn.ConnectionState()
|
|
|
|
return state.PeerCertificates
|
|
}
|
|
|
|
func TestOCSPRetry(t *testing.T) {
|
|
c := New(testLogFactory, 10)
|
|
certs := getCert("s3-us-west-2.amazonaws.com:443")
|
|
dummyOCSPHost := &url.URL{
|
|
Scheme: "https",
|
|
Host: "dummyOCSPHost",
|
|
}
|
|
client := &fakeHTTPClient{
|
|
cnt: 3,
|
|
success: true,
|
|
body: []byte{1, 2, 3},
|
|
logger: hclog.New(hclog.DefaultOptions),
|
|
t: t,
|
|
}
|
|
res, b, st, err := c.retryOCSP(
|
|
context.TODO(),
|
|
client, fakeRequestFunc,
|
|
dummyOCSPHost,
|
|
make(map[string]string), []byte{0}, certs[len(certs)-1])
|
|
if err == nil {
|
|
fmt.Printf("should fail: %v, %v, %v\n", res, b, st)
|
|
}
|
|
client = &fakeHTTPClient{
|
|
cnt: 30,
|
|
success: true,
|
|
body: []byte{1, 2, 3},
|
|
logger: hclog.New(hclog.DefaultOptions),
|
|
t: t,
|
|
}
|
|
res, b, st, err = c.retryOCSP(
|
|
context.TODO(),
|
|
client, fakeRequestFunc,
|
|
dummyOCSPHost,
|
|
make(map[string]string), []byte{0}, certs[len(certs)-1])
|
|
if err == nil {
|
|
fmt.Printf("should fail: %v, %v, %v\n", res, b, st)
|
|
}
|
|
}
|
|
|
|
type tcCanEarlyExit struct {
|
|
results []*ocspStatus
|
|
resultLen int
|
|
retFailOpen *ocspStatus
|
|
retFailClosed *ocspStatus
|
|
}
|
|
|
|
func TestCanEarlyExitForOCSP(t *testing.T) {
|
|
testcases := []tcCanEarlyExit{
|
|
{ // 0
|
|
results: []*ocspStatus{
|
|
{
|
|
code: ocspStatusGood,
|
|
},
|
|
{
|
|
code: ocspStatusGood,
|
|
},
|
|
{
|
|
code: ocspStatusGood,
|
|
},
|
|
},
|
|
retFailOpen: nil,
|
|
retFailClosed: nil,
|
|
},
|
|
{ // 1
|
|
results: []*ocspStatus{
|
|
{
|
|
code: ocspStatusRevoked,
|
|
err: errors.New("revoked"),
|
|
},
|
|
{
|
|
code: ocspStatusGood,
|
|
},
|
|
{
|
|
code: ocspStatusGood,
|
|
},
|
|
},
|
|
retFailOpen: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
|
|
retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
|
|
},
|
|
{ // 2
|
|
results: []*ocspStatus{
|
|
{
|
|
code: ocspStatusUnknown,
|
|
err: errors.New("unknown"),
|
|
},
|
|
{
|
|
code: ocspStatusGood,
|
|
},
|
|
{
|
|
code: ocspStatusGood,
|
|
},
|
|
},
|
|
retFailOpen: nil,
|
|
retFailClosed: &ocspStatus{ocspStatusUnknown, errors.New("unknown")},
|
|
},
|
|
{ // 3: not taken as revoked if any invalid OCSP response (ocspInvalidValidity) is included.
|
|
results: []*ocspStatus{
|
|
{
|
|
code: ocspStatusRevoked,
|
|
err: errors.New("revoked"),
|
|
},
|
|
{
|
|
code: ocspInvalidValidity,
|
|
},
|
|
{
|
|
code: ocspStatusGood,
|
|
},
|
|
},
|
|
retFailOpen: nil,
|
|
retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
|
|
},
|
|
{ // 4: not taken as revoked if the number of results don't match the expected results.
|
|
results: []*ocspStatus{
|
|
{
|
|
code: ocspStatusRevoked,
|
|
err: errors.New("revoked"),
|
|
},
|
|
{
|
|
code: ocspStatusGood,
|
|
},
|
|
},
|
|
resultLen: 3,
|
|
retFailOpen: nil,
|
|
retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
|
|
},
|
|
}
|
|
c := New(testLogFactory, 10)
|
|
for idx, tt := range testcases {
|
|
expectedLen := len(tt.results)
|
|
if tt.resultLen > 0 {
|
|
expectedLen = tt.resultLen
|
|
}
|
|
r := c.canEarlyExitForOCSP(tt.results, expectedLen, &VerifyConfig{OcspFailureMode: FailOpenTrue})
|
|
if !(tt.retFailOpen == nil && r == nil) && !(tt.retFailOpen != nil && r != nil && tt.retFailOpen.code == r.code) {
|
|
t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailOpen, r)
|
|
}
|
|
r = c.canEarlyExitForOCSP(tt.results, expectedLen, &VerifyConfig{OcspFailureMode: FailOpenFalse})
|
|
if !(tt.retFailClosed == nil && r == nil) && !(tt.retFailClosed != nil && r != nil && tt.retFailClosed.code == r.code) {
|
|
t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailClosed, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
var testLogger = hclog.New(hclog.DefaultOptions)
|
|
|
|
func testLogFactory() hclog.Logger {
|
|
return testLogger
|
|
}
|
|
|
|
type fakeHTTPClient struct {
|
|
cnt int // number of retry
|
|
success bool // return success after retry in cnt times
|
|
timeout bool // timeout
|
|
body []byte // return body
|
|
t *testing.T
|
|
logger hclog.Logger
|
|
redirected bool
|
|
}
|
|
|
|
func (c *fakeHTTPClient) Do(_ *retryablehttp.Request) (*http.Response, error) {
|
|
c.cnt--
|
|
if c.cnt < 0 {
|
|
c.cnt = 0
|
|
}
|
|
c.t.Log("fakeHTTPClient.cnt", c.cnt)
|
|
|
|
var retcode int
|
|
if !c.redirected {
|
|
c.redirected = true
|
|
c.cnt++
|
|
retcode = 405
|
|
} else if c.success && c.cnt == 1 {
|
|
retcode = 200
|
|
} else {
|
|
if c.timeout {
|
|
// simulate timeout
|
|
time.Sleep(time.Second * 1)
|
|
return nil, &fakeHTTPError{
|
|
err: "Whatever reason (Client.Timeout exceeded while awaiting headers)",
|
|
timeout: true,
|
|
}
|
|
}
|
|
retcode = 0
|
|
}
|
|
|
|
ret := &http.Response{
|
|
StatusCode: retcode,
|
|
Body: &fakeResponseBody{body: c.body},
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
type fakeHTTPError struct {
|
|
err string
|
|
timeout bool
|
|
}
|
|
|
|
func (e *fakeHTTPError) Error() string { return e.err }
|
|
func (e *fakeHTTPError) Timeout() bool { return e.timeout }
|
|
func (e *fakeHTTPError) Temporary() bool { return true }
|
|
|
|
type fakeResponseBody struct {
|
|
body []byte
|
|
cnt int
|
|
}
|
|
|
|
func (b *fakeResponseBody) Read(p []byte) (n int, err error) {
|
|
if b.cnt == 0 {
|
|
copy(p, b.body)
|
|
b.cnt = 1
|
|
return len(b.body), nil
|
|
}
|
|
b.cnt = 0
|
|
return 0, io.EOF
|
|
}
|
|
|
|
func (b *fakeResponseBody) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func fakeRequestFunc(_, _ string, _ interface{}) (*retryablehttp.Request, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
const vaultCert = `-----BEGIN CERTIFICATE-----
|
|
MIIDuTCCAqGgAwIBAgIUA6VeVD1IB5rXcCZRAqPO4zr/GAMwDQYJKoZIhvcNAQEL
|
|
BQAwcjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAlZBMREwDwYDVQQHDAhTb21lQ2l0
|
|
eTESMBAGA1UECgwJTXlDb21wYW55MRMwEQYDVQQLDApNeURpdmlzaW9uMRowGAYD
|
|
VQQDDBF3d3cuY29uaHVnZWNvLmNvbTAeFw0yMjA5MDcxOTA1MzdaFw0yNDA5MDYx
|
|
OTA1MzdaMHIxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJWQTERMA8GA1UEBwwIU29t
|
|
ZUNpdHkxEjAQBgNVBAoMCU15Q29tcGFueTETMBEGA1UECwwKTXlEaXZpc2lvbjEa
|
|
MBgGA1UEAwwRd3d3LmNvbmh1Z2Vjby5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IB
|
|
DwAwggEKAoIBAQDL9qzEXi4PIafSAqfcwcmjujFvbG1QZbI8swxnD+w8i4ufAQU5
|
|
LDmvMrGo3ZbhJ0mCihYmFxpjhRdP2raJQ9TysHlPXHtDRpr9ckWTKBz2oIfqVtJ2
|
|
qzteQkWCkDAO7kPqzgCFsMeoMZeONRkeGib0lEzQAbW/Rqnphg8zVVkyQ71DZ7Pc
|
|
d5WkC2E28kKcSramhWfVFpxG3hSIrLOX2esEXteLRzKxFPf+gi413JZFKYIWrebP
|
|
u5t0++MLNpuX322geoki4BWMjQsd47XILmxZ4aj33ScZvdrZESCnwP76hKIxg9mO
|
|
lMxrqSWKVV5jHZrElSEj9LYJgDO1Y6eItn7hAgMBAAGjRzBFMAsGA1UdDwQEAwIE
|
|
MDATBgNVHSUEDDAKBggrBgEFBQcDATAhBgNVHREEGjAYggtleGFtcGxlLmNvbYIJ
|
|
bG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQA5dPdf5SdtMwe2uSspO/EuWqbM
|
|
497vMQBW1Ey8KRKasJjhvOVYMbe7De5YsnW4bn8u5pl0zQGF4hEtpmifAtVvziH/
|
|
K+ritQj9VVNbLLCbFcg+b0kfjt4yrDZ64vWvIeCgPjG1Kme8gdUUWgu9dOud5gdx
|
|
qg/tIFv4TRS/eIIymMlfd9owOD3Ig6S5fy4NaAJFAwXf8+3Rzuc+e7JSAPgAufjh
|
|
tOTWinxvoiOLuYwo9CyGgq4qKBFsrY0aE0gdA7oTQkpbEbo2EbqiWUl/PTCl1Y4Z
|
|
nSZ0n+4q9QC9RLrWwYTwh838d5RVLUst2mBKSA+vn7YkqmBJbdBC6nkd7n7H
|
|
-----END CERTIFICATE-----
|
|
`
|