148 lines
4.2 KiB
Go
148 lines
4.2 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
// protoc-gen-consul-rate-limit maintains the mapping of gRPC method names to
|
|
// a specification of how they should be rate-limited. This is used by the gRPC
|
|
// InTapHandle function (see agent/grpc-middleware/rate.go) to enforce relevant
|
|
// limits without having to call the handler.
|
|
//
|
|
// It works in two phases:
|
|
//
|
|
// 1. Buf/protoc invokes this plugin for each .proto file. We extract the rate
|
|
// limit specification from an annotation on the RPC:
|
|
//
|
|
// service Foo {
|
|
// rpc Bar(BarRequest) returns (BarResponse) {
|
|
// option (hashicorp.consul.internal.ratelimit.spec) = {
|
|
// operation_type: OPERATION_TYPE_WRITE,
|
|
// operation_category: OPERATION_CATEGORY_ACL
|
|
// };
|
|
// }
|
|
// }
|
|
//
|
|
// We write a JSON array of the limits to protobuf/package/path/.ratelimit.tmp:
|
|
//
|
|
// [
|
|
// {
|
|
// "MethodName": "/Foo/Bar",
|
|
// "OperationType": "OPERATION_TYPE_WRITE",
|
|
// "OperationCategory": "OPERATION_CATEGORY_ACL"
|
|
// }
|
|
// ]
|
|
//
|
|
// 2. The protobuf.sh script (invoked by make proto) runs our postprocess script
|
|
// which reads all of the .ratelimit.tmp files in proto and proto-public and
|
|
// generates a single Go map in agent/grpc-middleware/rate_limit_mappings.gen.go
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"google.golang.org/protobuf/compiler/protogen"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/hashicorp/consul/proto-public/annotations/ratelimit"
|
|
)
|
|
|
|
const (
|
|
outputFileName = ".ratelimit.tmp"
|
|
|
|
missingSpecTmpl = `RPC %s is missing rate-limit specification, fix it with:
|
|
|
|
import "proto-public/annotations/ratelimit/ratelimit.proto";
|
|
|
|
service %s {
|
|
rpc %s(...) returns (...) {
|
|
option (hashicorp.consul.internal.ratelimit.spec) = {
|
|
operation_type: OPERATION_TYPE_READ | OPERATION_TYPE_WRITE | OPERATION_TYPE_EXEMPT,
|
|
operation_category: OPERATION_CATEGORY_ACL | OPERATION_CATEGORY_PEER_STREAM | OPERATION_CATEGORY_CONNECT_CA | OPERATION_CATEGORY_PARTITION | OPERATION_CATEGORY_PEERING | OPERATION_CATEGORY_SERVER_DISCOVERY | OPERATION_CATEGORY_DATAPLANE | OPERATION_CATEGORY_DNS | OPERATION_CATEGORY_SUBSCRIBE | OPERATION_CATEGORY_OPERATOR | OPERATION_CATEGORY_RESOURCE,
|
|
};
|
|
}
|
|
}
|
|
`
|
|
|
|
enterpriseBuildTag = "//go:build consulent"
|
|
)
|
|
|
|
type rateLimitSpec struct {
|
|
MethodName string
|
|
OperationType string
|
|
OperationCategory string
|
|
Enterprise bool
|
|
}
|
|
|
|
func main() {
|
|
var opts protogen.Options
|
|
opts.Run(func(plugin *protogen.Plugin) error {
|
|
for _, path := range plugin.Request.FileToGenerate {
|
|
file, ok := plugin.FilesByPath[path]
|
|
if !ok {
|
|
return fmt.Errorf("failed to get file descriptor: %s", path)
|
|
}
|
|
|
|
specs, err := rateLimitSpecs(file)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(specs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
outputPath := filepath.Join(filepath.Dir(path), outputFileName)
|
|
output := plugin.NewGeneratedFile(outputPath, "")
|
|
if err := json.NewEncoder(output).Encode(specs); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func rateLimitSpecs(file *protogen.File) ([]rateLimitSpec, error) {
|
|
enterprise, err := isEnterpriseFile(file)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var specs []rateLimitSpec
|
|
for _, service := range file.Services {
|
|
for _, method := range service.Methods {
|
|
spec := rateLimitSpec{
|
|
// Format the method name in gRPC/HTTP path format.
|
|
MethodName: fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name()),
|
|
Enterprise: enterprise,
|
|
}
|
|
|
|
// Read the rate limit spec from the method options.
|
|
options := method.Desc.Options()
|
|
if !proto.HasExtension(options, ratelimit.E_Spec) {
|
|
err := fmt.Errorf(missingSpecTmpl,
|
|
method.Desc.Name(),
|
|
service.Desc.Name(),
|
|
method.Desc.Name())
|
|
return nil, err
|
|
}
|
|
|
|
def := proto.GetExtension(options, ratelimit.E_Spec).(*ratelimit.Spec)
|
|
spec.OperationType = def.OperationType.String()
|
|
spec.OperationCategory = def.OperationCategory.String()
|
|
|
|
specs = append(specs, spec)
|
|
}
|
|
}
|
|
return specs, nil
|
|
}
|
|
|
|
func isEnterpriseFile(file *protogen.File) (bool, error) {
|
|
source, err := os.ReadFile(file.Desc.Path())
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to read proto file: %w", err)
|
|
}
|
|
return bytes.Contains(source, []byte(enterpriseBuildTag)), nil
|
|
}
|