27bb03bbc0
* adding copyright header * fix fmt and a test
271 lines
6.6 KiB
Go
271 lines
6.6 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package sink
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"math/rand"
|
|
"os"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
hclog "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/helper/dhutil"
|
|
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
|
)
|
|
|
|
type Sink interface {
|
|
WriteToken(string) error
|
|
}
|
|
|
|
type SinkReader interface {
|
|
Token() string
|
|
}
|
|
|
|
type SinkConfig struct {
|
|
Sink
|
|
Logger hclog.Logger
|
|
Config map[string]interface{}
|
|
Client *api.Client
|
|
WrapTTL time.Duration
|
|
DHType string
|
|
DHPath string
|
|
DeriveKey bool
|
|
AAD string
|
|
cachedRemotePubKey []byte
|
|
cachedPubKey []byte
|
|
cachedPriKey []byte
|
|
}
|
|
|
|
type SinkServerConfig struct {
|
|
Logger hclog.Logger
|
|
Client *api.Client
|
|
Context context.Context
|
|
ExitAfterAuth bool
|
|
}
|
|
|
|
// SinkServer is responsible for pushing tokens to sinks
|
|
type SinkServer struct {
|
|
logger hclog.Logger
|
|
client *api.Client
|
|
random *rand.Rand
|
|
exitAfterAuth bool
|
|
remaining *int32
|
|
}
|
|
|
|
func NewSinkServer(conf *SinkServerConfig) *SinkServer {
|
|
ss := &SinkServer{
|
|
logger: conf.Logger,
|
|
client: conf.Client,
|
|
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
|
|
exitAfterAuth: conf.ExitAfterAuth,
|
|
remaining: new(int32),
|
|
}
|
|
|
|
return ss
|
|
}
|
|
|
|
// Run executes the server's run loop, which is responsible for reading
|
|
// in new tokens and pushing them out to the various sinks.
|
|
func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*SinkConfig) error {
|
|
latestToken := new(string)
|
|
writeSink := func(currSink *SinkConfig, currToken string) error {
|
|
if currToken != *latestToken {
|
|
return nil
|
|
}
|
|
var err error
|
|
|
|
if currSink.WrapTTL != 0 {
|
|
if currToken, err = currSink.wrapToken(ss.client, currSink.WrapTTL, currToken); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if currSink.DHType != "" {
|
|
if currToken, err = currSink.encryptToken(currToken); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return currSink.WriteToken(currToken)
|
|
}
|
|
|
|
if incoming == nil {
|
|
return errors.New("sink server: incoming channel is nil")
|
|
}
|
|
|
|
ss.logger.Info("starting sink server")
|
|
defer func() {
|
|
ss.logger.Info("sink server stopped")
|
|
}()
|
|
|
|
type sinkToken struct {
|
|
sink *SinkConfig
|
|
token string
|
|
}
|
|
sinkCh := make(chan sinkToken, len(sinks))
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
|
|
case token := <-incoming:
|
|
if len(sinks) > 0 {
|
|
if token != *latestToken {
|
|
|
|
// Drain the existing funcs
|
|
drainLoop:
|
|
for {
|
|
select {
|
|
case <-sinkCh:
|
|
atomic.AddInt32(ss.remaining, -1)
|
|
default:
|
|
break drainLoop
|
|
}
|
|
}
|
|
|
|
*latestToken = token
|
|
|
|
for _, s := range sinks {
|
|
atomic.AddInt32(ss.remaining, 1)
|
|
sinkCh <- sinkToken{s, token}
|
|
}
|
|
}
|
|
} else {
|
|
ss.logger.Trace("no sinks, ignoring new token")
|
|
if ss.exitAfterAuth {
|
|
ss.logger.Trace("no sinks, exitAfterAuth, bye")
|
|
return nil
|
|
}
|
|
}
|
|
case st := <-sinkCh:
|
|
atomic.AddInt32(ss.remaining, -1)
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
default:
|
|
}
|
|
|
|
if err := writeSink(st.sink, st.token); err != nil {
|
|
backoff := 2*time.Second + time.Duration(ss.random.Int63()%int64(time.Second*2)-int64(time.Second))
|
|
ss.logger.Error("error returned by sink function, retrying", "error", err, "backoff", backoff.String())
|
|
timer := time.NewTimer(backoff)
|
|
select {
|
|
case <-ctx.Done():
|
|
timer.Stop()
|
|
return nil
|
|
case <-timer.C:
|
|
atomic.AddInt32(ss.remaining, 1)
|
|
sinkCh <- st
|
|
}
|
|
} else {
|
|
if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SinkConfig) encryptToken(token string) (string, error) {
|
|
var aesKey []byte
|
|
var err error
|
|
resp := new(dhutil.Envelope)
|
|
switch s.DHType {
|
|
case "curve25519":
|
|
if len(s.cachedRemotePubKey) == 0 {
|
|
_, err = os.Lstat(s.DHPath)
|
|
if err != nil {
|
|
if !os.IsNotExist(err) {
|
|
return "", fmt.Errorf("error stat-ing dh parameters file: %w", err)
|
|
}
|
|
return "", errors.New("no dh parameters file found, and no cached pub key")
|
|
}
|
|
fileBytes, err := ioutil.ReadFile(s.DHPath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error reading file for dh parameters: %w", err)
|
|
}
|
|
theirPubKey := new(dhutil.PublicKeyInfo)
|
|
if err := jsonutil.DecodeJSON(fileBytes, theirPubKey); err != nil {
|
|
return "", fmt.Errorf("error decoding public key: %w", err)
|
|
}
|
|
if len(theirPubKey.Curve25519PublicKey) == 0 {
|
|
return "", errors.New("public key is nil")
|
|
}
|
|
s.cachedRemotePubKey = theirPubKey.Curve25519PublicKey
|
|
}
|
|
if len(s.cachedPubKey) == 0 {
|
|
s.cachedPubKey, s.cachedPriKey, err = dhutil.GeneratePublicPrivateKey()
|
|
if err != nil {
|
|
return "", fmt.Errorf("error generating pub/pri curve25519 keys: %w", err)
|
|
}
|
|
}
|
|
resp.Curve25519PublicKey = s.cachedPubKey
|
|
}
|
|
|
|
secret, err := dhutil.GenerateSharedSecret(s.cachedPriKey, s.cachedRemotePubKey)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error calculating shared key: %w", err)
|
|
}
|
|
if s.DeriveKey {
|
|
aesKey, err = dhutil.DeriveSharedKey(secret, s.cachedPubKey, s.cachedRemotePubKey)
|
|
} else {
|
|
aesKey = secret
|
|
}
|
|
|
|
if err != nil {
|
|
return "", fmt.Errorf("error deriving shared key: %w", err)
|
|
}
|
|
if len(aesKey) == 0 {
|
|
return "", errors.New("derived AES key is empty")
|
|
}
|
|
|
|
resp.EncryptedPayload, resp.Nonce, err = dhutil.EncryptAES(aesKey, []byte(token), []byte(s.AAD))
|
|
if err != nil {
|
|
return "", fmt.Errorf("error encrypting with shared key: %w", err)
|
|
}
|
|
m, err := jsonutil.EncodeJSON(resp)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error encoding encrypted payload: %w", err)
|
|
}
|
|
|
|
return string(m), nil
|
|
}
|
|
|
|
func (s *SinkConfig) wrapToken(client *api.Client, wrapTTL time.Duration, token string) (string, error) {
|
|
wrapClient, err := client.CloneWithHeaders()
|
|
if err != nil {
|
|
return "", fmt.Errorf("error deriving client for wrapping, not writing out to sink: %w)", err)
|
|
}
|
|
|
|
wrapClient.SetToken(token)
|
|
wrapClient.SetWrappingLookupFunc(func(string, string) string {
|
|
return wrapTTL.String()
|
|
})
|
|
|
|
secret, err := wrapClient.Logical().Write("sys/wrapping/wrap", map[string]interface{}{
|
|
"token": token,
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("error wrapping token, not writing out to sink: %w)", err)
|
|
}
|
|
if secret == nil {
|
|
return "", errors.New("nil secret returned, not writing out to sink")
|
|
}
|
|
if secret.WrapInfo == nil {
|
|
return "", errors.New("nil wrap info returned, not writing out to sink")
|
|
}
|
|
|
|
m, err := jsonutil.EncodeJSON(secret.WrapInfo)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error marshaling token, not writing out to sink: %w)", err)
|
|
}
|
|
|
|
return string(m), nil
|
|
}
|