diff --git a/consul/fsm.go b/consul/fsm.go index 8b4fd3d65..4c82357e6 100644 --- a/consul/fsm.go +++ b/consul/fsm.go @@ -69,6 +69,8 @@ func (c *consulFSM) Apply(log *raft.Log) interface{} { return c.applyKVSOperation(buf[1:], log.Index) case structs.SessionRequestType: return c.applySessionOperation(buf[1:], log.Index) + case structs.ACLRequestType: + return c.applyACLOperation(buf[1:], log.Index) default: panic(fmt.Errorf("failed to apply request: %#v", buf)) } @@ -196,6 +198,27 @@ func (c *consulFSM) applySessionOperation(buf []byte, index uint64) interface{} return nil } +func (c *consulFSM) applyACLOperation(buf []byte, index uint64) interface{} { + var req structs.ACLRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + switch req.Op { + case structs.ACLSet: + if err := c.state.ACLSet(index, &req.ACL); err != nil { + return err + } else { + return req.ACL.ID + } + case structs.ACLDelete: + return c.state.ACLDelete(index, req.ACL.ID) + default: + c.logger.Printf("[WARN] consul.fsm: Invalid ACL operation '%s'", req.Op) + return fmt.Errorf("Invalid ACL operation '%s'", req.Op) + } + return nil +} + func (c *consulFSM) Snapshot() (raft.FSMSnapshot, error) { defer func(start time.Time) { c.logger.Printf("[INFO] consul.fsm: snapshot created in %v", time.Now().Sub(start)) @@ -267,6 +290,15 @@ func (c *consulFSM) Restore(old io.ReadCloser) error { return err } + case structs.ACLRequestType: + var req structs.ACL + if err := dec.Decode(&req); err != nil { + return err + } + if err := c.state.ACLRestore(&req); err != nil { + return err + } + default: return fmt.Errorf("Unrecognized msg type: %v", msgType) } @@ -298,6 +330,11 @@ func (s *consulSnapshot) Persist(sink raft.SnapshotSink) error { return err } + if err := s.persistACLs(sink, encoder); err != nil { + sink.Cancel() + return err + } + if err := s.persistKV(sink, encoder); err != nil { sink.Cancel() return err @@ -364,6 +401,22 @@ func (s *consulSnapshot) persistSessions(sink raft.SnapshotSink, return nil } +func (s *consulSnapshot) persistACLs(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + acls, err := s.state.ACLList() + if err != nil { + return err + } + + for _, s := range acls { + sink.Write([]byte{byte(structs.ACLRequestType)}) + if err := encoder.Encode(s); err != nil { + return err + } + } + return nil +} + func (s *consulSnapshot) persistKV(sink raft.SnapshotSink, encoder *codec.Encoder) error { streamCh := make(chan interface{}, 256) diff --git a/consul/fsm_test.go b/consul/fsm_test.go index 5e5d086d8..f47c00653 100644 --- a/consul/fsm_test.go +++ b/consul/fsm_test.go @@ -328,6 +328,8 @@ func TestFSM_SnapshotRestore(t *testing.T) { }) session := &structs.Session{Node: "foo"} fsm.state.SessionCreate(9, session) + acl := &structs.ACL{Name: "User Token"} + fsm.state.ACLSet(10, acl) // Snapshot snap, err := fsm.Snapshot() @@ -392,7 +394,16 @@ func TestFSM_SnapshotRestore(t *testing.T) { t.Fatalf("err: %v", err) } if s.Node != "foo" { - t.Fatalf("bad: %v", d) + t.Fatalf("bad: %v", s) + } + + // Verify ACL is restored + _, a, err := fsm.state.ACLGet(acl.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if a.Name != "User Token" { + t.Fatalf("bad: %v", a) } } @@ -767,3 +778,75 @@ func TestFSM_KVSUnlock(t *testing.T) { t.Fatalf("bad: %v", *d) } } + +func TestFSM_ACL_Set_Delete(t *testing.T) { + fsm, err := NewFSM(os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + defer fsm.Close() + + // Create a new ACL + req := structs.ACLRequest{ + Datacenter: "dc1", + Op: structs.ACLSet, + ACL: structs.ACL{ + Name: "User token", + Type: structs.ACLTypeClient, + }, + } + buf, err := structs.Encode(structs.ACLRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := fsm.Apply(makeLog(buf)) + if err, ok := resp.(error); ok { + t.Fatalf("resp: %v", err) + } + + // Get the ACL + id := resp.(string) + _, acl, err := fsm.state.ACLGet(id) + if err != nil { + t.Fatalf("err: %v", err) + } + if acl == nil { + t.Fatalf("missing") + } + + // Verify the ACL + if acl.ID != id { + t.Fatalf("bad: %v", *acl) + } + if acl.Name != "User token" { + t.Fatalf("bad: %v", *acl) + } + if acl.Type != structs.ACLTypeClient { + t.Fatalf("bad: %v", *acl) + } + + // Try to destroy + destroy := structs.ACLRequest{ + Datacenter: "dc1", + Op: structs.ACLDelete, + ACL: structs.ACL{ + ID: id, + }, + } + buf, err = structs.Encode(structs.ACLRequestType, destroy) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + _, acl, err = fsm.state.ACLGet(id) + if err != nil { + t.Fatalf("err: %v", err) + } + if acl != nil { + t.Fatalf("should be destroyed") + } +}