From 68d79142c4838056210d3f48ee00605b89b8ce1e Mon Sep 17 00:00:00 2001 From: Matt Keeler Date: Mon, 25 Nov 2019 12:07:04 -0500 Subject: [PATCH] OSS Modifications necessary for sessions namespacing --- agent/consul/fsm/commands_oss.go | 2 +- agent/consul/fsm/commands_oss_test.go | 4 +- agent/consul/fsm/snapshot_oss_test.go | 3 +- agent/consul/kvs_endpoint_test.go | 4 +- agent/consul/prepared_query_endpoint_test.go | 17 ++-- agent/consul/session_endpoint.go | 28 +++++-- agent/consul/session_endpoint_test.go | 35 +++++--- agent/consul/session_ttl.go | 16 ++-- agent/consul/session_ttl_test.go | 19 ++--- agent/consul/snapshot_endpoint_test.go | 2 +- agent/consul/state/catalog.go | 20 +++-- agent/consul/state/kvs.go | 2 +- agent/consul/state/kvs_test.go | 2 +- agent/consul/state/operations_oss.go | 26 ++++++ agent/consul/state/prepared_query.go | 4 +- agent/consul/state/prepared_query_test.go | 2 +- agent/consul/state/session.go | 85 +++++--------------- agent/consul/state/session_oss.go | 74 +++++++++++++++++ agent/consul/state/session_test.go | 82 ++++++++++--------- agent/consul/txn_endpoint_test.go | 5 +- agent/http.go | 4 +- agent/session_endpoint.go | 18 +++-- agent/session_endpoint_test.go | 32 ++++++-- agent/structs/structs.go | 7 +- agent/structs/structs_oss.go | 5 ++ website/source/api/session.html.md | 37 ++++++++- 26 files changed, 353 insertions(+), 182 deletions(-) create mode 100644 agent/consul/state/operations_oss.go create mode 100644 agent/consul/state/session_oss.go diff --git a/agent/consul/fsm/commands_oss.go b/agent/consul/fsm/commands_oss.go index 5e48c1a6b..ea18ac4e7 100644 --- a/agent/consul/fsm/commands_oss.go +++ b/agent/consul/fsm/commands_oss.go @@ -141,7 +141,7 @@ func (c *FSM) applySessionOperation(buf []byte, index uint64) interface{} { } return req.Session.ID case structs.SessionDestroy: - return c.state.SessionDestroy(index, req.Session.ID) + return c.state.SessionDestroy(index, req.Session.ID, &req.Session.EnterpriseMeta) default: c.logger.Printf("[WARN] consul.fsm: Invalid Session operation '%s'", req.Op) return fmt.Errorf("Invalid Session operation '%s'", req.Op) diff --git a/agent/consul/fsm/commands_oss_test.go b/agent/consul/fsm/commands_oss_test.go index 60f5289a0..b4c51a778 100644 --- a/agent/consul/fsm/commands_oss_test.go +++ b/agent/consul/fsm/commands_oss_test.go @@ -743,7 +743,7 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) { // Get the session id := resp.(string) - _, session, err := fsm.state.SessionGet(nil, id) + _, session, err := fsm.state.SessionGet(nil, id, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -779,7 +779,7 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) { t.Fatalf("resp: %v", resp) } - _, session, err = fsm.state.SessionGet(nil, id) + _, session, err = fsm.state.SessionGet(nil, id, nil) if err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/consul/fsm/snapshot_oss_test.go b/agent/consul/fsm/snapshot_oss_test.go index 39e46a9a7..120f8fa1b 100644 --- a/agent/consul/fsm/snapshot_oss_test.go +++ b/agent/consul/fsm/snapshot_oss_test.go @@ -79,6 +79,7 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { }) session := &structs.Session{ID: generateUUID(), Node: "foo"} fsm.state.SessionCreate(9, session) + policy := &structs.ACLPolicy{ ID: structs.ACLPolicyGlobalManagementID, Name: "global-management", @@ -359,7 +360,7 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { } // Verify session is restored - idx, s, err := fsm2.state.SessionGet(nil, session.ID) + idx, s, err := fsm2.state.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/consul/kvs_endpoint_test.go b/agent/consul/kvs_endpoint_test.go index e3ba380df..9599f2281 100644 --- a/agent/consul/kvs_endpoint_test.go +++ b/agent/consul/kvs_endpoint_test.go @@ -764,6 +764,7 @@ func TestKVS_Apply_LockDelay(t *testing.T) { // Create and invalidate a session with a lock. state := s1.fsm.State() + if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { t.Fatalf("err: %v", err) } @@ -783,7 +784,8 @@ func TestKVS_Apply_LockDelay(t *testing.T) { if ok, err := state.KVSLock(3, d); err != nil || !ok { t.Fatalf("err: %v", err) } - if err := state.SessionDestroy(4, id); err != nil { + + if err := state.SessionDestroy(4, id, nil); err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/consul/prepared_query_endpoint_test.go b/agent/consul/prepared_query_endpoint_test.go index c4ae24233..1a0c051da 100644 --- a/agent/consul/prepared_query_endpoint_test.go +++ b/agent/consul/prepared_query_endpoint_test.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/structs" + tokenStore "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" @@ -77,7 +78,7 @@ func TestPreparedQuery_Apply(t *testing.T) { query.Query.Service.Failover.NearestN = 0 query.Query.Session = "nope" err = msgpackrpc.CallWithCodec(codec, "PreparedQuery.Apply", &query, &reply) - if err == nil || !strings.Contains(err.Error(), "failed session lookup") { + if err == nil || !strings.Contains(err.Error(), "invalid session") { t.Fatalf("bad: %v", err) } @@ -852,7 +853,7 @@ func TestPreparedQuery_Get(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") // Create an ACL with write permissions for redis queries. var token string @@ -1105,7 +1106,7 @@ func TestPreparedQuery_List(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") // Create an ACL with write permissions for redis queries. var token string @@ -1461,16 +1462,16 @@ func TestPreparedQuery_Execute(t *testing.T) { codec2 := rpcClient(t, s2) defer codec2.Close() + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) testrpc.WaitForLeader(t, s1.RPC, "dc1") - testrpc.WaitForLeader(t, s2.RPC, "dc2") - - // Try to WAN join. joinWAN(t, s2, s1) + // Try to WAN join. retry.Run(t, func(r *retry.R) { if got, want := len(s1.WANMembers()), 2; got != want { r.Fatalf("got %d WAN members want %d", got, want) } }) + testrpc.WaitForLeader(t, s2.RPC, "dc2") // Create an ACL with read permission to the service. var execToken string @@ -2957,11 +2958,11 @@ func TestPreparedQuery_Wrapper(t *testing.T) { defer os.RemoveAll(dir2) defer s2.Shutdown() + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) testrpc.WaitForLeader(t, s1.RPC, "dc1") - testrpc.WaitForLeader(t, s2.RPC, "dc2") - // Try to WAN join. joinWAN(t, s2, s1) + testrpc.WaitForLeader(t, s2.RPC, "dc2") // Try all the operations on a real server via the wrapper. wrapper := &queryServerWrapper{s1} diff --git a/agent/consul/session_endpoint.go b/agent/consul/session_endpoint.go index 072cfce23..17155da0c 100644 --- a/agent/consul/session_endpoint.go +++ b/agent/consul/session_endpoint.go @@ -38,11 +38,15 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { if err != nil { return err } + // TODO (namespaces) (acls) infer entmeta if not provided. + // The entMeta to populate will be the one in the Session struct, not SessionRequest + // This is because the Session is what is passed to downstream functions like raftApply + if rule != nil && s.srv.config.ACLEnforceVersion8 { switch args.Op { case structs.SessionDestroy: state := s.srv.fsm.State() - _, existing, err := state.SessionGet(nil, args.Session.ID) + _, existing, err := state.SessionGet(nil, args.Session.ID, &args.Session.EnterpriseMeta) if err != nil { return fmt.Errorf("Session lookup failed: %v", err) } @@ -102,7 +106,7 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { s.srv.logger.Printf("[ERR] consul.session: UUID generation failed: %v", err) return err } - _, sess, err := state.SessionGet(nil, args.Session.ID) + _, sess, err := state.SessionGet(nil, args.Session.ID, &args.Session.EnterpriseMeta) if err != nil { s.srv.logger.Printf("[ERR] consul.session: Session lookup failed: %v", err) return err @@ -147,11 +151,13 @@ func (s *Session) Get(args *structs.SessionSpecificRequest, return err } + // TODO (namespaces) TODO (acls) infer args.entmeta if not provided + return s.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, session, err := state.SessionGet(ws, args.Session) + index, session, err := state.SessionGet(ws, args.SessionID, &args.EnterpriseMeta) if err != nil { return err } @@ -170,17 +176,19 @@ func (s *Session) Get(args *structs.SessionSpecificRequest, } // List is used to list all the active sessions -func (s *Session) List(args *structs.DCSpecificRequest, +func (s *Session) List(args *structs.SessionSpecificRequest, reply *structs.IndexedSessions) error { if done, err := s.srv.forward("Session.List", args, args, reply); done { return err } + // TODO (namespaces) TODO (acls) infer args.entmeta if not provided + return s.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, sessions, err := state.SessionList(ws) + index, sessions, err := state.SessionList(ws, &args.EnterpriseMeta) if err != nil { return err } @@ -200,11 +208,13 @@ func (s *Session) NodeSessions(args *structs.NodeSpecificRequest, return err } + // TODO (namespaces) TODO (acls) infer args.entmeta if not provided + return s.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - index, sessions, err := state.NodeSessions(ws, args.Node) + index, sessions, err := state.NodeSessions(ws, args.Node, &args.EnterpriseMeta) if err != nil { return err } @@ -225,9 +235,11 @@ func (s *Session) Renew(args *structs.SessionSpecificRequest, } defer metrics.MeasureSince([]string{"session", "renew"}, time.Now()) + // TODO (namespaces) (freddy):infer args.entmeta if not provided + // Get the session, from local state. state := s.srv.fsm.State() - index, session, err := state.SessionGet(nil, args.Session) + index, session, err := state.SessionGet(nil, args.SessionID, &args.EnterpriseMeta) if err != nil { return err } @@ -251,7 +263,7 @@ func (s *Session) Renew(args *structs.SessionSpecificRequest, // Reset the session TTL timer. reply.Sessions = structs.Sessions{session} - if err := s.srv.resetSessionTimer(args.Session, session); err != nil { + if err := s.srv.resetSessionTimer(args.SessionID, session); err != nil { s.srv.logger.Printf("[ERR] consul.session: Session renew failed: %v", err) return err } diff --git a/agent/consul/session_endpoint_test.go b/agent/consul/session_endpoint_test.go index 3528284e2..bd42febc1 100644 --- a/agent/consul/session_endpoint_test.go +++ b/agent/consul/session_endpoint_test.go @@ -17,6 +17,7 @@ func TestSession_Apply(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -41,7 +42,7 @@ func TestSession_Apply(t *testing.T) { // Verify state := s1.fsm.State() - _, s, err := state.SessionGet(nil, out) + _, s, err := state.SessionGet(nil, out, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -63,7 +64,7 @@ func TestSession_Apply(t *testing.T) { } // Verify - _, s, err = state.SessionGet(nil, id) + _, s, err = state.SessionGet(nil, id, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -77,6 +78,7 @@ func TestSession_DeleteApply(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -102,7 +104,7 @@ func TestSession_DeleteApply(t *testing.T) { // Verify state := s1.fsm.State() - _, s, err := state.SessionGet(nil, out) + _, s, err := state.SessionGet(nil, out, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -127,7 +129,7 @@ func TestSession_DeleteApply(t *testing.T) { } // Verify - _, s, err = state.SessionGet(nil, id) + _, s, err = state.SessionGet(nil, id, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -147,6 +149,7 @@ func TestSession_Apply_ACLDeny(t *testing.T) { }) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -237,6 +240,7 @@ func TestSession_Get(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -257,7 +261,7 @@ func TestSession_Get(t *testing.T) { getR := structs.SessionSpecificRequest{ Datacenter: "dc1", - Session: out, + SessionID: out, } var sessions structs.IndexedSessions if err := msgpackrpc.CallWithCodec(codec, "Session.Get", &getR, &sessions); err != nil { @@ -281,6 +285,7 @@ func TestSession_List(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -339,6 +344,7 @@ func TestSession_Get_List_NodeSessions_ACLFilter(t *testing.T) { }) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -384,7 +390,7 @@ session "foo" { // 8 ACL enforcement isn't enabled. getR := structs.SessionSpecificRequest{ Datacenter: "dc1", - Session: out, + SessionID: out, } { var sessions structs.IndexedSessions @@ -486,7 +492,7 @@ session "foo" { // Try to get a session that doesn't exist to make sure that's handled // correctly by the filter (it will get passed a nil slice). - getR.Session = "adf4238a-882b-9ddc-4a9d-5b6758e4159e" + getR.SessionID = "adf4238a-882b-9ddc-4a9d-5b6758e4159e" { var sessions structs.IndexedSessions if err := msgpackrpc.CallWithCodec(codec, "Session.Get", &getR, &sessions); err != nil { @@ -503,10 +509,12 @@ func TestSession_ApplyTimers(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() - testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + codec := rpcClient(t, s1) defer codec.Close() + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) arg := structs.SessionRequest{ Datacenter: "dc1", @@ -551,6 +559,7 @@ func TestSession_Renew(t *testing.T) { defer os.RemoveAll(dir1) defer s1.Shutdown() testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + codec := rpcClient(t, s1) defer codec.Close() @@ -613,7 +622,7 @@ func TestSession_Renew(t *testing.T) { for i := 0; i < 3; i++ { renewR := structs.SessionSpecificRequest{ Datacenter: "dc1", - Session: ids[i], + SessionID: ids[i], } var session structs.IndexedSessions if err := msgpackrpc.CallWithCodec(codec, "Session.Renew", &renewR, &session); err != nil { @@ -714,10 +723,12 @@ func TestSession_Renew_ACLDeny(t *testing.T) { }) defer os.RemoveAll(dir1) defer s1.Shutdown() - testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + codec := rpcClient(t, s1) defer codec.Close() + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") + // Create the ACL. req := structs.ACLRequest{ Datacenter: "dc1", @@ -761,7 +772,7 @@ session "foo" { // enforcement. renewR := structs.SessionSpecificRequest{ Datacenter: "dc1", - Session: id, + SessionID: id, } var session structs.IndexedSessions if err := msgpackrpc.CallWithCodec(codec, "Session.Renew", &renewR, &session); err != nil { @@ -787,6 +798,7 @@ func TestSession_NodeSessions(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() @@ -846,6 +858,7 @@ func TestSession_Apply_BadTTL(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() + codec := rpcClient(t, s1) defer codec.Close() diff --git a/agent/consul/session_ttl.go b/agent/consul/session_ttl.go index fd12701fb..e5ca429f7 100644 --- a/agent/consul/session_ttl.go +++ b/agent/consul/session_ttl.go @@ -22,7 +22,8 @@ const ( func (s *Server) initializeSessionTimers() error { // Scan all sessions and reset their timer state := s.fsm.State() - _, sessions, err := state.SessionList(nil) + + _, sessions, err := state.SessionList(nil, structs.WildcardEnterpriseMeta()) if err != nil { return err } @@ -41,7 +42,7 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error { // Fault the session in if not given if session == nil { state := s.fsm.State() - _, s, err := state.SessionGet(nil, id) + _, s, err := state.SessionGet(nil, id, nil) if err != nil { return err } @@ -66,11 +67,11 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error { return nil } - s.createSessionTimer(session.ID, ttl) + s.createSessionTimer(session.ID, ttl, &session.EnterpriseMeta) return nil } -func (s *Server) createSessionTimer(id string, ttl time.Duration) { +func (s *Server) createSessionTimer(id string, ttl time.Duration, entMeta *structs.EnterpriseMeta) { // Reset the session timer // Adjust the given TTL by the TTL multiplier. This is done // to give a client a grace period and to compensate for network @@ -78,12 +79,12 @@ func (s *Server) createSessionTimer(id string, ttl time.Duration) { // before the TTL, but there is no explicit promise about the upper // bound so this is allowable. ttl = ttl * structs.SessionTTLMultiplier - s.sessionTimers.ResetOrCreate(id, ttl, func() { s.invalidateSession(id) }) + s.sessionTimers.ResetOrCreate(id, ttl, func() { s.invalidateSession(id, entMeta) }) } // invalidateSession is invoked when a session TTL is reached and we // need to invalidate the session. -func (s *Server) invalidateSession(id string) { +func (s *Server) invalidateSession(id string, entMeta *structs.EnterpriseMeta) { defer metrics.MeasureSince([]string{"session_ttl", "invalidate"}, time.Now()) // Clear the session timer @@ -97,6 +98,9 @@ func (s *Server) invalidateSession(id string) { ID: id, }, } + if entMeta != nil { + args.Session.EnterpriseMeta = *entMeta + } // Retry with exponential backoff to invalidate the session for attempt := uint(0); attempt < maxInvalidateAttempts; attempt++ { diff --git a/agent/consul/session_ttl_test.go b/agent/consul/session_ttl_test.go index dfa1b32e5..92e9ae2c5 100644 --- a/agent/consul/session_ttl_test.go +++ b/agent/consul/session_ttl_test.go @@ -157,7 +157,7 @@ func TestResetSessionTimerLocked(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") - s1.createSessionTimer("foo", 5*time.Millisecond) + s1.createSessionTimer("foo", 5*time.Millisecond, nil) if s1.sessionTimers.Get("foo") == nil { t.Fatalf("missing timer") } @@ -178,7 +178,7 @@ func TestResetSessionTimerLocked_Renew(t *testing.T) { retry.Run(t, func(r *retry.R) { // create the timer and make verify it was created - s1.createSessionTimer("foo", ttl) + s1.createSessionTimer("foo", ttl, nil) if s1.sessionTimers.Get("foo") == nil { r.Fatalf("missing timer") } @@ -194,7 +194,7 @@ func TestResetSessionTimerLocked_Renew(t *testing.T) { retry.Run(t, func(r *retry.R) { // renew the session which will reset the TTL to 2*ttl // since that is the current SessionTTLMultiplier - s1.createSessionTimer("foo", ttl) + s1.createSessionTimer("foo", ttl, nil) if s1.sessionTimers.Get("foo") == nil { r.Fatal("missing timer") } @@ -231,6 +231,7 @@ func TestInvalidateSession(t *testing.T) { if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { t.Fatalf("err: %s", err) } + session := &structs.Session{ ID: generateUUID(), Node: "foo", @@ -241,10 +242,10 @@ func TestInvalidateSession(t *testing.T) { } // This should cause a destroy - s1.invalidateSession(session.ID) + s1.invalidateSession(session.ID, nil) // Check it is gone - _, sess, err := state.SessionGet(nil, session.ID) + _, sess, err := state.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -259,7 +260,7 @@ func TestClearSessionTimer(t *testing.T) { defer os.RemoveAll(dir1) defer s1.Shutdown() - s1.createSessionTimer("foo", 5*time.Millisecond) + s1.createSessionTimer("foo", 5*time.Millisecond, nil) err := s1.clearSessionTimer("foo") if err != nil { @@ -277,9 +278,9 @@ func TestClearAllSessionTimers(t *testing.T) { defer os.RemoveAll(dir1) defer s1.Shutdown() - s1.createSessionTimer("foo", 10*time.Millisecond) - s1.createSessionTimer("bar", 10*time.Millisecond) - s1.createSessionTimer("baz", 10*time.Millisecond) + s1.createSessionTimer("foo", 10*time.Millisecond, nil) + s1.createSessionTimer("bar", 10*time.Millisecond, nil) + s1.createSessionTimer("baz", 10*time.Millisecond, nil) s1.clearAllSessionTimers() diff --git a/agent/consul/snapshot_endpoint_test.go b/agent/consul/snapshot_endpoint_test.go index f3717b292..24af70180 100644 --- a/agent/consul/snapshot_endpoint_test.go +++ b/agent/consul/snapshot_endpoint_test.go @@ -169,7 +169,7 @@ func TestSnapshot_LeaderState(t *testing.T) { defer os.RemoveAll(dir1) defer s1.Shutdown() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") codec := rpcClient(t, s1) defer codec.Close() diff --git a/agent/consul/state/catalog.go b/agent/consul/state/catalog.go index 7e7e89fb8..e2f0b88b7 100644 --- a/agent/consul/state/catalog.go +++ b/agent/consul/state/catalog.go @@ -669,6 +669,7 @@ func (s *Store) deleteNodeCASTxn(tx *memdb.Txn, idx, cidx uint64, nodeName strin // deleteNodeTxn is the inner method used for removing a node from // the store within a given transaction. +// TODO (namespaces) (catalog) access to catalog tables needs to become namespace aware for services/checks func (s *Store) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeName string) error { // Look up the node. node, err := tx.First("nodes", "id", nodeName) @@ -748,15 +749,16 @@ func (s *Store) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeName string) error if err != nil { return fmt.Errorf("failed session lookup: %s", err) } - var ids []string + + var toDelete []*structs.Session for sess := sessions.Next(); sess != nil; sess = sessions.Next() { - ids = append(ids, sess.(*structs.Session).ID) + session := sess.(*structs.Session) + toDelete = append(toDelete, session) } - // Do the delete in a separate loop so we don't trash the iterator. - for _, id := range ids { - if err := s.deleteSessionTxn(tx, idx, id); err != nil { - return fmt.Errorf("failed session delete: %s", err) + for _, session := range toDelete { + if err := s.deleteSessionTxn(tx, idx, session.ID, &session.EnterpriseMeta); err != nil { + return fmt.Errorf("failed to delete session '%s': %v", session.ID, err) } } @@ -1605,7 +1607,8 @@ func (s *Store) ensureCheckTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthChec // Delete the session in a separate loop so we don't trash the // iterator. for _, id := range ids { - if err := s.deleteSessionTxn(tx, idx, id); err != nil { + // TODO (namespaces): Update when structs.HealthCheck supports Namespaces (&hc.EnterpriseMeta) + if err := s.deleteSessionTxn(tx, idx, id, nil); err != nil { return fmt.Errorf("failed deleting session: %s", err) } } @@ -1917,7 +1920,8 @@ func (s *Store) deleteCheckTxn(tx *memdb.Txn, idx uint64, node string, checkID t // Do the delete in a separate loop so we don't trash the iterator. for _, id := range ids { - if err := s.deleteSessionTxn(tx, idx, id); err != nil { + // TODO (namespaces): Update when structs.HealthCheck supports Namespaces (&hc.EnterpriseMeta) + if err := s.deleteSessionTxn(tx, idx, id, nil); err != nil { return fmt.Errorf("failed deleting session: %s", err) } } diff --git a/agent/consul/state/kvs.go b/agent/consul/state/kvs.go index eea081c91..6983c6b27 100644 --- a/agent/consul/state/kvs.go +++ b/agent/consul/state/kvs.go @@ -529,7 +529,7 @@ func (s *Store) kvsLockTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) ( } // Verify that the session exists. - sess, err := tx.First("sessions", "id", entry.Session) + sess, err := firstWithTxn(tx, "sessions", "id", entry.Session, &entry.EnterpriseMeta) if err != nil { return false, fmt.Errorf("failed session lookup: %s", err) } diff --git a/agent/consul/state/kvs_test.go b/agent/consul/state/kvs_test.go index e9f088c6d..4b3c0d3d8 100644 --- a/agent/consul/state/kvs_test.go +++ b/agent/consul/state/kvs_test.go @@ -91,7 +91,7 @@ func TestStateStore_GC(t *testing.T) { if ok, err := s.KVSLock(11, d); !ok || err != nil { t.Fatalf("err: %v", err) } - if err := s.SessionDestroy(12, session.ID); err != nil { + if err := s.SessionDestroy(12, session.ID, nil); err != nil { t.Fatalf("err: %s", err) } select { diff --git a/agent/consul/state/operations_oss.go b/agent/consul/state/operations_oss.go new file mode 100644 index 000000000..48deec786 --- /dev/null +++ b/agent/consul/state/operations_oss.go @@ -0,0 +1,26 @@ +// +build !consulent + +package state + +import ( + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-memdb" +) + +func firstWithTxn(tx *memdb.Txn, + table, index, idxVal string, entMeta *structs.EnterpriseMeta) (interface{}, error) { + + return tx.First(table, index, idxVal) +} + +func firstWatchWithTxn(tx *memdb.Txn, + table, index, idxVal string, entMeta *structs.EnterpriseMeta) (<-chan struct{}, interface{}, error) { + + return tx.FirstWatch(table, index, idxVal) +} + +func getWithTxn(tx *memdb.Txn, + table, index, idxVal string, entMeta *structs.EnterpriseMeta) (memdb.ResultIterator, error) { + + return tx.Get(table, index, idxVal) +} diff --git a/agent/consul/state/prepared_query.go b/agent/consul/state/prepared_query.go index 285355785..89a8f8349 100644 --- a/agent/consul/state/prepared_query.go +++ b/agent/consul/state/prepared_query.go @@ -210,9 +210,9 @@ func (s *Store) preparedQuerySetTxn(tx *memdb.Txn, idx uint64, query *structs.Pr // Verify that the session exists. if query.Session != "" { - sess, err := tx.First("sessions", "id", query.Session) + sess, err := firstWithTxn(tx, "sessions", "id", query.Session, nil) if err != nil { - return fmt.Errorf("failed session lookup: %s", err) + return fmt.Errorf("invalid session: %v", err) } if sess == nil { return fmt.Errorf("invalid session %#v", query.Session) diff --git a/agent/consul/state/prepared_query_test.go b/agent/consul/state/prepared_query_test.go index 44495819e..55eb77459 100644 --- a/agent/consul/state/prepared_query_test.go +++ b/agent/consul/state/prepared_query_test.go @@ -68,7 +68,7 @@ func TestStateStore_PreparedQuerySet_PreparedQueryGet(t *testing.T) { // The set will still fail because the session is bogus. err = s.PreparedQuerySet(1, query) - if err == nil || !strings.Contains(err.Error(), "failed session lookup") { + if err == nil || !strings.Contains(err.Error(), "invalid session") { t.Fatalf("bad: %v", err) } diff --git a/agent/consul/state/session.go b/agent/consul/state/session.go index 9775ff639..3769f9214 100644 --- a/agent/consul/state/session.go +++ b/agent/consul/state/session.go @@ -19,9 +19,7 @@ func sessionsTableSchema() *memdb.TableSchema { Name: "id", AllowMissing: false, Unique: true, - Indexer: &memdb.UUIDFieldIndex{ - Field: "ID", - }, + Indexer: sessionIndexer(), }, "node": &memdb.IndexSchema{ Name: "node", @@ -108,28 +106,10 @@ func (s *Snapshot) Sessions() (memdb.ResultIterator, error) { // Session is used when restoring from a snapshot. For general inserts, use // SessionCreate. func (s *Restore) Session(sess *structs.Session) error { - // Insert the session. - if err := s.tx.Insert("sessions", sess); err != nil { + if err := s.store.insertSessionTxn(s.tx, sess, sess.ModifyIndex, true); err != nil { return fmt.Errorf("failed inserting session: %s", err) } - // Insert the check mappings. - for _, checkID := range sess.Checks { - mapping := &sessionCheck{ - Node: sess.Node, - CheckID: checkID, - Session: sess.ID, - } - if err := s.tx.Insert("session_checks", mapping); err != nil { - return fmt.Errorf("failed inserting session check mapping: %s", err) - } - } - - // Update the index. - if err := indexUpdateMaxTxn(s.tx, sess.ModifyIndex, "sessions"); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - return nil } @@ -206,44 +186,30 @@ func (s *Store) sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Sessio } // Insert the session - if err := tx.Insert("sessions", sess); err != nil { + if err := s.insertSessionTxn(tx, sess, idx, false); err != nil { return fmt.Errorf("failed inserting session: %s", err) } - // Insert the check mappings - for _, checkID := range sess.Checks { - mapping := &sessionCheck{ - Node: sess.Node, - CheckID: checkID, - Session: sess.ID, - } - if err := tx.Insert("session_checks", mapping); err != nil { - return fmt.Errorf("failed inserting session check mapping: %s", err) - } - } - - // Update the index - if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) - } - return nil } // SessionGet is used to retrieve an active session from the state store. -func (s *Store) SessionGet(ws memdb.WatchSet, sessionID string) (uint64, *structs.Session, error) { +func (s *Store) SessionGet(ws memdb.WatchSet, + sessionID string, entMeta *structs.EnterpriseMeta) (uint64, *structs.Session, error) { + tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, "sessions") + idx := s.sessionMaxIndex(tx, entMeta) // Look up the session by its ID - watchCh, session, err := tx.FirstWatch("sessions", "id", sessionID) + watchCh, session, err := firstWatchWithTxn(tx, "sessions", "id", sessionID, entMeta) if err != nil { return 0, nil, fmt.Errorf("failed session lookup: %s", err) } ws.Add(watchCh) + if session != nil { return idx, session.(*structs.Session), nil } @@ -251,15 +217,15 @@ func (s *Store) SessionGet(ws memdb.WatchSet, sessionID string) (uint64, *struct } // SessionList returns a slice containing all of the active sessions. -func (s *Store) SessionList(ws memdb.WatchSet) (uint64, structs.Sessions, error) { +func (s *Store) SessionList(ws memdb.WatchSet, entMeta *structs.EnterpriseMeta) (uint64, structs.Sessions, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, "sessions") + idx := s.sessionMaxIndex(tx, entMeta) // Query all of the active sessions. - sessions, err := tx.Get("sessions", "id") + sessions, err := getWithTxn(tx, "sessions", "id_prefix", "", entMeta) if err != nil { return 0, nil, fmt.Errorf("failed session lookup: %s", err) } @@ -276,12 +242,12 @@ func (s *Store) SessionList(ws memdb.WatchSet) (uint64, structs.Sessions, error) // NodeSessions returns a set of active sessions associated // with the given node ID. The returned index is the highest // index seen from the result set. -func (s *Store) NodeSessions(ws memdb.WatchSet, nodeID string) (uint64, structs.Sessions, error) { +func (s *Store) NodeSessions(ws memdb.WatchSet, nodeID string, entMeta *structs.EnterpriseMeta) (uint64, structs.Sessions, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. - idx := maxIndexTxn(tx, "sessions") + idx := s.sessionMaxIndex(tx, entMeta) // Get all of the sessions which belong to the node sessions, err := tx.Get("sessions", "node", nodeID) @@ -290,23 +256,19 @@ func (s *Store) NodeSessions(ws memdb.WatchSet, nodeID string) (uint64, structs. } ws.Add(sessions.WatchCh()) - // Go over all of the sessions and return them as a slice - var result structs.Sessions - for session := sessions.Next(); session != nil; session = sessions.Next() { - result = append(result, session.(*structs.Session)) - } + result := s.collectNodeSessions(sessions, entMeta) return idx, result, nil } // SessionDestroy is used to remove an active session. This will // implicitly invalidate the session and invoke the specified // session destroy behavior. -func (s *Store) SessionDestroy(idx uint64, sessionID string) error { +func (s *Store) SessionDestroy(idx uint64, sessionID string, entMeta *structs.EnterpriseMeta) error { tx := s.db.Txn(true) defer tx.Abort() // Call the session deletion. - if err := s.deleteSessionTxn(tx, idx, sessionID); err != nil { + if err := s.deleteSessionTxn(tx, idx, sessionID, entMeta); err != nil { return err } @@ -316,9 +278,9 @@ func (s *Store) SessionDestroy(idx uint64, sessionID string) error { // deleteSessionTxn is the inner method, which is used to do the actual // session deletion and handle session invalidation, etc. -func (s *Store) deleteSessionTxn(tx *memdb.Txn, idx uint64, sessionID string) error { +func (s *Store) deleteSessionTxn(tx *memdb.Txn, idx uint64, sessionID string, entMeta *structs.EnterpriseMeta) error { // Look up the session. - sess, err := tx.First("sessions", "id", sessionID) + sess, err := firstWithTxn(tx, "sessions", "id", sessionID, entMeta) if err != nil { return fmt.Errorf("failed session lookup: %s", err) } @@ -327,15 +289,12 @@ func (s *Store) deleteSessionTxn(tx *memdb.Txn, idx uint64, sessionID string) er } // Delete the session and write the new index. - if err := tx.Delete("sessions", sess); err != nil { - return fmt.Errorf("failed deleting session: %s", err) - } - if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil { - return fmt.Errorf("failed updating index: %s", err) + session := sess.(*structs.Session) + if err := s.sessionDeleteWithSession(tx, session, idx); err != nil { + return fmt.Errorf("failed deleting session: %v", err) } // Enforce the max lock delay. - session := sess.(*structs.Session) delay := session.LockDelay if delay > structs.MaxLockDelay { delay = structs.MaxLockDelay diff --git a/agent/consul/state/session_oss.go b/agent/consul/state/session_oss.go new file mode 100644 index 000000000..4f05789ba --- /dev/null +++ b/agent/consul/state/session_oss.go @@ -0,0 +1,74 @@ +// +build !consulent + +package state + +import ( + "fmt" + + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-memdb" +) + +func sessionIndexer() *memdb.UUIDFieldIndex { + return &memdb.UUIDFieldIndex{ + Field: "ID", + } +} + +func (s *Store) collectNodeSessions(sessions memdb.ResultIterator, entMeta *structs.EnterpriseMeta) structs.Sessions { + // Go over all of the sessions and return them as a slice + var result structs.Sessions + for s := sessions.Next(); s != nil; s = sessions.Next() { + result = append(result, s.(*structs.Session)) + } + return result +} + +func (s *Store) sessionDeleteWithSession(tx *memdb.Txn, session *structs.Session, idx uint64) error { + if err := tx.Delete("sessions", session); err != nil { + return fmt.Errorf("failed deleting session: %s", err) + } + + // Update the indexes + err := tx.Insert("index", &IndexEntry{"sessions", idx}) + if err != nil { + return fmt.Errorf("failed updating sessions index: %v", err) + } + return nil +} + +func (s *Store) insertSessionTxn(tx *memdb.Txn, session *structs.Session, idx uint64, updateMax bool) error { + if err := tx.Insert("sessions", session); err != nil { + return err + } + + // Insert the check mappings + for _, checkID := range session.Checks { + mapping := &sessionCheck{ + Node: session.Node, + CheckID: checkID, + Session: session.ID, + } + if err := tx.Insert("session_checks", mapping); err != nil { + return fmt.Errorf("failed inserting session check mapping: %s", err) + } + } + + // Update the index + if updateMax { + if err := indexUpdateMaxTxn(tx, idx, "sessions"); err != nil { + return fmt.Errorf("failed updating sessions index: %v", err) + } + } else { + err := tx.Insert("index", &IndexEntry{"sessions", idx}) + if err != nil { + return fmt.Errorf("failed updating sessions index: %v", err) + } + } + + return nil +} + +func (s *Store) sessionMaxIndex(tx *memdb.Txn, entMeta *structs.EnterpriseMeta) uint64 { + return maxIndexTxn(tx, "sessions") +} diff --git a/agent/consul/state/session_test.go b/agent/consul/state/session_test.go index 1e638f6ec..c20e68229 100644 --- a/agent/consul/state/session_test.go +++ b/agent/consul/state/session_test.go @@ -2,6 +2,7 @@ package state import ( "fmt" + "github.com/stretchr/testify/assert" "reflect" "strings" "testing" @@ -18,7 +19,7 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { // SessionGet returns nil if the session doesn't exist ws := memdb.NewWatchSet() - idx, session, err := s.SessionGet(ws, testUUID()) + idx, session, err := s.SessionGet(ws, testUUID(), nil) if session != nil || err != nil { t.Fatalf("expected (nil, nil), got: (%#v, %#v)", session, err) } @@ -74,7 +75,7 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { // Retrieve the session again ws = memdb.NewWatchSet() - idx, session, err = s.SessionGet(ws, sess.ID) + idx, session, err = s.SessionGet(ws, sess.ID, nil) if err != nil { t.Fatalf("err: %s", err) } @@ -88,13 +89,15 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { ID: sess.ID, Behavior: structs.SessionKeysRelease, Node: "node1", - RaftIndex: structs.RaftIndex{ - CreateIndex: 2, - ModifyIndex: 2, - }, } - if !reflect.DeepEqual(expect, session) { - t.Fatalf("bad session: %#v", session) + if session.ID != expect.ID { + t.Fatalf("bad session ID: expected %s, got %s", expect.ID, session.ID) + } + if session.Node != expect.Node { + t.Fatalf("bad session Node: expected %s, got %s", expect.Node, session.Node) + } + if session.Behavior != expect.Behavior { + t.Fatalf("bad session Behavior: expected %s, got %s", expect.Behavior, session.Behavior) } // Registering with a non-existent check is disallowed @@ -176,7 +179,7 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { } // Pulling a nonexistent session gives the table index. - idx, session, err = s.SessionGet(nil, testUUID()) + idx, session, err = s.SessionGet(nil, testUUID(), nil) if err != nil { t.Fatalf("err: %s", err) } @@ -188,12 +191,12 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) { } } -func TegstStateStore_SessionList(t *testing.T) { +func TestStateStore_SessionList(t *testing.T) { s := testStateStore(t) // Listing when no sessions exist returns nil ws := memdb.NewWatchSet() - idx, res, err := s.SessionList(ws) + idx, res, err := s.SessionList(ws, nil) if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } @@ -231,15 +234,20 @@ func TegstStateStore_SessionList(t *testing.T) { } // List out all of the sessions - idx, sessionList, err := s.SessionList(nil) + idx, sessionList, err := s.SessionList(nil, nil) if err != nil { t.Fatalf("err: %s", err) } if idx != 6 { t.Fatalf("bad index: %d", idx) } - if !reflect.DeepEqual(sessionList, sessions) { - t.Fatalf("bad: %#v", sessions) + sessionMap := make(map[string]*structs.Session) + for _, session := range sessionList { + sessionMap[session.ID] = session + } + + for _, expect := range sessions { + assert.Equal(t, expect, sessionMap[expect.ID]) } } @@ -248,7 +256,7 @@ func TestStateStore_NodeSessions(t *testing.T) { // Listing sessions with no results returns nil ws := memdb.NewWatchSet() - idx, res, err := s.NodeSessions(ws, "node1") + idx, res, err := s.NodeSessions(ws, "node1", nil) if idx != 0 || res != nil || err != nil { t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) } @@ -290,7 +298,7 @@ func TestStateStore_NodeSessions(t *testing.T) { // Query all of the sessions associated with a specific // node in the state store. ws1 := memdb.NewWatchSet() - idx, res, err = s.NodeSessions(ws1, "node1") + idx, res, err = s.NodeSessions(ws1, "node1", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -302,7 +310,7 @@ func TestStateStore_NodeSessions(t *testing.T) { } ws2 := memdb.NewWatchSet() - idx, res, err = s.NodeSessions(ws2, "node2") + idx, res, err = s.NodeSessions(ws2, "node2", nil) if err != nil { t.Fatalf("err: %s", err) } @@ -314,7 +322,7 @@ func TestStateStore_NodeSessions(t *testing.T) { } // Destroying a session on node1 should not affect node2's watch. - if err := s.SessionDestroy(100, sessions1[0].ID); err != nil { + if err := s.SessionDestroy(100, sessions1[0].ID, nil); err != nil { t.Fatalf("err: %s", err) } if !watchFired(ws1) { @@ -330,7 +338,7 @@ func TestStateStore_SessionDestroy(t *testing.T) { // Session destroy is idempotent and returns no error // if the session doesn't exist. - if err := s.SessionDestroy(1, testUUID()); err != nil { + if err := s.SessionDestroy(1, testUUID(), nil); err != nil { t.Fatalf("err: %s", err) } @@ -352,7 +360,7 @@ func TestStateStore_SessionDestroy(t *testing.T) { } // Destroy the session. - if err := s.SessionDestroy(3, sess.ID); err != nil { + if err := s.SessionDestroy(3, sess.ID, nil); err != nil { t.Fatalf("err: %s", err) } @@ -412,7 +420,7 @@ func TestStateStore_Session_Snapshot_Restore(t *testing.T) { defer snap.Close() // Alter the real state store. - if err := s.SessionDestroy(8, session1); err != nil { + if err := s.SessionDestroy(8, session1, nil); err != nil { t.Fatalf("err: %s", err) } @@ -456,7 +464,7 @@ func TestStateStore_Session_Snapshot_Restore(t *testing.T) { // Read the restored sessions back out and verify that they // match. - idx, res, err := s.SessionList(nil) + idx, res, err := s.SessionList(nil, nil) if err != nil { t.Fatalf("err: %s", err) } @@ -522,7 +530,7 @@ func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) { // Delete the node and make sure the watch fires. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -534,7 +542,7 @@ func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) { } // Lookup by ID, should be nil. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -577,7 +585,7 @@ func TestStateStore_Session_Invalidate_DeleteService(t *testing.T) { // Delete the service and make sure the watch fires. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -589,7 +597,7 @@ func TestStateStore_Session_Invalidate_DeleteService(t *testing.T) { } // Lookup by ID, should be nil. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -627,7 +635,7 @@ func TestStateStore_Session_Invalidate_Critical_Check(t *testing.T) { // Invalidate the check and make sure the watches fire. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -640,7 +648,7 @@ func TestStateStore_Session_Invalidate_Critical_Check(t *testing.T) { } // Lookup by ID, should be nil. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -678,7 +686,7 @@ func TestStateStore_Session_Invalidate_DeleteCheck(t *testing.T) { // Delete the check and make sure the watches fire. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -690,7 +698,7 @@ func TestStateStore_Session_Invalidate_DeleteCheck(t *testing.T) { } // Lookup by ID, should be nil. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -746,7 +754,7 @@ func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) { // Delete the node and make sure the watches fire. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -758,7 +766,7 @@ func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) { } // Lookup by ID, should be nil. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -828,7 +836,7 @@ func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) { // Delete the node and make sure the watches fire. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -840,7 +848,7 @@ func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) { } // Lookup by ID, should be nil. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -896,11 +904,11 @@ func TestStateStore_Session_Invalidate_PreparedQuery_Delete(t *testing.T) { // Invalidate the session and make sure the watches fire. ws := memdb.NewWatchSet() - idx, s2, err := s.SessionGet(ws, session.ID) + idx, s2, err := s.SessionGet(ws, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } - if err := s.SessionDestroy(5, session.ID); err != nil { + if err := s.SessionDestroy(5, session.ID, nil); err != nil { t.Fatalf("err: %v", err) } if !watchFired(ws) { @@ -908,7 +916,7 @@ func TestStateStore_Session_Invalidate_PreparedQuery_Delete(t *testing.T) { } // Make sure the session is gone. - idx, s2, err = s.SessionGet(nil, session.ID) + idx, s2, err = s.SessionGet(nil, session.ID, nil) if err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/consul/txn_endpoint_test.go b/agent/consul/txn_endpoint_test.go index 5cbfb56f2..de030d3dd 100644 --- a/agent/consul/txn_endpoint_test.go +++ b/agent/consul/txn_endpoint_test.go @@ -620,7 +620,7 @@ func TestTxn_Apply_LockDelay(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") // Create and invalidate a session with a lock. state := s1.fsm.State() @@ -643,7 +643,8 @@ func TestTxn_Apply_LockDelay(t *testing.T) { if ok, err := state.KVSLock(3, d); err != nil || !ok { t.Fatalf("err: %v", err) } - if err := state.SessionDestroy(4, id); err != nil { + + if err := state.SessionDestroy(4, id, nil); err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/http.go b/agent/http.go index a875e76aa..9e4941457 100644 --- a/agent/http.go +++ b/agent/http.go @@ -19,13 +19,13 @@ import ( "time" "github.com/NYTimes/gziphandler" - metrics "github.com/armon/go-metrics" + "github.com/armon/go-metrics" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/consul" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" - cleanhttp "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-cleanhttp" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" ) diff --git a/agent/session_endpoint.go b/agent/session_endpoint.go index aba7dc1b2..7961aa1de 100644 --- a/agent/session_endpoint.go +++ b/agent/session_endpoint.go @@ -31,6 +31,7 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request) } s.parseDC(req, &args.Datacenter) s.parseToken(req, &args.Token) + s.parseEntMeta(req, &args.Session.EnterpriseMeta) // Handle optional request body if req.ContentLength > 0 { @@ -79,6 +80,7 @@ func (s *HTTPServer) SessionDestroy(resp http.ResponseWriter, req *http.Request) } s.parseDC(req, &args.Datacenter) s.parseToken(req, &args.Token) + s.parseEntMeta(req, &args.Session.EnterpriseMeta) // Pull out the session id args.Session.ID = strings.TrimPrefix(req.URL.Path, "/v1/session/destroy/") @@ -101,10 +103,11 @@ func (s *HTTPServer) SessionRenew(resp http.ResponseWriter, req *http.Request) ( if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil } + s.parseEntMeta(req, &args.EnterpriseMeta) // Pull out the session id - args.Session = strings.TrimPrefix(req.URL.Path, "/v1/session/renew/") - if args.Session == "" { + args.SessionID = strings.TrimPrefix(req.URL.Path, "/v1/session/renew/") + if args.SessionID == "" { resp.WriteHeader(http.StatusBadRequest) fmt.Fprint(resp, "Missing session") return nil, nil @@ -115,7 +118,7 @@ func (s *HTTPServer) SessionRenew(resp http.ResponseWriter, req *http.Request) ( return nil, err } else if out.Sessions == nil { resp.WriteHeader(http.StatusNotFound) - fmt.Fprintf(resp, "Session id '%s' not found", args.Session) + fmt.Fprintf(resp, "Session id '%s' not found", args.SessionID) return nil, nil } @@ -128,10 +131,11 @@ func (s *HTTPServer) SessionGet(resp http.ResponseWriter, req *http.Request) (in if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil } + s.parseEntMeta(req, &args.EnterpriseMeta) // Pull out the session id - args.Session = strings.TrimPrefix(req.URL.Path, "/v1/session/info/") - if args.Session == "" { + args.SessionID = strings.TrimPrefix(req.URL.Path, "/v1/session/info/") + if args.SessionID == "" { resp.WriteHeader(http.StatusBadRequest) fmt.Fprint(resp, "Missing session") return nil, nil @@ -152,10 +156,11 @@ func (s *HTTPServer) SessionGet(resp http.ResponseWriter, req *http.Request) (in // SessionList is used to list all the sessions func (s *HTTPServer) SessionList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - args := structs.DCSpecificRequest{} + args := structs.SessionSpecificRequest{} if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil } + s.parseEntMeta(req, &args.EnterpriseMeta) var out structs.IndexedSessions defer setMeta(resp, &out.QueryMeta) @@ -176,6 +181,7 @@ func (s *HTTPServer) SessionsForNode(resp http.ResponseWriter, req *http.Request if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil } + s.parseEntMeta(req, &args.EnterpriseMeta) // Pull out the node name args.Node = strings.TrimPrefix(req.URL.Path, "/v1/session/node/") diff --git a/agent/session_endpoint_test.go b/agent/session_endpoint_test.go index 6485e0254..2e4c7e999 100644 --- a/agent/session_endpoint_test.go +++ b/agent/session_endpoint_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "reflect" "testing" "time" @@ -13,13 +14,14 @@ import ( "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/types" - "github.com/pascaldekloe/goe/verify" ) func verifySession(t *testing.T, r *retry.R, a *TestAgent, want structs.Session) { + t.Helper() + args := &structs.SessionSpecificRequest{ Datacenter: "dc1", - Session: want.ID, + SessionID: want.ID, } var out structs.IndexedSessions if err := a.RPC("Session.Get", args, &out); err != nil { @@ -34,7 +36,22 @@ func verifySession(t *testing.T, r *retry.R, a *TestAgent, want structs.Session) got := *(out.Sessions[0]) got.CreateIndex = 0 got.ModifyIndex = 0 - verify.Values(t, "", got, want) + + if got.ID != want.ID { + t.Fatalf("bad session ID: expected %s, got %s", want.ID, got.ID) + } + if got.Node != want.Node { + t.Fatalf("bad session Node: expected %s, got %s", want.Node, got.Node) + } + if got.Behavior != want.Behavior { + t.Fatalf("bad session Behavior: expected %s, got %s", want.Behavior, got.Behavior) + } + if got.LockDelay != want.LockDelay { + t.Fatalf("bad session LockDelay: expected %s, got %s", want.LockDelay, got.LockDelay) + } + if !reflect.DeepEqual(got.Checks, want.Checks) { + t.Fatalf("bad session Checks: expected %+v, got %+v", want.Checks, got.Checks) + } } func TestSessionCreate(t *testing.T) { @@ -224,7 +241,8 @@ func TestSessionCreate_NoCheck(t *testing.T) { } func makeTestSession(t *testing.T, srv *HTTPServer) string { - req, _ := http.NewRequest("PUT", "/v1/session/create", nil) + url := "/v1/session/create" + req, _ := http.NewRequest("PUT", url, nil) resp := httptest.NewRecorder() obj, err := srv.SessionCreate(resp, req) if err != nil { @@ -243,7 +261,8 @@ func makeTestSessionDelete(t *testing.T, srv *HTTPServer) string { } enc.Encode(raw) - req, _ := http.NewRequest("PUT", "/v1/session/create", body) + url := "/v1/session/create" + req, _ := http.NewRequest("PUT", url, body) resp := httptest.NewRecorder() obj, err := srv.SessionCreate(resp, req) if err != nil { @@ -262,7 +281,8 @@ func makeTestSessionTTL(t *testing.T, srv *HTTPServer, ttl string) string { } enc.Encode(raw) - req, _ := http.NewRequest("PUT", "/v1/session/create", body) + url := "/v1/session/create" + req, _ := http.NewRequest("PUT", url, body) resp := httptest.NewRecorder() obj, err := srv.SessionCreate(resp, req) if err != nil { diff --git a/agent/structs/structs.go b/agent/structs/structs.go index 3551cecac..ac0c89c3d 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -515,6 +515,7 @@ func (r *ServiceSpecificRequest) CacheMinIndex() uint64 { type NodeSpecificRequest struct { Datacenter string Node string + EnterpriseMeta QueryOptions } @@ -1620,6 +1621,7 @@ type DirEntry struct { Value []byte Session string `json:",omitempty"` + EnterpriseMeta RaftIndex } @@ -1664,6 +1666,7 @@ func (r *KVSRequest) RequestDatacenter() string { type KeyRequest struct { Datacenter string Key string + EnterpriseMeta QueryOptions } @@ -1718,6 +1721,7 @@ type Session struct { Behavior SessionBehavior // What to do when session is invalidated TTL string + EnterpriseMeta RaftIndex } @@ -1773,7 +1777,8 @@ func (r *SessionRequest) RequestDatacenter() string { // SessionSpecificRequest is used to request a session by ID type SessionSpecificRequest struct { Datacenter string - Session string + SessionID string + EnterpriseMeta QueryOptions } diff --git a/agent/structs/structs_oss.go b/agent/structs/structs_oss.go index e5220ef9b..1d146be24 100644 --- a/agent/structs/structs_oss.go +++ b/agent/structs/structs_oss.go @@ -19,6 +19,11 @@ func (m *EnterpriseMeta) addToHash(hasher hash.Hash) { // do nothing } +// WildcardEnterpriseMeta stub +func WildcardEnterpriseMeta() *EnterpriseMeta { + return nil +} + // ReplicationEnterpriseMeta stub func ReplicationEnterpriseMeta() *EnterpriseMeta { return nil diff --git a/website/source/api/session.html.md b/website/source/api/session.html.md index 8c0063656..6329d2086 100644 --- a/website/source/api/session.html.md +++ b/website/source/api/session.html.md @@ -8,9 +8,7 @@ description: |- # Session HTTP Endpoint -The `/session` endpoints create, destroy, and query sessions in Consul. A -conceptual overview of sessions is found at the -[Session Internals](/docs/internals/sessions.html) page. +The `/session` endpoints create, destroy, and query sessions in Consul. ## Create Session @@ -33,11 +31,17 @@ The table below shows this endpoint's support for ### Parameters +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. + - `dc` `(string: "")` - Specifies the datacenter to query. This will default to the datacenter of the agent being queried. This is specified as part of the URL as a query parameter. Using this across datacenters is not recommended. -- `LockDelay` `(string: "15s")` - Specifies the duration for the lock delay. +- `LockDelay` `(string: "15s")` - Specifies the duration for the lock delay. This + must be greater than `0`. - `Node` `(string: "")` - Specifies the name of the node. This must refer to a node that is already registered. @@ -126,6 +130,11 @@ The table below shows this endpoint's support for - `dc` `(string: "")` - Specifies the datacenter to query. This will default to the datacenter of the agent being queried. This is specified as part of the URL as a query parameter. Using this across datacenters is not recommended. + +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. ### Sample Request @@ -167,6 +176,11 @@ The table below shows this endpoint's support for - `dc` `(string: "")` - Specifies the datacenter to query. This will default to the datacenter of the agent being queried. This is specified as part of the URL as a query parameter. Using this across datacenters is not recommended. + +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. ### Sample Request @@ -223,6 +237,11 @@ The table below shows this endpoint's support for - `dc` `(string: "")` - Specifies the datacenter to query. This will default to the datacenter of the agent being queried. This is specified as part of the URL as a query parameter. Using this across datacenters is not recommended. + +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. ### Sample Request @@ -274,6 +293,11 @@ The table below shows this endpoint's support for - `dc` `(string: "")` - Specifies the datacenter to query. This will default to the datacenter of the agent being queried. This is specified as part of the URL as a query parameter. Using this across datacenters is not recommended. + +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. ### Sample Request @@ -329,6 +353,11 @@ The table below shows this endpoint's support for - `dc` `(string: "")` - Specifies the datacenter to query. This will default to the datacenter of the agent being queried. This is specified as part of the URL as a query parameter. Using this across datacenters is not recommended. + +- `ns` `(string: "")` - **Enterprise Only** Specifies the namespace to query. + If not provided, the namespace will be inferred from the request's ACL token, + or will default to the `default` namespace. This is specified as part of the + URL as a query parameter. ### Sample Request