backport of commit f8dd46acb830f4ef7baef759fb4d0e3752d03e9e (#22251)

This commit is contained in:
hc-github-team-secure-vault-core 2023-08-08 17:05:43 -04:00 committed by GitHub
parent 1c89ff215b
commit 0ecf0f300e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 125 additions and 37 deletions

4
changelog/22249.txt Normal file
View File

@ -0,0 +1,4 @@
```release-note:bug
sdk/ldaputil: Properly escape user filters when using UPN domains
sdk/ldaputil: use EscapeLDAPValue implementation from cap/ldap
```

View File

@ -8,6 +8,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/binary"
"encoding/hex"
"fmt"
"math"
"net"
@ -227,7 +228,11 @@ func (c *Client) RenderUserSearchFilter(cfg *ConfigEntry, username string) (stri
}
if cfg.UPNDomain != "" {
context.UserAttr = "userPrincipalName"
context.Username = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain)
// Intentionally, calling EscapeFilter(...) (vs EscapeValue) since the
// username is being injected into a search filter.
// As an untrusted string, the username must be escaped according to RFC
// 4515, in order to prevent attackers from injecting characters that could modify the filter
context.Username = fmt.Sprintf("%s@%s", ldap.EscapeFilter(username), cfg.UPNDomain)
}
// Execute the template. Note that the template context contains escaped input and does
@ -595,42 +600,59 @@ func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string,
}
// EscapeLDAPValue is exported because a plugin uses it outside this package.
// EscapeLDAPValue will properly escape the input string as an ldap value
// rfc4514 states the following must be escaped:
// - leading space or hash
// - trailing space
// - special characters '"', '+', ',', ';', '<', '>', '\\'
// - hex
func EscapeLDAPValue(input string) string {
if input == "" {
return ""
}
// RFC4514 forbids un-escaped:
// - leading space or hash
// - trailing space
// - special characters '"', '+', ',', ';', '<', '>', '\\'
// - null
for i := 0; i < len(input); i++ {
escaped := false
if input[i] == '\\' && i+1 < len(input)-1 {
i++
escaped = true
}
switch input[i] {
case '"', '+', ',', ';', '<', '>', '\\':
if !escaped {
input = input[0:i] + "\\" + input[i:]
i++
}
buf := bytes.Buffer{}
escFn := func(c byte) {
buf.WriteByte('\\')
buf.WriteByte(c)
}
inputLen := len(input)
for i := 0; i < inputLen; i++ {
char := input[i]
switch {
case i == 0 && char == ' ' || char == '#':
// leading space or hash.
escFn(char)
continue
}
if escaped {
input = input[0:i] + "\\" + input[i:]
i++
case i == inputLen-1 && char == ' ':
// trailing space.
escFn(char)
continue
case specialChar(char):
escFn(char)
continue
case char < ' ' || char > '~':
// anything that's not between the ascii space and tilde must be hex
buf.WriteByte('\\')
buf.WriteString(hex.EncodeToString([]byte{char}))
continue
default:
// everything remaining, doesn't need to be escaped
buf.WriteByte(char)
}
}
if input[0] == ' ' || input[0] == '#' {
input = "\\" + input
return buf.String()
}
func specialChar(char byte) bool {
switch char {
case '"', '+', ',', ';', '<', '>', '\\':
return true
default:
return false
}
if input[len(input)-1] == ' ' {
input = input[0:len(input)-1] + "\\ "
}
return input
}
/*

View File

@ -7,6 +7,8 @@ import (
"testing"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestDialLDAP duplicates a potential panic that was
@ -29,15 +31,20 @@ func TestDialLDAP(t *testing.T) {
func TestLDAPEscape(t *testing.T) {
testcases := map[string]string{
"#test": "\\#test",
"test,hello": "test\\,hello",
"test,hel+lo": "test\\,hel\\+lo",
"test\\hello": "test\\\\hello",
" test ": "\\ test \\ ",
"": "",
"\\test": "\\\\test",
"test\\": "test\\\\",
"test\\ ": "test\\\\\\ ",
"#test": "\\#test",
"test,hello": "test\\,hello",
"test,hel+lo": "test\\,hel\\+lo",
"test\\hello": "test\\\\hello",
" test ": "\\ test \\ ",
"": "",
`\`: `\\`,
"trailing\000": `trailing\00`,
"mid\000dle": `mid\00dle`,
"\000": `\00`,
"multiple\000\000": `multiple\00\00`,
"backlash-before-null\\\000": `backlash-before-null\\\00`,
"trailing\\": `trailing\\`,
"double-escaping\\>": `double-escaping\\\>`,
}
for test, answer := range testcases {
@ -88,3 +95,58 @@ func TestSIDBytesToString(t *testing.T) {
}
}
}
func TestClient_renderUserSearchFilter(t *testing.T) {
t.Parallel()
tests := []struct {
name string
conf *ConfigEntry
username string
want string
errContains string
}{
{
name: "valid-default",
username: "alice",
conf: &ConfigEntry{
UserAttr: "cn",
},
want: "(cn=alice)",
},
{
name: "escaped-malicious-filter",
username: "foo@example.com)((((((((((((((((((((((((((((((((((((((userPrincipalName=foo",
conf: &ConfigEntry{
UPNDomain: "example.com",
UserFilter: "(&({{.UserAttr}}={{.Username}})({{.UserAttr}}=admin@example.com))",
},
want: "(&(userPrincipalName=foo@example.com\\29\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28userPrincipalName=foo@example.com)(userPrincipalName=admin@example.com))",
},
{
name: "bad-filter-unclosed-action",
username: "alice",
conf: &ConfigEntry{
UserFilter: "hello{{range",
},
errContains: "search failed due to template compilation error",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
c := Client{
Logger: hclog.NewNullLogger(),
LDAP: NewLDAP(),
}
f, err := c.RenderUserSearchFilter(tc.conf, tc.username)
if tc.errContains != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tc.errContains)
return
}
require.NoError(t, err)
assert.NotEmpty(t, f)
assert.Equal(t, tc.want, f)
})
}
}