Go语言http.Get()超时设置(更新)

Go by 达达 at 2014-02-26

上次写了一遍日志分享http.Get()设置超时的方案,后来自己过了一遍代码发现逻辑上有问题。

在Dail之后设置了Deadline,之后就没再重新设置。这对于不重用连接的http请求是没有问题的,但是Go的http库是支持keep-alive的,可以重用TCP/IP连接。这意味着一个连接过了超时时间后再被使用,就会出现超时错误,因为没有再重置超时时间。

拿上次的实验代码,在发送请求时加一个keep-alive头,然后每次发送请求前加个Sleep,就可以重新以上情况。

怎样做到每次使用一个连接发送和接收前就设置超时呢?我想了个办法是在Dial回调返回自己包装过的TimeoutConn,间接的调用真实的Conn,这样就可以再每次Read和Write之前设置超时时间了。

以下是修改后的实验代码:

//
// How to set timeout for http.Get() in golang
//
package main

import (
    "io"
    "io/ioutil"
    "log"
    "net"
    "net/http"
    "sync"
    "time"
)

type TimeoutConn struct {
    conn    net.Conn
    timeout time.Duration
}

func NewTimeoutConn(conn net.Conn, timeout time.Duration) *TimeoutConn {
    return &TimeoutConn{
        conn:    conn,
        timeout: timeout,
    }
}

func (c *TimeoutConn) Read(b []byte) (n int, err error) {
    c.SetReadDeadline(time.Now().Add(c.timeout))
    return c.conn.Read(b)
}

func (c *TimeoutConn) Write(b []byte) (n int, err error) {
    c.SetWriteDeadline(time.Now().Add(c.timeout))
    return c.conn.Write(b)
}

func (c *TimeoutConn) Close() error {
    return c.conn.Close()
}

func (c *TimeoutConn) LocalAddr() net.Addr {
    return c.conn.LocalAddr()
}

func (c *TimeoutConn) RemoteAddr() net.Addr {
    return c.conn.RemoteAddr()
}

func (c *TimeoutConn) SetDeadline(t time.Time) error {
    return c.conn.SetDeadline(t)
}

func (c *TimeoutConn) SetReadDeadline(t time.Time) error {
    return c.conn.SetReadDeadline(t)
}

func (c *TimeoutConn) SetWriteDeadline(t time.Time) error {
    return c.conn.SetWriteDeadline(t)
}

func main() {
    client := &http.Client{
        Transport: &http.Transport{
            Dial: func(netw, addr string) (net.Conn, error) {
                log.Printf("dial to %s://%s", netw, addr)

                conn, err := net.DialTimeout(netw, addr, time.Second*2)

                if err != nil {
                    return nil, err
                }

                return NewTimeoutConn(conn, time.Second*2), nil
            },
            ResponseHeaderTimeout: time.Second * 2,
        },
    }

    addr := StartTestServer()

    SendTestRequest(client, "1st", addr, "normal")
    SendTestRequest(client, "2st", addr, "normal")
    SendTestRequest(client, "3st", addr, "timeout")
    SendTestRequest(client, "4st", addr, "normal")

    time.Sleep(time.Second * 3)

    SendTestRequest(client, "5st", addr, "normal")
}

func StartTestServer() string {
    listener, err := net.Listen("tcp", ":0")

    if err != nil {
        log.Fatalf("failed to listen - %s", err.Error())
    }

    wg := new(sync.WaitGroup)
    wg.Add(1)

    go func() {
        http.HandleFunc("/normal", func(w http.ResponseWriter, req *http.Request) {
            time.Sleep(1000 * time.Millisecond)
            io.WriteString(w, "ok")
        })

        http.HandleFunc("/timeout", func(w http.ResponseWriter, req *http.Request) {
            time.Sleep(2500 * time.Millisecond)
            io.WriteString(w, "ok")
        })

        wg.Done()

        err = http.Serve(listener, nil)

        if err != nil {
            log.Fatalf("failed to start HTTP server - %s", err.Error())
        }
    }()

    wg.Wait()

    log.Printf("start http server at http://%s/", listener.Addr())

    return listener.Addr().String()
}

func SendTestRequest(client *http.Client, id, addr, path string) {
    req, err := http.NewRequest("GET", "http://"+addr+"/"+path, nil)

    if err != nil {
        log.Fatalf("new request failed - %s", err)
    }

    req.Header.Add("Connection", "keep-alive")

    switch path {
    case "normal":
        if resp, err := client.Do(req); err != nil {
            log.Fatalf("%s request failed - %s", id, err)
        } else {
            result, err2 := ioutil.ReadAll(resp.Body)
            if err2 != nil {
                log.Fatalf("%s response read failed - %s", id, err2)
            }
            resp.Body.Close()
            log.Printf("%s request - %s", id, result)
        }
    case "timeout":
        if _, err := client.Do(req); err == nil {
            log.Fatalf("%s request not timeout", id)
        } else {
            log.Printf("%s request - %s", id, err)
        }
    }
}