// Copyright 2015 CoreOS, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package client

import (
	"errors"
	"io/ioutil"
	"net/http"
	"net/url"
	"reflect"
	"strings"
	"testing"
	"time"

	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
)

type staticHTTPClient struct {
	resp http.Response
	err  error
}

func (s *staticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, []byte, error) {
	return &s.resp, nil, s.err
}

type staticHTTPAction struct {
	request http.Request
}

type staticHTTPResponse struct {
	resp http.Response
	err  error
}

func (s *staticHTTPAction) HTTPRequest(url.URL) *http.Request {
	return &s.request
}

type multiStaticHTTPClient struct {
	responses []staticHTTPResponse
	cur       int
}

func (s *multiStaticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, []byte, error) {
	r := s.responses[s.cur]
	s.cur++
	return &r.resp, nil, r.err
}

type fakeTransport struct {
	respchan     chan *http.Response
	errchan      chan error
	startCancel  chan struct{}
	finishCancel chan struct{}
}

func newFakeTransport() *fakeTransport {
	return &fakeTransport{
		respchan:     make(chan *http.Response, 1),
		errchan:      make(chan error, 1),
		startCancel:  make(chan struct{}, 1),
		finishCancel: make(chan struct{}, 1),
	}
}

func (t *fakeTransport) RoundTrip(*http.Request) (*http.Response, error) {
	select {
	case resp := <-t.respchan:
		return resp, nil
	case err := <-t.errchan:
		return nil, err
	case <-t.startCancel:
		// wait on finishCancel to simulate taking some amount of
		// time while calling CancelRequest
		<-t.finishCancel
		return nil, errors.New("cancelled")
	}
}

func (t *fakeTransport) CancelRequest(*http.Request) {
	t.startCancel <- struct{}{}
}

type fakeAction struct{}

func (a *fakeAction) HTTPRequest(url.URL) *http.Request {
	return &http.Request{}
}

func TestHTTPClientDoSuccess(t *testing.T) {
	tr := newFakeTransport()
	c := &httpClient{transport: tr}

	tr.respchan <- &http.Response{
		StatusCode: http.StatusTeapot,
		Body:       ioutil.NopCloser(strings.NewReader("foo")),
	}

	resp, body, err := c.Do(context.Background(), &fakeAction{})
	if err != nil {
		t.Fatalf("incorrect error value: want=nil got=%v", err)
	}

	wantCode := http.StatusTeapot
	if wantCode != resp.StatusCode {
		t.Fatalf("invalid response code: want=%d got=%d", wantCode, resp.StatusCode)
	}

	wantBody := []byte("foo")
	if !reflect.DeepEqual(wantBody, body) {
		t.Fatalf("invalid response body: want=%q got=%q", wantBody, body)
	}
}

func TestHTTPClientDoError(t *testing.T) {
	tr := newFakeTransport()
	c := &httpClient{transport: tr}

	tr.errchan <- errors.New("fixture")

	_, _, err := c.Do(context.Background(), &fakeAction{})
	if err == nil {
		t.Fatalf("expected non-nil error, got nil")
	}
}

func TestHTTPClientDoCancelContext(t *testing.T) {
	tr := newFakeTransport()
	c := &httpClient{transport: tr}

	tr.startCancel <- struct{}{}
	tr.finishCancel <- struct{}{}

	_, _, err := c.Do(context.Background(), &fakeAction{})
	if err == nil {
		t.Fatalf("expected non-nil error, got nil")
	}
}

func TestHTTPClientDoCancelContextWaitForRoundTrip(t *testing.T) {
	tr := newFakeTransport()
	c := &httpClient{transport: tr}

	donechan := make(chan struct{})
	ctx, cancel := context.WithCancel(context.Background())
	go func() {
		c.Do(ctx, &fakeAction{})
		close(donechan)
	}()

	// This should call CancelRequest and begin the cancellation process
	cancel()

	select {
	case <-donechan:
		t.Fatalf("httpClient.do should not have exited yet")
	default:
	}

	tr.finishCancel <- struct{}{}

	select {
	case <-donechan:
		//expected behavior
		return
	case <-time.After(time.Second):
		t.Fatalf("httpClient.do did not exit within 1s")
	}
}

func TestHTTPClusterClientDo(t *testing.T) {
	fakeErr := errors.New("fake!")
	tests := []struct {
		client   *httpClusterClient
		wantCode int
		wantErr  error
	}{
		// first good response short-circuits Do
		{
			client: &httpClusterClient{
				clients: []HTTPClient{
					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
					&staticHTTPClient{err: fakeErr},
				},
			},
			wantCode: http.StatusTeapot,
		},

		// fall through to good endpoint if err is arbitrary
		{
			client: &httpClusterClient{
				clients: []HTTPClient{
					&staticHTTPClient{err: fakeErr},
					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
				},
			},
			wantCode: http.StatusTeapot,
		},

		// ErrTimeout short-circuits Do
		{
			client: &httpClusterClient{
				clients: []HTTPClient{
					&staticHTTPClient{err: ErrTimeout},
					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
				},
			},
			wantErr: ErrTimeout,
		},

		// ErrCanceled short-circuits Do
		{
			client: &httpClusterClient{
				clients: []HTTPClient{
					&staticHTTPClient{err: ErrCanceled},
					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
				},
			},
			wantErr: ErrCanceled,
		},

		// return err if there are no endpoints
		{
			client: &httpClusterClient{
				clients: []HTTPClient{},
			},
			wantErr: ErrNoEndpoints,
		},

		// return err if all endpoints return arbitrary errors
		{
			client: &httpClusterClient{
				clients: []HTTPClient{
					&staticHTTPClient{err: fakeErr},
					&staticHTTPClient{err: fakeErr},
				},
			},
			wantErr: fakeErr,
		},

		// 500-level errors cause Do to fallthrough to next endpoint
		{
			client: &httpClusterClient{
				clients: []HTTPClient{
					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusBadGateway}},
					&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
				},
			},
			wantCode: http.StatusTeapot,
		},
	}

	for i, tt := range tests {
		resp, _, err := tt.client.Do(context.Background(), nil)
		if !reflect.DeepEqual(tt.wantErr, err) {
			t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
			continue
		}

		if resp == nil {
			if tt.wantCode != 0 {
				t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode)
			}
			continue
		}

		if resp.StatusCode != tt.wantCode {
			t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
			continue
		}
	}
}

