From 484d6004f75fb23f26a2e110f67f0d06c577cbd4 Mon Sep 17 00:00:00 2001 From: Juan Batiz-Benet Date: Sat, 27 Sep 2014 00:18:54 -0700 Subject: [PATCH] crypto: abstracted Key and added Equals. --- crypto/key.go | 29 +++++++++++++++++++++------ crypto/key_test.go | 42 ++++++++++++++++++++++++++++++++++++++- crypto/rsa.go | 10 ++++++++++ crypto/spipe/handshake.go | 12 +---------- 4 files changed, 75 insertions(+), 18 deletions(-) diff --git a/crypto/key.go b/crypto/key.go index 38b3b0ebd..f0a35c698 100644 --- a/crypto/key.go +++ b/crypto/key.go @@ -23,7 +23,17 @@ const ( RSA = iota ) +type Key interface { + // Bytes returns a serialized, storeable representation of this key + Bytes() ([]byte, error) + + // Equals checks whether two PubKeys are the same + Equals(Key) bool +} + type PrivKey interface { + Key + // Cryptographically sign the given bytes Sign([]byte) ([]byte, error) @@ -32,17 +42,13 @@ type PrivKey interface { // Generate a secret string of bytes GenSecret() []byte - - // Bytes returns a serialized, storeable representation of this key - Bytes() ([]byte, error) } type PubKey interface { + Key + // Verify that 'sig' is the signed hash of 'data' Verify(data []byte, sig []byte) (bool, error) - - // Bytes returns a serialized, storeable representation of this key - Bytes() ([]byte, error) } // Given a public key, generates the shared key. @@ -229,3 +235,14 @@ func UnmarshalPrivateKey(data []byte) (PrivKey, error) { return nil, ErrBadKeyType } } + +// KeyEqual checks whether two +func KeyEqual(k1, k2 Key) bool { + if k1 == k2 { + return true + } + + b1, err1 := k1.Bytes() + b2, err2 := k2.Bytes() + return bytes.Equal(b1, b2) && err1 == err2 +} diff --git a/crypto/key_test.go b/crypto/key_test.go index c002c5819..13c94215e 100644 --- a/crypto/key_test.go +++ b/crypto/key_test.go @@ -3,12 +3,14 @@ package crypto import "testing" func TestRsaKeys(t *testing.T) { - sk, _, err := GenerateKeyPair(RSA, 512) + sk, pk, err := GenerateKeyPair(RSA, 512) if err != nil { t.Fatal(err) } testKeySignature(t, sk) testKeyEncoding(t, sk) + testKeyEquals(t, sk) + testKeyEquals(t, pk) } func testKeySignature(t *testing.T, sk PrivKey) { @@ -52,3 +54,41 @@ func testKeyEncoding(t *testing.T, sk PrivKey) { t.Fatal(err) } } + +func testKeyEquals(t *testing.T, k Key) { + kb, err := k.Bytes() + if err != nil { + t.Fatal(err) + } + + if !KeyEqual(k, k) { + t.Fatal("Key not equal to itself.") + } + + if !KeyEqual(k, testkey(kb)) { + t.Fatal("Key not equal to key with same bytes.") + } + + sk, pk, err := GenerateKeyPair(RSA, 512) + if err != nil { + t.Fatal(err) + } + + if KeyEqual(k, sk) { + t.Fatal("Keys should not equal.") + } + + if KeyEqual(k, pk) { + t.Fatal("Keys should not equal.") + } +} + +type testkey []byte + +func (pk testkey) Bytes() ([]byte, error) { + return pk, nil +} + +func (pk testkey) Equals(k Key) bool { + return KeyEqual(pk, k) +} diff --git a/crypto/rsa.go b/crypto/rsa.go index 513b868d1..e582b59c2 100644 --- a/crypto/rsa.go +++ b/crypto/rsa.go @@ -41,6 +41,11 @@ func (pk *RsaPublicKey) Bytes() ([]byte, error) { return proto.Marshal(pbmes) } +// Equals checks whether this key is equal to another +func (pk *RsaPublicKey) Equals(k Key) bool { + return KeyEqual(pk, k) +} + func (sk *RsaPrivateKey) GenSecret() []byte { buf := make([]byte, 16) rand.Read(buf) @@ -65,6 +70,11 @@ func (sk *RsaPrivateKey) Bytes() ([]byte, error) { return proto.Marshal(pbmes) } +// Equals checks whether this key is equal to another +func (sk *RsaPrivateKey) Equals(k Key) bool { + return KeyEqual(sk, k) +} + func UnmarshalRsaPrivateKey(b []byte) (*RsaPrivateKey, error) { sk, err := x509.ParsePKCS1PrivateKey(b) if err != nil { diff --git a/crypto/spipe/handshake.go b/crypto/spipe/handshake.go index 18c1eeec4..f617c75b3 100644 --- a/crypto/spipe/handshake.go +++ b/crypto/spipe/handshake.go @@ -379,17 +379,7 @@ func getOrConstructPeer(peers peer.Peerstore, rpk ci.PubKey) (*peer.Peer, error) // did have pubkey, let's verify it's really the same. // this shouldn't ever happen, given we hashed, etc, but it could mean // expected code (or protocol) invariants violated. - - lb, err1 := npeer.PubKey.Bytes() - if err1 != nil { - return nil, err1 - } - rb, err2 := rpk.Bytes() - if err2 != nil { - return nil, err2 - } - - if !bytes.Equal(lb, rb) { + if !npeer.PubKey.Equals(rpk) { return nil, fmt.Errorf("WARNING: PubKey mismatch: %v", npeer.ID.Pretty()) } return npeer, nil