Sync some ns stuff to api/command

This commit is contained in:
Jeff Mitchell 2018-08-22 14:37:40 -04:00
parent d5a3010498
commit 66a0029195
14 changed files with 237 additions and 233 deletions

View File

@ -19,6 +19,7 @@ import (
"github.com/hashicorp/go-cleanhttp"
retryablehttp "github.com/hashicorp/go-retryablehttp"
"github.com/hashicorp/go-rootcerts"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/parseutil"
"golang.org/x/net/http2"
"golang.org/x/time/rate"
@ -474,7 +475,7 @@ func (c *Client) SetNamespace(namespace string) {
c.headers = make(http.Header)
}
c.headers.Set("X-Vault-Namespace", namespace)
c.headers.Set(consts.NamespaceHeaderName, namespace)
}
// Token returns the access token being used by this client. It will

View File

@ -119,4 +119,6 @@ type GenerateRootStatusResponse struct {
EncodedToken string `json:"encoded_token"`
EncodedRootToken string `json:"encoded_root_token"`
PGPFingerprint string `json:"pgp_fingerprint"`
OTP string `json:"otp"`
OTPLength int `json:"otp_length"`
}

View File

@ -1,97 +0,0 @@
package api
import (
"fmt"
"net/http"
)
// ListNamespacesResponse is the response from the ListNamespaces call.
type ListNamespacesResponse struct {
// NamespacePaths is the list of child namespace paths
NamespacePaths []string `json:"namespace_paths"`
}
type GetNamespaceResponse struct {
Path string `json:"path"`
}
// ListNamespaces lists any existing namespace relative to the namespace
// provided in the client's namespace header.
func (c *Sys) ListNamespaces() (*ListNamespacesResponse, error) {
r := c.c.NewRequest("LIST", "/v1/sys/namespaces")
resp, err := c.c.RawRequest(r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result struct {
Data struct {
Keys []string `json:"keys"`
} `json:"data"`
}
err = resp.DecodeJSON(&result)
if err != nil {
return nil, err
}
return &ListNamespacesResponse{NamespacePaths: result.Data.Keys}, nil
}
// GetNamespace returns namespace information
func (c *Sys) GetNamespace(path string) (*GetNamespaceResponse, error) {
r := c.c.NewRequest("GET", fmt.Sprintf("/v1/sys/namespaces/%s", path))
resp, err := c.c.RawRequest(r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, nil
}
ret := &GetNamespaceResponse{}
result := map[string]interface{}{
"data": map[string]interface{}{},
}
if err := resp.DecodeJSON(&result); err != nil {
return nil, err
}
if data, ok := result["data"]; ok {
if pathOk, ok := data.(map[string]interface{})["path"]; ok {
if pathRaw, ok := pathOk.(string); ok {
ret.Path = pathRaw
}
}
}
return ret, nil
}
// CreateNamespace creates a new namespace relative to the namespace provided
// in the client's namespace header.
func (c *Sys) CreateNamespace(path string) error {
r := c.c.NewRequest("POST", fmt.Sprintf("/v1/sys/namespaces/%s", path))
resp, err := c.c.RawRequest(r)
if err != nil {
return err
}
defer resp.Body.Close()
return nil
}
// DeleteNamespace delete an existing namespace relative to the namespace
// provided in the client's namespace header.
func (c *Sys) DeleteNamespace(path string) error {
r := c.c.NewRequest("DELETE", fmt.Sprintf("/v1/sys/namespaces/%s", path))
resp, err := c.c.RawRequest(r)
if err != nil {
return err
}
defer resp.Body.Close()
return nil
}

View File

@ -20,8 +20,13 @@ import (
"github.com/posener/complete"
)
// maxLineLength is the maximum width of any line.
const maxLineLength int = 78
const (
// maxLineLength is the maximum width of any line.
maxLineLength int = 78
// notSetNamespace is a flag value for a not-set namespace
notSetNamespace = "(not set)"
)
// reRemoveWhitespace is a regular expression for stripping whitespace from
// a string.
@ -39,6 +44,7 @@ type BaseCommand struct {
flagClientCert string
flagClientKey string
flagNamespace string
flagNS string
flagTLSServerName string
flagTLSSkipVerify bool
flagWrapTTL time.Duration
@ -120,7 +126,12 @@ func (c *BaseCommand) Client() (*api.Client, error) {
}
client.SetMFACreds(c.flagMFA)
client.SetNamespace(namespace.Canonicalize(c.flagNamespace))
switch {
case c.flagNS != notSetNamespace:
client.SetNamespace(namespace.Canonicalize(c.flagNS))
case c.flagNamespace != notSetNamespace:
client.SetNamespace(namespace.Canonicalize(c.flagNamespace))
}
c.client = client
@ -242,11 +253,21 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
f.StringVar(&StringVar{
Name: "namespace",
Target: &c.flagNamespace,
Default: "",
Default: notSetNamespace, // this can never be a real value
EnvVar: "VAULT_NAMESPACE",
Completion: complete.PredictAnything,
Usage: "The namespace to use for the command. Setting this is not " +
"necessary but allows using relative paths.",
"necessary but allows using relative paths. -ns can be used as " +
"shortcut.",
})
f.StringVar(&StringVar{
Name: "ns",
Target: &c.flagNS,
Default: notSetNamespace, // this can never be a real value
Completion: complete.PredictAnything,
Hidden: true,
Usage: "Alias for -namespace.",
})
f.StringVar(&StringVar{

View File

@ -43,7 +43,7 @@ type BoolVar struct {
func (f *FlagSet) BoolVar(i *BoolVar) {
def := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if v, exist := os.LookupEnv(i.EnvVar); exist {
if b, err := strconv.ParseBool(v); err == nil {
def = b
}
@ -104,7 +104,7 @@ type IntVar struct {
func (f *FlagSet) IntVar(i *IntVar) {
initial := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if v, exist := os.LookupEnv(i.EnvVar); exist {
if i, err := strconv.ParseInt(v, 0, 64); err == nil {
initial = int(i)
}
@ -168,7 +168,7 @@ type Int64Var struct {
func (f *FlagSet) Int64Var(i *Int64Var) {
initial := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if v, exist := os.LookupEnv(i.EnvVar); exist {
if i, err := strconv.ParseInt(v, 0, 64); err == nil {
initial = i
}
@ -232,7 +232,7 @@ type UintVar struct {
func (f *FlagSet) UintVar(i *UintVar) {
initial := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if v, exist := os.LookupEnv(i.EnvVar); exist {
if i, err := strconv.ParseUint(v, 0, 64); err == nil {
initial = uint(i)
}
@ -296,7 +296,7 @@ type Uint64Var struct {
func (f *FlagSet) Uint64Var(i *Uint64Var) {
initial := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if v, exist := os.LookupEnv(i.EnvVar); exist {
if i, err := strconv.ParseUint(v, 0, 64); err == nil {
initial = i
}
@ -360,7 +360,7 @@ type StringVar struct {
func (f *FlagSet) StringVar(i *StringVar) {
initial := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if v, exist := os.LookupEnv(i.EnvVar); exist {
initial = v
}
@ -417,7 +417,7 @@ type Float64Var struct {
func (f *FlagSet) Float64Var(i *Float64Var) {
initial := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if v, exist := os.LookupEnv(i.EnvVar); exist {
if i, err := strconv.ParseFloat(v, 64); err == nil {
initial = i
}
@ -481,7 +481,7 @@ type DurationVar struct {
func (f *FlagSet) DurationVar(i *DurationVar) {
initial := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if v, exist := os.LookupEnv(i.EnvVar); exist {
if d, err := time.ParseDuration(appendDurationSuffix(v)); err == nil {
initial = d
}
@ -558,7 +558,7 @@ type StringSliceVar struct {
func (f *FlagSet) StringSliceVar(i *StringSliceVar) {
initial := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if v, exist := os.LookupEnv(i.EnvVar); exist {
parts := strings.Split(v, ",")
for i := range parts {
parts[i] = strings.TrimSpace(parts[i])
@ -751,14 +751,14 @@ func (f *FlagSet) Var(value flag.Value, name, usage string) {
// -- helpers
func envDefault(key, def string) string {
if v := os.Getenv(key); v != "" {
if v, exist := os.LookupEnv(key); exist {
return v
}
return def
}
func envBoolDefault(key string, def bool) bool {
if v := os.Getenv(key); v != "" {
if v, exist := os.LookupEnv(key); exist {
b, err := strconv.ParseBool(v)
if err != nil {
panic(err)
@ -769,7 +769,7 @@ func envBoolDefault(key string, def bool) bool {
}
func envDurationDefault(key string, def time.Duration) time.Duration {
if v := os.Getenv(key); v != "" {
if v, exist := os.LookupEnv(key); exist {
d, err := time.ParseDuration(v)
if err != nil {
panic(err)

View File

@ -21,9 +21,9 @@ func (c *NamespaceCommand) Help() string {
Usage: vault namespace <subcommand> [options] [args]
This command groups subcommands for interacting with Vault namespaces.
These set of subcommands operate on the context of the namespace that the
current logged in token belongs to.
These subcommands operate in the context of the namespace that the
currently logged in token belongs to.
List enabled child namespaces:
$ vault namespace list

View File

@ -79,14 +79,19 @@ func (c *NamespaceCreateCommand) Run(args []string) int {
return 2
}
err = client.Sys().CreateNamespace(namespacePath)
_, err = client.Logical().Write("sys/namespaces/"+namespacePath, nil)
if err != nil {
c.UI.Error(fmt.Sprintf("Error creating namespace: %s", err))
return 2
}
if !strings.HasSuffix(namespacePath, "/") {
namespacePath = namespacePath + "/"
}
if c.flagNamespace != notSetNamespace {
namespacePath = path.Join(c.flagNamespace, namespacePath)
}
// Output full path
fullPath := path.Join(c.flagNamespace, namespacePath) + "/"
c.UI.Output(fmt.Sprintf("Success! Namespace created at: %s", fullPath))
c.UI.Output(fmt.Sprintf("Success! Namespace created at: %s", namespacePath))
return 0
}

View File

@ -79,14 +79,24 @@ func (c *NamespaceDeleteCommand) Run(args []string) int {
return 2
}
err = client.Sys().DeleteNamespace(namespacePath)
secret, err := client.Logical().Delete("sys/namespaces/" + namespacePath)
if err != nil {
c.UI.Error(fmt.Sprintf("Error deleting namespace: %s", err))
return 2
}
// Output full path
fullPath := path.Join(c.flagNamespace, namespacePath) + "/"
c.UI.Output(fmt.Sprintf("Success! Namespace deleted at: %s", fullPath))
if secret != nil {
// Likely, we have warnings
return OutputSecret(c.UI, secret)
}
if !strings.HasSuffix(namespacePath, "/") {
namespacePath = namespacePath + "/"
}
if c.flagNamespace != notSetNamespace {
namespacePath = path.Join(c.flagNamespace, namespacePath)
}
c.UI.Output(fmt.Sprintf("Success! Namespace deleted at: %s", namespacePath))
return 0
}

View File

@ -66,19 +66,29 @@ func (c *NamespaceListCommand) Run(args []string) int {
return 2
}
namespaces, err := client.Sys().ListNamespaces()
secret, err := client.Logical().List("sys/namespaces")
if err != nil {
c.UI.Error(fmt.Sprintf("Error listing namespaces: %s", err))
return 2
}
switch Format(c.UI) {
case "table":
for _, ns := range namespaces.NamespacePaths {
c.UI.Output(ns)
}
return 0
default:
return OutputData(c.UI, namespaces)
if secret == nil {
c.UI.Error(fmt.Sprintf("No namespaces found"))
return 2
}
// There could be e.g. warnings
if secret.Data == nil {
return OutputSecret(c.UI, secret)
}
if secret.WrapInfo != nil && secret.WrapInfo.TTL != 0 {
return OutputSecret(c.UI, secret)
}
if _, ok := extractListData(secret); !ok {
c.UI.Error(fmt.Sprintf("No entries found"))
return 2
}
return OutputList(c.UI, secret)
}

View File

@ -16,14 +16,14 @@ type NamespaceLookupCommand struct {
}
func (c *NamespaceLookupCommand) Synopsis() string {
return "Create a new namespace"
return "Look up an existing namespace"
}
func (c *NamespaceLookupCommand) Help() string {
helpText := `
Usage: vault namespace create [options] PATH
Create a child namespace. The namespace created will be relative to the
Create a child namespace. The namespace created will be relative to the
namespace provided in either VAULT_NAMESPACE environemnt variable or
-namespace CLI flag.
@ -33,7 +33,7 @@ Usage: vault namespace create [options] PATH
Get information about the namespace of a particular child token (e.g. ns1/ns2/):
$ vault namespace create -namespace=ns1 ns2
$ vault namespace lookup -namespace=ns1 ns2
` + c.Flags().Help()
@ -78,19 +78,15 @@ func (c *NamespaceLookupCommand) Run(args []string) int {
return 2
}
resp, err := client.Sys().GetNamespace(namespacePath)
secret, err := client.Logical().Read("sys/namespaces/" + namespacePath)
if err != nil {
c.UI.Error(fmt.Sprintf("Error looking up namespace: %s", err))
return 2
}
switch Format(c.UI) {
case "table":
data := map[string]interface{}{
"path": resp.Path,
}
return OutputData(c.UI, data)
default:
return OutputData(c.UI, resp)
if secret == nil {
c.UI.Error("Namespace not found")
return 2
}
return OutputSecret(c.UI, secret)
}

View File

@ -12,6 +12,7 @@ import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/base62"
"github.com/hashicorp/vault/helper/password"
"github.com/hashicorp/vault/helper/pgpkeys"
"github.com/hashicorp/vault/helper/xor"
@ -123,8 +124,7 @@ func (c *OperatorGenerateRootCommand) Flags() *FlagSets {
Default: "",
EnvVar: "",
Completion: complete.PredictAnything,
Usage: "Decode and output the generated root token. This option requires " +
"the \"-otp\" flag be set to the OTP used during initialization.",
Usage: "The value to decode; setting this triggers a decode operation.",
})
f.BoolVar(&BoolVar{
@ -233,9 +233,13 @@ func (c *OperatorGenerateRootCommand) Run(args []string) int {
switch {
case c.flagGenerateOTP:
return c.generateOTP()
otp, code := c.generateOTP(client, c.flagDRToken)
if code == 0 {
return PrintRaw(c.UI, otp)
}
return code
case c.flagDecode != "":
return c.decode(c.flagDecode, c.flagOTP)
return c.decode(client, c.flagDecode, c.flagOTP, c.flagDRToken)
case c.flagCancel:
return c.cancel(client, c.flagDRToken)
case c.flagInit:
@ -252,41 +256,48 @@ func (c *OperatorGenerateRootCommand) Run(args []string) int {
}
}
// verifyOTP verifies the given OTP code is exactly 16 bytes.
func (c *OperatorGenerateRootCommand) verifyOTP(otp string) error {
if len(otp) == 0 {
return fmt.Errorf("no OTP passed in")
}
otpBytes, err := base64.StdEncoding.DecodeString(otp)
if err != nil {
return errwrap.Wrapf("error decoding base64 OTP value: {{err}}", err)
}
if otpBytes == nil || len(otpBytes) != 16 {
return fmt.Errorf("decoded OTP value is invalid or wrong length")
}
return nil
}
// generateOTP generates a suitable OTP code for generating a root token.
func (c *OperatorGenerateRootCommand) generateOTP() int {
buf := make([]byte, 16)
readLen, err := rand.Read(buf)
func (c *OperatorGenerateRootCommand) generateOTP(client *api.Client, drToken bool) (string, int) {
f := client.Sys().GenerateRootStatus
if drToken {
f = client.Sys().GenerateDROperationTokenStatus
}
status, err := f()
if err != nil {
c.UI.Error(fmt.Sprintf("Error reading random bytes: %s", err))
return 2
c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err))
return "", 2
}
if readLen != 16 {
c.UI.Error(fmt.Sprintf("Read %d bytes when we should have read 16", readLen))
return 2
}
switch status.OTPLength {
case 0:
// This is the fallback case
buf := make([]byte, 16)
readLen, err := rand.Read(buf)
if err != nil {
c.UI.Error(fmt.Sprintf("Error reading random bytes: %s", err))
return "", 2
}
return PrintRaw(c.UI, base64.StdEncoding.EncodeToString(buf))
if readLen != 16 {
c.UI.Error(fmt.Sprintf("Read %d bytes when we should have read 16", readLen))
return "", 2
}
return base64.StdEncoding.EncodeToString(buf), 0
default:
otp, err := base62.Random(status.OTPLength, true)
if err != nil {
c.UI.Error(errwrap.Wrapf("Error reading random bytes: {{err}}", err).Error())
return "", 2
}
return otp, 0
}
}
// decode decodes the given value using the otp.
func (c *OperatorGenerateRootCommand) decode(encoded, otp string) int {
func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp string, drToken bool) int {
if encoded == "" {
c.UI.Error("Missing encoded value: use -decode=<string> to supply it")
return 1
@ -296,38 +307,56 @@ func (c *OperatorGenerateRootCommand) decode(encoded, otp string) int {
return 1
}
tokenBytes, err := xor.XORBase64(encoded, otp)
f := client.Sys().GenerateRootStatus
if drToken {
f = client.Sys().GenerateDROperationTokenStatus
}
status, err := f()
if err != nil {
c.UI.Error(fmt.Sprintf("Error xoring token: %s", err))
return 1
c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err))
return 2
}
token, err := uuid.FormatUUID(tokenBytes)
if err != nil {
c.UI.Error(fmt.Sprintf("Error formatting base64 token value: %s", err))
return 1
}
switch status.OTPLength {
case 0:
// Backwards compat
tokenBytes, err := xor.XORBase64(encoded, otp)
if err != nil {
c.UI.Error(fmt.Sprintf("Error xoring token: %s", err))
return 1
}
return PrintRaw(c.UI, strings.TrimSpace(token))
token, err := uuid.FormatUUID(tokenBytes)
if err != nil {
c.UI.Error(fmt.Sprintf("Error formatting base64 token value: %s", err))
return 1
}
return PrintRaw(c.UI, strings.TrimSpace(token))
default:
tokenBytes, err := base64.RawStdEncoding.DecodeString(encoded)
if err != nil {
c.UI.Error(errwrap.Wrapf("Error decoding base64'd token: {{err}}", err).Error())
return 1
}
tokenBytes, err = xor.XORBytes(tokenBytes, []byte(otp))
if err != nil {
c.UI.Error(errwrap.Wrapf("Error xoring token: {{err}}", err).Error())
return 1
}
return PrintRaw(c.UI, string(tokenBytes))
}
}
// init is used to start the generation process
func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey string, drToken bool) int {
// Validate incoming fields. Either OTP OR PGP keys must be supplied.
switch {
case otp == "" && pgpKey == "":
c.UI.Error("Error initializing: must specify either -otp or -pgp-key")
return 1
case otp != "" && pgpKey != "":
if otp != "" && pgpKey != "" {
c.UI.Error("Error initializing: cannot specify both -otp and -pgp-key")
return 1
case otp != "":
if err := c.verifyOTP(otp); err != nil {
c.UI.Error(fmt.Sprintf("Error initializing: invalid OTP: %s", err))
return 1
}
case pgpKey != "":
// OK
}
// Start the root generation
@ -368,6 +397,10 @@ func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, dr
c.UI.Error(wrapAtLength(
"No root generation is in progress. Start a root generation by " +
"running \"vault operator generate-root -init\"."))
c.UI.Warn(wrapAtLength(fmt.Sprintf(
"If starting root generation using the OTP method and generating "+
"your own OTP, the length of the OTP string needs to be %d "+
"characters in length.", status.OTPLength)))
return 1
}
@ -494,6 +527,13 @@ func (c *OperatorGenerateRootCommand) printStatus(status *api.GenerateRootStatus
case status.EncodedRootToken != "":
out = append(out, fmt.Sprintf("Encoded Root Token | %s", status.EncodedRootToken))
}
if status.OTP != "" {
c.UI.Warn(wrapAtLength("A One-Time-Password has been generated for you and is shown in the OTP field. You will need this value to decode the resulting root token, so keep it safe."))
out = append(out, fmt.Sprintf("OTP | %s", status.OTP))
}
if status.OTPLength != 0 {
out = append(out, fmt.Sprintf("OTP Length | %d", status.OTPLength))
}
output := columnOutput(out, nil)
c.UI.Output(output)

View File

@ -34,22 +34,14 @@ func TestOperatorGenerateRootCommand_Run(t *testing.T) {
out string
code int
}{
{
"init_no_args",
[]string{
"-init",
},
"must specify either -otp or -pgp-key",
1,
},
{
"init_invalid_otp",
[]string{
"-init",
"-otp", "not-a-valid-otp",
},
"Error initializing: invalid OTP:",
1,
"illegal base64 data at input",
2,
},
{
"init_pgp_multi",
@ -99,12 +91,12 @@ func TestOperatorGenerateRootCommand_Run(t *testing.T) {
code := cmd.Run(tc.args)
if code != tc.code {
t.Errorf("expected %d to be %d", code, tc.code)
t.Errorf("%s: expected %d to be %d", tc.name, code, tc.code)
}
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, tc.out) {
t.Errorf("expected %q to contain %q", combined, tc.out)
t.Errorf("%s: expected %q to contain %q", tc.name, combined, tc.out)
}
})
}
@ -116,7 +108,7 @@ func TestOperatorGenerateRootCommand_Run(t *testing.T) {
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testOperatorGenerateRootCommand(t)
_, cmd := testOperatorGenerateRootCommand(t)
cmd.client = client
code := cmd.Run([]string{
@ -125,11 +117,6 @@ func TestOperatorGenerateRootCommand_Run(t *testing.T) {
if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}
output := ui.OutputWriter.String() + ui.ErrorWriter.String()
if err := cmd.verifyOTP(output); err != nil {
t.Fatal(err)
}
})
t.Run("decode", func(t *testing.T) {

View File

@ -4,4 +4,11 @@ const (
// ExpirationRestoreWorkerCount specifies the number of workers to use while
// restoring leases into the expiration manager
ExpirationRestoreWorkerCount = 64
// NamespaceHeaderName is the header set to specify which namespace the
// request is indented for.
NamespaceHeaderName = "X-Vault-Namespace"
// AuthHeaderName is the name of the header containing the token.
AuthHeaderName = "X-Vault-Token"
)

View File

@ -3,6 +3,7 @@ package namespace
import (
"context"
"errors"
"net/http"
"strings"
)
@ -16,6 +17,11 @@ type nsContext struct {
type contextValues struct{}
type Namespace struct {
ID string `json:"id"`
Path string `json:"path"`
}
const (
RootNamespaceID = "root"
)
@ -23,18 +29,14 @@ const (
var (
contextNamespace contextValues = struct{}{}
ErrNoNamespace error = errors.New("no namespace")
RootNamespace *Namespace = &Namespace{
ID: RootNamespaceID,
Path: "",
}
)
type Namespace struct {
ID string `json:"id"`
Path string `json:"path"`
}
func New(id, path string) *Namespace {
return &Namespace{
ID: id,
Path: path,
}
var AdjustRequest = func(r *http.Request) (*http.Request, int) {
return r.WithContext(ContextWithNamespace(r.Context(), RootNamespace)), 0
}
func (n *Namespace) HasParent(possibleParent *Namespace) bool {
@ -60,6 +62,18 @@ func ContextWithNamespace(ctx context.Context, ns *Namespace) context.Context {
}
}
func RootContext(ctx context.Context) context.Context {
if ctx == nil {
return ContextWithNamespace(context.Background(), RootNamespace)
}
return ContextWithNamespace(ctx, RootNamespace)
}
// This function caches the ns to avoid doing a .Value lookup over and over,
// because it's called a *lot* in the request critical path. .Value is
// concurrency-safe so uses some kind of locking/atomicity, but it should never
// be read before first write, plus we don't believe this will be called from
// different goroutines, so it should be safe.
func FromContext(ctx context.Context) (*Namespace, error) {
if ctx == nil {
return nil, errors.New("context was nil")
@ -72,20 +86,28 @@ func FromContext(ctx context.Context) (*Namespace, error) {
}
}
ns := ctx.Value(contextNamespace)
nsRaw := ctx.Value(contextNamespace)
if nsRaw == nil {
return nil, ErrNoNamespace
}
ns := nsRaw.(*Namespace)
if ns == nil {
return nil, ErrNoNamespace
}
if ok {
nsCtx.cachedNS = ns.(*Namespace)
nsCtx.cachedNS = ns
}
return ns.(*Namespace), nil
return ns, nil
}
func TestContext() context.Context {
return ContextWithNamespace(context.Background(), New(RootNamespaceID, ""))
return ContextWithNamespace(context.Background(), TestNamespace())
}
func TestNamespace() *Namespace {
return RootNamespace
}
// Canonicalize trims any prefix '/' and adds a trailing '/' to the