open-vault/serviceregistration/kubernetes/testing/testserver.go
Hamid Ghaf 27bb03bbc0
adding copyright header (#19555)
* adding copyright header

* fix fmt and a test
2023-03-15 09:00:52 -07:00

242 lines
5.9 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package testing
import (
_ "embed"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path"
"strings"
"sync"
"testing"
"go.uber.org/atomic"
)
const (
ExpectedNamespace = "default"
ExpectedPodName = "shell-demo"
)
// Pull real-life-based testing data in from files at compile time.
// We decided to embed them in the test binary because of past issues
// with reading files that we encountered on CI workers.
//go:embed ca.crt
var caCrt string
//go:embed resp-get-pod.json
var getPodResponse string
//go:embed resp-not-found.json
var notFoundResponse string
//go:embed resp-update-pod.json
var updatePodTagsResponse string
//go:embed token
var token string
var (
// ReturnGatewayTimeouts toggles whether the test server should return,
// well, gateway timeouts...
ReturnGatewayTimeouts = atomic.NewBool(false)
pathToFiles = func() string {
wd, _ := os.Getwd()
repoName := "vault-enterprise"
if !strings.Contains(wd, repoName) {
repoName = "vault"
}
pathParts := strings.Split(wd, repoName)
return pathParts[0] + "vault/serviceregistration/kubernetes/testing/"
}()
)
// Conf returns the info needed to configure the client to point at
// the test server. This must be done by the caller to avoid an import
// cycle between the client and the testserver. Example usage:
//
// client.Scheme = testConf.ClientScheme
// client.TokenFile = testConf.PathToTokenFile
// client.RootCAFile = testConf.PathToRootCAFile
// if err := os.Setenv(client.EnvVarKubernetesServiceHost, testConf.ServiceHost); err != nil {
// t.Fatal(err)
// }
// if err := os.Setenv(client.EnvVarKubernetesServicePort, testConf.ServicePort); err != nil {
// t.Fatal(err)
// }
type Conf struct {
ClientScheme, PathToTokenFile, PathToRootCAFile, ServiceHost, ServicePort string
}
// Server returns an http test server that can be used to test
// Kubernetes client code. It also retains the current state,
// and a func to close the server and to clean up any temporary
// files.
func Server(t *testing.T) (testState *State, testConf *Conf, closeFunc func()) {
testState = &State{m: &sync.Map{}}
testConf = &Conf{
ClientScheme: "http://",
}
// We're going to have multiple close funcs to call.
var closers []func()
closeFunc = func() {
for _, closer := range closers {
closer()
}
}
// Plant our token in a place where it can be read for the config.
tmpToken, err := ioutil.TempFile("", "token")
if err != nil {
t.Fatal(err)
}
closers = append(closers, func() {
os.Remove(tmpToken.Name())
})
if _, err = tmpToken.WriteString(token); err != nil {
closeFunc()
t.Fatal(err)
}
if err := tmpToken.Close(); err != nil {
closeFunc()
t.Fatal(err)
}
testConf.PathToTokenFile = tmpToken.Name()
tmpCACrt, err := ioutil.TempFile("", "ca.crt")
if err != nil {
closeFunc()
t.Fatal(err)
}
closers = append(closers, func() {
os.Remove(tmpCACrt.Name())
})
if _, err = tmpCACrt.WriteString(caCrt); err != nil {
closeFunc()
t.Fatal(err)
}
if err := tmpCACrt.Close(); err != nil {
closeFunc()
t.Fatal(err)
}
testConf.PathToRootCAFile = tmpCACrt.Name()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if ReturnGatewayTimeouts.Load() {
w.WriteHeader(504)
return
}
namespace, podName, err := parsePath(r.URL.Path)
if err != nil {
w.WriteHeader(400)
w.Write([]byte(fmt.Sprintf("unable to parse %s: %s", r.URL.Path, err.Error())))
return
}
switch {
case namespace != ExpectedNamespace, podName != ExpectedPodName:
w.WriteHeader(404)
w.Write([]byte(notFoundResponse))
return
case r.Method == http.MethodGet:
w.WriteHeader(200)
w.Write([]byte(getPodResponse))
return
case r.Method == http.MethodPatch:
var patches []interface{}
if err := json.NewDecoder(r.Body).Decode(&patches); err != nil {
w.WriteHeader(400)
w.Write([]byte(fmt.Sprintf("unable to decode patches %s: %s", r.URL.Path, err.Error())))
return
}
for _, patch := range patches {
patchMap := patch.(map[string]interface{})
p := patchMap["path"].(string)
testState.store(p, patchMap)
}
w.WriteHeader(200)
w.Write([]byte(updatePodTagsResponse))
return
default:
w.WriteHeader(400)
w.Write([]byte(fmt.Sprintf("unexpected request method: %s", r.Method)))
}
}))
closers = append(closers, ts.Close)
// ts.URL example: http://127.0.0.1:35681
urlFields := strings.Split(ts.URL, "://")
if len(urlFields) != 2 {
closeFunc()
t.Fatal("received unexpected test url: " + ts.URL)
}
urlFields = strings.Split(urlFields[1], ":")
if len(urlFields) != 2 {
closeFunc()
t.Fatal("received unexpected test url: " + ts.URL)
}
testConf.ServiceHost = urlFields[0]
testConf.ServicePort = urlFields[1]
return testState, testConf, closeFunc
}
type State struct {
m *sync.Map
}
func (s *State) NumPatches() int {
l := 0
f := func(key, value interface{}) bool {
l++
return true
}
s.m.Range(f)
return l
}
func (s *State) Get(key string) map[string]interface{} {
v, ok := s.m.Load(key)
if !ok {
return nil
}
patch, ok := v.(map[string]interface{})
if !ok {
return nil
}
return patch
}
func (s *State) store(k string, p map[string]interface{}) {
s.m.Store(k, p)
}
// The path should be formatted like this:
// fmt.Sprintf("/api/v1/namespaces/%s/pods/%s", namespace, podName)
func parsePath(urlPath string) (namespace, podName string, err error) {
original := urlPath
podName = path.Base(urlPath)
urlPath = strings.TrimSuffix(urlPath, "/pods/"+podName)
namespace = path.Base(urlPath)
if original != fmt.Sprintf("/api/v1/namespaces/%s/pods/%s", namespace, podName) {
return "", "", fmt.Errorf("received unexpected path: %s", original)
}
return namespace, podName, nil
}
func readFile(fileName string) (string, error) {
b, err := ioutil.ReadFile(pathToFiles + fileName)
if err != nil {
return "", err
}
return string(b), nil
}