regenerate rpc glue stubs in protobuf files using comments (#12625)

This commit is contained in:
R.B. Boyer 2022-03-25 15:55:40 -05:00 committed by GitHub
parent 9d3df6b08b
commit f531f1e87d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 395 additions and 7 deletions

View File

@ -79,6 +79,7 @@ function main {
local proto_go_path=${proto_path%%.proto}.pb.go
local proto_go_bin_path=${proto_path%%.proto}.pb.binary.go
local proto_go_rpcglue_path=${proto_path%%.proto}.rpcglue.pb.go
local go_proto_out="paths=source_relative"
if is_set "${grpc}"
@ -132,13 +133,13 @@ function main {
return 1
fi
echo "debug_run protoc \
-I=\"${golang_proto_path}\" \
-I=\"${golang_proto_mod_path}\" \
-I=\"${SOURCE_DIR}\" \
--go_out=\"${go_proto_out}${SOURCE_DIR}\" \
--go-binary_out=\"${SOURCE_DIR}\" \
\"${proto_path}\""
echo "debug_run protoc \
-I=\"${golang_proto_path}\" \
-I=\"${golang_proto_mod_path}\" \
-I=\"${SOURCE_DIR}\" \
--go_out=\"${go_proto_out}${SOURCE_DIR}\" \
--go-binary_out=\"${SOURCE_DIR}\" \
\"${proto_path}\""
BUILD_TAGS=$(sed -e '/^[[:space:]]*$/,$d' < "${proto_path}" | grep '// +build')
if test -n "${BUILD_TAGS}"
@ -152,6 +153,15 @@ function main {
mv "${proto_go_bin_path}.new" "${proto_go_bin_path}"
fi
# note: this has to run after we fix up the build tags above
rm -f "${proto_go_rpcglue_path}"
debug_run go run ./internal/tools/proto-gen-rpc-glue/main.go -path "${proto_go_path}"
if test $? -ne 0
then
err "Failed to generate consul rpc glue outputs from ${proto_path}"
return 1
fi
return 0
}

View File

@ -0,0 +1,3 @@
module github.com/hashicorp/consul/internal/tools/proto-gen-rpc-glue
go 1.17

View File

View File

@ -0,0 +1,375 @@
package main
import (
"bytes"
"errors"
"flag"
"fmt"
"go/ast"
"go/parser"
"go/token"
"log"
"os"
"os/exec"
"strings"
)
var (
flagPath = flag.String("path", "", "path of file to load")
verbose = flag.Bool("v", false, "verbose output")
)
const (
annotationPrefix = "@consul-rpc-glue:"
outputFileSuffix = ".rpcglue.pb.go"
)
func main() {
flag.Parse()
log.SetFlags(0)
if *flagPath == "" {
log.Fatal("missing required -path argument")
}
if err := run(*flagPath); err != nil {
log.Fatal(err)
}
}
func run(path string) error {
fi, err := os.Stat(path)
if err != nil {
return err
}
if fi.IsDir() {
return fmt.Errorf("argument must be a file: %s", path)
}
if !strings.HasSuffix(path, ".pb.go") {
return fmt.Errorf("file must end with .pb.go: %s", path)
}
if err := processFile(path); err != nil {
return fmt.Errorf("error processing file %q: %v", path, err)
}
return nil
}
func processFile(path string) error {
if *verbose {
log.Printf("visiting file %q", path)
}
fset := token.NewFileSet()
tree, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
if err != nil {
return err
}
v := visitor{}
ast.Walk(&v, tree)
if err := v.Err(); err != nil {
return err
}
if len(v.Types) == 0 {
return nil
}
if *verbose {
log.Printf("Package: %s", v.Package)
log.Printf("BuildTags: %v", v.BuildTags)
log.Println()
for _, typ := range v.Types {
log.Printf("Type: %s", typ.Name)
ann := typ.Annotation
if ann.ReadRequest != "" {
log.Printf(" ReadRequest from %s", ann.ReadRequest)
}
if ann.WriteRequest != "" {
log.Printf(" WriteRequest from %s", ann.WriteRequest)
}
if ann.TargetDatacenter != "" {
log.Printf(" TargetDatacenter from %s", ann.TargetDatacenter)
}
}
}
// generate output
var buf bytes.Buffer
if len(v.BuildTags) > 0 {
for _, line := range v.BuildTags {
buf.WriteString(line + "\n")
}
buf.WriteString("\n")
}
buf.WriteString("// Code generated by proto-gen-rpc-glue. DO NOT EDIT.\n\n")
buf.WriteString("package " + v.Package + "\n")
buf.WriteString(`
import (
"time"
)
`)
for _, typ := range v.Types {
if typ.Annotation.WriteRequest != "" {
buf.WriteString(fmt.Sprintf(`
func (msg *%[1]s) AllowStaleRead() bool {
return false
}
func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) {
if msg == nil || msg.%[2]s == nil {
return false, nil
}
return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b)
}
func (msg *%[1]s) IsRead() bool {
return false
}
func (msg *%[1]s) SetTokenSecret(s string) {
msg.%[2]s.SetTokenSecret(s)
}
func (msg *%[1]s) TokenSecret() string {
if msg == nil || msg.%[2]s == nil {
return ""
}
return msg.%[2]s.TokenSecret()
}
func (msg *%[1]s) Token() string {
if msg.%[2]s == nil {
return ""
}
return msg.%[2]s.Token
}
`, typ.Name, typ.Annotation.WriteRequest))
}
if typ.Annotation.ReadRequest != "" {
buf.WriteString(fmt.Sprintf(`
func (msg *%[1]s) IsRead() bool {
return true
}
func (msg *%[1]s) AllowStaleRead() bool {
return msg.%[2]s.AllowStaleRead()
}
func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) {
if msg == nil || msg.%[2]s == nil {
return false, nil
}
return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b)
}
func (msg *%[1]s) SetTokenSecret(s string) {
msg.%[2]s.SetTokenSecret(s)
}
func (msg *%[1]s) TokenSecret() string {
if msg == nil || msg.%[2]s == nil {
return ""
}
return msg.%[2]s.TokenSecret()
}
func (msg *%[1]s) Token() string {
if msg.%[2]s == nil {
return ""
}
return msg.%[2]s.Token
}
`, typ.Name, typ.Annotation.ReadRequest))
}
if typ.Annotation.TargetDatacenter != "" {
buf.WriteString(fmt.Sprintf(`
func (msg *%[1]s) RequestDatacenter() string {
if msg == nil || msg.%[2]s == nil {
return ""
}
return msg.%[2]s.GetDatacenter()
}
`, typ.Name, typ.Annotation.TargetDatacenter))
}
}
// write to disk
outFile := strings.TrimSuffix(path, ".pb.go") + outputFileSuffix
if err := os.WriteFile(outFile, buf.Bytes(), 0644); err != nil {
return err
}
// clean up
cmd := exec.Command("gofmt", "-s", "-w", outFile)
cmd.Stdout = nil
cmd.Stderr = os.Stderr
cmd.Stdin = nil
if err := cmd.Run(); err != nil {
return fmt.Errorf("error running 'gofmt -s -w %q': %v", outFile, err)
}
return nil
}
type TypeInfo struct {
Name string
Annotation Annotation
}
type visitor struct {
Package string
BuildTags []string
Types []TypeInfo
Errs []error
}
func (v *visitor) Err() error {
switch len(v.Errs) {
case 0:
return nil
case 1:
return v.Errs[0]
default:
//
var s []string
for _, e := range v.Errs {
s = append(s, e.Error())
}
return errors.New(strings.Join(s, "; "))
}
}
var _ ast.Visitor = (*visitor)(nil)
func (v *visitor) Visit(node ast.Node) ast.Visitor {
if node == nil {
return v
}
switch x := node.(type) {
case *ast.File:
v.Package = x.Name.Name
v.BuildTags = getRawBuildTags(x)
for _, d := range x.Decls {
gd, ok := d.(*ast.GenDecl)
if !ok {
continue
}
if gd.Doc == nil {
continue
} else if len(gd.Specs) != 1 {
continue
}
spec := gd.Specs[0]
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
ann, err := getAnnotation(gd.Doc.List)
if err != nil {
v.Errs = append(v.Errs, err)
continue
} else if ann.IsZero() {
continue
}
v.Types = append(v.Types, TypeInfo{
Name: typeSpec.Name.Name,
Annotation: ann,
})
}
}
return v
}
type Annotation struct {
ReadRequest string
WriteRequest string
TargetDatacenter string
}
func (a Annotation) IsZero() bool {
return a == Annotation{}
}
func getAnnotation(doc []*ast.Comment) (Annotation, error) {
raw, ok := getRawStructAnnotation(doc)
if !ok {
return Annotation{}, nil
}
var ann Annotation
parts := strings.Split(raw, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
switch {
case part == "ReadRequest":
ann.ReadRequest = "ReadRequest"
case strings.HasPrefix(part, "ReadRequest"):
ann.TargetDatacenter = strings.TrimPrefix(part, "ReadRequest")
case part == "WriteRequest":
ann.WriteRequest = "WriteRequest"
case strings.HasPrefix(part, "WriteRequest"):
ann.TargetDatacenter = strings.TrimPrefix(part, "WriteRequest")
case part == "TargetDatacenter":
ann.TargetDatacenter = "TargetDatacenter"
case strings.HasPrefix(part, "TargetDatacenter"):
ann.TargetDatacenter = strings.TrimPrefix(part, "TargetDatacenter")
default:
return Annotation{}, fmt.Errorf("unexpected annotation part: %s", part)
}
}
return ann, nil
}
func getRawStructAnnotation(doc []*ast.Comment) (string, bool) {
for _, line := range doc {
text := strings.TrimSpace(strings.TrimLeft(line.Text, "/"))
ann := strings.TrimSpace(strings.TrimPrefix(text, annotationPrefix))
if text != ann {
return ann, true
}
}
return "", false
}
func getRawBuildTags(file *ast.File) []string {
// build tags are always the first group, at the very top
if len(file.Comments) == 0 {
return nil
}
cg := file.Comments[0]
var out []string
for _, line := range cg.List {
text := strings.TrimSpace(strings.TrimLeft(line.Text, "/"))
if !strings.HasPrefix(text, "go:build ") && !strings.HasPrefix(text, "+build") {
break // stop at first non-build-tag
}
out = append(out, line.Text)
}
return out
}