package rest import ( "bytes" "encoding/base64" "fmt" "io" "io/ioutil" "net/http" "reflect" "strconv" "strings" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/private/protocol" ) // UnmarshalHandler is a named request handler for unmarshaling rest protocol requests var UnmarshalHandler = request.NamedHandler{Name: "awssdk.rest.Unmarshal", Fn: Unmarshal} // UnmarshalMetaHandler is a named request handler for unmarshaling rest protocol request metadata var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta", Fn: UnmarshalMeta} // Unmarshal unmarshals the REST component of a response in a REST service. func Unmarshal(r *request.Request) { if r.DataFilled() { v := reflect.Indirect(reflect.ValueOf(r.Data)) unmarshalBody(r, v) } } // UnmarshalMeta unmarshals the REST metadata of a response in a REST service func UnmarshalMeta(r *request.Request) { r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid") if r.RequestID == "" { // Alternative version of request id in the header r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id") } if r.DataFilled() { v := reflect.Indirect(reflect.ValueOf(r.Data)) unmarshalLocationElements(r, v) } } func unmarshalBody(r *request.Request, v reflect.Value) { if field, ok := v.Type().FieldByName("_"); ok { if payloadName := field.Tag.Get("payload"); payloadName != "" { pfield, _ := v.Type().FieldByName(payloadName) if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" { payload := v.FieldByName(payloadName) if payload.IsValid() { switch payload.Interface().(type) { case []byte: defer r.HTTPResponse.Body.Close() b, err := ioutil.ReadAll(r.HTTPResponse.Body) if err != nil { r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err) } else { payload.Set(reflect.ValueOf(b)) } case *string: defer r.HTTPResponse.Body.Close() b, err := ioutil.ReadAll(r.HTTPResponse.Body) if err != nil { r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err) } else { str := string(b) payload.Set(reflect.ValueOf(&str)) } default: switch payload.Type().String() { case "io.ReadCloser": payload.Set(reflect.ValueOf(r.HTTPResponse.Body)) case "io.ReadSeeker": b, err := ioutil.ReadAll(r.HTTPResponse.Body) if err != nil { r.Error = awserr.New(request.ErrCodeSerialization, "failed to read response body", err) return } payload.Set(reflect.ValueOf(ioutil.NopCloser(bytes.NewReader(b)))) default: io.Copy(ioutil.Discard, r.HTTPResponse.Body) defer r.HTTPResponse.Body.Close() r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", fmt.Errorf("unknown payload type %s", payload.Type())) } } } } } } } func unmarshalLocationElements(r *request.Request, v reflect.Value) { for i := 0; i < v.NumField(); i++ { m, field := v.Field(i), v.Type().Field(i) if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) { continue } if m.IsValid() { name := field.Tag.Get("locationName") if name == "" { name = field.Name } switch field.Tag.Get("location") { case "statusCode": unmarshalStatusCode(m, r.HTTPResponse.StatusCode) case "header": err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name), field.Tag) if err != nil { r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err) break } case "headers": prefix := field.Tag.Get("locationName") err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix) if err != nil { r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err) break } } } if r.Error != nil { return } } } func unmarshalStatusCode(v reflect.Value, statusCode int) { if !v.IsValid() { return } switch v.Interface().(type) { case *int64: s := int64(statusCode) v.Set(reflect.ValueOf(&s)) } } func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error { if len(headers) == 0 { return nil } switch r.Interface().(type) { case map[string]*string: // we only support string map value types out := map[string]*string{} for k, v := range headers { k = http.CanonicalHeaderKey(k) if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) { out[k[len(prefix):]] = &v[0] } } if len(out) != 0 { r.Set(reflect.ValueOf(out)) } } return nil } func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) error { switch tag.Get("type") { case "jsonvalue": if len(header) == 0 { return nil } case "blob": if len(header) == 0 { return nil } default: if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) { return nil } } switch v.Interface().(type) { case *string: v.Set(reflect.ValueOf(&header)) case []byte: b, err := base64.StdEncoding.DecodeString(header) if err != nil { return err } v.Set(reflect.ValueOf(b)) case *bool: b, err := strconv.ParseBool(header) if err != nil { return err } v.Set(reflect.ValueOf(&b)) case *int64: i, err := strconv.ParseInt(header, 10, 64) if err != nil { return err } v.Set(reflect.ValueOf(&i)) case *float64: f, err := strconv.ParseFloat(header, 64) if err != nil { return err } v.Set(reflect.ValueOf(&f)) case *time.Time: format := tag.Get("timestampFormat") if len(format) == 0 { format = protocol.RFC822TimeFormatName } t, err := protocol.ParseTime(format, header) if err != nil { return err } v.Set(reflect.ValueOf(&t)) case aws.JSONValue: escaping := protocol.NoEscape if tag.Get("location") == "header" { escaping = protocol.Base64Escape } m, err := protocol.DecodeJSONValue(header, escaping) if err != nil { return err } v.Set(reflect.ValueOf(m)) default: err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type()) return err } return nil }