Merge branch 'master-oss' into sys-tidy-leases
This commit is contained in:
commit
55ef4f2566
10
CHANGELOG.md
10
CHANGELOG.md
|
@ -25,6 +25,14 @@ FEATURES:
|
|||
revoke-force) have also been relocated to `sys/leases/`, but they also work
|
||||
at the old paths for compatibility. Reading (but not listing) leases via
|
||||
`sys/leases/lookup` is now a part of the current `default` policy. [GH-2650]
|
||||
* **TOTP Secret Backend**: You can now store multi-factor authentication keys
|
||||
in Vault and use the API to retrieve time-based one-time use passwords on
|
||||
demand. The backend can also be used to generate a new key and validate
|
||||
passwords generated by that key. [GH-2492]
|
||||
* **Database Secret Backend & Secure Plugins**: This new secret backend
|
||||
combines the functionality of the MySQL, PostgreSQL, MSSQL, and Cassandra
|
||||
backends. It also provides a plugin interface for extendability through
|
||||
custom databases. [GH-2200]
|
||||
|
||||
IMPROVEMENTS:
|
||||
|
||||
|
@ -59,6 +67,8 @@ BUG FIXES:
|
|||
* auth/ldap: Don't lowercase groups attached to users [GH-2613]
|
||||
* cli: Don't panic if `vault write` is used with the `force` flag but no path
|
||||
[GH-2674]
|
||||
* core: Help operations should request forward since standbys may not have
|
||||
appropriate info [GH-2677]
|
||||
* secret/mssql: Update mssql driver to support queries with colons [GH-2610]
|
||||
* secret/pki: Don't lowercase O/OU values in certs [GH-2555]
|
||||
* secret/pki: Don't attempt to validate IP SANs if none are provided [GH-2574]
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/salt"
|
||||
"github.com/hashicorp/vault/helper/wrapping"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/mitchellh/copystructure"
|
||||
"github.com/mitchellh/reflectwalk"
|
||||
|
@ -84,7 +85,7 @@ func Hash(salter *salt.Salt, raw interface{}) error {
|
|||
|
||||
s.Data = data.(map[string]interface{})
|
||||
|
||||
case *logical.ResponseWrapInfo:
|
||||
case *wrapping.ResponseWrapInfo:
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/salt"
|
||||
"github.com/hashicorp/vault/helper/wrapping"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/mitchellh/copystructure"
|
||||
)
|
||||
|
@ -69,7 +70,7 @@ func TestCopy_response(t *testing.T) {
|
|||
Data: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
WrapInfo: &logical.ResponseWrapInfo{
|
||||
WrapInfo: &wrapping.ResponseWrapInfo{
|
||||
TTL: 60,
|
||||
Token: "foo",
|
||||
CreationTime: time.Now(),
|
||||
|
@ -140,7 +141,7 @@ func TestHash(t *testing.T) {
|
|||
Data: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
WrapInfo: &logical.ResponseWrapInfo{
|
||||
WrapInfo: &wrapping.ResponseWrapInfo{
|
||||
TTL: 60,
|
||||
Token: "bar",
|
||||
CreationTime: now,
|
||||
|
@ -151,7 +152,7 @@ func TestHash(t *testing.T) {
|
|||
Data: map[string]interface{}{
|
||||
"foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
|
||||
},
|
||||
WrapInfo: &logical.ResponseWrapInfo{
|
||||
WrapInfo: &wrapping.ResponseWrapInfo{
|
||||
TTL: 60,
|
||||
Token: "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
|
||||
CreationTime: now,
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
const databaseConfigPath = "database/config/"
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
return Backend(conf).Setup(conf)
|
||||
}
|
||||
|
||||
func Backend(conf *logical.BackendConfig) *databaseBackend {
|
||||
var b databaseBackend
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigurePluginConnection(&b),
|
||||
pathListRoles(&b),
|
||||
pathRoles(&b),
|
||||
pathCredsCreate(&b),
|
||||
pathResetConnection(&b),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{
|
||||
secretCreds(&b),
|
||||
},
|
||||
|
||||
Clean: b.closeAllDBs,
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
}
|
||||
|
||||
b.logger = conf.Logger
|
||||
b.connections = make(map[string]dbplugin.Database)
|
||||
return &b
|
||||
}
|
||||
|
||||
type databaseBackend struct {
|
||||
connections map[string]dbplugin.Database
|
||||
logger log.Logger
|
||||
|
||||
*framework.Backend
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// closeAllDBs closes all connections from all database types
|
||||
func (b *databaseBackend) closeAllDBs() {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
for _, db := range b.connections {
|
||||
db.Close()
|
||||
}
|
||||
|
||||
b.connections = make(map[string]dbplugin.Database)
|
||||
}
|
||||
|
||||
// This function is used to retrieve a database object either from the cached
|
||||
// connection map. The caller of this function needs to hold the backend's read
|
||||
// lock.
|
||||
func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
|
||||
db, ok := b.connections[name]
|
||||
return db, ok
|
||||
}
|
||||
|
||||
// This function creates a new db object from the stored configuration and
|
||||
// caches it in the connections map. The caller of this function needs to hold
|
||||
// the backend's write lock
|
||||
func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) {
|
||||
db, ok := b.connections[name]
|
||||
if ok {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
config, err := b.DatabaseConfig(s, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err = dbplugin.PluginFactory(config.PluginName, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = db.Initialize(config.ConnectionDetails, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.connections[name] = db
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*DatabaseConfig, error) {
|
||||
entry, err := s.Get(fmt.Sprintf("config/%s", name))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read connection configuration: %s", err)
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, fmt.Errorf("failed to find entry for connection with name: %s", name)
|
||||
}
|
||||
|
||||
var config DatabaseConfig
|
||||
if err := entry.DecodeJSON(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry, error) {
|
||||
entry, err := s.Get("role/" + roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result roleEntry
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) invalidate(key string) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(key, databaseConfigPath):
|
||||
name := strings.TrimPrefix(key, databaseConfigPath)
|
||||
b.clearConnection(name)
|
||||
}
|
||||
}
|
||||
|
||||
// clearConnection closes the database connection and
|
||||
// removes it from the b.connections map.
|
||||
func (b *databaseBackend) clearConnection(name string) {
|
||||
db, ok := b.connections[name]
|
||||
if ok {
|
||||
db.Close()
|
||||
delete(b.connections, name)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) closeIfShutdown(name string, err error) {
|
||||
// Plugin has shutdown, close it so next call can reconnect.
|
||||
if err == rpc.ErrShutdown {
|
||||
b.Lock()
|
||||
b.clearConnection(name)
|
||||
b.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The database backend supports using many different databases
|
||||
as secret backends, including but not limited to:
|
||||
cassandra, mssql, mysql, postgres
|
||||
|
||||
After mounting this backend, configure it using the endpoints within
|
||||
the "database/config/" path.
|
||||
`
|
|
@ -0,0 +1,766 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
stdhttp "net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/plugins/database/postgresql"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/lib/pq"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
testImagePull sync.Once
|
||||
)
|
||||
|
||||
func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cleanup func(), retURL string) {
|
||||
if os.Getenv("PG_URL") != "" {
|
||||
return func() {}, os.Getenv("PG_URL")
|
||||
}
|
||||
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
|
||||
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=database"})
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start local PostgreSQL docker container: %s", err)
|
||||
}
|
||||
|
||||
cleanup = func() {
|
||||
err := pool.Purge(resource)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to cleanup local container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
retURL = fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp"))
|
||||
|
||||
// exponential backoff-retry
|
||||
if err = pool.Retry(func() error {
|
||||
// This will cause a validation to run
|
||||
resp, err := b.HandleRequest(&logical.Request{
|
||||
Storage: s,
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/postgresql",
|
||||
Data: map[string]interface{}{
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"connection_url": retURL,
|
||||
},
|
||||
})
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
// It's likely not up and running yet, so return error and try again
|
||||
return fmt.Errorf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected warning")
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("Could not connect to PostgreSQL docker container: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getCore(t *testing.T) ([]*vault.TestClusterCore, logical.SystemView) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"database": Factory,
|
||||
},
|
||||
}
|
||||
|
||||
handler1 := stdhttp.NewServeMux()
|
||||
handler2 := stdhttp.NewServeMux()
|
||||
handler3 := stdhttp.NewServeMux()
|
||||
|
||||
// Chicken-and-egg: Handler needs a core. So we create handlers first, then
|
||||
// add routes chained to a Handler-created handler.
|
||||
cores := vault.TestCluster(t, []stdhttp.Handler{handler1, handler2, handler3}, coreConfig, false)
|
||||
handler1.Handle("/", http.Handler(cores[0].Core))
|
||||
handler2.Handle("/", http.Handler(cores[1].Core))
|
||||
handler3.Handle("/", http.Handler(cores[2].Core))
|
||||
|
||||
core := cores[0]
|
||||
|
||||
sys := vault.TestDynamicSystemView(core.Core)
|
||||
vault.TestAddTestPlugin(t, core.Core, "postgresql-database-plugin", "TestBackend_PluginMain")
|
||||
|
||||
return cores, sys
|
||||
}
|
||||
|
||||
func TestBackend_PluginMain(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
content := []byte(vault.TestClusterCACert)
|
||||
tmpfile, err := ioutil.TempFile("", "example")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer os.Remove(tmpfile.Name()) // clean up
|
||||
|
||||
if _, err := tmpfile.Write(content); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
args := []string{"--ca-cert=" + tmpfile.Name()}
|
||||
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
postgresql.Run(apiClientMeta.GetTLSConfig())
|
||||
}
|
||||
|
||||
func TestBackend_config_connection(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.System = sys
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer b.Cleanup()
|
||||
|
||||
configData := map[string]interface{}{
|
||||
"connection_url": "sample_connection_url",
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"verify_connection": false,
|
||||
"allowed_roles": []string{"*"},
|
||||
}
|
||||
|
||||
configReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: configData,
|
||||
}
|
||||
resp, err = b.HandleRequest(configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
expected := map[string]interface{}{
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"connection_details": map[string]interface{}{
|
||||
"connection_url": "sample_connection_url",
|
||||
},
|
||||
"allowed_roles": []string{"*"},
|
||||
}
|
||||
configReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
|
||||
if !reflect.DeepEqual(expected, resp.Data) {
|
||||
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.System = sys
|
||||
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer b.Cleanup()
|
||||
|
||||
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
|
||||
defer cleanup()
|
||||
|
||||
// Configure a connection
|
||||
data := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"allowed_roles": []string{"plugin-role-test"},
|
||||
}
|
||||
req := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err := b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Create a role
|
||||
data = map[string]interface{}{
|
||||
"db_name": "plugin-test",
|
||||
"creation_statements": testRole,
|
||||
"default_ttl": "5m",
|
||||
"max_ttl": "10m",
|
||||
}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/plugin-role-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Get creds
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/plugin-role-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
credsResp, err := b.HandleRequest(req)
|
||||
if err != nil || (credsResp != nil && credsResp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||
}
|
||||
|
||||
if !testCredsExist(t, credsResp, connURL) {
|
||||
t.Fatalf("Creds should exist")
|
||||
}
|
||||
|
||||
// Revoke creds
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Operation: logical.RevokeOperation,
|
||||
Storage: config.StorageView,
|
||||
Secret: &logical.Secret{
|
||||
InternalData: map[string]interface{}{
|
||||
"secret_type": "creds",
|
||||
"username": credsResp.Data["username"],
|
||||
"role": "plugin-role-test",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
if testCredsExist(t, credsResp, connURL) {
|
||||
t.Fatalf("Creds should not exist")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestBackend_connectionCrud(t *testing.T) {
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.System = sys
|
||||
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer b.Cleanup()
|
||||
|
||||
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
|
||||
defer cleanup()
|
||||
|
||||
// Configure a connection
|
||||
data := map[string]interface{}{
|
||||
"connection_url": "test",
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"verify_connection": false,
|
||||
}
|
||||
req := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err := b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Create a role
|
||||
data = map[string]interface{}{
|
||||
"db_name": "plugin-test",
|
||||
"creation_statements": testRole,
|
||||
"revocation_statements": defaultRevocationSQL,
|
||||
"default_ttl": "5m",
|
||||
"max_ttl": "10m",
|
||||
}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/plugin-role-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Update the connection
|
||||
data = map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"allowed_roles": []string{"plugin-role-test"},
|
||||
}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Read connection
|
||||
expected := map[string]interface{}{
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"connection_details": map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
"allowed_roles": []string{"plugin-role-test"},
|
||||
}
|
||||
req.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
|
||||
if !reflect.DeepEqual(expected, resp.Data) {
|
||||
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
|
||||
}
|
||||
|
||||
// Reset Connection
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "reset/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Get creds
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/plugin-role-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
credsResp, err := b.HandleRequest(req)
|
||||
if err != nil || (credsResp != nil && credsResp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||
}
|
||||
|
||||
if !testCredsExist(t, credsResp, connURL) {
|
||||
t.Fatalf("Creds should exist")
|
||||
}
|
||||
|
||||
// Delete Connection
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Read connection
|
||||
req.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Should be empty
|
||||
if resp != nil {
|
||||
t.Fatal("Expected response to be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_roleCrud(t *testing.T) {
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.System = sys
|
||||
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer b.Cleanup()
|
||||
|
||||
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
|
||||
defer cleanup()
|
||||
|
||||
// Configure a connection
|
||||
data := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
}
|
||||
req := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err := b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Create a role
|
||||
data = map[string]interface{}{
|
||||
"db_name": "plugin-test",
|
||||
"creation_statements": testRole,
|
||||
"revocation_statements": defaultRevocationSQL,
|
||||
"default_ttl": "5m",
|
||||
"max_ttl": "10m",
|
||||
}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/plugin-role-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Read the role
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "roles/plugin-role-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
expected := dbplugin.Statements{
|
||||
CreationStatements: testRole,
|
||||
RevocationStatements: defaultRevocationSQL,
|
||||
}
|
||||
|
||||
var actual dbplugin.Statements
|
||||
if err := mapstructure.Decode(resp.Data, &actual); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("Statements did not match, exepected %#v, got %#v", expected, actual)
|
||||
}
|
||||
|
||||
// Delete the role
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "roles/plugin-role-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Read the role
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "roles/plugin-role-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Should be empty
|
||||
if resp != nil {
|
||||
t.Fatal("Expected response to be nil")
|
||||
}
|
||||
}
|
||||
func TestBackend_allowedRoles(t *testing.T) {
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.System = sys
|
||||
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer b.Cleanup()
|
||||
|
||||
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
|
||||
defer cleanup()
|
||||
|
||||
// Configure a connection
|
||||
data := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
}
|
||||
req := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err := b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Create a denied and an allowed role
|
||||
data = map[string]interface{}{
|
||||
"db_name": "plugin-test",
|
||||
"creation_statements": testRole,
|
||||
"default_ttl": "5m",
|
||||
"max_ttl": "10m",
|
||||
}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/denied",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
data = map[string]interface{}{
|
||||
"db_name": "plugin-test",
|
||||
"creation_statements": testRole,
|
||||
"default_ttl": "5m",
|
||||
"max_ttl": "10m",
|
||||
}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/allowed",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Get creds from denied role, should fail
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/denied",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
credsResp, err := b.HandleRequest(req)
|
||||
if err != logical.ErrPermissionDenied {
|
||||
t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err)
|
||||
}
|
||||
|
||||
// update connection with * allowed roles connection
|
||||
data = map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"allowed_roles": "*",
|
||||
}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Get creds, should work.
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/allowed",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
credsResp, err = b.HandleRequest(req)
|
||||
if err != nil || (credsResp != nil && credsResp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||
}
|
||||
|
||||
if !testCredsExist(t, credsResp, connURL) {
|
||||
t.Fatalf("Creds should exist")
|
||||
}
|
||||
|
||||
// update connection with allowed roles
|
||||
data = map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"allowed_roles": "allow, allowed",
|
||||
}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Get creds from denied role, should fail
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/denied",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
credsResp, err = b.HandleRequest(req)
|
||||
if err != logical.ErrPermissionDenied {
|
||||
t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err)
|
||||
}
|
||||
|
||||
// Get creds from allowed role, should work.
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/allowed",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
credsResp, err = b.HandleRequest(req)
|
||||
if err != nil || (credsResp != nil && credsResp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||
}
|
||||
|
||||
if !testCredsExist(t, credsResp, connURL) {
|
||||
t.Fatalf("Creds should exist")
|
||||
}
|
||||
}
|
||||
|
||||
func testCredsExist(t *testing.T, resp *logical.Response, connURL string) bool {
|
||||
var d struct {
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
log.Printf("[TRACE] Generated credentials: %v", d)
|
||||
conn, err := pq.ParseURL(connURL)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn += " timezone=utc"
|
||||
|
||||
db, err := sql.Open("postgres", conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
returnedRows := func() int {
|
||||
stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');")
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query(d.Username)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
i := 0
|
||||
for rows.Next() {
|
||||
i++
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
return returnedRows() == 2
|
||||
}
|
||||
|
||||
const testRole = `
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
`
|
||||
|
||||
const defaultRevocationSQL = `
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
|
||||
REVOKE USAGE ON SCHEMA public FROM {{name}};
|
||||
|
||||
DROP ROLE IF EXISTS {{name}};
|
||||
`
|
|
@ -0,0 +1,132 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
)
|
||||
|
||||
// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's Close
|
||||
// method to also call Kill() on the plugin.Client.
|
||||
type DatabasePluginClient struct {
|
||||
client *plugin.Client
|
||||
sync.Mutex
|
||||
|
||||
*databasePluginRPCClient
|
||||
}
|
||||
|
||||
func (dc *DatabasePluginClient) Close() error {
|
||||
err := dc.databasePluginRPCClient.Close()
|
||||
dc.client.Kill()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// newPluginClient returns a databaseRPCClient with a connection to a running
|
||||
// plugin. The client is wrapped in a DatabasePluginClient object to ensure the
|
||||
// plugin is killed on call of Close().
|
||||
func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner) (Database, error) {
|
||||
// pluginMap is the map of plugins we can dispense.
|
||||
var pluginMap = map[string]plugin.Plugin{
|
||||
"database": new(DatabasePlugin),
|
||||
}
|
||||
|
||||
client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Connect via RPC
|
||||
rpcClient, err := client.Client()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Request the plugin
|
||||
raw, err := rpcClient.Dispense("database")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We should have a database type now. This feels like a normal interface
|
||||
// implementation but is in fact over an RPC connection.
|
||||
databaseRPC := raw.(*databasePluginRPCClient)
|
||||
|
||||
// Wrap RPC implimentation in DatabasePluginClient
|
||||
return &DatabasePluginClient{
|
||||
client: client,
|
||||
databasePluginRPCClient: databaseRPC,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ---- RPC client domain ----
|
||||
|
||||
// databasePluginRPCClient implements Database and is used on the client to
|
||||
// make RPC calls to a plugin.
|
||||
type databasePluginRPCClient struct {
|
||||
client *rpc.Client
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Type() (string, error) {
|
||||
var dbType string
|
||||
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
|
||||
|
||||
return fmt.Sprintf("plugin-%s", dbType), err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
|
||||
req := CreateUserRequest{
|
||||
Statements: statements,
|
||||
UsernamePrefix: usernamePrefix,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
var resp CreateUserResponse
|
||||
err = dr.client.Call("Plugin.CreateUser", req, &resp)
|
||||
|
||||
return resp.Username, resp.Password, err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error {
|
||||
req := RenewUserRequest{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error {
|
||||
req := RevokeUserRequest{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
req := InitializeRequest{
|
||||
Config: conf,
|
||||
VerifyConnection: verifyConnection,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Close() error {
|
||||
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
// ---- Tracing Middleware Domain ----
|
||||
|
||||
// databaseTracingMiddleware wraps a implementation of Database and executes
|
||||
// trace logging on function call.
|
||||
type databaseTracingMiddleware struct {
|
||||
next Database
|
||||
logger log.Logger
|
||||
|
||||
typeStr string
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Type() (string, error) {
|
||||
return mw.next.Type()
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.CreateUser(statements, usernamePrefix, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr)
|
||||
return mw.next.RenewUser(statements, username, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.RevokeUser(statements, username)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.Initialize(conf, verifyConnection)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Close() (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.Close()
|
||||
}
|
||||
|
||||
// ---- Metrics Middleware Domain ----
|
||||
|
||||
// databaseMetricsMiddleware wraps an implementation of Databases and on
|
||||
// function call logs metrics about this instance.
|
||||
type databaseMetricsMiddleware struct {
|
||||
next Database
|
||||
|
||||
typeStr string
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Type() (string, error) {
|
||||
return mw.next.Type()
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "CreateUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
|
||||
|
||||
if err != nil {
|
||||
metrics.IncrCounter([]string{"database", "CreateUser", "error"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser", "error"}, 1)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
|
||||
return mw.next.CreateUser(statements, usernamePrefix, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "RenewUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
|
||||
|
||||
if err != nil {
|
||||
metrics.IncrCounter([]string{"database", "RenewUser", "error"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser", "error"}, 1)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
|
||||
return mw.next.RenewUser(statements, username, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
|
||||
|
||||
if err != nil {
|
||||
metrics.IncrCounter([]string{"database", "RevokeUser", "error"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser", "error"}, 1)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
|
||||
return mw.next.RevokeUser(statements, username)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "Initialize"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
|
||||
|
||||
if err != nil {
|
||||
metrics.IncrCounter([]string{"database", "Initialize", "error"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize", "error"}, 1)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
|
||||
return mw.next.Initialize(conf, verifyConnection)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Close() (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "Close"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "Close"}, now)
|
||||
|
||||
if err != nil {
|
||||
metrics.IncrCounter([]string{"database", "Close", "error"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Close", "error"}, 1)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
metrics.IncrCounter([]string{"database", "Close"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
|
||||
return mw.next.Close()
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
// Database is the interface that all database objects must implement.
|
||||
type Database interface {
|
||||
Type() (string, error)
|
||||
CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error)
|
||||
RenewUser(statements Statements, username string, expiration time.Time) error
|
||||
RevokeUser(statements Statements, username string) error
|
||||
|
||||
Initialize(config map[string]interface{}, verifyConnection bool) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Statements set in role creation and passed into the database type's functions.
|
||||
type Statements struct {
|
||||
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
|
||||
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
|
||||
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
|
||||
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
|
||||
}
|
||||
|
||||
// PluginFactory is used to build plugin database types. It wraps the database
|
||||
// object in a logging and metrics middleware.
|
||||
func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
|
||||
// Look for plugin in the plugin catalog
|
||||
pluginRunner, err := sys.LookupPlugin(pluginName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var db Database
|
||||
if pluginRunner.Builtin {
|
||||
// Plugin is builtin so we can retrieve an instance of the interface
|
||||
// from the pluginRunner. Then cast it to a Database.
|
||||
dbRaw, err := pluginRunner.BuiltinFactory()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting plugin type: %s", err)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
db, ok = dbRaw.(Database)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsuported database type: %s", pluginName)
|
||||
}
|
||||
|
||||
} else {
|
||||
// create a DatabasePluginClient instance
|
||||
db, err = newPluginClient(sys, pluginRunner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
typeStr, err := db.Type()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting plugin type: %s", err)
|
||||
}
|
||||
|
||||
// Wrap with metrics middleware
|
||||
db = &databaseMetricsMiddleware{
|
||||
next: db,
|
||||
typeStr: typeStr,
|
||||
}
|
||||
|
||||
// Wrap with tracing middleware
|
||||
if logger.IsTrace() {
|
||||
db = &databaseTracingMiddleware{
|
||||
next: db,
|
||||
typeStr: typeStr,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// handshakeConfigs are used to just do a basic handshake between
|
||||
// a plugin and host. If the handshake fails, a user friendly error is shown.
|
||||
// This prevents users from executing bad plugins or executing a plugin
|
||||
// directory. It is a UX feature, not a security feature.
|
||||
var handshakeConfig = plugin.HandshakeConfig{
|
||||
ProtocolVersion: 1,
|
||||
MagicCookieKey: "VAULT_DATABASE_PLUGIN",
|
||||
MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb",
|
||||
}
|
||||
|
||||
// DatabasePlugin implements go-plugin's Plugin interface. It has methods for
|
||||
// retrieving a server and a client instance of the plugin.
|
||||
type DatabasePlugin struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) {
|
||||
return &databasePluginRPCServer{impl: d.impl}, nil
|
||||
}
|
||||
|
||||
func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
|
||||
return &databasePluginRPCClient{client: c}, nil
|
||||
}
|
||||
|
||||
// ---- RPC Request Args Domain ----
|
||||
|
||||
type InitializeRequest struct {
|
||||
Config map[string]interface{}
|
||||
VerifyConnection bool
|
||||
}
|
||||
|
||||
type CreateUserRequest struct {
|
||||
Statements Statements
|
||||
UsernamePrefix string
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RenewUserRequest struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RevokeUserRequest struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
}
|
||||
|
||||
// ---- RPC Response Args Domain ----
|
||||
|
||||
type CreateUserResponse struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
|
@ -0,0 +1,248 @@
|
|||
package dbplugin_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
stdhttp "net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/plugins"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
type mockPlugin struct {
|
||||
users map[string][]string
|
||||
}
|
||||
|
||||
func (m *mockPlugin) Type() (string, error) { return "mock", nil }
|
||||
func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
|
||||
err = errors.New("err")
|
||||
if usernamePrefix == "" || expiration.IsZero() {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if _, ok := m.users[usernamePrefix]; ok {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
m.users[usernamePrefix] = []string{password}
|
||||
|
||||
return usernamePrefix, "test", nil
|
||||
}
|
||||
func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
err := errors.New("err")
|
||||
if username == "" || expiration.IsZero() {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := m.users[username]; !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
err := errors.New("err")
|
||||
if username == "" {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := m.users[username]; !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(m.users, username)
|
||||
return nil
|
||||
}
|
||||
func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error {
|
||||
err := errors.New("err")
|
||||
if len(conf) != 1 {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func (m *mockPlugin) Close() error {
|
||||
m.users = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func getCore(t *testing.T) ([]*vault.TestClusterCore, logical.SystemView) {
|
||||
coreConfig := &vault.CoreConfig{}
|
||||
|
||||
handler1 := stdhttp.NewServeMux()
|
||||
handler2 := stdhttp.NewServeMux()
|
||||
handler3 := stdhttp.NewServeMux()
|
||||
|
||||
// Chicken-and-egg: Handler needs a core. So we create handlers first, then
|
||||
// add routes chained to a Handler-created handler.
|
||||
cores := vault.TestCluster(t, []stdhttp.Handler{handler1, handler2, handler3}, coreConfig, false)
|
||||
handler1.Handle("/", http.Handler(cores[0].Core))
|
||||
handler2.Handle("/", http.Handler(cores[1].Core))
|
||||
handler3.Handle("/", http.Handler(cores[2].Core))
|
||||
|
||||
core := cores[0]
|
||||
|
||||
sys := vault.TestDynamicSystemView(core.Core)
|
||||
vault.TestAddTestPlugin(t, core.Core, "test-plugin", "TestPlugin_Main")
|
||||
|
||||
return cores, sys
|
||||
}
|
||||
|
||||
// This is not an actual test case, it's a helper function that will be executed
|
||||
// by the go-plugin client via an exec call.
|
||||
func TestPlugin_Main(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
plugin := &mockPlugin{
|
||||
users: make(map[string][]string),
|
||||
}
|
||||
|
||||
args := []string{"--tls-skip-verify=true"}
|
||||
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
plugins.Serve(plugin, apiClientMeta.GetTLSConfig())
|
||||
}
|
||||
|
||||
func TestPlugin_Initialize(t *testing.T) {
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
dbRaw, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
|
||||
err = dbRaw.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = dbRaw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_CreateUser(t *testing.T) {
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
us, pw, err := db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if us != "test" || pw != "test" {
|
||||
t.Fatal("expected username and password to be 'test'")
|
||||
}
|
||||
|
||||
// try and save the same user again to verify it saved the first time, this
|
||||
// should return an error
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, user wasn't created correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_RenewUser(t *testing.T) {
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_RevokeUser(t *testing.T) {
|
||||
cores, sys := getCore(t)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(dbplugin.Statements{}, us)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Try adding the same username back so we can verify it was removed
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
)
|
||||
|
||||
// Serve is called from within a plugin and wraps the provided
|
||||
// Database implementation in a databasePluginRPCServer object and starts a
|
||||
// RPC server.
|
||||
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
||||
dbPlugin := &DatabasePlugin{
|
||||
impl: db,
|
||||
}
|
||||
|
||||
// pluginMap is the map of plugins we can dispense.
|
||||
var pluginMap = map[string]plugin.Plugin{
|
||||
"database": dbPlugin,
|
||||
}
|
||||
|
||||
plugin.Serve(&plugin.ServeConfig{
|
||||
HandshakeConfig: handshakeConfig,
|
||||
Plugins: pluginMap,
|
||||
TLSProvider: tlsProvider,
|
||||
})
|
||||
}
|
||||
|
||||
// ---- RPC server domain ----
|
||||
|
||||
// databasePluginRPCServer implements an RPC version of Database and is run
|
||||
// inside a plugin. It wraps an underlying implementation of Database.
|
||||
type databasePluginRPCServer struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
|
||||
var err error
|
||||
*resp, err = ds.impl.Type()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {
|
||||
var err error
|
||||
resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernamePrefix, args.Expiration)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error {
|
||||
err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error {
|
||||
err := ds.impl.RevokeUser(args.Statements, args.Username)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error {
|
||||
err := ds.impl.Initialize(args.Config, args.VerifyConnection)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
|
||||
ds.impl.Close()
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,270 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
var (
|
||||
respErrEmptyPluginName = "empty plugin name"
|
||||
respErrEmptyName = "empty name attribute given"
|
||||
)
|
||||
|
||||
// DatabaseConfig is used by the Factory function to configure a Database
|
||||
// object.
|
||||
type DatabaseConfig struct {
|
||||
PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"`
|
||||
// ConnectionDetails stores the database specific connection settings needed
|
||||
// by each database type.
|
||||
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
|
||||
AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"`
|
||||
}
|
||||
|
||||
// pathResetConnection configures a path to reset a plugin.
|
||||
func pathResetConnection(b *databaseBackend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: fmt.Sprintf("reset/%s", framework.GenericNameRegex("name")),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of this database connection",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: b.pathConnectionReset(),
|
||||
},
|
||||
|
||||
HelpSynopsis: pathResetConnectionHelpSyn,
|
||||
HelpDescription: pathResetConnectionHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// pathConnectionReset resets a plugin by closing the existing instance and
|
||||
// creating a new one.
|
||||
func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
|
||||
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
if name == "" {
|
||||
return logical.ErrorResponse(respErrEmptyName), nil
|
||||
}
|
||||
|
||||
// Grab the mutex lock
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
// Close plugin and delete the entry in the connections cache.
|
||||
b.clearConnection(name)
|
||||
|
||||
// Execute plugin again, we don't need the object so throw away.
|
||||
_, err := b.createDBObj(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// pathConfigurePluginConnection returns a configured framework.Path setup to
|
||||
// operate on plugins.
|
||||
func pathConfigurePluginConnection(b *databaseBackend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: fmt.Sprintf("config/%s", framework.GenericNameRegex("name")),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of this database connection",
|
||||
},
|
||||
|
||||
"plugin_name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: `The name of a builtin or previously registered
|
||||
plugin known to vault. This endpoint will create an instance of
|
||||
that plugin type.`,
|
||||
},
|
||||
|
||||
"verify_connection": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Default: true,
|
||||
Description: `If true, the connection details are verified by
|
||||
actually connecting to the database. Defaults to true.`,
|
||||
},
|
||||
|
||||
"allowed_roles": &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `Comma separated string or array of the role names
|
||||
allowed to get creds from this database connection. If empty no
|
||||
roles are allowed. If "*" all roles are allowed.`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: b.connectionWriteHandler(),
|
||||
logical.ReadOperation: b.connectionReadHandler(),
|
||||
logical.DeleteOperation: b.connectionDeleteHandler(),
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigConnectionHelpSyn,
|
||||
HelpDescription: pathConfigConnectionHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// connectionReadHandler reads out the connection configuration
|
||||
func (b *databaseBackend) connectionReadHandler() framework.OperationFunc {
|
||||
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
if name == "" {
|
||||
return logical.ErrorResponse(respErrEmptyName), nil
|
||||
}
|
||||
|
||||
entry, err := req.Storage.Get(fmt.Sprintf("config/%s", name))
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to read connection configuration")
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var config DatabaseConfig
|
||||
if err := entry.DecodeJSON(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &logical.Response{
|
||||
Data: structs.New(config).Map(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// connectionDeleteHandler deletes the connection configuration
|
||||
func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc {
|
||||
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
if name == "" {
|
||||
return logical.ErrorResponse(respErrEmptyName), nil
|
||||
}
|
||||
|
||||
err := req.Storage.Delete(fmt.Sprintf("config/%s", name))
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to delete connection configuration")
|
||||
}
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
if _, ok := b.connections[name]; ok {
|
||||
err = b.connections[name].Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
delete(b.connections, name)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// connectionWriteHandler returns a handler function for creating and updating
|
||||
// both builtin and plugin database types.
|
||||
func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
||||
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
pluginName := data.Get("plugin_name").(string)
|
||||
if pluginName == "" {
|
||||
return logical.ErrorResponse(respErrEmptyPluginName), nil
|
||||
}
|
||||
|
||||
name := data.Get("name").(string)
|
||||
if name == "" {
|
||||
return logical.ErrorResponse(respErrEmptyName), nil
|
||||
}
|
||||
|
||||
verifyConnection := data.Get("verify_connection").(bool)
|
||||
|
||||
allowedRoles := data.Get("allowed_roles").([]string)
|
||||
|
||||
// Remove these entries from the data before we store it keyed under
|
||||
// ConnectionDetails.
|
||||
delete(data.Raw, "name")
|
||||
delete(data.Raw, "plugin_name")
|
||||
delete(data.Raw, "allowed_roles")
|
||||
delete(data.Raw, "verify_connection")
|
||||
|
||||
config := &DatabaseConfig{
|
||||
ConnectionDetails: data.Raw,
|
||||
PluginName: pluginName,
|
||||
AllowedRoles: allowedRoles,
|
||||
}
|
||||
|
||||
db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||
}
|
||||
|
||||
err = db.Initialize(config.ConnectionDetails, verifyConnection)
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||
}
|
||||
|
||||
// Grab the mutex lock
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
// Close and remove the old connection
|
||||
b.clearConnection(name)
|
||||
|
||||
// Save the new connection
|
||||
b.connections[name] = db
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any.")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
const pathConfigConnectionHelpSyn = `
|
||||
Configure connection details to a database plugin.
|
||||
`
|
||||
|
||||
const pathConfigConnectionHelpDesc = `
|
||||
This path configures the connection details used to connect to a particular
|
||||
database. This path runs the provided plugin name and passes the configured
|
||||
connection details to the plugin. See the documentation for the plugin specified
|
||||
for a full list of accepted connection details.
|
||||
|
||||
In addition to the database specific connection details, this endpoint also
|
||||
accepts:
|
||||
|
||||
* "plugin_name" (required) - The name of a builtin or previously registered
|
||||
plugin known to vault. This endpoint will create an instance of that
|
||||
plugin type.
|
||||
|
||||
* "verify_connection" (default: true) - A boolean value denoting if the plugin should verify
|
||||
it is able to connect to the database using the provided connection
|
||||
details.
|
||||
`
|
||||
|
||||
const pathResetConnectionHelpSyn = `
|
||||
Resets a database plugin.
|
||||
`
|
||||
|
||||
const pathResetConnectionHelpDesc = `
|
||||
This path resets the database connection by closing the existing database plugin
|
||||
instance and running a new one.
|
||||
`
|
|
@ -0,0 +1,106 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathCredsCreate(b *databaseBackend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "creds/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathCredsCreateRead(),
|
||||
},
|
||||
|
||||
HelpSynopsis: pathCredsCreateReadHelpSyn,
|
||||
HelpDescription: pathCredsCreateReadHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
||||
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Get the role
|
||||
role, err := b.Role(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
||||
}
|
||||
|
||||
dbConfig, err := b.DatabaseConfig(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If role name isn't in the database's allowed roles, send back a
|
||||
// permission denied.
|
||||
if !strutil.StrListContains(dbConfig.AllowedRoles, "*") && !strutil.StrListContains(dbConfig.AllowedRoles, name) {
|
||||
return nil, logical.ErrPermissionDenied
|
||||
}
|
||||
|
||||
// Grab the read lock
|
||||
b.RLock()
|
||||
var unlockFunc func() = b.RUnlock
|
||||
|
||||
// Get the Database object
|
||||
db, ok := b.getDBObj(role.DBName)
|
||||
if !ok {
|
||||
// Upgrade lock
|
||||
b.RUnlock()
|
||||
b.Lock()
|
||||
unlockFunc = b.Unlock
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
unlockFunc()
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
}
|
||||
}
|
||||
|
||||
expiration := time.Now().Add(role.DefaultTTL)
|
||||
|
||||
// Create the user
|
||||
username, password, err := db.CreateUser(role.Statements, req.DisplayName, expiration)
|
||||
// Unlock
|
||||
unlockFunc()
|
||||
if err != nil {
|
||||
b.closeIfShutdown(role.DBName, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}, map[string]interface{}{
|
||||
"username": username,
|
||||
"role": name,
|
||||
})
|
||||
resp.Secret.TTL = role.DefaultTTL
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
const pathCredsCreateReadHelpSyn = `
|
||||
Request database credentials for a certain role.
|
||||
`
|
||||
|
||||
const pathCredsCreateReadHelpDesc = `
|
||||
This path reads database credentials for a certain role. The
|
||||
database credentials will be generated on demand and will be automatically
|
||||
revoked when the lease is up.
|
||||
`
|
|
@ -0,0 +1,233 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathListRoles(b *databaseBackend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/?$",
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ListOperation: b.pathRoleList(),
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathRoles(b *databaseBackend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
|
||||
"db_name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the database this role acts on.",
|
||||
},
|
||||
"creation_statements": {
|
||||
Type: framework.TypeString,
|
||||
Description: `Statements to be executed to create a user. Must be a semicolon-separated
|
||||
string, a base64-encoded semicolon-separated string, a serialized JSON string
|
||||
array, or a base64-encoded serialized JSON string array. The '{{name}}',
|
||||
'{{password}}', and '{{expiration}}' values will be substituted.`,
|
||||
},
|
||||
"revocation_statements": {
|
||||
Type: framework.TypeString,
|
||||
Description: `Statements to be executed to revoke a user. Must be a semicolon-separated
|
||||
string, a base64-encoded semicolon-separated string, a serialized JSON string
|
||||
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
|
||||
will be substituted.`,
|
||||
},
|
||||
"renew_statements": {
|
||||
Type: framework.TypeString,
|
||||
Description: `Statements to be executed to renew a user. Must be a semicolon-separated
|
||||
string, a base64-encoded semicolon-separated string, a serialized JSON string
|
||||
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
|
||||
will be substituted.`,
|
||||
},
|
||||
"rollback_statements": {
|
||||
Type: framework.TypeString,
|
||||
Description: `Statements to be executed to revoke a user. Must be a semicolon-separated
|
||||
string, a base64-encoded semicolon-separated string, a serialized JSON string
|
||||
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
|
||||
will be substituted.`,
|
||||
},
|
||||
|
||||
"default_ttl": {
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: "Default ttl for role.",
|
||||
},
|
||||
|
||||
"max_ttl": {
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: "Maximum time a credential is valid for",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathRoleRead(),
|
||||
logical.UpdateOperation: b.pathRoleCreate(),
|
||||
logical.DeleteOperation: b.pathRoleDelete(),
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathRoleDelete() framework.OperationFunc {
|
||||
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete("role/" + data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathRoleRead() framework.OperationFunc {
|
||||
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
role, err := b.Role(req.Storage, data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"db_name": role.DBName,
|
||||
"creation_statements": role.Statements.CreationStatements,
|
||||
"revocation_statements": role.Statements.RevocationStatements,
|
||||
"rollback_statements": role.Statements.RollbackStatements,
|
||||
"renew_statements": role.Statements.RenewStatements,
|
||||
"default_ttl": role.DefaultTTL.Seconds(),
|
||||
"max_ttl": role.MaxTTL.Seconds(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathRoleList() framework.OperationFunc {
|
||||
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List("role/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logical.ListResponse(entries), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
|
||||
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
if name == "" {
|
||||
return logical.ErrorResponse("empty role name attribute given"), nil
|
||||
}
|
||||
|
||||
dbName := data.Get("db_name").(string)
|
||||
if dbName == "" {
|
||||
return logical.ErrorResponse("empty database name attribute given"), nil
|
||||
}
|
||||
|
||||
// Get statements
|
||||
creationStmts := data.Get("creation_statements").(string)
|
||||
revocationStmts := data.Get("revocation_statements").(string)
|
||||
rollbackStmts := data.Get("rollback_statements").(string)
|
||||
renewStmts := data.Get("renew_statements").(string)
|
||||
|
||||
// Get TTLs
|
||||
defaultTTLRaw := data.Get("default_ttl").(int)
|
||||
maxTTLRaw := data.Get("max_ttl").(int)
|
||||
defaultTTL := time.Duration(defaultTTLRaw) * time.Second
|
||||
maxTTL := time.Duration(maxTTLRaw) * time.Second
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: creationStmts,
|
||||
RevocationStatements: revocationStmts,
|
||||
RollbackStatements: rollbackStmts,
|
||||
RenewStatements: renewStmts,
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
|
||||
DBName: dbName,
|
||||
Statements: statements,
|
||||
DefaultTTL: defaultTTL,
|
||||
MaxTTL: maxTTL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
type roleEntry struct {
|
||||
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
||||
Statements dbplugin.Statements `json:"statments" mapstructure:"statements" structs:"statments"`
|
||||
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
|
||||
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
|
||||
}
|
||||
|
||||
const pathRoleHelpSyn = `
|
||||
Manage the roles that can be created with this backend.
|
||||
`
|
||||
|
||||
const pathRoleHelpDesc = `
|
||||
This path lets you manage the roles that can be created with this backend.
|
||||
|
||||
The "db_name" parameter is required and configures the name of the database
|
||||
connection to use.
|
||||
|
||||
The "creation_statements" parameter customizes the string used to create the
|
||||
credentials. This can be a sequence of SQL queries, or other statement formats
|
||||
for a particular database type. Some substitution will be done to the statement
|
||||
strings for certain keys. The names of the variables must be surrounded by "{{"
|
||||
and "}}" to be replaced.
|
||||
|
||||
* "name" - The random username generated for the DB user.
|
||||
|
||||
* "password" - The random password generated for the DB user.
|
||||
|
||||
* "expiration" - The timestamp when this user will expire.
|
||||
|
||||
Example of a decent creation_statements for a postgresql database plugin:
|
||||
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
|
||||
The "revocation_statements" parameter customizes the statement string used to
|
||||
revoke a user. Example of a decent revocation_statements for a postgresql
|
||||
database plugin:
|
||||
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
|
||||
REVOKE USAGE ON SCHEMA public FROM {{name}};
|
||||
DROP ROLE IF EXISTS {{name}};
|
||||
|
||||
The "renew_statements" parameter customizes the statement string used to renew a
|
||||
user.
|
||||
The "rollback_statements' parameter customizes the statement string used to
|
||||
rollback a change if needed.
|
||||
`
|
|
@ -0,0 +1,139 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
const SecretCredsType = "creds"
|
||||
|
||||
func secretCreds(b *databaseBackend) *framework.Secret {
|
||||
return &framework.Secret{
|
||||
Type: SecretCredsType,
|
||||
Fields: map[string]*framework.FieldSchema{},
|
||||
|
||||
Renew: b.secretCredsRenew(),
|
||||
Revoke: b.secretCredsRevoke(),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
||||
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the username from the internal data
|
||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing username internal data")
|
||||
}
|
||||
username, ok := usernameRaw.(string)
|
||||
|
||||
roleNameRaw, ok := req.Secret.InternalData["role"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
}
|
||||
|
||||
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, fmt.Errorf("error during renew: could not find role with name %s", req.Secret.InternalData["role"])
|
||||
}
|
||||
|
||||
f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System())
|
||||
resp, err := f(req, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Grab the read lock
|
||||
b.RLock()
|
||||
var unlockFunc func() = b.RUnlock
|
||||
|
||||
// Get the Database object
|
||||
db, ok := b.getDBObj(role.DBName)
|
||||
if !ok {
|
||||
// Upgrade lock
|
||||
b.RUnlock()
|
||||
b.Lock()
|
||||
unlockFunc = b.Unlock
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
unlockFunc()
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
|
||||
err := db.RenewUser(role.Statements, username, expireTime)
|
||||
// Unlock
|
||||
unlockFunc()
|
||||
if err != nil {
|
||||
b.closeIfShutdown(role.DBName, err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
|
||||
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the username from the internal data
|
||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing username internal data")
|
||||
}
|
||||
username, ok := usernameRaw.(string)
|
||||
|
||||
var resp *logical.Response
|
||||
|
||||
roleNameRaw, ok := req.Secret.InternalData["role"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no role name was provided")
|
||||
}
|
||||
|
||||
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, fmt.Errorf("error during revoke: could not find role with name %s", req.Secret.InternalData["role"])
|
||||
}
|
||||
|
||||
// Grab the read lock
|
||||
b.RLock()
|
||||
var unlockFunc func() = b.RUnlock
|
||||
|
||||
// Get our connection
|
||||
db, ok := b.getDBObj(role.DBName)
|
||||
if !ok {
|
||||
// Upgrade lock
|
||||
b.RUnlock()
|
||||
b.Lock()
|
||||
unlockFunc = b.Unlock
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
unlockFunc()
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = db.RevokeUser(role.Statements, username)
|
||||
// Unlock
|
||||
unlockFunc()
|
||||
if err != nil {
|
||||
b.closeIfShutdown(role.DBName, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package totp
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
return Backend(conf).Setup(conf)
|
||||
}
|
||||
|
||||
func Backend(conf *logical.BackendConfig) *backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathListKeys(&b),
|
||||
pathKeys(&b),
|
||||
pathCode(&b),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{},
|
||||
}
|
||||
|
||||
return &b
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The TOTP backend dynamically generates time-based one-time use passwords.
|
||||
`
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,110 @@
|
|||
package totp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
otplib "github.com/pquerna/otp"
|
||||
totplib "github.com/pquerna/otp/totp"
|
||||
)
|
||||
|
||||
func pathCode(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "code/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the key.",
|
||||
},
|
||||
"code": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "TOTP code to be validated.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathReadCode,
|
||||
logical.UpdateOperation: b.pathValidateCode,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathCodeHelpSyn,
|
||||
HelpDescription: pathCodeHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathReadCode(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Get the key
|
||||
key, err := b.Key(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if key == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown key: %s", name)), nil
|
||||
}
|
||||
|
||||
// Generate password using totp library
|
||||
totpToken, err := totplib.GenerateCodeCustom(key.Key, time.Now(), totplib.ValidateOpts{
|
||||
Period: key.Period,
|
||||
Digits: key.Digits,
|
||||
Algorithm: key.Algorithm,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the secret
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"code": totpToken,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathValidateCode(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
code := data.Get("code").(string)
|
||||
|
||||
// Enforce input value requirements
|
||||
if code == "" {
|
||||
return logical.ErrorResponse("the code value is required"), nil
|
||||
}
|
||||
|
||||
// Get the key's stored values
|
||||
key, err := b.Key(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if key == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown key: %s", name)), nil
|
||||
}
|
||||
|
||||
valid, err := totplib.ValidateCustom(code, key.Key, time.Now(), totplib.ValidateOpts{
|
||||
Period: key.Period,
|
||||
Skew: key.Skew,
|
||||
Digits: key.Digits,
|
||||
Algorithm: key.Algorithm,
|
||||
})
|
||||
if err != nil && err != otplib.ErrValidateInputInvalidLength {
|
||||
return logical.ErrorResponse("an error occured while validating the code"), err
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"valid": valid,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
const pathCodeHelpSyn = `
|
||||
Request time-based one-time use password or validate a password for a certain key .
|
||||
`
|
||||
const pathCodeHelpDesc = `
|
||||
This path generates and validates time-based one-time use passwords for a certain key.
|
||||
|
||||
`
|
|
@ -0,0 +1,424 @@
|
|||
package totp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base32"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image/png"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
otplib "github.com/pquerna/otp"
|
||||
totplib "github.com/pquerna/otp/totp"
|
||||
)
|
||||
|
||||
func pathListKeys(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "keys/?$",
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ListOperation: b.pathKeyList,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathKeyHelpSyn,
|
||||
HelpDescription: pathKeyHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathKeys(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "keys/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the key.",
|
||||
},
|
||||
|
||||
"generate": {
|
||||
Type: framework.TypeBool,
|
||||
Default: false,
|
||||
Description: "Determines if a key should be generated by Vault or if a key is being passed from another service.",
|
||||
},
|
||||
|
||||
"exported": {
|
||||
Type: framework.TypeBool,
|
||||
Default: true,
|
||||
Description: "Determines if a QR code and url are returned upon generating a key. Only used if generate is true.",
|
||||
},
|
||||
|
||||
"key_size": {
|
||||
Type: framework.TypeInt,
|
||||
Default: 20,
|
||||
Description: "Determines the size in bytes of the generated key. Only used if generate is true.",
|
||||
},
|
||||
|
||||
"key": {
|
||||
Type: framework.TypeString,
|
||||
Description: "The shared master key used to generate a TOTP token. Only used if generate is false.",
|
||||
},
|
||||
|
||||
"issuer": {
|
||||
Type: framework.TypeString,
|
||||
Description: `The name of the key's issuing organization. Required if generate is true.`,
|
||||
},
|
||||
|
||||
"account_name": {
|
||||
Type: framework.TypeString,
|
||||
Description: `The name of the account associated with the key. Required if generate is true.`,
|
||||
},
|
||||
|
||||
"period": {
|
||||
Type: framework.TypeDurationSecond,
|
||||
Default: 30,
|
||||
Description: `The length of time used to generate a counter for the TOTP token calculation.`,
|
||||
},
|
||||
|
||||
"algorithm": {
|
||||
Type: framework.TypeString,
|
||||
Default: "SHA1",
|
||||
Description: `The hashing algorithm used to generate the TOTP token. Options include SHA1, SHA256 and SHA512.`,
|
||||
},
|
||||
|
||||
"digits": {
|
||||
Type: framework.TypeInt,
|
||||
Default: 6,
|
||||
Description: `The number of digits in the generated TOTP token. This value can either be 6 or 8.`,
|
||||
},
|
||||
|
||||
"skew": {
|
||||
Type: framework.TypeInt,
|
||||
Default: 1,
|
||||
Description: `The number of delay periods that are allowed when validating a TOTP token. This value can either be 0 or 1. Only used if generate is true.`,
|
||||
},
|
||||
|
||||
"qr_size": {
|
||||
Type: framework.TypeInt,
|
||||
Default: 200,
|
||||
Description: `The pixel size of the generated square QR code. Only used if generate is true and exported is true. If this value is 0, a QR code will not be returned.`,
|
||||
},
|
||||
|
||||
"url": {
|
||||
Type: framework.TypeString,
|
||||
Description: `A TOTP url string containing all of the parameters for key setup. Only used if generate is false.`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathKeyRead,
|
||||
logical.UpdateOperation: b.pathKeyCreate,
|
||||
logical.DeleteOperation: b.pathKeyDelete,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathKeyHelpSyn,
|
||||
HelpDescription: pathKeyHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) Key(s logical.Storage, n string) (*keyEntry, error) {
|
||||
entry, err := s.Get("key/" + n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result keyEntry
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathKeyDelete(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete("key/" + data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathKeyRead(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
key, err := b.Key(req.Storage, data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if key == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Translate algorithm back to string
|
||||
algorithm := key.Algorithm.String()
|
||||
|
||||
// Return values of key
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"issuer": key.Issuer,
|
||||
"account_name": key.AccountName,
|
||||
"period": key.Period,
|
||||
"algorithm": algorithm,
|
||||
"digits": key.Digits,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathKeyList(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List("key/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logical.ListResponse(entries), nil
|
||||
}
|
||||
|
||||
func (b *backend) pathKeyCreate(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
generate := data.Get("generate").(bool)
|
||||
exported := data.Get("exported").(bool)
|
||||
keyString := data.Get("key").(string)
|
||||
issuer := data.Get("issuer").(string)
|
||||
accountName := data.Get("account_name").(string)
|
||||
period := data.Get("period").(int)
|
||||
algorithm := data.Get("algorithm").(string)
|
||||
digits := data.Get("digits").(int)
|
||||
skew := data.Get("skew").(int)
|
||||
qrSize := data.Get("qr_size").(int)
|
||||
keySize := data.Get("key_size").(int)
|
||||
inputURL := data.Get("url").(string)
|
||||
|
||||
if generate {
|
||||
if keyString != "" {
|
||||
return logical.ErrorResponse("a key should not be passed if generate is true"), nil
|
||||
}
|
||||
if inputURL != "" {
|
||||
return logical.ErrorResponse("a url should not be passed if generate is true"), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Read parameters from url if given
|
||||
if inputURL != "" {
|
||||
//Parse url
|
||||
urlObject, err := url.Parse(inputURL)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("an error occured while parsing url string"), err
|
||||
}
|
||||
|
||||
//Set up query object
|
||||
urlQuery := urlObject.Query()
|
||||
path := strings.TrimPrefix(urlObject.Path, "/")
|
||||
index := strings.Index(path, ":")
|
||||
|
||||
//Read issuer
|
||||
urlIssuer := urlQuery.Get("issuer")
|
||||
if urlIssuer != "" {
|
||||
issuer = urlIssuer
|
||||
} else {
|
||||
if index != -1 {
|
||||
issuer = path[:index]
|
||||
}
|
||||
}
|
||||
|
||||
//Read account name
|
||||
if index == -1 {
|
||||
accountName = path
|
||||
} else {
|
||||
accountName = path[index+1:]
|
||||
}
|
||||
|
||||
//Read key string
|
||||
keyString = urlQuery.Get("secret")
|
||||
|
||||
//Read period
|
||||
periodQuery := urlQuery.Get("period")
|
||||
if periodQuery != "" {
|
||||
periodInt, err := strconv.Atoi(periodQuery)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("an error occured while parsing period value in url"), err
|
||||
}
|
||||
period = periodInt
|
||||
}
|
||||
|
||||
//Read digits
|
||||
digitsQuery := urlQuery.Get("digits")
|
||||
if digitsQuery != "" {
|
||||
digitsInt, err := strconv.Atoi(digitsQuery)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("an error occured while parsing digits value in url"), err
|
||||
}
|
||||
digits = digitsInt
|
||||
}
|
||||
|
||||
//Read algorithm
|
||||
algorithmQuery := urlQuery.Get("algorithm")
|
||||
if algorithmQuery != "" {
|
||||
algorithm = algorithmQuery
|
||||
}
|
||||
}
|
||||
|
||||
// Translate digits and algorithm to a format the totp library understands
|
||||
var keyDigits otplib.Digits
|
||||
switch digits {
|
||||
case 6:
|
||||
keyDigits = otplib.DigitsSix
|
||||
case 8:
|
||||
keyDigits = otplib.DigitsEight
|
||||
default:
|
||||
return logical.ErrorResponse("the digits value can only be 6 or 8"), nil
|
||||
}
|
||||
|
||||
var keyAlgorithm otplib.Algorithm
|
||||
switch algorithm {
|
||||
case "SHA1":
|
||||
keyAlgorithm = otplib.AlgorithmSHA1
|
||||
case "SHA256":
|
||||
keyAlgorithm = otplib.AlgorithmSHA256
|
||||
case "SHA512":
|
||||
keyAlgorithm = otplib.AlgorithmSHA512
|
||||
default:
|
||||
return logical.ErrorResponse("the algorithm value is not valid"), nil
|
||||
}
|
||||
|
||||
// Enforce input value requirements
|
||||
if period <= 0 {
|
||||
return logical.ErrorResponse("the period value must be greater than zero"), nil
|
||||
}
|
||||
|
||||
switch skew {
|
||||
case 0:
|
||||
case 1:
|
||||
default:
|
||||
return logical.ErrorResponse("the skew value must be 0 or 1"), nil
|
||||
}
|
||||
|
||||
// QR size can be zero but it shouldn't be negative
|
||||
if qrSize < 0 {
|
||||
return logical.ErrorResponse("the qr_size value must be greater than or equal to zero"), nil
|
||||
}
|
||||
|
||||
if keySize <= 0 {
|
||||
return logical.ErrorResponse("the key_size value must be greater than zero"), nil
|
||||
}
|
||||
|
||||
// Period, Skew and Key Size need to be unsigned ints
|
||||
uintPeriod := uint(period)
|
||||
uintSkew := uint(skew)
|
||||
uintKeySize := uint(keySize)
|
||||
|
||||
var response *logical.Response
|
||||
|
||||
switch generate {
|
||||
case true:
|
||||
// If the key is generated, Account Name and Issuer are required.
|
||||
if accountName == "" {
|
||||
return logical.ErrorResponse("the account_name value is required for generated keys"), nil
|
||||
}
|
||||
|
||||
if issuer == "" {
|
||||
return logical.ErrorResponse("the issuer value is required for generated keys"), nil
|
||||
}
|
||||
|
||||
// Generate a new key
|
||||
keyObject, err := totplib.Generate(totplib.GenerateOpts{
|
||||
Issuer: issuer,
|
||||
AccountName: accountName,
|
||||
Period: uintPeriod,
|
||||
Digits: keyDigits,
|
||||
Algorithm: keyAlgorithm,
|
||||
SecretSize: uintKeySize,
|
||||
})
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("an error occured while generating a key"), err
|
||||
}
|
||||
|
||||
// Get key string value
|
||||
keyString = keyObject.Secret()
|
||||
|
||||
// Skip returning the QR code and url if exported is set to false
|
||||
if exported {
|
||||
// Prepare the url and barcode
|
||||
urlString := keyObject.String()
|
||||
|
||||
// Don't include QR code is size is set to zero
|
||||
if qrSize == 0 {
|
||||
response = &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"url": urlString,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
barcode, err := keyObject.Image(qrSize, qrSize)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("an error occured while generating a QR code image"), err
|
||||
}
|
||||
|
||||
var buff bytes.Buffer
|
||||
png.Encode(&buff, barcode)
|
||||
b64Barcode := base64.StdEncoding.EncodeToString(buff.Bytes())
|
||||
response = &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"url": urlString,
|
||||
"barcode": b64Barcode,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
if keyString == "" {
|
||||
return logical.ErrorResponse("the key value is required"), nil
|
||||
}
|
||||
|
||||
_, err := base32.StdEncoding.DecodeString(keyString)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"invalid key value: %s", err)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("key/"+name, &keyEntry{
|
||||
Key: keyString,
|
||||
Issuer: issuer,
|
||||
AccountName: accountName,
|
||||
Period: uintPeriod,
|
||||
Algorithm: keyAlgorithm,
|
||||
Digits: keyDigits,
|
||||
Skew: uintSkew,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
type keyEntry struct {
|
||||
Key string `json:"key" mapstructure:"key" structs:"key"`
|
||||
Issuer string `json:"issuer" mapstructure:"issuer" structs:"issuer"`
|
||||
AccountName string `json:"account_name" mapstructure:"account_name" structs:"account_name"`
|
||||
Period uint `json:"period" mapstructure:"period" structs:"period"`
|
||||
Algorithm otplib.Algorithm `json:"algorithm" mapstructure:"algorithm" structs:"algorithm"`
|
||||
Digits otplib.Digits `json:"digits" mapstructure:"digits" structs:"digits"`
|
||||
Skew uint `json:"skew" mapstructure:"skew" structs:"skew"`
|
||||
}
|
||||
|
||||
const pathKeyHelpSyn = `
|
||||
Manage the keys that can be created with this backend.
|
||||
`
|
||||
|
||||
const pathKeyHelpDesc = `
|
||||
This path lets you manage the keys that can be created with this backend.
|
||||
|
||||
`
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/hashicorp/vault/builtin/logical/aws"
|
||||
"github.com/hashicorp/vault/builtin/logical/cassandra"
|
||||
"github.com/hashicorp/vault/builtin/logical/consul"
|
||||
"github.com/hashicorp/vault/builtin/logical/database"
|
||||
"github.com/hashicorp/vault/builtin/logical/mongodb"
|
||||
"github.com/hashicorp/vault/builtin/logical/mssql"
|
||||
"github.com/hashicorp/vault/builtin/logical/mysql"
|
||||
|
@ -28,6 +29,7 @@ import (
|
|||
"github.com/hashicorp/vault/builtin/logical/postgresql"
|
||||
"github.com/hashicorp/vault/builtin/logical/rabbitmq"
|
||||
"github.com/hashicorp/vault/builtin/logical/ssh"
|
||||
"github.com/hashicorp/vault/builtin/logical/totp"
|
||||
"github.com/hashicorp/vault/builtin/logical/transit"
|
||||
|
||||
"github.com/hashicorp/vault/audit"
|
||||
|
@ -91,6 +93,8 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory {
|
|||
"mysql": mysql.Factory,
|
||||
"ssh": ssh.Factory,
|
||||
"rabbitmq": rabbitmq.Factory,
|
||||
"database": database.Factory,
|
||||
"totp": totp.Factory,
|
||||
},
|
||||
ShutdownCh: command.MakeShutdownCh(),
|
||||
SighupCh: command.MakeSighupCh(),
|
||||
|
|
|
@ -238,6 +238,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
DefaultLeaseTTL: config.DefaultLeaseTTL,
|
||||
ClusterName: config.ClusterName,
|
||||
CacheSize: config.CacheSize,
|
||||
PluginDirectory: config.PluginDirectory,
|
||||
}
|
||||
if dev {
|
||||
coreConfig.DevToken = devRootTokenID
|
||||
|
|
|
@ -42,7 +42,8 @@ type Config struct {
|
|||
DefaultLeaseTTL time.Duration `hcl:"-"`
|
||||
DefaultLeaseTTLRaw interface{} `hcl:"default_lease_ttl"`
|
||||
|
||||
ClusterName string `hcl:"cluster_name"`
|
||||
ClusterName string `hcl:"cluster_name"`
|
||||
PluginDirectory string `hcl:"plugin_directory"`
|
||||
}
|
||||
|
||||
// DevConfig is a Config that is used for dev mode of Vault.
|
||||
|
@ -272,6 +273,11 @@ func (c *Config) Merge(c2 *Config) *Config {
|
|||
result.EnableUI = c2.EnableUI
|
||||
}
|
||||
|
||||
result.PluginDirectory = c.PluginDirectory
|
||||
if c2.PluginDirectory != "" {
|
||||
result.PluginDirectory = c2.PluginDirectory
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
|
@ -363,6 +369,7 @@ func ParseConfig(d string, logger log.Logger) (*Config, error) {
|
|||
"default_lease_ttl",
|
||||
"max_lease_ttl",
|
||||
"cluster_name",
|
||||
"plugin_directory",
|
||||
}
|
||||
if err := checkHCLKeys(list, valid); err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
package builtinplugins
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/vault/plugins/database/cassandra"
|
||||
"github.com/hashicorp/vault/plugins/database/mssql"
|
||||
"github.com/hashicorp/vault/plugins/database/mysql"
|
||||
"github.com/hashicorp/vault/plugins/database/postgresql"
|
||||
)
|
||||
|
||||
type BuiltinFactory func() (interface{}, error)
|
||||
|
||||
var plugins map[string]BuiltinFactory = map[string]BuiltinFactory{
|
||||
// These four plugins all use the same mysql implementation but with
|
||||
// different username settings passed by the constructor.
|
||||
"mysql-database-plugin": mysql.New(mysql.DisplayNameLen, mysql.UsernameLen),
|
||||
"mysql-aurora-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen),
|
||||
"mysql-rds-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen),
|
||||
"mysql-legacy-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen),
|
||||
|
||||
"postgresql-database-plugin": postgresql.New,
|
||||
"mssql-database-plugin": mssql.New,
|
||||
"cassandra-database-plugin": cassandra.New,
|
||||
}
|
||||
|
||||
func Get(name string) (BuiltinFactory, bool) {
|
||||
f, ok := plugins[name]
|
||||
return f, ok
|
||||
}
|
||||
|
||||
func Keys() []string {
|
||||
keys := make([]string, len(plugins))
|
||||
|
||||
i := 0
|
||||
for k := range plugins {
|
||||
keys[i] = k
|
||||
i++
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package pluginutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/vault/helper/mlock"
|
||||
)
|
||||
|
||||
var (
|
||||
// PluginUnwrapTokenEnv is the ENV name used to pass the configuration for
|
||||
// enabling mlock
|
||||
PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED"
|
||||
)
|
||||
|
||||
// OptionallyEnableMlock determines if mlock should be called, and if so enables
|
||||
// mlock.
|
||||
func OptionallyEnableMlock() error {
|
||||
if os.Getenv(PluginMlockEnabled) == "true" {
|
||||
return mlock.LockMemory()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
package pluginutil
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/wrapping"
|
||||
)
|
||||
|
||||
// Looker defines the plugin Lookup function that looks into the plugin catalog
|
||||
// for availible plugins and returns a PluginRunner
|
||||
type Looker interface {
|
||||
LookupPlugin(string) (*PluginRunner, error)
|
||||
}
|
||||
|
||||
// Wrapper interface defines the functions needed by the runner to wrap the
|
||||
// metadata needed to run a plugin process. This includes looking up Mlock
|
||||
// configuration and wrapping data in a respose wrapped token.
|
||||
type RunnerUtil interface {
|
||||
ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error)
|
||||
MlockEnabled() bool
|
||||
}
|
||||
|
||||
// LookWrapper defines the functions for both Looker and Wrapper
|
||||
type LookRunnerUtil interface {
|
||||
Looker
|
||||
RunnerUtil
|
||||
}
|
||||
|
||||
// PluginRunner defines the metadata needed to run a plugin securely with
|
||||
// go-plugin.
|
||||
type PluginRunner struct {
|
||||
Name string `json:"name"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Sha256 []byte `json:"sha256"`
|
||||
Builtin bool `json:"builtin"`
|
||||
BuiltinFactory func() (interface{}, error) `json:"-"`
|
||||
}
|
||||
|
||||
// Run takes a wrapper instance, and the go-plugin paramaters and executes a
|
||||
// plugin.
|
||||
func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) {
|
||||
// Get a CA TLS Certificate
|
||||
certBytes, key, err := generateCert()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use CA to sign a client cert and return a configured TLS config
|
||||
clientTLSConfig, err := createClientTLSConfig(certBytes, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use CA to sign a server cert and wrap the values in a response wrapped
|
||||
// token.
|
||||
wrapToken, err := wrapServerConfig(wrapper, certBytes, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd := exec.Command(r.Command, r.Args...)
|
||||
cmd.Env = append(cmd.Env, env...)
|
||||
// Add the response wrap token to the ENV of the plugin
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken))
|
||||
// Add the mlock setting to the ENV of the plugin
|
||||
if wrapper.MlockEnabled() {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true"))
|
||||
}
|
||||
|
||||
secureConfig := &plugin.SecureConfig{
|
||||
Checksum: r.Sha256,
|
||||
Hash: sha256.New(),
|
||||
}
|
||||
|
||||
client := plugin.NewClient(&plugin.ClientConfig{
|
||||
HandshakeConfig: hs,
|
||||
Plugins: pluginMap,
|
||||
Cmd: cmd,
|
||||
TLSConfig: clientTLSConfig,
|
||||
SecureConfig: secureConfig,
|
||||
})
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
type APIClientMeta struct {
|
||||
// These are set by the command line flags.
|
||||
flagCACert string
|
||||
flagCAPath string
|
||||
flagClientCert string
|
||||
flagClientKey string
|
||||
flagInsecure bool
|
||||
}
|
||||
|
||||
func (f *APIClientMeta) FlagSet() *flag.FlagSet {
|
||||
fs := flag.NewFlagSet("tls settings", flag.ContinueOnError)
|
||||
|
||||
fs.StringVar(&f.flagCACert, "ca-cert", "", "")
|
||||
fs.StringVar(&f.flagCAPath, "ca-path", "", "")
|
||||
fs.StringVar(&f.flagClientCert, "client-cert", "", "")
|
||||
fs.StringVar(&f.flagClientKey, "client-key", "", "")
|
||||
fs.BoolVar(&f.flagInsecure, "tls-skip-verify", false, "")
|
||||
|
||||
return fs
|
||||
}
|
||||
|
||||
func (f *APIClientMeta) GetTLSConfig() *api.TLSConfig {
|
||||
// If we need custom TLS configuration, then set it
|
||||
if f.flagCACert != "" || f.flagCAPath != "" || f.flagClientCert != "" || f.flagClientKey != "" || f.flagInsecure {
|
||||
t := &api.TLSConfig{
|
||||
CACert: f.flagCACert,
|
||||
CAPath: f.flagCAPath,
|
||||
ClientCert: f.flagClientCert,
|
||||
ClientKey: f.flagClientKey,
|
||||
TLSServerName: "",
|
||||
Insecure: f.flagInsecure,
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,227 @@
|
|||
package pluginutil
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/SermoDigital/jose/jws"
|
||||
"github.com/hashicorp/errwrap"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
)
|
||||
|
||||
var (
|
||||
// PluginUnwrapTokenEnv is the ENV name used to pass unwrap tokens to the
|
||||
// plugin.
|
||||
PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN"
|
||||
)
|
||||
|
||||
// generateCert is used internally to create certificates for the plugin
|
||||
// client and server.
|
||||
func generateCert() ([]byte, *ecdsa.PrivateKey, error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
host, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sn, err := certutil.GenerateSerialNumber()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: host,
|
||||
},
|
||||
DNSNames: []string{host},
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
|
||||
SerialNumber: sn,
|
||||
NotBefore: time.Now().Add(-30 * time.Second),
|
||||
NotAfter: time.Now().Add(262980 * time.Hour),
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key)
|
||||
if err != nil {
|
||||
return nil, nil, errwrap.Wrapf("unable to generate client certificate: {{err}}", err)
|
||||
}
|
||||
|
||||
return certBytes, key, nil
|
||||
}
|
||||
|
||||
// createClientTLSConfig creates a signed certificate and returns a configured
|
||||
// TLS config.
|
||||
func createClientTLSConfig(certBytes []byte, key *ecdsa.PrivateKey) (*tls.Config, error) {
|
||||
clientCert, err := x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing generated plugin certificate: %v", err)
|
||||
}
|
||||
|
||||
cert := tls.Certificate{
|
||||
Certificate: [][]byte{certBytes},
|
||||
PrivateKey: key,
|
||||
Leaf: clientCert,
|
||||
}
|
||||
|
||||
clientCertPool := x509.NewCertPool()
|
||||
clientCertPool.AddCert(clientCert)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
RootCAs: clientCertPool,
|
||||
ServerName: clientCert.Subject.CommonName,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// wrapServerConfig is used to create a server certificate and private key, then
|
||||
// wrap them in an unwrap token for later retrieval by the plugin.
|
||||
func wrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) (string, error) {
|
||||
rawKey, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
wrapInfo, err := sys.ResponseWrapData(map[string]interface{}{
|
||||
"ServerCert": certBytes,
|
||||
"ServerKey": rawKey,
|
||||
}, time.Second*10, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return wrapInfo.Token, nil
|
||||
}
|
||||
|
||||
// VaultPluginTLSProvider is run inside a plugin and retrives the response
|
||||
// wrapped TLS certificate from vault. It returns a configured TLS Config.
|
||||
func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, error) {
|
||||
return func() (*tls.Config, error) {
|
||||
unwrapToken := os.Getenv(PluginUnwrapTokenEnv)
|
||||
|
||||
// Parse the JWT and retrieve the vault address
|
||||
wt, err := jws.ParseJWT([]byte(unwrapToken))
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("error decoding token: %s", err))
|
||||
}
|
||||
if wt == nil {
|
||||
return nil, errors.New("nil decoded token")
|
||||
}
|
||||
|
||||
addrRaw := wt.Claims().Get("addr")
|
||||
if addrRaw == nil {
|
||||
return nil, errors.New("decoded token does not contain primary cluster address")
|
||||
}
|
||||
vaultAddr, ok := addrRaw.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("decoded token's address not valid")
|
||||
}
|
||||
if vaultAddr == "" {
|
||||
return nil, errors.New(`no address for the vault found`)
|
||||
}
|
||||
|
||||
// Sanity check the value
|
||||
if _, err := url.Parse(vaultAddr); err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err))
|
||||
}
|
||||
|
||||
// Unwrap the token
|
||||
clientConf := api.DefaultConfig()
|
||||
clientConf.Address = vaultAddr
|
||||
if apiTLSConfig != nil {
|
||||
clientConf.ConfigureTLS(apiTLSConfig)
|
||||
}
|
||||
client, err := api.NewClient(clientConf)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error during api client creation: {{err}}", err)
|
||||
}
|
||||
|
||||
secret, err := client.Logical().Unwrap(unwrapToken)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err)
|
||||
}
|
||||
if secret == nil {
|
||||
return nil, errors.New("error during token unwrap request secret is nil")
|
||||
}
|
||||
|
||||
// Retrieve and parse the server's certificate
|
||||
serverCertBytesRaw, ok := secret.Data["ServerCert"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("error unmarshalling certificate")
|
||||
}
|
||||
|
||||
serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
|
||||
serverCert, err := x509.ParseCertificate(serverCertBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve and parse the server's private key
|
||||
serverKeyB64, ok := secret.Data["ServerKey"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("error unmarshalling certificate")
|
||||
}
|
||||
|
||||
serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
|
||||
serverKey, err := x509.ParseECPrivateKey(serverKeyRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
}
|
||||
|
||||
// Add CA cert to the cert pool
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(serverCert)
|
||||
|
||||
// Build a certificate object out of the server's cert and private key.
|
||||
cert := tls.Certificate{
|
||||
Certificate: [][]byte{serverCertBytes},
|
||||
PrivateKey: serverKey,
|
||||
Leaf: serverCert,
|
||||
}
|
||||
|
||||
// Setup TLS config
|
||||
tlsConfig := &tls.Config{
|
||||
ClientCAs: caCertPool,
|
||||
RootCAs: caCertPool,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
// TLS 1.2 minimum
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
}
|
|
@ -29,6 +29,19 @@ func StrListSubset(super, sub []string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// Parses a comma separated list of strings into a slice of strings.
|
||||
// The return slice will be sorted and will not contain duplicate or
|
||||
// empty items.
|
||||
func ParseDedupAndSortStrings(input string, sep string) []string {
|
||||
input = strings.TrimSpace(input)
|
||||
parsed := []string{}
|
||||
if input == "" {
|
||||
// Don't return nil
|
||||
return parsed
|
||||
}
|
||||
return RemoveDuplicates(strings.Split(input, sep), false)
|
||||
}
|
||||
|
||||
// Parses a comma separated list of strings into a slice of strings.
|
||||
// The return slice will be sorted and will not contain duplicate or
|
||||
// empty items. The values will be converted to lower case.
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
package wrapping
|
||||
|
||||
import "time"
|
||||
|
||||
type ResponseWrapInfo struct {
|
||||
// Setting to non-zero specifies that the response should be wrapped.
|
||||
// Specifies the desired TTL of the wrapping token.
|
||||
TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"`
|
||||
|
||||
// The token containing the wrapped response
|
||||
Token string `json:"token" structs:"token" mapstructure:"token"`
|
||||
|
||||
// The creation time. This can be used with the TTL to figure out an
|
||||
// expected expiration.
|
||||
CreationTime time.Time `json:"creation_time" structs:"creation_time" mapstructure:"cration_time"`
|
||||
|
||||
// If the contained response is the output of a token creation call, the
|
||||
// created token's accessor will be accessible here
|
||||
WrappedAccessor string `json:"wrapped_accessor" structs:"wrapped_accessor" mapstructure:"wrapped_accessor"`
|
||||
|
||||
// The format to use. This doesn't get returned, it's only internal.
|
||||
Format string `json:"format" structs:"format" mapstructure:"format"`
|
||||
}
|
|
@ -595,3 +595,33 @@ func TestHTTP_Forwarding_ClientTLS(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP_Forwarding_HelpOperation(t *testing.T) {
|
||||
handler1 := http.NewServeMux()
|
||||
handler2 := http.NewServeMux()
|
||||
handler3 := http.NewServeMux()
|
||||
|
||||
cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, &vault.CoreConfig{}, true)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
handler1.Handle("/", Handler(cores[0].Core))
|
||||
handler2.Handle("/", Handler(cores[1].Core))
|
||||
handler3.Handle("/", Handler(cores[2].Core))
|
||||
|
||||
vault.TestWaitActive(t, cores[0].Core)
|
||||
|
||||
testHelp := func(client *api.Client) {
|
||||
help, err := client.Help("auth/token")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if help == nil {
|
||||
t.Fatal("help was nil")
|
||||
}
|
||||
}
|
||||
|
||||
testHelp(cores[0].Client)
|
||||
testHelp(cores[1].Client)
|
||||
}
|
||||
|
|
12
http/help.go
12
http/help.go
|
@ -8,14 +8,18 @@ import (
|
|||
)
|
||||
|
||||
func wrapHelpHandler(h http.Handler, core *vault.Core) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
// If the help parameter is not blank, then show the help
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
|
||||
// If the help parameter is not blank, then show the help. We request
|
||||
// forward because standby nodes do not have mounts and other state.
|
||||
if v := req.URL.Query().Get("help"); v != "" || req.Method == "HELP" {
|
||||
handleHelp(core, w, req)
|
||||
handleRequestForwarding(core,
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleHelp(core, w, r)
|
||||
})).ServeHTTP(writer, req)
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, req)
|
||||
h.ServeHTTP(writer, req)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
|
|
@ -4,8 +4,8 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/wrapping"
|
||||
"github.com/mitchellh/copystructure"
|
||||
)
|
||||
|
||||
|
@ -28,26 +28,6 @@ const (
|
|||
HTTPStatusCode = "http_status_code"
|
||||
)
|
||||
|
||||
type ResponseWrapInfo struct {
|
||||
// Setting to non-zero specifies that the response should be wrapped.
|
||||
// Specifies the desired TTL of the wrapping token.
|
||||
TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"`
|
||||
|
||||
// The token containing the wrapped response
|
||||
Token string `json:"token" structs:"token" mapstructure:"token"`
|
||||
|
||||
// The creation time. This can be used with the TTL to figure out an
|
||||
// expected expiration.
|
||||
CreationTime time.Time `json:"creation_time" structs:"creation_time" mapstructure:"cration_time"`
|
||||
|
||||
// If the contained response is the output of a token creation call, the
|
||||
// created token's accessor will be accessible here
|
||||
WrappedAccessor string `json:"wrapped_accessor" structs:"wrapped_accessor" mapstructure:"wrapped_accessor"`
|
||||
|
||||
// The format to use. This doesn't get returned, it's only internal.
|
||||
Format string `json:"format" structs:"format" mapstructure:"format"`
|
||||
}
|
||||
|
||||
// Response is a struct that stores the response of a request.
|
||||
// It is used to abstract the details of the higher level request protocol.
|
||||
type Response struct {
|
||||
|
@ -78,7 +58,7 @@ type Response struct {
|
|||
warnings []string `json:"warnings" structs:"warnings" mapstructure:"warnings"`
|
||||
|
||||
// Information for wrapping the response in a cubbyhole
|
||||
WrapInfo *ResponseWrapInfo `json:"wrap_info" structs:"wrap_info" mapstructure:"wrap_info"`
|
||||
WrapInfo *wrapping.ResponseWrapInfo `json:"wrap_info" structs:"wrap_info" mapstructure:"wrap_info"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -123,7 +103,7 @@ func init() {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("error copying WrapInfo: %v", err)
|
||||
}
|
||||
ret.WrapInfo = retWrapInfo.(*ResponseWrapInfo)
|
||||
ret.WrapInfo = retWrapInfo.(*wrapping.ResponseWrapInfo)
|
||||
}
|
||||
|
||||
return &ret, nil
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
package logical
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/helper/wrapping"
|
||||
)
|
||||
|
||||
// SystemView exposes system configuration information in a safe way
|
||||
|
@ -37,6 +40,18 @@ type SystemView interface {
|
|||
|
||||
// ReplicationState indicates the state of cluster replication
|
||||
ReplicationState() consts.ReplicationState
|
||||
|
||||
// ResponseWrapData wraps the given data in a cubbyhole and returns the
|
||||
// token used to unwrap.
|
||||
ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error)
|
||||
|
||||
// LookupPlugin looks into the plugin catalog for a plugin with the given
|
||||
// name. Returns a PluginRunner or an error if a plugin can not be found.
|
||||
LookupPlugin(string) (*pluginutil.PluginRunner, error)
|
||||
|
||||
// MlockEnabled returns the configuration setting for enabling mlock on
|
||||
// plugins.
|
||||
MlockEnabled() bool
|
||||
}
|
||||
|
||||
type StaticSystemView struct {
|
||||
|
@ -46,6 +61,7 @@ type StaticSystemView struct {
|
|||
TaintedVal bool
|
||||
CachingDisabledVal bool
|
||||
Primary bool
|
||||
EnableMlock bool
|
||||
ReplicationStateVal consts.ReplicationState
|
||||
}
|
||||
|
||||
|
@ -72,3 +88,15 @@ func (d StaticSystemView) CachingDisabled() bool {
|
|||
func (d StaticSystemView) ReplicationState() consts.ReplicationState {
|
||||
return d.ReplicationStateVal
|
||||
}
|
||||
|
||||
func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
||||
return nil, errors.New("ResponseWrapData is not implemented in StaticSystemView")
|
||||
}
|
||||
|
||||
func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) {
|
||||
return nil, errors.New("LookupPlugin is not implemented in StaticSystemView")
|
||||
}
|
||||
|
||||
func (d StaticSystemView) MlockEnabled() bool {
|
||||
return d.EnableMlock
|
||||
}
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/plugins/database/cassandra"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args)
|
||||
|
||||
err := cassandra.Run(apiClientMeta.GetTLSConfig())
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,169 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/plugins"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
|
||||
defaultUserDeletionCQL = `DROP USER '{{username}}';`
|
||||
cassandraTypeName = "cassandra"
|
||||
)
|
||||
|
||||
// Cassandra is an implementation of Database interface
|
||||
type Cassandra struct {
|
||||
connutil.ConnectionProducer
|
||||
credsutil.CredentialsProducer
|
||||
}
|
||||
|
||||
// New returns a new Cassandra instance
|
||||
func New() (interface{}, error) {
|
||||
connProducer := &connutil.CassandraConnectionProducer{}
|
||||
connProducer.Type = cassandraTypeName
|
||||
|
||||
credsProducer := &credsutil.CassandraCredentialsProducer{}
|
||||
|
||||
dbType := &Cassandra{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
return dbType, nil
|
||||
}
|
||||
|
||||
// Run instantiates a Cassandra object, and runs the RPC server for the plugin
|
||||
func Run(apiTLSConfig *api.TLSConfig) error {
|
||||
dbType, err := New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
plugins.Serve(dbType.(*Cassandra), apiTLSConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Type returns the TypeName for this backend
|
||||
func (c *Cassandra) Type() (string, error) {
|
||||
return cassandraTypeName, nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) getConnection() (*gocql.Session, error) {
|
||||
session, err := c.Connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session.(*gocql.Session), nil
|
||||
}
|
||||
|
||||
// CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by
|
||||
// the CreationStatement provided.
|
||||
func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// Get the connection
|
||||
session, err := c.getConnection()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
creationCQL := statements.CreationStatements
|
||||
if creationCQL == "" {
|
||||
creationCQL = defaultUserCreationCQL
|
||||
}
|
||||
rollbackCQL := statements.RollbackStatements
|
||||
if rollbackCQL == "" {
|
||||
rollbackCQL = defaultUserDeletionCQL
|
||||
}
|
||||
|
||||
username, err = c.GenerateUsername(usernamePrefix)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
password, err = c.GeneratePassword()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
err = session.Query(dbutil.QueryHelper(query, map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
})).Exec()
|
||||
if err != nil {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
session.Query(dbutil.QueryHelper(query, map[string]string{
|
||||
"username": username,
|
||||
})).Exec()
|
||||
}
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
return username, password, nil
|
||||
}
|
||||
|
||||
// RenewUser is not supported on Cassandra, so this is a no-op.
|
||||
func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUser attempts to drop the specified user.
|
||||
func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
// Grab the lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
session, err := c.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
revocationCQL := statements.RevocationStatements
|
||||
if revocationCQL == "" {
|
||||
revocationCQL = defaultUserDeletionCQL
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationCQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
err := session.Query(dbutil.QueryHelper(query, map[string]string{
|
||||
"username": username,
|
||||
})).Exec()
|
||||
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
|
@ -0,0 +1,230 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"fmt"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||
)
|
||||
|
||||
func prepareCassandraTestContainer(t *testing.T) (cleanup func(), retURL string) {
|
||||
if os.Getenv("CASSANDRA_HOST") != "" {
|
||||
return func() {}, os.Getenv("CASSANDRA_HOST")
|
||||
}
|
||||
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
|
||||
cwd, _ := os.Getwd()
|
||||
cassandraMountPath := fmt.Sprintf("%s/test-fixtures/:/etc/cassandra/", cwd)
|
||||
|
||||
ro := &dockertest.RunOptions{
|
||||
Repository: "cassandra",
|
||||
Tag: "latest",
|
||||
Mounts: []string{cassandraMountPath},
|
||||
}
|
||||
resource, err := pool.RunWithOptions(ro)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start local cassandra docker container: %s", err)
|
||||
}
|
||||
|
||||
cleanup = func() {
|
||||
err := pool.Purge(resource)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to cleanup local container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
retURL = fmt.Sprintf("localhost:%s", resource.GetPort("9042/tcp"))
|
||||
port, _ := strconv.Atoi(resource.GetPort("9042/tcp"))
|
||||
|
||||
// exponential backoff-retry
|
||||
if err = pool.Retry(func() error {
|
||||
clusterConfig := gocql.NewCluster(retURL)
|
||||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: "cassandra",
|
||||
Password: "cassandra",
|
||||
}
|
||||
clusterConfig.ProtoVersion = 4
|
||||
clusterConfig.Port = port
|
||||
|
||||
session, err := clusterConfig.CreateSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating session: %s", err)
|
||||
}
|
||||
defer session.Close()
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("Could not connect to cassandra docker container: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestCassandra_Initialize(t *testing.T) {
|
||||
cleanup, connURL := prepareCassandraTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"hosts": connURL,
|
||||
"username": "cassandra",
|
||||
"password": "cassandra",
|
||||
"protocol_version": 4,
|
||||
}
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*Cassandra)
|
||||
connProducer := db.ConnectionProducer.(*connutil.CassandraConnectionProducer)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if !connProducer.Initialized {
|
||||
t.Fatal("Database should be initalized")
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCassandra_CreateUser(t *testing.T) {
|
||||
cleanup, connURL := prepareCassandraTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"hosts": connURL,
|
||||
"username": "cassandra",
|
||||
"password": "cassandra",
|
||||
"protocol_version": 4,
|
||||
}
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*Cassandra)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testCassandraRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMyCassandra_RenewUser(t *testing.T) {
|
||||
cleanup, connURL := prepareCassandraTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"hosts": connURL,
|
||||
"username": "cassandra",
|
||||
"password": "cassandra",
|
||||
"protocol_version": 4,
|
||||
}
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*Cassandra)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testCassandraRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCassandra_RevokeUser(t *testing.T) {
|
||||
cleanup, connURL := prepareCassandraTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"hosts": connURL,
|
||||
"username": "cassandra",
|
||||
"password": "cassandra",
|
||||
"protocol_version": 4,
|
||||
}
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*Cassandra)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testCassandraRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
}
|
||||
|
||||
func testCredsExist(t testing.TB, connURL, username, password string) error {
|
||||
clusterConfig := gocql.NewCluster(connURL)
|
||||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
clusterConfig.ProtoVersion = 4
|
||||
|
||||
session, err := clusterConfig.CreateSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating session: %s", err)
|
||||
}
|
||||
defer session.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
const testCassandraRole = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;
|
||||
GRANT ALL PERMISSIONS ON ALL KEYSPACES TO {{username}};`
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,21 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/plugins/database/mssql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args)
|
||||
|
||||
err := mssql.Run(apiClientMeta.GetTLSConfig())
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,318 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/plugins"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||
)
|
||||
|
||||
const msSQLTypeName = "mssql"
|
||||
|
||||
// MSSQL is an implementation of Database interface
|
||||
type MSSQL struct {
|
||||
connutil.ConnectionProducer
|
||||
credsutil.CredentialsProducer
|
||||
}
|
||||
|
||||
func New() (interface{}, error) {
|
||||
connProducer := &connutil.SQLConnectionProducer{}
|
||||
connProducer.Type = msSQLTypeName
|
||||
|
||||
credsProducer := &credsutil.SQLCredentialsProducer{
|
||||
DisplayNameLen: 20,
|
||||
UsernameLen: 128,
|
||||
}
|
||||
|
||||
dbType := &MSSQL{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
return dbType, nil
|
||||
}
|
||||
|
||||
// Run instantiates a MSSQL object, and runs the RPC server for the plugin
|
||||
func Run(apiTLSConfig *api.TLSConfig) error {
|
||||
dbType, err := New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
plugins.Serve(dbType.(*MSSQL), apiTLSConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Type returns the TypeName for this backend
|
||||
func (m *MSSQL) Type() (string, error) {
|
||||
return msSQLTypeName, nil
|
||||
}
|
||||
|
||||
func (m *MSSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := m.Connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
|
||||
// the CreationStatement provided.
|
||||
func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if statements.CreationStatements == "" {
|
||||
return "", "", dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
username, err = m.GenerateUsername(usernamePrefix)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
password, err = m.GeneratePassword()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
expirationStr, err := m.GenerateExpiration(expiration)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
"expiration": expirationStr,
|
||||
}))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return username, password, nil
|
||||
}
|
||||
|
||||
// RenewUser is not supported on MSSQL, so this is a no-op.
|
||||
func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
|
||||
// then kill pending connections from that user, and finally drop the user and login from the
|
||||
// database instance.
|
||||
func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
if statements.RevocationStatements == "" {
|
||||
return m.revokeUserDefault(username)
|
||||
}
|
||||
|
||||
// Get connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.RevocationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MSSQL) revokeUserDefault(username string) error {
|
||||
// Get connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// First disable server login
|
||||
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer disableStmt.Close()
|
||||
if _, err := disableStmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Query for sessions for the login so that we can kill any outstanding
|
||||
// sessions. There cannot be any active sessions before we drop the logins
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
sessionStmt, err := db.Prepare(fmt.Sprintf(
|
||||
"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sessionStmt.Close()
|
||||
|
||||
sessionRows, err := sessionStmt.Query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sessionRows.Close()
|
||||
|
||||
var revokeStmts []string
|
||||
for sessionRows.Next() {
|
||||
var sessionID int
|
||||
err = sessionRows.Scan(&sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID))
|
||||
}
|
||||
|
||||
// Query for database users using undocumented stored procedure for now since
|
||||
// it is the easiest way to get this information;
|
||||
// we need to drop the database users before we can drop the login and the role
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
stmt, err := db.Prepare(fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var loginName, dbName, qUsername string
|
||||
var aliasName sql.NullString
|
||||
err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username))
|
||||
}
|
||||
|
||||
// we do not stop on error, as we want to remove as
|
||||
// many permissions as possible right now
|
||||
var lastStmtError error
|
||||
for _, query := range revokeStmts {
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
continue
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
}
|
||||
}
|
||||
|
||||
// can't drop if not all database users are dropped
|
||||
if rows.Err() != nil {
|
||||
return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err())
|
||||
}
|
||||
if lastStmtError != nil {
|
||||
return fmt.Errorf("could not perform all sql statements: %s", lastStmtError)
|
||||
}
|
||||
|
||||
// Drop this login
|
||||
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const dropUserSQL = `
|
||||
USE [%s]
|
||||
IF EXISTS
|
||||
(SELECT name
|
||||
FROM sys.database_principals
|
||||
WHERE name = N'%s')
|
||||
BEGIN
|
||||
DROP USER [%s]
|
||||
END
|
||||
`
|
||||
|
||||
const dropLoginSQL = `
|
||||
IF EXISTS
|
||||
(SELECT name
|
||||
FROM master.sys.server_principals
|
||||
WHERE name = N'%s')
|
||||
BEGIN
|
||||
DROP LOGIN [%s]
|
||||
END
|
||||
`
|
|
@ -0,0 +1,167 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
)
|
||||
|
||||
var (
|
||||
testMSQLImagePull sync.Once
|
||||
)
|
||||
|
||||
func TestMSSQL_Initialize(t *testing.T) {
|
||||
if os.Getenv("MSSQL_URL") == "" || os.Getenv("VAULT_ACC") != "1" {
|
||||
return
|
||||
}
|
||||
connURL := os.Getenv("MSSQL_URL")
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*MSSQL)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||
if !connProducer.Initialized {
|
||||
t.Fatal("Database should be initalized")
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSSQL_CreateUser(t *testing.T) {
|
||||
if os.Getenv("MSSQL_URL") == "" || os.Getenv("VAULT_ACC") != "1" {
|
||||
return
|
||||
}
|
||||
connURL := os.Getenv("MSSQL_URL")
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*MSSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testMSSQLRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSSQL_RevokeUser(t *testing.T) {
|
||||
if os.Getenv("MSSQL_URL") == "" || os.Getenv("VAULT_ACC") != "1" {
|
||||
return
|
||||
}
|
||||
connURL := os.Getenv("MSSQL_URL")
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*MSSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testMSSQLRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
|
||||
username, password, err = db.CreateUser(statements, "test", time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test custom revoke statememt
|
||||
statements.RevocationStatements = testMSSQLDrop
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
}
|
||||
|
||||
func testCredsExist(t testing.TB, connURL, username, password string) error {
|
||||
// Log in with the new creds
|
||||
parts := strings.Split(connURL, "@")
|
||||
connURL = fmt.Sprintf("sqlserver://%s:%s@%s", username, password, parts[1])
|
||||
db, err := sql.Open("mssql", connURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
return db.Ping()
|
||||
}
|
||||
|
||||
const testMSSQLRole = `
|
||||
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
|
||||
CREATE USER [{{name}}] FOR LOGIN [{{name}}];
|
||||
GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];`
|
||||
|
||||
const testMSSQLDrop = `
|
||||
DROP USER [{{name}}];
|
||||
DROP LOGIN [{{name}}];
|
||||
`
|
|
@ -0,0 +1,21 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/plugins/database/mysql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args)
|
||||
|
||||
err := mysql.Run(apiClientMeta.GetTLSConfig())
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,201 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/plugins"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMysqlRevocationStmts = `
|
||||
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
|
||||
DROP USER '{{name}}'@'%'
|
||||
`
|
||||
mySQLTypeName = "mysql"
|
||||
)
|
||||
|
||||
var (
|
||||
DisplayNameLen int = 10
|
||||
LegacyDisplayNameLen int = 4
|
||||
UsernameLen int = 32
|
||||
LegacyUsernameLen int = 16
|
||||
)
|
||||
|
||||
type MySQL struct {
|
||||
connutil.ConnectionProducer
|
||||
credsutil.CredentialsProducer
|
||||
}
|
||||
|
||||
// New implements builtinplugins.BuiltinFactory
|
||||
func New(displayLen, usernameLen int) func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
connProducer := &connutil.SQLConnectionProducer{}
|
||||
connProducer.Type = mySQLTypeName
|
||||
|
||||
credsProducer := &credsutil.SQLCredentialsProducer{
|
||||
DisplayNameLen: displayLen,
|
||||
UsernameLen: usernameLen,
|
||||
}
|
||||
|
||||
dbType := &MySQL{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
return dbType, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Run instantiates a MySQL object, and runs the RPC server for the plugin
|
||||
func Run(apiTLSConfig *api.TLSConfig) error {
|
||||
f := New(DisplayNameLen, UsernameLen)
|
||||
dbType, err := f()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
plugins.Serve(dbType.(*MySQL), apiTLSConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQL) Type() (string, error) {
|
||||
return mySQLTypeName, nil
|
||||
}
|
||||
|
||||
func (m *MySQL) getConnection() (*sql.DB, error) {
|
||||
db, err := m.Connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (m *MySQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if statements.CreationStatements == "" {
|
||||
return "", "", dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
username, err = m.GenerateUsername(usernamePrefix)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
password, err = m.GeneratePassword()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
expirationStr, err := m.GenerateExpiration(expiration)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
"expiration": expirationStr,
|
||||
}))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return username, password, nil
|
||||
}
|
||||
|
||||
// NOOP
|
||||
func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
// Grab the read lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
revocationStmts := statements.RevocationStatements
|
||||
// Use a default SQL statement for revocation if one cannot be fetched from the role
|
||||
if revocationStmts == "" {
|
||||
revocationStmts = defaultMysqlRevocationStmts
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// This is not a prepared statement because not all commands are supported
|
||||
// 1295: This command is not supported in the prepared statement protocol yet
|
||||
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
|
||||
query = strings.Replace(query, "{{name}}", username, -1)
|
||||
_, err = tx.Exec(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,206 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
testMySQLImagePull sync.Once
|
||||
)
|
||||
|
||||
func prepareMySQLTestContainer(t *testing.T) (cleanup func(), retURL string) {
|
||||
if os.Getenv("MYSQL_URL") != "" {
|
||||
return func() {}, os.Getenv("MYSQL_URL")
|
||||
}
|
||||
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
|
||||
resource, err := pool.Run("mysql", "latest", []string{"MYSQL_ROOT_PASSWORD=secret"})
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start local MySQL docker container: %s", err)
|
||||
}
|
||||
|
||||
cleanup = func() {
|
||||
err := pool.Purge(resource)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to cleanup local container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
retURL = fmt.Sprintf("root:secret@(localhost:%s)/mysql?parseTime=true", resource.GetPort("3306/tcp"))
|
||||
|
||||
// exponential backoff-retry
|
||||
if err = pool.Retry(func() error {
|
||||
var err error
|
||||
var db *sql.DB
|
||||
db, err = sql.Open("mysql", retURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Ping()
|
||||
}); err != nil {
|
||||
t.Fatalf("Could not connect to MySQL docker container: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestMySQL_Initialize(t *testing.T) {
|
||||
cleanup, connURL := prepareMySQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
f := New(DisplayNameLen, UsernameLen)
|
||||
dbRaw, _ := f()
|
||||
db := dbRaw.(*MySQL)
|
||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if !connProducer.Initialized {
|
||||
t.Fatal("Database should be initalized")
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQL_CreateUser(t *testing.T) {
|
||||
cleanup, connURL := prepareMySQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
f := New(DisplayNameLen, UsernameLen)
|
||||
dbRaw, _ := f()
|
||||
db := dbRaw.(*MySQL)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testMySQLRoleWildCard,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQL_RevokeUser(t *testing.T) {
|
||||
cleanup, connURL := prepareMySQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
f := New(DisplayNameLen, UsernameLen)
|
||||
dbRaw, _ := f()
|
||||
db := dbRaw.(*MySQL)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testMySQLRoleWildCard,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
|
||||
statements.CreationStatements = testMySQLRoleWildCard
|
||||
username, password, err = db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test custom revoke statements
|
||||
statements.RevocationStatements = testMySQLRevocationSQL
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
}
|
||||
|
||||
func testCredsExist(t testing.TB, connURL, username, password string) error {
|
||||
// Log in with the new creds
|
||||
connURL = strings.Replace(connURL, "root:secret", fmt.Sprintf("%s:%s", username, password), 1)
|
||||
db, err := sql.Open("mysql", connURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
return db.Ping()
|
||||
}
|
||||
|
||||
const testMySQLRoleWildCard = `
|
||||
CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';
|
||||
GRANT SELECT ON *.* TO '{{name}}'@'%';
|
||||
`
|
||||
const testMySQLRevocationSQL = `
|
||||
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
|
||||
DROP USER '{{name}}'@'%';
|
||||
`
|
|
@ -0,0 +1,21 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/plugins/database/postgresql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args)
|
||||
|
||||
err := postgresql.Run(apiClientMeta.GetTLSConfig())
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,343 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/plugins"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const postgreSQLTypeName string = "postgres"
|
||||
|
||||
// New implements builtinplugins.BuiltinFactory
|
||||
func New() (interface{}, error) {
|
||||
connProducer := &connutil.SQLConnectionProducer{}
|
||||
connProducer.Type = postgreSQLTypeName
|
||||
|
||||
credsProducer := &credsutil.SQLCredentialsProducer{
|
||||
DisplayNameLen: 10,
|
||||
UsernameLen: 63,
|
||||
}
|
||||
|
||||
dbType := &PostgreSQL{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
return dbType, nil
|
||||
}
|
||||
|
||||
// Run instantiates a PostgreSQL object, and runs the RPC server for the plugin
|
||||
func Run(apiTLSConfig *api.TLSConfig) error {
|
||||
dbType, err := New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
plugins.Serve(dbType.(*PostgreSQL), apiTLSConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type PostgreSQL struct {
|
||||
connutil.ConnectionProducer
|
||||
credsutil.CredentialsProducer
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Type() (string, error) {
|
||||
return postgreSQLTypeName, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := p.Connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
|
||||
if statements.CreationStatements == "" {
|
||||
return "", "", dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
username, err = p.GenerateUsername(usernamePrefix)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
password, err = p.GeneratePassword()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
expirationStr, err := p.GenerateExpiration(expiration)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Get the connection
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
// Return the secret
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
"expiration": expirationStr,
|
||||
}))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
|
||||
return username, password, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expirationStr, err := p.GenerateExpiration(expiration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"ALTER ROLE %s VALID UNTIL '%s';",
|
||||
pq.QuoteIdentifier(username),
|
||||
expirationStr)
|
||||
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if statements.RevocationStatements == "" {
|
||||
return p.defaultRevokeUser(username)
|
||||
}
|
||||
|
||||
return p.customRevokeUser(username, statements.RevocationStatements)
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the role exists
|
||||
var exists bool
|
||||
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
|
||||
if exists == false {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query for permissions; we need to revoke permissions before we can drop
|
||||
// the role
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
const initialNumRevocations = 16
|
||||
revocationStmts := make([]string, 0, initialNumRevocations)
|
||||
for rows.Next() {
|
||||
var schema string
|
||||
err = rows.Scan(&schema)
|
||||
if err != nil {
|
||||
// keep going; remove as many permissions as possible right now
|
||||
continue
|
||||
}
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`,
|
||||
pq.QuoteIdentifier(schema),
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE USAGE ON SCHEMA %s FROM %s;`,
|
||||
pq.QuoteIdentifier(schema),
|
||||
pq.QuoteIdentifier(username)))
|
||||
}
|
||||
|
||||
// for good measure, revoke all privileges and usage on schema public
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`,
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;",
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
"REVOKE USAGE ON SCHEMA public FROM %s;",
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
// get the current database name so we can issue a REVOKE CONNECT for
|
||||
// this username
|
||||
var dbname sql.NullString
|
||||
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if dbname.Valid {
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE CONNECT ON DATABASE %s FROM %s;`,
|
||||
pq.QuoteIdentifier(dbname.String),
|
||||
pq.QuoteIdentifier(username)))
|
||||
}
|
||||
|
||||
// again, here, we do not stop on error, as we want to remove as
|
||||
// many permissions as possible right now
|
||||
var lastStmtError error
|
||||
for _, query := range revocationStmts {
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
continue
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
}
|
||||
}
|
||||
|
||||
// can't drop if not all privileges are revoked
|
||||
if rows.Err() != nil {
|
||||
return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err())
|
||||
}
|
||||
if lastStmtError != nil {
|
||||
return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError)
|
||||
}
|
||||
|
||||
// Drop this user
|
||||
stmt, err = db.Prepare(fmt.Sprintf(
|
||||
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,313 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
testPostgresImagePull sync.Once
|
||||
)
|
||||
|
||||
func preparePostgresTestContainer(t *testing.T) (cleanup func(), retURL string) {
|
||||
if os.Getenv("PG_URL") != "" {
|
||||
return func() {}, os.Getenv("PG_URL")
|
||||
}
|
||||
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
|
||||
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=database"})
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start local PostgreSQL docker container: %s", err)
|
||||
}
|
||||
|
||||
cleanup = func() {
|
||||
err := pool.Purge(resource)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to cleanup local container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
retURL = fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp"))
|
||||
|
||||
// exponential backoff-retry
|
||||
if err = pool.Retry(func() error {
|
||||
var err error
|
||||
var db *sql.DB
|
||||
db, err = sql.Open("postgres", retURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Ping()
|
||||
}); err != nil {
|
||||
t.Fatalf("Could not connect to PostgreSQL docker container: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestPostgreSQL_Initialize(t *testing.T) {
|
||||
cleanup, connURL := preparePostgresTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
|
||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if !connProducer.Initialized {
|
||||
t.Fatal("Database should be initalized")
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||
cleanup, connURL := preparePostgresTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testPostgresRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
statements.CreationStatements = testPostgresReadOnlyRole
|
||||
username, password, err = db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||
cleanup, connURL := preparePostgresTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testPostgresRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Sleep longer than the inital expiration time
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||
cleanup, connURL := preparePostgresTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testPostgresRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
|
||||
username, password, err = db.CreateUser(statements, "test", time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test custom revoke statements
|
||||
statements.RevocationStatements = defaultPostgresRevocationSQL
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
}
|
||||
|
||||
func testCredsExist(t testing.TB, connURL, username, password string) error {
|
||||
// Log in with the new creds
|
||||
connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1)
|
||||
db, err := sql.Open("postgres", connURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
return db.Ping()
|
||||
}
|
||||
|
||||
const testPostgresRole = `
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
`
|
||||
|
||||
const testPostgresReadOnlyRole = `
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";
|
||||
`
|
||||
|
||||
const testPostgresBlockStatementRole = `
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
|
||||
CREATE ROLE "foo-role";
|
||||
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
|
||||
ALTER ROLE "foo-role" SET search_path = foo;
|
||||
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
|
||||
END IF;
|
||||
END
|
||||
$$
|
||||
|
||||
CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';
|
||||
GRANT "foo-role" TO "{{name}}";
|
||||
ALTER ROLE "{{name}}" SET search_path = foo;
|
||||
GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";
|
||||
`
|
||||
|
||||
var testPostgresBlockStatementRoleSlice = []string{
|
||||
`
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
|
||||
CREATE ROLE "foo-role";
|
||||
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
|
||||
ALTER ROLE "foo-role" SET search_path = foo;
|
||||
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
|
||||
END IF;
|
||||
END
|
||||
$$
|
||||
`,
|
||||
`CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`,
|
||||
`GRANT "foo-role" TO "{{name}}";`,
|
||||
`ALTER ROLE "{{name}}" SET search_path = foo;`,
|
||||
`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`,
|
||||
}
|
||||
|
||||
const defaultPostgresRevocationSQL = `
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}";
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}";
|
||||
REVOKE USAGE ON SCHEMA public FROM "{{name}}";
|
||||
|
||||
DROP ROLE IF EXISTS "{{name}}";
|
||||
`
|
|
@ -0,0 +1,226 @@
|
|||
package connutil
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/parseutil"
|
||||
"github.com/hashicorp/vault/helper/tlsutil"
|
||||
)
|
||||
|
||||
// CassandraConnectionProducer implements ConnectionProducer and provides an
|
||||
// interface for cassandra databases to make connections.
|
||||
type CassandraConnectionProducer struct {
|
||||
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
|
||||
Username string `json:"username" structs:"username" mapstructure:"username"`
|
||||
Password string `json:"password" structs:"password" mapstructure:"password"`
|
||||
TLS bool `json:"tls" structs:"tls" mapstructure:"tls"`
|
||||
InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
|
||||
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
|
||||
ConnectTimeoutRaw interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
|
||||
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
|
||||
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
|
||||
PemBundle string `json:"pem_bundle" structs:"pem_bundle" mapstructure:"pem_bundle"`
|
||||
PemJSON string `json:"pem_json" structs:"pem_json" mapstructure:"pem_json"`
|
||||
|
||||
connectTimeout time.Duration
|
||||
certificate string
|
||||
privateKey string
|
||||
issuingCA string
|
||||
|
||||
Initialized bool
|
||||
Type string
|
||||
session *gocql.Session
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
err := mapstructure.Decode(conf, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Initialized = true
|
||||
|
||||
if c.ConnectTimeoutRaw == nil {
|
||||
c.ConnectTimeoutRaw = "0s"
|
||||
}
|
||||
c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid connect_timeout: %s", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(c.Hosts) == 0:
|
||||
return fmt.Errorf("hosts cannot be empty")
|
||||
case len(c.Username) == 0:
|
||||
return fmt.Errorf("username cannot be empty")
|
||||
case len(c.Password) == 0:
|
||||
return fmt.Errorf("password cannot be empty")
|
||||
}
|
||||
|
||||
var certBundle *certutil.CertBundle
|
||||
var parsedCertBundle *certutil.ParsedCertBundle
|
||||
switch {
|
||||
case len(c.PemJSON) != 0:
|
||||
parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %s", err)
|
||||
}
|
||||
certBundle, err = parsedCertBundle.ToCertBundle()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error marshaling PEM information: %s", err)
|
||||
}
|
||||
c.certificate = certBundle.Certificate
|
||||
c.privateKey = certBundle.PrivateKey
|
||||
c.issuingCA = certBundle.IssuingCA
|
||||
c.TLS = true
|
||||
|
||||
case len(c.PemBundle) != 0:
|
||||
parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error parsing the given PEM information: %s", err)
|
||||
}
|
||||
certBundle, err = parsedCertBundle.ToCertBundle()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error marshaling PEM information: %s", err)
|
||||
}
|
||||
c.certificate = certBundle.Certificate
|
||||
c.privateKey = certBundle.PrivateKey
|
||||
c.issuingCA = certBundle.IssuingCA
|
||||
c.TLS = true
|
||||
}
|
||||
|
||||
if verifyConnection {
|
||||
if _, err := c.Connection(); err != nil {
|
||||
return fmt.Errorf("error Initalizing Connection: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CassandraConnectionProducer) Connection() (interface{}, error) {
|
||||
if !c.Initialized {
|
||||
return nil, errNotInitialized
|
||||
}
|
||||
|
||||
// If we already have a DB, return it
|
||||
if c.session != nil {
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
session, err := c.createSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store the session in backend for reuse
|
||||
c.session = session
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (c *CassandraConnectionProducer) Close() error {
|
||||
// Grab the write lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if c.session != nil {
|
||||
c.session.Close()
|
||||
}
|
||||
|
||||
c.session = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) {
|
||||
clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...)
|
||||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: c.Username,
|
||||
Password: c.Password,
|
||||
}
|
||||
|
||||
clusterConfig.ProtoVersion = c.ProtocolVersion
|
||||
if clusterConfig.ProtoVersion == 0 {
|
||||
clusterConfig.ProtoVersion = 2
|
||||
}
|
||||
|
||||
clusterConfig.Timeout = c.connectTimeout
|
||||
if c.TLS {
|
||||
var tlsConfig *tls.Config
|
||||
if len(c.certificate) > 0 || len(c.issuingCA) > 0 {
|
||||
if len(c.certificate) > 0 && len(c.privateKey) == 0 {
|
||||
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
|
||||
}
|
||||
|
||||
certBundle := &certutil.CertBundle{}
|
||||
if len(c.certificate) > 0 {
|
||||
certBundle.Certificate = c.certificate
|
||||
certBundle.PrivateKey = c.privateKey
|
||||
}
|
||||
if len(c.issuingCA) > 0 {
|
||||
certBundle.IssuingCA = c.issuingCA
|
||||
}
|
||||
|
||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
|
||||
}
|
||||
|
||||
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
||||
if err != nil || tlsConfig == nil {
|
||||
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = c.InsecureTLS
|
||||
|
||||
if c.TLSMinVersion != "" {
|
||||
var ok bool
|
||||
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
||||
}
|
||||
} else {
|
||||
// MinVersion was not being set earlier. Reset it to
|
||||
// zero to gracefully handle upgrades.
|
||||
tlsConfig.MinVersion = 0
|
||||
}
|
||||
}
|
||||
|
||||
clusterConfig.SslOpts = &gocql.SslOptions{
|
||||
Config: tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
session, err := clusterConfig.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating session: %s", err)
|
||||
}
|
||||
|
||||
// Set consistency
|
||||
if c.Consistency != "" {
|
||||
consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.SetConsistency(consistencyValue)
|
||||
}
|
||||
|
||||
// Verify the info
|
||||
err = session.Query(`LIST USERS`).Exec()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error validating connection info: %s", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package connutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
errNotInitialized = errors.New("connection has not been initalized")
|
||||
)
|
||||
|
||||
// ConnectionProducer can be used as an embeded interface in the Database
|
||||
// definition. It implements the methods dealing with individual database
|
||||
// connections and is used in all the builtin database types.
|
||||
type ConnectionProducer interface {
|
||||
Close() error
|
||||
Initialize(map[string]interface{}, bool) error
|
||||
Connection() (interface{}, error)
|
||||
|
||||
sync.Locker
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
package connutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
// Import sql drivers
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/hashicorp/vault/helper/parseutil"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
|
||||
type SQLConnectionProducer struct {
|
||||
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
||||
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
||||
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
|
||||
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
|
||||
|
||||
Type string
|
||||
maxConnectionLifetime time.Duration
|
||||
Initialized bool
|
||||
db *sql.DB
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
err := mapstructure.Decode(conf, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(c.ConnectionURL) == 0 {
|
||||
return fmt.Errorf("connection_url cannot be empty")
|
||||
}
|
||||
|
||||
if c.MaxOpenConnections == 0 {
|
||||
c.MaxOpenConnections = 2
|
||||
}
|
||||
|
||||
if c.MaxIdleConnections == 0 {
|
||||
c.MaxIdleConnections = c.MaxOpenConnections
|
||||
}
|
||||
if c.MaxIdleConnections > c.MaxOpenConnections {
|
||||
c.MaxIdleConnections = c.MaxOpenConnections
|
||||
}
|
||||
if c.MaxConnectionLifetimeRaw == nil {
|
||||
c.MaxConnectionLifetimeRaw = "0s"
|
||||
}
|
||||
|
||||
c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid max_connection_lifetime: %s", err)
|
||||
}
|
||||
|
||||
if verifyConnection {
|
||||
if _, err := c.Connection(); err != nil {
|
||||
return fmt.Errorf("error initalizing connection: %s", err)
|
||||
}
|
||||
|
||||
if err := c.db.Ping(); err != nil {
|
||||
return fmt.Errorf("error initalizing connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Initialized = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) Connection() (interface{}, error) {
|
||||
// If we already have a DB, test it and return
|
||||
if c.db != nil {
|
||||
if err := c.db.Ping(); err == nil {
|
||||
return c.db, nil
|
||||
}
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
// reestablishing anyways
|
||||
c.db.Close()
|
||||
}
|
||||
|
||||
// For mssql backend, switch to sqlserver instead
|
||||
dbType := c.Type
|
||||
if c.Type == "mssql" {
|
||||
dbType = "sqlserver"
|
||||
}
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
conn := c.ConnectionURL
|
||||
|
||||
// Ensure timezone is set to UTC for all the conenctions
|
||||
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
|
||||
if strings.Contains(conn, "?") {
|
||||
conn += "&timezone=utc"
|
||||
} else {
|
||||
conn += "?timezone=utc"
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
c.db, err = sql.Open(dbType, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set some connection pool settings. We don't need much of this,
|
||||
// since the request rate shouldn't be high.
|
||||
c.db.SetMaxOpenConns(c.MaxOpenConnections)
|
||||
c.db.SetMaxIdleConns(c.MaxIdleConnections)
|
||||
c.db.SetConnMaxLifetime(c.maxConnectionLifetime)
|
||||
|
||||
return c.db, nil
|
||||
}
|
||||
|
||||
// Close attempts to close the connection
|
||||
func (c *SQLConnectionProducer) Close() error {
|
||||
// Grab the write lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if c.db != nil {
|
||||
c.db.Close()
|
||||
}
|
||||
|
||||
c.db = nil
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package credsutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
// CassandraCredentialsProducer implements CredentialsProducer and provides an
|
||||
// interface for cassandra databases to generate user information.
|
||||
type CassandraCredentialsProducer struct{}
|
||||
|
||||
func (ccp *CassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) {
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix())
|
||||
username = strings.Replace(username, "-", "_", -1)
|
||||
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func (ccp *CassandraCredentialsProducer) GeneratePassword() (string, error) {
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func (ccp *CassandraCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
|
||||
return "", nil
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
package credsutil
|
||||
|
||||
import "time"
|
||||
|
||||
// CredentialsProducer can be used as an embeded interface in the Database
|
||||
// definition. It implements the methods for generating user information for a
|
||||
// particular database type and is used in all the builtin database types.
|
||||
type CredentialsProducer interface {
|
||||
GenerateUsername(displayName string) (string, error)
|
||||
GeneratePassword() (string, error)
|
||||
GenerateExpiration(ttl time.Time) (string, error)
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package credsutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
// SQLCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types.
|
||||
type SQLCredentialsProducer struct {
|
||||
DisplayNameLen int
|
||||
UsernameLen int
|
||||
}
|
||||
|
||||
func (scp *SQLCredentialsProducer) GenerateUsername(displayName string) (string, error) {
|
||||
if scp.DisplayNameLen > 0 && len(displayName) > scp.DisplayNameLen {
|
||||
displayName = displayName[:scp.DisplayNameLen]
|
||||
}
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
username := fmt.Sprintf("v-%s-%s", displayName, userUUID)
|
||||
if scp.UsernameLen > 0 && len(username) > scp.UsernameLen {
|
||||
username = username[:scp.UsernameLen]
|
||||
}
|
||||
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func (scp *SQLCredentialsProducer) GeneratePassword() (string, error) {
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func (scp *SQLCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
|
||||
return ttl.Format("2006-01-02 15:04:05-0700"), nil
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package dbutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyCreationStatement = errors.New("empty creation statements")
|
||||
)
|
||||
|
||||
// Query templates a query for us.
|
||||
func QueryHelper(tpl string, data map[string]string) string {
|
||||
for k, v := range data {
|
||||
tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1)
|
||||
}
|
||||
|
||||
return tpl
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package plugins
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
)
|
||||
|
||||
// Serve is used to start a plugin's RPC server. It takes an interface that must
|
||||
// implement a known plugin interface to vault and an optional api.TLSConfig for
|
||||
// use during the inital unwrap request to vault. The api config is particulary
|
||||
// useful when vault is setup to require client cert checking.
|
||||
func Serve(plugin interface{}, tlsConfig *api.TLSConfig) {
|
||||
tlsProvider := pluginutil.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
err := pluginutil.OptionallyEnableMlock()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
switch p := plugin.(type) {
|
||||
case dbplugin.Database:
|
||||
dbplugin.Serve(p, tlsProvider)
|
||||
default:
|
||||
fmt.Println("Unsupported plugin type")
|
||||
}
|
||||
|
||||
}
|
|
@ -10,6 +10,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -330,6 +331,14 @@ type Core struct {
|
|||
|
||||
// uiEnabled indicates whether Vault Web UI is enabled or not
|
||||
uiEnabled bool
|
||||
|
||||
// pluginDirectory is the location vault will look for plugin binaries
|
||||
pluginDirectory string
|
||||
|
||||
// pluginCatalog is used to manage plugin configurations
|
||||
pluginCatalog *PluginCatalog
|
||||
|
||||
enableMlock bool
|
||||
}
|
||||
|
||||
// CoreConfig is used to parameterize a core
|
||||
|
@ -374,6 +383,8 @@ type CoreConfig struct {
|
|||
|
||||
EnableUI bool `json:"ui" structs:"ui" mapstructure:"ui"`
|
||||
|
||||
PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"`
|
||||
|
||||
ReloadFuncs *map[string][]ReloadFunc
|
||||
ReloadFuncsLock *sync.RWMutex
|
||||
}
|
||||
|
@ -430,6 +441,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
clusterName: conf.ClusterName,
|
||||
clusterListenerShutdownCh: make(chan struct{}),
|
||||
clusterListenerShutdownSuccessCh: make(chan struct{}),
|
||||
enableMlock: !conf.DisableMlock,
|
||||
}
|
||||
|
||||
// Wrap the physical backend in a cache layer if enabled and not already wrapped
|
||||
|
@ -453,8 +465,15 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Construct a new AES-GCM barrier
|
||||
var err error
|
||||
if conf.PluginDirectory != "" {
|
||||
c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("core setup failed, could not verify plugin directory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Construct a new AES-GCM barrier
|
||||
c.barrier, err = NewAESGCMBarrier(c.physical)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("barrier setup failed: %v", err)
|
||||
|
@ -1280,6 +1299,10 @@ func (c *Core) postUnseal() (retErr error) {
|
|||
if err := c.setupAuditedHeadersConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.setupPluginCatalog(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.ha != nil {
|
||||
if err := c.startClusterListener(); err != nil {
|
||||
return err
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/helper/wrapping"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
|
@ -87,3 +90,49 @@ func (d dynamicSystemView) ReplicationState() consts.ReplicationState {
|
|||
d.core.clusterParamsLock.RUnlock()
|
||||
return state
|
||||
}
|
||||
|
||||
// ResponseWrapData wraps the given data in a cubbyhole and returns the
|
||||
// token used to unwrap.
|
||||
func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
||||
req := &logical.Request{
|
||||
Operation: logical.CreateOperation,
|
||||
Path: "sys/wrapping/wrap",
|
||||
}
|
||||
|
||||
resp := &logical.Response{
|
||||
WrapInfo: &wrapping.ResponseWrapInfo{
|
||||
TTL: ttl,
|
||||
},
|
||||
Data: data,
|
||||
}
|
||||
|
||||
if jwt {
|
||||
resp.WrapInfo.Format = "jwt"
|
||||
}
|
||||
|
||||
_, err := d.core.wrapInCubbyhole(req, resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp.WrapInfo, nil
|
||||
}
|
||||
|
||||
// LookupPlugin looks for a plugin with the given name in the plugin catalog. It
|
||||
// returns a PluginRunner or an error if no plugin was found.
|
||||
func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) {
|
||||
r, err := d.core.pluginCatalog.Get(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("no plugin found with name: %s", name)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// MlockEnabled returns the configuration setting for enabling mlock on plugins.
|
||||
func (d dynamicSystemView) MlockEnabled() bool {
|
||||
return d.core.enableMlock
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/parseutil"
|
||||
"github.com/hashicorp/vault/helper/wrapping"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
@ -62,6 +63,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen
|
|||
"replication/reindex",
|
||||
"rotate",
|
||||
"config/auditing/*",
|
||||
"plugins/catalog/*",
|
||||
"revoke-prefix/*",
|
||||
"leases/revoke-prefix/*",
|
||||
"leases/revoke-force/*",
|
||||
|
@ -747,6 +749,48 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen
|
|||
HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers"][0]),
|
||||
HelpDescription: strings.TrimSpace(sysHelp["audited-headers"][1]),
|
||||
},
|
||||
&framework.Path{
|
||||
Pattern: "plugins/catalog/$",
|
||||
|
||||
Fields: map[string]*framework.FieldSchema{},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ListOperation: b.handlePluginCatalogList,
|
||||
},
|
||||
|
||||
HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]),
|
||||
HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]),
|
||||
},
|
||||
&framework.Path{
|
||||
Pattern: "plugins/catalog/(?P<name>.+)",
|
||||
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "The name of the plugin",
|
||||
},
|
||||
"sha_256": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: `The SHA256 sum of the executable used in the
|
||||
command field. This should be HEX encoded.`,
|
||||
},
|
||||
"command": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: `The command used to start the plugin. The
|
||||
executable defined in this command must exist in vault's
|
||||
plugin directory.`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: b.handlePluginCatalogUpdate,
|
||||
logical.DeleteOperation: b.handlePluginCatalogDelete,
|
||||
logical.ReadOperation: b.handlePluginCatalogRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]),
|
||||
HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -788,6 +832,77 @@ func (b *SystemBackend) invalidate(key string) {
|
|||
}
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handlePluginCatalogList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
plugins, err := b.Core.pluginCatalog.List()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logical.ListResponse(plugins), nil
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handlePluginCatalogUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
pluginName := d.Get("name").(string)
|
||||
if pluginName == "" {
|
||||
return logical.ErrorResponse("missing plugin name"), nil
|
||||
}
|
||||
|
||||
sha256 := d.Get("sha_256").(string)
|
||||
if sha256 == "" {
|
||||
return logical.ErrorResponse("missing SHA-256 value"), nil
|
||||
}
|
||||
|
||||
command := d.Get("command").(string)
|
||||
if command == "" {
|
||||
return logical.ErrorResponse("missing command value"), nil
|
||||
}
|
||||
|
||||
sha256Bytes, err := hex.DecodeString(sha256)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("Could not decode SHA-256 value from Hex"), err
|
||||
}
|
||||
|
||||
err = b.Core.pluginCatalog.Set(pluginName, command, sha256Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handlePluginCatalogRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
pluginName := d.Get("name").(string)
|
||||
if pluginName == "" {
|
||||
return logical.ErrorResponse("missing plugin name"), nil
|
||||
}
|
||||
plugin, err := b.Core.pluginCatalog.Get(pluginName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if plugin == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"plugin": plugin,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handlePluginCatalogDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
pluginName := d.Get("name").(string)
|
||||
if pluginName == "" {
|
||||
return logical.ErrorResponse("missing plugin name"), nil
|
||||
}
|
||||
err := b.Core.pluginCatalog.Delete(pluginName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// handleAuditedHeaderUpdate creates or overwrites a header entry
|
||||
func (b *SystemBackend) handleAuditedHeaderUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
header := d.Get("header").(string)
|
||||
|
@ -2094,7 +2209,7 @@ func (b *SystemBackend) handleWrappingRewrap(
|
|||
Data: map[string]interface{}{
|
||||
"response": response,
|
||||
},
|
||||
WrapInfo: &logical.ResponseWrapInfo{
|
||||
WrapInfo: &wrapping.ResponseWrapInfo{
|
||||
TTL: time.Duration(creationTTL),
|
||||
},
|
||||
}, nil
|
||||
|
@ -2553,7 +2668,23 @@ This path responds to the following HTTP methods.
|
|||
"Lists the headers configured to be audited.",
|
||||
`Returns a list of headers that have been configured to be audited.`,
|
||||
},
|
||||
"plugins/catalog": {
|
||||
`Configures the plugins known to vault`,
|
||||
`
|
||||
This path responds to the following HTTP methods.
|
||||
LIST /
|
||||
Returns a list of names of configured plugins.
|
||||
|
||||
GET /<name>
|
||||
Retrieve the metadata for the named plugin.
|
||||
|
||||
PUT /<name>
|
||||
Add or update plugin.
|
||||
|
||||
DELETE /<name>
|
||||
Delete the plugin with the given name.
|
||||
`,
|
||||
},
|
||||
"leases": {
|
||||
`View or list lease metadata.`,
|
||||
`
|
||||
|
|
|
@ -2,6 +2,11 @@ package vault
|
|||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -9,6 +14,8 @@ import (
|
|||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/helper/builtinplugins"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/helper/salt"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
@ -25,6 +32,7 @@ func TestSystemBackend_RootPaths(t *testing.T) {
|
|||
"replication/reindex",
|
||||
"rotate",
|
||||
"config/auditing/*",
|
||||
"plugins/catalog/*",
|
||||
"revoke-prefix/*",
|
||||
"leases/revoke-prefix/*",
|
||||
"leases/revoke-force/*",
|
||||
|
@ -1543,3 +1551,92 @@ func testCoreSystemBackend(t *testing.T) (*Core, logical.Backend, string) {
|
|||
}
|
||||
return c, b, root
|
||||
}
|
||||
|
||||
func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) {
|
||||
c, b, _ := testCoreSystemBackend(t)
|
||||
// Bootstrap the pluginCatalog
|
||||
sym, err := filepath.EvalSymlinks(os.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
c.pluginCatalog.directory = sym
|
||||
|
||||
req := logical.TestRequest(t, logical.ListOperation, "plugins/catalog/")
|
||||
resp, err := b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Data["keys"].([]string)) != len(builtinplugins.Keys()) {
|
||||
t.Fatalf("Wrong number of plugins, got %d, expected %d", len(resp.Data["keys"].([]string)), len(builtinplugins.Keys()))
|
||||
}
|
||||
|
||||
req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/mysql-database-plugin")
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
expectedBuiltin := &pluginutil.PluginRunner{
|
||||
Name: "mysql-database-plugin",
|
||||
Builtin: true,
|
||||
}
|
||||
expectedBuiltin.BuiltinFactory, _ = builtinplugins.Get("mysql-database-plugin")
|
||||
|
||||
p := resp.Data["plugin"].(*pluginutil.PluginRunner)
|
||||
if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) {
|
||||
t.Fatal("expected BuiltinFactory did not match actual")
|
||||
}
|
||||
|
||||
expectedBuiltin.BuiltinFactory = nil
|
||||
p.BuiltinFactory = nil
|
||||
if !reflect.DeepEqual(p, expectedBuiltin) {
|
||||
t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", resp.Data["plugin"].(*pluginutil.PluginRunner), expectedBuiltin)
|
||||
}
|
||||
|
||||
// Set a plugin
|
||||
file, err := ioutil.TempFile(os.TempDir(), "temp")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
command := fmt.Sprintf("%s --test", filepath.Base(file.Name()))
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "plugins/catalog/test-plugin")
|
||||
req.Data["sha_256"] = hex.EncodeToString([]byte{'1'})
|
||||
req.Data["command"] = command
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/test-plugin")
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
expected := &pluginutil.PluginRunner{
|
||||
Name: "test-plugin",
|
||||
Command: filepath.Join(sym, filepath.Base(file.Name())),
|
||||
Args: []string{"--test"},
|
||||
Sha256: []byte{'1'},
|
||||
Builtin: false,
|
||||
}
|
||||
if !reflect.DeepEqual(resp.Data["plugin"].(*pluginutil.PluginRunner), expected) {
|
||||
t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", resp.Data["plugin"].(*pluginutil.PluginRunner), expected)
|
||||
}
|
||||
|
||||
// Delete plugin
|
||||
req = logical.TestRequest(t, logical.DeleteOperation, "plugins/catalog/test-plugin")
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/test-plugin")
|
||||
resp, err = b.HandleRequest(req)
|
||||
if resp != nil || err != nil {
|
||||
t.Fatalf("expected nil response, plugin not deleted correctly got resp: %v, err: %v", resp, err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/helper/builtinplugins"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
var (
|
||||
pluginCatalogPath = "core/plugin-catalog/"
|
||||
ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured")
|
||||
)
|
||||
|
||||
// PluginCatalog keeps a record of plugins known to vault. External plugins need
|
||||
// to be registered to the catalog before they can be used in backends. Builtin
|
||||
// plugins are automatically detected and included in the catalog.
|
||||
type PluginCatalog struct {
|
||||
catalogView *BarrierView
|
||||
directory string
|
||||
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *Core) setupPluginCatalog() error {
|
||||
c.pluginCatalog = &PluginCatalog{
|
||||
catalogView: NewBarrierView(c.barrier, pluginCatalogPath),
|
||||
directory: c.pluginDirectory,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a plugin with the specified name from the catalog. It first
|
||||
// looks for external plugins with this name and then looks for builtin plugins.
|
||||
// It returns a PluginRunner or an error if no plugin was found.
|
||||
func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
// If the directory isn't set only look for builtin plugins.
|
||||
if c.directory != "" {
|
||||
// Look for external plugins in the barrier
|
||||
out, err := c.catalogView.Get(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve plugin \"%s\": %v", name, err)
|
||||
}
|
||||
if out != nil {
|
||||
entry := new(pluginutil.PluginRunner)
|
||||
if err := jsonutil.DecodeJSON(out.Value, entry); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode plugin entry: %v", err)
|
||||
}
|
||||
|
||||
// prepend the plugin directory to the command
|
||||
entry.Command = filepath.Join(c.directory, entry.Command)
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
}
|
||||
// Look for builtin plugins
|
||||
if factory, ok := builtinplugins.Get(name); ok {
|
||||
return &pluginutil.PluginRunner{
|
||||
Name: name,
|
||||
Builtin: true,
|
||||
BuiltinFactory: factory,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Set registers a new external plugin with the catalog, or updates an existing
|
||||
// external plugin. It takes the name, command and SHA256 of the plugin.
|
||||
func (c *PluginCatalog) Set(name, command string, sha256 []byte) error {
|
||||
if c.directory == "" {
|
||||
return ErrDirectoryNotConfigured
|
||||
}
|
||||
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
parts := strings.Split(command, " ")
|
||||
|
||||
// Best effort check to make sure the command isn't breaking out of the
|
||||
// configured plugin directory.
|
||||
commandFull := filepath.Join(c.directory, parts[0])
|
||||
sym, err := filepath.EvalSymlinks(commandFull)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while validating the command path: %v", err)
|
||||
}
|
||||
symAbs, err := filepath.Abs(filepath.Dir(sym))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while validating the command path: %v", err)
|
||||
}
|
||||
|
||||
if symAbs != c.directory {
|
||||
return errors.New("can not execute files outside of configured plugin directory")
|
||||
}
|
||||
|
||||
entry := &pluginutil.PluginRunner{
|
||||
Name: name,
|
||||
Command: parts[0],
|
||||
Args: parts[1:],
|
||||
Sha256: sha256,
|
||||
Builtin: false,
|
||||
}
|
||||
|
||||
buf, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode plugin entry: %v", err)
|
||||
}
|
||||
|
||||
logicalEntry := logical.StorageEntry{
|
||||
Key: name,
|
||||
Value: buf,
|
||||
}
|
||||
if err := c.catalogView.Put(&logicalEntry); err != nil {
|
||||
return fmt.Errorf("failed to persist plugin entry: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete is used to remove an external plugin from the catalog. Builtin plugins
|
||||
// can not be deleted.
|
||||
func (c *PluginCatalog) Delete(name string) error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
return c.catalogView.Delete(name)
|
||||
}
|
||||
|
||||
// List returns a list of all the known plugin names. If an external and builtin
|
||||
// plugin share the same name, only one instance of the name will be returned.
|
||||
func (c *PluginCatalog) List() ([]string, error) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
// Collect keys for external plugins in the barrier.
|
||||
keys, err := logical.CollectKeys(c.catalogView)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the keys for builtin plugins
|
||||
builtinKeys := builtinplugins.Keys()
|
||||
|
||||
// Use a map to unique the two lists
|
||||
mapKeys := make(map[string]bool)
|
||||
|
||||
for _, plugin := range keys {
|
||||
mapKeys[plugin] = true
|
||||
}
|
||||
|
||||
for _, plugin := range builtinKeys {
|
||||
mapKeys[plugin] = true
|
||||
}
|
||||
|
||||
retList := make([]string, len(mapKeys))
|
||||
i := 0
|
||||
for k := range mapKeys {
|
||||
retList[i] = k
|
||||
i++
|
||||
}
|
||||
// sort for consistent ordering of builtin pluings
|
||||
sort.Strings(retList)
|
||||
|
||||
return retList, nil
|
||||
}
|
|
@ -0,0 +1,176 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/helper/builtinplugins"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
)
|
||||
|
||||
func TestPluginCatalog_CRUD(t *testing.T) {
|
||||
core, _, _ := TestCoreUnsealed(t)
|
||||
|
||||
sym, err := filepath.EvalSymlinks(os.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
core.pluginCatalog.directory = sym
|
||||
|
||||
// Get builtin plugin
|
||||
p, err := core.pluginCatalog.Get("mysql-database-plugin")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error %v", err)
|
||||
}
|
||||
|
||||
expectedBuiltin := &pluginutil.PluginRunner{
|
||||
Name: "mysql-database-plugin",
|
||||
Builtin: true,
|
||||
}
|
||||
expectedBuiltin.BuiltinFactory, _ = builtinplugins.Get("mysql-database-plugin")
|
||||
|
||||
if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) {
|
||||
t.Fatal("expected BuiltinFactory did not match actual")
|
||||
}
|
||||
expectedBuiltin.BuiltinFactory = nil
|
||||
p.BuiltinFactory = nil
|
||||
if !reflect.DeepEqual(p, expectedBuiltin) {
|
||||
t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expectedBuiltin)
|
||||
}
|
||||
|
||||
// Set a plugin, test overwriting a builtin plugin
|
||||
file, err := ioutil.TempFile(os.TempDir(), "temp")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
command := fmt.Sprintf("%s --test", filepath.Base(file.Name()))
|
||||
err = core.pluginCatalog.Set("mysql-database-plugin", command, []byte{'1'})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Get the plugin
|
||||
p, err = core.pluginCatalog.Get("mysql-database-plugin")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error %v", err)
|
||||
}
|
||||
|
||||
expected := &pluginutil.PluginRunner{
|
||||
Name: "mysql-database-plugin",
|
||||
Command: filepath.Join(sym, filepath.Base(file.Name())),
|
||||
Args: []string{"--test"},
|
||||
Sha256: []byte{'1'},
|
||||
Builtin: false,
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(p, expected) {
|
||||
t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expected)
|
||||
}
|
||||
|
||||
// Delete the plugin
|
||||
err = core.pluginCatalog.Delete("mysql-database-plugin")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
// Get builtin plugin
|
||||
p, err = core.pluginCatalog.Get("mysql-database-plugin")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error %v", err)
|
||||
}
|
||||
|
||||
expectedBuiltin = &pluginutil.PluginRunner{
|
||||
Name: "mysql-database-plugin",
|
||||
Builtin: true,
|
||||
}
|
||||
expectedBuiltin.BuiltinFactory, _ = builtinplugins.Get("mysql-database-plugin")
|
||||
|
||||
if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) {
|
||||
t.Fatal("expected BuiltinFactory did not match actual")
|
||||
}
|
||||
expectedBuiltin.BuiltinFactory = nil
|
||||
p.BuiltinFactory = nil
|
||||
if !reflect.DeepEqual(p, expectedBuiltin) {
|
||||
t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expectedBuiltin)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestPluginCatalog_List(t *testing.T) {
|
||||
core, _, _ := TestCoreUnsealed(t)
|
||||
|
||||
sym, err := filepath.EvalSymlinks(os.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
core.pluginCatalog.directory = sym
|
||||
|
||||
// Get builtin plugins and sort them
|
||||
builtinKeys := builtinplugins.Keys()
|
||||
sort.Strings(builtinKeys)
|
||||
|
||||
// List only builtin plugins
|
||||
plugins, err := core.pluginCatalog.List()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error %v", err)
|
||||
}
|
||||
|
||||
if len(plugins) != len(builtinKeys) {
|
||||
t.Fatalf("unexpected length of plugin list, expected %d, got %d", len(builtinKeys), len(plugins))
|
||||
}
|
||||
|
||||
for i, p := range builtinKeys {
|
||||
if !reflect.DeepEqual(plugins[i], p) {
|
||||
t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", plugins[i], p)
|
||||
}
|
||||
}
|
||||
|
||||
// Set a plugin, test overwriting a builtin plugin
|
||||
file, err := ioutil.TempFile(os.TempDir(), "temp")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
command := fmt.Sprintf("%s --test", filepath.Base(file.Name()))
|
||||
err = core.pluginCatalog.Set("mysql-database-plugin", command, []byte{'1'})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Set another plugin
|
||||
err = core.pluginCatalog.Set("aaaaaaa", command, []byte{'1'})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// List the plugins
|
||||
plugins, err = core.pluginCatalog.List()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error %v", err)
|
||||
}
|
||||
|
||||
if len(plugins) != len(builtinKeys)+1 {
|
||||
t.Fatalf("unexpected length of plugin list, expected %d, got %d", len(builtinKeys)+1, len(plugins))
|
||||
}
|
||||
|
||||
// verify the first plugin is the one we just created.
|
||||
if !reflect.DeepEqual(plugins[0], "aaaaaaa") {
|
||||
t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", plugins[0], "aaaaaaa")
|
||||
}
|
||||
|
||||
// verify the builtin pluings are correct
|
||||
for i, p := range builtinKeys {
|
||||
if !reflect.DeepEqual(plugins[i+1], p) {
|
||||
t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", plugins[i+1], p)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/policyutil"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/helper/wrapping"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
|
@ -216,7 +217,7 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
}
|
||||
|
||||
if wrapTTL > 0 {
|
||||
resp.WrapInfo = &logical.ResponseWrapInfo{
|
||||
resp.WrapInfo = &wrapping.ResponseWrapInfo{
|
||||
TTL: wrapTTL,
|
||||
Format: wrapFormat,
|
||||
}
|
||||
|
@ -362,7 +363,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
|
|||
}
|
||||
|
||||
if wrapTTL > 0 {
|
||||
resp.WrapInfo = &logical.ResponseWrapInfo{
|
||||
resp.WrapInfo = &wrapping.ResponseWrapInfo{
|
||||
TTL: wrapTTL,
|
||||
Format: wrapFormat,
|
||||
}
|
||||
|
|
|
@ -8,9 +8,12 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -293,6 +296,45 @@ func TestKeyCopy(key []byte) []byte {
|
|||
return result
|
||||
}
|
||||
|
||||
func TestDynamicSystemView(c *Core) *dynamicSystemView {
|
||||
me := &MountEntry{
|
||||
Config: MountConfig{
|
||||
DefaultLeaseTTL: 24 * time.Hour,
|
||||
MaxLeaseTTL: 2 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
|
||||
return &dynamicSystemView{c, me}
|
||||
}
|
||||
|
||||
func TestAddTestPlugin(t testing.TB, c *Core, name, testFunc string) {
|
||||
file, err := os.Open(os.Args[0])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hash := sha256.New()
|
||||
|
||||
_, err = io.Copy(hash, file)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sum := hash.Sum(nil)
|
||||
c.pluginCatalog.directory, err = filepath.EvalSymlinks(os.Args[0])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c.pluginCatalog.directory = filepath.Dir(c.pluginCatalog.directory)
|
||||
|
||||
command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc)
|
||||
err = c.pluginCatalog.Set(name, command, sum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var testLogicalBackends = map[string]logical.Factory{}
|
||||
|
||||
// Starts the test server which responds to SSH authentication.
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Florian Sundermann
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,18 @@
|
|||
##Introduction##
|
||||
This is a package for GO which can be used to create different types of barcodes.
|
||||
|
||||
##Supported Barcode Types##
|
||||
* Aztec Code
|
||||
* Codabar
|
||||
* Code 128
|
||||
* Code 39
|
||||
* EAN 8
|
||||
* EAN 13
|
||||
* Datamatrix
|
||||
* QR Codes
|
||||
* 2 of 5
|
||||
|
||||
##Documentation##
|
||||
See [GoDoc](https://godoc.org/github.com/boombuler/barcode)
|
||||
|
||||
To create a barcode use the Encode function from one of the subpackages.
|
|
@ -0,0 +1,27 @@
|
|||
package barcode
|
||||
|
||||
import "image"
|
||||
|
||||
// Contains some meta information about a barcode
|
||||
type Metadata struct {
|
||||
// the name of the barcode kind
|
||||
CodeKind string
|
||||
// contains 1 for 1D barcodes or 2 for 2D barcodes
|
||||
Dimensions byte
|
||||
}
|
||||
|
||||
// a rendered and encoded barcode
|
||||
type Barcode interface {
|
||||
image.Image
|
||||
// returns some meta information about the barcode
|
||||
Metadata() Metadata
|
||||
// the data that was encoded in this barcode
|
||||
Content() string
|
||||
}
|
||||
|
||||
// Additional interface that some barcodes might implement to provide
|
||||
// the value of its checksum.
|
||||
type BarcodeIntCS interface {
|
||||
Barcode
|
||||
CheckSum() int
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
package qr
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/boombuler/barcode/utils"
|
||||
)
|
||||
|
||||
const charSet string = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ $%*+-./:"
|
||||
|
||||
func stringToAlphaIdx(content string) <-chan int {
|
||||
result := make(chan int)
|
||||
go func() {
|
||||
for _, r := range content {
|
||||
idx := strings.IndexRune(charSet, r)
|
||||
result <- idx
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
close(result)
|
||||
}()
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func encodeAlphaNumeric(content string, ecl ErrorCorrectionLevel) (*utils.BitList, *versionInfo, error) {
|
||||
|
||||
contentLenIsOdd := len(content)%2 == 1
|
||||
contentBitCount := (len(content) / 2) * 11
|
||||
if contentLenIsOdd {
|
||||
contentBitCount += 6
|
||||
}
|
||||
vi := findSmallestVersionInfo(ecl, alphaNumericMode, contentBitCount)
|
||||
if vi == nil {
|
||||
return nil, nil, errors.New("To much data to encode")
|
||||
}
|
||||
|
||||
res := new(utils.BitList)
|
||||
res.AddBits(int(alphaNumericMode), 4)
|
||||
res.AddBits(len(content), vi.charCountBits(alphaNumericMode))
|
||||
|
||||
encoder := stringToAlphaIdx(content)
|
||||
|
||||
for idx := 0; idx < len(content)/2; idx++ {
|
||||
c1 := <-encoder
|
||||
c2 := <-encoder
|
||||
if c1 < 0 || c2 < 0 {
|
||||
return nil, nil, fmt.Errorf("\"%s\" can not be encoded as %s", content, AlphaNumeric)
|
||||
}
|
||||
res.AddBits(c1*45+c2, 11)
|
||||
}
|
||||
if contentLenIsOdd {
|
||||
c := <-encoder
|
||||
if c < 0 {
|
||||
return nil, nil, fmt.Errorf("\"%s\" can not be encoded as %s", content, AlphaNumeric)
|
||||
}
|
||||
res.AddBits(c, 6)
|
||||
}
|
||||
|
||||
addPaddingAndTerminator(res, vi)
|
||||
|
||||
return res, vi, nil
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package qr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/boombuler/barcode/utils"
|
||||
)
|
||||
|
||||
func encodeAuto(content string, ecl ErrorCorrectionLevel) (*utils.BitList, *versionInfo, error) {
|
||||
bits, vi, _ := Numeric.getEncoder()(content, ecl)
|
||||
if bits != nil && vi != nil {
|
||||
return bits, vi, nil
|
||||
}
|
||||
bits, vi, _ = AlphaNumeric.getEncoder()(content, ecl)
|
||||
if bits != nil && vi != nil {
|
||||
return bits, vi, nil
|
||||
}
|
||||
bits, vi, _ = Unicode.getEncoder()(content, ecl)
|
||||
if bits != nil && vi != nil {
|
||||
return bits, vi, nil
|
||||
}
|
||||
return nil, nil, fmt.Errorf("No encoding found to encode \"%s\"", content)
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
package qr
|
||||
|
||||
type block struct {
|
||||
data []byte
|
||||
ecc []byte
|
||||
}
|
||||
type blockList []*block
|
||||
|
||||
func splitToBlocks(data <-chan byte, vi *versionInfo) blockList {
|
||||
result := make(blockList, vi.NumberOfBlocksInGroup1+vi.NumberOfBlocksInGroup2)
|
||||
|
||||
for b := 0; b < int(vi.NumberOfBlocksInGroup1); b++ {
|
||||
blk := new(block)
|
||||
blk.data = make([]byte, vi.DataCodeWordsPerBlockInGroup1)
|
||||
for cw := 0; cw < int(vi.DataCodeWordsPerBlockInGroup1); cw++ {
|
||||
blk.data[cw] = <-data
|
||||
}
|
||||
blk.ecc = ec.calcECC(blk.data, vi.ErrorCorrectionCodewordsPerBlock)
|
||||
result[b] = blk
|
||||
}
|
||||
|
||||
for b := 0; b < int(vi.NumberOfBlocksInGroup2); b++ {
|
||||
blk := new(block)
|
||||
blk.data = make([]byte, vi.DataCodeWordsPerBlockInGroup2)
|
||||
for cw := 0; cw < int(vi.DataCodeWordsPerBlockInGroup2); cw++ {
|
||||
blk.data[cw] = <-data
|
||||
}
|
||||
blk.ecc = ec.calcECC(blk.data, vi.ErrorCorrectionCodewordsPerBlock)
|
||||
result[int(vi.NumberOfBlocksInGroup1)+b] = blk
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (bl blockList) interleave(vi *versionInfo) []byte {
|
||||
var maxCodewordCount int
|
||||
if vi.DataCodeWordsPerBlockInGroup1 > vi.DataCodeWordsPerBlockInGroup2 {
|
||||
maxCodewordCount = int(vi.DataCodeWordsPerBlockInGroup1)
|
||||
} else {
|
||||
maxCodewordCount = int(vi.DataCodeWordsPerBlockInGroup2)
|
||||
}
|
||||
resultLen := (vi.DataCodeWordsPerBlockInGroup1+vi.ErrorCorrectionCodewordsPerBlock)*vi.NumberOfBlocksInGroup1 +
|
||||
(vi.DataCodeWordsPerBlockInGroup2+vi.ErrorCorrectionCodewordsPerBlock)*vi.NumberOfBlocksInGroup2
|
||||
|
||||
result := make([]byte, 0, resultLen)
|
||||
for i := 0; i < maxCodewordCount; i++ {
|
||||
for b := 0; b < len(bl); b++ {
|
||||
if len(bl[b].data) > i {
|
||||
result = append(result, bl[b].data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
for i := 0; i < int(vi.ErrorCorrectionCodewordsPerBlock); i++ {
|
||||
for b := 0; b < len(bl); b++ {
|
||||
result = append(result, bl[b].ecc[i])
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
|
@ -0,0 +1,416 @@
|
|||
// Package qr can be used to create QR barcodes.
|
||||
package qr
|
||||
|
||||
import (
|
||||
"image"
|
||||
|
||||
"github.com/boombuler/barcode"
|
||||
"github.com/boombuler/barcode/utils"
|
||||
)
|
||||
|
||||
type encodeFn func(content string, eccLevel ErrorCorrectionLevel) (*utils.BitList, *versionInfo, error)
|
||||
|
||||
// Encoding mode for QR Codes.
|
||||
type Encoding byte
|
||||
|
||||
const (
|
||||
// Auto will choose ths best matching encoding
|
||||
Auto Encoding = iota
|
||||
// Numeric encoding only encodes numbers [0-9]
|
||||
Numeric
|
||||
// AlphaNumeric encoding only encodes uppercase letters, numbers and [Space], $, %, *, +, -, ., /, :
|
||||
AlphaNumeric
|
||||
// Unicode encoding encodes the string as utf-8
|
||||
Unicode
|
||||
// only for testing purpose
|
||||
unknownEncoding
|
||||
)
|
||||
|
||||
func (e Encoding) getEncoder() encodeFn {
|
||||
switch e {
|
||||
case Auto:
|
||||
return encodeAuto
|
||||
case Numeric:
|
||||
return encodeNumeric
|
||||
case AlphaNumeric:
|
||||
return encodeAlphaNumeric
|
||||
case Unicode:
|
||||
return encodeUnicode
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e Encoding) String() string {
|
||||
switch e {
|
||||
case Auto:
|
||||
return "Auto"
|
||||
case Numeric:
|
||||
return "Numeric"
|
||||
case AlphaNumeric:
|
||||
return "AlphaNumeric"
|
||||
case Unicode:
|
||||
return "Unicode"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Encode returns a QR barcode with the given content, error correction level and uses the given encoding
|
||||
func Encode(content string, level ErrorCorrectionLevel, mode Encoding) (barcode.Barcode, error) {
|
||||
bits, vi, err := mode.getEncoder()(content, level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blocks := splitToBlocks(bits.IterateBytes(), vi)
|
||||
data := blocks.interleave(vi)
|
||||
result := render(data, vi)
|
||||
result.content = content
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func render(data []byte, vi *versionInfo) *qrcode {
|
||||
dim := vi.modulWidth()
|
||||
results := make([]*qrcode, 8)
|
||||
for i := 0; i < 8; i++ {
|
||||
results[i] = newBarcode(dim)
|
||||
}
|
||||
|
||||
occupied := newBarcode(dim)
|
||||
|
||||
setAll := func(x int, y int, val bool) {
|
||||
occupied.Set(x, y, true)
|
||||
for i := 0; i < 8; i++ {
|
||||
results[i].Set(x, y, val)
|
||||
}
|
||||
}
|
||||
|
||||
drawFinderPatterns(vi, setAll)
|
||||
drawAlignmentPatterns(occupied, vi, setAll)
|
||||
|
||||
//Timing Pattern:
|
||||
var i int
|
||||
for i = 0; i < dim; i++ {
|
||||
if !occupied.Get(i, 6) {
|
||||
setAll(i, 6, i%2 == 0)
|
||||
}
|
||||
if !occupied.Get(6, i) {
|
||||
setAll(6, i, i%2 == 0)
|
||||
}
|
||||
}
|
||||
// Dark Module
|
||||
setAll(8, dim-8, true)
|
||||
|
||||
drawVersionInfo(vi, setAll)
|
||||
drawFormatInfo(vi, -1, occupied.Set)
|
||||
for i := 0; i < 8; i++ {
|
||||
drawFormatInfo(vi, i, results[i].Set)
|
||||
}
|
||||
|
||||
// Write the data
|
||||
var curBitNo int
|
||||
|
||||
for pos := range iterateModules(occupied) {
|
||||
var curBit bool
|
||||
if curBitNo < len(data)*8 {
|
||||
curBit = ((data[curBitNo/8] >> uint(7-(curBitNo%8))) & 1) == 1
|
||||
} else {
|
||||
curBit = false
|
||||
}
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
setMasked(pos.X, pos.Y, curBit, i, results[i].Set)
|
||||
}
|
||||
curBitNo++
|
||||
}
|
||||
|
||||
lowestPenalty := ^uint(0)
|
||||
lowestPenaltyIdx := -1
|
||||
for i := 0; i < 8; i++ {
|
||||
p := results[i].calcPenalty()
|
||||
if p < lowestPenalty {
|
||||
lowestPenalty = p
|
||||
lowestPenaltyIdx = i
|
||||
}
|
||||
}
|
||||
return results[lowestPenaltyIdx]
|
||||
}
|
||||
|
||||
func setMasked(x, y int, val bool, mask int, set func(int, int, bool)) {
|
||||
switch mask {
|
||||
case 0:
|
||||
val = val != (((y + x) % 2) == 0)
|
||||
break
|
||||
case 1:
|
||||
val = val != ((y % 2) == 0)
|
||||
break
|
||||
case 2:
|
||||
val = val != ((x % 3) == 0)
|
||||
break
|
||||
case 3:
|
||||
val = val != (((y + x) % 3) == 0)
|
||||
break
|
||||
case 4:
|
||||
val = val != (((y/2 + x/3) % 2) == 0)
|
||||
break
|
||||
case 5:
|
||||
val = val != (((y*x)%2)+((y*x)%3) == 0)
|
||||
break
|
||||
case 6:
|
||||
val = val != ((((y*x)%2)+((y*x)%3))%2 == 0)
|
||||
break
|
||||
case 7:
|
||||
val = val != ((((y+x)%2)+((y*x)%3))%2 == 0)
|
||||
}
|
||||
set(x, y, val)
|
||||
}
|
||||
|
||||
func iterateModules(occupied *qrcode) <-chan image.Point {
|
||||
result := make(chan image.Point)
|
||||
allPoints := make(chan image.Point)
|
||||
go func() {
|
||||
curX := occupied.dimension - 1
|
||||
curY := occupied.dimension - 1
|
||||
isUpward := true
|
||||
|
||||
for true {
|
||||
if isUpward {
|
||||
allPoints <- image.Pt(curX, curY)
|
||||
allPoints <- image.Pt(curX-1, curY)
|
||||
curY--
|
||||
if curY < 0 {
|
||||
curY = 0
|
||||
curX -= 2
|
||||
if curX == 6 {
|
||||
curX--
|
||||
}
|
||||
if curX < 0 {
|
||||
break
|
||||
}
|
||||
isUpward = false
|
||||
}
|
||||
} else {
|
||||
allPoints <- image.Pt(curX, curY)
|
||||
allPoints <- image.Pt(curX-1, curY)
|
||||
curY++
|
||||
if curY >= occupied.dimension {
|
||||
curY = occupied.dimension - 1
|
||||
curX -= 2
|
||||
if curX == 6 {
|
||||
curX--
|
||||
}
|
||||
isUpward = true
|
||||
if curX < 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close(allPoints)
|
||||
}()
|
||||
go func() {
|
||||
for pt := range allPoints {
|
||||
if !occupied.Get(pt.X, pt.Y) {
|
||||
result <- pt
|
||||
}
|
||||
}
|
||||
close(result)
|
||||
}()
|
||||
return result
|
||||
}
|
||||
|
||||
func drawFinderPatterns(vi *versionInfo, set func(int, int, bool)) {
|
||||
dim := vi.modulWidth()
|
||||
drawPattern := func(xoff int, yoff int) {
|
||||
for x := -1; x < 8; x++ {
|
||||
for y := -1; y < 8; y++ {
|
||||
val := (x == 0 || x == 6 || y == 0 || y == 6 || (x > 1 && x < 5 && y > 1 && y < 5)) && (x <= 6 && y <= 6 && x >= 0 && y >= 0)
|
||||
|
||||
if x+xoff >= 0 && x+xoff < dim && y+yoff >= 0 && y+yoff < dim {
|
||||
set(x+xoff, y+yoff, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
drawPattern(0, 0)
|
||||
drawPattern(0, dim-7)
|
||||
drawPattern(dim-7, 0)
|
||||
}
|
||||
|
||||
func drawAlignmentPatterns(occupied *qrcode, vi *versionInfo, set func(int, int, bool)) {
|
||||
drawPattern := func(xoff int, yoff int) {
|
||||
for x := -2; x <= 2; x++ {
|
||||
for y := -2; y <= 2; y++ {
|
||||
val := x == -2 || x == 2 || y == -2 || y == 2 || (x == 0 && y == 0)
|
||||
set(x+xoff, y+yoff, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
positions := vi.alignmentPatternPlacements()
|
||||
|
||||
for _, x := range positions {
|
||||
for _, y := range positions {
|
||||
if occupied.Get(x, y) {
|
||||
continue
|
||||
}
|
||||
drawPattern(x, y)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var formatInfos = map[ErrorCorrectionLevel]map[int][]bool{
|
||||
L: {
|
||||
0: []bool{true, true, true, false, true, true, true, true, true, false, false, false, true, false, false},
|
||||
1: []bool{true, true, true, false, false, true, false, true, true, true, true, false, false, true, true},
|
||||
2: []bool{true, true, true, true, true, false, true, true, false, true, false, true, false, true, false},
|
||||
3: []bool{true, true, true, true, false, false, false, true, false, false, true, true, true, false, true},
|
||||
4: []bool{true, true, false, false, true, true, false, false, false, true, false, true, true, true, true},
|
||||
5: []bool{true, true, false, false, false, true, true, false, false, false, true, true, false, false, false},
|
||||
6: []bool{true, true, false, true, true, false, false, false, true, false, false, false, false, false, true},
|
||||
7: []bool{true, true, false, true, false, false, true, false, true, true, true, false, true, true, false},
|
||||
},
|
||||
M: {
|
||||
0: []bool{true, false, true, false, true, false, false, false, false, false, true, false, false, true, false},
|
||||
1: []bool{true, false, true, false, false, false, true, false, false, true, false, false, true, false, true},
|
||||
2: []bool{true, false, true, true, true, true, false, false, true, true, true, true, true, false, false},
|
||||
3: []bool{true, false, true, true, false, true, true, false, true, false, false, true, false, true, true},
|
||||
4: []bool{true, false, false, false, true, false, true, true, true, true, true, true, false, false, true},
|
||||
5: []bool{true, false, false, false, false, false, false, true, true, false, false, true, true, true, false},
|
||||
6: []bool{true, false, false, true, true, true, true, true, false, false, true, false, true, true, true},
|
||||
7: []bool{true, false, false, true, false, true, false, true, false, true, false, false, false, false, false},
|
||||
},
|
||||
Q: {
|
||||
0: []bool{false, true, true, false, true, false, true, false, true, false, true, true, true, true, true},
|
||||
1: []bool{false, true, true, false, false, false, false, false, true, true, false, true, false, false, false},
|
||||
2: []bool{false, true, true, true, true, true, true, false, false, true, true, false, false, false, true},
|
||||
3: []bool{false, true, true, true, false, true, false, false, false, false, false, false, true, true, false},
|
||||
4: []bool{false, true, false, false, true, false, false, true, false, true, true, false, true, false, false},
|
||||
5: []bool{false, true, false, false, false, false, true, true, false, false, false, false, false, true, true},
|
||||
6: []bool{false, true, false, true, true, true, false, true, true, false, true, true, false, true, false},
|
||||
7: []bool{false, true, false, true, false, true, true, true, true, true, false, true, true, false, true},
|
||||
},
|
||||
H: {
|
||||
0: []bool{false, false, true, false, true, true, false, true, false, false, false, true, false, false, true},
|
||||
1: []bool{false, false, true, false, false, true, true, true, false, true, true, true, true, true, false},
|
||||
2: []bool{false, false, true, true, true, false, false, true, true, true, false, false, true, true, true},
|
||||
3: []bool{false, false, true, true, false, false, true, true, true, false, true, false, false, false, false},
|
||||
4: []bool{false, false, false, false, true, true, true, false, true, true, false, false, false, true, false},
|
||||
5: []bool{false, false, false, false, false, true, false, false, true, false, true, false, true, false, true},
|
||||
6: []bool{false, false, false, true, true, false, true, false, false, false, false, true, true, false, false},
|
||||
7: []bool{false, false, false, true, false, false, false, false, false, true, true, true, false, true, true},
|
||||
},
|
||||
}
|
||||
|
||||
func drawFormatInfo(vi *versionInfo, usedMask int, set func(int, int, bool)) {
|
||||
var formatInfo []bool
|
||||
|
||||
if usedMask == -1 {
|
||||
formatInfo = []bool{true, true, true, true, true, true, true, true, true, true, true, true, true, true, true} // Set all to true cause -1 --> occupied mask.
|
||||
} else {
|
||||
formatInfo = formatInfos[vi.Level][usedMask]
|
||||
}
|
||||
|
||||
if len(formatInfo) == 15 {
|
||||
dim := vi.modulWidth()
|
||||
set(0, 8, formatInfo[0])
|
||||
set(1, 8, formatInfo[1])
|
||||
set(2, 8, formatInfo[2])
|
||||
set(3, 8, formatInfo[3])
|
||||
set(4, 8, formatInfo[4])
|
||||
set(5, 8, formatInfo[5])
|
||||
set(7, 8, formatInfo[6])
|
||||
set(8, 8, formatInfo[7])
|
||||
set(8, 7, formatInfo[8])
|
||||
set(8, 5, formatInfo[9])
|
||||
set(8, 4, formatInfo[10])
|
||||
set(8, 3, formatInfo[11])
|
||||
set(8, 2, formatInfo[12])
|
||||
set(8, 1, formatInfo[13])
|
||||
set(8, 0, formatInfo[14])
|
||||
|
||||
set(8, dim-1, formatInfo[0])
|
||||
set(8, dim-2, formatInfo[1])
|
||||
set(8, dim-3, formatInfo[2])
|
||||
set(8, dim-4, formatInfo[3])
|
||||
set(8, dim-5, formatInfo[4])
|
||||
set(8, dim-6, formatInfo[5])
|
||||
set(8, dim-7, formatInfo[6])
|
||||
set(dim-8, 8, formatInfo[7])
|
||||
set(dim-7, 8, formatInfo[8])
|
||||
set(dim-6, 8, formatInfo[9])
|
||||
set(dim-5, 8, formatInfo[10])
|
||||
set(dim-4, 8, formatInfo[11])
|
||||
set(dim-3, 8, formatInfo[12])
|
||||
set(dim-2, 8, formatInfo[13])
|
||||
set(dim-1, 8, formatInfo[14])
|
||||
}
|
||||
}
|
||||
|
||||
var versionInfoBitsByVersion = map[byte][]bool{
|
||||
7: []bool{false, false, false, true, true, true, true, true, false, false, true, false, false, true, false, true, false, false},
|
||||
8: []bool{false, false, true, false, false, false, false, true, false, true, true, false, true, true, true, true, false, false},
|
||||
9: []bool{false, false, true, false, false, true, true, false, true, false, true, false, false, true, true, false, false, true},
|
||||
10: []bool{false, false, true, false, true, false, false, true, false, false, true, true, false, true, false, false, true, true},
|
||||
11: []bool{false, false, true, false, true, true, true, false, true, true, true, true, true, true, false, true, true, false},
|
||||
12: []bool{false, false, true, true, false, false, false, true, true, true, false, true, true, false, false, false, true, false},
|
||||
13: []bool{false, false, true, true, false, true, true, false, false, false, false, true, false, false, false, true, true, true},
|
||||
14: []bool{false, false, true, true, true, false, false, true, true, false, false, false, false, false, true, true, false, true},
|
||||
15: []bool{false, false, true, true, true, true, true, false, false, true, false, false, true, false, true, false, false, false},
|
||||
16: []bool{false, true, false, false, false, false, true, false, true, true, false, true, true, true, true, false, false, false},
|
||||
17: []bool{false, true, false, false, false, true, false, true, false, false, false, true, false, true, true, true, false, true},
|
||||
18: []bool{false, true, false, false, true, false, true, false, true, false, false, false, false, true, false, true, true, true},
|
||||
19: []bool{false, true, false, false, true, true, false, true, false, true, false, false, true, true, false, false, true, false},
|
||||
20: []bool{false, true, false, true, false, false, true, false, false, true, true, false, true, false, false, true, true, false},
|
||||
21: []bool{false, true, false, true, false, true, false, true, true, false, true, false, false, false, false, false, true, true},
|
||||
22: []bool{false, true, false, true, true, false, true, false, false, false, true, true, false, false, true, false, false, true},
|
||||
23: []bool{false, true, false, true, true, true, false, true, true, true, true, true, true, false, true, true, false, false},
|
||||
24: []bool{false, true, true, false, false, false, true, true, true, false, true, true, false, false, false, true, false, false},
|
||||
25: []bool{false, true, true, false, false, true, false, false, false, true, true, true, true, false, false, false, false, true},
|
||||
26: []bool{false, true, true, false, true, false, true, true, true, true, true, false, true, false, true, false, true, true},
|
||||
27: []bool{false, true, true, false, true, true, false, false, false, false, true, false, false, false, true, true, true, false},
|
||||
28: []bool{false, true, true, true, false, false, true, true, false, false, false, false, false, true, true, false, true, false},
|
||||
29: []bool{false, true, true, true, false, true, false, false, true, true, false, false, true, true, true, true, true, true},
|
||||
30: []bool{false, true, true, true, true, false, true, true, false, true, false, true, true, true, false, true, false, true},
|
||||
31: []bool{false, true, true, true, true, true, false, false, true, false, false, true, false, true, false, false, false, false},
|
||||
32: []bool{true, false, false, false, false, false, true, false, false, true, true, true, false, true, false, true, false, true},
|
||||
33: []bool{true, false, false, false, false, true, false, true, true, false, true, true, true, true, false, false, false, false},
|
||||
34: []bool{true, false, false, false, true, false, true, false, false, false, true, false, true, true, true, false, true, false},
|
||||
35: []bool{true, false, false, false, true, true, false, true, true, true, true, false, false, true, true, true, true, true},
|
||||
36: []bool{true, false, false, true, false, false, true, false, true, true, false, false, false, false, true, false, true, true},
|
||||
37: []bool{true, false, false, true, false, true, false, true, false, false, false, false, true, false, true, true, true, false},
|
||||
38: []bool{true, false, false, true, true, false, true, false, true, false, false, true, true, false, false, true, false, false},
|
||||
39: []bool{true, false, false, true, true, true, false, true, false, true, false, true, false, false, false, false, false, true},
|
||||
40: []bool{true, false, true, false, false, false, true, true, false, false, false, true, true, false, true, false, false, true},
|
||||
}
|
||||
|
||||
func drawVersionInfo(vi *versionInfo, set func(int, int, bool)) {
|
||||
versionInfoBits, ok := versionInfoBitsByVersion[vi.Version]
|
||||
|
||||
if ok && len(versionInfoBits) > 0 {
|
||||
for i := 0; i < len(versionInfoBits); i++ {
|
||||
x := (vi.modulWidth() - 11) + i%3
|
||||
y := i / 3
|
||||
set(x, y, versionInfoBits[len(versionInfoBits)-i-1])
|
||||
set(y, x, versionInfoBits[len(versionInfoBits)-i-1])
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func addPaddingAndTerminator(bl *utils.BitList, vi *versionInfo) {
|
||||
for i := 0; i < 4 && bl.Len() < vi.totalDataBytes()*8; i++ {
|
||||
bl.AddBit(false)
|
||||
}
|
||||
|
||||
for bl.Len()%8 != 0 {
|
||||
bl.AddBit(false)
|
||||
}
|
||||
|
||||
for i := 0; bl.Len() < vi.totalDataBytes()*8; i++ {
|
||||
if i%2 == 0 {
|
||||
bl.AddByte(236)
|
||||
} else {
|
||||
bl.AddByte(17)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package qr
|
||||
|
||||
import (
|
||||
"github.com/boombuler/barcode/utils"
|
||||
)
|
||||
|
||||
type errorCorrection struct {
|
||||
rs *utils.ReedSolomonEncoder
|
||||
}
|
||||
|
||||
var ec = newErrorCorrection()
|
||||
|
||||
func newErrorCorrection() *errorCorrection {
|
||||
fld := utils.NewGaloisField(285, 256, 0)
|
||||
return &errorCorrection{utils.NewReedSolomonEncoder(fld)}
|
||||
}
|
||||
|
||||
func (ec *errorCorrection) calcECC(data []byte, eccCount byte) []byte {
|
||||
dataInts := make([]int, len(data))
|
||||
for i := 0; i < len(data); i++ {
|
||||
dataInts[i] = int(data[i])
|
||||
}
|
||||
res := ec.rs.Encode(dataInts, int(eccCount))
|
||||
result := make([]byte, len(res))
|
||||
for i := 0; i < len(res); i++ {
|
||||
result[i] = byte(res[i])
|
||||
}
|
||||
return result
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
package qr
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/boombuler/barcode/utils"
|
||||
)
|
||||
|
||||
func encodeNumeric(content string, ecl ErrorCorrectionLevel) (*utils.BitList, *versionInfo, error) {
|
||||
contentBitCount := (len(content) / 3) * 10
|
||||
switch len(content) % 3 {
|
||||
case 1:
|
||||
contentBitCount += 4
|
||||
case 2:
|
||||
contentBitCount += 7
|
||||
}
|
||||
vi := findSmallestVersionInfo(ecl, numericMode, contentBitCount)
|
||||
if vi == nil {
|
||||
return nil, nil, errors.New("To much data to encode")
|
||||
}
|
||||
res := new(utils.BitList)
|
||||
res.AddBits(int(numericMode), 4)
|
||||
res.AddBits(len(content), vi.charCountBits(numericMode))
|
||||
|
||||
for pos := 0; pos < len(content); pos += 3 {
|
||||
var curStr string
|
||||
if pos+3 <= len(content) {
|
||||
curStr = content[pos : pos+3]
|
||||
} else {
|
||||
curStr = content[pos:]
|
||||
}
|
||||
|
||||
i, err := strconv.Atoi(curStr)
|
||||
if err != nil || i < 0 {
|
||||
return nil, nil, fmt.Errorf("\"%s\" can not be encoded as %s", content, Numeric)
|
||||
}
|
||||
var bitCnt byte
|
||||
switch len(curStr) % 3 {
|
||||
case 0:
|
||||
bitCnt = 10
|
||||
case 1:
|
||||
bitCnt = 4
|
||||
break
|
||||
case 2:
|
||||
bitCnt = 7
|
||||
break
|
||||
}
|
||||
|
||||
res.AddBits(i, bitCnt)
|
||||
}
|
||||
|
||||
addPaddingAndTerminator(res, vi)
|
||||
return res, vi, nil
|
||||
}
|
|
@ -0,0 +1,166 @@
|
|||
package qr
|
||||
|
||||
import (
|
||||
"image"
|
||||
"image/color"
|
||||
"math"
|
||||
|
||||
"github.com/boombuler/barcode"
|
||||
"github.com/boombuler/barcode/utils"
|
||||
)
|
||||
|
||||
type qrcode struct {
|
||||
dimension int
|
||||
data *utils.BitList
|
||||
content string
|
||||
}
|
||||
|
||||
func (qr *qrcode) Content() string {
|
||||
return qr.content
|
||||
}
|
||||
|
||||
func (qr *qrcode) Metadata() barcode.Metadata {
|
||||
return barcode.Metadata{"QR Code", 2}
|
||||
}
|
||||
|
||||
func (qr *qrcode) ColorModel() color.Model {
|
||||
return color.Gray16Model
|
||||
}
|
||||
|
||||
func (qr *qrcode) Bounds() image.Rectangle {
|
||||
return image.Rect(0, 0, qr.dimension, qr.dimension)
|
||||
}
|
||||
|
||||
func (qr *qrcode) At(x, y int) color.Color {
|
||||
if qr.Get(x, y) {
|
||||
return color.Black
|
||||
}
|
||||
return color.White
|
||||
}
|
||||
|
||||
func (qr *qrcode) Get(x, y int) bool {
|
||||
return qr.data.GetBit(x*qr.dimension + y)
|
||||
}
|
||||
|
||||
func (qr *qrcode) Set(x, y int, val bool) {
|
||||
qr.data.SetBit(x*qr.dimension+y, val)
|
||||
}
|
||||
|
||||
func (qr *qrcode) calcPenalty() uint {
|
||||
return qr.calcPenaltyRule1() + qr.calcPenaltyRule2() + qr.calcPenaltyRule3() + qr.calcPenaltyRule4()
|
||||
}
|
||||
|
||||
func (qr *qrcode) calcPenaltyRule1() uint {
|
||||
var result uint
|
||||
for x := 0; x < qr.dimension; x++ {
|
||||
checkForX := false
|
||||
var cntX uint
|
||||
checkForY := false
|
||||
var cntY uint
|
||||
|
||||
for y := 0; y < qr.dimension; y++ {
|
||||
if qr.Get(x, y) == checkForX {
|
||||
cntX++
|
||||
} else {
|
||||
checkForX = !checkForX
|
||||
if cntX >= 5 {
|
||||
result += cntX - 2
|
||||
}
|
||||
cntX = 1
|
||||
}
|
||||
|
||||
if qr.Get(y, x) == checkForY {
|
||||
cntY++
|
||||
} else {
|
||||
checkForY = !checkForY
|
||||
if cntY >= 5 {
|
||||
result += cntY - 2
|
||||
}
|
||||
cntY = 1
|
||||
}
|
||||
}
|
||||
|
||||
if cntX >= 5 {
|
||||
result += cntX - 2
|
||||
}
|
||||
if cntY >= 5 {
|
||||
result += cntY - 2
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (qr *qrcode) calcPenaltyRule2() uint {
|
||||
var result uint
|
||||
for x := 0; x < qr.dimension-1; x++ {
|
||||
for y := 0; y < qr.dimension-1; y++ {
|
||||
check := qr.Get(x, y)
|
||||
if qr.Get(x, y+1) == check && qr.Get(x+1, y) == check && qr.Get(x+1, y+1) == check {
|
||||
result += 3
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (qr *qrcode) calcPenaltyRule3() uint {
|
||||
pattern1 := []bool{true, false, true, true, true, false, true, false, false, false, false}
|
||||
pattern2 := []bool{false, false, false, false, true, false, true, true, true, false, true}
|
||||
|
||||
var result uint
|
||||
for x := 0; x <= qr.dimension-len(pattern1); x++ {
|
||||
for y := 0; y < qr.dimension; y++ {
|
||||
pattern1XFound := true
|
||||
pattern2XFound := true
|
||||
pattern1YFound := true
|
||||
pattern2YFound := true
|
||||
|
||||
for i := 0; i < len(pattern1); i++ {
|
||||
iv := qr.Get(x+i, y)
|
||||
if iv != pattern1[i] {
|
||||
pattern1XFound = false
|
||||
}
|
||||
if iv != pattern2[i] {
|
||||
pattern2XFound = false
|
||||
}
|
||||
iv = qr.Get(y, x+i)
|
||||
if iv != pattern1[i] {
|
||||
pattern1YFound = false
|
||||
}
|
||||
if iv != pattern2[i] {
|
||||
pattern2YFound = false
|
||||
}
|
||||
}
|
||||
if pattern1XFound || pattern2XFound {
|
||||
result += 40
|
||||
}
|
||||
if pattern1YFound || pattern2YFound {
|
||||
result += 40
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (qr *qrcode) calcPenaltyRule4() uint {
|
||||
totalNum := qr.data.Len()
|
||||
trueCnt := 0
|
||||
for i := 0; i < totalNum; i++ {
|
||||
if qr.data.GetBit(i) {
|
||||
trueCnt++
|
||||
}
|
||||
}
|
||||
percDark := float64(trueCnt) * 100 / float64(totalNum)
|
||||
floor := math.Abs(math.Floor(percDark/5) - 10)
|
||||
ceil := math.Abs(math.Ceil(percDark/5) - 10)
|
||||
return uint(math.Min(floor, ceil) * 10)
|
||||
}
|
||||
|
||||
func newBarcode(dim int) *qrcode {
|
||||
res := new(qrcode)
|
||||
res.dimension = dim
|
||||
res.data = utils.NewBitList(dim * dim)
|
||||
return res
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package qr
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/boombuler/barcode/utils"
|
||||
)
|
||||
|
||||
func encodeUnicode(content string, ecl ErrorCorrectionLevel) (*utils.BitList, *versionInfo, error) {
|
||||
data := []byte(content)
|
||||
|
||||
vi := findSmallestVersionInfo(ecl, byteMode, len(data)*8)
|
||||
if vi == nil {
|
||||
return nil, nil, errors.New("To much data to encode")
|
||||
}
|
||||
|
||||
// It's not correct to add the unicode bytes to the result directly but most readers can't handle the
|
||||
// required ECI header...
|
||||
res := new(utils.BitList)
|
||||
res.AddBits(int(byteMode), 4)
|
||||
res.AddBits(len(content), vi.charCountBits(byteMode))
|
||||
for _, b := range data {
|
||||
res.AddByte(b)
|
||||
}
|
||||
addPaddingAndTerminator(res, vi)
|
||||
return res, vi, nil
|
||||
}
|
|
@ -0,0 +1,310 @@
|
|||
package qr
|
||||
|
||||
import "math"
|
||||
|
||||
// ErrorCorrectionLevel indicates the amount of "backup data" stored in the QR code
|
||||
type ErrorCorrectionLevel byte
|
||||
|
||||
const (
|
||||
// L recovers 7% of data
|
||||
L ErrorCorrectionLevel = iota
|
||||
// M recovers 15% of data
|
||||
M
|
||||
// Q recovers 25% of data
|
||||
Q
|
||||
// H recovers 30% of data
|
||||
H
|
||||
)
|
||||
|
||||
func (ecl ErrorCorrectionLevel) String() string {
|
||||
switch ecl {
|
||||
case L:
|
||||
return "L"
|
||||
case M:
|
||||
return "M"
|
||||
case Q:
|
||||
return "Q"
|
||||
case H:
|
||||
return "H"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
type encodingMode byte
|
||||
|
||||
const (
|
||||
numericMode encodingMode = 1
|
||||
alphaNumericMode encodingMode = 2
|
||||
byteMode encodingMode = 4
|
||||
kanjiMode encodingMode = 8
|
||||
)
|
||||
|
||||
type versionInfo struct {
|
||||
Version byte
|
||||
Level ErrorCorrectionLevel
|
||||
ErrorCorrectionCodewordsPerBlock byte
|
||||
NumberOfBlocksInGroup1 byte
|
||||
DataCodeWordsPerBlockInGroup1 byte
|
||||
NumberOfBlocksInGroup2 byte
|
||||
DataCodeWordsPerBlockInGroup2 byte
|
||||
}
|
||||
|
||||
var versionInfos = []*versionInfo{
|
||||
&versionInfo{1, L, 7, 1, 19, 0, 0},
|
||||
&versionInfo{1, M, 10, 1, 16, 0, 0},
|
||||
&versionInfo{1, Q, 13, 1, 13, 0, 0},
|
||||
&versionInfo{1, H, 17, 1, 9, 0, 0},
|
||||
&versionInfo{2, L, 10, 1, 34, 0, 0},
|
||||
&versionInfo{2, M, 16, 1, 28, 0, 0},
|
||||
&versionInfo{2, Q, 22, 1, 22, 0, 0},
|
||||
&versionInfo{2, H, 28, 1, 16, 0, 0},
|
||||
&versionInfo{3, L, 15, 1, 55, 0, 0},
|
||||
&versionInfo{3, M, 26, 1, 44, 0, 0},
|
||||
&versionInfo{3, Q, 18, 2, 17, 0, 0},
|
||||
&versionInfo{3, H, 22, 2, 13, 0, 0},
|
||||
&versionInfo{4, L, 20, 1, 80, 0, 0},
|
||||
&versionInfo{4, M, 18, 2, 32, 0, 0},
|
||||
&versionInfo{4, Q, 26, 2, 24, 0, 0},
|
||||
&versionInfo{4, H, 16, 4, 9, 0, 0},
|
||||
&versionInfo{5, L, 26, 1, 108, 0, 0},
|
||||
&versionInfo{5, M, 24, 2, 43, 0, 0},
|
||||
&versionInfo{5, Q, 18, 2, 15, 2, 16},
|
||||
&versionInfo{5, H, 22, 2, 11, 2, 12},
|
||||
&versionInfo{6, L, 18, 2, 68, 0, 0},
|
||||
&versionInfo{6, M, 16, 4, 27, 0, 0},
|
||||
&versionInfo{6, Q, 24, 4, 19, 0, 0},
|
||||
&versionInfo{6, H, 28, 4, 15, 0, 0},
|
||||
&versionInfo{7, L, 20, 2, 78, 0, 0},
|
||||
&versionInfo{7, M, 18, 4, 31, 0, 0},
|
||||
&versionInfo{7, Q, 18, 2, 14, 4, 15},
|
||||
&versionInfo{7, H, 26, 4, 13, 1, 14},
|
||||
&versionInfo{8, L, 24, 2, 97, 0, 0},
|
||||
&versionInfo{8, M, 22, 2, 38, 2, 39},
|
||||
&versionInfo{8, Q, 22, 4, 18, 2, 19},
|
||||
&versionInfo{8, H, 26, 4, 14, 2, 15},
|
||||
&versionInfo{9, L, 30, 2, 116, 0, 0},
|
||||
&versionInfo{9, M, 22, 3, 36, 2, 37},
|
||||
&versionInfo{9, Q, 20, 4, 16, 4, 17},
|
||||
&versionInfo{9, H, 24, 4, 12, 4, 13},
|
||||
&versionInfo{10, L, 18, 2, 68, 2, 69},
|
||||
&versionInfo{10, M, 26, 4, 43, 1, 44},
|
||||
&versionInfo{10, Q, 24, 6, 19, 2, 20},
|
||||
&versionInfo{10, H, 28, 6, 15, 2, 16},
|
||||
&versionInfo{11, L, 20, 4, 81, 0, 0},
|
||||
&versionInfo{11, M, 30, 1, 50, 4, 51},
|
||||
&versionInfo{11, Q, 28, 4, 22, 4, 23},
|
||||
&versionInfo{11, H, 24, 3, 12, 8, 13},
|
||||
&versionInfo{12, L, 24, 2, 92, 2, 93},
|
||||
&versionInfo{12, M, 22, 6, 36, 2, 37},
|
||||
&versionInfo{12, Q, 26, 4, 20, 6, 21},
|
||||
&versionInfo{12, H, 28, 7, 14, 4, 15},
|
||||
&versionInfo{13, L, 26, 4, 107, 0, 0},
|
||||
&versionInfo{13, M, 22, 8, 37, 1, 38},
|
||||
&versionInfo{13, Q, 24, 8, 20, 4, 21},
|
||||
&versionInfo{13, H, 22, 12, 11, 4, 12},
|
||||
&versionInfo{14, L, 30, 3, 115, 1, 116},
|
||||
&versionInfo{14, M, 24, 4, 40, 5, 41},
|
||||
&versionInfo{14, Q, 20, 11, 16, 5, 17},
|
||||
&versionInfo{14, H, 24, 11, 12, 5, 13},
|
||||
&versionInfo{15, L, 22, 5, 87, 1, 88},
|
||||
&versionInfo{15, M, 24, 5, 41, 5, 42},
|
||||
&versionInfo{15, Q, 30, 5, 24, 7, 25},
|
||||
&versionInfo{15, H, 24, 11, 12, 7, 13},
|
||||
&versionInfo{16, L, 24, 5, 98, 1, 99},
|
||||
&versionInfo{16, M, 28, 7, 45, 3, 46},
|
||||
&versionInfo{16, Q, 24, 15, 19, 2, 20},
|
||||
&versionInfo{16, H, 30, 3, 15, 13, 16},
|
||||
&versionInfo{17, L, 28, 1, 107, 5, 108},
|
||||
&versionInfo{17, M, 28, 10, 46, 1, 47},
|
||||
&versionInfo{17, Q, 28, 1, 22, 15, 23},
|
||||
&versionInfo{17, H, 28, 2, 14, 17, 15},
|
||||
&versionInfo{18, L, 30, 5, 120, 1, 121},
|
||||
&versionInfo{18, M, 26, 9, 43, 4, 44},
|
||||
&versionInfo{18, Q, 28, 17, 22, 1, 23},
|
||||
&versionInfo{18, H, 28, 2, 14, 19, 15},
|
||||
&versionInfo{19, L, 28, 3, 113, 4, 114},
|
||||
&versionInfo{19, M, 26, 3, 44, 11, 45},
|
||||
&versionInfo{19, Q, 26, 17, 21, 4, 22},
|
||||
&versionInfo{19, H, 26, 9, 13, 16, 14},
|
||||
&versionInfo{20, L, 28, 3, 107, 5, 108},
|
||||
&versionInfo{20, M, 26, 3, 41, 13, 42},
|
||||
&versionInfo{20, Q, 30, 15, 24, 5, 25},
|
||||
&versionInfo{20, H, 28, 15, 15, 10, 16},
|
||||
&versionInfo{21, L, 28, 4, 116, 4, 117},
|
||||
&versionInfo{21, M, 26, 17, 42, 0, 0},
|
||||
&versionInfo{21, Q, 28, 17, 22, 6, 23},
|
||||
&versionInfo{21, H, 30, 19, 16, 6, 17},
|
||||
&versionInfo{22, L, 28, 2, 111, 7, 112},
|
||||
&versionInfo{22, M, 28, 17, 46, 0, 0},
|
||||
&versionInfo{22, Q, 30, 7, 24, 16, 25},
|
||||
&versionInfo{22, H, 24, 34, 13, 0, 0},
|
||||
&versionInfo{23, L, 30, 4, 121, 5, 122},
|
||||
&versionInfo{23, M, 28, 4, 47, 14, 48},
|
||||
&versionInfo{23, Q, 30, 11, 24, 14, 25},
|
||||
&versionInfo{23, H, 30, 16, 15, 14, 16},
|
||||
&versionInfo{24, L, 30, 6, 117, 4, 118},
|
||||
&versionInfo{24, M, 28, 6, 45, 14, 46},
|
||||
&versionInfo{24, Q, 30, 11, 24, 16, 25},
|
||||
&versionInfo{24, H, 30, 30, 16, 2, 17},
|
||||
&versionInfo{25, L, 26, 8, 106, 4, 107},
|
||||
&versionInfo{25, M, 28, 8, 47, 13, 48},
|
||||
&versionInfo{25, Q, 30, 7, 24, 22, 25},
|
||||
&versionInfo{25, H, 30, 22, 15, 13, 16},
|
||||
&versionInfo{26, L, 28, 10, 114, 2, 115},
|
||||
&versionInfo{26, M, 28, 19, 46, 4, 47},
|
||||
&versionInfo{26, Q, 28, 28, 22, 6, 23},
|
||||
&versionInfo{26, H, 30, 33, 16, 4, 17},
|
||||
&versionInfo{27, L, 30, 8, 122, 4, 123},
|
||||
&versionInfo{27, M, 28, 22, 45, 3, 46},
|
||||
&versionInfo{27, Q, 30, 8, 23, 26, 24},
|
||||
&versionInfo{27, H, 30, 12, 15, 28, 16},
|
||||
&versionInfo{28, L, 30, 3, 117, 10, 118},
|
||||
&versionInfo{28, M, 28, 3, 45, 23, 46},
|
||||
&versionInfo{28, Q, 30, 4, 24, 31, 25},
|
||||
&versionInfo{28, H, 30, 11, 15, 31, 16},
|
||||
&versionInfo{29, L, 30, 7, 116, 7, 117},
|
||||
&versionInfo{29, M, 28, 21, 45, 7, 46},
|
||||
&versionInfo{29, Q, 30, 1, 23, 37, 24},
|
||||
&versionInfo{29, H, 30, 19, 15, 26, 16},
|
||||
&versionInfo{30, L, 30, 5, 115, 10, 116},
|
||||
&versionInfo{30, M, 28, 19, 47, 10, 48},
|
||||
&versionInfo{30, Q, 30, 15, 24, 25, 25},
|
||||
&versionInfo{30, H, 30, 23, 15, 25, 16},
|
||||
&versionInfo{31, L, 30, 13, 115, 3, 116},
|
||||
&versionInfo{31, M, 28, 2, 46, 29, 47},
|
||||
&versionInfo{31, Q, 30, 42, 24, 1, 25},
|
||||
&versionInfo{31, H, 30, 23, 15, 28, 16},
|
||||
&versionInfo{32, L, 30, 17, 115, 0, 0},
|
||||
&versionInfo{32, M, 28, 10, 46, 23, 47},
|
||||
&versionInfo{32, Q, 30, 10, 24, 35, 25},
|
||||
&versionInfo{32, H, 30, 19, 15, 35, 16},
|
||||
&versionInfo{33, L, 30, 17, 115, 1, 116},
|
||||
&versionInfo{33, M, 28, 14, 46, 21, 47},
|
||||
&versionInfo{33, Q, 30, 29, 24, 19, 25},
|
||||
&versionInfo{33, H, 30, 11, 15, 46, 16},
|
||||
&versionInfo{34, L, 30, 13, 115, 6, 116},
|
||||
&versionInfo{34, M, 28, 14, 46, 23, 47},
|
||||
&versionInfo{34, Q, 30, 44, 24, 7, 25},
|
||||
&versionInfo{34, H, 30, 59, 16, 1, 17},
|
||||
&versionInfo{35, L, 30, 12, 121, 7, 122},
|
||||
&versionInfo{35, M, 28, 12, 47, 26, 48},
|
||||
&versionInfo{35, Q, 30, 39, 24, 14, 25},
|
||||
&versionInfo{35, H, 30, 22, 15, 41, 16},
|
||||
&versionInfo{36, L, 30, 6, 121, 14, 122},
|
||||
&versionInfo{36, M, 28, 6, 47, 34, 48},
|
||||
&versionInfo{36, Q, 30, 46, 24, 10, 25},
|
||||
&versionInfo{36, H, 30, 2, 15, 64, 16},
|
||||
&versionInfo{37, L, 30, 17, 122, 4, 123},
|
||||
&versionInfo{37, M, 28, 29, 46, 14, 47},
|
||||
&versionInfo{37, Q, 30, 49, 24, 10, 25},
|
||||
&versionInfo{37, H, 30, 24, 15, 46, 16},
|
||||
&versionInfo{38, L, 30, 4, 122, 18, 123},
|
||||
&versionInfo{38, M, 28, 13, 46, 32, 47},
|
||||
&versionInfo{38, Q, 30, 48, 24, 14, 25},
|
||||
&versionInfo{38, H, 30, 42, 15, 32, 16},
|
||||
&versionInfo{39, L, 30, 20, 117, 4, 118},
|
||||
&versionInfo{39, M, 28, 40, 47, 7, 48},
|
||||
&versionInfo{39, Q, 30, 43, 24, 22, 25},
|
||||
&versionInfo{39, H, 30, 10, 15, 67, 16},
|
||||
&versionInfo{40, L, 30, 19, 118, 6, 119},
|
||||
&versionInfo{40, M, 28, 18, 47, 31, 48},
|
||||
&versionInfo{40, Q, 30, 34, 24, 34, 25},
|
||||
&versionInfo{40, H, 30, 20, 15, 61, 16},
|
||||
}
|
||||
|
||||
func (vi *versionInfo) totalDataBytes() int {
|
||||
g1Data := int(vi.NumberOfBlocksInGroup1) * int(vi.DataCodeWordsPerBlockInGroup1)
|
||||
g2Data := int(vi.NumberOfBlocksInGroup2) * int(vi.DataCodeWordsPerBlockInGroup2)
|
||||
return (g1Data + g2Data)
|
||||
}
|
||||
|
||||
func (vi *versionInfo) charCountBits(m encodingMode) byte {
|
||||
switch m {
|
||||
case numericMode:
|
||||
if vi.Version < 10 {
|
||||
return 10
|
||||
} else if vi.Version < 27 {
|
||||
return 12
|
||||
}
|
||||
return 14
|
||||
|
||||
case alphaNumericMode:
|
||||
if vi.Version < 10 {
|
||||
return 9
|
||||
} else if vi.Version < 27 {
|
||||
return 11
|
||||
}
|
||||
return 13
|
||||
|
||||
case byteMode:
|
||||
if vi.Version < 10 {
|
||||
return 8
|
||||
}
|
||||
return 16
|
||||
|
||||
case kanjiMode:
|
||||
if vi.Version < 10 {
|
||||
return 8
|
||||
} else if vi.Version < 27 {
|
||||
return 10
|
||||
}
|
||||
return 12
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (vi *versionInfo) modulWidth() int {
|
||||
return ((int(vi.Version) - 1) * 4) + 21
|
||||
}
|
||||
|
||||
func (vi *versionInfo) alignmentPatternPlacements() []int {
|
||||
if vi.Version == 1 {
|
||||
return make([]int, 0)
|
||||
}
|
||||
|
||||
first := 6
|
||||
last := vi.modulWidth() - 7
|
||||
space := float64(last - first)
|
||||
count := int(math.Ceil(space/28)) + 1
|
||||
|
||||
result := make([]int, count)
|
||||
result[0] = first
|
||||
result[len(result)-1] = last
|
||||
if count > 2 {
|
||||
step := int(math.Ceil(float64(last-first) / float64(count-1)))
|
||||
if step%2 == 1 {
|
||||
frac := float64(last-first) / float64(count-1)
|
||||
_, x := math.Modf(frac)
|
||||
if x >= 0.5 {
|
||||
frac = math.Ceil(frac)
|
||||
} else {
|
||||
frac = math.Floor(frac)
|
||||
}
|
||||
|
||||
if int(frac)%2 == 0 {
|
||||
step--
|
||||
} else {
|
||||
step++
|
||||
}
|
||||
}
|
||||
|
||||
for i := 1; i <= count-2; i++ {
|
||||
result[i] = last - (step * (count - 1 - i))
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func findSmallestVersionInfo(ecl ErrorCorrectionLevel, mode encodingMode, dataBits int) *versionInfo {
|
||||
dataBits = dataBits + 4 // mode indicator
|
||||
for _, vi := range versionInfos {
|
||||
if vi.Level == ecl {
|
||||
if (vi.totalDataBytes() * 8) >= (dataBits + int(vi.charCountBits(mode))) {
|
||||
return vi
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,134 @@
|
|||
package barcode
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"math"
|
||||
)
|
||||
|
||||
type wrapFunc func(x, y int) color.Color
|
||||
|
||||
type scaledBarcode struct {
|
||||
wrapped Barcode
|
||||
wrapperFunc wrapFunc
|
||||
rect image.Rectangle
|
||||
}
|
||||
|
||||
type intCSscaledBC struct {
|
||||
scaledBarcode
|
||||
}
|
||||
|
||||
func (bc *scaledBarcode) Content() string {
|
||||
return bc.wrapped.Content()
|
||||
}
|
||||
|
||||
func (bc *scaledBarcode) Metadata() Metadata {
|
||||
return bc.wrapped.Metadata()
|
||||
}
|
||||
|
||||
func (bc *scaledBarcode) ColorModel() color.Model {
|
||||
return bc.wrapped.ColorModel()
|
||||
}
|
||||
|
||||
func (bc *scaledBarcode) Bounds() image.Rectangle {
|
||||
return bc.rect
|
||||
}
|
||||
|
||||
func (bc *scaledBarcode) At(x, y int) color.Color {
|
||||
return bc.wrapperFunc(x, y)
|
||||
}
|
||||
|
||||
func (bc *intCSscaledBC) CheckSum() int {
|
||||
if cs, ok := bc.wrapped.(BarcodeIntCS); ok {
|
||||
return cs.CheckSum()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Scale returns a resized barcode with the given width and height.
|
||||
func Scale(bc Barcode, width, height int) (Barcode, error) {
|
||||
switch bc.Metadata().Dimensions {
|
||||
case 1:
|
||||
return scale1DCode(bc, width, height)
|
||||
case 2:
|
||||
return scale2DCode(bc, width, height)
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported barcode format")
|
||||
}
|
||||
|
||||
func newScaledBC(wrapped Barcode, wrapperFunc wrapFunc, rect image.Rectangle) Barcode {
|
||||
result := &scaledBarcode{
|
||||
wrapped: wrapped,
|
||||
wrapperFunc: wrapperFunc,
|
||||
rect: rect,
|
||||
}
|
||||
|
||||
if _, ok := wrapped.(BarcodeIntCS); ok {
|
||||
return &intCSscaledBC{*result}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func scale2DCode(bc Barcode, width, height int) (Barcode, error) {
|
||||
orgBounds := bc.Bounds()
|
||||
orgWidth := orgBounds.Max.X - orgBounds.Min.X
|
||||
orgHeight := orgBounds.Max.Y - orgBounds.Min.Y
|
||||
|
||||
factor := int(math.Min(float64(width)/float64(orgWidth), float64(height)/float64(orgHeight)))
|
||||
if factor <= 0 {
|
||||
return nil, fmt.Errorf("can not scale barcode to an image smaller than %dx%d", orgWidth, orgHeight)
|
||||
}
|
||||
|
||||
offsetX := (width - (orgWidth * factor)) / 2
|
||||
offsetY := (height - (orgHeight * factor)) / 2
|
||||
|
||||
wrap := func(x, y int) color.Color {
|
||||
if x < offsetX || y < offsetY {
|
||||
return color.White
|
||||
}
|
||||
x = (x - offsetX) / factor
|
||||
y = (y - offsetY) / factor
|
||||
if x >= orgWidth || y >= orgHeight {
|
||||
return color.White
|
||||
}
|
||||
return bc.At(x, y)
|
||||
}
|
||||
|
||||
return newScaledBC(
|
||||
bc,
|
||||
wrap,
|
||||
image.Rect(0, 0, width, height),
|
||||
), nil
|
||||
}
|
||||
|
||||
func scale1DCode(bc Barcode, width, height int) (Barcode, error) {
|
||||
orgBounds := bc.Bounds()
|
||||
orgWidth := orgBounds.Max.X - orgBounds.Min.X
|
||||
factor := int(float64(width) / float64(orgWidth))
|
||||
|
||||
if factor <= 0 {
|
||||
return nil, fmt.Errorf("can not scale barcode to an image smaller than %dx1", orgWidth)
|
||||
}
|
||||
offsetX := (width - (orgWidth * factor)) / 2
|
||||
|
||||
wrap := func(x, y int) color.Color {
|
||||
if x < offsetX {
|
||||
return color.White
|
||||
}
|
||||
x = (x - offsetX) / factor
|
||||
|
||||
if x >= orgWidth {
|
||||
return color.White
|
||||
}
|
||||
return bc.At(x, 0)
|
||||
}
|
||||
|
||||
return newScaledBC(
|
||||
bc,
|
||||
wrap,
|
||||
image.Rect(0, 0, width, height),
|
||||
), nil
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
// Package utils contain some utilities which are needed to create barcodes
|
||||
package utils
|
||||
|
||||
import (
|
||||
"image"
|
||||
"image/color"
|
||||
|
||||
"github.com/boombuler/barcode"
|
||||
)
|
||||
|
||||
type base1DCode struct {
|
||||
*BitList
|
||||
kind string
|
||||
content string
|
||||
}
|
||||
|
||||
type base1DCodeIntCS struct {
|
||||
base1DCode
|
||||
checksum int
|
||||
}
|
||||
|
||||
func (c *base1DCode) Content() string {
|
||||
return c.content
|
||||
}
|
||||
|
||||
func (c *base1DCode) Metadata() barcode.Metadata {
|
||||
return barcode.Metadata{c.kind, 1}
|
||||
}
|
||||
|
||||
func (c *base1DCode) ColorModel() color.Model {
|
||||
return color.Gray16Model
|
||||
}
|
||||
|
||||
func (c *base1DCode) Bounds() image.Rectangle {
|
||||
return image.Rect(0, 0, c.Len(), 1)
|
||||
}
|
||||
|
||||
func (c *base1DCode) At(x, y int) color.Color {
|
||||
if c.GetBit(x) {
|
||||
return color.Black
|
||||
}
|
||||
return color.White
|
||||
}
|
||||
|
||||
func (c *base1DCodeIntCS) CheckSum() int {
|
||||
return c.checksum
|
||||
}
|
||||
|
||||
// New1DCode creates a new 1D barcode where the bars are represented by the bits in the bars BitList
|
||||
func New1DCodeIntCheckSum(codeKind, content string, bars *BitList, checksum int) barcode.BarcodeIntCS {
|
||||
return &base1DCodeIntCS{base1DCode{bars, codeKind, content}, checksum}
|
||||
}
|
||||
|
||||
// New1DCode creates a new 1D barcode where the bars are represented by the bits in the bars BitList
|
||||
func New1DCode(codeKind, content string, bars *BitList) barcode.Barcode {
|
||||
return &base1DCode{bars, codeKind, content}
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
package utils
|
||||
|
||||
// BitList is a list that contains bits
|
||||
type BitList struct {
|
||||
count int
|
||||
data []int32
|
||||
}
|
||||
|
||||
// NewBitList returns a new BitList with the given length
|
||||
// all bits are initialize with false
|
||||
func NewBitList(capacity int) *BitList {
|
||||
bl := new(BitList)
|
||||
bl.count = capacity
|
||||
x := 0
|
||||
if capacity%32 != 0 {
|
||||
x = 1
|
||||
}
|
||||
bl.data = make([]int32, capacity/32+x)
|
||||
return bl
|
||||
}
|
||||
|
||||
// Len returns the number of contained bits
|
||||
func (bl *BitList) Len() int {
|
||||
return bl.count
|
||||
}
|
||||
|
||||
func (bl *BitList) grow() {
|
||||
growBy := len(bl.data)
|
||||
if growBy < 128 {
|
||||
growBy = 128
|
||||
} else if growBy >= 1024 {
|
||||
growBy = 1024
|
||||
}
|
||||
|
||||
nd := make([]int32, len(bl.data)+growBy)
|
||||
copy(nd, bl.data)
|
||||
bl.data = nd
|
||||
}
|
||||
|
||||
// AddBit appends the given bits to the end of the list
|
||||
func (bl *BitList) AddBit(bits ...bool) {
|
||||
for _, bit := range bits {
|
||||
itmIndex := bl.count / 32
|
||||
for itmIndex >= len(bl.data) {
|
||||
bl.grow()
|
||||
}
|
||||
bl.SetBit(bl.count, bit)
|
||||
bl.count++
|
||||
}
|
||||
}
|
||||
|
||||
// SetBit sets the bit at the given index to the given value
|
||||
func (bl *BitList) SetBit(index int, value bool) {
|
||||
itmIndex := index / 32
|
||||
itmBitShift := 31 - (index % 32)
|
||||
if value {
|
||||
bl.data[itmIndex] = bl.data[itmIndex] | 1<<uint(itmBitShift)
|
||||
} else {
|
||||
bl.data[itmIndex] = bl.data[itmIndex] & ^(1 << uint(itmBitShift))
|
||||
}
|
||||
}
|
||||
|
||||
// GetBit returns the bit at the given index
|
||||
func (bl *BitList) GetBit(index int) bool {
|
||||
itmIndex := index / 32
|
||||
itmBitShift := 31 - (index % 32)
|
||||
return ((bl.data[itmIndex] >> uint(itmBitShift)) & 1) == 1
|
||||
}
|
||||
|
||||
// AddByte appends all 8 bits of the given byte to the end of the list
|
||||
func (bl *BitList) AddByte(b byte) {
|
||||
for i := 7; i >= 0; i-- {
|
||||
bl.AddBit(((b >> uint(i)) & 1) == 1)
|
||||
}
|
||||
}
|
||||
|
||||
// AddBits appends the last (LSB) 'count' bits of 'b' the the end of the list
|
||||
func (bl *BitList) AddBits(b int, count byte) {
|
||||
for i := int(count) - 1; i >= 0; i-- {
|
||||
bl.AddBit(((b >> uint(i)) & 1) == 1)
|
||||
}
|
||||
}
|
||||
|
||||
// GetBytes returns all bits of the BitList as a []byte
|
||||
func (bl *BitList) GetBytes() []byte {
|
||||
len := bl.count >> 3
|
||||
if (bl.count % 8) != 0 {
|
||||
len++
|
||||
}
|
||||
result := make([]byte, len)
|
||||
for i := 0; i < len; i++ {
|
||||
shift := (3 - (i % 4)) * 8
|
||||
result[i] = (byte)((bl.data[i/4] >> uint(shift)) & 0xFF)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// IterateBytes iterates through all bytes contained in the BitList
|
||||
func (bl *BitList) IterateBytes() <-chan byte {
|
||||
res := make(chan byte)
|
||||
|
||||
go func() {
|
||||
c := bl.count
|
||||
shift := 24
|
||||
i := 0
|
||||
for c > 0 {
|
||||
res <- byte((bl.data[i] >> uint(shift)) & 0xFF)
|
||||
shift -= 8
|
||||
if shift < 0 {
|
||||
shift = 24
|
||||
i++
|
||||
}
|
||||
c -= 8
|
||||
}
|
||||
close(res)
|
||||
}()
|
||||
|
||||
return res
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package utils
|
||||
|
||||
// GaloisField encapsulates galois field arithmetics
|
||||
type GaloisField struct {
|
||||
Size int
|
||||
Base int
|
||||
ALogTbl []int
|
||||
LogTbl []int
|
||||
}
|
||||
|
||||
// NewGaloisField creates a new galois field
|
||||
func NewGaloisField(pp, fieldSize, b int) *GaloisField {
|
||||
result := new(GaloisField)
|
||||
|
||||
result.Size = fieldSize
|
||||
result.Base = b
|
||||
result.ALogTbl = make([]int, fieldSize)
|
||||
result.LogTbl = make([]int, fieldSize)
|
||||
|
||||
x := 1
|
||||
for i := 0; i < fieldSize; i++ {
|
||||
result.ALogTbl[i] = x
|
||||
x = x * 2
|
||||
if x >= fieldSize {
|
||||
x = (x ^ pp) & (fieldSize - 1)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < fieldSize; i++ {
|
||||
result.LogTbl[result.ALogTbl[i]] = int(i)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (gf *GaloisField) Zero() *GFPoly {
|
||||
return NewGFPoly(gf, []int{0})
|
||||
}
|
||||
|
||||
// AddOrSub add or substract two numbers
|
||||
func (gf *GaloisField) AddOrSub(a, b int) int {
|
||||
return a ^ b
|
||||
}
|
||||
|
||||
// Multiply multiplys two numbers
|
||||
func (gf *GaloisField) Multiply(a, b int) int {
|
||||
if a == 0 || b == 0 {
|
||||
return 0
|
||||
}
|
||||
return gf.ALogTbl[(gf.LogTbl[a]+gf.LogTbl[b])%(gf.Size-1)]
|
||||
}
|
||||
|
||||
// Divide divides two numbers
|
||||
func (gf *GaloisField) Divide(a, b int) int {
|
||||
if b == 0 {
|
||||
panic("divide by zero")
|
||||
} else if a == 0 {
|
||||
return 0
|
||||
}
|
||||
return gf.ALogTbl[(gf.LogTbl[a]-gf.LogTbl[b])%(gf.Size-1)]
|
||||
}
|
||||
|
||||
func (gf *GaloisField) Invers(num int) int {
|
||||
return gf.ALogTbl[(gf.Size-1)-gf.LogTbl[num]]
|
||||
}
|
|
@ -0,0 +1,103 @@
|
|||
package utils
|
||||
|
||||
type GFPoly struct {
|
||||
gf *GaloisField
|
||||
Coefficients []int
|
||||
}
|
||||
|
||||
func (gp *GFPoly) Degree() int {
|
||||
return len(gp.Coefficients) - 1
|
||||
}
|
||||
|
||||
func (gp *GFPoly) Zero() bool {
|
||||
return gp.Coefficients[0] == 0
|
||||
}
|
||||
|
||||
// GetCoefficient returns the coefficient of x ^ degree
|
||||
func (gp *GFPoly) GetCoefficient(degree int) int {
|
||||
return gp.Coefficients[gp.Degree()-degree]
|
||||
}
|
||||
|
||||
func (gp *GFPoly) AddOrSubstract(other *GFPoly) *GFPoly {
|
||||
if gp.Zero() {
|
||||
return other
|
||||
} else if other.Zero() {
|
||||
return gp
|
||||
}
|
||||
smallCoeff := gp.Coefficients
|
||||
largeCoeff := other.Coefficients
|
||||
if len(smallCoeff) > len(largeCoeff) {
|
||||
largeCoeff, smallCoeff = smallCoeff, largeCoeff
|
||||
}
|
||||
sumDiff := make([]int, len(largeCoeff))
|
||||
lenDiff := len(largeCoeff) - len(smallCoeff)
|
||||
copy(sumDiff, largeCoeff[:lenDiff])
|
||||
for i := lenDiff; i < len(largeCoeff); i++ {
|
||||
sumDiff[i] = int(gp.gf.AddOrSub(int(smallCoeff[i-lenDiff]), int(largeCoeff[i])))
|
||||
}
|
||||
return NewGFPoly(gp.gf, sumDiff)
|
||||
}
|
||||
|
||||
func (gp *GFPoly) MultByMonominal(degree int, coeff int) *GFPoly {
|
||||
if coeff == 0 {
|
||||
return gp.gf.Zero()
|
||||
}
|
||||
size := len(gp.Coefficients)
|
||||
result := make([]int, size+degree)
|
||||
for i := 0; i < size; i++ {
|
||||
result[i] = int(gp.gf.Multiply(int(gp.Coefficients[i]), int(coeff)))
|
||||
}
|
||||
return NewGFPoly(gp.gf, result)
|
||||
}
|
||||
|
||||
func (gp *GFPoly) Multiply(other *GFPoly) *GFPoly {
|
||||
if gp.Zero() || other.Zero() {
|
||||
return gp.gf.Zero()
|
||||
}
|
||||
aCoeff := gp.Coefficients
|
||||
aLen := len(aCoeff)
|
||||
bCoeff := other.Coefficients
|
||||
bLen := len(bCoeff)
|
||||
product := make([]int, aLen+bLen-1)
|
||||
for i := 0; i < aLen; i++ {
|
||||
ac := int(aCoeff[i])
|
||||
for j := 0; j < bLen; j++ {
|
||||
bc := int(bCoeff[j])
|
||||
product[i+j] = int(gp.gf.AddOrSub(int(product[i+j]), gp.gf.Multiply(ac, bc)))
|
||||
}
|
||||
}
|
||||
return NewGFPoly(gp.gf, product)
|
||||
}
|
||||
|
||||
func (gp *GFPoly) Divide(other *GFPoly) (quotient *GFPoly, remainder *GFPoly) {
|
||||
quotient = gp.gf.Zero()
|
||||
remainder = gp
|
||||
fld := gp.gf
|
||||
denomLeadTerm := other.GetCoefficient(other.Degree())
|
||||
inversDenomLeadTerm := fld.Invers(int(denomLeadTerm))
|
||||
for remainder.Degree() >= other.Degree() && !remainder.Zero() {
|
||||
degreeDiff := remainder.Degree() - other.Degree()
|
||||
scale := int(fld.Multiply(int(remainder.GetCoefficient(remainder.Degree())), inversDenomLeadTerm))
|
||||
term := other.MultByMonominal(degreeDiff, scale)
|
||||
itQuot := NewMonominalPoly(fld, degreeDiff, scale)
|
||||
quotient = quotient.AddOrSubstract(itQuot)
|
||||
remainder = remainder.AddOrSubstract(term)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func NewMonominalPoly(field *GaloisField, degree int, coeff int) *GFPoly {
|
||||
if coeff == 0 {
|
||||
return field.Zero()
|
||||
}
|
||||
result := make([]int, degree+1)
|
||||
result[0] = coeff
|
||||
return NewGFPoly(field, result)
|
||||
}
|
||||
|
||||
func NewGFPoly(field *GaloisField, coefficients []int) *GFPoly {
|
||||
for len(coefficients) > 1 && coefficients[0] == 0 {
|
||||
coefficients = coefficients[1:]
|
||||
}
|
||||
return &GFPoly{field, coefficients}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ReedSolomonEncoder struct {
|
||||
gf *GaloisField
|
||||
polynomes []*GFPoly
|
||||
m *sync.Mutex
|
||||
}
|
||||
|
||||
func NewReedSolomonEncoder(gf *GaloisField) *ReedSolomonEncoder {
|
||||
return &ReedSolomonEncoder{
|
||||
gf, []*GFPoly{NewGFPoly(gf, []int{1})}, new(sync.Mutex),
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *ReedSolomonEncoder) getPolynomial(degree int) *GFPoly {
|
||||
rs.m.Lock()
|
||||
defer rs.m.Unlock()
|
||||
|
||||
if degree >= len(rs.polynomes) {
|
||||
last := rs.polynomes[len(rs.polynomes)-1]
|
||||
for d := len(rs.polynomes); d <= degree; d++ {
|
||||
next := last.Multiply(NewGFPoly(rs.gf, []int{1, rs.gf.ALogTbl[d-1+rs.gf.Base]}))
|
||||
rs.polynomes = append(rs.polynomes, next)
|
||||
last = next
|
||||
}
|
||||
}
|
||||
return rs.polynomes[degree]
|
||||
}
|
||||
|
||||
func (rs *ReedSolomonEncoder) Encode(data []int, eccCount int) []int {
|
||||
generator := rs.getPolynomial(eccCount)
|
||||
info := NewGFPoly(rs.gf, data)
|
||||
info = info.MultByMonominal(eccCount, 1)
|
||||
_, remainder := info.Divide(generator)
|
||||
|
||||
result := make([]int, eccCount)
|
||||
numZero := int(eccCount) - len(remainder.Coefficients)
|
||||
copy(result[numZero:], remainder.Coefficients)
|
||||
return result
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package utils
|
||||
|
||||
// RuneToInt converts a rune between '0' and '9' to an integer between 0 and 9
|
||||
// If the rune is outside of this range -1 is returned.
|
||||
func RuneToInt(r rune) int {
|
||||
if r >= '0' && r <= '9' {
|
||||
return int(r - '0')
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// IntToRune converts a digit 0 - 9 to the rune '0' - '9'. If the given int is outside
|
||||
// of this range 'F' is returned!
|
||||
func IntToRune(i int) rune {
|
||||
if i >= 0 && i <= 9 {
|
||||
return rune(i + '0')
|
||||
}
|
||||
return 'F'
|
||||
}
|
|
@ -0,0 +1,353 @@
|
|||
Mozilla Public License, version 2.0
|
||||
|
||||
1. Definitions
|
||||
|
||||
1.1. “Contributor”
|
||||
|
||||
means each individual or legal entity that creates, contributes to the
|
||||
creation of, or owns Covered Software.
|
||||
|
||||
1.2. “Contributor Version”
|
||||
|
||||
means the combination of the Contributions of others (if any) used by a
|
||||
Contributor and that particular Contributor’s Contribution.
|
||||
|
||||
1.3. “Contribution”
|
||||
|
||||
means Covered Software of a particular Contributor.
|
||||
|
||||
1.4. “Covered Software”
|
||||
|
||||
means Source Code Form to which the initial Contributor has attached the
|
||||
notice in Exhibit A, the Executable Form of such Source Code Form, and
|
||||
Modifications of such Source Code Form, in each case including portions
|
||||
thereof.
|
||||
|
||||
1.5. “Incompatible With Secondary Licenses”
|
||||
means
|
||||
|
||||
a. that the initial Contributor has attached the notice described in
|
||||
Exhibit B to the Covered Software; or
|
||||
|
||||
b. that the Covered Software was made available under the terms of version
|
||||
1.1 or earlier of the License, but not also under the terms of a
|
||||
Secondary License.
|
||||
|
||||
1.6. “Executable Form”
|
||||
|
||||
means any form of the work other than Source Code Form.
|
||||
|
||||
1.7. “Larger Work”
|
||||
|
||||
means a work that combines Covered Software with other material, in a separate
|
||||
file or files, that is not Covered Software.
|
||||
|
||||
1.8. “License”
|
||||
|
||||
means this document.
|
||||
|
||||
1.9. “Licensable”
|
||||
|
||||
means having the right to grant, to the maximum extent possible, whether at the
|
||||
time of the initial grant or subsequently, any and all of the rights conveyed by
|
||||
this License.
|
||||
|
||||
1.10. “Modifications”
|
||||
|
||||
means any of the following:
|
||||
|
||||
a. any file in Source Code Form that results from an addition to, deletion
|
||||
from, or modification of the contents of Covered Software; or
|
||||
|
||||
b. any new file in Source Code Form that contains any Covered Software.
|
||||
|
||||
1.11. “Patent Claims” of a Contributor
|
||||
|
||||
means any patent claim(s), including without limitation, method, process,
|
||||
and apparatus claims, in any patent Licensable by such Contributor that
|
||||
would be infringed, but for the grant of the License, by the making,
|
||||
using, selling, offering for sale, having made, import, or transfer of
|
||||
either its Contributions or its Contributor Version.
|
||||
|
||||
1.12. “Secondary License”
|
||||
|
||||
means either the GNU General Public License, Version 2.0, the GNU Lesser
|
||||
General Public License, Version 2.1, the GNU Affero General Public
|
||||
License, Version 3.0, or any later versions of those licenses.
|
||||
|
||||
1.13. “Source Code Form”
|
||||
|
||||
means the form of the work preferred for making modifications.
|
||||
|
||||
1.14. “You” (or “Your”)
|
||||
|
||||
means an individual or a legal entity exercising rights under this
|
||||
License. For legal entities, “You” includes any entity that controls, is
|
||||
controlled by, or is under common control with You. For purposes of this
|
||||
definition, “control” means (a) the power, direct or indirect, to cause
|
||||
the direction or management of such entity, whether by contract or
|
||||
otherwise, or (b) ownership of more than fifty percent (50%) of the
|
||||
outstanding shares or beneficial ownership of such entity.
|
||||
|
||||
|
||||
2. License Grants and Conditions
|
||||
|
||||
2.1. Grants
|
||||
|
||||
Each Contributor hereby grants You a world-wide, royalty-free,
|
||||
non-exclusive license:
|
||||
|
||||
a. under intellectual property rights (other than patent or trademark)
|
||||
Licensable by such Contributor to use, reproduce, make available,
|
||||
modify, display, perform, distribute, and otherwise exploit its
|
||||
Contributions, either on an unmodified basis, with Modifications, or as
|
||||
part of a Larger Work; and
|
||||
|
||||
b. under Patent Claims of such Contributor to make, use, sell, offer for
|
||||
sale, have made, import, and otherwise transfer either its Contributions
|
||||
or its Contributor Version.
|
||||
|
||||
2.2. Effective Date
|
||||
|
||||
The licenses granted in Section 2.1 with respect to any Contribution become
|
||||
effective for each Contribution on the date the Contributor first distributes
|
||||
such Contribution.
|
||||
|
||||
2.3. Limitations on Grant Scope
|
||||
|
||||
The licenses granted in this Section 2 are the only rights granted under this
|
||||
License. No additional rights or licenses will be implied from the distribution
|
||||
or licensing of Covered Software under this License. Notwithstanding Section
|
||||
2.1(b) above, no patent license is granted by a Contributor:
|
||||
|
||||
a. for any code that a Contributor has removed from Covered Software; or
|
||||
|
||||
b. for infringements caused by: (i) Your and any other third party’s
|
||||
modifications of Covered Software, or (ii) the combination of its
|
||||
Contributions with other software (except as part of its Contributor
|
||||
Version); or
|
||||
|
||||
c. under Patent Claims infringed by Covered Software in the absence of its
|
||||
Contributions.
|
||||
|
||||
This License does not grant any rights in the trademarks, service marks, or
|
||||
logos of any Contributor (except as may be necessary to comply with the
|
||||
notice requirements in Section 3.4).
|
||||
|
||||
2.4. Subsequent Licenses
|
||||
|
||||
No Contributor makes additional grants as a result of Your choice to
|
||||
distribute the Covered Software under a subsequent version of this License
|
||||
(see Section 10.2) or under the terms of a Secondary License (if permitted
|
||||
under the terms of Section 3.3).
|
||||
|
||||
2.5. Representation
|
||||
|
||||
Each Contributor represents that the Contributor believes its Contributions
|
||||
are its original creation(s) or it has sufficient rights to grant the
|
||||
rights to its Contributions conveyed by this License.
|
||||
|
||||
2.6. Fair Use
|
||||
|
||||
This License is not intended to limit any rights You have under applicable
|
||||
copyright doctrines of fair use, fair dealing, or other equivalents.
|
||||
|
||||
2.7. Conditions
|
||||
|
||||
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in
|
||||
Section 2.1.
|
||||
|
||||
|
||||
3. Responsibilities
|
||||
|
||||
3.1. Distribution of Source Form
|
||||
|
||||
All distribution of Covered Software in Source Code Form, including any
|
||||
Modifications that You create or to which You contribute, must be under the
|
||||
terms of this License. You must inform recipients that the Source Code Form
|
||||
of the Covered Software is governed by the terms of this License, and how
|
||||
they can obtain a copy of this License. You may not attempt to alter or
|
||||
restrict the recipients’ rights in the Source Code Form.
|
||||
|
||||
3.2. Distribution of Executable Form
|
||||
|
||||
If You distribute Covered Software in Executable Form then:
|
||||
|
||||
a. such Covered Software must also be made available in Source Code Form,
|
||||
as described in Section 3.1, and You must inform recipients of the
|
||||
Executable Form how they can obtain a copy of such Source Code Form by
|
||||
reasonable means in a timely manner, at a charge no more than the cost
|
||||
of distribution to the recipient; and
|
||||
|
||||
b. You may distribute such Executable Form under the terms of this License,
|
||||
or sublicense it under different terms, provided that the license for
|
||||
the Executable Form does not attempt to limit or alter the recipients’
|
||||
rights in the Source Code Form under this License.
|
||||
|
||||
3.3. Distribution of a Larger Work
|
||||
|
||||
You may create and distribute a Larger Work under terms of Your choice,
|
||||
provided that You also comply with the requirements of this License for the
|
||||
Covered Software. If the Larger Work is a combination of Covered Software
|
||||
with a work governed by one or more Secondary Licenses, and the Covered
|
||||
Software is not Incompatible With Secondary Licenses, this License permits
|
||||
You to additionally distribute such Covered Software under the terms of
|
||||
such Secondary License(s), so that the recipient of the Larger Work may, at
|
||||
their option, further distribute the Covered Software under the terms of
|
||||
either this License or such Secondary License(s).
|
||||
|
||||
3.4. Notices
|
||||
|
||||
You may not remove or alter the substance of any license notices (including
|
||||
copyright notices, patent notices, disclaimers of warranty, or limitations
|
||||
of liability) contained within the Source Code Form of the Covered
|
||||
Software, except that You may alter any license notices to the extent
|
||||
required to remedy known factual inaccuracies.
|
||||
|
||||
3.5. Application of Additional Terms
|
||||
|
||||
You may choose to offer, and to charge a fee for, warranty, support,
|
||||
indemnity or liability obligations to one or more recipients of Covered
|
||||
Software. However, You may do so only on Your own behalf, and not on behalf
|
||||
of any Contributor. You must make it absolutely clear that any such
|
||||
warranty, support, indemnity, or liability obligation is offered by You
|
||||
alone, and You hereby agree to indemnify every Contributor for any
|
||||
liability incurred by such Contributor as a result of warranty, support,
|
||||
indemnity or liability terms You offer. You may include additional
|
||||
disclaimers of warranty and limitations of liability specific to any
|
||||
jurisdiction.
|
||||
|
||||
4. Inability to Comply Due to Statute or Regulation
|
||||
|
||||
If it is impossible for You to comply with any of the terms of this License
|
||||
with respect to some or all of the Covered Software due to statute, judicial
|
||||
order, or regulation then You must: (a) comply with the terms of this License
|
||||
to the maximum extent possible; and (b) describe the limitations and the code
|
||||
they affect. Such description must be placed in a text file included with all
|
||||
distributions of the Covered Software under this License. Except to the
|
||||
extent prohibited by statute or regulation, such description must be
|
||||
sufficiently detailed for a recipient of ordinary skill to be able to
|
||||
understand it.
|
||||
|
||||
5. Termination
|
||||
|
||||
5.1. The rights granted under this License will terminate automatically if You
|
||||
fail to comply with any of its terms. However, if You become compliant,
|
||||
then the rights granted under this License from a particular Contributor
|
||||
are reinstated (a) provisionally, unless and until such Contributor
|
||||
explicitly and finally terminates Your grants, and (b) on an ongoing basis,
|
||||
if such Contributor fails to notify You of the non-compliance by some
|
||||
reasonable means prior to 60 days after You have come back into compliance.
|
||||
Moreover, Your grants from a particular Contributor are reinstated on an
|
||||
ongoing basis if such Contributor notifies You of the non-compliance by
|
||||
some reasonable means, this is the first time You have received notice of
|
||||
non-compliance with this License from such Contributor, and You become
|
||||
compliant prior to 30 days after Your receipt of the notice.
|
||||
|
||||
5.2. If You initiate litigation against any entity by asserting a patent
|
||||
infringement claim (excluding declaratory judgment actions, counter-claims,
|
||||
and cross-claims) alleging that a Contributor Version directly or
|
||||
indirectly infringes any patent, then the rights granted to You by any and
|
||||
all Contributors for the Covered Software under Section 2.1 of this License
|
||||
shall terminate.
|
||||
|
||||
5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user
|
||||
license agreements (excluding distributors and resellers) which have been
|
||||
validly granted by You or Your distributors under this License prior to
|
||||
termination shall survive termination.
|
||||
|
||||
6. Disclaimer of Warranty
|
||||
|
||||
Covered Software is provided under this License on an “as is” basis, without
|
||||
warranty of any kind, either expressed, implied, or statutory, including,
|
||||
without limitation, warranties that the Covered Software is free of defects,
|
||||
merchantable, fit for a particular purpose or non-infringing. The entire
|
||||
risk as to the quality and performance of the Covered Software is with You.
|
||||
Should any Covered Software prove defective in any respect, You (not any
|
||||
Contributor) assume the cost of any necessary servicing, repair, or
|
||||
correction. This disclaimer of warranty constitutes an essential part of this
|
||||
License. No use of any Covered Software is authorized under this License
|
||||
except under this disclaimer.
|
||||
|
||||
7. Limitation of Liability
|
||||
|
||||
Under no circumstances and under no legal theory, whether tort (including
|
||||
negligence), contract, or otherwise, shall any Contributor, or anyone who
|
||||
distributes Covered Software as permitted above, be liable to You for any
|
||||
direct, indirect, special, incidental, or consequential damages of any
|
||||
character including, without limitation, damages for lost profits, loss of
|
||||
goodwill, work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses, even if such party shall have been
|
||||
informed of the possibility of such damages. This limitation of liability
|
||||
shall not apply to liability for death or personal injury resulting from such
|
||||
party’s negligence to the extent applicable law prohibits such limitation.
|
||||
Some jurisdictions do not allow the exclusion or limitation of incidental or
|
||||
consequential damages, so this exclusion and limitation may not apply to You.
|
||||
|
||||
8. Litigation
|
||||
|
||||
Any litigation relating to this License may be brought only in the courts of
|
||||
a jurisdiction where the defendant maintains its principal place of business
|
||||
and such litigation shall be governed by laws of that jurisdiction, without
|
||||
reference to its conflict-of-law provisions. Nothing in this Section shall
|
||||
prevent a party’s ability to bring cross-claims or counter-claims.
|
||||
|
||||
9. Miscellaneous
|
||||
|
||||
This License represents the complete agreement concerning the subject matter
|
||||
hereof. If any provision of this License is held to be unenforceable, such
|
||||
provision shall be reformed only to the extent necessary to make it
|
||||
enforceable. Any law or regulation which provides that the language of a
|
||||
contract shall be construed against the drafter shall not be used to construe
|
||||
this License against a Contributor.
|
||||
|
||||
|
||||
10. Versions of the License
|
||||
|
||||
10.1. New Versions
|
||||
|
||||
Mozilla Foundation is the license steward. Except as provided in Section
|
||||
10.3, no one other than the license steward has the right to modify or
|
||||
publish new versions of this License. Each version will be given a
|
||||
distinguishing version number.
|
||||
|
||||
10.2. Effect of New Versions
|
||||
|
||||
You may distribute the Covered Software under the terms of the version of
|
||||
the License under which You originally received the Covered Software, or
|
||||
under the terms of any subsequent version published by the license
|
||||
steward.
|
||||
|
||||
10.3. Modified Versions
|
||||
|
||||
If you create software not governed by this License, and you want to
|
||||
create a new license for such software, you may create and use a modified
|
||||
version of this License if you rename the license and remove any
|
||||
references to the name of the license steward (except to note that such
|
||||
modified license differs from this License).
|
||||
|
||||
10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses
|
||||
If You choose to distribute Source Code Form that is Incompatible With
|
||||
Secondary Licenses under the terms of this version of the License, the
|
||||
notice described in Exhibit B of this License must be attached.
|
||||
|
||||
Exhibit A - Source Code Form License Notice
|
||||
|
||||
This Source Code Form is subject to the
|
||||
terms of the Mozilla Public License, v.
|
||||
2.0. If a copy of the MPL was not
|
||||
distributed with this file, You can
|
||||
obtain one at
|
||||
http://mozilla.org/MPL/2.0/.
|
||||
|
||||
If it is not possible or desirable to put the notice in a particular file, then
|
||||
You may include the notice in a location (such as a LICENSE file in a relevant
|
||||
directory) where a recipient would be likely to look for such a notice.
|
||||
|
||||
You may add additional accurate notices of copyright ownership.
|
||||
|
||||
Exhibit B - “Incompatible With Secondary Licenses” Notice
|
||||
|
||||
This Source Code Form is “Incompatible
|
||||
With Secondary Licenses”, as defined by
|
||||
the Mozilla Public License, v. 2.0.
|
|
@ -0,0 +1,161 @@
|
|||
# Go Plugin System over RPC
|
||||
|
||||
`go-plugin` is a Go (golang) plugin system over RPC. It is the plugin system
|
||||
that has been in use by HashiCorp tooling for over 3 years. While initially
|
||||
created for [Packer](https://www.packer.io), it has since been used by
|
||||
[Terraform](https://www.terraform.io) and [Otto](https://www.ottoproject.io),
|
||||
with plans to also use it for [Nomad](https://www.nomadproject.io) and
|
||||
[Vault](https://www.vaultproject.io).
|
||||
|
||||
While the plugin system is over RPC, it is currently only designed to work
|
||||
over a local [reliable] network. Plugins over a real network are not supported
|
||||
and will lead to unexpected behavior.
|
||||
|
||||
This plugin system has been used on millions of machines across many different
|
||||
projects and has proven to be battle hardened and ready for production use.
|
||||
|
||||
## Features
|
||||
|
||||
The HashiCorp plugin system supports a number of features:
|
||||
|
||||
**Plugins are Go interface implementations.** This makes writing and consuming
|
||||
plugins feel very natural. To a plugin author: you just implement an
|
||||
interface as if it were going to run in the same process. For a plugin user:
|
||||
you just use and call functions on an interface as if it were in the same
|
||||
process. This plugin system handles the communication in between.
|
||||
|
||||
**Complex arguments and return values are supported.** This library
|
||||
provides APIs for handling complex arguments and return values such
|
||||
as interfaces, `io.Reader/Writer`, etc. We do this by giving you a library
|
||||
(`MuxBroker`) for creating new connections between the client/server to
|
||||
serve additional interfaces or transfer raw data.
|
||||
|
||||
**Bidirectional communication.** Because the plugin system supports
|
||||
complex arguments, the host process can send it interface implementations
|
||||
and the plugin can call back into the host process.
|
||||
|
||||
**Built-in Logging.** Any plugins that use the `log` standard library
|
||||
will have log data automatically sent to the host process. The host
|
||||
process will mirror this output prefixed with the path to the plugin
|
||||
binary. This makes debugging with plugins simple.
|
||||
|
||||
**Protocol Versioning.** A very basic "protocol version" is supported that
|
||||
can be incremented to invalidate any previous plugins. This is useful when
|
||||
interface signatures are changing, protocol level changes are necessary,
|
||||
etc. When a protocol version is incompatible, a human friendly error
|
||||
message is shown to the end user.
|
||||
|
||||
**Stdout/Stderr Syncing.** While plugins are subprocesses, they can continue
|
||||
to use stdout/stderr as usual and the output will get mirrored back to
|
||||
the host process. The host process can control what `io.Writer` these
|
||||
streams go to to prevent this from happening.
|
||||
|
||||
**TTY Preservation.** Plugin subprocesses are connected to the identical
|
||||
stdin file descriptor as the host process, allowing software that requires
|
||||
a TTY to work. For example, a plugin can execute `ssh` and even though there
|
||||
are multiple subprocesses and RPC happening, it will look and act perfectly
|
||||
to the end user.
|
||||
|
||||
**Host upgrade while a plugin is running.** Plugins can be "reattached"
|
||||
so that the host process can be upgraded while the plugin is still running.
|
||||
This requires the host/plugin to know this is possible and daemonize
|
||||
properly. `NewClient` takes a `ReattachConfig` to determine if and how to
|
||||
reattach.
|
||||
|
||||
## Architecture
|
||||
|
||||
The HashiCorp plugin system works by launching subprocesses and communicating
|
||||
over RPC (using standard `net/rpc`). A single connection is made between
|
||||
any plugin and the host process, and we use a
|
||||
[connection multiplexing](https://github.com/hashicorp/yamux)
|
||||
library to multiplex any other connections on top.
|
||||
|
||||
This architecture has a number of benefits:
|
||||
|
||||
* Plugins can't crash your host process: A panic in a plugin doesn't
|
||||
panic the plugin user.
|
||||
|
||||
* Plugins are very easy to write: just write a Go application and `go build`.
|
||||
Theoretically you could also use another language as long as it can
|
||||
communicate the Go `net/rpc` protocol but this hasn't yet been tried.
|
||||
|
||||
* Plugins are very easy to install: just put the binary in a location where
|
||||
the host will find it (depends on the host but this library also provides
|
||||
helpers), and the plugin host handles the rest.
|
||||
|
||||
* Plugins can be relatively secure: The plugin only has access to the
|
||||
interfaces and args given to it, not to the entire memory space of the
|
||||
process. More security features are planned (see the coming soon section
|
||||
below).
|
||||
|
||||
## Usage
|
||||
|
||||
To use the plugin system, you must take the following steps. These are
|
||||
high-level steps that must be done. Examples are available in the
|
||||
`examples/` directory.
|
||||
|
||||
1. Choose the interface(s) you want to expose for plugins.
|
||||
|
||||
2. For each interface, implement an implementation of that interface
|
||||
that communicates over an `*rpc.Client` (from the standard `net/rpc`
|
||||
package) for every function call. Likewise, implement the RPC server
|
||||
struct this communicates to which is then communicating to a real,
|
||||
concrete implementation.
|
||||
|
||||
3. Create a `Plugin` implementation that knows how to create the RPC
|
||||
client/server for a given plugin type.
|
||||
|
||||
4. Plugin authors call `plugin.Serve` to serve a plugin from the
|
||||
`main` function.
|
||||
|
||||
5. Plugin users use `plugin.Client` to launch a subprocess and request
|
||||
an interface implementation over RPC.
|
||||
|
||||
That's it! In practice, step 2 is the most tedious and time consuming step.
|
||||
Even so, it isn't very difficult and you can see examples in the `examples/`
|
||||
directory as well as throughout our various open source projects.
|
||||
|
||||
For complete API documentation, see [GoDoc](https://godoc.org/github.com/hashicorp/go-plugin).
|
||||
|
||||
## Roadmap
|
||||
|
||||
Our plugin system is constantly evolving. As we use the plugin system for
|
||||
new projects or for new features in existing projects, we constantly find
|
||||
improvements we can make.
|
||||
|
||||
At this point in time, the roadmap for the plugin system is:
|
||||
|
||||
**Cryptographically Secure Plugins.** We'll implement signing plugins
|
||||
and loading signed plugins in order to allow Vault to make use of multi-process
|
||||
in a secure way.
|
||||
|
||||
**Semantic Versioning.** Plugins will be able to implement a semantic version.
|
||||
This plugin system will give host processes a system for constraining
|
||||
versions. This is in addition to the protocol versioning already present
|
||||
which is more for larger underlying changes.
|
||||
|
||||
**Plugin fetching.** We will integrate with [go-getter](https://github.com/hashicorp/go-getter)
|
||||
to support automatic download + install of plugins. Paired with cryptographically
|
||||
secure plugins (above), we can make this a safe operation for an amazing
|
||||
user experience.
|
||||
|
||||
## What About Shared Libraries?
|
||||
|
||||
When we started using plugins (late 2012, early 2013), plugins over RPC
|
||||
were the only option since Go didn't support dynamic library loading. Today,
|
||||
Go still doesn't support dynamic library loading, but they do intend to.
|
||||
Since 2012, our plugin system has stabilized from millions of users using it,
|
||||
and has many benefits we've come to value greatly.
|
||||
|
||||
For example, we intend to use this plugin system in
|
||||
[Vault](https://www.vaultproject.io), and dynamic library loading will
|
||||
simply never be acceptable in Vault for security reasons. That is an extreme
|
||||
example, but we believe our library system has more upsides than downsides
|
||||
over dynamic library loading and since we've had it built and tested for years,
|
||||
we'll likely continue to use it.
|
||||
|
||||
Shared libraries have one major advantage over our system which is much
|
||||
higher performance. In real world scenarios across our various tools,
|
||||
we've never required any more performance out of our plugin system and it
|
||||
has seen very high throughput, so this isn't a concern for us at the moment.
|
||||
|
|
@ -0,0 +1,666 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// If this is 1, then we've called CleanupClients. This can be used
|
||||
// by plugin RPC implementations to change error behavior since you
|
||||
// can expected network connection errors at this point. This should be
|
||||
// read by using sync/atomic.
|
||||
var Killed uint32 = 0
|
||||
|
||||
// This is a slice of the "managed" clients which are cleaned up when
|
||||
// calling Cleanup
|
||||
var managedClients = make([]*Client, 0, 5)
|
||||
var managedClientsLock sync.Mutex
|
||||
|
||||
// Error types
|
||||
var (
|
||||
// ErrProcessNotFound is returned when a client is instantiated to
|
||||
// reattach to an existing process and it isn't found.
|
||||
ErrProcessNotFound = errors.New("Reattachment process not found")
|
||||
|
||||
// ErrChecksumsDoNotMatch is returned when binary's checksum doesn't match
|
||||
// the one provided in the SecureConfig.
|
||||
ErrChecksumsDoNotMatch = errors.New("checksums did not match")
|
||||
|
||||
// ErrSecureNoChecksum is returned when an empty checksum is provided to the
|
||||
// SecureConfig.
|
||||
ErrSecureConfigNoChecksum = errors.New("no checksum provided")
|
||||
|
||||
// ErrSecureNoHash is returned when a nil Hash object is provided to the
|
||||
// SecureConfig.
|
||||
ErrSecureConfigNoHash = errors.New("no hash implementation provided")
|
||||
|
||||
// ErrSecureConfigAndReattach is returned when both Reattach and
|
||||
// SecureConfig are set.
|
||||
ErrSecureConfigAndReattach = errors.New("only one of Reattach or SecureConfig can be set")
|
||||
)
|
||||
|
||||
// Client handles the lifecycle of a plugin application. It launches
|
||||
// plugins, connects to them, dispenses interface implementations, and handles
|
||||
// killing the process.
|
||||
//
|
||||
// Plugin hosts should use one Client for each plugin executable. To
|
||||
// dispense a plugin type, use the `Client.Client` function, and then
|
||||
// cal `Dispense`. This awkward API is mostly historical but is used to split
|
||||
// the client that deals with subprocess management and the client that
|
||||
// does RPC management.
|
||||
//
|
||||
// See NewClient and ClientConfig for using a Client.
|
||||
type Client struct {
|
||||
config *ClientConfig
|
||||
exited bool
|
||||
doneLogging chan struct{}
|
||||
l sync.Mutex
|
||||
address net.Addr
|
||||
process *os.Process
|
||||
client *RPCClient
|
||||
}
|
||||
|
||||
// ClientConfig is the configuration used to initialize a new
|
||||
// plugin client. After being used to initialize a plugin client,
|
||||
// that configuration must not be modified again.
|
||||
type ClientConfig struct {
|
||||
// HandshakeConfig is the configuration that must match servers.
|
||||
HandshakeConfig
|
||||
|
||||
// Plugins are the plugins that can be consumed.
|
||||
Plugins map[string]Plugin
|
||||
|
||||
// One of the following must be set, but not both.
|
||||
//
|
||||
// Cmd is the unstarted subprocess for starting the plugin. If this is
|
||||
// set, then the Client starts the plugin process on its own and connects
|
||||
// to it.
|
||||
//
|
||||
// Reattach is configuration for reattaching to an existing plugin process
|
||||
// that is already running. This isn't common.
|
||||
Cmd *exec.Cmd
|
||||
Reattach *ReattachConfig
|
||||
|
||||
// SecureConfig is configuration for verifying the integrity of the
|
||||
// executable. It can not be used with Reattach.
|
||||
SecureConfig *SecureConfig
|
||||
|
||||
// TLSConfig is used to enable TLS on the RPC client.
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// Managed represents if the client should be managed by the
|
||||
// plugin package or not. If true, then by calling CleanupClients,
|
||||
// it will automatically be cleaned up. Otherwise, the client
|
||||
// user is fully responsible for making sure to Kill all plugin
|
||||
// clients. By default the client is _not_ managed.
|
||||
Managed bool
|
||||
|
||||
// The minimum and maximum port to use for communicating with
|
||||
// the subprocess. If not set, this defaults to 10,000 and 25,000
|
||||
// respectively.
|
||||
MinPort, MaxPort uint
|
||||
|
||||
// StartTimeout is the timeout to wait for the plugin to say it
|
||||
// has started successfully.
|
||||
StartTimeout time.Duration
|
||||
|
||||
// If non-nil, then the stderr of the client will be written to here
|
||||
// (as well as the log). This is the original os.Stderr of the subprocess.
|
||||
// This isn't the output of synced stderr.
|
||||
Stderr io.Writer
|
||||
|
||||
// SyncStdout, SyncStderr can be set to override the
|
||||
// respective os.Std* values in the plugin. Care should be taken to
|
||||
// avoid races here. If these are nil, then this will automatically be
|
||||
// hooked up to os.Stdin, Stdout, and Stderr, respectively.
|
||||
//
|
||||
// If the default values (nil) are used, then this package will not
|
||||
// sync any of these streams.
|
||||
SyncStdout io.Writer
|
||||
SyncStderr io.Writer
|
||||
}
|
||||
|
||||
// ReattachConfig is used to configure a client to reattach to an
|
||||
// already-running plugin process. You can retrieve this information by
|
||||
// calling ReattachConfig on Client.
|
||||
type ReattachConfig struct {
|
||||
Addr net.Addr
|
||||
Pid int
|
||||
}
|
||||
|
||||
// SecureConfig is used to configure a client to verify the integrity of an
|
||||
// executable before running. It does this by verifying the checksum is
|
||||
// expected. Hash is used to specify the hashing method to use when checksumming
|
||||
// the file. The configuration is verified by the client by calling the
|
||||
// SecureConfig.Check() function.
|
||||
//
|
||||
// The host process should ensure the checksum was provided by a trusted and
|
||||
// authoritative source. The binary should be installed in such a way that it
|
||||
// can not be modified by an unauthorized user between the time of this check
|
||||
// and the time of execution.
|
||||
type SecureConfig struct {
|
||||
Checksum []byte
|
||||
Hash hash.Hash
|
||||
}
|
||||
|
||||
// Check takes the filepath to an executable and returns true if the checksum of
|
||||
// the file matches the checksum provided in the SecureConfig.
|
||||
func (s *SecureConfig) Check(filePath string) (bool, error) {
|
||||
if len(s.Checksum) == 0 {
|
||||
return false, ErrSecureConfigNoChecksum
|
||||
}
|
||||
|
||||
if s.Hash == nil {
|
||||
return false, ErrSecureConfigNoHash
|
||||
}
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
_, err = io.Copy(s.Hash, file)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
sum := s.Hash.Sum(nil)
|
||||
|
||||
return subtle.ConstantTimeCompare(sum, s.Checksum) == 1, nil
|
||||
}
|
||||
|
||||
// This makes sure all the managed subprocesses are killed and properly
|
||||
// logged. This should be called before the parent process running the
|
||||
// plugins exits.
|
||||
//
|
||||
// This must only be called _once_.
|
||||
func CleanupClients() {
|
||||
// Set the killed to true so that we don't get unexpected panics
|
||||
atomic.StoreUint32(&Killed, 1)
|
||||
|
||||
// Kill all the managed clients in parallel and use a WaitGroup
|
||||
// to wait for them all to finish up.
|
||||
var wg sync.WaitGroup
|
||||
managedClientsLock.Lock()
|
||||
for _, client := range managedClients {
|
||||
wg.Add(1)
|
||||
|
||||
go func(client *Client) {
|
||||
client.Kill()
|
||||
wg.Done()
|
||||
}(client)
|
||||
}
|
||||
managedClientsLock.Unlock()
|
||||
|
||||
log.Println("[DEBUG] plugin: waiting for all plugin processes to complete...")
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Creates a new plugin client which manages the lifecycle of an external
|
||||
// plugin and gets the address for the RPC connection.
|
||||
//
|
||||
// The client must be cleaned up at some point by calling Kill(). If
|
||||
// the client is a managed client (created with NewManagedClient) you
|
||||
// can just call CleanupClients at the end of your program and they will
|
||||
// be properly cleaned.
|
||||
func NewClient(config *ClientConfig) (c *Client) {
|
||||
if config.MinPort == 0 && config.MaxPort == 0 {
|
||||
config.MinPort = 10000
|
||||
config.MaxPort = 25000
|
||||
}
|
||||
|
||||
if config.StartTimeout == 0 {
|
||||
config.StartTimeout = 1 * time.Minute
|
||||
}
|
||||
|
||||
if config.Stderr == nil {
|
||||
config.Stderr = ioutil.Discard
|
||||
}
|
||||
|
||||
if config.SyncStdout == nil {
|
||||
config.SyncStdout = ioutil.Discard
|
||||
}
|
||||
if config.SyncStderr == nil {
|
||||
config.SyncStderr = ioutil.Discard
|
||||
}
|
||||
|
||||
c = &Client{config: config}
|
||||
if config.Managed {
|
||||
managedClientsLock.Lock()
|
||||
managedClients = append(managedClients, c)
|
||||
managedClientsLock.Unlock()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Client returns an RPC client for the plugin.
|
||||
//
|
||||
// Subsequent calls to this will return the same RPC client.
|
||||
func (c *Client) Client() (*RPCClient, error) {
|
||||
addr, err := c.Start()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.l.Lock()
|
||||
defer c.l.Unlock()
|
||||
|
||||
if c.client != nil {
|
||||
return c.client, nil
|
||||
}
|
||||
|
||||
// Connect to the client
|
||||
conn, err := net.Dial(addr.Network(), addr.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
// Make sure to set keep alive so that the connection doesn't die
|
||||
tcpConn.SetKeepAlive(true)
|
||||
}
|
||||
|
||||
if c.config.TLSConfig != nil {
|
||||
conn = tls.Client(conn, c.config.TLSConfig)
|
||||
}
|
||||
|
||||
// Create the actual RPC client
|
||||
c.client, err = NewRPCClient(conn, c.config.Plugins)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Begin the stream syncing so that stdin, out, err work properly
|
||||
err = c.client.SyncStreams(
|
||||
c.config.SyncStdout,
|
||||
c.config.SyncStderr)
|
||||
if err != nil {
|
||||
c.client.Close()
|
||||
c.client = nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.client, nil
|
||||
}
|
||||
|
||||
// Tells whether or not the underlying process has exited.
|
||||
func (c *Client) Exited() bool {
|
||||
c.l.Lock()
|
||||
defer c.l.Unlock()
|
||||
return c.exited
|
||||
}
|
||||
|
||||
// End the executing subprocess (if it is running) and perform any cleanup
|
||||
// tasks necessary such as capturing any remaining logs and so on.
|
||||
//
|
||||
// This method blocks until the process successfully exits.
|
||||
//
|
||||
// This method can safely be called multiple times.
|
||||
func (c *Client) Kill() {
|
||||
// Grab a lock to read some private fields.
|
||||
c.l.Lock()
|
||||
process := c.process
|
||||
addr := c.address
|
||||
doneCh := c.doneLogging
|
||||
c.l.Unlock()
|
||||
|
||||
// If there is no process, we never started anything. Nothing to kill.
|
||||
if process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// We need to check for address here. It is possible that the plugin
|
||||
// started (process != nil) but has no address (addr == nil) if the
|
||||
// plugin failed at startup. If we do have an address, we need to close
|
||||
// the plugin net connections.
|
||||
graceful := false
|
||||
if addr != nil {
|
||||
// Close the client to cleanly exit the process.
|
||||
client, err := c.Client()
|
||||
if err == nil {
|
||||
err = client.Close()
|
||||
|
||||
// If there is no error, then we attempt to wait for a graceful
|
||||
// exit. If there was an error, we assume that graceful cleanup
|
||||
// won't happen and just force kill.
|
||||
graceful = err == nil
|
||||
if err != nil {
|
||||
// If there was an error just log it. We're going to force
|
||||
// kill in a moment anyways.
|
||||
log.Printf(
|
||||
"[WARN] plugin: error closing client during Kill: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we're attempting a graceful exit, then we wait for a short period
|
||||
// of time to allow that to happen. To wait for this we just wait on the
|
||||
// doneCh which would be closed if the process exits.
|
||||
if graceful {
|
||||
select {
|
||||
case <-doneCh:
|
||||
return
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
// If graceful exiting failed, just kill it
|
||||
process.Kill()
|
||||
|
||||
// Wait for the client to finish logging so we have a complete log
|
||||
<-doneCh
|
||||
}
|
||||
|
||||
// Starts the underlying subprocess, communicating with it to negotiate
|
||||
// a port for RPC connections, and returning the address to connect via RPC.
|
||||
//
|
||||
// This method is safe to call multiple times. Subsequent calls have no effect.
|
||||
// Once a client has been started once, it cannot be started again, even if
|
||||
// it was killed.
|
||||
func (c *Client) Start() (addr net.Addr, err error) {
|
||||
c.l.Lock()
|
||||
defer c.l.Unlock()
|
||||
|
||||
if c.address != nil {
|
||||
return c.address, nil
|
||||
}
|
||||
|
||||
// If one of cmd or reattach isn't set, then it is an error. We wrap
|
||||
// this in a {} for scoping reasons, and hopeful that the escape
|
||||
// analysis will pop the stock here.
|
||||
{
|
||||
cmdSet := c.config.Cmd != nil
|
||||
attachSet := c.config.Reattach != nil
|
||||
secureSet := c.config.SecureConfig != nil
|
||||
if cmdSet == attachSet {
|
||||
return nil, fmt.Errorf("Only one of Cmd or Reattach must be set")
|
||||
}
|
||||
|
||||
if secureSet && attachSet {
|
||||
return nil, ErrSecureConfigAndReattach
|
||||
}
|
||||
}
|
||||
|
||||
// Create the logging channel for when we kill
|
||||
c.doneLogging = make(chan struct{})
|
||||
|
||||
if c.config.Reattach != nil {
|
||||
// Verify the process still exists. If not, then it is an error
|
||||
p, err := os.FindProcess(c.config.Reattach.Pid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Attempt to connect to the addr since on Unix systems FindProcess
|
||||
// doesn't actually return an error if it can't find the process.
|
||||
conn, err := net.Dial(
|
||||
c.config.Reattach.Addr.Network(),
|
||||
c.config.Reattach.Addr.String())
|
||||
if err != nil {
|
||||
p.Kill()
|
||||
return nil, ErrProcessNotFound
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
// Goroutine to mark exit status
|
||||
go func(pid int) {
|
||||
// Wait for the process to die
|
||||
pidWait(pid)
|
||||
|
||||
// Log so we can see it
|
||||
log.Printf("[DEBUG] plugin: reattached plugin process exited\n")
|
||||
|
||||
// Mark it
|
||||
c.l.Lock()
|
||||
defer c.l.Unlock()
|
||||
c.exited = true
|
||||
|
||||
// Close the logging channel since that doesn't work on reattach
|
||||
close(c.doneLogging)
|
||||
}(p.Pid)
|
||||
|
||||
// Set the address and process
|
||||
c.address = c.config.Reattach.Addr
|
||||
c.process = p
|
||||
|
||||
return c.address, nil
|
||||
}
|
||||
|
||||
env := []string{
|
||||
fmt.Sprintf("%s=%s", c.config.MagicCookieKey, c.config.MagicCookieValue),
|
||||
fmt.Sprintf("PLUGIN_MIN_PORT=%d", c.config.MinPort),
|
||||
fmt.Sprintf("PLUGIN_MAX_PORT=%d", c.config.MaxPort),
|
||||
}
|
||||
|
||||
stdout_r, stdout_w := io.Pipe()
|
||||
stderr_r, stderr_w := io.Pipe()
|
||||
|
||||
cmd := c.config.Cmd
|
||||
cmd.Env = append(cmd.Env, os.Environ()...)
|
||||
cmd.Env = append(cmd.Env, env...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stderr = stderr_w
|
||||
cmd.Stdout = stdout_w
|
||||
|
||||
if c.config.SecureConfig != nil {
|
||||
if ok, err := c.config.SecureConfig.Check(cmd.Path); err != nil {
|
||||
return nil, fmt.Errorf("error verifying checksum: %s", err)
|
||||
} else if !ok {
|
||||
return nil, ErrChecksumsDoNotMatch
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG] plugin: starting plugin: %s %#v", cmd.Path, cmd.Args)
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Set the process
|
||||
c.process = cmd.Process
|
||||
|
||||
// Make sure the command is properly cleaned up if there is an error
|
||||
defer func() {
|
||||
r := recover()
|
||||
|
||||
if err != nil || r != nil {
|
||||
cmd.Process.Kill()
|
||||
}
|
||||
|
||||
if r != nil {
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start goroutine to wait for process to exit
|
||||
exitCh := make(chan struct{})
|
||||
go func() {
|
||||
// Make sure we close the write end of our stderr/stdout so
|
||||
// that the readers send EOF properly.
|
||||
defer stderr_w.Close()
|
||||
defer stdout_w.Close()
|
||||
|
||||
// Wait for the command to end.
|
||||
cmd.Wait()
|
||||
|
||||
// Log and make sure to flush the logs write away
|
||||
log.Printf("[DEBUG] plugin: %s: plugin process exited\n", cmd.Path)
|
||||
os.Stderr.Sync()
|
||||
|
||||
// Mark that we exited
|
||||
close(exitCh)
|
||||
|
||||
// Set that we exited, which takes a lock
|
||||
c.l.Lock()
|
||||
defer c.l.Unlock()
|
||||
c.exited = true
|
||||
}()
|
||||
|
||||
// Start goroutine that logs the stderr
|
||||
go c.logStderr(stderr_r)
|
||||
|
||||
// Start a goroutine that is going to be reading the lines
|
||||
// out of stdout
|
||||
linesCh := make(chan []byte)
|
||||
go func() {
|
||||
defer close(linesCh)
|
||||
|
||||
buf := bufio.NewReader(stdout_r)
|
||||
for {
|
||||
line, err := buf.ReadBytes('\n')
|
||||
if line != nil {
|
||||
linesCh <- line
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Make sure after we exit we read the lines from stdout forever
|
||||
// so they don't block since it is an io.Pipe
|
||||
defer func() {
|
||||
go func() {
|
||||
for _ = range linesCh {
|
||||
}
|
||||
}()
|
||||
}()
|
||||
|
||||
// Some channels for the next step
|
||||
timeout := time.After(c.config.StartTimeout)
|
||||
|
||||
// Start looking for the address
|
||||
log.Printf("[DEBUG] plugin: waiting for RPC address for: %s", cmd.Path)
|
||||
select {
|
||||
case <-timeout:
|
||||
err = errors.New("timeout while waiting for plugin to start")
|
||||
case <-exitCh:
|
||||
err = errors.New("plugin exited before we could connect")
|
||||
case lineBytes := <-linesCh:
|
||||
// Trim the line and split by "|" in order to get the parts of
|
||||
// the output.
|
||||
line := strings.TrimSpace(string(lineBytes))
|
||||
parts := strings.SplitN(line, "|", 4)
|
||||
if len(parts) < 4 {
|
||||
err = fmt.Errorf(
|
||||
"Unrecognized remote plugin message: %s\n\n"+
|
||||
"This usually means that the plugin is either invalid or simply\n"+
|
||||
"needs to be recompiled to support the latest protocol.", line)
|
||||
return
|
||||
}
|
||||
|
||||
// Check the core protocol. Wrapped in a {} for scoping.
|
||||
{
|
||||
var coreProtocol int64
|
||||
coreProtocol, err = strconv.ParseInt(parts[0], 10, 0)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Error parsing core protocol version: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if int(coreProtocol) != CoreProtocolVersion {
|
||||
err = fmt.Errorf("Incompatible core API version with plugin. "+
|
||||
"Plugin version: %s, Ours: %d\n\n"+
|
||||
"To fix this, the plugin usually only needs to be recompiled.\n"+
|
||||
"Please report this to the plugin author.", parts[0], CoreProtocolVersion)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the protocol version
|
||||
var protocol int64
|
||||
protocol, err = strconv.ParseInt(parts[1], 10, 0)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Error parsing protocol version: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Test the API version
|
||||
if uint(protocol) != c.config.ProtocolVersion {
|
||||
err = fmt.Errorf("Incompatible API version with plugin. "+
|
||||
"Plugin version: %s, Ours: %d", parts[1], c.config.ProtocolVersion)
|
||||
return
|
||||
}
|
||||
|
||||
switch parts[2] {
|
||||
case "tcp":
|
||||
addr, err = net.ResolveTCPAddr("tcp", parts[3])
|
||||
case "unix":
|
||||
addr, err = net.ResolveUnixAddr("unix", parts[3])
|
||||
default:
|
||||
err = fmt.Errorf("Unknown address type: %s", parts[3])
|
||||
}
|
||||
}
|
||||
|
||||
c.address = addr
|
||||
return
|
||||
}
|
||||
|
||||
// ReattachConfig returns the information that must be provided to NewClient
|
||||
// to reattach to the plugin process that this client started. This is
|
||||
// useful for plugins that detach from their parent process.
|
||||
//
|
||||
// If this returns nil then the process hasn't been started yet. Please
|
||||
// call Start or Client before calling this.
|
||||
func (c *Client) ReattachConfig() *ReattachConfig {
|
||||
c.l.Lock()
|
||||
defer c.l.Unlock()
|
||||
|
||||
if c.address == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.config.Cmd != nil && c.config.Cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we connected via reattach, just return the information as-is
|
||||
if c.config.Reattach != nil {
|
||||
return c.config.Reattach
|
||||
}
|
||||
|
||||
return &ReattachConfig{
|
||||
Addr: c.address,
|
||||
Pid: c.config.Cmd.Process.Pid,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) logStderr(r io.Reader) {
|
||||
bufR := bufio.NewReader(r)
|
||||
for {
|
||||
line, err := bufR.ReadString('\n')
|
||||
if line != "" {
|
||||
c.config.Stderr.Write([]byte(line))
|
||||
|
||||
line = strings.TrimRightFunc(line, unicode.IsSpace)
|
||||
log.Printf("[DEBUG] plugin: %s: %s", filepath.Base(c.config.Cmd.Path), line)
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Flag that we've completed logging for others
|
||||
close(c.doneLogging)
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// Discover discovers plugins that are in a given directory.
|
||||
//
|
||||
// The directory doesn't need to be absolute. For example, "." will work fine.
|
||||
//
|
||||
// This currently assumes any file matching the glob is a plugin.
|
||||
// In the future this may be smarter about checking that a file is
|
||||
// executable and so on.
|
||||
//
|
||||
// TODO: test
|
||||
func Discover(glob, dir string) ([]string, error) {
|
||||
var err error
|
||||
|
||||
// Make the directory absolute if it isn't already
|
||||
if !filepath.IsAbs(dir) {
|
||||
dir, err = filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return filepath.Glob(filepath.Join(dir, glob))
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package plugin
|
||||
|
||||
// This is a type that wraps error types so that they can be messaged
|
||||
// across RPC channels. Since "error" is an interface, we can't always
|
||||
// gob-encode the underlying structure. This is a valid error interface
|
||||
// implementer that we will push across.
|
||||
type BasicError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
// NewBasicError is used to create a BasicError.
|
||||
//
|
||||
// err is allowed to be nil.
|
||||
func NewBasicError(err error) *BasicError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &BasicError{err.Error()}
|
||||
}
|
||||
|
||||
func (e *BasicError) Error() string {
|
||||
return e.Message
|
||||
}
|
|
@ -0,0 +1,204 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
// MuxBroker is responsible for brokering multiplexed connections by unique ID.
|
||||
//
|
||||
// It is used by plugins to multiplex multiple RPC connections and data
|
||||
// streams on top of a single connection between the plugin process and the
|
||||
// host process.
|
||||
//
|
||||
// This allows a plugin to request a channel with a specific ID to connect to
|
||||
// or accept a connection from, and the broker handles the details of
|
||||
// holding these channels open while they're being negotiated.
|
||||
//
|
||||
// The Plugin interface has access to these for both Server and Client.
|
||||
// The broker can be used by either (optionally) to reserve and connect to
|
||||
// new multiplexed streams. This is useful for complex args and return values,
|
||||
// or anything else you might need a data stream for.
|
||||
type MuxBroker struct {
|
||||
nextId uint32
|
||||
session *yamux.Session
|
||||
streams map[uint32]*muxBrokerPending
|
||||
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
type muxBrokerPending struct {
|
||||
ch chan net.Conn
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
func newMuxBroker(s *yamux.Session) *MuxBroker {
|
||||
return &MuxBroker{
|
||||
session: s,
|
||||
streams: make(map[uint32]*muxBrokerPending),
|
||||
}
|
||||
}
|
||||
|
||||
// Accept accepts a connection by ID.
|
||||
//
|
||||
// This should not be called multiple times with the same ID at one time.
|
||||
func (m *MuxBroker) Accept(id uint32) (net.Conn, error) {
|
||||
var c net.Conn
|
||||
p := m.getStream(id)
|
||||
select {
|
||||
case c = <-p.ch:
|
||||
close(p.doneCh)
|
||||
case <-time.After(5 * time.Second):
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
delete(m.streams, id)
|
||||
|
||||
return nil, fmt.Errorf("timeout waiting for accept")
|
||||
}
|
||||
|
||||
// Ack our connection
|
||||
if err := binary.Write(c, binary.LittleEndian, id); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// AcceptAndServe is used to accept a specific stream ID and immediately
|
||||
// serve an RPC server on that stream ID. This is used to easily serve
|
||||
// complex arguments.
|
||||
//
|
||||
// The served interface is always registered to the "Plugin" name.
|
||||
func (m *MuxBroker) AcceptAndServe(id uint32, v interface{}) {
|
||||
conn, err := m.Accept(id)
|
||||
if err != nil {
|
||||
log.Printf("[ERR] plugin: plugin acceptAndServe error: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
serve(conn, "Plugin", v)
|
||||
}
|
||||
|
||||
// Close closes the connection and all sub-connections.
|
||||
func (m *MuxBroker) Close() error {
|
||||
return m.session.Close()
|
||||
}
|
||||
|
||||
// Dial opens a connection by ID.
|
||||
func (m *MuxBroker) Dial(id uint32) (net.Conn, error) {
|
||||
// Open the stream
|
||||
stream, err := m.session.OpenStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write the stream ID onto the wire.
|
||||
if err := binary.Write(stream, binary.LittleEndian, id); err != nil {
|
||||
stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read the ack that we connected. Then we're off!
|
||||
var ack uint32
|
||||
if err := binary.Read(stream, binary.LittleEndian, &ack); err != nil {
|
||||
stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
if ack != id {
|
||||
stream.Close()
|
||||
return nil, fmt.Errorf("bad ack: %d (expected %d)", ack, id)
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// NextId returns a unique ID to use next.
|
||||
//
|
||||
// It is possible for very long-running plugin hosts to wrap this value,
|
||||
// though it would require a very large amount of RPC calls. In practice
|
||||
// we've never seen it happen.
|
||||
func (m *MuxBroker) NextId() uint32 {
|
||||
return atomic.AddUint32(&m.nextId, 1)
|
||||
}
|
||||
|
||||
// Run starts the brokering and should be executed in a goroutine, since it
|
||||
// blocks forever, or until the session closes.
|
||||
//
|
||||
// Uses of MuxBroker never need to call this. It is called internally by
|
||||
// the plugin host/client.
|
||||
func (m *MuxBroker) Run() {
|
||||
for {
|
||||
stream, err := m.session.AcceptStream()
|
||||
if err != nil {
|
||||
// Once we receive an error, just exit
|
||||
break
|
||||
}
|
||||
|
||||
// Read the stream ID from the stream
|
||||
var id uint32
|
||||
if err := binary.Read(stream, binary.LittleEndian, &id); err != nil {
|
||||
stream.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// Initialize the waiter
|
||||
p := m.getStream(id)
|
||||
select {
|
||||
case p.ch <- stream:
|
||||
default:
|
||||
}
|
||||
|
||||
// Wait for a timeout
|
||||
go m.timeoutWait(id, p)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MuxBroker) getStream(id uint32) *muxBrokerPending {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
p, ok := m.streams[id]
|
||||
if ok {
|
||||
return p
|
||||
}
|
||||
|
||||
m.streams[id] = &muxBrokerPending{
|
||||
ch: make(chan net.Conn, 1),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
return m.streams[id]
|
||||
}
|
||||
|
||||
func (m *MuxBroker) timeoutWait(id uint32, p *muxBrokerPending) {
|
||||
// Wait for the stream to either be picked up and connected, or
|
||||
// for a timeout.
|
||||
timeout := false
|
||||
select {
|
||||
case <-p.doneCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
timeout = true
|
||||
}
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Delete the stream so no one else can grab it
|
||||
delete(m.streams, id)
|
||||
|
||||
// If we timed out, then check if we have a channel in the buffer,
|
||||
// and if so, close it.
|
||||
if timeout {
|
||||
select {
|
||||
case s := <-p.ch:
|
||||
s.Close()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
// The plugin package exposes functions and helpers for communicating to
|
||||
// plugins which are implemented as standalone binary applications.
|
||||
//
|
||||
// plugin.Client fully manages the lifecycle of executing the application,
|
||||
// connecting to it, and returning the RPC client for dispensing plugins.
|
||||
//
|
||||
// plugin.Serve fully manages listeners to expose an RPC server from a binary
|
||||
// that plugin.Client can connect to.
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"net/rpc"
|
||||
)
|
||||
|
||||
// Plugin is the interface that is implemented to serve/connect to an
|
||||
// inteface implementation.
|
||||
type Plugin interface {
|
||||
// Server should return the RPC server compatible struct to serve
|
||||
// the methods that the Client calls over net/rpc.
|
||||
Server(*MuxBroker) (interface{}, error)
|
||||
|
||||
// Client returns an interface implementation for the plugin you're
|
||||
// serving that communicates to the server end of the plugin.
|
||||
Client(*MuxBroker, *rpc.Client) (interface{}, error)
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// pidAlive checks whether a pid is alive.
|
||||
func pidAlive(pid int) bool {
|
||||
return _pidAlive(pid)
|
||||
}
|
||||
|
||||
// pidWait blocks for a process to exit.
|
||||
func pidWait(pid int) error {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if !pidAlive(pid) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
// +build !windows
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// _pidAlive tests whether a process is alive or not by sending it Signal 0,
|
||||
// since Go otherwise has no way to test this.
|
||||
func _pidAlive(pid int) bool {
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err == nil {
|
||||
err = proc.Signal(syscall.Signal(0))
|
||||
}
|
||||
|
||||
return err == nil
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const (
|
||||
// Weird name but matches the MSDN docs
|
||||
exit_STILL_ACTIVE = 259
|
||||
|
||||
processDesiredAccess = syscall.STANDARD_RIGHTS_READ |
|
||||
syscall.PROCESS_QUERY_INFORMATION |
|
||||
syscall.SYNCHRONIZE
|
||||
)
|
||||
|
||||
// _pidAlive tests whether a process is alive or not
|
||||
func _pidAlive(pid int) bool {
|
||||
h, err := syscall.OpenProcess(processDesiredAccess, false, uint32(pid))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var ec uint32
|
||||
if e := syscall.GetExitCodeProcess(h, &ec); e != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return ec == exit_STILL_ACTIVE
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/rpc"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
// RPCClient connects to an RPCServer over net/rpc to dispense plugin types.
|
||||
type RPCClient struct {
|
||||
broker *MuxBroker
|
||||
control *rpc.Client
|
||||
plugins map[string]Plugin
|
||||
|
||||
// These are the streams used for the various stdout/err overrides
|
||||
stdout, stderr net.Conn
|
||||
}
|
||||
|
||||
// NewRPCClient creates a client from an already-open connection-like value.
|
||||
// Dial is typically used instead.
|
||||
func NewRPCClient(conn io.ReadWriteCloser, plugins map[string]Plugin) (*RPCClient, error) {
|
||||
// Create the yamux client so we can multiplex
|
||||
mux, err := yamux.Client(conn, nil)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Connect to the control stream.
|
||||
control, err := mux.Open()
|
||||
if err != nil {
|
||||
mux.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Connect stdout, stderr streams
|
||||
stdstream := make([]net.Conn, 2)
|
||||
for i, _ := range stdstream {
|
||||
stdstream[i], err = mux.Open()
|
||||
if err != nil {
|
||||
mux.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Create the broker and start it up
|
||||
broker := newMuxBroker(mux)
|
||||
go broker.Run()
|
||||
|
||||
// Build the client using our broker and control channel.
|
||||
return &RPCClient{
|
||||
broker: broker,
|
||||
control: rpc.NewClient(control),
|
||||
plugins: plugins,
|
||||
stdout: stdstream[0],
|
||||
stderr: stdstream[1],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SyncStreams should be called to enable syncing of stdout,
|
||||
// stderr with the plugin.
|
||||
//
|
||||
// This will return immediately and the syncing will continue to happen
|
||||
// in the background. You do not need to launch this in a goroutine itself.
|
||||
//
|
||||
// This should never be called multiple times.
|
||||
func (c *RPCClient) SyncStreams(stdout io.Writer, stderr io.Writer) error {
|
||||
go copyStream("stdout", stdout, c.stdout)
|
||||
go copyStream("stderr", stderr, c.stderr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the connection. The client is no longer usable after this
|
||||
// is called.
|
||||
func (c *RPCClient) Close() error {
|
||||
// Call the control channel and ask it to gracefully exit. If this
|
||||
// errors, then we save it so that we always return an error but we
|
||||
// want to try to close the other channels anyways.
|
||||
var empty struct{}
|
||||
returnErr := c.control.Call("Control.Quit", true, &empty)
|
||||
|
||||
// Close the other streams we have
|
||||
if err := c.control.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.stdout.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.stderr.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.broker.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return back the error we got from Control.Quit. This is very important
|
||||
// since we MUST return non-nil error if this fails so that Client.Kill
|
||||
// will properly try a process.Kill.
|
||||
return returnErr
|
||||
}
|
||||
|
||||
func (c *RPCClient) Dispense(name string) (interface{}, error) {
|
||||
p, ok := c.plugins[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown plugin type: %s", name)
|
||||
}
|
||||
|
||||
var id uint32
|
||||
if err := c.control.Call(
|
||||
"Dispenser.Dispense", name, &id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn, err := c.broker.Dial(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.Client(c.broker, rpc.NewClient(conn))
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/rpc"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
// RPCServer listens for network connections and then dispenses interface
|
||||
// implementations over net/rpc.
|
||||
//
|
||||
// After setting the fields below, they shouldn't be read again directly
|
||||
// from the structure which may be reading/writing them concurrently.
|
||||
type RPCServer struct {
|
||||
Plugins map[string]Plugin
|
||||
|
||||
// Stdout, Stderr are what this server will use instead of the
|
||||
// normal stdin/out/err. This is because due to the multi-process nature
|
||||
// of our plugin system, we can't use the normal process values so we
|
||||
// make our own custom one we pipe across.
|
||||
Stdout io.Reader
|
||||
Stderr io.Reader
|
||||
|
||||
// DoneCh should be set to a non-nil channel that will be closed
|
||||
// when the control requests the RPC server to end.
|
||||
DoneCh chan<- struct{}
|
||||
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// Accept accepts connections on a listener and serves requests for
|
||||
// each incoming connection. Accept blocks; the caller typically invokes
|
||||
// it in a go statement.
|
||||
func (s *RPCServer) Accept(lis net.Listener) {
|
||||
for {
|
||||
conn, err := lis.Accept()
|
||||
if err != nil {
|
||||
log.Printf("[ERR] plugin: plugin server: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
go s.ServeConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// ServeConn runs a single connection.
|
||||
//
|
||||
// ServeConn blocks, serving the connection until the client hangs up.
|
||||
func (s *RPCServer) ServeConn(conn io.ReadWriteCloser) {
|
||||
// First create the yamux server to wrap this connection
|
||||
mux, err := yamux.Server(conn, nil)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
log.Printf("[ERR] plugin: error creating yamux server: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Accept the control connection
|
||||
control, err := mux.Accept()
|
||||
if err != nil {
|
||||
mux.Close()
|
||||
if err != io.EOF {
|
||||
log.Printf("[ERR] plugin: error accepting control connection: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Connect the stdstreams (in, out, err)
|
||||
stdstream := make([]net.Conn, 2)
|
||||
for i, _ := range stdstream {
|
||||
stdstream[i], err = mux.Accept()
|
||||
if err != nil {
|
||||
mux.Close()
|
||||
log.Printf("[ERR] plugin: accepting stream %d: %s", i, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Copy std streams out to the proper place
|
||||
go copyStream("stdout", stdstream[0], s.Stdout)
|
||||
go copyStream("stderr", stdstream[1], s.Stderr)
|
||||
|
||||
// Create the broker and start it up
|
||||
broker := newMuxBroker(mux)
|
||||
go broker.Run()
|
||||
|
||||
// Use the control connection to build the dispenser and serve the
|
||||
// connection.
|
||||
server := rpc.NewServer()
|
||||
server.RegisterName("Control", &controlServer{
|
||||
server: s,
|
||||
})
|
||||
server.RegisterName("Dispenser", &dispenseServer{
|
||||
broker: broker,
|
||||
plugins: s.Plugins,
|
||||
})
|
||||
server.ServeConn(control)
|
||||
}
|
||||
|
||||
// done is called internally by the control server to trigger the
|
||||
// doneCh to close which is listened to by the main process to cleanly
|
||||
// exit.
|
||||
func (s *RPCServer) done() {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.DoneCh != nil {
|
||||
close(s.DoneCh)
|
||||
s.DoneCh = nil
|
||||
}
|
||||
}
|
||||
|
||||
// dispenseServer dispenses variousinterface implementations for Terraform.
|
||||
type controlServer struct {
|
||||
server *RPCServer
|
||||
}
|
||||
|
||||
func (c *controlServer) Quit(
|
||||
null bool, response *struct{}) error {
|
||||
// End the server
|
||||
c.server.done()
|
||||
|
||||
// Always return true
|
||||
*response = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dispenseServer dispenses variousinterface implementations for Terraform.
|
||||
type dispenseServer struct {
|
||||
broker *MuxBroker
|
||||
plugins map[string]Plugin
|
||||
}
|
||||
|
||||
func (d *dispenseServer) Dispense(
|
||||
name string, response *uint32) error {
|
||||
// Find the function to create this implementation
|
||||
p, ok := d.plugins[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown plugin type: %s", name)
|
||||
}
|
||||
|
||||
// Create the implementation first so we know if there is an error.
|
||||
impl, err := p.Server(d.broker)
|
||||
if err != nil {
|
||||
// We turn the error into an errors error so that it works across RPC
|
||||
return errors.New(err.Error())
|
||||
}
|
||||
|
||||
// Reserve an ID for our implementation
|
||||
id := d.broker.NextId()
|
||||
*response = id
|
||||
|
||||
// Run the rest in a goroutine since it can only happen once this RPC
|
||||
// call returns. We wait for a connection for the plugin implementation
|
||||
// and serve it.
|
||||
go func() {
|
||||
conn, err := d.broker.Accept(id)
|
||||
if err != nil {
|
||||
log.Printf("[ERR] go-plugin: plugin dispense error: %s: %s", name, err)
|
||||
return
|
||||
}
|
||||
|
||||
serve(conn, "Plugin", impl)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func serve(conn io.ReadWriteCloser, name string, v interface{}) {
|
||||
server := rpc.NewServer()
|
||||
if err := server.RegisterName(name, v); err != nil {
|
||||
log.Printf("[ERR] go-plugin: plugin dispense error: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
server.ServeConn(conn)
|
||||
}
|
|
@ -0,0 +1,235 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// CoreProtocolVersion is the ProtocolVersion of the plugin system itself.
|
||||
// We will increment this whenever we change any protocol behavior. This
|
||||
// will invalidate any prior plugins but will at least allow us to iterate
|
||||
// on the core in a safe way. We will do our best to do this very
|
||||
// infrequently.
|
||||
const CoreProtocolVersion = 1
|
||||
|
||||
// HandshakeConfig is the configuration used by client and servers to
|
||||
// handshake before starting a plugin connection. This is embedded by
|
||||
// both ServeConfig and ClientConfig.
|
||||
//
|
||||
// In practice, the plugin host creates a HandshakeConfig that is exported
|
||||
// and plugins then can easily consume it.
|
||||
type HandshakeConfig struct {
|
||||
// ProtocolVersion is the version that clients must match on to
|
||||
// agree they can communicate. This should match the ProtocolVersion
|
||||
// set on ClientConfig when using a plugin.
|
||||
ProtocolVersion uint
|
||||
|
||||
// MagicCookieKey and value are used as a very basic verification
|
||||
// that a plugin is intended to be launched. This is not a security
|
||||
// measure, just a UX feature. If the magic cookie doesn't match,
|
||||
// we show human-friendly output.
|
||||
MagicCookieKey string
|
||||
MagicCookieValue string
|
||||
}
|
||||
|
||||
// ServeConfig configures what sorts of plugins are served.
|
||||
type ServeConfig struct {
|
||||
// HandshakeConfig is the configuration that must match clients.
|
||||
HandshakeConfig
|
||||
|
||||
// Plugins are the plugins that are served.
|
||||
Plugins map[string]Plugin
|
||||
|
||||
// TLSProvider is a function that returns a configured tls.Config.
|
||||
TLSProvider func() (*tls.Config, error)
|
||||
}
|
||||
|
||||
// Serve serves the plugins given by ServeConfig.
|
||||
//
|
||||
// Serve doesn't return until the plugin is done being executed. Any
|
||||
// errors will be outputted to the log.
|
||||
//
|
||||
// This is the method that plugins should call in their main() functions.
|
||||
func Serve(opts *ServeConfig) {
|
||||
// Validate the handshake config
|
||||
if opts.MagicCookieKey == "" || opts.MagicCookieValue == "" {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"Misconfigured ServeConfig given to serve this plugin: no magic cookie\n"+
|
||||
"key or value was set. Please notify the plugin author and report\n"+
|
||||
"this as a bug.\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// First check the cookie
|
||||
if os.Getenv(opts.MagicCookieKey) != opts.MagicCookieValue {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"This binary is a plugin. These are not meant to be executed directly.\n"+
|
||||
"Please execute the program that consumes these plugins, which will\n"+
|
||||
"load any plugins automatically\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Logging goes to the original stderr
|
||||
log.SetOutput(os.Stderr)
|
||||
|
||||
// Create our new stdout, stderr files. These will override our built-in
|
||||
// stdout/stderr so that it works across the stream boundary.
|
||||
stdout_r, stdout_w, err := os.Pipe()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error preparing plugin: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
stderr_r, stderr_w, err := os.Pipe()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error preparing plugin: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Register a listener so we can accept a connection
|
||||
listener, err := serverListener()
|
||||
if err != nil {
|
||||
log.Printf("[ERR] plugin: plugin init: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if opts.TLSProvider != nil {
|
||||
tlsConfig, err := opts.TLSProvider()
|
||||
if err != nil {
|
||||
log.Printf("[ERR] plugin: plugin tls init: %s", err)
|
||||
return
|
||||
}
|
||||
listener = tls.NewListener(listener, tlsConfig)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
// Create the channel to tell us when we're done
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
// Create the RPC server to dispense
|
||||
server := &RPCServer{
|
||||
Plugins: opts.Plugins,
|
||||
Stdout: stdout_r,
|
||||
Stderr: stderr_r,
|
||||
DoneCh: doneCh,
|
||||
}
|
||||
|
||||
// Output the address and service name to stdout so that core can bring it up.
|
||||
log.Printf("[DEBUG] plugin: plugin address: %s %s\n",
|
||||
listener.Addr().Network(), listener.Addr().String())
|
||||
fmt.Printf("%d|%d|%s|%s\n",
|
||||
CoreProtocolVersion,
|
||||
opts.ProtocolVersion,
|
||||
listener.Addr().Network(),
|
||||
listener.Addr().String())
|
||||
os.Stdout.Sync()
|
||||
|
||||
// Eat the interrupts
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch, os.Interrupt)
|
||||
go func() {
|
||||
var count int32 = 0
|
||||
for {
|
||||
<-ch
|
||||
newCount := atomic.AddInt32(&count, 1)
|
||||
log.Printf(
|
||||
"[DEBUG] plugin: received interrupt signal (count: %d). Ignoring.",
|
||||
newCount)
|
||||
}
|
||||
}()
|
||||
|
||||
// Set our new out, err
|
||||
os.Stdout = stdout_w
|
||||
os.Stderr = stderr_w
|
||||
|
||||
// Serve
|
||||
go server.Accept(listener)
|
||||
|
||||
// Wait for the graceful exit
|
||||
<-doneCh
|
||||
}
|
||||
|
||||
func serverListener() (net.Listener, error) {
|
||||
if runtime.GOOS == "windows" {
|
||||
return serverListener_tcp()
|
||||
}
|
||||
|
||||
return serverListener_unix()
|
||||
}
|
||||
|
||||
func serverListener_tcp() (net.Listener, error) {
|
||||
minPort, err := strconv.ParseInt(os.Getenv("PLUGIN_MIN_PORT"), 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxPort, err := strconv.ParseInt(os.Getenv("PLUGIN_MAX_PORT"), 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for port := minPort; port <= maxPort; port++ {
|
||||
address := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
listener, err := net.Listen("tcp", address)
|
||||
if err == nil {
|
||||
return listener, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("Couldn't bind plugin TCP listener")
|
||||
}
|
||||
|
||||
func serverListener_unix() (net.Listener, error) {
|
||||
tf, err := ioutil.TempFile("", "plugin")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := tf.Name()
|
||||
|
||||
// Close the file and remove it because it has to not exist for
|
||||
// the domain socket.
|
||||
if err := tf.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := os.Remove(path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l, err := net.Listen("unix", path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wrap the listener in rmListener so that the Unix domain socket file
|
||||
// is removed on close.
|
||||
return &rmListener{
|
||||
Listener: l,
|
||||
Path: path,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// rmListener is an implementation of net.Listener that forwards most
|
||||
// calls to the listener but also removes a file as part of the close. We
|
||||
// use this to cleanup the unix domain socket on close.
|
||||
type rmListener struct {
|
||||
net.Listener
|
||||
Path string
|
||||
}
|
||||
|
||||
func (l *rmListener) Close() error {
|
||||
// Close the listener itself
|
||||
if err := l.Listener.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the file
|
||||
return os.Remove(l.Path)
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// ServeMuxMap is the type that is used to configure ServeMux
|
||||
type ServeMuxMap map[string]*ServeConfig
|
||||
|
||||
// ServeMux is like Serve, but serves multiple types of plugins determined
|
||||
// by the argument given on the command-line.
|
||||
//
|
||||
// This command doesn't return until the plugin is done being executed. Any
|
||||
// errors are logged or output to stderr.
|
||||
func ServeMux(m ServeMuxMap) {
|
||||
if len(os.Args) != 2 {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"Invoked improperly. This is an internal command that shouldn't\n"+
|
||||
"be manually invoked.\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
opts, ok := m[os.Args[1]]
|
||||
if !ok {
|
||||
fmt.Fprintf(os.Stderr, "Unknown plugin: %s\n", os.Args[1])
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
Serve(opts)
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
)
|
||||
|
||||
func copyStream(name string, dst io.Writer, src io.Reader) {
|
||||
if src == nil {
|
||||
panic(name + ": src is nil")
|
||||
}
|
||||
if dst == nil {
|
||||
panic(name + ": dst is nil")
|
||||
}
|
||||
if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
|
||||
log.Printf("[ERR] plugin: stream copy '%s' error: %s", name, err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"net/rpc"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// The testing file contains test helpers that you can use outside of
|
||||
// this package for making it easier to test plugins themselves.
|
||||
|
||||
// TestConn is a helper function for returning a client and server
|
||||
// net.Conn connected to each other.
|
||||
func TestConn(t *testing.T) (net.Conn, net.Conn) {
|
||||
// Listen to any local port. This listener will be closed
|
||||
// after a single connection is established.
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Start a goroutine to accept our client connection
|
||||
var serverConn net.Conn
|
||||
doneCh := make(chan struct{})
|
||||
go func() {
|
||||
defer close(doneCh)
|
||||
defer l.Close()
|
||||
var err error
|
||||
serverConn, err = l.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Connect to the server
|
||||
clientConn, err := net.Dial("tcp", l.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Wait for the server side to acknowledge it has connected
|
||||
<-doneCh
|
||||
|
||||
return clientConn, serverConn
|
||||
}
|
||||
|
||||
// TestRPCConn returns a rpc client and server connected to each other.
|
||||
func TestRPCConn(t *testing.T) (*rpc.Client, *rpc.Server) {
|
||||
clientConn, serverConn := TestConn(t)
|
||||
|
||||
server := rpc.NewServer()
|
||||
go server.ServeConn(serverConn)
|
||||
|
||||
client := rpc.NewClient(clientConn)
|
||||
return client, server
|
||||
}
|
||||
|
||||
// TestPluginRPCConn returns a plugin RPC client and server that are connected
|
||||
// together and configured.
|
||||
func TestPluginRPCConn(t *testing.T, ps map[string]Plugin) (*RPCClient, *RPCServer) {
|
||||
// Create two net.Conns we can use to shuttle our control connection
|
||||
clientConn, serverConn := TestConn(t)
|
||||
|
||||
// Start up the server
|
||||
server := &RPCServer{Plugins: ps, Stdout: new(bytes.Buffer), Stderr: new(bytes.Buffer)}
|
||||
go server.ServeConn(serverConn)
|
||||
|
||||
// Connect the client to the server
|
||||
client, err := NewRPCClient(clientConn, ps)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
return client, server
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,5 @@
|
|||
otp
|
||||
Copyright (c) 2014, Paul Querna
|
||||
|
||||
This product includes software developed by
|
||||
Paul Querna (http://paul.querna.org/).
|
|
@ -0,0 +1,60 @@
|
|||
# otp: One Time Password utilities Go / Golang
|
||||
|
||||
[![GoDoc](https://godoc.org/github.com/pquerna/otp?status.svg)](https://godoc.org/github.com/pquerna/otp) [![Build Status](https://travis-ci.org/pquerna/otp.svg?branch=master)](https://travis-ci.org/pquerna/otp)
|
||||
|
||||
# Why One Time Passwords?
|
||||
|
||||
One Time Passwords (OTPs) are an mechanism to improve security over passwords alone. When a Time-based OTP (TOTP) is stored on a user's phone, and combined with something the user knows (Password), you have an easy on-ramp to [Multi-factor authentication](http://en.wikipedia.org/wiki/Multi-factor_authentication) without adding a dependency on a SMS provider. This Password and TOTP combination is used by many popular websites including Google, Github, Facebook, Salesforce and many others.
|
||||
|
||||
The `otp` library enables you to easily add TOTPs to your own application, increasing your user's security against mass-password breaches and malware.
|
||||
|
||||
Because TOTP is standardized and widely deployed, there are many [mobile clients and software implementations](http://en.wikipedia.org/wiki/Time-based_One-time_Password_Algorithm#Client_implementations).
|
||||
|
||||
## `otp` Supports:
|
||||
|
||||
* Generating QR Code images for easy user enrollment.
|
||||
* Time-based One-time Password Algorithm (TOTP) (RFC 6238): Time based OTP, the most commonly used method.
|
||||
* HMAC-based One-time Password Algorithm (HOTP) (RFC 4226): Counter based OTP, which TOTP is based upon.
|
||||
* Generation and Validation of codes for either algorithm.
|
||||
|
||||
## Implementing TOTP in your application:
|
||||
|
||||
### User Enrollment
|
||||
|
||||
For an example of a working enrollment work flow, [Github has documented theirs](https://help.github.com/articles/configuring-two-factor-authentication-via-a-totp-mobile-app/
|
||||
), but the basics are:
|
||||
|
||||
1. Generate new TOTP Key for a User. `key,_ := totp.Generate(...)`.
|
||||
1. Display the Key's Secret and QR-Code for the User. `key.Secret()` and `key.Image(...)`.
|
||||
1. Test that the user can successfully use their TOTP. `totp.Validate(...)`.
|
||||
1. Store TOTP Secret for the User in your backend. `key.Secret()`
|
||||
1. Provide the user with "recovery codes". (See Recovery Codes bellow)
|
||||
|
||||
### Code Generation
|
||||
|
||||
* In either TOTP or HOTP cases, use the `GenerateCode` function and a counter or
|
||||
`time.Time` struct to generate a valid code compatible with most implementations.
|
||||
* For uncommon or custom settings, or to catch unlikely errors, use `GenerateCodeCustom`
|
||||
in either module.
|
||||
|
||||
### Validation
|
||||
|
||||
1. Prompt and validate User's password as normal.
|
||||
1. If the user has TOTP enabled, prompt for TOTP passcode.
|
||||
1. Retrieve the User's TOTP Secret from your backend.
|
||||
1. Validate the user's passcode. `totp.Validate(...)`
|
||||
|
||||
|
||||
### Recovery Codes
|
||||
|
||||
When a user loses access to their TOTP device, they would no longer have access to their account. Because TOTPs are often configured on mobile devices that can be lost, stolen or damaged, this is a common problem. For this reason many providers give their users "backup codes" or "recovery codes". These are a set of one time use codes that can be used instead of the TOTP. These can simply be randomly generated strings that you store in your backend. [Github's documentation provides an overview of the user experience](
|
||||
https://help.github.com/articles/downloading-your-two-factor-authentication-recovery-codes/).
|
||||
|
||||
|
||||
## Improvements, bugs, adding feature, etc:
|
||||
|
||||
Please [open issues in Github](https://github.com/pquerna/otp/issues) for ideas, bugs, and general thoughts. Pull requests are of course preferred :)
|
||||
|
||||
## License
|
||||
|
||||
`otp` is licensed under the [Apache License, Version 2.0](./LICENSE)
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2014 Paul Querna
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
// Package otp implements both HOTP and TOTP based
|
||||
// one time passcodes in a Google Authenticator compatible manner.
|
||||
//
|
||||
// When adding a TOTP for a user, you must store the "secret" value
|
||||
// persistently. It is recommend to store the secret in an encrypted field in your
|
||||
// datastore. Due to how TOTP works, it is not possible to store a hash
|
||||
// for the secret value like you would a password.
|
||||
//
|
||||
// To enroll a user, you must first generate an OTP for them. Google
|
||||
// Authenticator supports using a QR code as an enrollment method:
|
||||
//
|
||||
// import (
|
||||
// "github.com/pquerna/otp/totp"
|
||||
//
|
||||
// "bytes"
|
||||
// "image/png"
|
||||
// )
|
||||
//
|
||||
// key, err := totp.Generate(totp.GenerateOpts{
|
||||
// Issuer: "Example.com",
|
||||
// AccountName: "alice@example.com",
|
||||
// })
|
||||
//
|
||||
// // Convert TOTP key into a QR code encoded as a PNG image.
|
||||
// var buf bytes.Buffer
|
||||
// img, err := key.Image(200, 200)
|
||||
// png.Encode(&buf, img)
|
||||
//
|
||||
// // display the QR code to the user.
|
||||
// display(buf.Bytes())
|
||||
//
|
||||
// // Now Validate that the user's successfully added the passcode.
|
||||
// passcode := promptForPasscode()
|
||||
// valid := totp.Validate(passcode, key.Secret())
|
||||
//
|
||||
// if valid {
|
||||
// // User successfully used their TOTP, save it to your backend!
|
||||
// storeSecret("alice@example.com", key.Secret())
|
||||
// }
|
||||
//
|
||||
// Validating a TOTP passcode is very easy, just prompt the user for a passcode
|
||||
// and retrieve the associated user's previously stored secret.
|
||||
// import "github.com/pquerna/otp/totp"
|
||||
//
|
||||
// passcode := promptForPasscode()
|
||||
// secret := getSecret("alice@example.com")
|
||||
//
|
||||
// valid := totp.Validate(passcode, secret)
|
||||
//
|
||||
// if valid {
|
||||
// // Success! continue login process.
|
||||
// }
|
||||
package otp
|
|
@ -0,0 +1,63 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/pquerna/otp"
|
||||
"github.com/pquerna/otp/totp"
|
||||
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image/png"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
func display(key *otp.Key, data []byte) {
|
||||
fmt.Printf("Issuer: %s\n", key.Issuer())
|
||||
fmt.Printf("Account Name: %s\n", key.AccountName())
|
||||
fmt.Printf("Secret: %s\n", key.Secret())
|
||||
fmt.Println("Writing PNG to qr-code.png....")
|
||||
ioutil.WriteFile("qr-code.png", data, 0644)
|
||||
fmt.Println("")
|
||||
fmt.Println("Please add your TOTP to your OTP Application now!")
|
||||
fmt.Println("")
|
||||
}
|
||||
|
||||
func promptForPasscode() string {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
fmt.Print("Enter Passcode: ")
|
||||
text, _ := reader.ReadString('\n')
|
||||
return text
|
||||
}
|
||||
|
||||
func main() {
|
||||
key, err := totp.Generate(totp.GenerateOpts{
|
||||
Issuer: "Example.com",
|
||||
AccountName: "alice@example.com",
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Convert TOTP key into a PNG
|
||||
var buf bytes.Buffer
|
||||
img, err := key.Image(200, 200)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
png.Encode(&buf, img)
|
||||
|
||||
// display the QR code to the user.
|
||||
display(key, buf.Bytes())
|
||||
|
||||
// Now Validate that the user's successfully added the passcode.
|
||||
fmt.Println("Validating TOTP...")
|
||||
passcode := promptForPasscode()
|
||||
valid := totp.Validate(passcode, key.Secret())
|
||||
if valid {
|
||||
println("Valid passcode!")
|
||||
os.Exit(0)
|
||||
} else {
|
||||
println("Invalid passocde!")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue