diff --git a/commands/http/client.go b/commands/http/client.go index 8dccfcf21..a768205e8 100644 --- a/commands/http/client.go +++ b/commands/http/client.go @@ -3,6 +3,7 @@ package http import ( "bytes" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -183,16 +184,16 @@ func getResponse(httpRes *http.Response, req cmds.Request) (cmds.Response, error res.SetCloser(httpRes.Body) - if len(httpRes.Header.Get(streamHeader)) > 0 { + if len(httpRes.Header.Get(streamHeader)) > 0 && contentType != "application/json" { // if output is a stream, we can just use the body reader - res.SetOutput(httpRes.Body) + res.SetOutput(&httpResponseReader{httpRes}) return res, nil } else if len(httpRes.Header.Get(channelHeader)) > 0 { // if output is coming from a channel, decode each chunk outChan := make(chan interface{}) go func() { - dec := json.NewDecoder(httpRes.Body) + dec := json.NewDecoder(&httpResponseReader{httpRes}) outputType := reflect.TypeOf(req.Command().Type) ctx := req.Context() @@ -237,7 +238,7 @@ func getResponse(httpRes *http.Response, req cmds.Request) (cmds.Response, error return res, nil } - dec := json.NewDecoder(httpRes.Body) + dec := json.NewDecoder(&httpResponseReader{httpRes}) if httpRes.StatusCode >= http.StatusBadRequest { e := cmds.Error{} @@ -284,3 +285,30 @@ func getResponse(httpRes *http.Response, req cmds.Request) (cmds.Response, error return res, nil } + +type httpResponseReader struct { + resp *http.Response +} + +func (r *httpResponseReader) Read(b []byte) (int, error) { + n, err := r.resp.Body.Read(b) + if err == io.EOF { + _ = r.resp.Body.Close() + trailerErr := r.checkError() + if trailerErr != nil { + return n, trailerErr + } + } + return n, err +} + +func (r *httpResponseReader) checkError() error { + if e := r.resp.Trailer.Get(StreamErrHeader); e != "" { + return errors.New(e) + } + return nil +} + +func (r *httpResponseReader) Close() error { + return r.resp.Body.Close() +} diff --git a/commands/http/handler.go b/commands/http/handler.go index 855c195ea..5cb349c50 100644 --- a/commands/http/handler.go +++ b/commands/http/handler.go @@ -118,43 +118,43 @@ func (i internalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // call the command res := i.root.Call(req) - // set the Content-Type based on res output + // now handle responding to the client properly + sendResponse(w, req, res) +} + +func sendResponse(w http.ResponseWriter, req cmds.Request, res cmds.Response) { + + var mime string if _, ok := res.Output().(io.Reader); ok { + mime = "" // we don't set the Content-Type for streams, so that browsers can MIME-sniff the type themselves // we set this header so clients have a way to know this is an output stream // (not marshalled command output) // TODO: set a specific Content-Type if the command response needs it to be a certain type - w.Header().Set(streamHeader, "1") - } else { - enc, found, err := req.Option(cmds.EncShort).String() + // Try to guess mimeType from the encoding option + enc, found, err := res.Request().Option(cmds.EncShort).String() if err != nil || !found { w.WriteHeader(http.StatusInternalServerError) return } - mime := mimeTypes[enc] - w.Header().Set(contentTypeHeader, mime) - } - - // set the Content-Length from the response length - if res.Length() > 0 { - w.Header().Set(contentLengthHeader, strconv.FormatUint(res.Length(), 10)) + mime = mimeTypes[enc] } + status := 200 // if response contains an error, write an HTTP error status code if e := res.Error(); e != nil { if e.Code == cmds.ErrClient { - w.WriteHeader(http.StatusBadRequest) + status = http.StatusBadRequest } else { - w.WriteHeader(http.StatusInternalServerError) + status = http.StatusInternalServerError } + // TODO: do we just ignore this error? or what? } out, err := res.Reader() if err != nil { - w.Header().Set(contentTypeHeader, "text/plain") - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(err.Error())) + http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -167,13 +167,11 @@ func (i internalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { streamChans, _, _ := req.Option("stream-channels").Bool() if isChan && streamChans { - if err := copyChunks(applicationJson, w, out); err != nil { - log.Error("error while writing stream", err) - } - return + // streaming output from a channel will always be json objects + mime = applicationJson } - if err := flushCopy(w, out); err != nil { + if err := copyChunks(mime, status, isChan, res.Length(), w, out); err != nil { log.Error("error while writing stream", err) } } @@ -183,20 +181,9 @@ func (i Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { i.corsHandler.ServeHTTP(w, r) } -// flushCopy Copies from an io.Reader to a http.ResponseWriter. -// Flushes chunks over HTTP stream as they are read (if supported by transport). -func flushCopy(w http.ResponseWriter, out io.Reader) error { - if _, ok := w.(http.Flusher); !ok { - return copyChunks("", w, out) - } - - _, err := io.Copy(&flushResponse{w}, out) - return err -} - // Copies from an io.Reader to a http.ResponseWriter. // Flushes chunks over HTTP stream as they are read (if supported by transport). -func copyChunks(contentType string, w http.ResponseWriter, out io.Reader) error { +func copyChunks(contentType string, status int, channel bool, length uint64, w http.ResponseWriter, out io.Reader) error { hijacker, ok := w.(http.Hijacker) if !ok { return errors.New("Could not create hijacker") @@ -207,12 +194,20 @@ func copyChunks(contentType string, w http.ResponseWriter, out io.Reader) error } defer conn.Close() - writer.WriteString("HTTP/1.1 200 OK\r\n") + writer.WriteString(fmt.Sprintf("HTTP/1.1 %d %s\r\n", status, http.StatusText(status))) + writer.WriteString(streamHeader + ": 1\r\n") if contentType != "" { writer.WriteString(contentTypeHeader + ": " + contentType + "\r\n") } + if channel { + writer.WriteString(channelHeader + ": 1\r\n") + } + if length > 0 { + w.Header().Set(contentLengthHeader, strconv.FormatUint(length, 10)) + } writer.WriteString(transferEncodingHeader + ": chunked\r\n") - writer.WriteString(channelHeader + ": 1\r\n\r\n") + + writer.WriteString("\r\n") writeChunks := func() error { buf := make([]byte, 32*1024) @@ -248,11 +243,11 @@ func copyChunks(contentType string, w http.ResponseWriter, out io.Reader) error // if there was a stream error, write out an error trailer. hopefully // the client will pick it up! if streamErr != nil { - writer.WriteString(StreamErrHeader + ": " + sanitizedErrStr(err) + "\r\n") + writer.WriteString(StreamErrHeader + ": " + sanitizedErrStr(streamErr) + "\r\n") } writer.WriteString("\r\n") // close response writer.Flush() - return nil + return streamErr } func sanitizedErrStr(err error) string {