From a670be39e8033fecc3bf153b4307d2bd68209ccc Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Fri, 11 Oct 2024 16:36:58 -0400 Subject: [PATCH] cert-v2: backwards compatibility trickery for ipv6 (#1245) --- cert/cert.go | 4 +++ connection_state.go | 13 ++++---- control.go | 1 + handshake_ix.go | 77 ++++++++++++++++++++++++++++++++++++++----- interface.go | 13 +++++++- lighthouse.go | 79 +++++++++++++++++++++++++++++++-------------- main.go | 12 ++++++- pki.go | 14 ++++---- ssh.go | 1 + 9 files changed, 167 insertions(+), 47 deletions(-) diff --git a/cert/cert.go b/cert/cert.go index 6904ab6a2..496e12e02 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -109,6 +109,10 @@ type CachedCertificate struct { signerFingerprint string } +func (cc *CachedCertificate) String() string { + return cc.Certificate.String() +} + // UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate. func UnmarshalCertificate(b []byte) (Certificate, error) { //TODO: you left off here, no one uses this function but it might be beneficial to export _something_ that someone can use, maybe the Versioned unmarshallsers? diff --git a/connection_state.go b/connection_state.go index a13164d5c..faee443de 100644 --- a/connection_state.go +++ b/connection_state.go @@ -3,6 +3,7 @@ package nebula import ( "crypto/rand" "encoding/json" + "fmt" "sync" "sync/atomic" @@ -26,8 +27,7 @@ type ConnectionState struct { writeLock sync.Mutex } -func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern noise.HandshakePattern) *ConnectionState { - crt := cs.GetDefaultCertificate() +func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { var dhFunc noise.DHFunc switch crt.Curve() { case cert.Curve_CURVE25519: @@ -39,8 +39,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern dhFunc = noiseutil.DHP256 } default: - l.Errorf("invalid curve: %s", crt.Curve()) - return nil + return nil, fmt.Errorf("invalid curve: %s", crt.Curve()) } var ncs noise.CipherSuite @@ -53,7 +52,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} b := NewBits(ReplayWindow) - // Clear out bit 0, we never transmit it and we don't want it showing as packet loss + // Clear out bit 0, we never transmit it, and we don't want it showing as packet loss b.Update(l, 0) hs, err := noise.NewHandshakeState(noise.Config{ @@ -67,7 +66,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern PresharedKeyPlacement: 0, }) if err != nil { - return nil + return nil, fmt.Errorf("NewConnectionState: %s", err) } // The queue and ready params prevent a counter race that would happen when @@ -81,7 +80,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern // always start the counter from 2, as packet 1 and packet 2 are handshake packets. ci.messageCounter.Add(2) - return ci + return ci, nil } func (cs *ConnectionState) MarshalJSON() ([]byte, error) { diff --git a/control.go b/control.go index 75bdbd771..866e9db25 100644 --- a/control.go +++ b/control.go @@ -134,6 +134,7 @@ func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { _, found := c.f.myVpnAddrsTable.Lookup(vpnIp) if found { //TODO: we might have 2 certs.... + //TODO: this should return our latest version cert return c.f.pki.getDefaultCertificate().Copy() } hi := c.f.hostMap.QueryVpnAddr(vpnIp) diff --git a/handshake_ix.go b/handshake_ix.go index df853ab1a..5a551a205 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -23,22 +23,55 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { return false } + // If we're connecting to a v6 address we must use a v2 cert cs := f.pki.getCertState() - ci := NewConnectionState(f.l, cs, true, noise.HandshakeIX) + v := cs.defaultVersion + for _, a := range hh.hostinfo.vpnAddrs { + if a.Is6() { + v = cert.Version2 + break + } + } + + crt := cs.getCertificate(v) + if crt == nil { + f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Unable to handshake with host because no certificate is available") + return false + } + + crtHs := cs.getHandshakeBytes(v) + if crtHs == nil { + f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Unable to handshake with host because no certificate handshake bytes is available") + } + + ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) + if err != nil { + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Failed to create connection state") + return false + } hh.hostinfo.ConnectionState = ci hs := &NebulaHandshake{ Details: &NebulaHandshakeDetails{ InitiatorIndex: hh.hostinfo.localIndexId, Time: uint64(time.Now().UnixNano()), - Cert: cs.getDefaultHandshakeBytes(), - CertVersion: uint32(cs.defaultVersion), + Cert: crtHs, + CertVersion: uint32(v), }, } hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return false } @@ -63,14 +96,30 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { cs := f.pki.getCertState() - ci := NewConnectionState(f.l, cs, false, noise.HandshakeIX) + crt := cs.GetDefaultCertificate() + if crt == nil { + f.l.WithField("udpAddr", addr). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", cs.defaultVersion). + Error("Unable to handshake with host because no certificate is available") + } + + ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) + if err != nil { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to create connection state") + return + } + // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to call noise.ReadMessage") return } @@ -78,7 +127,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed unmarshal handshake message") return } @@ -98,7 +148,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if remoteCert.Certificate.Version() != ci.myCert.Version() { // We started off using the wrong certificate version, lets see if we can match the version that was sent to us rc := cs.getCertificate(remoteCert.Certificate.Version()) - //TODO: anywhere we are logging remoteCert needs to be remoteCert.Certificate OR we make a pass through func on CachedCertificate if rc == nil { f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). @@ -183,6 +232,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hs.Details.ResponderIndex = myIndex hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) + if hs.Details.Cert == nil { + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithField("certVersion", ci.myCert.Version()). + Error("Unable to handshake with host because no certificate handshake bytes is available") + return + } + hs.Details.CertVersion = uint32(ci.myCert.Version()) // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) diff --git a/interface.go b/interface.go index 378c35450..b4903f103 100644 --- a/interface.go +++ b/interface.go @@ -14,6 +14,7 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -327,7 +328,17 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.pki.getDefaultCertificate(), c) + cs := f.pki.getCertState() + certificate := cs.getCertificate(cert.Version2) + if certificate == nil { + certificate = cs.getCertificate(cert.Version1) + } + + if certificate == nil { + panic("No certificate available to reconfigure the firewall") + } + + fw, err := NewFirewallFromConfig(f.l, certificate, c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return diff --git a/lighthouse.go b/lighthouse.go index 1beaa331c..6825b8252 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -738,42 +738,73 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { return } - // Send a query to the lighthouses and hope for the best next time - v := lh.ifce.GetCertState().defaultVersion msg := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{}, } - if v == 1 { - if !addr.Is4() { - lh.l.WithField("vpnAddr", addr).Error("Can't query lighthouse for v6 address using a v1 protocol") - return + var v1Query, v2Query []byte + var err error + var v cert.Version + queried := 0 + lighthouses := lh.GetLighthouses() + + for lhVpnAddr := range lighthouses { + hi := lh.ifce.GetHostInfo(lhVpnAddr) + if hi != nil { + v = hi.ConnectionState.myCert.Version() + } else { + v = lh.ifce.GetCertState().defaultVersion } - b := addr.As4() - msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) - } else if v == 2 { - msg.Details.VpnAddr = netAddrToProtoAddr(addr) + if v == cert.Version1 { + if !addr.Is4() { + lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr). + Error("Can't query lighthouse for v6 address using a v1 protocol") + continue + } - } else { - panic("unsupported version") - } + if v1Query == nil { + b := addr.As4() + msg.Details.VpnAddr = nil + msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) - query, err := msg.Marshal() - if err != nil { - lh.l.WithError(err).WithField("vpnAddr", addr).Error("Failed to marshal lighthouse query payload") - return - } + v1Query, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("queryVpnAddr", addr). + WithField("lighthouseAddr", lhVpnAddr). + Error("Failed to marshal lighthouse v1 query payload") + continue + } + } - lighthouses := lh.GetLighthouses() - lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) + lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, lhVpnAddr, v1Query, nb, out) + queried++ - for n := range lighthouses { - //TODO: there is a slight possibility this lighthouse is using a v2 protocol even if our default is v1 - // We could facilitate the move to v2 by marshalling a v2 query - lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) + } else if v == cert.Version2 { + if v2Query == nil { + msg.Details.OldVpnAddr = 0 + msg.Details.VpnAddr = netAddrToProtoAddr(addr) + + v2Query, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("queryVpnAddr", addr). + WithField("lighthouseAddr", lhVpnAddr). + Error("Failed to marshal lighthouse v2 query payload") + continue + } + } + + lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, lhVpnAddr, v2Query, nb, out) + queried++ + + } else { + lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v) + continue + } } + + lh.metricTx(NebulaMeta_HostQuery, int64(queried)) } func (lh *LightHouse) StartUpdateWorker() { diff --git a/main.go b/main.go index 6aea39a0f..894305ede 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,8 @@ import ( "net/netip" "time" + "github.com/slackhq/nebula/cert" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" @@ -60,7 +62,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - certificate := pki.getDefaultCertificate() + cs := pki.getCertState() + certificate := cs.getCertificate(cert.Version2) + if certificate == nil { + certificate = cs.getCertificate(cert.Version1) + } + + if certificate == nil { + panic("No certificates available to configure the firewall") + } fw, err := NewFirewallFromConfig(l, certificate, c) if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) diff --git a/pki.go b/pki.go index 779a1598a..69c60863b 100644 --- a/pki.go +++ b/pki.go @@ -70,10 +70,12 @@ func (p *PKI) getCertState() *CertState { return p.cs.Load() } +// TODO: We should remove this func (p *PKI) getDefaultCertificate() cert.Certificate { return p.cs.Load().GetDefaultCertificate() } +// TODO: We should remove this func (p *PKI) getCertificate(v cert.Version) cert.Certificate { return p.cs.Load().getCertificate(v) } @@ -209,10 +211,6 @@ func (cs *CertState) GetDefaultCertificate() cert.Certificate { return c } -func (cs *CertState) getDefaultHandshakeBytes() []byte { - return cs.getHandshakeBytes(cs.defaultVersion) -} - func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { switch v { case cert.Version1: @@ -224,15 +222,17 @@ func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { return nil } +// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version. +// Callers must check if the return []byte is nil. func (cs *CertState) getHandshakeBytes(v cert.Version) []byte { switch v { case cert.Version1: return cs.v1HandshakeBytes case cert.Version2: return cs.v2HandshakeBytes + default: + return nil } - - panic("No handshake bytes found") } func (cs *CertState) String() string { @@ -370,6 +370,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil) } + //TODO: make sure v2 has v1s address + cs.defaultVersion = dv } diff --git a/ssh.go b/ssh.go index a04b28b23..5244e4a4a 100644 --- a/ssh.go +++ b/ssh.go @@ -785,6 +785,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return nil } + //TODO: This should return both certs cert := ifce.pki.getDefaultCertificate() if len(a) > 0 { vpnIp, err := netip.ParseAddr(a[0])