open-vault/helper/testhelpers/pluginhelpers/pluginhelpers.go
Hamid Ghaf 27bb03bbc0
adding copyright header (#19555)
* adding copyright header

* fix fmt and a test
2023-03-15 09:00:52 -07:00

147 lines
4 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
// Package pluginhelpers contains testhelpers that don't depend on package
// vault, and thus can be used within vault (as well as elsewhere.)
package pluginhelpers
import (
"crypto/sha256"
"fmt"
"os"
"os/exec"
"path"
"path/filepath"
"sync"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/mitchellh/go-testing-interface"
)
var (
testPluginCacheLock sync.Mutex
testPluginCache = map[string][]byte{}
)
type TestPlugin struct {
Name string
Typ consts.PluginType
Version string
FileName string
Sha256 string
}
func GetPlugin(t testing.T, typ consts.PluginType) (string, string, string, string) {
t.Helper()
var pluginName string
var pluginType string
var pluginMain string
var pluginVersionLocation string
switch typ {
case consts.PluginTypeCredential:
pluginType = "approle"
pluginName = "vault-plugin-auth-" + pluginType
pluginMain = filepath.Join("builtin", "credential", pluginType, "cmd", pluginType, "main.go")
pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/builtin/credential/%s.ReportedVersion", pluginType)
case consts.PluginTypeSecrets:
pluginType = "consul"
pluginName = "vault-plugin-secrets-" + pluginType
pluginMain = filepath.Join("builtin", "logical", pluginType, "cmd", pluginType, "main.go")
pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/builtin/logical/%s.ReportedVersion", pluginType)
case consts.PluginTypeDatabase:
pluginType = "postgresql"
pluginName = "vault-plugin-database-" + pluginType
pluginMain = filepath.Join("plugins", "database", pluginType, fmt.Sprintf("%s-database-plugin", pluginType), "main.go")
pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/plugins/database/%s.ReportedVersion", pluginType)
default:
t.Fatal(typ.String())
}
return pluginName, pluginType, pluginMain, pluginVersionLocation
}
// to mount a plugin, we need a working binary plugin, so we compile one here.
// pluginVersion is used to override the plugin's self-reported version
func CompilePlugin(t testing.T, typ consts.PluginType, pluginVersion string, pluginDir string) TestPlugin {
t.Helper()
pluginName, pluginType, pluginMain, pluginVersionLocation := GetPlugin(t, typ)
testPluginCacheLock.Lock()
defer testPluginCacheLock.Unlock()
var pluginBytes []byte
dir := ""
var err error
pluginRootDir := "builtin"
if typ == consts.PluginTypeDatabase {
pluginRootDir = "plugins"
}
for {
dir, err = os.Getwd()
if err != nil {
t.Fatal(err)
}
// detect if we are in a subdirectory or the root directory and compensate
if _, err := os.Stat(pluginRootDir); os.IsNotExist(err) {
err := os.Chdir("..")
if err != nil {
t.Fatal(err)
}
} else {
break
}
}
pluginPath := path.Join(pluginDir, pluginName)
if pluginVersion != "" {
pluginPath += "-" + pluginVersion
}
key := fmt.Sprintf("%s %s %s", pluginName, pluginType, pluginVersion)
// cache the compilation to only run once
var ok bool
pluginBytes, ok = testPluginCache[key]
if !ok {
// we need to compile
line := []string{"build"}
if pluginVersion != "" {
line = append(line, "-ldflags", fmt.Sprintf("-X %s=%s", pluginVersionLocation, pluginVersion))
}
line = append(line, "-o", pluginPath, pluginMain)
cmd := exec.Command("go", line...)
cmd.Dir = dir
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatal(fmt.Errorf("error running go build %v output: %s", err, output))
}
testPluginCache[key], err = os.ReadFile(pluginPath)
if err != nil {
t.Fatal(err)
}
pluginBytes = testPluginCache[key]
}
// write the cached plugin if necessary
if _, err := os.Stat(pluginPath); os.IsNotExist(err) {
err = os.WriteFile(pluginPath, pluginBytes, 0o755)
}
if err != nil {
t.Fatal(err)
}
sha := sha256.New()
_, err = sha.Write(pluginBytes)
if err != nil {
t.Fatal(err)
}
return TestPlugin{
Name: pluginName,
Typ: typ,
Version: pluginVersion,
FileName: path.Base(pluginPath),
Sha256: fmt.Sprintf("%x", sha.Sum(nil)),
}
}