Merge pull request #524 from amalaviy/session_ttl

Consul Session TTLs
This commit is contained in:
Armon Dadgar 2014-12-12 14:42:25 -08:00
commit 8dbfe7c9a8
13 changed files with 816 additions and 13 deletions

View File

@ -188,6 +188,7 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) {
s.mux.HandleFunc("/v1/session/create", s.wrap(s.SessionCreate))
s.mux.HandleFunc("/v1/session/destroy/", s.wrap(s.SessionDestroy))
s.mux.HandleFunc("/v1/session/renew/", s.wrap(s.SessionRenew))
s.mux.HandleFunc("/v1/session/info/", s.wrap(s.SessionGet))
s.mux.HandleFunc("/v1/session/node/", s.wrap(s.SessionsForNode))
s.mux.HandleFunc("/v1/session/list", s.wrap(s.SessionList))

View File

@ -40,6 +40,7 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request)
Checks: []string{consul.SerfCheckID},
LockDelay: 15 * time.Second,
Behavior: structs.SessionKeysRelease,
TTL: "",
},
}
s.parseDC(req, &args.Datacenter)
@ -51,6 +52,21 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request)
resp.Write([]byte(fmt.Sprintf("Request decode failed: %v", err)))
return nil, nil
}
if args.Session.TTL != "" {
ttl, err := time.ParseDuration(args.Session.TTL)
if err != nil {
resp.WriteHeader(400)
resp.Write([]byte(fmt.Sprintf("Request TTL decode failed: %v", err)))
return nil, nil
}
if ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax {
resp.WriteHeader(400)
resp.Write([]byte(fmt.Sprintf("Request TTL '%s', must be between [%v-%v]", args.Session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)))
return nil, nil
}
}
}
// Create the session, get the ID
@ -130,6 +146,39 @@ func (s *HTTPServer) SessionDestroy(resp http.ResponseWriter, req *http.Request)
return true, nil
}
// SessionRenew is used to renew the TTL on an existing TTL session
func (s *HTTPServer) SessionRenew(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Mandate a PUT request
if req.Method != "PUT" {
resp.WriteHeader(405)
return nil, nil
}
args := structs.SessionSpecificRequest{}
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
// Pull out the session id
args.Session = strings.TrimPrefix(req.URL.Path, "/v1/session/renew/")
if args.Session == "" {
resp.WriteHeader(400)
resp.Write([]byte("Missing session"))
return nil, nil
}
var out structs.IndexedSessions
if err := s.agent.RPC("Session.Renew", &args, &out); err != nil {
return nil, err
} else if out.Sessions == nil {
resp.WriteHeader(404)
resp.Write([]byte(fmt.Sprintf("Session id '%s' not found", args.Session)))
return nil, nil
}
return out.Sessions, nil
}
// SessionGet is used to get info for a particular session
func (s *HTTPServer) SessionGet(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
args := structs.SessionSpecificRequest{}

View File

@ -176,6 +176,28 @@ func makeTestSessionDelete(t *testing.T, srv *HTTPServer) string {
return sessResp.ID
}
func makeTestSessionTTL(t *testing.T, srv *HTTPServer, ttl string) string {
// Create Session with TTL
body := bytes.NewBuffer(nil)
enc := json.NewEncoder(body)
raw := map[string]interface{}{
"TTL": ttl,
}
enc.Encode(raw)
req, err := http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp := httptest.NewRecorder()
obj, err := srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
sessResp := obj.(sessionCreateResponse)
return sessResp.ID
}
func TestSessionDestroy(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
id := makeTestSession(t, srv)
@ -192,6 +214,206 @@ func TestSessionDestroy(t *testing.T) {
})
}
func TestSessionTTL(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
TTL := "10s" // use the minimum legal ttl
ttl := 10 * time.Second
id := makeTestSessionTTL(t, srv, TTL)
req, err := http.NewRequest("GET",
"/v1/session/info/"+id, nil)
resp := httptest.NewRecorder()
obj, err := srv.SessionGet(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
respObj, ok := obj.(structs.Sessions)
if !ok {
t.Fatalf("should work")
}
if len(respObj) != 1 {
t.Fatalf("bad: %v", respObj)
}
if respObj[0].TTL != TTL {
t.Fatalf("Incorrect TTL: %s", respObj[0].TTL)
}
time.Sleep(ttl*structs.SessionTTLMultiplier + ttl)
req, err = http.NewRequest("GET",
"/v1/session/info/"+id, nil)
resp = httptest.NewRecorder()
obj, err = srv.SessionGet(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
respObj, ok = obj.(structs.Sessions)
if len(respObj) != 0 {
t.Fatalf("session '%s' should have been destroyed", id)
}
})
}
func TestSessionBadTTL(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
badTTL := "10z"
// Create Session with illegal TTL
body := bytes.NewBuffer(nil)
enc := json.NewEncoder(body)
raw := map[string]interface{}{
"TTL": badTTL,
}
enc.Encode(raw)
req, err := http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp := httptest.NewRecorder()
obj, err := srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if obj != nil {
t.Fatalf("illegal TTL '%s' allowed", badTTL)
}
if resp.Code != 400 {
t.Fatalf("Bad response code, should be 400")
}
// less than SessionTTLMin
body = bytes.NewBuffer(nil)
enc = json.NewEncoder(body)
raw = map[string]interface{}{
"TTL": "5s",
}
enc.Encode(raw)
req, err = http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp = httptest.NewRecorder()
obj, err = srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if obj != nil {
t.Fatalf("illegal TTL '%s' allowed", badTTL)
}
if resp.Code != 400 {
t.Fatalf("Bad response code, should be 400")
}
// more than SessionTTLMax
body = bytes.NewBuffer(nil)
enc = json.NewEncoder(body)
raw = map[string]interface{}{
"TTL": "4000s",
}
enc.Encode(raw)
req, err = http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp = httptest.NewRecorder()
obj, err = srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if obj != nil {
t.Fatalf("illegal TTL '%s' allowed", badTTL)
}
if resp.Code != 400 {
t.Fatalf("Bad response code, should be 400")
}
})
}
func TestSessionTTLRenew(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
TTL := "10s" // use the minimum legal ttl
ttl := 10 * time.Second
id := makeTestSessionTTL(t, srv, TTL)
req, err := http.NewRequest("GET",
"/v1/session/info/"+id, nil)
resp := httptest.NewRecorder()
obj, err := srv.SessionGet(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
respObj, ok := obj.(structs.Sessions)
if !ok {
t.Fatalf("should work")
}
if len(respObj) != 1 {
t.Fatalf("bad: %v", respObj)
}
if respObj[0].TTL != TTL {
t.Fatalf("Incorrect TTL: %s", respObj[0].TTL)
}
// Sleep to consume some time before renew
time.Sleep(ttl * (structs.SessionTTLMultiplier / 2))
req, err = http.NewRequest("PUT",
"/v1/session/renew/"+id, nil)
resp = httptest.NewRecorder()
obj, err = srv.SessionRenew(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
respObj, ok = obj.(structs.Sessions)
if !ok {
t.Fatalf("should work")
}
if len(respObj) != 1 {
t.Fatalf("bad: %v", respObj)
}
// Sleep for ttl * TTL Multiplier
time.Sleep(ttl * structs.SessionTTLMultiplier)
req, err = http.NewRequest("GET",
"/v1/session/info/"+id, nil)
resp = httptest.NewRecorder()
obj, err = srv.SessionGet(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
respObj, ok = obj.(structs.Sessions)
if !ok {
t.Fatalf("session '%s' should have renewed", id)
}
if len(respObj) != 1 {
t.Fatalf("session '%s' should have renewed", id)
}
// now wait for timeout and expect session to get destroyed
time.Sleep(ttl * structs.SessionTTLMultiplier)
req, err = http.NewRequest("GET",
"/v1/session/info/"+id, nil)
resp = httptest.NewRecorder()
obj, err = srv.SessionGet(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
respObj, ok = obj.(structs.Sessions)
if !ok {
t.Fatalf("session '%s' should have destroyed", id)
}
if len(respObj) != 0 {
t.Fatalf("session '%s' should have destroyed", id)
}
})
}
func TestSessionGet(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
id := makeTestSession(t, srv)

View File

@ -61,6 +61,13 @@ func (s *Server) leaderLoop(stopCh chan struct{}) {
s.logger.Printf("[ERR] consul: ACL initialization failed: %v", err)
}
// Setup Session Timers if we are the leader and need to
if err := s.initializeSessionTimers(); err != nil {
s.logger.Printf("[ERR] consul: Session Timers initialization failed: %v", err)
}
// clear the session timers if we are no longer leader and exit the leaderLoop
defer s.clearAllSessionTimers()
// Reconcile channel is only used once initial reconcile
// has succeeded
var reconcileCh chan serf.Member

View File

@ -370,6 +370,9 @@ func TestLeader_LeftLeader(t *testing.T) {
break
}
}
if leader == nil {
t.Fatalf("Should have a leader")
}
leader.Leave()
leader.Shutdown()
time.Sleep(100 * time.Millisecond)

View File

@ -128,6 +128,12 @@ type Server struct {
// which SHOULD only consist of Consul servers
serfWAN *serf.Serf
// sessionTimers track the expiration time of each Session that has
// a TTL. On expiration, a SessionDestroy event will occur, and
// destroy the session via standard session destory processing
sessionTimers map[string]*time.Timer
sessionTimersLock sync.RWMutex
shutdown bool
shutdownCh chan struct{}
shutdownLock sync.Mutex

View File

@ -36,6 +36,16 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
default:
return fmt.Errorf("Invalid Behavior setting '%s'", args.Session.Behavior)
}
if args.Session.TTL != "" {
ttl, err := time.ParseDuration(args.Session.TTL)
if err != nil {
return fmt.Errorf("Session TTL '%s' invalid: %v", args.Session.TTL, err)
}
if ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax {
return fmt.Errorf("Invalid Session TTL '%d', must be between [%v=%v]", ttl, structs.SessionTTLMin, structs.SessionTTLMax)
}
}
// If this is a create, we must generate the Session ID. This must
// be done prior to appending to the raft log, because the ID is not
@ -63,6 +73,13 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
s.srv.logger.Printf("[ERR] consul.session: Apply failed: %v", err)
return err
}
if args.Op == structs.SessionCreate && args.Session.TTL != "" {
s.srv.resetSessionTimer(args.Session.ID, nil)
} else if args.Op == structs.SessionDestroy && args.Session.TTL != "" {
s.srv.clearSessionTimer(args.Session.ID)
}
if respErr, ok := resp.(error); ok {
return respErr
}
@ -133,3 +150,24 @@ func (s *Session) NodeSessions(args *structs.NodeSpecificRequest,
return err
})
}
// Renew is used to renew the TTL on a single session
func (s *Session) Renew(args *structs.SessionSpecificRequest,
reply *structs.IndexedSessions) error {
if done, err := s.srv.forward("Session.Renew", args, args, reply); done {
return err
}
// Get the local state
state := s.srv.fsm.State()
// Get the session, from local state
index, session, err := state.SessionGet(args.Session)
reply.Index = index
if session != nil {
reply.Sessions = structs.Sessions{session}
// reset the session TTL timer
err = s.srv.resetSessionTimer(args.Session, session)
}
return err
}

