diff --git a/http/handler.go b/http/handler.go index fbcae36a0..2baf42a88 100644 --- a/http/handler.go +++ b/http/handler.go @@ -15,7 +15,8 @@ func Handler(core *vault.Core) http.Handler { mux.Handle("/v1/sys/seal-status", handleSysSealStatus(core)) mux.Handle("/v1/sys/seal", handleSysSeal(core)) mux.Handle("/v1/sys/unseal", handleSysUnseal(core)) - mux.Handle("/v1/sys/mounts", handleSysMounts(core)) + mux.Handle("/v1/sys/mounts", handleSysListMounts(core)) + mux.Handle("/v1/sys/mount/", handleSysMount(core)) mux.Handle("/v1/", handleLogical(core)) return mux } diff --git a/http/http_test.go b/http/http_test.go index bb5540d92..06a4dc37a 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -8,7 +8,15 @@ import ( "testing" ) +func testHttpPost(t *testing.T, addr string, body interface{}) *http.Response { + return testHttpData(t, "POST", addr, body) +} + func testHttpPut(t *testing.T, addr string, body interface{}) *http.Response { + return testHttpData(t, "PUT", addr, body) +} + +func testHttpData(t *testing.T, method string, addr string, body interface{}) *http.Response { bodyReader := new(bytes.Buffer) if body != nil { enc := json.NewEncoder(bodyReader) @@ -17,7 +25,7 @@ func testHttpPut(t *testing.T, addr string, body interface{}) *http.Response { } } - req, err := http.NewRequest("PUT", addr, bodyReader) + req, err := http.NewRequest(method, addr, bodyReader) if err != nil { t.Fatalf("err: %s", err) } diff --git a/http/sys_mount.go b/http/sys_mount.go index 85b5cd99a..ec4a2657e 100644 --- a/http/sys_mount.go +++ b/http/sys_mount.go @@ -2,12 +2,13 @@ package http import ( "net/http" + "strings" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/vault" ) -func handleSysMounts(core *vault.Core) http.Handler { +func handleSysListMounts(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { respondError(w, http.StatusMethodNotAllowed, nil) @@ -26,3 +27,51 @@ func handleSysMounts(core *vault.Core) http.Handler { respondOk(w, resp.Data) }) } + +func handleSysMount(core *vault.Core) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + respondError(w, http.StatusMethodNotAllowed, nil) + return + } + + // Determine the path... + prefix := "/v1/sys/mount/" + if !strings.HasPrefix(r.URL.Path, prefix) { + respondError(w, http.StatusNotFound, nil) + return + } + path := r.URL.Path[len(prefix):] + if path == "" { + respondError(w, http.StatusNotFound, nil) + return + } + + // Parse the request if we can + var req MountRequest + if err := parseRequest(r, &req); err != nil { + respondError(w, http.StatusBadRequest, err) + return + } + + _, err := core.HandleRequest(&logical.Request{ + Operation: logical.WriteOperation, + Path: "sys/mount/" + path, + Data: map[string]interface{}{ + "type": req.Type, + "description": req.Description, + }, + }) + if err != nil { + respondError(w, http.StatusInternalServerError, err) + return + } + + respondOk(w, nil) + }) +} + +type MountRequest struct { + Type string `json:"type"` + Description string `json:"description"` +} diff --git a/http/sys_mount_test.go b/http/sys_mount_test.go index 04bda9395..bf03611d0 100644 --- a/http/sys_mount_test.go +++ b/http/sys_mount_test.go @@ -35,3 +35,41 @@ func TestSysMounts(t *testing.T) { t.Fatalf("bad: %#v", actual) } } + +func TestSysMount(t *testing.T) { + core, _ := vault.TestCoreUnsealed(t) + ln, addr := TestServer(t, core) + defer ln.Close() + + resp := testHttpPost(t, addr+"/v1/sys/mount/foo", map[string]interface{}{ + "type": "generic", + "description": "foo", + }) + testResponseStatus(t, resp, 204) + + resp, err := http.Get(addr + "/v1/sys/mounts") + if err != nil { + t.Fatalf("err: %s", err) + } + + var actual map[string]interface{} + expected := map[string]interface{}{ + "foo/": map[string]interface{}{ + "description": "foo", + "type": "generic", + }, + "secret/": map[string]interface{}{ + "description": "generic secret storage", + "type": "generic", + }, + "sys/": map[string]interface{}{ + "description": "system endpoints used for control, policy and debugging", + "type": "system", + }, + } + testResponseStatus(t, resp, 200) + testResponseBody(t, resp, &actual) + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("bad: %#v", actual) + } +} diff --git a/vault/logical_system.go b/vault/logical_system.go index 60e3186ed..37b88a7f7 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -29,7 +29,7 @@ func NewSystemBackend(core *Core) logical.Backend { }, &framework.Path{ - Pattern: "mount/(?P.+?)", + Pattern: "mount/(?P.+)", Fields: map[string]*framework.FieldSchema{ "path": &framework.FieldSchema{