Skip to content

Commit

Permalink
cert-v2: backwards compatibility trickery for ipv6 (#1245)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackDoanRivian authored Oct 11, 2024
1 parent 92ce06e commit a670be3
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 47 deletions.
4 changes: 4 additions & 0 deletions cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ type CachedCertificate struct {
signerFingerprint string
}

func (cc *CachedCertificate) String() string {
return cc.Certificate.String()
}

// UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate.
func UnmarshalCertificate(b []byte) (Certificate, error) {
//TODO: you left off here, no one uses this function but it might be beneficial to export _something_ that someone can use, maybe the Versioned unmarshallsers?
Expand Down
13 changes: 6 additions & 7 deletions connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package nebula
import (
"crypto/rand"
"encoding/json"
"fmt"
"sync"
"sync/atomic"

Expand All @@ -26,8 +27,7 @@ type ConnectionState struct {
writeLock sync.Mutex
}

func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern noise.HandshakePattern) *ConnectionState {
crt := cs.GetDefaultCertificate()
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
var dhFunc noise.DHFunc
switch crt.Curve() {
case cert.Curve_CURVE25519:
Expand All @@ -39,8 +39,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern
dhFunc = noiseutil.DHP256
}
default:
l.Errorf("invalid curve: %s", crt.Curve())
return nil
return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
}

var ncs noise.CipherSuite
Expand All @@ -53,7 +52,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}

b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0)

hs, err := noise.NewHandshakeState(noise.Config{
Expand All @@ -67,7 +66,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern
PresharedKeyPlacement: 0,
})
if err != nil {
return nil
return nil, fmt.Errorf("NewConnectionState: %s", err)
}

// The queue and ready params prevent a counter race that would happen when
Expand All @@ -81,7 +80,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
ci.messageCounter.Add(2)

return ci
return ci, nil
}

func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
Expand Down
1 change: 1 addition & 0 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
if found {
//TODO: we might have 2 certs....
//TODO: this should return our latest version cert
return c.f.pki.getDefaultCertificate().Copy()
}
hi := c.f.hostMap.QueryVpnAddr(vpnIp)
Expand Down
77 changes: 69 additions & 8 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,55 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return false
}

// If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState()
ci := NewConnectionState(f.l, cs, true, noise.HandshakeIX)
v := cs.defaultVersion
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
}

crt := cs.getCertificate(v)
if crt == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate is available")
return false
}

crtHs := cs.getHandshakeBytes(v)
if crtHs == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
}

ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Failed to create connection state")
return false
}
hh.hostinfo.ConnectionState = ci

hs := &NebulaHandshake{
Details: &NebulaHandshakeDetails{
InitiatorIndex: hh.hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: cs.getDefaultHandshakeBytes(),
CertVersion: uint32(cs.defaultVersion),
Cert: crtHs,
CertVersion: uint32(v),
},
}

hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return false
}
Expand All @@ -63,22 +96,39 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {

func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
cs := f.pki.getCertState()
ci := NewConnectionState(f.l, cs, false, noise.HandshakeIX)
crt := cs.GetDefaultCertificate()
if crt == nil {
f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.defaultVersion).
Error("Unable to handshake with host because no certificate is available")
}

ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
return
}

// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)

msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
return
}

hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
return
}

Expand All @@ -98,7 +148,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
rc := cs.getCertificate(remoteCert.Certificate.Version())
//TODO: anywhere we are logging remoteCert needs to be remoteCert.Certificate OR we make a pass through func on CachedCertificate
if rc == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Expand Down Expand Up @@ -183,6 +232,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet

hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return
}

hs.Details.CertVersion = uint32(ci.myCert.Version())
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano())
Expand Down
13 changes: 12 additions & 1 deletion interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
Expand Down Expand Up @@ -327,7 +328,17 @@ func (f *Interface) reloadFirewall(c *config.C) {
return
}

fw, err := NewFirewallFromConfig(f.l, f.pki.getDefaultCertificate(), c)
cs := f.pki.getCertState()
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
}

if certificate == nil {
panic("No certificate available to reconfigure the firewall")
}

fw, err := NewFirewallFromConfig(f.l, certificate, c)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
return
Expand Down
79 changes: 55 additions & 24 deletions lighthouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -738,42 +738,73 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
return
}

// Send a query to the lighthouses and hope for the best next time
v := lh.ifce.GetCertState().defaultVersion
msg := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{},
}

if v == 1 {
if !addr.Is4() {
lh.l.WithField("vpnAddr", addr).Error("Can't query lighthouse for v6 address using a v1 protocol")
return
var v1Query, v2Query []byte
var err error
var v cert.Version
queried := 0
lighthouses := lh.GetLighthouses()

for lhVpnAddr := range lighthouses {
hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil {
v = hi.ConnectionState.myCert.Version()
} else {
v = lh.ifce.GetCertState().defaultVersion
}
b := addr.As4()
msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])

} else if v == 2 {
msg.Details.VpnAddr = netAddrToProtoAddr(addr)
if v == cert.Version1 {
if !addr.Is4() {
lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr).
Error("Can't query lighthouse for v6 address using a v1 protocol")
continue
}

} else {
panic("unsupported version")
}
if v1Query == nil {
b := addr.As4()
msg.Details.VpnAddr = nil
msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])

query, err := msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("vpnAddr", addr).Error("Failed to marshal lighthouse query payload")
return
}
v1Query, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("queryVpnAddr", addr).
WithField("lighthouseAddr", lhVpnAddr).
Error("Failed to marshal lighthouse v1 query payload")
continue
}
}

lighthouses := lh.GetLighthouses()
lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses)))
lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, lhVpnAddr, v1Query, nb, out)
queried++

for n := range lighthouses {
//TODO: there is a slight possibility this lighthouse is using a v2 protocol even if our default is v1
// We could facilitate the move to v2 by marshalling a v2 query
lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
} else if v == cert.Version2 {
if v2Query == nil {
msg.Details.OldVpnAddr = 0
msg.Details.VpnAddr = netAddrToProtoAddr(addr)

v2Query, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("queryVpnAddr", addr).
WithField("lighthouseAddr", lhVpnAddr).
Error("Failed to marshal lighthouse v2 query payload")
continue
}
}

lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, lhVpnAddr, v2Query, nb, out)
queried++

} else {
lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v)
continue
}
}

lh.metricTx(NebulaMeta_HostQuery, int64(queried))
}

func (lh *LightHouse) StartUpdateWorker() {
Expand Down
12 changes: 11 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net/netip"
"time"

"github.com/slackhq/nebula/cert"

"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
Expand Down Expand Up @@ -60,7 +62,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
}

certificate := pki.getDefaultCertificate()
cs := pki.getCertState()
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
}

if certificate == nil {
panic("No certificates available to configure the firewall")
}
fw, err := NewFirewallFromConfig(l, certificate, c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
Expand Down
Loading

0 comments on commit a670be3

Please sign in to comment.