open-vault/builtin/logical/pki/util.go
Hamid Ghaf 27bb03bbc0
adding copyright header (#19555)
* adding copyright header

* fix fmt and a test
2023-03-15 09:00:52 -07:00

504 lines
13 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (
"crypto"
"crypto/x509"
"fmt"
"math/big"
"net/http"
"regexp"
"strings"
"sync"
"time"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
)
const (
managedKeyNameArg = "managed_key_name"
managedKeyIdArg = "managed_key_id"
defaultRef = "default"
// Constants for If-Modified-Since operation
headerIfModifiedSince = "If-Modified-Since"
headerLastModified = "Last-Modified"
)
var (
nameMatcher = regexp.MustCompile("^" + framework.GenericNameRegex(issuerRefParam) + "$")
errIssuerNameInUse = errutil.UserError{Err: "issuer name already in use"}
errIssuerNameIsEmpty = errutil.UserError{Err: "expected non-empty issuer name"}
errKeyNameInUse = errutil.UserError{Err: "key name already in use"}
)
func serialFromCert(cert *x509.Certificate) string {
return serialFromBigInt(cert.SerialNumber)
}
func serialFromBigInt(serial *big.Int) string {
return strings.TrimSpace(certutil.GetHexFormatted(serial.Bytes(), ":"))
}
func normalizeSerialFromBigInt(serial *big.Int) string {
return strings.TrimSpace(certutil.GetHexFormatted(serial.Bytes(), "-"))
}
func normalizeSerial(serial string) string {
return strings.ReplaceAll(strings.ToLower(serial), ":", "-")
}
func denormalizeSerial(serial string) string {
return strings.ReplaceAll(strings.ToLower(serial), "-", ":")
}
func serialToBigInt(serial string) (*big.Int, bool) {
norm := normalizeSerial(serial)
hex := strings.ReplaceAll(norm, "-", "")
return big.NewInt(0).SetString(hex, 16)
}
func kmsRequested(input *inputBundle) bool {
return kmsRequestedFromFieldData(input.apiData)
}
func kmsRequestedFromFieldData(data *framework.FieldData) bool {
exportedStr, ok := data.GetOk("exported")
if !ok {
return false
}
return exportedStr.(string) == "kms"
}
func existingKeyRequested(input *inputBundle) bool {
return existingKeyRequestedFromFieldData(input.apiData)
}
func existingKeyRequestedFromFieldData(data *framework.FieldData) bool {
exportedStr, ok := data.GetOk("exported")
if !ok {
return false
}
return exportedStr.(string) == "existing"
}
type managedKeyId interface {
String() string
}
type (
UUIDKey string
NameKey string
)
func (u UUIDKey) String() string {
return string(u)
}
func (n NameKey) String() string {
return string(n)
}
type managedKeyInfo struct {
publicKey crypto.PublicKey
keyType certutil.PrivateKeyType
name NameKey
uuid UUIDKey
}
// getManagedKeyId returns a NameKey or a UUIDKey, whichever was specified in the
// request API data.
func getManagedKeyId(data *framework.FieldData) (managedKeyId, error) {
name, UUID, err := getManagedKeyNameOrUUID(data)
if err != nil {
return nil, err
}
var keyId managedKeyId = NameKey(name)
if len(UUID) > 0 {
keyId = UUIDKey(UUID)
}
return keyId, nil
}
func getKeyRefWithErr(data *framework.FieldData) (string, error) {
keyRef := getKeyRef(data)
if len(keyRef) == 0 {
return "", errutil.UserError{Err: "missing argument key_ref for existing type"}
}
return keyRef, nil
}
func getManagedKeyNameOrUUID(data *framework.FieldData) (name string, UUID string, err error) {
getApiData := func(argName string) (string, error) {
arg, ok := data.GetOk(argName)
if !ok {
return "", nil
}
argValue, ok := arg.(string)
if !ok {
return "", errutil.UserError{Err: fmt.Sprintf("invalid type for argument %s", argName)}
}
return strings.TrimSpace(argValue), nil
}
keyName, err := getApiData(managedKeyNameArg)
keyUUID, err2 := getApiData(managedKeyIdArg)
switch {
case err != nil:
return "", "", err
case err2 != nil:
return "", "", err2
case len(keyName) == 0 && len(keyUUID) == 0:
return "", "", errutil.UserError{Err: fmt.Sprintf("missing argument %s or %s", managedKeyNameArg, managedKeyIdArg)}
case len(keyName) > 0 && len(keyUUID) > 0:
return "", "", errutil.UserError{Err: fmt.Sprintf("only one argument of %s or %s should be specified", managedKeyNameArg, managedKeyIdArg)}
}
return keyName, keyUUID, nil
}
func getIssuerName(sc *storageContext, data *framework.FieldData) (string, error) {
issuerName := ""
issuerNameIface, ok := data.GetOk("issuer_name")
if ok {
issuerName = strings.TrimSpace(issuerNameIface.(string))
if len(issuerName) == 0 {
return issuerName, errIssuerNameIsEmpty
}
if strings.ToLower(issuerName) == defaultRef {
return issuerName, errutil.UserError{Err: "reserved keyword 'default' can not be used as issuer name"}
}
if !nameMatcher.MatchString(issuerName) {
return issuerName, errutil.UserError{Err: "issuer name contained invalid characters"}
}
issuerId, err := sc.resolveIssuerReference(issuerName)
if err == nil {
return issuerName, errIssuerNameInUse
}
if err != nil && issuerId != IssuerRefNotFound {
return issuerName, errutil.InternalError{Err: err.Error()}
}
}
return issuerName, nil
}
func getKeyName(sc *storageContext, data *framework.FieldData) (string, error) {
keyName := ""
keyNameIface, ok := data.GetOk(keyNameParam)
if ok {
keyName = strings.TrimSpace(keyNameIface.(string))
if strings.ToLower(keyName) == defaultRef {
return "", errutil.UserError{Err: "reserved keyword 'default' can not be used as key name"}
}
if !nameMatcher.MatchString(keyName) {
return "", errutil.UserError{Err: "key name contained invalid characters"}
}
keyId, err := sc.resolveKeyReference(keyName)
if err == nil {
return "", errKeyNameInUse
}
if err != nil && keyId != KeyRefNotFound {
return "", errutil.InternalError{Err: err.Error()}
}
}
return keyName, nil
}
func getIssuerRef(data *framework.FieldData) string {
return extractRef(data, issuerRefParam)
}
func getKeyRef(data *framework.FieldData) string {
return extractRef(data, keyRefParam)
}
func extractRef(data *framework.FieldData, paramName string) string {
value := strings.TrimSpace(data.Get(paramName).(string))
if strings.EqualFold(value, defaultRef) {
return defaultRef
}
return value
}
func isStringArrayDifferent(a, b []string) bool {
if len(a) != len(b) {
return true
}
for i, v := range a {
if v != b[i] {
return true
}
}
return false
}
func hasHeader(header string, req *logical.Request) bool {
var hasHeader bool
headerValue := req.Headers[header]
if len(headerValue) > 0 {
hasHeader = true
}
return hasHeader
}
func parseIfNotModifiedSince(req *logical.Request) (time.Time, error) {
var headerTimeValue time.Time
headerValue := req.Headers[headerIfModifiedSince]
headerTimeValue, err := time.Parse(time.RFC1123, headerValue[0])
if err != nil {
return headerTimeValue, fmt.Errorf("failed to parse given value for '%s' header: %w", headerIfModifiedSince, err)
}
return headerTimeValue, nil
}
type ifModifiedReqType int
const (
ifModifiedUnknown ifModifiedReqType = iota
ifModifiedCA = iota
ifModifiedCRL = iota
ifModifiedDeltaCRL = iota
ifModifiedUnifiedCRL = iota
ifModifiedUnifiedDeltaCRL = iota
)
type IfModifiedSinceHelper struct {
req *logical.Request
reqType ifModifiedReqType
issuerRef issuerID
}
func sendNotModifiedResponseIfNecessary(helper *IfModifiedSinceHelper, sc *storageContext, resp *logical.Response) (bool, error) {
responseHeaders := map[string][]string{}
if !hasHeader(headerIfModifiedSince, helper.req) {
return false, nil
}
before, err := sc.isIfModifiedSinceBeforeLastModified(helper, responseHeaders)
if err != nil {
return false, err
}
if !before {
return false, nil
}
// Fill response
resp.Data = map[string]interface{}{
logical.HTTPContentType: "",
logical.HTTPStatusCode: 304,
}
resp.Headers = responseHeaders
return true, nil
}
func (sc *storageContext) isIfModifiedSinceBeforeLastModified(helper *IfModifiedSinceHelper, responseHeaders map[string][]string) (bool, error) {
// False return --> we were last modified _before_ the requester's
// time --> keep using the cached copy and return 304.
var err error
var lastModified time.Time
ifModifiedSince, err := parseIfNotModifiedSince(helper.req)
if err != nil {
return false, err
}
switch helper.reqType {
case ifModifiedCRL, ifModifiedDeltaCRL:
if sc.Backend.crlBuilder.invalidate.Load() {
// When we see the CRL is invalidated, respond with false
// regardless of what the local CRL state says. We've likely
// renamed some issuers or are about to rebuild a new CRL....
//
// We do this earlier, ahead of config load, as it saves us a
// potential error condition.
return false, nil
}
crlConfig, err := sc.getLocalCRLConfig()
if err != nil {
return false, err
}
lastModified = crlConfig.LastModified
if helper.reqType == ifModifiedDeltaCRL {
lastModified = crlConfig.DeltaLastModified
}
case ifModifiedUnifiedCRL, ifModifiedUnifiedDeltaCRL:
if sc.Backend.crlBuilder.invalidate.Load() {
// When we see the CRL is invalidated, respond with false
// regardless of what the local CRL state says. We've likely
// renamed some issuers or are about to rebuild a new CRL....
//
// We do this earlier, ahead of config load, as it saves us a
// potential error condition.
return false, nil
}
crlConfig, err := sc.getUnifiedCRLConfig()
if err != nil {
return false, err
}
lastModified = crlConfig.LastModified
if helper.reqType == ifModifiedUnifiedDeltaCRL {
lastModified = crlConfig.DeltaLastModified
}
case ifModifiedCA:
issuerId, err := sc.resolveIssuerReference(string(helper.issuerRef))
if err != nil {
return false, err
}
issuer, err := sc.fetchIssuerById(issuerId)
if err != nil {
return false, err
}
lastModified = issuer.LastModified
default:
return false, fmt.Errorf("unknown if-modified-since request type: %v", helper.reqType)
}
if !lastModified.IsZero() && lastModified.Before(ifModifiedSince) {
responseHeaders[headerLastModified] = []string{lastModified.Format(http.TimeFormat)}
return true, nil
}
return false, nil
}
func addWarnings(resp *logical.Response, warnings []string) *logical.Response {
for _, warning := range warnings {
resp.AddWarning(warning)
}
return resp
}
// revocationQueue is a type for allowing invalidateFunc to continue operating
// quickly, while letting periodicFunc slowly sort through all open
// revocations to process. In particular, we do not wish to be holding this
// lock while periodicFunc is running, so iteration returns a full copy of
// the data in this queue. We use a map from serial->[]clusterId, allowing us
// to quickly insert and remove items, without using a slice of tuples. One
// serial might be present on two clusters, if two clusters both have the cert
// stored locally (e.g., via BYOC), which would result in two confirmation
// entries and thus dictating the need for []clusterId. This also lets us
// avoid having duplicate entries.
type revocationQueue struct {
_l sync.Mutex
queue map[string][]string
}
func newRevocationQueue() *revocationQueue {
return &revocationQueue{
queue: make(map[string][]string),
}
}
func (q *revocationQueue) Add(items ...*revocationQueueEntry) {
q._l.Lock()
defer q._l.Unlock()
for _, item := range items {
var found bool
for _, cluster := range q.queue[item.Serial] {
if cluster == item.Cluster {
found = true
break
}
}
if !found {
q.queue[item.Serial] = append(q.queue[item.Serial], item.Cluster)
}
}
}
func (q *revocationQueue) Remove(item *revocationQueueEntry) {
q._l.Lock()
defer q._l.Unlock()
clusters, present := q.queue[item.Serial]
if !present {
return
}
if len(clusters) == 0 || (len(clusters) == 1 && clusters[0] == item.Cluster) {
delete(q.queue, item.Serial)
return
}
result := clusters
for index, cluster := range clusters {
if cluster == item.Cluster {
result = append(clusters[0:index], clusters[index+1:]...)
break
}
}
q.queue[item.Serial] = result
}
// As this doesn't depend on any internal state, it should not be called
// unless it is OK to remove any items added since the last Iterate()
// function call.
func (q *revocationQueue) RemoveAll() {
q._l.Lock()
defer q._l.Unlock()
q.queue = make(map[string][]string)
}
func (q *revocationQueue) Iterate() []*revocationQueueEntry {
q._l.Lock()
defer q._l.Unlock()
// Heuristic: by storing by serial, occasionally we'll get double entires
// if it was already revoked, but otherwise we'll be off by fewer when
// building this list.
ret := make([]*revocationQueueEntry, 0, len(q.queue))
for serial, clusters := range q.queue {
for _, cluster := range clusters {
ret = append(ret, &revocationQueueEntry{
Serial: serial,
Cluster: cluster,
})
}
}
return ret
}
// sliceToMapKey return a map that who's keys are entries in a map.
func sliceToMapKey(s []string) map[string]struct{} {
var empty struct{}
myMap := make(map[string]struct{}, len(s))
for _, s := range s {
myMap[s] = empty
}
return myMap
}