You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

145 lines
3.2 KiB

  1. // Copyright 2014 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package oauth2
  5. import (
  6. "errors"
  7. "io"
  8. "net/http"
  9. "sync"
  10. )
  11. // Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
  12. // wrapping a base RoundTripper and adding an Authorization header
  13. // with a token from the supplied Sources.
  14. //
  15. // Transport is a low-level mechanism. Most code will use the
  16. // higher-level Config.Client method instead.
  17. type Transport struct {
  18. // Source supplies the token to add to outgoing requests'
  19. // Authorization headers.
  20. Source TokenSource
  21. // Base is the base RoundTripper used to make HTTP requests.
  22. // If nil, http.DefaultTransport is used.
  23. Base http.RoundTripper
  24. mu sync.Mutex // guards modReq
  25. modReq map[*http.Request]*http.Request // original -> modified
  26. }
  27. // RoundTrip authorizes and authenticates the request with an
  28. // access token from Transport's Source.
  29. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
  30. reqBodyClosed := false
  31. if req.Body != nil {
  32. defer func() {
  33. if !reqBodyClosed {
  34. req.Body.Close()
  35. }
  36. }()
  37. }
  38. if t.Source == nil {
  39. return nil, errors.New("oauth2: Transport's Source is nil")
  40. }
  41. token, err := t.Source.Token()
  42. if err != nil {
  43. return nil, err
  44. }
  45. req2 := cloneRequest(req) // per RoundTripper contract
  46. token.SetAuthHeader(req2)
  47. t.setModReq(req, req2)
  48. res, err := t.base().RoundTrip(req2)
  49. // req.Body is assumed to have been closed by the base RoundTripper.
  50. reqBodyClosed = true
  51. if err != nil {
  52. t.setModReq(req, nil)
  53. return nil, err
  54. }
  55. res.Body = &onEOFReader{
  56. rc: res.Body,
  57. fn: func() { t.setModReq(req, nil) },
  58. }
  59. return res, nil
  60. }
  61. // CancelRequest cancels an in-flight request by closing its connection.
  62. func (t *Transport) CancelRequest(req *http.Request) {
  63. type canceler interface {
  64. CancelRequest(*http.Request)
  65. }
  66. if cr, ok := t.base().(canceler); ok {
  67. t.mu.Lock()
  68. modReq := t.modReq[req]
  69. delete(t.modReq, req)
  70. t.mu.Unlock()
  71. cr.CancelRequest(modReq)
  72. }
  73. }
  74. func (t *Transport) base() http.RoundTripper {
  75. if t.Base != nil {
  76. return t.Base
  77. }
  78. return http.DefaultTransport
  79. }
  80. func (t *Transport) setModReq(orig, mod *http.Request) {
  81. t.mu.Lock()
  82. defer t.mu.Unlock()
  83. if t.modReq == nil {
  84. t.modReq = make(map[*http.Request]*http.Request)
  85. }
  86. if mod == nil {
  87. delete(t.modReq, orig)
  88. } else {
  89. t.modReq[orig] = mod
  90. }
  91. }
  92. // cloneRequest returns a clone of the provided *http.Request.
  93. // The clone is a shallow copy of the struct and its Header map.
  94. func cloneRequest(r *http.Request) *http.Request {
  95. // shallow copy of the struct
  96. r2 := new(http.Request)
  97. *r2 = *r
  98. // deep copy of the Header
  99. r2.Header = make(http.Header, len(r.Header))
  100. for k, s := range r.Header {
  101. r2.Header[k] = append([]string(nil), s...)
  102. }
  103. return r2
  104. }
  105. type onEOFReader struct {
  106. rc io.ReadCloser
  107. fn func()
  108. }
  109. func (r *onEOFReader) Read(p []byte) (n int, err error) {
  110. n, err = r.rc.Read(p)
  111. if err == io.EOF {
  112. r.runFunc()
  113. }
  114. return
  115. }
  116. func (r *onEOFReader) Close() error {
  117. err := r.rc.Close()
  118. r.runFunc()
  119. return err
  120. }
  121. func (r *onEOFReader) runFunc() {
  122. if fn := r.fn; fn != nil {
  123. fn()
  124. r.fn = nil
  125. }
  126. }