222 lines
4.8 KiB
Go
222 lines
4.8 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package forwarding
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"errors"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/hashicorp/vault/sdk/helper/compressutil"
|
|
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
|
)
|
|
|
|
type bufCloser struct {
|
|
*bytes.Buffer
|
|
}
|
|
|
|
func (b bufCloser) Close() error {
|
|
b.Reset()
|
|
return nil
|
|
}
|
|
|
|
// GenerateForwardedRequest generates a new http.Request that contains the
|
|
// original requests's information in the new request's body.
|
|
func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request, error) {
|
|
fq, err := GenerateForwardedRequest(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var newBody []byte
|
|
switch os.Getenv("VAULT_MESSAGE_TYPE") {
|
|
case "json":
|
|
newBody, err = jsonutil.EncodeJSON(fq)
|
|
case "json_compress":
|
|
newBody, err = jsonutil.EncodeJSONAndCompress(fq, &compressutil.CompressionConfig{
|
|
Type: compressutil.CompressionTypeLZW,
|
|
})
|
|
case "proto3":
|
|
fallthrough
|
|
default:
|
|
newBody, err = proto.Marshal(fq)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ret, err := http.NewRequest("POST", addr, bytes.NewBuffer(newBody))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func GenerateForwardedRequest(req *http.Request) (*Request, error) {
|
|
var reader io.Reader = req.Body
|
|
ctx := req.Context()
|
|
maxRequestSize := ctx.Value("max_request_size")
|
|
if maxRequestSize != nil {
|
|
max, ok := maxRequestSize.(int64)
|
|
if !ok {
|
|
return nil, errors.New("could not parse max_request_size from request context")
|
|
}
|
|
if max > 0 {
|
|
reader = io.LimitReader(req.Body, max)
|
|
}
|
|
}
|
|
|
|
body, err := ioutil.ReadAll(reader)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fq := Request{
|
|
Method: req.Method,
|
|
HeaderEntries: make(map[string]*HeaderEntry, len(req.Header)),
|
|
Host: req.Host,
|
|
RemoteAddr: req.RemoteAddr,
|
|
Body: body,
|
|
}
|
|
|
|
reqURL := req.URL
|
|
fq.Url = &URL{
|
|
Scheme: reqURL.Scheme,
|
|
Opaque: reqURL.Opaque,
|
|
Host: reqURL.Host,
|
|
Path: reqURL.Path,
|
|
RawPath: reqURL.RawPath,
|
|
RawQuery: reqURL.RawQuery,
|
|
Fragment: reqURL.Fragment,
|
|
}
|
|
|
|
for k, v := range req.Header {
|
|
fq.HeaderEntries[k] = &HeaderEntry{
|
|
Values: v,
|
|
}
|
|
}
|
|
|
|
if req.TLS != nil && req.TLS.PeerCertificates != nil && len(req.TLS.PeerCertificates) > 0 {
|
|
fq.PeerCertificates = make([][]byte, len(req.TLS.PeerCertificates))
|
|
for i, cert := range req.TLS.PeerCertificates {
|
|
fq.PeerCertificates[i] = cert.Raw
|
|
}
|
|
}
|
|
|
|
return &fq, nil
|
|
}
|
|
|
|
// ParseForwardedRequest generates a new http.Request that is comprised of the
|
|
// values in the given request's body, assuming it correctly parses into a
|
|
// ForwardedRequest.
|
|
func ParseForwardedHTTPRequest(req *http.Request) (*http.Request, error) {
|
|
buf := bytes.NewBuffer(nil)
|
|
_, err := buf.ReadFrom(req.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fq := new(Request)
|
|
switch os.Getenv("VAULT_MESSAGE_TYPE") {
|
|
case "json", "json_compress":
|
|
err = jsonutil.DecodeJSON(buf.Bytes(), fq)
|
|
default:
|
|
err = proto.Unmarshal(buf.Bytes(), fq)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ParseForwardedRequest(fq)
|
|
}
|
|
|
|
func ParseForwardedRequest(fq *Request) (*http.Request, error) {
|
|
buf := bufCloser{
|
|
Buffer: bytes.NewBuffer(fq.Body),
|
|
}
|
|
|
|
ret := &http.Request{
|
|
Method: fq.Method,
|
|
Header: make(map[string][]string, len(fq.HeaderEntries)),
|
|
Body: buf,
|
|
Host: fq.Host,
|
|
RemoteAddr: fq.RemoteAddr,
|
|
}
|
|
|
|
ret.URL = &url.URL{
|
|
Scheme: fq.Url.Scheme,
|
|
Opaque: fq.Url.Opaque,
|
|
Host: fq.Url.Host,
|
|
Path: fq.Url.Path,
|
|
RawPath: fq.Url.RawPath,
|
|
RawQuery: fq.Url.RawQuery,
|
|
Fragment: fq.Url.Fragment,
|
|
}
|
|
|
|
for k, v := range fq.HeaderEntries {
|
|
ret.Header[k] = v.Values
|
|
}
|
|
|
|
if fq.PeerCertificates != nil && len(fq.PeerCertificates) > 0 {
|
|
ret.TLS = &tls.ConnectionState{
|
|
PeerCertificates: make([]*x509.Certificate, len(fq.PeerCertificates)),
|
|
}
|
|
for i, certBytes := range fq.PeerCertificates {
|
|
cert, err := x509.ParseCertificate(certBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret.TLS.PeerCertificates[i] = cert
|
|
}
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
type RPCResponseWriter struct {
|
|
statusCode int
|
|
header http.Header
|
|
body *bytes.Buffer
|
|
}
|
|
|
|
// NewRPCResponseWriter returns an initialized RPCResponseWriter
|
|
func NewRPCResponseWriter() *RPCResponseWriter {
|
|
w := &RPCResponseWriter{
|
|
header: make(http.Header),
|
|
body: new(bytes.Buffer),
|
|
statusCode: 200,
|
|
}
|
|
// w.header.Set("Content-Type", "application/octet-stream")
|
|
return w
|
|
}
|
|
|
|
func (w *RPCResponseWriter) Header() http.Header {
|
|
return w.header
|
|
}
|
|
|
|
func (w *RPCResponseWriter) Write(buf []byte) (int, error) {
|
|
w.body.Write(buf)
|
|
return len(buf), nil
|
|
}
|
|
|
|
func (w *RPCResponseWriter) WriteHeader(code int) {
|
|
w.statusCode = code
|
|
}
|
|
|
|
func (w *RPCResponseWriter) StatusCode() int {
|
|
return w.statusCode
|
|
}
|
|
|
|
func (w *RPCResponseWriter) Body() *bytes.Buffer {
|
|
return w.body
|
|
}
|