456 lines
12 KiB
Go
456 lines
12 KiB
Go
|
// Copyright (c) HashiCorp, Inc.
|
||
|
// SPDX-License-Identifier: MPL-2.0
|
||
|
|
||
|
package totp
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"encoding/base32"
|
||
|
"encoding/base64"
|
||
|
"fmt"
|
||
|
"image/png"
|
||
|
"net/url"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/hashicorp/vault/sdk/framework"
|
||
|
"github.com/hashicorp/vault/sdk/logical"
|
||
|
otplib "github.com/pquerna/otp"
|
||
|
totplib "github.com/pquerna/otp/totp"
|
||
|
)
|
||
|
|
||
|
func pathListKeys(b *backend) *framework.Path {
|
||
|
return &framework.Path{
|
||
|
Pattern: "keys/?$",
|
||
|
|
||
|
DisplayAttrs: &framework.DisplayAttributes{
|
||
|
OperationPrefix: operationPrefixTOTP,
|
||
|
OperationSuffix: "keys",
|
||
|
},
|
||
|
|
||
|
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||
|
logical.ListOperation: b.pathKeyList,
|
||
|
},
|
||
|
|
||
|
HelpSynopsis: pathKeyHelpSyn,
|
||
|
HelpDescription: pathKeyHelpDesc,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func pathKeys(b *backend) *framework.Path {
|
||
|
return &framework.Path{
|
||
|
Pattern: "keys/" + framework.GenericNameWithAtRegex("name"),
|
||
|
|
||
|
DisplayAttrs: &framework.DisplayAttributes{
|
||
|
OperationPrefix: operationPrefixTOTP,
|
||
|
OperationSuffix: "key",
|
||
|
},
|
||
|
|
||
|
Fields: map[string]*framework.FieldSchema{
|
||
|
"name": {
|
||
|
Type: framework.TypeString,
|
||
|
Description: "Name of the key.",
|
||
|
},
|
||
|
|
||
|
"generate": {
|
||
|
Type: framework.TypeBool,
|
||
|
Default: false,
|
||
|
Description: "Determines if a key should be generated by Vault or if a key is being passed from another service.",
|
||
|
},
|
||
|
|
||
|
"exported": {
|
||
|
Type: framework.TypeBool,
|
||
|
Default: true,
|
||
|
Description: "Determines if a QR code and url are returned upon generating a key. Only used if generate is true.",
|
||
|
},
|
||
|
|
||
|
"key_size": {
|
||
|
Type: framework.TypeInt,
|
||
|
Default: 20,
|
||
|
Description: "Determines the size in bytes of the generated key. Only used if generate is true.",
|
||
|
},
|
||
|
|
||
|
"key": {
|
||
|
Type: framework.TypeString,
|
||
|
Description: "The shared master key used to generate a TOTP token. Only used if generate is false.",
|
||
|
},
|
||
|
|
||
|
"issuer": {
|
||
|
Type: framework.TypeString,
|
||
|
Description: `The name of the key's issuing organization. Required if generate is true.`,
|
||
|
},
|
||
|
|
||
|
"account_name": {
|
||
|
Type: framework.TypeString,
|
||
|
Description: `The name of the account associated with the key. Required if generate is true.`,
|
||
|
},
|
||
|
|
||
|
"period": {
|
||
|
Type: framework.TypeDurationSecond,
|
||
|
Default: 30,
|
||
|
Description: `The length of time used to generate a counter for the TOTP token calculation.`,
|
||
|
},
|
||
|
|
||
|
"algorithm": {
|
||
|
Type: framework.TypeString,
|
||
|
Default: "SHA1",
|
||
|
Description: `The hashing algorithm used to generate the TOTP token. Options include SHA1, SHA256 and SHA512.`,
|
||
|
},
|
||
|
|
||
|
"digits": {
|
||
|
Type: framework.TypeInt,
|
||
|
Default: 6,
|
||
|
Description: `The number of digits in the generated TOTP token. This value can either be 6 or 8.`,
|
||
|
},
|
||
|
|
||
|
"skew": {
|
||
|
Type: framework.TypeInt,
|
||
|
Default: 1,
|
||
|
Description: `The number of delay periods that are allowed when validating a TOTP token. This value can either be 0 or 1. Only used if generate is true.`,
|
||
|
},
|
||
|
|
||
|
"qr_size": {
|
||
|
Type: framework.TypeInt,
|
||
|
Default: 200,
|
||
|
Description: `The pixel size of the generated square QR code. Only used if generate is true and exported is true. If this value is 0, a QR code will not be returned.`,
|
||
|
},
|
||
|
|
||
|
"url": {
|
||
|
Type: framework.TypeString,
|
||
|
Description: `A TOTP url string containing all of the parameters for key setup. Only used if generate is false.`,
|
||
|
},
|
||
|
},
|
||
|
|
||
|
Operations: map[logical.Operation]framework.OperationHandler{
|
||
|
logical.ReadOperation: &framework.PathOperation{
|
||
|
Callback: b.pathKeyRead,
|
||
|
DisplayAttrs: &framework.DisplayAttributes{
|
||
|
OperationVerb: "read",
|
||
|
},
|
||
|
},
|
||
|
logical.UpdateOperation: &framework.PathOperation{
|
||
|
Callback: b.pathKeyCreate,
|
||
|
DisplayAttrs: &framework.DisplayAttributes{
|
||
|
OperationVerb: "create",
|
||
|
},
|
||
|
},
|
||
|
logical.DeleteOperation: &framework.PathOperation{
|
||
|
Callback: b.pathKeyDelete,
|
||
|
DisplayAttrs: &framework.DisplayAttributes{
|
||
|
OperationVerb: "delete",
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
|
||
|
HelpSynopsis: pathKeyHelpSyn,
|
||
|
HelpDescription: pathKeyHelpDesc,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (b *backend) Key(ctx context.Context, s logical.Storage, n string) (*keyEntry, error) {
|
||
|
entry, err := s.Get(ctx, "key/"+n)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if entry == nil {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
var result keyEntry
|
||
|
if err := entry.DecodeJSON(&result); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return &result, nil
|
||
|
}
|
||
|
|
||
|
func (b *backend) pathKeyDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||
|
err := req.Storage.Delete(ctx, "key/"+data.Get("name").(string))
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
func (b *backend) pathKeyRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||
|
key, err := b.Key(ctx, req.Storage, data.Get("name").(string))
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if key == nil {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
// Translate algorithm back to string
|
||
|
algorithm := key.Algorithm.String()
|
||
|
|
||
|
// Return values of key
|
||
|
return &logical.Response{
|
||
|
Data: map[string]interface{}{
|
||
|
"issuer": key.Issuer,
|
||
|
"account_name": key.AccountName,
|
||
|
"period": key.Period,
|
||
|
"algorithm": algorithm,
|
||
|
"digits": key.Digits,
|
||
|
},
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (b *backend) pathKeyList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||
|
entries, err := req.Storage.List(ctx, "key/")
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return logical.ListResponse(entries), nil
|
||
|
}
|
||
|
|
||
|
func (b *backend) pathKeyCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||
|
name := data.Get("name").(string)
|
||
|
generate := data.Get("generate").(bool)
|
||
|
exported := data.Get("exported").(bool)
|
||
|
keyString := data.Get("key").(string)
|
||
|
issuer := data.Get("issuer").(string)
|
||
|
accountName := data.Get("account_name").(string)
|
||
|
period := data.Get("period").(int)
|
||
|
algorithm := data.Get("algorithm").(string)
|
||
|
digits := data.Get("digits").(int)
|
||
|
skew := data.Get("skew").(int)
|
||
|
qrSize := data.Get("qr_size").(int)
|
||
|
keySize := data.Get("key_size").(int)
|
||
|
inputURL := data.Get("url").(string)
|
||
|
|
||
|
if generate {
|
||
|
if keyString != "" {
|
||
|
return logical.ErrorResponse("a key should not be passed if generate is true"), nil
|
||
|
}
|
||
|
if inputURL != "" {
|
||
|
return logical.ErrorResponse("a url should not be passed if generate is true"), nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Read parameters from url if given
|
||
|
if inputURL != "" {
|
||
|
// Parse url
|
||
|
urlObject, err := url.Parse(inputURL)
|
||
|
if err != nil {
|
||
|
return logical.ErrorResponse("an error occurred while parsing url string"), err
|
||
|
}
|
||
|
|
||
|
// Set up query object
|
||
|
urlQuery := urlObject.Query()
|
||
|
path := strings.TrimPrefix(urlObject.Path, "/")
|
||
|
index := strings.Index(path, ":")
|
||
|
|
||
|
// Read issuer
|
||
|
urlIssuer := urlQuery.Get("issuer")
|
||
|
if urlIssuer != "" {
|
||
|
issuer = urlIssuer
|
||
|
} else {
|
||
|
if index != -1 {
|
||
|
issuer = path[:index]
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Read account name
|
||
|
if index == -1 {
|
||
|
accountName = path
|
||
|
} else {
|
||
|
accountName = path[index+1:]
|
||
|
}
|
||
|
|
||
|
// Read key string
|
||
|
keyString = urlQuery.Get("secret")
|
||
|
|
||
|
// Read period
|
||
|
periodQuery := urlQuery.Get("period")
|
||
|
if periodQuery != "" {
|
||
|
periodInt, err := strconv.Atoi(periodQuery)
|
||
|
if err != nil {
|
||
|
return logical.ErrorResponse("an error occurred while parsing period value in url"), err
|
||
|
}
|
||
|
period = periodInt
|
||
|
}
|
||
|
|
||
|
// Read digits
|
||
|
digitsQuery := urlQuery.Get("digits")
|
||
|
if digitsQuery != "" {
|
||
|
digitsInt, err := strconv.Atoi(digitsQuery)
|
||
|
if err != nil {
|
||
|
return logical.ErrorResponse("an error occurred while parsing digits value in url"), err
|
||
|
}
|
||
|
digits = digitsInt
|
||
|
}
|
||
|
|
||
|
// Read algorithm
|
||
|
algorithmQuery := urlQuery.Get("algorithm")
|
||
|
if algorithmQuery != "" {
|
||
|
algorithm = algorithmQuery
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Translate digits and algorithm to a format the totp library understands
|
||
|
var keyDigits otplib.Digits
|
||
|
switch digits {
|
||
|
case 6:
|
||
|
keyDigits = otplib.DigitsSix
|
||
|
case 8:
|
||
|
keyDigits = otplib.DigitsEight
|
||
|
default:
|
||
|
return logical.ErrorResponse("the digits value can only be 6 or 8"), nil
|
||
|
}
|
||
|
|
||
|
var keyAlgorithm otplib.Algorithm
|
||
|
switch algorithm {
|
||
|
case "SHA1":
|
||
|
keyAlgorithm = otplib.AlgorithmSHA1
|
||
|
case "SHA256":
|
||
|
keyAlgorithm = otplib.AlgorithmSHA256
|
||
|
case "SHA512":
|
||
|
keyAlgorithm = otplib.AlgorithmSHA512
|
||
|
default:
|
||
|
return logical.ErrorResponse("the algorithm value is not valid"), nil
|
||
|
}
|
||
|
|
||
|
// Enforce input value requirements
|
||
|
if period <= 0 {
|
||
|
return logical.ErrorResponse("the period value must be greater than zero"), nil
|
||
|
}
|
||
|
|
||
|
switch skew {
|
||
|
case 0:
|
||
|
case 1:
|
||
|
default:
|
||
|
return logical.ErrorResponse("the skew value must be 0 or 1"), nil
|
||
|
}
|
||
|
|
||
|
// QR size can be zero but it shouldn't be negative
|
||
|
if qrSize < 0 {
|
||
|
return logical.ErrorResponse("the qr_size value must be greater than or equal to zero"), nil
|
||
|
}
|
||
|
|
||
|
if keySize <= 0 {
|
||
|
return logical.ErrorResponse("the key_size value must be greater than zero"), nil
|
||
|
}
|
||
|
|
||
|
// Period, Skew and Key Size need to be unsigned ints
|
||
|
uintPeriod := uint(period)
|
||
|
uintSkew := uint(skew)
|
||
|
uintKeySize := uint(keySize)
|
||
|
|
||
|
var response *logical.Response
|
||
|
|
||
|
switch generate {
|
||
|
case true:
|
||
|
// If the key is generated, Account Name and Issuer are required.
|
||
|
if accountName == "" {
|
||
|
return logical.ErrorResponse("the account_name value is required for generated keys"), nil
|
||
|
}
|
||
|
|
||
|
if issuer == "" {
|
||
|
return logical.ErrorResponse("the issuer value is required for generated keys"), nil
|
||
|
}
|
||
|
|
||
|
// Generate a new key
|
||
|
keyObject, err := totplib.Generate(totplib.GenerateOpts{
|
||
|
Issuer: issuer,
|
||
|
AccountName: accountName,
|
||
|
Period: uintPeriod,
|
||
|
Digits: keyDigits,
|
||
|
Algorithm: keyAlgorithm,
|
||
|
SecretSize: uintKeySize,
|
||
|
Rand: b.GetRandomReader(),
|
||
|
})
|
||
|
if err != nil {
|
||
|
return logical.ErrorResponse("an error occurred while generating a key"), err
|
||
|
}
|
||
|
|
||
|
// Get key string value
|
||
|
keyString = keyObject.Secret()
|
||
|
|
||
|
// Skip returning the QR code and url if exported is set to false
|
||
|
if exported {
|
||
|
// Prepare the url and barcode
|
||
|
urlString := keyObject.String()
|
||
|
|
||
|
// Don't include QR code if size is set to zero
|
||
|
if qrSize == 0 {
|
||
|
response = &logical.Response{
|
||
|
Data: map[string]interface{}{
|
||
|
"url": urlString,
|
||
|
},
|
||
|
}
|
||
|
} else {
|
||
|
barcode, err := keyObject.Image(qrSize, qrSize)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to generate QR code image: %w", err)
|
||
|
}
|
||
|
|
||
|
var buff bytes.Buffer
|
||
|
png.Encode(&buff, barcode)
|
||
|
b64Barcode := base64.StdEncoding.EncodeToString(buff.Bytes())
|
||
|
response = &logical.Response{
|
||
|
Data: map[string]interface{}{
|
||
|
"url": urlString,
|
||
|
"barcode": b64Barcode,
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
default:
|
||
|
if keyString == "" {
|
||
|
return logical.ErrorResponse("the key value is required"), nil
|
||
|
}
|
||
|
|
||
|
if i := len(keyString) % 8; i != 0 {
|
||
|
keyString += strings.Repeat("=", 8-i)
|
||
|
}
|
||
|
|
||
|
_, err := base32.StdEncoding.DecodeString(strings.ToUpper(keyString))
|
||
|
if err != nil {
|
||
|
return logical.ErrorResponse(fmt.Sprintf(
|
||
|
"invalid key value: %s", err)), nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Store it
|
||
|
entry, err := logical.StorageEntryJSON("key/"+name, &keyEntry{
|
||
|
Key: keyString,
|
||
|
Issuer: issuer,
|
||
|
AccountName: accountName,
|
||
|
Period: uintPeriod,
|
||
|
Algorithm: keyAlgorithm,
|
||
|
Digits: keyDigits,
|
||
|
Skew: uintSkew,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if err := req.Storage.Put(ctx, entry); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return response, nil
|
||
|
}
|
||
|
|
||
|
type keyEntry struct {
|
||
|
Key string `json:"key" mapstructure:"key" structs:"key"`
|
||
|
Issuer string `json:"issuer" mapstructure:"issuer" structs:"issuer"`
|
||
|
AccountName string `json:"account_name" mapstructure:"account_name" structs:"account_name"`
|
||
|
Period uint `json:"period" mapstructure:"period" structs:"period"`
|
||
|
Algorithm otplib.Algorithm `json:"algorithm" mapstructure:"algorithm" structs:"algorithm"`
|
||
|
Digits otplib.Digits `json:"digits" mapstructure:"digits" structs:"digits"`
|
||
|
Skew uint `json:"skew" mapstructure:"skew" structs:"skew"`
|
||
|
}
|
||
|
|
||
|
const pathKeyHelpSyn = `
|
||
|
Manage the keys that can be created with this backend.
|
||
|
`
|
||
|
|
||
|
const pathKeyHelpDesc = `
|
||
|
This path lets you manage the keys that can be created with this backend.
|
||
|
|
||
|
`
|