open-vault/vendor/github.com/centrify/cloud-golang-sdk/oauth/oauth.go
Jeff Mitchell 98b479ab58 Bump deps
2018-01-26 18:51:00 -05:00

183 lines
4.7 KiB
Go

package oauth
import (
"encoding/base64"
"encoding/json"
"io/ioutil"
"net/http"
"net/http/cookiejar"
"net/url"
"strings"
)
type HttpClientFactory func() *http.Client
// TokenResponse represents successful token response
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
}
type ErrorResponse struct {
Error string `json:"error"`
Description string `json:"error_description"`
}
// OauthClient represents a stateful Oauth client
type OauthClient struct {
Service string
Client *http.Client
Headers map[string]string
ClientID string
ClientSecret string
SourceHeader string
}
// GetNewClient creates a new client for the specified endpoint
func GetNewClient(service string, httpFactory HttpClientFactory) (*OauthClient, error) {
jar, err := cookiejar.New(nil)
if err != nil {
return nil, err
}
// Munge on the service a little bit, force it to have no trailing / and always start with https://
url, err := url.Parse(service)
if err != nil {
return nil, err
}
url.Scheme = "https"
url.Path = ""
client := &OauthClient{}
client.Service = url.String()
if httpFactory != nil {
client.Client = httpFactory()
} else {
client.Client = &http.Client{}
}
client.Client.Jar = jar
client.Headers = make(map[string]string)
client.SourceHeader = "cloud-golang-sdk"
return client, err
}
// GetNewConfidentialClient creates a new client for the specified endpoint
func GetNewConfidentialClient(service string, clientID string, clientSecret string, httpFactory HttpClientFactory) (*OauthClient, error) {
client, err := GetNewClient(service, httpFactory)
if err != nil {
return nil, err
}
client.ClientID = clientID
client.ClientSecret = clientSecret
return client, nil
}
// ResourceOwner implements the ResourceOwner flow
func (c *OauthClient) ResourceOwner(appID string, scope string, owner string, ownerPassword string) (*TokenResponse, *ErrorResponse, error) {
args := make(map[string]string)
args["grant_type"] = "password"
args["username"] = owner
args["password"] = ownerPassword
args["scope"] = scope
return c.postAndGetResponse("/oauth2/token/"+appID, args)
}
func (c *OauthClient) ClientCredentials(appID string, scope string) (*TokenResponse, *ErrorResponse, error) {
args := make(map[string]string)
args["grant_type"] = "client_credentials"
args["scope"] = scope
return c.postAndGetResponse("/oauth2/token/"+appID, args)
}
func (c *OauthClient) RefreshToken(appID string, refreshToken string) (*TokenResponse, *ErrorResponse, error) {
args := make(map[string]string)
args["grant_type"] = "refresh_token"
args["refresh_token"] = refreshToken
return c.postAndGetResponse("/oauth2/token/"+appID, args)
}
func (c *OauthClient) postAndGetResponse(method string, args map[string]string) (*TokenResponse, *ErrorResponse, error) {
body, status, err := c.postAndGetBody(method, args)
if err != nil {
return nil, nil, err
}
if status == 200 {
response, err := bodyToTokenResponse(body)
if err != nil {
return nil, nil, err
}
return response, nil, nil
}
response, err := bodyToErrorResponse(body)
if err != nil {
return nil, nil, err
}
return nil, response, nil
}
func (c *OauthClient) postAndGetBody(method string, args map[string]string) ([]byte, int, error) {
postdata := strings.NewReader(payloadFromMap(args))
postreq, err := http.NewRequest("POST", c.Service+method, postdata)
if err != nil {
return nil, 0, err
}
if c.ClientID != "" && c.ClientSecret != "" {
postreq.Header.Add("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(c.ClientID+":"+c.ClientSecret)))
}
postreq.Header.Add("Content-Type", "application/x-www-form-urlencoded")
postreq.Header.Add("X-CENTRIFY-NATIVE-CLIENT", "Yes")
postreq.Header.Add("X-CFY-SRC", c.SourceHeader)
for k, v := range c.Headers {
postreq.Header.Add(k, v)
}
httpresp, err := c.Client.Do(postreq)
if err != nil {
return nil, 0, err
}
defer httpresp.Body.Close()
body, err := ioutil.ReadAll(httpresp.Body)
if err != nil {
return nil, httpresp.StatusCode, err
}
return body, httpresp.StatusCode, nil
}
func payloadFromMap(input map[string]string) string {
data := url.Values{}
for i, v := range input {
data.Add(i, v)
}
return data.Encode()
}
func bodyToTokenResponse(body []byte) (*TokenResponse, error) {
reply := &TokenResponse{}
err := json.Unmarshal(body, &reply)
if err != nil {
return nil, err
}
return reply, nil
}
func bodyToErrorResponse(body []byte) (*ErrorResponse, error) {
reply := &ErrorResponse{}
err := json.Unmarshal(body, &reply)
if err != nil {
return nil, err
}
return reply, nil
}