View File

@ -5,6 +5,7 @@ import (
"github.com/hashicorp/consul/testutil"
"os"
"testing"
"time"
)
func TestSessionEndpoint_Apply(t *testing.T) {
@ -223,6 +224,161 @@ func TestSessionEndpoint_List(t *testing.T) {
}
}
func TestSessionEndpoint_Renew(t *testing.T) {
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
client := rpcClient(t, s1)
defer client.Close()
testutil.WaitForLeader(t, client.Call, "dc1")
TTL := "10s" // the minimum allowed ttl
ttl := 10 * time.Second
s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"})
ids := []string{}
for i := 0; i < 5; i++ {
arg := structs.SessionRequest{
Datacenter: "dc1",
Op: structs.SessionCreate,
Session: structs.Session{
Node: "foo",
TTL: TTL,
},
}
var out string
if err := client.Call("Session.Apply", &arg, &out); err != nil {
t.Fatalf("err: %v", err)
}
ids = append(ids, out)
}
getR := structs.DCSpecificRequest{
Datacenter: "dc1",
}
var sessions structs.IndexedSessions
if err := client.Call("Session.List", &getR, &sessions); err != nil {
t.Fatalf("err: %v", err)
}
if sessions.Index == 0 {
t.Fatalf("Bad: %v", sessions)
}
if len(sessions.Sessions) != 5 {
t.Fatalf("Bad: %v", sessions.Sessions)
}
for i := 0; i < len(sessions.Sessions); i++ {
s := sessions.Sessions[i]
if !strContains(ids, s.ID) {
t.Fatalf("bad: %v", s)
}
if s.Node != "foo" {
t.Fatalf("bad: %v", s)
}
if s.TTL != TTL {
t.Fatalf("bad session TTL: %s %v", s.TTL, s)
}
t.Logf("Created session '%s'", s.ID)
}
// Sleep for time shorter than internal destroy ttl
time.Sleep(ttl * structs.SessionTTLMultiplier / 2)
// renew 3 out of 5 sessions
for i := 0; i < 3; i++ {
renewR := structs.SessionSpecificRequest{
Datacenter: "dc1",
Session: ids[i],
}
var session structs.IndexedSessions
if err := client.Call("Session.Renew", &renewR, &session); err != nil {
t.Fatalf("err: %v", err)
}
if session.Index == 0 {
t.Fatalf("Bad: %v", session)
}
if len(session.Sessions) != 1 {
t.Fatalf("Bad: %v", session.Sessions)
}
s := session.Sessions[0]
if !strContains(ids, s.ID) {
t.Fatalf("bad: %v", s)
}
if s.Node != "foo" {
t.Fatalf("bad: %v", s)
}
t.Logf("Renewed session '%s'", s.ID)
}
// now sleep for 2/3 the internal destroy TTL time for renewed sessions
// which is more than the internal destroy TTL time for the non-renewed sessions
time.Sleep((ttl * structs.SessionTTLMultiplier) * 2.0 / 3.0)
var sessionsL1 structs.IndexedSessions
if err := client.Call("Session.List", &getR, &sessionsL1); err != nil {
t.Fatalf("err: %v", err)
}
if sessionsL1.Index == 0 {
t.Fatalf("Bad: %v", sessionsL1)
}
t.Logf("Expect 2 sessions to be destroyed")
for i := 0; i < len(sessionsL1.Sessions); i++ {
s := sessionsL1.Sessions[i]
if !strContains(ids, s.ID) {
t.Fatalf("bad: %v", s)
}
if s.Node != "foo" {
t.Fatalf("bad: %v", s)
}
if s.TTL != TTL {
t.Fatalf("bad: %v", s)
}
if i > 2 {
t.Errorf("session '%s' should be destroyed", s.ID)
}
}
if len(sessionsL1.Sessions) > 3 {
t.Fatalf("Bad: %v", sessionsL1.Sessions)
}
// now sleep again for ttl*2 - no sessions should still be alive
time.Sleep(ttl * structs.SessionTTLMultiplier)
var sessionsL2 structs.IndexedSessions
if err := client.Call("Session.List", &getR, &sessionsL2); err != nil {
t.Fatalf("err: %v", err)
}
if sessionsL2.Index == 0 {
t.Fatalf("Bad: %v", sessionsL2)
}
if len(sessionsL2.Sessions) != 0 {
for i := 0; i < len(sessionsL2.Sessions); i++ {
s := sessionsL2.Sessions[i]
if !strContains(ids, s.ID) {
t.Fatalf("bad: %v", s)
}
if s.Node != "foo" {
t.Fatalf("bad: %v", s)
}
if s.TTL != TTL {
t.Fatalf("bad: %v", s)
}
t.Errorf("session '%s' should be destroyed", s.ID)
}
t.Fatalf("Bad: %v", sessionsL2.Sessions)
}
}
func TestSessionEndpoint_NodeSessions(t *testing.T) {
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)

106
consul/session_ttl.go Normal file
View File

@ -0,0 +1,106 @@
package consul
import (
"fmt"
"github.com/hashicorp/consul/consul/structs"
"time"
)
func (s *Server) initializeSessionTimers() error {
s.sessionTimersLock.Lock()
s.sessionTimers = make(map[string]*time.Timer)
s.sessionTimersLock.Unlock()
// walk the TTL index and resetSessionTimer for each non-zero TTL
state := s.fsm.State()
_, sessions, err := state.SessionListTTL()
if err != nil {
return err
}
for _, session := range sessions {
err := s.resetSessionTimer(session.ID, session)
if err != nil {
return err
}
}
return nil
}
// invalidate the session when timer expires, called by AfterFunc
func (s *Server) invalidateSession(id string) {
args := structs.SessionRequest{
Datacenter: s.config.Datacenter,
Op: structs.SessionDestroy,
}
args.Session.ID = id
// Apply the update to destroy the session
_, err := s.raftApply(structs.SessionRequestType, args)
if err != nil {
s.logger.Printf("[ERR] consul.session: Apply failed: %v", err)
}
}
func (s *Server) resetSessionTimer(id string, session *structs.Session) error {
if session == nil {
var err error
// find the session
state := s.fsm.State()
_, session, err = state.SessionGet(id)
if err != nil || session == nil {
return fmt.Errorf("Could not find session for '%s'\n", id)
}
}
if session.TTL == "" {
return nil
}
ttl, err := time.ParseDuration(session.TTL)
if err != nil {
return fmt.Errorf("Invalid Session TTL '%s': %v", session.TTL, err)
}
if ttl == 0 {
return nil
}
s.sessionTimersLock.Lock()
if s.sessionTimers == nil {
s.sessionTimers = make(map[string]*time.Timer)
}
defer s.sessionTimersLock.Unlock()
if t := s.sessionTimers[id]; t != nil {
// TBD may modify the session's active TTL based on load here
t.Reset(ttl * structs.SessionTTLMultiplier)
} else {
s.sessionTimers[session.ID] = time.AfterFunc(ttl*structs.SessionTTLMultiplier, func() {
s.invalidateSession(session.ID)
})
}
return nil
}
func (s *Server) clearSessionTimer(id string) error {
s.sessionTimersLock.Lock()
defer s.sessionTimersLock.Unlock()
if s.sessionTimers[id] != nil {
// stop the session timer and delete from the map
s.sessionTimers[id].Stop()
delete(s.sessionTimers, id)
}
return nil
}
func (s *Server) clearAllSessionTimers() error {
s.sessionTimersLock.Lock()
defer s.sessionTimersLock.Unlock()
// stop all timers and clear out the map
for _, t := range s.sessionTimers {
t.Stop()
}
s.sessionTimers = nil
return nil
}

