From 2b06ffaa7d42499abda471097c0159ee2faf3d76 Mon Sep 17 00:00:00 2001 From: Jeromy Date: Sat, 25 Jul 2015 18:46:44 -0700 Subject: [PATCH] better refactor of http handler code License: MIT Signed-off-by: Jeromy --- commands/http/handler.go | 151 ++++++++++++++++++++++----------------- 1 file changed, 86 insertions(+), 65 deletions(-) diff --git a/commands/http/handler.go b/commands/http/handler.go index 5cb349c50..57ef0a447 100644 --- a/commands/http/handler.go +++ b/commands/http/handler.go @@ -1,6 +1,7 @@ package http import ( + "bufio" "errors" "fmt" "io" @@ -71,6 +72,11 @@ func NewHandler(ctx cmds.Context, root *cmds.Command, allowedOrigin string) *Han return &Handler{internal, c.Handler(internal)} } +func (i Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Call the CORS handler which wraps the internal handler. + i.corsHandler.ServeHTTP(w, r) +} + func (i internalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Debug("Incoming API request: ", r.URL) @@ -102,8 +108,8 @@ func (i internalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // get the node's context to pass into the commands. node, err := i.ctx.GetNode() if err != nil { - err = fmt.Errorf("cmds/http: couldn't GetNode(): %s", err) - http.Error(w, err.Error(), http.StatusInternalServerError) + s := fmt.Sprintf("cmds/http: couldn't GetNode(): %s", err) + http.Error(w, s, http.StatusInternalServerError) return } @@ -122,23 +128,32 @@ func (i internalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { sendResponse(w, req, res) } -func sendResponse(w http.ResponseWriter, req cmds.Request, res cmds.Response) { - - var mime string +func guessMimeType(res cmds.Response) (string, error) { if _, ok := res.Output().(io.Reader); ok { - mime = "" // we don't set the Content-Type for streams, so that browsers can MIME-sniff the type themselves // we set this header so clients have a way to know this is an output stream // (not marshalled command output) // TODO: set a specific Content-Type if the command response needs it to be a certain type - } else { - // Try to guess mimeType from the encoding option - enc, found, err := res.Request().Option(cmds.EncShort).String() - if err != nil || !found { - w.WriteHeader(http.StatusInternalServerError) - return - } - mime = mimeTypes[enc] + return "", nil + } + + // Try to guess mimeType from the encoding option + enc, found, err := res.Request().Option(cmds.EncShort).String() + if err != nil { + return "", err + } + if !found { + return "", errors.New("no encoding option set") + } + + return mimeTypes[enc], nil +} + +func sendResponse(w http.ResponseWriter, req cmds.Request, res cmds.Response) { + mime, err := guessMimeType(res) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return } status := 200 @@ -149,7 +164,7 @@ func sendResponse(w http.ResponseWriter, req cmds.Request, res cmds.Response) { } else { status = http.StatusInternalServerError } - // TODO: do we just ignore this error? or what? + // NOTE: The error will actually be written out by the reader below } out, err := res.Reader() @@ -158,6 +173,11 @@ func sendResponse(w http.ResponseWriter, req cmds.Request, res cmds.Response) { return } + h := w.Header() + if res.Length() > 0 { + h.Set(contentLengthHeader, strconv.FormatUint(res.Length(), 10)) + } + // if output is a channel and user requested streaming channels, // use chunk copier for the output _, isChan := res.Output().(chan interface{}) @@ -166,26 +186,30 @@ func sendResponse(w http.ResponseWriter, req cmds.Request, res cmds.Response) { } streamChans, _, _ := req.Option("stream-channels").Bool() - if isChan && streamChans { - // streaming output from a channel will always be json objects - mime = applicationJson + if isChan { + h.Set(channelHeader, "1") + if streamChans { + // streaming output from a channel will always be json objects + mime = applicationJson + } } + if mime != "" { + h.Set(contentTypeHeader, mime) + } + h.Set(streamHeader, "1") + h.Set(transferEncodingHeader, "chunked") - if err := copyChunks(mime, status, isChan, res.Length(), w, out); err != nil { + if err := copyChunks(status, w, out); err != nil { log.Error("error while writing stream", err) } } -func (i Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Call the CORS handler which wraps the internal handler. - i.corsHandler.ServeHTTP(w, r) -} - // Copies from an io.Reader to a http.ResponseWriter. // Flushes chunks over HTTP stream as they are read (if supported by transport). -func copyChunks(contentType string, status int, channel bool, length uint64, w http.ResponseWriter, out io.Reader) error { +func copyChunks(status int, w http.ResponseWriter, out io.Reader) error { hijacker, ok := w.(http.Hijacker) if !ok { + log.Error("Failed to create hijacker! cannot continue!") return errors.New("Could not create hijacker") } conn, writer, err := hijacker.Hijack() @@ -194,51 +218,20 @@ func copyChunks(contentType string, status int, channel bool, length uint64, w h } defer conn.Close() + // write status writer.WriteString(fmt.Sprintf("HTTP/1.1 %d %s\r\n", status, http.StatusText(status))) - writer.WriteString(streamHeader + ": 1\r\n") - if contentType != "" { - writer.WriteString(contentTypeHeader + ": " + contentType + "\r\n") - } - if channel { - writer.WriteString(channelHeader + ": 1\r\n") - } - if length > 0 { - w.Header().Set(contentLengthHeader, strconv.FormatUint(length, 10)) - } - writer.WriteString(transferEncodingHeader + ": chunked\r\n") + // Write out headers + w.Header().Write(writer) + + // end of headers writer.WriteString("\r\n") - writeChunks := func() error { - buf := make([]byte, 32*1024) - for { - n, err := out.Read(buf) + // write body + streamErr := writeChunks(out, writer) - if n > 0 { - length := fmt.Sprintf("%x\r\n", n) - writer.WriteString(length) - - _, err := writer.Write(buf[0:n]) - if err != nil { - return err - } - - writer.WriteString("\r\n") - writer.Flush() - } - - if err != nil && err != io.EOF { - return err - } - if err == io.EOF { - break - } - } - return nil - } - - streamErr := writeChunks() - writer.WriteString("0\r\n") // close body + // close body + writer.WriteString("0\r\n") // if there was a stream error, write out an error trailer. hopefully // the client will pick it up! @@ -250,6 +243,34 @@ func copyChunks(contentType string, status int, channel bool, length uint64, w h return streamErr } +func writeChunks(r io.Reader, w *bufio.ReadWriter) error { + buf := make([]byte, 32*1024) + for { + n, err := r.Read(buf) + + if n > 0 { + length := fmt.Sprintf("%x\r\n", n) + w.WriteString(length) + + _, err := w.Write(buf[0:n]) + if err != nil { + return err + } + + w.WriteString("\r\n") + w.Flush() + } + + if err != nil && err != io.EOF { + return err + } + if err == io.EOF { + break + } + } + return nil +} + func sanitizedErrStr(err error) string { s := err.Error() s = strings.Split(s, "\n")[0]