Skip to content

Commit

Permalink
add support for psk2
Browse files Browse the repository at this point in the history
In psk2, the responder doesn't know which preshared key to use until
after they read the first message. To support this, allow setting psk
for reponder after initialization with SetPresharedKey method.
  • Loading branch information
nsmith5 authored and titanous committed Dec 5, 2023
1 parent d803f5c commit acf4844
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 6 deletions.
76 changes: 76 additions & 0 deletions noise_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package noise

import (
"bytes"
"encoding/hex"
"math"
"testing"
Expand Down Expand Up @@ -249,6 +250,81 @@ func (NoiseSuite) TestXXRoundtrip(c *C) {
c.Assert(string(res), Equals, "worri")
}

func (NoiseSuite) Test_IXpsk2_Roundtrip(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256)
rngI := new(RandomInc)
rngR := new(RandomInc)
*rngR = 1

staticI, err := cs.GenerateKeypair(rngI)
if err != nil {
c.Fatal(err)
}
staticR, err := cs.GenerateKeypair(rngR)
if err != nil {
c.Fatal(err)
}

psk := []byte("00000000000000000000000000000000")

hsI, _ := NewHandshakeState(Config{
CipherSuite: cs,
Random: rngI,
Pattern: HandshakeIX,
PresharedKeyPlacement: 2,
PresharedKey: psk,
Initiator: true,
StaticKeypair: staticI,
})
hsR, _ := NewHandshakeState(Config{
CipherSuite: cs,
Random: rngR,
Pattern: HandshakeIX,
PresharedKeyPlacement: 2,
StaticKeypair: staticR,
})

// -> e, s
msg, _, _, _ := hsI.WriteMessage(nil, nil)
c.Assert(msg, HasLen, 96)

res, _, _, err := hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(res, HasLen, 0)

if !bytes.Equal(hsR.PeerStatic(), staticI.Public) {
c.Error("wrong public key from peer")
}

// Look up psk from peer static public key

// responder should know psk now and set it from the
// initiators preshared key
if err = hsR.SetPresharedKey(psk); err != nil {
c.Fatal(err)
}
// <- e, dhee, dhse, s, dhes, psk
msg, csR0, csR1, _ := hsR.WriteMessage(nil, nil)
c.Assert(msg, HasLen, 96)
res, csI0, csI1, err := hsI.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(res, HasLen, 0)

// transport I -> R
msg, err = csI0.Encrypt(nil, nil, []byte("foo"))
c.Assert(err, IsNil)
res, err = csR0.Decrypt(nil, nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "foo")

// transport R -> I
msg, err = csR1.Encrypt(nil, nil, []byte("bar"))
c.Assert(err, IsNil)
res, err = csI1.Decrypt(nil, nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "bar")
}

func (NoiseSuite) Test_NNpsk0_Roundtrip(c *C) {
cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b)
rngI := new(RandomInc)
Expand Down
32 changes: 26 additions & 6 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ type HandshakeState struct {
rs []byte // remote party's static public key
re []byte // remote party's ephemeral public key
psk []byte // preshared key, maybe zero length
willPsk bool // indicates if preshared key will be used (even if not yet set)
messagePatterns [][]MessagePattern
shouldWrite bool
initiator bool
Expand Down Expand Up @@ -299,7 +300,6 @@ func NewHandshakeState(c Config) (*HandshakeState, error) {
s: c.StaticKeypair,
e: c.EphemeralKeypair,
rs: c.PeerStatic,
psk: c.PresharedKey,
messagePatterns: c.Pattern.Messages,
shouldWrite: c.Initiator,
initiator: c.Initiator,
Expand All @@ -313,11 +313,18 @@ func NewHandshakeState(c Config) (*HandshakeState, error) {
copy(hs.re, c.PeerEphemeral)
}
hs.ss.cs = c.CipherSuite

pskModifier := ""
if len(hs.psk) > 0 {
if len(hs.psk) != 32 {
return nil, errors.New("noise: specification mandates 256-bit preshared keys")
// NB: for psk{0,1} we must have preshared key set in configuration as its needed in the first
// message. For psk{2+} we may not know the correct psk yet so it might not be set.
if len(c.PresharedKey) > 0 || c.PresharedKeyPlacement >= 2 {
hs.willPsk = true
if len(c.PresharedKey) > 0 {
if err := hs.SetPresharedKey(c.PresharedKey); err != nil {
return nil, err
}
}

pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement)
hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...)
if c.PresharedKeyPlacement == 0 {
Expand All @@ -326,6 +333,7 @@ func NewHandshakeState(c Config) (*HandshakeState, error) {
hs.messagePatterns[c.PresharedKeyPlacement-1] = append(hs.messagePatterns[c.PresharedKeyPlacement-1], MessagePatternPSK)
}
}

hs.ss.InitializeSymmetric([]byte("Noise_" + c.Pattern.Name + pskModifier + "_" + string(hs.ss.cs.Name())))
hs.ss.MixHash(c.Prologue)
for _, m := range c.Pattern.InitiatorPreMessages {
Expand Down Expand Up @@ -383,7 +391,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
s.e = e
out = append(out, s.e.Public...)
s.ss.MixHash(s.e.Public)
if len(s.psk) > 0 {
if s.willPsk {
s.ss.MixKey(s.e.Public)
}
case MessagePatternS:
Expand Down Expand Up @@ -435,6 +443,9 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
}
s.ss.MixKey(dh)
case MessagePatternPSK:
if len(s.psk) == 0 {
return nil, nil, nil, errors.New("noise: cannot send psk message without psk set")
}
s.ss.MixKeyAndHash(s.psk)
}
}
Expand All @@ -456,6 +467,15 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
var ErrShortMessage = errors.New("noise: message is too short")

func (s *HandshakeState) SetPresharedKey(psk []byte) error {
if len(psk) != 32 {
return errors.New("noise: specification mandates 256-bit preshared keys")
}
s.psk = make([]byte, 32)
copy(s.psk, psk)
return nil
}

// ReadMessage processes a received handshake message and appends the payload,
// if any to out. If the handshake is completed by the call, two CipherStates
// will be returned, one is used for encryption of messages to the remote peer,
Expand Down Expand Up @@ -491,7 +511,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
s.re = s.re[:s.ss.cs.DHLen()]
copy(s.re, message)
s.ss.MixHash(s.re)
if len(s.psk) > 0 {
if s.willPsk {
s.ss.MixKey(s.re)
}
case MessagePatternS:
Expand Down

0 comments on commit acf4844

Please sign in to comment.