func TestRedirectedHTTPAction(t *testing.T) {
	act := &redirectedHTTPAction{
		action: &staticHTTPAction{
			request: http.Request{
				Method: "DELETE",
				URL: &url.URL{
					Scheme: "https",
					Host:   "foo.example.com",
					Path:   "/ping",
				},
			},
		},
		location: url.URL{
			Scheme: "https",
			Host:   "bar.example.com",
			Path:   "/pong",
		},
	}

	want := &http.Request{
		Method: "DELETE",
		URL: &url.URL{
			Scheme: "https",
			Host:   "bar.example.com",
			Path:   "/pong",
		},
	}
	got := act.HTTPRequest(url.URL{Scheme: "http", Host: "baz.example.com", Path: "/pang"})

	if !reflect.DeepEqual(want, got) {
		t.Fatalf("HTTPRequest is %#v, want %#v", want, got)
	}
}

func TestRedirectFollowingHTTPClient(t *testing.T) {
	tests := []struct {
		max      int
		client   HTTPClient
		wantCode int
		wantErr  error
	}{
		// errors bubbled up
		{
			max: 2,
			client: &multiStaticHTTPClient{
				responses: []staticHTTPResponse{
					staticHTTPResponse{
						err: errors.New("fail!"),
					},
				},
			},
			wantErr: errors.New("fail!"),
		},

		// no need to follow redirect if none given
		{
			max: 2,
			client: &multiStaticHTTPClient{
				responses: []staticHTTPResponse{
					staticHTTPResponse{
						resp: http.Response{
							StatusCode: http.StatusTeapot,
						},
					},
				},
			},
			wantCode: http.StatusTeapot,
		},

		// redirects if less than max
		{
			max: 2,
			client: &multiStaticHTTPClient{
				responses: []staticHTTPResponse{
					staticHTTPResponse{
						resp: http.Response{
							StatusCode: http.StatusTemporaryRedirect,
							Header:     http.Header{"Location": []string{"http://example.com"}},
						},
					},
					staticHTTPResponse{
						resp: http.Response{
							StatusCode: http.StatusTeapot,
						},
					},
				},
			},
			wantCode: http.StatusTeapot,
		},

		// succeed after reaching max redirects
		{
			max: 2,
			client: &multiStaticHTTPClient{
				responses: []staticHTTPResponse{
					staticHTTPResponse{
						resp: http.Response{
							StatusCode: http.StatusTemporaryRedirect,
							Header:     http.Header{"Location": []string{"http://example.com"}},
						},
					},
					staticHTTPResponse{
						resp: http.Response{
							StatusCode: http.StatusTemporaryRedirect,
							Header:     http.Header{"Location": []string{"http://example.com"}},
						},
					},
					staticHTTPResponse{
						resp: http.Response{
							StatusCode: http.StatusTeapot,
						},
					},
				},
			},
			wantCode: http.StatusTeapot,
		},

		// fail at max+1 redirects
		{
			max: 1,
			client: &multiStaticHTTPClient{
				responses: []staticHTTPResponse{
					staticHTTPResponse{
						resp: http.Response{
							StatusCode: http.StatusTemporaryRedirect,
							Header:     http.Header{"Location": []string{"http://example.com"}},
						},
					},
					staticHTTPResponse{
						resp: http.Response{
							StatusCode: http.StatusTemporaryRedirect,
							Header:     http.Header{"Location": []string{"http://example.com"}},
						},
					},
					staticHTTPResponse{
						resp: http.Response{
							StatusCode: http.StatusTeapot,
						},
					},
				},
			},
			wantErr: ErrTooManyRedirects,
		},

		// fail if Location header not set
		{
			max: 1,
			client: &multiStaticHTTPClient{
				responses: []staticHTTPResponse{
					staticHTTPResponse{
						resp: http.Response{
							StatusCode: http.StatusTemporaryRedirect,
						},
					},
				},
			},
			wantErr: errors.New("Location header not set"),
		},

		// fail if Location header is invalid
		{
			max: 1,
			client: &multiStaticHTTPClient{
				responses: []staticHTTPResponse{
					staticHTTPResponse{
						resp: http.Response{
							StatusCode: http.StatusTemporaryRedirect,
							Header:     http.Header{"Location": []string{":"}},
						},
					},
				},
			},
			wantErr: errors.New("Location header not valid URL: :"),
		},
	}

	for i, tt := range tests {
		client := &redirectFollowingHTTPClient{client: tt.client, max: tt.max}
		resp, _, err := client.Do(context.Background(), nil)
		if !reflect.DeepEqual(tt.wantErr, err) {
			t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
			continue
		}

		if resp == nil {
			if tt.wantCode != 0 {
				t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode)
			}
			continue
		}

		if resp.StatusCode != tt.wantCode {
			t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
			continue
		}
	}
}
