open-vault/helper/forwarding/util.go

222 lines
4.8 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package forwarding
import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"github.com/golang/protobuf/proto"
"github.com/hashicorp/vault/sdk/helper/compressutil"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
)
type bufCloser struct {
*bytes.Buffer
}
func (b bufCloser) Close() error {
b.Reset()
return nil
}
// GenerateForwardedRequest generates a new http.Request that contains the
// original requests's information in the new request's body.
func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request, error) {
fq, err := GenerateForwardedRequest(req)
if err != nil {
return nil, err
}
var newBody []byte
switch os.Getenv("VAULT_MESSAGE_TYPE") {
case "json":
newBody, err = jsonutil.EncodeJSON(fq)
case "json_compress":
newBody, err = jsonutil.EncodeJSONAndCompress(fq, &compressutil.CompressionConfig{
Type: compressutil.CompressionTypeLZW,
})
case "proto3":
fallthrough
default:
newBody, err = proto.Marshal(fq)
}
if err != nil {
return nil, err
}
ret, err := http.NewRequest("POST", addr, bytes.NewBuffer(newBody))
if err != nil {
return nil, err
}
return ret, nil
}
func GenerateForwardedRequest(req *http.Request) (*Request, error) {
var reader io.Reader = req.Body
ctx := req.Context()
maxRequestSize := ctx.Value("max_request_size")
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if !ok {
return nil, errors.New("could not parse max_request_size from request context")
}
if max > 0 {
reader = io.LimitReader(req.Body, max)
}
}
body, err := ioutil.ReadAll(reader)
if err != nil {
return nil, err
}
fq := Request{
Method: req.Method,
HeaderEntries: make(map[string]*HeaderEntry, len(req.Header)),
Host: req.Host,
RemoteAddr: req.RemoteAddr,
Body: body,
}
reqURL := req.URL
fq.Url = &URL{
Scheme: reqURL.Scheme,
Opaque: reqURL.Opaque,
Host: reqURL.Host,
Path: reqURL.Path,
RawPath: reqURL.RawPath,
RawQuery: reqURL.RawQuery,
Fragment: reqURL.Fragment,
}
for k, v := range req.Header {
fq.HeaderEntries[k] = &HeaderEntry{
Values: v,
}
}
if req.TLS != nil && req.TLS.PeerCertificates != nil && len(req.TLS.PeerCertificates) > 0 {
fq.PeerCertificates = make([][]byte, len(req.TLS.PeerCertificates))
for i, cert := range req.TLS.PeerCertificates {
fq.PeerCertificates[i] = cert.Raw
}
}
return &fq, nil
}
// ParseForwardedRequest generates a new http.Request that is comprised of the
// values in the given request's body, assuming it correctly parses into a
// ForwardedRequest.
func ParseForwardedHTTPRequest(req *http.Request) (*http.Request, error) {
buf := bytes.NewBuffer(nil)
_, err := buf.ReadFrom(req.Body)
if err != nil {
return nil, err
}
fq := new(Request)
switch os.Getenv("VAULT_MESSAGE_TYPE") {
case "json", "json_compress":
err = jsonutil.DecodeJSON(buf.Bytes(), fq)
default:
err = proto.Unmarshal(buf.Bytes(), fq)
}
if err != nil {
return nil, err
}
return ParseForwardedRequest(fq)
}
func ParseForwardedRequest(fq *Request) (*http.Request, error) {
buf := bufCloser{
Buffer: bytes.NewBuffer(fq.Body),
}
ret := &http.Request{
Method: fq.Method,
Header: make(map[string][]string, len(fq.HeaderEntries)),
Body: buf,
Host: fq.Host,
RemoteAddr: fq.RemoteAddr,
}
ret.URL = &url.URL{
Scheme: fq.Url.Scheme,
Opaque: fq.Url.Opaque,
Host: fq.Url.Host,
Path: fq.Url.Path,
RawPath: fq.Url.RawPath,
RawQuery: fq.Url.RawQuery,
Fragment: fq.Url.Fragment,
}
for k, v := range fq.HeaderEntries {
ret.Header[k] = v.Values
}
if fq.PeerCertificates != nil && len(fq.PeerCertificates) > 0 {
ret.TLS = &tls.ConnectionState{
PeerCertificates: make([]*x509.Certificate, len(fq.PeerCertificates)),
}
for i, certBytes := range fq.PeerCertificates {
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, err
}
ret.TLS.PeerCertificates[i] = cert
}
}
return ret, nil
}
type RPCResponseWriter struct {
statusCode int
header http.Header
body *bytes.Buffer
}
// NewRPCResponseWriter returns an initialized RPCResponseWriter
func NewRPCResponseWriter() *RPCResponseWriter {
w := &RPCResponseWriter{
header: make(http.Header),
body: new(bytes.Buffer),
statusCode: 200,
}
// w.header.Set("Content-Type", "application/octet-stream")
return w
}
func (w *RPCResponseWriter) Header() http.Header {
return w.header
}
func (w *RPCResponseWriter) Write(buf []byte) (int, error) {
w.body.Write(buf)
return len(buf), nil
}
func (w *RPCResponseWriter) WriteHeader(code int) {
w.statusCode = code
}
func (w *RPCResponseWriter) StatusCode() int {
return w.statusCode
}
func (w *RPCResponseWriter) Body() *bytes.Buffer {
return w.body
}