package main import ( "bytes" "errors" "flag" "fmt" "go/ast" "go/parser" "go/token" "log" "os" "os/exec" "strings" ) var ( flagPath = flag.String("path", "", "path of file to load") verbose = flag.Bool("v", false, "verbose output") ) const ( annotationPrefix = "@consul-rpc-glue:" outputFileSuffix = ".rpcglue.pb.go" ) func main() { flag.Parse() log.SetFlags(0) if *flagPath == "" { log.Fatal("missing required -path argument") } if err := run(*flagPath); err != nil { log.Fatal(err) } } func run(path string) error { fi, err := os.Stat(path) if err != nil { return err } if fi.IsDir() { return fmt.Errorf("argument must be a file: %s", path) } if !strings.HasSuffix(path, ".pb.go") { return fmt.Errorf("file must end with .pb.go: %s", path) } if err := processFile(path); err != nil { return fmt.Errorf("error processing file %q: %v", path, err) } return nil } func processFile(path string) error { if *verbose { log.Printf("visiting file %q", path) } fset := token.NewFileSet() tree, err := parser.ParseFile(fset, path, nil, parser.ParseComments) if err != nil { return err } v := visitor{} ast.Walk(&v, tree) if err := v.Err(); err != nil { return err } if len(v.Types) == 0 { return nil } if *verbose { log.Printf("Package: %s", v.Package) log.Printf("BuildTags: %v", v.BuildTags) log.Println() for _, typ := range v.Types { log.Printf("Type: %s", typ.Name) ann := typ.Annotation if ann.ReadRequest != "" { log.Printf(" ReadRequest from %s", ann.ReadRequest) } if ann.WriteRequest != "" { log.Printf(" WriteRequest from %s", ann.WriteRequest) } if ann.TargetDatacenter != "" { log.Printf(" TargetDatacenter from %s", ann.TargetDatacenter) } if ann.QueryOptions != "" { log.Printf(" QueryOptions from %s", ann.QueryOptions) } if ann.QueryMeta != "" { log.Printf(" QueryMeta from %s", ann.QueryMeta) } } } // generate output var buf bytes.Buffer if len(v.BuildTags) > 0 { for _, line := range v.BuildTags { buf.WriteString(line + "\n") } buf.WriteString("\n") } buf.WriteString("// Code generated by proto-gen-rpc-glue. DO NOT EDIT.\n\n") buf.WriteString("package " + v.Package + "\n") buf.WriteString(` import ( "time" "github.com/hashicorp/consul/agent/structs" ) // Reference imports to suppress errors if they are not otherwise used. var _ structs.RPCInfo `) for _, typ := range v.Types { if typ.Annotation.WriteRequest != "" { buf.WriteString(fmt.Sprintf(tmplWriteRequest, typ.Name, typ.Annotation.WriteRequest)) } if typ.Annotation.ReadRequest != "" { buf.WriteString(fmt.Sprintf(tmplReadRequest, typ.Name, typ.Annotation.ReadRequest)) } if typ.Annotation.TargetDatacenter != "" { buf.WriteString(fmt.Sprintf(tmplTargetDatacenter, typ.Name, typ.Annotation.TargetDatacenter)) } if typ.Annotation.QueryOptions != "" { buf.WriteString(fmt.Sprintf(tmplQueryOptions, typ.Name, typ.Annotation.QueryOptions)) } if typ.Annotation.QueryMeta != "" { buf.WriteString(fmt.Sprintf(tmplQueryMeta, typ.Name, typ.Annotation.QueryMeta)) } } // write to disk outFile := strings.TrimSuffix(path, ".pb.go") + outputFileSuffix if err := os.WriteFile(outFile, buf.Bytes(), 0644); err != nil { return err } // clean up cmd := exec.Command("gofmt", "-s", "-w", outFile) cmd.Stdout = nil cmd.Stderr = os.Stderr cmd.Stdin = nil if err := cmd.Run(); err != nil { return fmt.Errorf("error running 'gofmt -s -w %q': %v", outFile, err) } return nil } type TypeInfo struct { Name string Annotation Annotation } type visitor struct { Package string BuildTags []string Types []TypeInfo Errs []error } func (v *visitor) Err() error { switch len(v.Errs) { case 0: return nil case 1: return v.Errs[0] default: // var s []string for _, e := range v.Errs { s = append(s, e.Error()) } return errors.New(strings.Join(s, "; ")) } } var _ ast.Visitor = (*visitor)(nil) func (v *visitor) Visit(node ast.Node) ast.Visitor { if node == nil { return v } switch x := node.(type) { case *ast.File: v.Package = x.Name.Name v.BuildTags = getRawBuildTags(x) for _, d := range x.Decls { gd, ok := d.(*ast.GenDecl) if !ok { continue } if gd.Doc == nil { continue } else if len(gd.Specs) != 1 { continue } spec := gd.Specs[0] typeSpec, ok := spec.(*ast.TypeSpec) if !ok { continue } ann, err := getAnnotation(gd.Doc.List) if err != nil { v.Errs = append(v.Errs, err) continue } else if ann.IsZero() { continue } v.Types = append(v.Types, TypeInfo{ Name: typeSpec.Name.Name, Annotation: ann, }) } } return v } type Annotation struct { QueryMeta string QueryOptions string ReadRequest string WriteRequest string TargetDatacenter string } func (a Annotation) IsZero() bool { return a == Annotation{} } func getAnnotation(doc []*ast.Comment) (Annotation, error) { raw, ok := getRawStructAnnotation(doc) if !ok { return Annotation{}, nil } var ann Annotation parts := strings.Split(raw, ",") for _, part := range parts { part = strings.TrimSpace(part) switch { case part == "ReadRequest": ann.ReadRequest = "ReadRequest" case strings.HasPrefix(part, "ReadRequest="): ann.ReadRequest = strings.TrimPrefix(part, "ReadRequest=") case part == "WriteRequest": ann.WriteRequest = "WriteRequest" case strings.HasPrefix(part, "WriteRequest="): ann.WriteRequest = strings.TrimPrefix(part, "WriteRequest=") case part == "TargetDatacenter": ann.TargetDatacenter = "TargetDatacenter" case strings.HasPrefix(part, "TargetDatacenter="): ann.TargetDatacenter = strings.TrimPrefix(part, "TargetDatacenter=") case part == "QueryOptions": ann.QueryOptions = "QueryOptions" case strings.HasPrefix(part, "QueryOptions="): ann.QueryOptions = strings.TrimPrefix(part, "QueryOptions=") case part == "QueryMeta": ann.QueryMeta = "QueryMeta" case strings.HasPrefix(part, "QueryMeta="): ann.QueryMeta = strings.TrimPrefix(part, "QueryMeta=") default: return Annotation{}, fmt.Errorf("unexpected annotation part: %s", part) } } return ann, nil } func getRawStructAnnotation(doc []*ast.Comment) (string, bool) { for _, line := range doc { text := strings.TrimSpace(strings.TrimLeft(line.Text, "/")) ann := strings.TrimSpace(strings.TrimPrefix(text, annotationPrefix)) if text != ann { return ann, true } } return "", false } func getRawBuildTags(file *ast.File) []string { // build tags are always the first group, at the very top if len(file.Comments) == 0 { return nil } cg := file.Comments[0] var out []string for _, line := range cg.List { text := strings.TrimSpace(strings.TrimLeft(line.Text, "/")) if !strings.HasPrefix(text, "go:build ") && !strings.HasPrefix(text, "+build") { break // stop at first non-build-tag } out = append(out, line.Text) } return out } const tmplWriteRequest = ` // AllowStaleRead implements structs.RPCInfo func (msg *%[1]s) AllowStaleRead() bool { return false } // HasTimedOut implements structs.RPCInfo func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) { if msg == nil || msg.%[2]s == nil { return false, nil } return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b) } // IsRead implements structs.RPCInfo func (msg *%[1]s) IsRead() bool { return false } // SetTokenSecret implements structs.RPCInfo func (msg *%[1]s) SetTokenSecret(s string) { // TODO: initialize if nil msg.%[2]s.SetTokenSecret(s) } // TokenSecret implements structs.RPCInfo func (msg *%[1]s) TokenSecret() string { if msg == nil || msg.%[2]s == nil { return "" } return msg.%[2]s.TokenSecret() } // Token implements structs.RPCInfo func (msg *%[1]s) Token() string { if msg.%[2]s == nil { return "" } return msg.%[2]s.Token } ` const tmplReadRequest = ` // IsRead implements structs.RPCInfo func (msg *%[1]s) IsRead() bool { return true } // AllowStaleRead implements structs.RPCInfo func (msg *%[1]s) AllowStaleRead() bool { // TODO: initialize if nil return msg.%[2]s.AllowStaleRead() } // HasTimedOut implements structs.RPCInfo func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) { if msg == nil || msg.%[2]s == nil { return false, nil } return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b) } // SetTokenSecret implements structs.RPCInfo func (msg *%[1]s) SetTokenSecret(s string) { // TODO: initialize if nil msg.%[2]s.SetTokenSecret(s) } // TokenSecret implements structs.RPCInfo func (msg *%[1]s) TokenSecret() string { if msg == nil || msg.%[2]s == nil { return "" } return msg.%[2]s.TokenSecret() } // Token implements structs.RPCInfo func (msg *%[1]s) Token() string { if msg.%[2]s == nil { return "" } return msg.%[2]s.Token } ` const tmplTargetDatacenter = ` // RequestDatacenter implements structs.RPCInfo func (msg *%[1]s) RequestDatacenter() string { if msg == nil || msg.%[2]s == nil { return "" } return msg.%[2]s.GetDatacenter() } ` const tmplQueryOptions = ` // IsRead implements structs.RPCInfo func (msg *%[1]s) IsRead() bool { return true } // AllowStaleRead implements structs.RPCInfo func (msg *%[1]s) AllowStaleRead() bool { return msg.%[2]s.AllowStaleRead() } // HasTimedOut implements structs.RPCInfo func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) { if msg == nil || msg.%[2]s == nil { return false, nil } return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b) } // SetTokenSecret implements structs.RPCInfo func (msg *%[1]s) SetTokenSecret(s string) { // TODO: initialize if nil msg.%[2]s.SetTokenSecret(s) } // TokenSecret implements structs.RPCInfo func (msg *%[1]s) TokenSecret() string { if msg == nil || msg.%[2]s == nil { return "" } return msg.%[2]s.TokenSecret() } // Token implements structs.RPCInfo func (msg *%[1]s) Token() string { if msg.%[2]s == nil { return "" } return msg.%[2]s.Token } // GetToken is required to implement blockingQueryOptions func (msg *%[1]s) GetToken() string { if msg == nil || msg.%[2]s == nil { return "" } return msg.%[2]s.GetToken() } // GetMinQueryIndex is required to implement blockingQueryOptions func (msg *%[1]s) GetMinQueryIndex() uint64 { if msg == nil || msg.%[2]s == nil { return 0 } return msg.%[2]s.GetMinQueryIndex() } // GetMaxQueryTime is required to implement blockingQueryOptions func (msg *%[1]s) GetMaxQueryTime() (time.Duration, error) { if msg == nil || msg.%[2]s == nil { return 0, nil } return structs.DurationFromProto(msg.%[2]s.GetMaxQueryTime()), nil } // GetRequireConsistent is required to implement blockingQueryOptions func (msg *%[1]s) GetRequireConsistent() bool { if msg == nil || msg.%[2]s == nil { return false } return msg.%[2]s.RequireConsistent } ` const tmplQueryMeta = ` // SetLastContact is required to implement blockingQueryResponseMeta func (msg *%[1]s) SetLastContact(d time.Duration) { if msg == nil || msg.%[2]s == nil { return } msg.%[2]s.SetLastContact(d) } // SetKnownLeader is required to implement blockingQueryResponseMeta func (msg *%[1]s) SetKnownLeader(b bool) { if msg == nil || msg.%[2]s == nil { return } msg.%[2]s.SetKnownLeader(b) } // GetIndex is required to implement blockingQueryResponseMeta func (msg *%[1]s) GetIndex() uint64 { if msg == nil || msg.%[2]s == nil { return 0 } return msg.%[2]s.GetIndex() } // SetIndex is required to implement blockingQueryResponseMeta func (msg *%[1]s) SetIndex(i uint64) { if msg == nil || msg.%[2]s == nil { return } msg.%[2]s.SetIndex(i) } // SetResultsFilteredByACLs is required to implement blockingQueryResponseMeta func (msg *%[1]s) SetResultsFilteredByACLs(b bool) { if msg == nil || msg.%[2]s == nil { return } msg.%[2]s.SetResultsFilteredByACLs(b) } `