package httppeeridauth import ( "bytes" "crypto/rand" "crypto/tls" "fmt" "io" "log/slog" "net/http" "net/http/httptest" "os" "strings" "sync" "testing" "time" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestMutualAuth tests that we can do a mutually authenticated round trip func TestMutualAuth(t *testing.T) { originalLogger := log defer func() { log = originalLogger }() // Override to print debug logs log = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: slog.LevelDebug, })) zeroBytes := make([]byte, 64) serverKey, _, err := crypto.GenerateEd25519Key(bytes.NewReader(zeroBytes)) require.NoError(t, err) type clientTestCase struct { name string clientKeyGen func(t *testing.T) crypto.PrivKey } clientTestCases := []clientTestCase{ { name: "ED25519", clientKeyGen: func(t *testing.T) crypto.PrivKey { t.Helper() clientKey, _, err := crypto.GenerateEd25519Key(rand.Reader) require.NoError(t, err) return clientKey }, }, { name: "RSA", clientKeyGen: func(t *testing.T) crypto.PrivKey { t.Helper() clientKey, _, err := crypto.GenerateRSAKeyPair(2048, rand.Reader) require.NoError(t, err) return clientKey }, }, } type serverTestCase struct { name string serverGen func(t *testing.T) (*httptest.Server, *ServerPeerIDAuth) } serverTestCases := []serverTestCase{ { name: "no TLS", serverGen: func(t *testing.T) (*httptest.Server, *ServerPeerIDAuth) { t.Helper() auth := ServerPeerIDAuth{ PrivKey: serverKey, ValidHostnameFn: func(s string) bool { return s == "example.com" }, TokenTTL: time.Hour, NoTLS: true, } ts := httptest.NewServer(&auth) t.Cleanup(ts.Close) return ts, &auth }, }, { name: "TLS", serverGen: func(t *testing.T) (*httptest.Server, *ServerPeerIDAuth) { t.Helper() auth := ServerPeerIDAuth{ PrivKey: serverKey, ValidHostnameFn: func(s string) bool { return s == "example.com" }, TokenTTL: time.Hour, } ts := httptest.NewTLSServer(&auth) t.Cleanup(ts.Close) return ts, &auth }, }, } for _, ctc := range clientTestCases { for _, stc := range serverTestCases { t.Run(ctc.name+"+"+stc.name, func(t *testing.T) { ts, server := stc.serverGen(t) client := ts.Client() roundTripper := instrumentedRoundTripper{client.Transport, 0} client.Transport = &roundTripper requestsSent := func() int { defer func() { roundTripper.timesRoundtripped = 0 }() return roundTripper.timesRoundtripped } tlsClientConfig := roundTripper.TLSClientConfig() if tlsClientConfig != nil { // If we're using TLS, we need to set the SNI so that the // server can verify the request Host matches it. tlsClientConfig.ServerName = "example.com" } clientKey := ctc.clientKeyGen(t) clientAuth := ClientPeerIDAuth{PrivKey: clientKey} expectedServerID, err := peer.IDFromPrivateKey(serverKey) require.NoError(t, err) req, err := http.NewRequest("POST", ts.URL, nil) require.NoError(t, err) req.Host = "example.com" serverID, resp, err := clientAuth.AuthenticatedDo(client, req) require.NoError(t, err) require.Equal(t, expectedServerID, serverID) require.NotZero(t, clientAuth.tm.tokenMap["example.com"]) require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, 2, requestsSent()) // Once more with the auth token req, err = http.NewRequest("POST", ts.URL, nil) require.NoError(t, err) req.Host = "example.com" serverID, resp, err = clientAuth.AuthenticatedDo(client, req) require.NotEmpty(t, req.Header.Get("Authorization")) require.True(t, HasAuthHeader(req)) require.NoError(t, err) require.Equal(t, expectedServerID, serverID) require.NotZero(t, clientAuth.tm.tokenMap["example.com"]) require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, 1, requestsSent(), "should only call newRequest once since we have a token") t.Run("Tokens Expired", func(t *testing.T) { // Clear the auth token on the server side server.TokenTTL = 1 // Small TTL time.Sleep(100 * time.Millisecond) resetServerTokenTTL := sync.OnceFunc(func() { server.TokenTTL = time.Hour }) req, err := http.NewRequest("POST", ts.URL, nil) require.NoError(t, err) req.Host = "example.com" req.GetBody = func() (io.ReadCloser, error) { resetServerTokenTTL() return nil, nil } serverID, resp, err = clientAuth.AuthenticatedDo(client, req) require.NoError(t, err) require.NotEmpty(t, req.Header.Get("Authorization")) require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, expectedServerID, serverID) require.NotZero(t, clientAuth.tm.tokenMap["example.com"]) require.Equal(t, 3, requestsSent(), "should call newRequest 3x since our token expired") }) t.Run("Tokens Invalidated", func(t *testing.T) { // Clear the auth token on the server side key := make([]byte, 32) _, err := rand.Read(key) if err != nil { panic(err) } server.hmacPool = newHmacPool(key) req, err := http.NewRequest("POST", ts.URL, nil) req.GetBody = func() (io.ReadCloser, error) { return nil, nil } require.NoError(t, err) req.Host = "example.com" serverID, resp, err = clientAuth.AuthenticatedDo(client, req) require.NoError(t, err) require.NotEmpty(t, req.Header.Get("Authorization")) require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, expectedServerID, serverID) require.NotZero(t, clientAuth.tm.tokenMap["example.com"]) require.Equal(t, 3, requestsSent(), "should call have sent 3 reqs since our token expired") }) }) } } } func TestBodyNotSentDuringRedirect(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { b, err := io.ReadAll(r.Body) assert.NoError(t, err) assert.Empty(t, string(b)) if r.URL.Path != "/redirected" { w.Header().Set("Location", "/redirected") w.WriteHeader(http.StatusTemporaryRedirect) return } })) t.Cleanup(ts.Close) client := ts.Client() clientKey, _, _ := crypto.GenerateEd25519Key(rand.Reader) clientAuth := ClientPeerIDAuth{PrivKey: clientKey} req, err := http.NewRequest( "POST", ts.URL, strings.NewReader("Only for authenticated servers"), ) req.Host = "example.com" require.NoError(t, err) _, _, err = clientAuth.AuthenticatedDo(client, req) require.ErrorContains(t, err, "signature not set") // server doesn't actually handshake } type instrumentedRoundTripper struct { http.RoundTripper timesRoundtripped int } func (irt *instrumentedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { irt.timesRoundtripped++ return irt.RoundTripper.RoundTrip(req) } func (irt *instrumentedRoundTripper) TLSClientConfig() *tls.Config { return irt.RoundTripper.(*http.Transport).TLSClientConfig } func TestConcurrentAuth(t *testing.T) { serverKey, _, err := crypto.GenerateEd25519Key(rand.Reader) require.NoError(t, err) auth := ServerPeerIDAuth{ PrivKey: serverKey, ValidHostnameFn: func(s string) bool { return s == "example.com" }, TokenTTL: time.Hour, NoTLS: true, Next: func(_ peer.ID, w http.ResponseWriter, r *http.Request) { reqBody, err := io.ReadAll(r.Body) require.NoError(t, err) _, err = w.Write(reqBody) require.NoError(t, err) }, } ts := httptest.NewServer(&auth) t.Cleanup(ts.Close) wg := sync.WaitGroup{} for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() clientKey, _, err := crypto.GenerateEd25519Key(rand.Reader) require.NoError(t, err) clientAuth := ClientPeerIDAuth{PrivKey: clientKey} reqBody := []byte(fmt.Sprintf("echo %d", i)) req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(reqBody)) require.NoError(t, err) req.Host = "example.com" client := ts.Client() _, resp, err := clientAuth.AuthenticatedDo(client, req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, reqBody, respBody) }() } wg.Wait() }