backport of commit f8dd46acb830f4ef7baef759fb4d0e3752d03e9e (#22251)
This commit is contained in:
parent
1c89ff215b
commit
0ecf0f300e
|
@ -0,0 +1,4 @@
|
|||
```release-note:bug
|
||||
sdk/ldaputil: Properly escape user filters when using UPN domains
|
||||
sdk/ldaputil: use EscapeLDAPValue implementation from cap/ldap
|
||||
```
|
|
@ -8,6 +8,7 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
|
@ -227,7 +228,11 @@ func (c *Client) RenderUserSearchFilter(cfg *ConfigEntry, username string) (stri
|
|||
}
|
||||
if cfg.UPNDomain != "" {
|
||||
context.UserAttr = "userPrincipalName"
|
||||
context.Username = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain)
|
||||
// Intentionally, calling EscapeFilter(...) (vs EscapeValue) since the
|
||||
// username is being injected into a search filter.
|
||||
// As an untrusted string, the username must be escaped according to RFC
|
||||
// 4515, in order to prevent attackers from injecting characters that could modify the filter
|
||||
context.Username = fmt.Sprintf("%s@%s", ldap.EscapeFilter(username), cfg.UPNDomain)
|
||||
}
|
||||
|
||||
// Execute the template. Note that the template context contains escaped input and does
|
||||
|
@ -595,42 +600,59 @@ func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string,
|
|||
}
|
||||
|
||||
// EscapeLDAPValue is exported because a plugin uses it outside this package.
|
||||
// EscapeLDAPValue will properly escape the input string as an ldap value
|
||||
// rfc4514 states the following must be escaped:
|
||||
// - leading space or hash
|
||||
// - trailing space
|
||||
// - special characters '"', '+', ',', ';', '<', '>', '\\'
|
||||
// - hex
|
||||
func EscapeLDAPValue(input string) string {
|
||||
if input == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// RFC4514 forbids un-escaped:
|
||||
// - leading space or hash
|
||||
// - trailing space
|
||||
// - special characters '"', '+', ',', ';', '<', '>', '\\'
|
||||
// - null
|
||||
for i := 0; i < len(input); i++ {
|
||||
escaped := false
|
||||
if input[i] == '\\' && i+1 < len(input)-1 {
|
||||
i++
|
||||
escaped = true
|
||||
}
|
||||
switch input[i] {
|
||||
case '"', '+', ',', ';', '<', '>', '\\':
|
||||
if !escaped {
|
||||
input = input[0:i] + "\\" + input[i:]
|
||||
i++
|
||||
}
|
||||
buf := bytes.Buffer{}
|
||||
|
||||
escFn := func(c byte) {
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte(c)
|
||||
}
|
||||
|
||||
inputLen := len(input)
|
||||
for i := 0; i < inputLen; i++ {
|
||||
char := input[i]
|
||||
switch {
|
||||
case i == 0 && char == ' ' || char == '#':
|
||||
// leading space or hash.
|
||||
escFn(char)
|
||||
continue
|
||||
}
|
||||
if escaped {
|
||||
input = input[0:i] + "\\" + input[i:]
|
||||
i++
|
||||
case i == inputLen-1 && char == ' ':
|
||||
// trailing space.
|
||||
escFn(char)
|
||||
continue
|
||||
case specialChar(char):
|
||||
escFn(char)
|
||||
continue
|
||||
case char < ' ' || char > '~':
|
||||
// anything that's not between the ascii space and tilde must be hex
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteString(hex.EncodeToString([]byte{char}))
|
||||
continue
|
||||
default:
|
||||
// everything remaining, doesn't need to be escaped
|
||||
buf.WriteByte(char)
|
||||
}
|
||||
}
|
||||
if input[0] == ' ' || input[0] == '#' {
|
||||
input = "\\" + input
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func specialChar(char byte) bool {
|
||||
switch char {
|
||||
case '"', '+', ',', ';', '<', '>', '\\':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
if input[len(input)-1] == ' ' {
|
||||
input = input[0:len(input)-1] + "\\ "
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDialLDAP duplicates a potential panic that was
|
||||
|
@ -29,15 +31,20 @@ func TestDialLDAP(t *testing.T) {
|
|||
|
||||
func TestLDAPEscape(t *testing.T) {
|
||||
testcases := map[string]string{
|
||||
"#test": "\\#test",
|
||||
"test,hello": "test\\,hello",
|
||||
"test,hel+lo": "test\\,hel\\+lo",
|
||||
"test\\hello": "test\\\\hello",
|
||||
" test ": "\\ test \\ ",
|
||||
"": "",
|
||||
"\\test": "\\\\test",
|
||||
"test\\": "test\\\\",
|
||||
"test\\ ": "test\\\\\\ ",
|
||||
"#test": "\\#test",
|
||||
"test,hello": "test\\,hello",
|
||||
"test,hel+lo": "test\\,hel\\+lo",
|
||||
"test\\hello": "test\\\\hello",
|
||||
" test ": "\\ test \\ ",
|
||||
"": "",
|
||||
`\`: `\\`,
|
||||
"trailing\000": `trailing\00`,
|
||||
"mid\000dle": `mid\00dle`,
|
||||
"\000": `\00`,
|
||||
"multiple\000\000": `multiple\00\00`,
|
||||
"backlash-before-null\\\000": `backlash-before-null\\\00`,
|
||||
"trailing\\": `trailing\\`,
|
||||
"double-escaping\\>": `double-escaping\\\>`,
|
||||
}
|
||||
|
||||
for test, answer := range testcases {
|
||||
|
@ -88,3 +95,58 @@ func TestSIDBytesToString(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_renderUserSearchFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
conf *ConfigEntry
|
||||
username string
|
||||
want string
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid-default",
|
||||
username: "alice",
|
||||
conf: &ConfigEntry{
|
||||
UserAttr: "cn",
|
||||
},
|
||||
want: "(cn=alice)",
|
||||
},
|
||||
{
|
||||
name: "escaped-malicious-filter",
|
||||
username: "foo@example.com)((((((((((((((((((((((((((((((((((((((userPrincipalName=foo",
|
||||
conf: &ConfigEntry{
|
||||
UPNDomain: "example.com",
|
||||
UserFilter: "(&({{.UserAttr}}={{.Username}})({{.UserAttr}}=admin@example.com))",
|
||||
},
|
||||
want: "(&(userPrincipalName=foo@example.com\\29\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28userPrincipalName=foo@example.com)(userPrincipalName=admin@example.com))",
|
||||
},
|
||||
{
|
||||
name: "bad-filter-unclosed-action",
|
||||
username: "alice",
|
||||
conf: &ConfigEntry{
|
||||
UserFilter: "hello{{range",
|
||||
},
|
||||
errContains: "search failed due to template compilation error",
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := Client{
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LDAP: NewLDAP(),
|
||||
}
|
||||
|
||||
f, err := c.RenderUserSearchFilter(tc.conf, tc.username)
|
||||
if tc.errContains != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tc.errContains)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, f)
|
||||
assert.Equal(t, tc.want, f)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue