// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package forwarding import ( "bytes" "crypto/tls" "crypto/x509" "errors" "io" "io/ioutil" "net/http" "net/url" "os" "github.com/golang/protobuf/proto" "github.com/hashicorp/vault/sdk/helper/compressutil" "github.com/hashicorp/vault/sdk/helper/jsonutil" ) type bufCloser struct { *bytes.Buffer } func (b bufCloser) Close() error { b.Reset() return nil } // GenerateForwardedRequest generates a new http.Request that contains the // original requests's information in the new request's body. func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request, error) { fq, err := GenerateForwardedRequest(req) if err != nil { return nil, err } var newBody []byte switch os.Getenv("VAULT_MESSAGE_TYPE") { case "json": newBody, err = jsonutil.EncodeJSON(fq) case "json_compress": newBody, err = jsonutil.EncodeJSONAndCompress(fq, &compressutil.CompressionConfig{ Type: compressutil.CompressionTypeLZW, }) case "proto3": fallthrough default: newBody, err = proto.Marshal(fq) } if err != nil { return nil, err } ret, err := http.NewRequest("POST", addr, bytes.NewBuffer(newBody)) if err != nil { return nil, err } return ret, nil } func GenerateForwardedRequest(req *http.Request) (*Request, error) { var reader io.Reader = req.Body ctx := req.Context() maxRequestSize := ctx.Value("max_request_size") if maxRequestSize != nil { max, ok := maxRequestSize.(int64) if !ok { return nil, errors.New("could not parse max_request_size from request context") } if max > 0 { reader = io.LimitReader(req.Body, max) } } body, err := ioutil.ReadAll(reader) if err != nil { return nil, err } fq := Request{ Method: req.Method, HeaderEntries: make(map[string]*HeaderEntry, len(req.Header)), Host: req.Host, RemoteAddr: req.RemoteAddr, Body: body, } reqURL := req.URL fq.Url = &URL{ Scheme: reqURL.Scheme, Opaque: reqURL.Opaque, Host: reqURL.Host, Path: reqURL.Path, RawPath: reqURL.RawPath, RawQuery: reqURL.RawQuery, Fragment: reqURL.Fragment, } for k, v := range req.Header { fq.HeaderEntries[k] = &HeaderEntry{ Values: v, } } if req.TLS != nil && req.TLS.PeerCertificates != nil && len(req.TLS.PeerCertificates) > 0 { fq.PeerCertificates = make([][]byte, len(req.TLS.PeerCertificates)) for i, cert := range req.TLS.PeerCertificates { fq.PeerCertificates[i] = cert.Raw } } return &fq, nil } // ParseForwardedRequest generates a new http.Request that is comprised of the // values in the given request's body, assuming it correctly parses into a // ForwardedRequest. func ParseForwardedHTTPRequest(req *http.Request) (*http.Request, error) { buf := bytes.NewBuffer(nil) _, err := buf.ReadFrom(req.Body) if err != nil { return nil, err } fq := new(Request) switch os.Getenv("VAULT_MESSAGE_TYPE") { case "json", "json_compress": err = jsonutil.DecodeJSON(buf.Bytes(), fq) default: err = proto.Unmarshal(buf.Bytes(), fq) } if err != nil { return nil, err } return ParseForwardedRequest(fq) } func ParseForwardedRequest(fq *Request) (*http.Request, error) { buf := bufCloser{ Buffer: bytes.NewBuffer(fq.Body), } ret := &http.Request{ Method: fq.Method, Header: make(map[string][]string, len(fq.HeaderEntries)), Body: buf, Host: fq.Host, RemoteAddr: fq.RemoteAddr, } ret.URL = &url.URL{ Scheme: fq.Url.Scheme, Opaque: fq.Url.Opaque, Host: fq.Url.Host, Path: fq.Url.Path, RawPath: fq.Url.RawPath, RawQuery: fq.Url.RawQuery, Fragment: fq.Url.Fragment, } for k, v := range fq.HeaderEntries { ret.Header[k] = v.Values } if fq.PeerCertificates != nil && len(fq.PeerCertificates) > 0 { ret.TLS = &tls.ConnectionState{ PeerCertificates: make([]*x509.Certificate, len(fq.PeerCertificates)), } for i, certBytes := range fq.PeerCertificates { cert, err := x509.ParseCertificate(certBytes) if err != nil { return nil, err } ret.TLS.PeerCertificates[i] = cert } } return ret, nil } type RPCResponseWriter struct { statusCode int header http.Header body *bytes.Buffer } // NewRPCResponseWriter returns an initialized RPCResponseWriter func NewRPCResponseWriter() *RPCResponseWriter { w := &RPCResponseWriter{ header: make(http.Header), body: new(bytes.Buffer), statusCode: 200, } // w.header.Set("Content-Type", "application/octet-stream") return w } func (w *RPCResponseWriter) Header() http.Header { return w.header } func (w *RPCResponseWriter) Write(buf []byte) (int, error) { w.body.Write(buf) return len(buf), nil } func (w *RPCResponseWriter) WriteHeader(code int) { w.statusCode = code } func (w *RPCResponseWriter) StatusCode() int { return w.statusCode } func (w *RPCResponseWriter) Body() *bytes.Buffer { return w.body }