27bb03bbc0
* adding copyright header * fix fmt and a test
242 lines
5.9 KiB
Go
242 lines
5.9 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package testing
|
|
|
|
import (
|
|
_ "embed"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"go.uber.org/atomic"
|
|
)
|
|
|
|
const (
|
|
ExpectedNamespace = "default"
|
|
ExpectedPodName = "shell-demo"
|
|
)
|
|
|
|
// Pull real-life-based testing data in from files at compile time.
|
|
// We decided to embed them in the test binary because of past issues
|
|
// with reading files that we encountered on CI workers.
|
|
|
|
//go:embed ca.crt
|
|
var caCrt string
|
|
|
|
//go:embed resp-get-pod.json
|
|
var getPodResponse string
|
|
|
|
//go:embed resp-not-found.json
|
|
var notFoundResponse string
|
|
|
|
//go:embed resp-update-pod.json
|
|
var updatePodTagsResponse string
|
|
|
|
//go:embed token
|
|
var token string
|
|
|
|
var (
|
|
// ReturnGatewayTimeouts toggles whether the test server should return,
|
|
// well, gateway timeouts...
|
|
ReturnGatewayTimeouts = atomic.NewBool(false)
|
|
|
|
pathToFiles = func() string {
|
|
wd, _ := os.Getwd()
|
|
repoName := "vault-enterprise"
|
|
if !strings.Contains(wd, repoName) {
|
|
repoName = "vault"
|
|
}
|
|
pathParts := strings.Split(wd, repoName)
|
|
return pathParts[0] + "vault/serviceregistration/kubernetes/testing/"
|
|
}()
|
|
)
|
|
|
|
// Conf returns the info needed to configure the client to point at
|
|
// the test server. This must be done by the caller to avoid an import
|
|
// cycle between the client and the testserver. Example usage:
|
|
//
|
|
// client.Scheme = testConf.ClientScheme
|
|
// client.TokenFile = testConf.PathToTokenFile
|
|
// client.RootCAFile = testConf.PathToRootCAFile
|
|
// if err := os.Setenv(client.EnvVarKubernetesServiceHost, testConf.ServiceHost); err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
// if err := os.Setenv(client.EnvVarKubernetesServicePort, testConf.ServicePort); err != nil {
|
|
// t.Fatal(err)
|
|
// }
|
|
type Conf struct {
|
|
ClientScheme, PathToTokenFile, PathToRootCAFile, ServiceHost, ServicePort string
|
|
}
|
|
|
|
// Server returns an http test server that can be used to test
|
|
// Kubernetes client code. It also retains the current state,
|
|
// and a func to close the server and to clean up any temporary
|
|
// files.
|
|
func Server(t *testing.T) (testState *State, testConf *Conf, closeFunc func()) {
|
|
testState = &State{m: &sync.Map{}}
|
|
testConf = &Conf{
|
|
ClientScheme: "http://",
|
|
}
|
|
|
|
// We're going to have multiple close funcs to call.
|
|
var closers []func()
|
|
closeFunc = func() {
|
|
for _, closer := range closers {
|
|
closer()
|
|
}
|
|
}
|
|
|
|
// Plant our token in a place where it can be read for the config.
|
|
tmpToken, err := ioutil.TempFile("", "token")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
closers = append(closers, func() {
|
|
os.Remove(tmpToken.Name())
|
|
})
|
|
if _, err = tmpToken.WriteString(token); err != nil {
|
|
closeFunc()
|
|
t.Fatal(err)
|
|
}
|
|
if err := tmpToken.Close(); err != nil {
|
|
closeFunc()
|
|
t.Fatal(err)
|
|
}
|
|
testConf.PathToTokenFile = tmpToken.Name()
|
|
|
|
tmpCACrt, err := ioutil.TempFile("", "ca.crt")
|
|
if err != nil {
|
|
closeFunc()
|
|
t.Fatal(err)
|
|
}
|
|
closers = append(closers, func() {
|
|
os.Remove(tmpCACrt.Name())
|
|
})
|
|
if _, err = tmpCACrt.WriteString(caCrt); err != nil {
|
|
closeFunc()
|
|
t.Fatal(err)
|
|
}
|
|
if err := tmpCACrt.Close(); err != nil {
|
|
closeFunc()
|
|
t.Fatal(err)
|
|
}
|
|
testConf.PathToRootCAFile = tmpCACrt.Name()
|
|
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if ReturnGatewayTimeouts.Load() {
|
|
w.WriteHeader(504)
|
|
return
|
|
}
|
|
namespace, podName, err := parsePath(r.URL.Path)
|
|
if err != nil {
|
|
w.WriteHeader(400)
|
|
w.Write([]byte(fmt.Sprintf("unable to parse %s: %s", r.URL.Path, err.Error())))
|
|
return
|
|
}
|
|
|
|
switch {
|
|
case namespace != ExpectedNamespace, podName != ExpectedPodName:
|
|
w.WriteHeader(404)
|
|
w.Write([]byte(notFoundResponse))
|
|
return
|
|
case r.Method == http.MethodGet:
|
|
w.WriteHeader(200)
|
|
w.Write([]byte(getPodResponse))
|
|
return
|
|
case r.Method == http.MethodPatch:
|
|
var patches []interface{}
|
|
if err := json.NewDecoder(r.Body).Decode(&patches); err != nil {
|
|
w.WriteHeader(400)
|
|
w.Write([]byte(fmt.Sprintf("unable to decode patches %s: %s", r.URL.Path, err.Error())))
|
|
return
|
|
}
|
|
for _, patch := range patches {
|
|
patchMap := patch.(map[string]interface{})
|
|
p := patchMap["path"].(string)
|
|
testState.store(p, patchMap)
|
|
}
|
|
w.WriteHeader(200)
|
|
w.Write([]byte(updatePodTagsResponse))
|
|
return
|
|
default:
|
|
w.WriteHeader(400)
|
|
w.Write([]byte(fmt.Sprintf("unexpected request method: %s", r.Method)))
|
|
}
|
|
}))
|
|
closers = append(closers, ts.Close)
|
|
|
|
// ts.URL example: http://127.0.0.1:35681
|
|
urlFields := strings.Split(ts.URL, "://")
|
|
if len(urlFields) != 2 {
|
|
closeFunc()
|
|
t.Fatal("received unexpected test url: " + ts.URL)
|
|
}
|
|
urlFields = strings.Split(urlFields[1], ":")
|
|
if len(urlFields) != 2 {
|
|
closeFunc()
|
|
t.Fatal("received unexpected test url: " + ts.URL)
|
|
}
|
|
testConf.ServiceHost = urlFields[0]
|
|
testConf.ServicePort = urlFields[1]
|
|
return testState, testConf, closeFunc
|
|
}
|
|
|
|
type State struct {
|
|
m *sync.Map
|
|
}
|
|
|
|
func (s *State) NumPatches() int {
|
|
l := 0
|
|
f := func(key, value interface{}) bool {
|
|
l++
|
|
return true
|
|
}
|
|
s.m.Range(f)
|
|
return l
|
|
}
|
|
|
|
func (s *State) Get(key string) map[string]interface{} {
|
|
v, ok := s.m.Load(key)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
patch, ok := v.(map[string]interface{})
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return patch
|
|
}
|
|
|
|
func (s *State) store(k string, p map[string]interface{}) {
|
|
s.m.Store(k, p)
|
|
}
|
|
|
|
// The path should be formatted like this:
|
|
// fmt.Sprintf("/api/v1/namespaces/%s/pods/%s", namespace, podName)
|
|
func parsePath(urlPath string) (namespace, podName string, err error) {
|
|
original := urlPath
|
|
podName = path.Base(urlPath)
|
|
urlPath = strings.TrimSuffix(urlPath, "/pods/"+podName)
|
|
namespace = path.Base(urlPath)
|
|
if original != fmt.Sprintf("/api/v1/namespaces/%s/pods/%s", namespace, podName) {
|
|
return "", "", fmt.Errorf("received unexpected path: %s", original)
|
|
}
|
|
return namespace, podName, nil
|
|
}
|
|
|
|
func readFile(fileName string) (string, error) {
|
|
b, err := ioutil.ReadFile(pathToFiles + fileName)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(b), nil
|
|
}
|