Only generate default workload identity once per alloc task - 1.6.x (#18783)

this can save a bit of cpu when
running plans for tasks that already exist,
and prevents Nomad tokens from changing,
which can cause task template{}s to restart
unnecessarily.
This commit is contained in:
Daniel Bennett 2023-10-17 13:06:20 -05:00 committed by GitHub
parent 657c430e0b
commit b6298dc073
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 104 additions and 17 deletions

View File

@ -55,13 +55,6 @@ func (h *identityHook) Prestart(ctx context.Context, req *interfaces.TaskPrestar
return h.setToken()
}
func (h *identityHook) Update(_ context.Context, req *interfaces.TaskUpdateRequest, _ *interfaces.TaskUpdateResponse) error {
h.lock.Lock()
defer h.lock.Unlock()
return h.setToken()
}
// setToken adds the Nomad token to the task's environment and writes it to a
// file if requested by the jobsepc.
func (h *identityHook) setToken() error {

View File

@ -6,6 +6,5 @@ package taskrunner
import "github.com/hashicorp/nomad/client/allocrunner/interfaces"
var _ interfaces.TaskPrestartHook = (*identityHook)(nil)
var _ interfaces.TaskUpdateHook = (*identityHook)(nil)
// See task_runner_test.go:TestTaskRunner_IdentityHook

View File

@ -30,6 +30,12 @@ import (
const nomadKeystoreExtension = ".nks.json"
type claimSigner interface {
SignClaims(*structs.IdentityClaims) (string, string, error)
}
var _ claimSigner = &Encrypter{}
// Encrypter is the keyring for encrypting variables and signing workload
// identities.
type Encrypter struct {

View File

@ -21,6 +21,18 @@ import (
"github.com/hashicorp/nomad/testutil"
)
type mockSigner struct {
calls []*structs.IdentityClaims
nextToken, nextKeyID string
nextErr error
}
func (s *mockSigner) SignClaims(c *structs.IdentityClaims) (token, keyID string, err error) {
s.calls = append(s.calls, c)
return s.nextToken, s.nextKeyID, s.nextErr
}
// TestEncrypter_LoadSave exercises round-tripping keys to disk
func TestEncrypter_LoadSave(t *testing.T) {
ci.Parallel(t)

View File

@ -279,7 +279,7 @@ func (p *planner) applyPlan(plan *structs.Plan, result *structs.PlanResult, snap
// to approximate the scheduling time.
updateAllocTimestamps(req.AllocsUpdated, now)
err := p.signAllocIdentities(plan.Job, req.AllocsUpdated)
err := signAllocIdentities(p.Server.encrypter, plan.Job, req.AllocsUpdated)
if err != nil {
return nil, err
}
@ -409,16 +409,19 @@ func updateAllocTimestamps(allocations []*structs.Allocation, timestamp int64) {
}
}
func (p *planner) signAllocIdentities(job *structs.Job, allocations []*structs.Allocation) error {
encrypter := p.Server.encrypter
func signAllocIdentities(signer claimSigner, job *structs.Job, allocations []*structs.Allocation) error {
for _, alloc := range allocations {
alloc.SignedIdentities = map[string]string{}
if alloc.SignedIdentities == nil {
alloc.SignedIdentities = map[string]string{}
}
tg := job.LookupTaskGroup(alloc.TaskGroup)
for _, task := range tg.Tasks {
// skip tasks that already have an identity
if _, ok := alloc.SignedIdentities[task.Name]; ok {
continue
}
claims := alloc.ToTaskIdentityClaims(job, task.Name)
token, keyID, err := encrypter.SignClaims(claims)
token, keyID, err := signer.SignClaims(claims)
if err != nil {
return err
}

View File

@ -4,6 +4,7 @@
package nomad
import (
"errors"
"reflect"
"testing"
"time"
@ -16,6 +17,7 @@ import (
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
"github.com/hashicorp/raft"
"github.com/shoenig/test/must"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -393,6 +395,78 @@ func TestPlanApply_applyPlanWithNormalizedAllocs(t *testing.T) {
assert.Equal(index, evalOut.ModifyIndex)
}
func TestPlanApply_signAllocIdentities(t *testing.T) {
// note: this is mutated by the method under test
alloc := mockAlloc()
job := alloc.Job
taskName := job.TaskGroups[0].Tasks[0].Name // "web"
allocs := []*structs.Allocation{alloc}
signErr := errors.New("could not sign the thing")
cases := []struct {
name string
signer *mockSigner
expectErr error
callNum int
}{
{
name: "signer error",
signer: &mockSigner{
nextErr: signErr,
},
expectErr: signErr,
callNum: 1,
},
{
name: "first signing",
signer: &mockSigner{
nextToken: "first-token",
nextKeyID: "first-key",
},
callNum: 1,
},
{
name: "second signing",
signer: &mockSigner{
nextToken: "dont-sign-token",
nextKeyID: "dont-sign-key",
},
callNum: 0,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := signAllocIdentities(tc.signer, job, allocs)
if tc.expectErr != nil {
must.Error(t, err)
must.ErrorIs(t, err, tc.expectErr)
} else {
must.NoError(t, err)
// assert mutations happened
must.MapLen(t, 1, allocs[0].SignedIdentities)
// we should always keep the first signing
must.Eq(t, "first-token", allocs[0].SignedIdentities[taskName])
must.Eq(t, "first-key", allocs[0].SigningKeyID)
}
must.Len(t, tc.callNum, tc.signer.calls, must.Sprint("unexpected call count"))
if tc.callNum > 0 {
call := tc.signer.calls[tc.callNum-1]
must.NotNil(t, call)
must.Eq(t, call.AllocationID, alloc.ID)
must.Eq(t, call.Namespace, alloc.Namespace)
must.Eq(t, call.JobID, job.ID)
must.Eq(t, call.TaskName, taskName)
}
})
}
}
func TestPlanApply_EvalPlan_Simple(t *testing.T) {
ci.Parallel(t)
state := testStateStore(t)

View File

@ -888,7 +888,7 @@ func TestServiceRegistration_List(t *testing.T) {
job.Namespace = "platform"
allocs[0].Namespace = "platform"
require.NoError(t, s.State().UpsertJob(structs.MsgTypeTestSetup, 10, nil, job))
s.signAllocIdentities(job, allocs)
signAllocIdentities(s.encrypter, job, allocs)
require.NoError(t, s.State().UpsertAllocs(structs.MsgTypeTestSetup, 15, allocs))
signedToken := allocs[0].SignedIdentities["web"]
@ -1175,7 +1175,7 @@ func TestServiceRegistration_GetService(t *testing.T) {
allocs := []*structs.Allocation{mock.Alloc()}
job := allocs[0].Job
require.NoError(t, s.State().UpsertJob(structs.MsgTypeTestSetup, 10, nil, job))
s.signAllocIdentities(job, allocs)
signAllocIdentities(s.encrypter, job, allocs)
require.NoError(t, s.State().UpsertAllocs(structs.MsgTypeTestSetup, 15, allocs))
signedToken := allocs[0].SignedIdentities["web"]