From ce1d185a63ac3342916c2ad6b7acf5ceef05b493 Mon Sep 17 00:00:00 2001 From: Artur Troian Date: Mon, 29 Apr 2024 10:27:59 -0400 Subject: [PATCH] refactor: use reference tls verification from akash-api Signed-off-by: Artur Troian --- cmd/provider-services/cmd/leaseEvents.go | 2 +- cmd/provider-services/cmd/leaseLogs.go | 2 +- cmd/provider-services/cmd/leaseStatus.go | 2 +- cmd/provider-services/cmd/manifest.go | 4 +- .../cmd/migrate_endpoints.go | 2 +- .../cmd/migrate_hostnames.go | 4 +- cmd/provider-services/cmd/serviceStatus.go | 2 +- cmd/provider-services/cmd/shell.go | 2 +- cmd/provider-services/cmd/status.go | 2 +- gateway/grpc/server.go | 56 +------ gateway/rest/client.go | 150 ++++++++---------- gateway/rest/client_shell.go | 4 +- gateway/rest/integration_test.go | 31 ++-- gateway/rest/router_test.go | 70 +++++--- gateway/utils/utils.go | 66 ++------ go.mod | 4 +- go.sum | 8 +- 17 files changed, 158 insertions(+), 253 deletions(-) diff --git a/cmd/provider-services/cmd/leaseEvents.go b/cmd/provider-services/cmd/leaseEvents.go index 810ed781..ddbba1eb 100644 --- a/cmd/provider-services/cmd/leaseEvents.go +++ b/cmd/provider-services/cmd/leaseEvents.go @@ -88,7 +88,7 @@ func doLeaseEvents(cmd *cobra.Command) error { for _, lid := range leases { stream := result{lid: lid} prov, _ := sdk.AccAddressFromBech32(lid.Provider) - gclient, err := gwrest.NewClient(cl, prov, []tls.Certificate{cert}) + gclient, err := gwrest.NewClient(ctx, cl, prov, []tls.Certificate{cert}) if err == nil { stream.stream, stream.error = gclient.LeaseEvents(ctx, lid, svcs, follow) } else { diff --git a/cmd/provider-services/cmd/leaseLogs.go b/cmd/provider-services/cmd/leaseLogs.go index cecf4924..f6fe1f92 100644 --- a/cmd/provider-services/cmd/leaseLogs.go +++ b/cmd/provider-services/cmd/leaseLogs.go @@ -109,7 +109,7 @@ func doLeaseLogs(cmd *cobra.Command) error { for _, lid := range leases { stream := result{lid: lid} prov, _ := sdk.AccAddressFromBech32(lid.Provider) - gclient, err := gwrest.NewClient(cl, prov, []tls.Certificate{cert}) + gclient, err := gwrest.NewClient(ctx, cl, prov, []tls.Certificate{cert}) if err == nil { stream.stream, stream.error = gclient.LeaseLogs(ctx, lid, svcs, follow, tailLines) } else { diff --git a/cmd/provider-services/cmd/leaseStatus.go b/cmd/provider-services/cmd/leaseStatus.go index 954fc424..42e5e931 100644 --- a/cmd/provider-services/cmd/leaseStatus.go +++ b/cmd/provider-services/cmd/leaseStatus.go @@ -59,7 +59,7 @@ func doLeaseStatus(cmd *cobra.Command) error { return markRPCServerError(err) } - gclient, err := gwrest.NewClient(cl, prov, []tls.Certificate{cert}) + gclient, err := gwrest.NewClient(ctx, cl, prov, []tls.Certificate{cert}) if err != nil { return err } diff --git a/cmd/provider-services/cmd/manifest.go b/cmd/provider-services/cmd/manifest.go index 56379dfe..79147623 100644 --- a/cmd/provider-services/cmd/manifest.go +++ b/cmd/provider-services/cmd/manifest.go @@ -100,12 +100,12 @@ func doSendManifest(cmd *cobra.Command, sdlpath string) error { for i, lid := range leases { prov, _ := sdk.AccAddressFromBech32(lid.Provider) - gclient, err := gwrest.NewClient(cl, prov, []tls.Certificate{cert}) + gclient, err := gwrest.NewClient(ctx, cl, prov, []tls.Certificate{cert}) if err != nil { return err } - err = gclient.SubmitManifest(cmd.Context(), dseq, mani) + err = gclient.SubmitManifest(ctx, dseq, mani) res := result{ Provider: prov, Status: "PASS", diff --git a/cmd/provider-services/cmd/migrate_endpoints.go b/cmd/provider-services/cmd/migrate_endpoints.go index cd7408db..9c6ea179 100644 --- a/cmd/provider-services/cmd/migrate_endpoints.go +++ b/cmd/provider-services/cmd/migrate_endpoints.go @@ -44,7 +44,7 @@ func migrateEndpoints(cmd *cobra.Command, args []string) error { return markRPCServerError(err) } - gclient, err := gwrest.NewClient(cl, prov, []tls.Certificate{cert}) + gclient, err := gwrest.NewClient(ctx, cl, prov, []tls.Certificate{cert}) if err != nil { return err } diff --git a/cmd/provider-services/cmd/migrate_hostnames.go b/cmd/provider-services/cmd/migrate_hostnames.go index 4f388911..61b5d5b8 100644 --- a/cmd/provider-services/cmd/migrate_hostnames.go +++ b/cmd/provider-services/cmd/migrate_hostnames.go @@ -43,7 +43,7 @@ func migrateHostnames(cmd *cobra.Command, args []string) error { return markRPCServerError(err) } - gclient, err := gwrest.NewClient(cl, prov, []tls.Certificate{cert}) + gclient, err := gwrest.NewClient(ctx, cl, prov, []tls.Certificate{cert}) if err != nil { return err } @@ -58,7 +58,7 @@ func migrateHostnames(cmd *cobra.Command, args []string) error { return err } - err = gclient.MigrateHostnames(cmd.Context(), hostnames, dseq, gseq) + err = gclient.MigrateHostnames(ctx, hostnames, dseq, gseq) if err != nil { return showErrorToUser(err) } diff --git a/cmd/provider-services/cmd/serviceStatus.go b/cmd/provider-services/cmd/serviceStatus.go index c64fe67c..86a3162c 100644 --- a/cmd/provider-services/cmd/serviceStatus.go +++ b/cmd/provider-services/cmd/serviceStatus.go @@ -67,7 +67,7 @@ func doServiceStatus(cmd *cobra.Command) error { return markRPCServerError(err) } - gclient, err := gwrest.NewClient(cl, prov, []tls.Certificate{cert}) + gclient, err := gwrest.NewClient(ctx, cl, prov, []tls.Certificate{cert}) if err != nil { return err } diff --git a/cmd/provider-services/cmd/shell.go b/cmd/provider-services/cmd/shell.go index 36d5327b..77abd61d 100644 --- a/cmd/provider-services/cmd/shell.go +++ b/cmd/provider-services/cmd/shell.go @@ -129,7 +129,7 @@ func doLeaseShell(cmd *cobra.Command, args []string) error { return markRPCServerError(err) } - gclient, err := gwrest.NewClient(cl, prov, []tls.Certificate{cert}) + gclient, err := gwrest.NewClient(ctx, cl, prov, []tls.Certificate{cert}) if err != nil { return err } diff --git a/cmd/provider-services/cmd/status.go b/cmd/provider-services/cmd/status.go index 272bded9..026c7651 100644 --- a/cmd/provider-services/cmd/status.go +++ b/cmd/provider-services/cmd/status.go @@ -43,7 +43,7 @@ func doStatus(cmd *cobra.Command, addr sdk.Address) error { return err } - gclient, err := gwrest.NewClient(cl, addr, nil) + gclient, err := gwrest.NewClient(ctx, cl, addr, nil) if err != nil { return err } diff --git a/gateway/grpc/server.go b/gateway/grpc/server.go index 941f7cc8..4dfb8cd0 100644 --- a/gateway/grpc/server.go +++ b/gateway/grpc/server.go @@ -3,11 +3,11 @@ package grpc import ( "crypto/tls" "crypto/x509" - "errors" "fmt" "net" "time" + atls "github.com/akash-network/akash-api/go/util/tls" "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -121,61 +121,11 @@ func mtlsInterceptor() grpc.UnaryServerInterceptor { certificates := mtls.State.PeerCertificates if len(certificates) > 0 { - if len(certificates) != 1 { - return nil, fmt.Errorf("tls: invalid certificate chain") // nolint: goerr113 - } - cquery := QueryClientFromCtx(ctx) - cert := certificates[0] - - // validation - var owner sdk.Address - if owner, err = sdk.AccAddressFromBech32(cert.Subject.CommonName); err != nil { - return nil, fmt.Errorf("tls: invalid certificate's subject common name: %w", err) - } - - // 1. CommonName in issuer and Subject must match and be as Bech32 format - if cert.Subject.CommonName != cert.Issuer.CommonName { - return nil, fmt.Errorf("tls: invalid certificate's issuer common name: %w", err) - } - - // 2. serial number must be in - if cert.SerialNumber == nil { - return nil, fmt.Errorf("tls: invalid certificate serial number: %w", err) - } - - // 3. look up certificate on chain - var resp *ctypes.QueryCertificatesResponse - resp, err = cquery.Certificates( - ctx, - &ctypes.QueryCertificatesRequest{ - Filter: ctypes.CertificateFilter{ - Owner: owner.String(), - Serial: cert.SerialNumber.String(), - State: "valid", - }, - }, - ) + owner, _, err := atls.ValidatePeerCertificates(ctx, cquery, certificates, []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}) if err != nil { - return nil, fmt.Errorf("tls: unable to fetch certificate from chain: %w", err) - } - if (len(resp.Certificates) != 1) || !resp.Certificates[0].Certificate.IsState(ctypes.CertificateValid) { - return nil, errors.New("tls: attempt to use non-existing or revoked certificate") // nolint: goerr113 - } - - clientCertPool := x509.NewCertPool() - clientCertPool.AddCert(cert) - - opts := x509.VerifyOptions{ - Roots: clientCertPool, - CurrentTime: time.Now(), - KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - MaxConstraintComparisions: 0, - } - - if _, err = cert.Verify(opts); err != nil { - return nil, fmt.Errorf("tls: unable to verify certificate: %w", err) + return nil, err } ctx = ContextWithOwner(ctx, owner) diff --git a/gateway/rest/client.go b/gateway/rest/client.go index 44b9f80d..718f2f46 100644 --- a/gateway/rest/client.go +++ b/gateway/rest/client.go @@ -17,6 +17,7 @@ import ( "time" aclient "github.com/akash-network/akash-api/go/node/client/v1beta2" + atls "github.com/akash-network/akash-api/go/util/tls" "github.com/golang-jwt/jwt/v4" "github.com/gorilla/websocket" "github.com/pkg/errors" @@ -85,9 +86,29 @@ type ServiceLogs struct { OnClose <-chan string } +type httpClient interface { + Do(*http.Request) (*http.Response, error) +} + +type client struct { + host *url.URL + addr sdk.Address + cclient ctypes.QueryClient + certs []tls.Certificate +} + +type reqClient struct { + ctx context.Context + host *url.URL + hclient httpClient + wsclient *websocket.Dialer + addr sdk.Address + cclient ctypes.QueryClient +} + // NewClient returns a new Client -func NewClient(qclient aclient.QueryClient, addr sdk.Address, certs []tls.Certificate) (Client, error) { - res, err := qclient.Provider(context.Background(), &ptypes.QueryProviderRequest{Owner: addr.String()}) +func NewClient(ctx context.Context, qclient aclient.QueryClient, addr sdk.Address, certs []tls.Certificate) (Client, error) { + res, err := qclient.Provider(ctx, &ptypes.QueryProviderRequest{Owner: addr.String()}) if err != nil { return nil, err } @@ -97,20 +118,26 @@ func NewClient(qclient aclient.QueryClient, addr sdk.Address, certs []tls.Certif return nil, err } - return newClient(qclient, addr, certs, uri), nil -} - -func newClient(qclient aclient.QueryClient, addr sdk.Address, certs []tls.Certificate, uri *url.URL) *client { - cl := &client{ + return &client{ host: uri, addr: addr, cclient: qclient, + certs: certs, + }, nil +} + +func (c *client) newReqClient(ctx context.Context) *reqClient { + cl := &reqClient{ + ctx: ctx, + host: c.host, + addr: c.addr, + cclient: c.cclient, } tlsConfig := &tls.Config{ // must use Hostname rather than Host field as certificate is issued for host without port - ServerName: uri.Hostname(), - Certificates: certs, + ServerName: cl.host.Hostname(), + Certificates: c.certs, InsecureSkipVerify: true, // nolint: gosec VerifyPeerCertificate: cl.verifyPeerCertificate, MinVersion: tls.VersionTLS13, @@ -187,18 +214,6 @@ func NewClientDirectory(ctx context.Context, cctx cosmosclient.Context) (*Client }, nil } -type httpClient interface { - Do(*http.Request) (*http.Response, error) -} - -type client struct { - host *url.URL - hclient httpClient - wsclient *websocket.Dialer - addr sdk.Address - cclient ctypes.QueryClient -} - type ClientCustomClaims struct { AkashNamespace *AkashNamespace `json:"https://akash.network/"` jwt.RegisteredClaims @@ -244,7 +259,8 @@ func (c *client) GetJWT(ctx context.Context) (*jwt.Token, error) { return nil, err } - resp, err := c.hclient.Do(req) + rCl := c.newReqClient(ctx) + resp, err := rCl.hclient.Do(req) if err != nil { return nil, err } @@ -288,68 +304,21 @@ func (err ClientResponseError) ClientError() string { return fmt.Sprintf("Remote Server returned %d\n%s", err.Status, err.Message) } -func (c *client) verifyPeerCertificate(certificates [][]byte, _ [][]*x509.Certificate) error { - if len(certificates) != 1 { - return errors.Errorf("tls: invalid certificate chain") - } +func (c *reqClient) verifyPeerCertificate(certificates [][]byte, _ [][]*x509.Certificate) error { + peerCerts := make([]*x509.Certificate, 0, len(certificates)) - cert, err := x509.ParseCertificate(certificates[0]) - if err != nil { - return errors.Wrap(err, "tls: failed to parse certificate") - } - - // validation - var prov sdk.Address - if prov, err = sdk.AccAddressFromBech32(cert.Subject.CommonName); err != nil { - return errors.Wrap(err, "tls: invalid certificate's subject common name") - } - - // 1. CommonName in issuer and Subject must be the same - if cert.Subject.CommonName != cert.Issuer.CommonName { - return errors.Wrap(err, "tls: invalid certificate's issuer common name") - } - - if !c.addr.Equals(prov) { - return errors.Errorf("tls: hijacked certificate") - } + for idx := range certificates { + cert, err := x509.ParseCertificate(certificates[idx]) + if err != nil { + return err + } - // 2. serial number must be in - if cert.SerialNumber == nil { - return errors.Wrap(err, "tls: invalid certificate serial number") + peerCerts = append(peerCerts, cert) } - // 3. look up certificate on chain. it must not be revoked - var resp *ctypes.QueryCertificatesResponse - resp, err = c.cclient.Certificates( - context.Background(), - &ctypes.QueryCertificatesRequest{ - Filter: ctypes.CertificateFilter{ - Owner: prov.String(), - Serial: cert.SerialNumber.String(), - State: "valid", - }, - }, - ) + _, _, err := atls.ValidatePeerCertificates(c.ctx, c.cclient, peerCerts, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}) if err != nil { - return errors.Wrap(err, "tls: unable to fetch certificate from chain") - } - if (len(resp.Certificates) != 1) || !resp.Certificates[0].Certificate.IsState(ctypes.CertificateValid) { - return errors.New("tls: attempt to use non-existing or revoked certificate") - } - - certPool := x509.NewCertPool() - certPool.AddCert(cert) - - opts := x509.VerifyOptions{ - DNSName: c.host.Hostname(), - Roots: certPool, - CurrentTime: time.Now(), - KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - MaxConstraintComparisions: 0, - } - - if _, err = cert.Verify(opts); err != nil { - return errors.Wrap(err, "tls: unable to verify certificate") + return err } return nil @@ -390,7 +359,8 @@ func (c *client) Validate(ctx context.Context, gspec dtypes.GroupSpec) (provider } req.Header.Set("Content-Type", contentTypeJSON) - resp, err := c.hclient.Do(req) + rCl := c.newReqClient(ctx) + resp, err := rCl.hclient.Do(req) if err != nil { return provider.ValidateGroupSpecResult{}, err } @@ -435,7 +405,10 @@ func (c *client) SubmitManifest(ctx context.Context, dseq uint64, mani manifest. } req.Header.Set("Content-Type", contentTypeJSON) - resp, err := c.hclient.Do(req) + + rCl := c.newReqClient(ctx) + resp, err := rCl.hclient.Do(req) + if err != nil { return err } @@ -475,7 +448,8 @@ func (c *client) MigrateEndpoints(ctx context.Context, endpoints []string, dseq } req.Header.Set("Content-Type", contentTypeJSON) - resp, err := c.hclient.Do(req) + rCl := c.newReqClient(ctx) + resp, err := rCl.hclient.Do(req) if err != nil { return err } @@ -515,7 +489,8 @@ func (c *client) MigrateHostnames(ctx context.Context, hostnames []string, dseq } req.Header.Set("Content-Type", contentTypeJSON) - resp, err := c.hclient.Do(req) + rCl := c.newReqClient(ctx) + resp, err := rCl.hclient.Do(req) if err != nil { return err } @@ -563,7 +538,8 @@ func (c *client) LeaseEvents(ctx context.Context, id mtypes.LeaseID, _ string, f query.Set("follow", strconv.FormatBool(follow)) endpoint.RawQuery = query.Encode() - conn, response, err := c.wsclient.DialContext(ctx, endpoint.String(), nil) + rCl := c.newReqClient(ctx) + conn, response, err := rCl.wsclient.DialContext(ctx, endpoint.String(), nil) if err != nil { if errors.Is(err, websocket.ErrBadHandshake) { buf := &bytes.Buffer{} @@ -663,7 +639,8 @@ func (c *client) getStatus(ctx context.Context, uri string, obj interface{}) err } req.Header.Set("Content-Type", contentTypeJSON) - resp, err := c.hclient.Do(req) + rCl := c.newReqClient(ctx) + resp, err := rCl.hclient.Do(req) if err != nil { return err } @@ -737,7 +714,8 @@ func (c *client) LeaseLogs(ctx context.Context, endpoint.RawQuery = query.Encode() - conn, response, err := c.wsclient.DialContext(ctx, endpoint.String(), nil) + rCl := c.newReqClient(ctx) + conn, response, err := rCl.wsclient.DialContext(ctx, endpoint.String(), nil) if err != nil { if errors.Is(err, websocket.ErrBadHandshake) { buf := &bytes.Buffer{} diff --git a/gateway/rest/client_shell.go b/gateway/rest/client_shell.go index f8ef6a3b..47afbf92 100644 --- a/gateway/rest/client_shell.go +++ b/gateway/rest/client_shell.go @@ -63,7 +63,9 @@ func (c *client) LeaseShell(ctx context.Context, lID mtypes.LeaseID, service str endpoint.RawQuery = query.Encode() subctx, subcancel := context.WithCancel(ctx) - conn, response, err := c.wsclient.DialContext(subctx, endpoint.String(), nil) + + rCl := c.newReqClient(ctx) + conn, response, err := rCl.wsclient.DialContext(subctx, endpoint.String(), nil) if err != nil { if errors.Is(err, websocket.ErrBadHandshake) { buf := &bytes.Buffer{} diff --git a/gateway/rest/integration_test.go b/gateway/rest/integration_test.go index b7f956b6..d3fb20e7 100644 --- a/gateway/rest/integration_test.go +++ b/gateway/rest/integration_test.go @@ -6,16 +6,17 @@ import ( "crypto/tls" "testing" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + akashmanifest "github.com/akash-network/akash-api/go/manifest/v2beta2" qmock "github.com/akash-network/akash-api/go/node/client/v1beta2/mocks" dtypes "github.com/akash-network/akash-api/go/node/deployment/v1beta3" mtypes "github.com/akash-network/akash-api/go/node/market/v1beta4" providertypes "github.com/akash-network/akash-api/go/node/provider/v1beta3" - "github.com/akash-network/node/testutil" - sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" + "github.com/akash-network/akash-api/go/testutil" "github.com/akash-network/provider" pcmock "github.com/akash-network/provider/cluster/mocks" @@ -35,7 +36,7 @@ func Test_router_Status(t *testing.T) { mocks.pclient.On("Status", mock.Anything).Return(expected, nil) withServer(t, addr, mocks.pclient, mocks.qclient, nil, func(host string) { - client, err := NewClient(mocks.qclient, addr, nil) + client, err := NewClient(context.Background(), mocks.qclient, addr, nil) assert.NoError(t, err) result, err := client.Status(context.Background()) assert.NoError(t, err) @@ -49,7 +50,7 @@ func Test_router_Status(t *testing.T) { mocks := createMocks() mocks.pclient.On("Status", mock.Anything).Return(nil, errors.New("oops")) withServer(t, addr, mocks.pclient, mocks.qclient, nil, func(host string) { - client, err := NewClient(mocks.qclient, addr, nil) + client, err := NewClient(context.Background(), mocks.qclient, addr, nil) assert.NoError(t, err) _, err = client.Status(context.Background()) assert.Error(t, err) @@ -68,7 +69,7 @@ func Test_router_Validate(t *testing.T) { mocks.pclient.On("Validate", mock.Anything, mock.Anything, mock.Anything).Return(expected, nil) withServer(t, addr, mocks.pclient, mocks.qclient, nil, func(host string) { cert := testutil.Certificate(t, testutil.AccAddress(t), testutil.CertificateOptionMocks(mocks.qclient)) - client, err := NewClient(mocks.qclient, addr, cert.Cert) + client, err := NewClient(context.Background(), mocks.qclient, addr, cert.Cert) assert.NoError(t, err) result, err := client.Validate(context.Background(), testutil.GroupSpec(t)) assert.NoError(t, err) @@ -83,7 +84,7 @@ func Test_router_Validate(t *testing.T) { mocks.pclient.On("Validate", mock.Anything, mock.Anything, mock.Anything).Return(provider.ValidateGroupSpecResult{}, errors.New("oops")) withServer(t, addr, mocks.pclient, mocks.qclient, nil, func(host string) { cert := testutil.Certificate(t, testutil.AccAddress(t), testutil.CertificateOptionMocks(mocks.qclient)) - client, err := NewClient(mocks.qclient, addr, cert.Cert) + client, err := NewClient(context.Background(), mocks.qclient, addr, cert.Cert) assert.NoError(t, err) _, err = client.Validate(context.Background(), dtypes.GroupSpec{}) assert.Error(t, err) @@ -105,7 +106,7 @@ func Test_router_Manifest(t *testing.T) { mocks.pmclient.On("Submit", mock.Anything, did, akashmanifest.Manifest(nil)).Return(nil) withServer(t, paddr, mocks.pclient, mocks.qclient, nil, func(host string) { cert := testutil.Certificate(t, caddr, testutil.CertificateOptionMocks(mocks.qclient)) - client, err := NewClient(mocks.qclient, paddr, cert.Cert) + client, err := NewClient(context.Background(), mocks.qclient, paddr, cert.Cert) assert.NoError(t, err) err = client.SubmitManifest(context.Background(), did.DSeq, nil) assert.NoError(t, err) @@ -124,7 +125,7 @@ func Test_router_Manifest(t *testing.T) { mocks.pmclient.On("Submit", mock.Anything, did, akashmanifest.Manifest(nil)).Return(errors.New("ded")) withServer(t, paddr, mocks.pclient, mocks.qclient, nil, func(host string) { cert := testutil.Certificate(t, caddr, testutil.CertificateOptionMocks(mocks.qclient)) - client, err := NewClient(mocks.qclient, paddr, cert.Cert) + client, err := NewClient(context.Background(), mocks.qclient, paddr, cert.Cert) assert.NoError(t, err) err = client.SubmitManifest(context.Background(), did.DSeq, nil) assert.Error(t, err) @@ -208,7 +209,7 @@ func Test_router_LeaseStatus(t *testing.T) { withServer(t, paddr, mocks.pclient, mocks.qclient, nil, func(host string) { cert := testutil.Certificate(t, caddr, testutil.CertificateOptionMocks(mocks.qclient)) - client, err := NewClient(mocks.qclient, paddr, cert.Cert) + client, err := NewClient(context.Background(), mocks.qclient, paddr, cert.Cert) assert.NoError(t, err) status, err := client.LeaseStatus(context.Background(), id) expected := LeaseStatus{ @@ -246,7 +247,7 @@ func Test_router_LeaseStatus(t *testing.T) { withServer(t, paddr, mocks.pclient, mocks.qclient, nil, func(host string) { cert := testutil.Certificate(t, caddr, testutil.CertificateOptionMocks(mocks.qclient)) - client, err := NewClient(mocks.qclient, paddr, cert.Cert) + client, err := NewClient(context.Background(), mocks.qclient, paddr, cert.Cert) assert.NoError(t, err) status, err := client.LeaseStatus(context.Background(), id) assert.Error(t, err) @@ -271,7 +272,7 @@ func Test_router_ServiceStatus(t *testing.T) { mocks.pcclient.On("ServiceStatus", mock.Anything, id, service).Return(expected, nil) withServer(t, paddr, mocks.pclient, mocks.qclient, nil, func(host string) { cert := testutil.Certificate(t, caddr, testutil.CertificateOptionMocks(mocks.qclient)) - client, err := NewClient(mocks.qclient, paddr, cert.Cert) + client, err := NewClient(context.Background(), mocks.qclient, paddr, cert.Cert) assert.NoError(t, err) status, err := client.ServiceStatus(context.Background(), id, service) assert.NoError(t, err) @@ -293,7 +294,7 @@ func Test_router_ServiceStatus(t *testing.T) { mocks.pcclient.On("ServiceStatus", mock.Anything, id, service).Return(nil, errors.New("ded")) withServer(t, paddr, mocks.pclient, mocks.qclient, nil, func(host string) { cert := testutil.Certificate(t, caddr, testutil.CertificateOptionMocks(mocks.qclient)) - client, err := NewClient(mocks.qclient, paddr, cert.Cert) + client, err := NewClient(context.Background(), mocks.qclient, paddr, cert.Cert) assert.NoError(t, err) status, err := client.ServiceStatus(context.Background(), id, service) assert.Nil(t, status) diff --git a/gateway/rest/router_test.go b/gateway/rest/router_test.go index f714b4fe..5e9cc8f2 100644 --- a/gateway/rest/router_test.go +++ b/gateway/rest/router_test.go @@ -2,6 +2,7 @@ package rest import ( "bytes" + "context" "crypto/tls" "encoding/json" "io" @@ -23,9 +24,9 @@ import ( dtypes "github.com/akash-network/akash-api/go/node/deployment/v1beta3" mtypes "github.com/akash-network/akash-api/go/node/market/v1beta4" types "github.com/akash-network/akash-api/go/node/market/v1beta4" + "github.com/akash-network/akash-api/go/testutil" "github.com/akash-network/node/sdl" - "github.com/akash-network/node/testutil" "github.com/akash-network/provider" kubeclienterrors "github.com/akash-network/provider/cluster/kube/errors" @@ -108,7 +109,7 @@ func runRouterTest(t *testing.T, authClient bool, fn func(*routerTest)) { mf.host, err = url.Parse(host) require.NoError(t, err) - gclient, err := NewClient(mocks.qclient, mf.paddr, certs) + gclient, err := NewClient(context.Background(), mocks.qclient, mf.paddr, certs) require.NoError(t, err) require.NotNil(t, gclient) @@ -145,11 +146,12 @@ func testCertHelper(t *testing.T, test *routerTest) { req.Header.Set("Content-Type", contentTypeJSON) - _, err = test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + _, err = rCl.hclient.Do(req) require.Error(t, err) // return error message looks like // Put "https://127.0.0.1:58536/deployment/652/manifest": tls: unable to verify certificate: x509: cannot validate certificate for 127.0.0.1 because it doesn't contain any IP SANs - require.Regexp(t, "^(Put|Get) (\".*\": )tls: unable to verify certificate: .*$", err.Error()) + require.Regexp(t, `^(Put|Get) (".*": )tls: unable to verify certificate: \(.*\)$`, err.Error()) } func TestRouteNotActiveClientCert(t *testing.T) { @@ -177,7 +179,7 @@ func TestRouteNotActiveClientCert(t *testing.T) { mf.host, err = url.Parse(host) require.NoError(t, err) - gclient, err := NewClient(mocks.qclient, mf.paddr, mf.ccert.Cert) + gclient, err := NewClient(context.Background(), mocks.qclient, mf.paddr, mf.ccert.Cert) require.NoError(t, err) require.NotNil(t, gclient) @@ -213,7 +215,7 @@ func TestRouteExpiredClientCert(t *testing.T) { mf.host, err = url.Parse(host) require.NoError(t, err) - gclient, err := NewClient(mocks.qclient, mf.paddr, mf.ccert.Cert) + gclient, err := NewClient(context.Background(), mocks.qclient, mf.paddr, mf.ccert.Cert) require.NoError(t, err) require.NotNil(t, gclient) @@ -252,7 +254,7 @@ func TestRouteNotActiveServerCert(t *testing.T) { mf.host, err = url.Parse(host) require.NoError(t, err) - gclient, err := NewClient(mocks.qclient, mf.paddr, mf.ccert.Cert) + gclient, err := NewClient(context.Background(), mocks.qclient, mf.paddr, mf.ccert.Cert) require.NoError(t, err) require.NotNil(t, gclient) @@ -292,7 +294,7 @@ func TestRouteExpiredServerCert(t *testing.T) { mf.host, err = url.Parse(host) require.NoError(t, err) - gclient, err := NewClient(mocks.qclient, mf.paddr, mf.ccert.Cert) + gclient, err := NewClient(context.Background(), mocks.qclient, mf.paddr, mf.ccert.Cert) require.NoError(t, err) require.NotNil(t, gclient) @@ -312,7 +314,9 @@ func TestRouteDoesNotExist(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) + require.NoError(t, err) require.Equal(t, http.StatusNotFound, resp.StatusCode) }) @@ -356,7 +360,8 @@ func TestRouteVersionOK(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) var data versionInfo @@ -389,7 +394,8 @@ func TestRouteStatusOK(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) data := make(map[string]interface{}) @@ -413,7 +419,8 @@ func TestRouteStatusFails(t *testing.T) { require.NoError(t, err) req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusInternalServerError, resp.StatusCode) @@ -443,7 +450,8 @@ func TestRouteValidateOK(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) data := make(map[string]interface{}) @@ -473,7 +481,8 @@ func TestRouteValidateUnauthorized(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, resp.StatusCode) }) @@ -494,7 +503,8 @@ func TestRouteValidateFails(t *testing.T) { require.NoError(t, err) req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusInternalServerError, resp.StatusCode) @@ -515,7 +525,8 @@ func TestRouteValidateFailsEmptyBody(t *testing.T) { require.NoError(t, err) req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -555,7 +566,8 @@ func TestRoutePutManifestOK(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -596,7 +608,8 @@ func TestRoutePutInvalidManifest(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode) @@ -689,7 +702,8 @@ func TestRouteLeaseStatusOk(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -737,7 +751,8 @@ func TestRouteLeaseNotInKubernetes(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusNotFound, resp.StatusCode) }) @@ -759,7 +774,8 @@ func TestRouteLeaseStatusErr(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusInternalServerError, resp.StatusCode) @@ -809,7 +825,8 @@ func TestRouteServiceStatusOK(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -849,7 +866,8 @@ func TestRouteServiceStatusNoDeployment(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusNotFound, resp.StatusCode) @@ -900,7 +918,8 @@ func TestRouteServiceStatusKubernetesNotFound(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusNotFound, resp.StatusCode) @@ -939,7 +958,8 @@ func TestRouteServiceStatusError(t *testing.T) { req.Header.Set("Content-Type", contentTypeJSON) - resp, err := test.gwclient.hclient.Do(req) + rCl := test.gwclient.newReqClient(context.Background()) + resp, err := rCl.hclient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusInternalServerError, resp.StatusCode) diff --git a/gateway/utils/utils.go b/gateway/utils/utils.go index e030e4c6..3920767c 100644 --- a/gateway/utils/utils.go +++ b/gateway/utils/utils.go @@ -4,13 +4,9 @@ import ( "context" "crypto/tls" "crypto/x509" - "time" - - "github.com/pkg/errors" - - sdk "github.com/cosmos/cosmos-sdk/types" ctypes "github.com/akash-network/akash-api/go/node/cert/v1beta3" + atls "github.com/akash-network/akash-api/go/util/tls" ) func NewServerTLSConfig(ctx context.Context, certs []tls.Certificate, cquery ctypes.QueryClient) (*tls.Config, error) { @@ -23,62 +19,20 @@ func NewServerTLSConfig(ctx context.Context, certs []tls.Certificate, cquery cty MinVersion: tls.VersionTLS13, VerifyPeerCertificate: func(certificates [][]byte, _ [][]*x509.Certificate) error { if len(certificates) > 0 { - if len(certificates) != 1 { - return errors.Errorf("tls: invalid certificate chain") - } - - cert, err := x509.ParseCertificate(certificates[0]) - if err != nil { - return errors.Wrap(err, "tls: failed to parse certificate") - } + peerCerts := make([]*x509.Certificate, 0, len(certificates)) - // validation - var owner sdk.Address - if owner, err = sdk.AccAddressFromBech32(cert.Subject.CommonName); err != nil { - return errors.Wrap(err, "tls: invalid certificate's subject common name") - } + for idx := range certificates { + cert, err := x509.ParseCertificate(certificates[idx]) + if err != nil { + return err + } - // 1. CommonName in issuer and Subject must match and be as Bech32 format - if cert.Subject.CommonName != cert.Issuer.CommonName { - return errors.Wrap(err, "tls: invalid certificate's issuer common name") + peerCerts = append(peerCerts, cert) } - // 2. serial number must be in - if cert.SerialNumber == nil { - return errors.Wrap(err, "tls: invalid certificate serial number") - } - - // 3. look up certificate on chain - var resp *ctypes.QueryCertificatesResponse - resp, err = cquery.Certificates( - ctx, - &ctypes.QueryCertificatesRequest{ - Filter: ctypes.CertificateFilter{ - Owner: owner.String(), - Serial: cert.SerialNumber.String(), - State: "valid", - }, - }, - ) + _, _, err := atls.ValidatePeerCertificates(ctx, cquery, peerCerts, []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}) if err != nil { - return errors.Wrap(err, "tls: unable to fetch certificate from chain") - } - if (len(resp.Certificates) != 1) || !resp.Certificates[0].Certificate.IsState(ctypes.CertificateValid) { - return errors.New("tls: attempt to use non-existing or revoked certificate") - } - - clientCertPool := x509.NewCertPool() - clientCertPool.AddCert(cert) - - opts := x509.VerifyOptions{ - Roots: clientCertPool, - CurrentTime: time.Now(), - KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - MaxConstraintComparisions: 0, - } - - if _, err = cert.Verify(opts); err != nil { - return errors.Wrap(err, "tls: unable to verify certificate") + return err } } return nil diff --git a/go.mod b/go.mod index 05d1995a..1d548b89 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module github.com/akash-network/provider go 1.21 require ( - github.com/akash-network/akash-api v0.0.65 - github.com/akash-network/node v0.33.0-rc0 + github.com/akash-network/akash-api v0.0.66 + github.com/akash-network/node v0.34.0 github.com/avast/retry-go/v4 v4.5.0 github.com/blang/semver/v4 v4.0.0 github.com/boz/go-lifecycle v0.1.1 diff --git a/go.sum b/go.sum index 6ae184dc..0b2f9144 100644 --- a/go.sum +++ b/go.sum @@ -197,16 +197,16 @@ github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= -github.com/akash-network/akash-api v0.0.65 h1:Jadkbu9rvE5UIrd2t0FRbEGMeco/KUEzD/8w6zqxPlI= -github.com/akash-network/akash-api v0.0.65/go.mod h1:pNr61L4+0sheol7ZK0HjgK3rxpIAbYBGq1w1oH4B0+M= +github.com/akash-network/akash-api v0.0.66 h1:HGYbjLmnKj7hNIO2V7f6CiHJfZJzeOBCIV45gRo/AbY= +github.com/akash-network/akash-api v0.0.66/go.mod h1:PdOQGTCX3kLBoKHdbPF9pe5+vSLANaMJbgA04UE+OqY= github.com/akash-network/cometbft v0.34.27-akash h1:V1dApDOr8Ee7BJzYyQ7Z9VBtrAul4+baMeA6C49dje0= github.com/akash-network/cometbft v0.34.27-akash/go.mod h1:BcCbhKv7ieM0KEddnYXvQZR+pZykTKReJJYf7YC7qhw= github.com/akash-network/ledger-go v0.14.3 h1:LCEFkTfgGA2xFMN2CtiKvXKE7dh0QSM77PJHCpSkaAo= github.com/akash-network/ledger-go v0.14.3/go.mod h1:NfsjfFvno9Kaq6mfpsKz4sqjnAVVEsVsnBJfKB4ueAs= github.com/akash-network/ledger-go/cosmos v0.14.4 h1:h3WiXmoKKs9wkj1LHcJ12cLjXXg6nG1fp+UQ5+wu/+o= github.com/akash-network/ledger-go/cosmos v0.14.4/go.mod h1:SjAfheQTE4rWk0ir+wjbOWxwj8nc8E4AZ08NdsvYG24= -github.com/akash-network/node v0.33.0-rc0 h1:RQuIRDLvu0kugCxwlxFr9X3B/JKYc5SChBSy/vNNZJE= -github.com/akash-network/node v0.33.0-rc0/go.mod h1:EnqNTPmvkKK0CHO1SqyF5ozAPJXpgmyFpBGak+KcPDY= +github.com/akash-network/node v0.34.0 h1:qBLEJlMDs7hSt1skomdPAtTXIXiAAmezYob8V+gG5Ks= +github.com/akash-network/node v0.34.0/go.mod h1:EnqNTPmvkKK0CHO1SqyF5ozAPJXpgmyFpBGak+KcPDY= github.com/alecthomas/participle/v2 v2.0.0-alpha7 h1:cK4vjj0VSgb3lN1nuKA5F7dw+1s1pWBe5bx7nNCnN+c= github.com/alecthomas/participle/v2 v2.0.0-alpha7/go.mod h1:NumScqsC42o9x+dGj8/YqsIfhrIQjFEOFovxotbBirA= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=