Check if sys view is missing in GRPC sys view (#18210)

And return an error instead of panicking.

This situation can occur if a plugin attempts to access the system
view during setup when Vault is checking the plugin metadata.

Fixes #17878.
This commit is contained in:
Christopher Swenson 2022-12-02 10:12:05 -08:00 committed by GitHub
parent 71b790bd0f
commit eba490ccef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 3 deletions

3
changelog/18210.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
sdk: Don't panic if system view or storage methods called during plugin setup.
```

View File

@ -10,6 +10,8 @@ import (
"github.com/hashicorp/vault/sdk/plugin/pb" "github.com/hashicorp/vault/sdk/plugin/pb"
) )
var errMissingStorage = errors.New("missing storage implementation: this method should not be called during plugin Setup, but only during and after Initialize")
func newGRPCStorageClient(conn *grpc.ClientConn) *GRPCStorageClient { func newGRPCStorageClient(conn *grpc.ClientConn) *GRPCStorageClient {
return &GRPCStorageClient{ return &GRPCStorageClient{
client: pb.NewStorageClient(conn), client: pb.NewStorageClient(conn),
@ -74,13 +76,16 @@ func (s *GRPCStorageClient) Delete(ctx context.Context, key string) error {
return nil return nil
} }
// StorageServer is a net/rpc compatible structure for serving // GRPCStorageServer is a net/rpc compatible structure for serving
type GRPCStorageServer struct { type GRPCStorageServer struct {
pb.UnimplementedStorageServer pb.UnimplementedStorageServer
impl logical.Storage impl logical.Storage
} }
func (s *GRPCStorageServer) List(ctx context.Context, args *pb.StorageListArgs) (*pb.StorageListReply, error) { func (s *GRPCStorageServer) List(ctx context.Context, args *pb.StorageListArgs) (*pb.StorageListReply, error) {
if s.impl == nil {
return nil, errMissingStorage
}
keys, err := s.impl.List(ctx, args.Prefix) keys, err := s.impl.List(ctx, args.Prefix)
return &pb.StorageListReply{ return &pb.StorageListReply{
Keys: keys, Keys: keys,
@ -89,6 +94,9 @@ func (s *GRPCStorageServer) List(ctx context.Context, args *pb.StorageListArgs)
} }
func (s *GRPCStorageServer) Get(ctx context.Context, args *pb.StorageGetArgs) (*pb.StorageGetReply, error) { func (s *GRPCStorageServer) Get(ctx context.Context, args *pb.StorageGetArgs) (*pb.StorageGetReply, error) {
if s.impl == nil {
return nil, errMissingStorage
}
storageEntry, err := s.impl.Get(ctx, args.Key) storageEntry, err := s.impl.Get(ctx, args.Key)
if storageEntry == nil { if storageEntry == nil {
return &pb.StorageGetReply{ return &pb.StorageGetReply{
@ -103,6 +111,9 @@ func (s *GRPCStorageServer) Get(ctx context.Context, args *pb.StorageGetArgs) (*
} }
func (s *GRPCStorageServer) Put(ctx context.Context, args *pb.StoragePutArgs) (*pb.StoragePutReply, error) { func (s *GRPCStorageServer) Put(ctx context.Context, args *pb.StoragePutArgs) (*pb.StoragePutReply, error) {
if s.impl == nil {
return nil, errMissingStorage
}
err := s.impl.Put(ctx, pb.ProtoStorageEntryToLogicalStorageEntry(args.Entry)) err := s.impl.Put(ctx, pb.ProtoStorageEntryToLogicalStorageEntry(args.Entry))
return &pb.StoragePutReply{ return &pb.StoragePutReply{
Err: pb.ErrToString(err), Err: pb.ErrToString(err),
@ -110,6 +121,9 @@ func (s *GRPCStorageServer) Put(ctx context.Context, args *pb.StoragePutArgs) (*
} }
func (s *GRPCStorageServer) Delete(ctx context.Context, args *pb.StorageDeleteArgs) (*pb.StorageDeleteReply, error) { func (s *GRPCStorageServer) Delete(ctx context.Context, args *pb.StorageDeleteArgs) (*pb.StorageDeleteReply, error) {
if s.impl == nil {
return nil, errMissingStorage
}
err := s.impl.Delete(ctx, args.Key) err := s.impl.Delete(ctx, args.Key)
return &pb.StorageDeleteReply{ return &pb.StorageDeleteReply{
Err: pb.ErrToString(err), Err: pb.ErrToString(err),

View File

@ -18,6 +18,8 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
var errMissingSystemView = errors.New("missing system view implementation: this method should not be called during plugin Setup, but only during and after Initialize")
func newGRPCSystemView(conn *grpc.ClientConn) *gRPCSystemViewClient { func newGRPCSystemView(conn *grpc.ClientConn) *gRPCSystemViewClient {
return &gRPCSystemViewClient{ return &gRPCSystemViewClient{
client: pb.NewSystemViewClient(conn), client: pb.NewSystemViewClient(conn),
@ -193,6 +195,9 @@ type gRPCSystemViewServer struct {
} }
func (s *gRPCSystemViewServer) DefaultLeaseTTL(ctx context.Context, _ *pb.Empty) (*pb.TTLReply, error) { func (s *gRPCSystemViewServer) DefaultLeaseTTL(ctx context.Context, _ *pb.Empty) (*pb.TTLReply, error) {
if s.impl == nil {
return nil, errMissingSystemView
}
ttl := s.impl.DefaultLeaseTTL() ttl := s.impl.DefaultLeaseTTL()
return &pb.TTLReply{ return &pb.TTLReply{
TTL: int64(ttl), TTL: int64(ttl),
@ -200,6 +205,9 @@ func (s *gRPCSystemViewServer) DefaultLeaseTTL(ctx context.Context, _ *pb.Empty)
} }
func (s *gRPCSystemViewServer) MaxLeaseTTL(ctx context.Context, _ *pb.Empty) (*pb.TTLReply, error) { func (s *gRPCSystemViewServer) MaxLeaseTTL(ctx context.Context, _ *pb.Empty) (*pb.TTLReply, error) {
if s.impl == nil {
return nil, errMissingSystemView
}
ttl := s.impl.MaxLeaseTTL() ttl := s.impl.MaxLeaseTTL()
return &pb.TTLReply{ return &pb.TTLReply{
TTL: int64(ttl), TTL: int64(ttl),
@ -207,6 +215,9 @@ func (s *gRPCSystemViewServer) MaxLeaseTTL(ctx context.Context, _ *pb.Empty) (*p
} }
func (s *gRPCSystemViewServer) Tainted(ctx context.Context, _ *pb.Empty) (*pb.TaintedReply, error) { func (s *gRPCSystemViewServer) Tainted(ctx context.Context, _ *pb.Empty) (*pb.TaintedReply, error) {
if s.impl == nil {
return nil, errMissingSystemView
}
tainted := s.impl.Tainted() tainted := s.impl.Tainted()
return &pb.TaintedReply{ return &pb.TaintedReply{
Tainted: tainted, Tainted: tainted,
@ -214,6 +225,9 @@ func (s *gRPCSystemViewServer) Tainted(ctx context.Context, _ *pb.Empty) (*pb.Ta
} }
func (s *gRPCSystemViewServer) CachingDisabled(ctx context.Context, _ *pb.Empty) (*pb.CachingDisabledReply, error) { func (s *gRPCSystemViewServer) CachingDisabled(ctx context.Context, _ *pb.Empty) (*pb.CachingDisabledReply, error) {
if s.impl == nil {
return nil, errMissingSystemView
}
cachingDisabled := s.impl.CachingDisabled() cachingDisabled := s.impl.CachingDisabled()
return &pb.CachingDisabledReply{ return &pb.CachingDisabledReply{
Disabled: cachingDisabled, Disabled: cachingDisabled,
@ -221,6 +235,9 @@ func (s *gRPCSystemViewServer) CachingDisabled(ctx context.Context, _ *pb.Empty)
} }
func (s *gRPCSystemViewServer) ReplicationState(ctx context.Context, _ *pb.Empty) (*pb.ReplicationStateReply, error) { func (s *gRPCSystemViewServer) ReplicationState(ctx context.Context, _ *pb.Empty) (*pb.ReplicationStateReply, error) {
if s.impl == nil {
return nil, errMissingSystemView
}
replicationState := s.impl.ReplicationState() replicationState := s.impl.ReplicationState()
return &pb.ReplicationStateReply{ return &pb.ReplicationStateReply{
State: int32(replicationState), State: int32(replicationState),
@ -228,6 +245,9 @@ func (s *gRPCSystemViewServer) ReplicationState(ctx context.Context, _ *pb.Empty
} }
func (s *gRPCSystemViewServer) ResponseWrapData(ctx context.Context, args *pb.ResponseWrapDataArgs) (*pb.ResponseWrapDataReply, error) { func (s *gRPCSystemViewServer) ResponseWrapData(ctx context.Context, args *pb.ResponseWrapDataArgs) (*pb.ResponseWrapDataReply, error) {
if s.impl == nil {
return nil, errMissingSystemView
}
data := map[string]interface{}{} data := map[string]interface{}{}
err := json.Unmarshal([]byte(args.Data), &data) err := json.Unmarshal([]byte(args.Data), &data)
if err != nil { if err != nil {
@ -253,6 +273,9 @@ func (s *gRPCSystemViewServer) ResponseWrapData(ctx context.Context, args *pb.Re
} }
func (s *gRPCSystemViewServer) MlockEnabled(ctx context.Context, _ *pb.Empty) (*pb.MlockEnabledReply, error) { func (s *gRPCSystemViewServer) MlockEnabled(ctx context.Context, _ *pb.Empty) (*pb.MlockEnabledReply, error) {
if s.impl == nil {
return nil, errMissingSystemView
}
enabled := s.impl.MlockEnabled() enabled := s.impl.MlockEnabled()
return &pb.MlockEnabledReply{ return &pb.MlockEnabledReply{
Enabled: enabled, Enabled: enabled,
@ -260,6 +283,9 @@ func (s *gRPCSystemViewServer) MlockEnabled(ctx context.Context, _ *pb.Empty) (*
} }
func (s *gRPCSystemViewServer) LocalMount(ctx context.Context, _ *pb.Empty) (*pb.LocalMountReply, error) { func (s *gRPCSystemViewServer) LocalMount(ctx context.Context, _ *pb.Empty) (*pb.LocalMountReply, error) {
if s.impl == nil {
return nil, errMissingSystemView
}
local := s.impl.LocalMount() local := s.impl.LocalMount()
return &pb.LocalMountReply{ return &pb.LocalMountReply{
Local: local, Local: local,
@ -267,6 +293,9 @@ func (s *gRPCSystemViewServer) LocalMount(ctx context.Context, _ *pb.Empty) (*pb
} }
func (s *gRPCSystemViewServer) EntityInfo(ctx context.Context, args *pb.EntityInfoArgs) (*pb.EntityInfoReply, error) { func (s *gRPCSystemViewServer) EntityInfo(ctx context.Context, args *pb.EntityInfoArgs) (*pb.EntityInfoReply, error) {
if s.impl == nil {
return nil, errMissingSystemView
}
entity, err := s.impl.EntityInfo(args.EntityID) entity, err := s.impl.EntityInfo(args.EntityID)
if err != nil { if err != nil {
return &pb.EntityInfoReply{ return &pb.EntityInfoReply{
@ -279,6 +308,9 @@ func (s *gRPCSystemViewServer) EntityInfo(ctx context.Context, args *pb.EntityIn
} }
func (s *gRPCSystemViewServer) GroupsForEntity(ctx context.Context, args *pb.EntityInfoArgs) (*pb.GroupsForEntityReply, error) { func (s *gRPCSystemViewServer) GroupsForEntity(ctx context.Context, args *pb.EntityInfoArgs) (*pb.GroupsForEntityReply, error) {
if s.impl == nil {
return nil, errMissingSystemView
}
groups, err := s.impl.GroupsForEntity(args.EntityID) groups, err := s.impl.GroupsForEntity(args.EntityID)
if err != nil { if err != nil {
return &pb.GroupsForEntityReply{ return &pb.GroupsForEntityReply{
@ -291,6 +323,9 @@ func (s *gRPCSystemViewServer) GroupsForEntity(ctx context.Context, args *pb.Ent
} }
func (s *gRPCSystemViewServer) PluginEnv(ctx context.Context, _ *pb.Empty) (*pb.PluginEnvReply, error) { func (s *gRPCSystemViewServer) PluginEnv(ctx context.Context, _ *pb.Empty) (*pb.PluginEnvReply, error) {
if s.impl == nil {
return nil, errMissingSystemView
}
pluginEnv, err := s.impl.PluginEnv(ctx) pluginEnv, err := s.impl.PluginEnv(ctx)
if err != nil { if err != nil {
return &pb.PluginEnvReply{ return &pb.PluginEnvReply{
@ -303,6 +338,9 @@ func (s *gRPCSystemViewServer) PluginEnv(ctx context.Context, _ *pb.Empty) (*pb.
} }
func (s *gRPCSystemViewServer) GeneratePasswordFromPolicy(ctx context.Context, req *pb.GeneratePasswordFromPolicyRequest) (*pb.GeneratePasswordFromPolicyReply, error) { func (s *gRPCSystemViewServer) GeneratePasswordFromPolicy(ctx context.Context, req *pb.GeneratePasswordFromPolicyRequest) (*pb.GeneratePasswordFromPolicyReply, error) {
if s.impl == nil {
return nil, errMissingSystemView
}
policyName := req.PolicyName policyName := req.PolicyName
if policyName == "" { if policyName == "" {
return &pb.GeneratePasswordFromPolicyReply{}, status.Errorf(codes.InvalidArgument, "no password policy specified") return &pb.GeneratePasswordFromPolicyReply{}, status.Errorf(codes.InvalidArgument, "no password policy specified")

View File

@ -6,7 +6,7 @@ import (
"testing" "testing"
"time" "time"
plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/plugin/pb" "github.com/hashicorp/vault/sdk/plugin/pb"
@ -14,6 +14,13 @@ import (
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
func TestSystem_GRPC_ReturnsErrIfSystemViewNil(t *testing.T) {
_, err := new(gRPCSystemViewServer).ReplicationState(context.Background(), nil)
if err == nil {
t.Error("Expected error when using server with no impl")
}
}
func TestSystem_GRPC_GRPC_impl(t *testing.T) { func TestSystem_GRPC_GRPC_impl(t *testing.T) {
var _ logical.SystemView = new(gRPCSystemViewClient) var _ logical.SystemView = new(gRPCSystemViewClient)
} }

View File

@ -1,15 +1,23 @@
package plugin package plugin
import ( import (
"context"
"testing" "testing"
"google.golang.org/grpc" "google.golang.org/grpc"
plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/plugin/pb" "github.com/hashicorp/vault/sdk/plugin/pb"
) )
func TestStorage_GRPC_ReturnsErrIfStorageNil(t *testing.T) {
_, err := new(GRPCStorageServer).Get(context.Background(), nil)
if err == nil {
t.Error("Expected error when using server with no impl")
}
}
func TestStorage_impl(t *testing.T) { func TestStorage_impl(t *testing.T) {
var _ logical.Storage = new(GRPCStorageClient) var _ logical.Storage = new(GRPCStorageClient)
} }