implement http trailers for errors after headers are sent

refactor http handler and copyChunks to get this all to work correctly
License: MIT
Signed-off-by: Jeromy <jeromyj@gmail.com>
This commit is contained in:
Jeromy 2015-07-24 17:41:59 -07:00
parent 814f437fb4
commit a7e50f1fbc
2 changed files with 64 additions and 41 deletions

View File

@ -3,6 +3,7 @@ package http
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@ -183,16 +184,16 @@ func getResponse(httpRes *http.Response, req cmds.Request) (cmds.Response, error
res.SetCloser(httpRes.Body)
if len(httpRes.Header.Get(streamHeader)) > 0 {
if len(httpRes.Header.Get(streamHeader)) > 0 && contentType != "application/json" {
// if output is a stream, we can just use the body reader
res.SetOutput(httpRes.Body)
res.SetOutput(&httpResponseReader{httpRes})
return res, nil
} else if len(httpRes.Header.Get(channelHeader)) > 0 {
// if output is coming from a channel, decode each chunk
outChan := make(chan interface{})
go func() {
dec := json.NewDecoder(httpRes.Body)
dec := json.NewDecoder(&httpResponseReader{httpRes})
outputType := reflect.TypeOf(req.Command().Type)
ctx := req.Context()
@ -237,7 +238,7 @@ func getResponse(httpRes *http.Response, req cmds.Request) (cmds.Response, error
return res, nil
}
dec := json.NewDecoder(httpRes.Body)
dec := json.NewDecoder(&httpResponseReader{httpRes})
if httpRes.StatusCode >= http.StatusBadRequest {
e := cmds.Error{}
@ -284,3 +285,30 @@ func getResponse(httpRes *http.Response, req cmds.Request) (cmds.Response, error
return res, nil
}
type httpResponseReader struct {
resp *http.Response
}
func (r *httpResponseReader) Read(b []byte) (int, error) {
n, err := r.resp.Body.Read(b)
if err == io.EOF {
_ = r.resp.Body.Close()
trailerErr := r.checkError()
if trailerErr != nil {
return n, trailerErr
}
}
return n, err
}
func (r *httpResponseReader) checkError() error {
if e := r.resp.Trailer.Get(StreamErrHeader); e != "" {
return errors.New(e)
}
return nil
}
func (r *httpResponseReader) Close() error {
return r.resp.Body.Close()
}

View File

@ -118,43 +118,43 @@ func (i internalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// call the command
res := i.root.Call(req)
// set the Content-Type based on res output
// now handle responding to the client properly
sendResponse(w, req, res)
}
func sendResponse(w http.ResponseWriter, req cmds.Request, res cmds.Response) {
var mime string
if _, ok := res.Output().(io.Reader); ok {
mime = ""
// we don't set the Content-Type for streams, so that browsers can MIME-sniff the type themselves
// we set this header so clients have a way to know this is an output stream
// (not marshalled command output)
// TODO: set a specific Content-Type if the command response needs it to be a certain type
w.Header().Set(streamHeader, "1")
} else {
enc, found, err := req.Option(cmds.EncShort).String()
// Try to guess mimeType from the encoding option
enc, found, err := res.Request().Option(cmds.EncShort).String()
if err != nil || !found {
w.WriteHeader(http.StatusInternalServerError)
return
}
mime := mimeTypes[enc]
w.Header().Set(contentTypeHeader, mime)
}
// set the Content-Length from the response length
if res.Length() > 0 {
w.Header().Set(contentLengthHeader, strconv.FormatUint(res.Length(), 10))
mime = mimeTypes[enc]
}
status := 200
// if response contains an error, write an HTTP error status code
if e := res.Error(); e != nil {
if e.Code == cmds.ErrClient {
w.WriteHeader(http.StatusBadRequest)
status = http.StatusBadRequest
} else {
w.WriteHeader(http.StatusInternalServerError)
status = http.StatusInternalServerError
}
// TODO: do we just ignore this error? or what?
}
out, err := res.Reader()
if err != nil {
w.Header().Set(contentTypeHeader, "text/plain")
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@ -167,13 +167,11 @@ func (i internalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
streamChans, _, _ := req.Option("stream-channels").Bool()
if isChan && streamChans {
if err := copyChunks(applicationJson, w, out); err != nil {
log.Error("error while writing stream", err)
}
return
// streaming output from a channel will always be json objects
mime = applicationJson
}
if err := flushCopy(w, out); err != nil {
if err := copyChunks(mime, status, isChan, res.Length(), w, out); err != nil {
log.Error("error while writing stream", err)
}
}
@ -183,20 +181,9 @@ func (i Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
i.corsHandler.ServeHTTP(w, r)
}
// flushCopy Copies from an io.Reader to a http.ResponseWriter.
// Flushes chunks over HTTP stream as they are read (if supported by transport).
func flushCopy(w http.ResponseWriter, out io.Reader) error {
if _, ok := w.(http.Flusher); !ok {
return copyChunks("", w, out)
}
_, err := io.Copy(&flushResponse{w}, out)
return err
}
// Copies from an io.Reader to a http.ResponseWriter.
// Flushes chunks over HTTP stream as they are read (if supported by transport).
func copyChunks(contentType string, w http.ResponseWriter, out io.Reader) error {
func copyChunks(contentType string, status int, channel bool, length uint64, w http.ResponseWriter, out io.Reader) error {
hijacker, ok := w.(http.Hijacker)
if !ok {
return errors.New("Could not create hijacker")
@ -207,12 +194,20 @@ func copyChunks(contentType string, w http.ResponseWriter, out io.Reader) error
}
defer conn.Close()
writer.WriteString("HTTP/1.1 200 OK\r\n")
writer.WriteString(fmt.Sprintf("HTTP/1.1 %d %s\r\n", status, http.StatusText(status)))
writer.WriteString(streamHeader + ": 1\r\n")
if contentType != "" {
writer.WriteString(contentTypeHeader + ": " + contentType + "\r\n")
}
if channel {
writer.WriteString(channelHeader + ": 1\r\n")
}
if length > 0 {
w.Header().Set(contentLengthHeader, strconv.FormatUint(length, 10))
}
writer.WriteString(transferEncodingHeader + ": chunked\r\n")
writer.WriteString(channelHeader + ": 1\r\n\r\n")
writer.WriteString("\r\n")
writeChunks := func() error {
buf := make([]byte, 32*1024)
@ -248,11 +243,11 @@ func copyChunks(contentType string, w http.ResponseWriter, out io.Reader) error
// if there was a stream error, write out an error trailer. hopefully
// the client will pick it up!
if streamErr != nil {
writer.WriteString(StreamErrHeader + ": " + sanitizedErrStr(err) + "\r\n")
writer.WriteString(StreamErrHeader + ": " + sanitizedErrStr(streamErr) + "\r\n")
}
writer.WriteString("\r\n") // close response
writer.Flush()
return nil
return streamErr
}
func sanitizedErrStr(err error) string {