open-vault/builtin/logical/totp/path_keys.go

456 lines
12 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package totp
import (
"bytes"
"context"
"encoding/base32"
"encoding/base64"
"fmt"
"image/png"
"net/url"
"strconv"
"strings"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
otplib "github.com/pquerna/otp"
totplib "github.com/pquerna/otp/totp"
)
func pathListKeys(b *backend) *framework.Path {
return &framework.Path{
Pattern: "keys/?$",
DisplayAttrs: &framework.DisplayAttributes{
OperationPrefix: operationPrefixTOTP,
OperationSuffix: "keys",
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathKeyList,
},
HelpSynopsis: pathKeyHelpSyn,
HelpDescription: pathKeyHelpDesc,
}
}
func pathKeys(b *backend) *framework.Path {
return &framework.Path{
Pattern: "keys/" + framework.GenericNameWithAtRegex("name"),
DisplayAttrs: &framework.DisplayAttributes{
OperationPrefix: operationPrefixTOTP,
OperationSuffix: "key",
},
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the key.",
},
"generate": {
Type: framework.TypeBool,
Default: false,
Description: "Determines if a key should be generated by Vault or if a key is being passed from another service.",
},
"exported": {
Type: framework.TypeBool,
Default: true,
Description: "Determines if a QR code and url are returned upon generating a key. Only used if generate is true.",
},
"key_size": {
Type: framework.TypeInt,
Default: 20,
Description: "Determines the size in bytes of the generated key. Only used if generate is true.",
},
"key": {
Type: framework.TypeString,
Description: "The shared master key used to generate a TOTP token. Only used if generate is false.",
},
"issuer": {
Type: framework.TypeString,
Description: `The name of the key's issuing organization. Required if generate is true.`,
},
"account_name": {
Type: framework.TypeString,
Description: `The name of the account associated with the key. Required if generate is true.`,
},
"period": {
Type: framework.TypeDurationSecond,
Default: 30,
Description: `The length of time used to generate a counter for the TOTP token calculation.`,
},
"algorithm": {
Type: framework.TypeString,
Default: "SHA1",
Description: `The hashing algorithm used to generate the TOTP token. Options include SHA1, SHA256 and SHA512.`,
},
"digits": {
Type: framework.TypeInt,
Default: 6,
Description: `The number of digits in the generated TOTP token. This value can either be 6 or 8.`,
},
"skew": {
Type: framework.TypeInt,
Default: 1,
Description: `The number of delay periods that are allowed when validating a TOTP token. This value can either be 0 or 1. Only used if generate is true.`,
},
"qr_size": {
Type: framework.TypeInt,
Default: 200,
Description: `The pixel size of the generated square QR code. Only used if generate is true and exported is true. If this value is 0, a QR code will not be returned.`,
},
"url": {
Type: framework.TypeString,
Description: `A TOTP url string containing all of the parameters for key setup. Only used if generate is false.`,
},
},
Operations: map[logical.Operation]framework.OperationHandler{
logical.ReadOperation: &framework.PathOperation{
Callback: b.pathKeyRead,
DisplayAttrs: &framework.DisplayAttributes{
OperationVerb: "read",
},
},
logical.UpdateOperation: &framework.PathOperation{
Callback: b.pathKeyCreate,
DisplayAttrs: &framework.DisplayAttributes{
OperationVerb: "create",
},
},
logical.DeleteOperation: &framework.PathOperation{
Callback: b.pathKeyDelete,
DisplayAttrs: &framework.DisplayAttributes{
OperationVerb: "delete",
},
},
},
HelpSynopsis: pathKeyHelpSyn,
HelpDescription: pathKeyHelpDesc,
}
}
func (b *backend) Key(ctx context.Context, s logical.Storage, n string) (*keyEntry, error) {
entry, err := s.Get(ctx, "key/"+n)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result keyEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
func (b *backend) pathKeyDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete(ctx, "key/"+data.Get("name").(string))
if err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathKeyRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
key, err := b.Key(ctx, req.Storage, data.Get("name").(string))
if err != nil {
return nil, err
}
if key == nil {
return nil, nil
}
// Translate algorithm back to string
algorithm := key.Algorithm.String()
// Return values of key
return &logical.Response{
Data: map[string]interface{}{
"issuer": key.Issuer,
"account_name": key.AccountName,
"period": key.Period,
"algorithm": algorithm,
"digits": key.Digits,
},
}, nil
}
func (b *backend) pathKeyList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List(ctx, "key/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
}
func (b *backend) pathKeyCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
generate := data.Get("generate").(bool)
exported := data.Get("exported").(bool)
keyString := data.Get("key").(string)
issuer := data.Get("issuer").(string)
accountName := data.Get("account_name").(string)
period := data.Get("period").(int)
algorithm := data.Get("algorithm").(string)
digits := data.Get("digits").(int)
skew := data.Get("skew").(int)
qrSize := data.Get("qr_size").(int)
keySize := data.Get("key_size").(int)
inputURL := data.Get("url").(string)
if generate {
if keyString != "" {
return logical.ErrorResponse("a key should not be passed if generate is true"), nil
}
if inputURL != "" {
return logical.ErrorResponse("a url should not be passed if generate is true"), nil
}
}
// Read parameters from url if given
if inputURL != "" {
// Parse url
urlObject, err := url.Parse(inputURL)
if err != nil {
return logical.ErrorResponse("an error occurred while parsing url string"), err
}
// Set up query object
urlQuery := urlObject.Query()
path := strings.TrimPrefix(urlObject.Path, "/")
index := strings.Index(path, ":")
// Read issuer
urlIssuer := urlQuery.Get("issuer")
if urlIssuer != "" {
issuer = urlIssuer
} else {
if index != -1 {
issuer = path[:index]
}
}
// Read account name
if index == -1 {
accountName = path
} else {
accountName = path[index+1:]
}
// Read key string
keyString = urlQuery.Get("secret")
// Read period
periodQuery := urlQuery.Get("period")
if periodQuery != "" {
periodInt, err := strconv.Atoi(periodQuery)
if err != nil {
return logical.ErrorResponse("an error occurred while parsing period value in url"), err
}
period = periodInt
}
// Read digits
digitsQuery := urlQuery.Get("digits")
if digitsQuery != "" {
digitsInt, err := strconv.Atoi(digitsQuery)
if err != nil {
return logical.ErrorResponse("an error occurred while parsing digits value in url"), err
}
digits = digitsInt
}
// Read algorithm
algorithmQuery := urlQuery.Get("algorithm")
if algorithmQuery != "" {
algorithm = algorithmQuery
}
}
// Translate digits and algorithm to a format the totp library understands
var keyDigits otplib.Digits
switch digits {
case 6:
keyDigits = otplib.DigitsSix
case 8:
keyDigits = otplib.DigitsEight
default:
return logical.ErrorResponse("the digits value can only be 6 or 8"), nil
}
var keyAlgorithm otplib.Algorithm
switch algorithm {
case "SHA1":
keyAlgorithm = otplib.AlgorithmSHA1
case "SHA256":
keyAlgorithm = otplib.AlgorithmSHA256
case "SHA512":
keyAlgorithm = otplib.AlgorithmSHA512
default:
return logical.ErrorResponse("the algorithm value is not valid"), nil
}
// Enforce input value requirements
if period <= 0 {
return logical.ErrorResponse("the period value must be greater than zero"), nil
}
switch skew {
case 0:
case 1:
default:
return logical.ErrorResponse("the skew value must be 0 or 1"), nil
}
// QR size can be zero but it shouldn't be negative
if qrSize < 0 {
return logical.ErrorResponse("the qr_size value must be greater than or equal to zero"), nil
}
if keySize <= 0 {
return logical.ErrorResponse("the key_size value must be greater than zero"), nil
}
// Period, Skew and Key Size need to be unsigned ints
uintPeriod := uint(period)
uintSkew := uint(skew)
uintKeySize := uint(keySize)
var response *logical.Response
switch generate {
case true:
// If the key is generated, Account Name and Issuer are required.
if accountName == "" {
return logical.ErrorResponse("the account_name value is required for generated keys"), nil
}
if issuer == "" {
return logical.ErrorResponse("the issuer value is required for generated keys"), nil
}
// Generate a new key
keyObject, err := totplib.Generate(totplib.GenerateOpts{
Issuer: issuer,
AccountName: accountName,
Period: uintPeriod,
Digits: keyDigits,
Algorithm: keyAlgorithm,
SecretSize: uintKeySize,
Rand: b.GetRandomReader(),
})
if err != nil {
return logical.ErrorResponse("an error occurred while generating a key"), err
}
// Get key string value
keyString = keyObject.Secret()
// Skip returning the QR code and url if exported is set to false
if exported {
// Prepare the url and barcode
urlString := keyObject.String()
// Don't include QR code if size is set to zero
if qrSize == 0 {
response = &logical.Response{
Data: map[string]interface{}{
"url": urlString,
},
}
} else {
barcode, err := keyObject.Image(qrSize, qrSize)
if err != nil {
return nil, fmt.Errorf("failed to generate QR code image: %w", err)
}
var buff bytes.Buffer
png.Encode(&buff, barcode)
b64Barcode := base64.StdEncoding.EncodeToString(buff.Bytes())
response = &logical.Response{
Data: map[string]interface{}{
"url": urlString,
"barcode": b64Barcode,
},
}
}
}
default:
if keyString == "" {
return logical.ErrorResponse("the key value is required"), nil
}
if i := len(keyString) % 8; i != 0 {
keyString += strings.Repeat("=", 8-i)
}
_, err := base32.StdEncoding.DecodeString(strings.ToUpper(keyString))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"invalid key value: %s", err)), nil
}
}
// Store it
entry, err := logical.StorageEntryJSON("key/"+name, &keyEntry{
Key: keyString,
Issuer: issuer,
AccountName: accountName,
Period: uintPeriod,
Algorithm: keyAlgorithm,
Digits: keyDigits,
Skew: uintSkew,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
return response, nil
}
type keyEntry struct {
Key string `json:"key" mapstructure:"key" structs:"key"`
Issuer string `json:"issuer" mapstructure:"issuer" structs:"issuer"`
AccountName string `json:"account_name" mapstructure:"account_name" structs:"account_name"`
Period uint `json:"period" mapstructure:"period" structs:"period"`
Algorithm otplib.Algorithm `json:"algorithm" mapstructure:"algorithm" structs:"algorithm"`
Digits otplib.Digits `json:"digits" mapstructure:"digits" structs:"digits"`
Skew uint `json:"skew" mapstructure:"skew" structs:"skew"`
}
const pathKeyHelpSyn = `
Manage the keys that can be created with this backend.
`
const pathKeyHelpDesc = `
This path lets you manage the keys that can be created with this backend.
`