168
consul/session_ttl_test.go Normal file
View File

@ -0,0 +1,168 @@
package consul
import (
"errors"
"fmt"
"os"
"testing"
"time"
"github.com/hashicorp/consul/consul/structs"
"github.com/hashicorp/consul/testutil"
)
func TestServer_sessionTTL(t *testing.T) {
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
dir2, s2 := testServerDCBootstrap(t, "dc1", false)
defer os.RemoveAll(dir2)
defer s2.Shutdown()
dir3, s3 := testServerDCBootstrap(t, "dc1", false)
defer os.RemoveAll(dir3)
defer s3.Shutdown()
servers := []*Server{s1, s2, s3}
// Try to join
addr := fmt.Sprintf("127.0.0.1:%d",
s1.config.SerfLANConfig.MemberlistConfig.BindPort)
if _, err := s2.JoinLAN([]string{addr}); err != nil {
t.Fatalf("err: %v", err)
}
if _, err := s3.JoinLAN([]string{addr}); err != nil {
t.Fatalf("err: %v", err)
}
for _, s := range servers {
testutil.WaitForResult(func() (bool, error) {
peers, _ := s.raftPeers.Peers()
return len(peers) == 3, nil
}, func(err error) {
t.Fatalf("should have 3 peers")
})
}
// Find the leader
var leader *Server
for _, s := range servers {
// check that s.sessionTimers is empty
if len(s.sessionTimers) != 0 {
t.Fatalf("should have no sessionTimers")
}
// find the leader too
if s.IsLeader() {
leader = s
}
}
if leader == nil {
t.Fatalf("Should have a leader")
}
client := rpcClient(t, leader)
defer client.Close()
leader.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"})
// create a TTL session
arg := structs.SessionRequest{
Datacenter: "dc1",
Op: structs.SessionCreate,
Session: structs.Session{
Node: "foo",
TTL: "10s",
},
}
var id1 string
if err := client.Call("Session.Apply", &arg, &id1); err != nil {
t.Fatalf("err: %v", err)
}
// check that leader.sessionTimers has the session id in it
// means initializeSessionTimers was called and resetSessionTimer was called
if len(leader.sessionTimers) == 0 || leader.sessionTimers[id1] == nil {
t.Fatalf("sessionTimers not initialized and does not contain session timer for session")
}
time.Sleep(100 * time.Millisecond)
leader.Leave()
leader.Shutdown()
// leader.sessionTimers should be empty due to clearAllSessionTimers getting called
if len(leader.sessionTimers) != 0 {
t.Fatalf("session timers should be empty on the shutdown leader")
}
time.Sleep(100 * time.Millisecond)
var remain *Server
for _, s := range servers {
if s == leader {
continue
}
remain = s
testutil.WaitForResult(func() (bool, error) {
peers, _ := s.raftPeers.Peers()
return len(peers) == 2, errors.New(fmt.Sprintf("%v", peers))
}, func(err error) {
t.Fatalf("should have 2 peers: %v", err)
})
}
// Verify the old leader is deregistered
state := remain.fsm.State()
testutil.WaitForResult(func() (bool, error) {
_, found, _ := state.GetNode(leader.config.NodeName)
return !found, nil
}, func(err error) {
t.Fatalf("leader should be deregistered")
})
// Find the new leader
leader = nil
for _, s := range servers {
// find the leader too
if s.IsLeader() {
leader = s
}
}
if leader == nil {
t.Fatalf("Should have a new leader")
}
// check that new leader.sessionTimers has the session id in it
if len(leader.sessionTimers) == 0 || leader.sessionTimers[id1] == nil {
t.Fatalf("sessionTimers not initialized and does not contain session timer for session")
}
// create another TTL session with the same parameters
var id2 string
if err := client.Call("Session.Apply", &arg, &id2); err != nil {
t.Fatalf("err: %v", err)
}
if len(leader.sessionTimers) != 2 {
t.Fatalf("sessionTimes length should be 2")
}
// destroy the via invalidateSession as if on TTL expiry
leader.invalidateSession(id2)
if len(leader.sessionTimers) != 1 {
t.Fatalf("sessionTimers length should 1")
}
// destroy the id2 session (test clearSessionTimer)
arg.Op = structs.SessionDestroy
arg.Session.ID = id2
if err := client.Call("Session.Apply", &arg, &id2); err != nil {
t.Fatalf("err: %v", err)
}
if len(leader.sessionTimers) != 0 {
t.Fatalf("sessionTimers length should be 0")
}
}

