88"context"
99"crypto/rand"
1010"encoding/base64"
11- "errors"
1211"fmt"
1312"io"
1413"io/ioutil"
@@ -47,18 +46,27 @@ type DialOptions struct{
4746CompressionThreshold int
4847}
4948
50- func (opts * DialOptions ) cloneWithDefaults () * DialOptions {
49+ func (opts * DialOptions ) cloneWithDefaults (ctx context.Context ) (context.Context , context.CancelFunc , * DialOptions ){
50+ var cancel context.CancelFunc
51+
5152var o DialOptions
5253if opts != nil {
5354o = * opts
5455 }
5556if o .HTTPClient == nil {
5657o .HTTPClient = http .DefaultClient
58+ } else if opts .HTTPClient .Timeout > 0 {
59+ ctx , cancel = context .WithTimeout (ctx , opts .HTTPClient .Timeout )
60+
61+ newClient := * opts .HTTPClient
62+ newClient .Timeout = 0
63+ opts .HTTPClient = & newClient
5764 }
5865if o .HTTPHeader == nil {
5966o .HTTPHeader = http.Header {}
6067 }
61- return & o
68+
69+ return ctx , cancel , & o
6270}
6371
6472// Dial performs a WebSocket handshake on url.
@@ -81,7 +89,11 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon
8189func dial (ctx context.Context , urls string , opts * DialOptions , rand io.Reader ) (_ * Conn , _ * http.Response , err error ){
8290defer errd .Wrap (& err , "failed to WebSocket dial" )
8391
84- opts = opts .cloneWithDefaults ()
92+ var cancel context.CancelFunc
93+ ctx , cancel , opts = opts .cloneWithDefaults (ctx )
94+ if cancel != nil {
95+ defer cancel ()
96+ }
8597
8698secWebSocketKey , err := secWebSocketKey (rand )
8799if err != nil {
@@ -137,10 +149,6 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
137149}
138150
139151func handshakeRequest (ctx context.Context , urls string , opts * DialOptions , copts * compressionOptions , secWebSocketKey string ) (* http.Response , error ){
140- if opts .HTTPClient .Timeout > 0 {
141- return nil , errors .New ("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67" )
142- }
143-
144152u , err := url .Parse (urls )
145153if err != nil {
146154return nil , fmt .Errorf ("failed to parse url: %w" , err )
@@ -193,11 +201,11 @@ func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSo
193201return nil , fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
194202 }
195203
196- if ! headerContainsToken (resp .Header , "Connection" , "Upgrade" ){
204+ if ! headerContainsTokenIgnoreCase (resp .Header , "Connection" , "Upgrade" ){
197205return nil , fmt .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , resp .Header .Get ("Connection" ))
198206 }
199207
200- if ! headerContainsToken (resp .Header , "Upgrade" , "WebSocket" ){
208+ if ! headerContainsTokenIgnoreCase (resp .Header , "Upgrade" , "WebSocket" ){
201209return nil , fmt .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , resp .Header .Get ("Upgrade" ))
202210 }
203211
@@ -242,7 +250,8 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress
242250return nil , fmt .Errorf ("WebSocket protcol violation: unsupported extensions from server: %+v" , exts [1 :])
243251 }
244252
245- copts = & * copts
253+ _copts := * copts
254+ copts = & _copts
246255
247256for _ , p := range ext .params {
248257switch p {
0 commit comments