View File

@ -294,6 +294,10 @@ func (s *StateStore) initialize() error {
AllowBlank: true,
Fields: []string{"Node"},
},
"ttl": &MDBIndex{
AllowBlank: true,
Fields: []string{"TTL"},
},
},
Decoder: func(buf []byte) interface{} {
out := new(structs.Session)
@ -369,6 +373,7 @@ func (s *StateStore) initialize() error {
"KVSListKeys": MDBTables{s.kvsTable},
"SessionGet": MDBTables{s.sessionTable},
"SessionList": MDBTables{s.sessionTable},
"SessionListTTL": MDBTables{s.sessionTable},
"NodeSessions": MDBTables{s.sessionTable},
"ACLGet": MDBTables{s.aclTable},
"ACLList": MDBTables{s.aclTable},
@ -1336,6 +1341,17 @@ func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error
return fmt.Errorf("Invalid Session Behavior setting '%s'", session.Behavior)
}
if session.TTL != "" {
ttl, err := time.ParseDuration(session.TTL)
if err != nil {
return fmt.Errorf("Invalid Session TTL '%s': %v", session.TTL, err)
}
if ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax {
return fmt.Errorf("Invalid Session TTL '%s', must be between [%v-%v]", session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)
}
}
// Assign the create index
session.CreateIndex = index
@ -1445,6 +1461,16 @@ func (s *StateStore) SessionList() (uint64, []*structs.Session, error) {
return idx, out, err
}
// SessionListTTL is used to list all the open ttl sessions
func (s *StateStore) SessionListTTL() (uint64, []*structs.Session, error) {
idx, res, err := s.sessionTable.Get("ttl")
out := make([]*structs.Session, len(res))
for i, raw := range res {
out[i] = raw.(*structs.Session)
}
return idx, out, err
}
// NodeSessions is used to list all the open sessions for a node
func (s *StateStore) NodeSessions(node string) (uint64, []*structs.Session, error) {
idx, res, err := s.sessionTable.Get("node", node)

View File

@ -703,13 +703,17 @@ func TestStoreSnapshot(t *testing.T) {
if ok, err := store.KVSLock(18, d); err != nil || !ok {
t.Fatalf("err: %v", err)
}
session = &structs.Session{ID: generateUUID(), Node: "baz", TTL: "60s"}
if err := store.SessionCreate(19, session); err != nil {
t.Fatalf("err: %v", err)
}
a1 := &structs.ACL{
ID: generateUUID(),
Name: "User token",
Type: structs.ACLTypeClient,
}
if err := store.ACLSet(19, a1); err != nil {
if err := store.ACLSet(20, a1); err != nil {
t.Fatalf("err: %v", err)
}
@ -718,7 +722,7 @@ func TestStoreSnapshot(t *testing.T) {
Name: "User token",
Type: structs.ACLTypeClient,
}
if err := store.ACLSet(20, a2); err != nil {
if err := store.ACLSet(21, a2); err != nil {
t.Fatalf("err: %v", err)
}
@ -730,7 +734,7 @@ func TestStoreSnapshot(t *testing.T) {
defer snap.Close()
// Check the last nodes
if idx := snap.LastIndex(); idx != 20 {
if idx := snap.LastIndex(); idx != 21 {
t.Fatalf("bad: %v", idx)
}
@ -785,15 +789,25 @@ func TestStoreSnapshot(t *testing.T) {
t.Fatalf("missing KVS entries!")
}
// Check there are 2 sessions
// Check there are 3 sessions
sessions, err := snap.SessionList()
if err != nil {
t.Fatalf("err: %v", err)
}
if len(sessions) != 2 {
if len(sessions) != 3 {
t.Fatalf("missing sessions")
}
ttls := 0
for _, session := range sessions {
if session.TTL != "" {
ttls++
}
}
if ttls != 1 {
t.Fatalf("Wrong number of sessions with TTL")
}
// Check for an acl
acls, err := snap.ACLList()
if err != nil {
@ -804,13 +818,13 @@ func TestStoreSnapshot(t *testing.T) {
}
// Make some changes!
if err := store.EnsureService(21, "foo", &structs.NodeService{"db", "db", []string{"slave"}, 8000}); err != nil {
if err := store.EnsureService(22, "foo", &structs.NodeService{"db", "db", []string{"slave"}, 8000}); err != nil {
t.Fatalf("err: %v", err)
}
if err := store.EnsureService(22, "bar", &structs.NodeService{"db", "db", []string{"master"}, 8000}); err != nil {
if err := store.EnsureService(23, "bar", &structs.NodeService{"db", "db", []string{"master"}, 8000}); err != nil {
t.Fatalf("err: %v", err)
}
if err := store.EnsureNode(23, structs.Node{"baz", "127.0.0.3"}); err != nil {
if err := store.EnsureNode(24, structs.Node{"baz", "127.0.0.3"}); err != nil {
t.Fatalf("err: %v", err)
}
checkAfter := &structs.HealthCheck{
@ -820,16 +834,16 @@ func TestStoreSnapshot(t *testing.T) {
Status: structs.HealthCritical,
ServiceID: "db",
}
if err := store.EnsureCheck(24, checkAfter); err != nil {
if err := store.EnsureCheck(26, checkAfter); err != nil {
t.Fatalf("err: %v", err)
}
if err := store.KVSDelete(25, "/web/b"); err != nil {
if err := store.KVSDelete(26, "/web/b"); err != nil {
t.Fatalf("err: %v", err)
}
// Nuke an ACL
if err := store.ACLDelete(26, a1.ID); err != nil {
if err := store.ACLDelete(27, a1.ID); err != nil {
t.Fatalf("err: %v", err)
}
@ -883,12 +897,12 @@ func TestStoreSnapshot(t *testing.T) {
t.Fatalf("missing KVS entries!")
}
// Check there are 2 sessions
// Check there are 3 sessions
sessions, err = snap.SessionList()
if err != nil {
t.Fatalf("err: %v", err)
}
if len(sessions) != 2 {
if len(sessions) != 3 {
t.Fatalf("missing sessions")
}

View File

@ -385,6 +385,12 @@ const (
SessionKeysDelete = "delete"
)
const (
SessionTTLMin = 10 * time.Second
SessionTTLMax = 3600 * time.Second
SessionTTLMultiplier = 2
)
// Session is used to represent an open session in the KV store.
// This issued to associate node checks with acquired locks.
type Session struct {
@ -395,6 +401,7 @@ type Session struct {
Checks []string
LockDelay time.Duration
Behavior SessionBehavior // What to do when session is invalidated
TTL string
}
type Sessions []*Session