diff -up nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc.alert-handler nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc --- nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc.alert-handler 2017-02-17 14:20:06.000000000 +0100 +++ nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc 2017-03-14 11:01:42.563689719 +0100 @@ -24,6 +24,8 @@ namespace nss_test { TEST_P(TlsConnectTls13, ZeroRtt) { SetupForZeroRtt(); + client_->SetExpectedAlertSentCount(1); + server_->SetExpectedAlertReceivedCount(1); client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); ExpectResumption(RESUME_TICKET); @@ -103,6 +105,8 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRtt EnableAlpn(); SetupForZeroRtt(); EnableAlpn(); + client_->SetExpectedAlertSentCount(1); + server_->SetExpectedAlertReceivedCount(1); client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); ExpectResumption(RESUME_TICKET); diff -up nss/gtests/ssl_gtest/ssl_exporter_unittest.cc.alert-handler nss/gtests/ssl_gtest/ssl_exporter_unittest.cc --- nss/gtests/ssl_gtest/ssl_exporter_unittest.cc.alert-handler 2017-02-17 14:20:06.000000000 +0100 +++ nss/gtests/ssl_gtest/ssl_exporter_unittest.cc 2017-03-14 11:01:42.563689719 +0100 @@ -90,6 +90,8 @@ int32_t RegularExporterShouldFail(TlsAge TEST_P(TlsConnectTls13, EarlyExporter) { SetupForZeroRtt(); + client_->SetExpectedAlertSentCount(1); + server_->SetExpectedAlertReceivedCount(1); client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); ExpectResumption(RESUME_TICKET); diff -up nss/gtests/ssl_gtest/ssl_extension_unittest.cc.alert-handler nss/gtests/ssl_gtest/ssl_extension_unittest.cc --- nss/gtests/ssl_gtest/ssl_extension_unittest.cc.alert-handler 2017-03-14 11:01:42.563689719 +0100 +++ nss/gtests/ssl_gtest/ssl_extension_unittest.cc 2017-03-14 11:06:39.215006989 +0100 @@ -167,27 +167,69 @@ class TlsExtensionTestBase : public TlsC : TlsConnectTestBase(mode, version) {} void ClientHelloErrorTest(PacketFilter* filter, - uint8_t alert = kTlsAlertDecodeError) { + uint8_t desc = kTlsAlertDecodeError) { + SSLAlert alert; + auto alert_recorder = new TlsAlertRecorder(); server_->SetPacketFilter(alert_recorder); if (filter) { client_->SetPacketFilter(filter); } ConnectExpectFail(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(alert, alert_recorder->description()); + EXPECT_EQ(desc, alert_recorder->description()); + + // verify no alerts received by the server + EXPECT_EQ(0U, server_->alert_received_count()); + + // verify the alert sent by the server + EXPECT_EQ(1U, server_->alert_sent_count()); + EXPECT_TRUE(server_->GetLastAlertSent(&alert)); + EXPECT_EQ(kTlsAlertFatal, alert.level); + EXPECT_EQ(desc, alert.description); + + // verify the alert received by the client + EXPECT_EQ(1U, client_->alert_received_count()); + EXPECT_TRUE(client_->GetLastAlertReceived(&alert)); + EXPECT_EQ(kTlsAlertFatal, alert.level); + EXPECT_EQ(desc, alert.description); + + // verify no alerts sent by the client + EXPECT_EQ(0U, client_->alert_sent_count()); } void ServerHelloErrorTest(PacketFilter* filter, - uint8_t alert = kTlsAlertDecodeError) { + uint8_t desc = kTlsAlertDecodeError) { + SSLAlert alert; + auto alert_recorder = new TlsAlertRecorder(); client_->SetPacketFilter(alert_recorder); if (filter) { server_->SetPacketFilter(filter); } ConnectExpectFail(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(alert, alert_recorder->description()); + EXPECT_EQ(desc, alert_recorder->description()); + + // verify no alerts received by the client + EXPECT_EQ(0U, client_->alert_received_count()); + + // verify the alert sent by the client + EXPECT_EQ(1U, client_->alert_sent_count()); + EXPECT_TRUE(client_->GetLastAlertSent(&alert)); + EXPECT_EQ(kTlsAlertFatal, alert.level); + EXPECT_EQ(desc, alert.description); + + // verify the alert received by the server + EXPECT_EQ(1U, server_->alert_received_count()); + EXPECT_TRUE(server_->GetLastAlertReceived(&alert)); + EXPECT_EQ(kTlsAlertFatal, alert.level); + EXPECT_EQ(desc, alert.description); + + // verify no alerts sent by the server + EXPECT_EQ(0U, server_->alert_sent_count()); } static void InitSimpleSni(DataBuffer* extension) { diff -up nss/gtests/ssl_gtest/ssl_version_unittest.cc.alert-handler nss/gtests/ssl_gtest/ssl_version_unittest.cc --- nss/gtests/ssl_gtest/ssl_version_unittest.cc.alert-handler 2017-02-17 14:20:06.000000000 +0100 +++ nss/gtests/ssl_gtest/ssl_version_unittest.cc 2017-03-14 11:01:42.563689719 +0100 @@ -225,6 +225,7 @@ TEST_F(TlsConnectTest, Tls13RejectsRehan TEST_P(TlsConnectGeneric, AlertBeforeServerHello) { EnsureTlsSetup(); + client_->SetExpectedAlertReceivedCount(1); client_->StartConnect(); server_->StartConnect(); client_->Handshake(); // Send ClientHello. diff -up nss/gtests/ssl_gtest/tls_agent.cc.alert-handler nss/gtests/ssl_gtest/tls_agent.cc --- nss/gtests/ssl_gtest/tls_agent.cc.alert-handler 2017-02-17 14:20:06.000000000 +0100 +++ nss/gtests/ssl_gtest/tls_agent.cc 2017-03-14 11:07:22.414890511 +0100 @@ -61,6 +61,12 @@ TlsAgent::TlsAgent(const std::string& na can_falsestart_hook_called_(false), sni_hook_called_(false), auth_certificate_hook_called_(false), + alert_received_count_(0), + expected_alert_received_count_(0), + last_alert_received_({0, 0}), + alert_sent_count_(0), + expected_alert_sent_count_(0), + last_alert_sent_({0, 0}), handshake_callback_called_(false), error_code_(0), send_ctr_(0), @@ -165,6 +171,14 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; + rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + rv = SSL_HandshakeCallback(ssl_fd_, HandshakeCallback, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; @@ -578,6 +592,11 @@ void TlsAgent::CheckErrorCode(int32_t ex << PORT_ErrorToName(expected) << std::endl; } +void TlsAgent::CheckAlerts() const { + EXPECT_EQ(expected_alert_received_count_, alert_received_count_); + EXPECT_EQ(expected_alert_sent_count_, alert_sent_count_); +} + void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const { ASSERT_EQ(0, error_code_); WAIT_(error_code_ != 0, delay); diff -up nss/gtests/ssl_gtest/tls_agent.h.alert-handler nss/gtests/ssl_gtest/tls_agent.h --- nss/gtests/ssl_gtest/tls_agent.h.alert-handler 2017-02-17 14:20:06.000000000 +0100 +++ nss/gtests/ssl_gtest/tls_agent.h 2017-03-14 11:01:42.564689693 +0100 @@ -139,6 +139,7 @@ class TlsAgent : public PollTarget { void EnableSrtp(); void CheckSrtp() const; void CheckErrorCode(int32_t expected) const; + void CheckAlerts() const; void WaitForErrorCode(int32_t expected, uint32_t delay) const; // Send data on the socket, encrypting it. void SendData(size_t bytes, size_t blocksize = 1024); @@ -239,6 +240,34 @@ class TlsAgent : public PollTarget { sni_callback_ = sni_callback; } + size_t alert_received_count() const { return alert_received_count_; } + + void SetExpectedAlertReceivedCount(size_t count) { + expected_alert_received_count_ = count; + } + + bool GetLastAlertReceived(SSLAlert* alert) const { + if (!alert_received_count_) { + return false; + } + *alert = last_alert_received_; + return true; + } + + size_t alert_sent_count() const { return alert_sent_count_; } + + void SetExpectedAlertSentCount(size_t count) { + expected_alert_sent_count_ = count; + } + + bool GetLastAlertSent(SSLAlert* alert) const { + if (!alert_sent_count_) { + return false; + } + *alert = last_alert_sent_; + return true; + } + private: const static char* states[]; @@ -320,6 +349,30 @@ class TlsAgent : public PollTarget { return SECSuccess; } + static void AlertReceivedCallback(const PRFileDesc* fd, void* arg, + const SSLAlert* alert) { + TlsAgent* agent = reinterpret_cast(arg); + + std::cerr << agent->role_str() + << ": Alert received: level=" << static_cast(alert->level) + << " desc=" << static_cast(alert->description) << std::endl; + + ++agent->alert_received_count_; + agent->last_alert_received_ = *alert; + } + + static void AlertSentCallback(const PRFileDesc* fd, void* arg, + const SSLAlert* alert) { + TlsAgent* agent = reinterpret_cast(arg); + + std::cerr << agent->role_str() + << ": Alert sent: level=" << static_cast(alert->level) + << " desc=" << static_cast(alert->description) << std::endl; + + ++agent->alert_sent_count_; + agent->last_alert_sent_ = *alert; + } + static void HandshakeCallback(PRFileDesc* fd, void* arg) { TlsAgent* agent = reinterpret_cast(arg); agent->handshake_callback_called_ = true; @@ -352,6 +405,12 @@ class TlsAgent : public PollTarget { bool can_falsestart_hook_called_; bool sni_hook_called_; bool auth_certificate_hook_called_; + size_t alert_received_count_; + size_t expected_alert_received_count_; + SSLAlert last_alert_received_; + size_t alert_sent_count_; + size_t expected_alert_sent_count_; + SSLAlert last_alert_sent_; bool handshake_callback_called_; SSLChannelInfo info_; SSLCipherSuiteInfo csinfo_; diff -up nss/gtests/ssl_gtest/tls_connect.cc.alert-handler nss/gtests/ssl_gtest/tls_connect.cc --- nss/gtests/ssl_gtest/tls_connect.cc.alert-handler 2017-02-17 14:20:06.000000000 +0100 +++ nss/gtests/ssl_gtest/tls_connect.cc 2017-03-14 11:01:42.564689693 +0100 @@ -309,6 +309,9 @@ void TlsConnectTestBase::CheckConnected( CheckResumption(expected_resumption_mode_); client_->CheckSecretsDestroyed(); server_->CheckSecretsDestroyed(); + + client_->CheckAlerts(); + server_->CheckAlerts(); } void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, diff -up nss/lib/ssl/ssl3con.c.alert-handler nss/lib/ssl/ssl3con.c --- nss/lib/ssl/ssl3con.c.alert-handler 2017-03-14 11:01:42.551690030 +0100 +++ nss/lib/ssl/ssl3con.c 2017-03-14 11:03:45.319510356 +0100 @@ -3143,6 +3143,10 @@ SSL3_SendAlert(sslSocket *ss, SSL3AlertL } ssl_ReleaseXmitBufLock(ss); ssl_ReleaseSSL3HandshakeLock(ss); + if (rv == SECSuccess && ss->alertSentCallback) { + SSLAlert alert = { level, desc }; + ss->alertSentCallback(ss->fd, ss->alertSentCallbackArg, &alert); + } return rv; /* error set by ssl3_FlushHandshake or ssl3_SendRecord */ } @@ -3255,6 +3259,11 @@ ssl3_HandleAlert(sslSocket *ss, sslBuffe SSL_TRC(5, ("%d: SSL3[%d] received alert, level = %d, description = %d", SSL_GETPID(), ss->fd, level, desc)); + if (ss->alertReceivedCallback) { + SSLAlert alert = { level, desc }; + ss->alertReceivedCallback(ss->fd, ss->alertReceivedCallbackArg, &alert); + } + switch (desc) { case close_notify: ss->recvdCloseNotify = 1; diff -up nss/lib/ssl/ssl.def.alert-handler nss/lib/ssl/ssl.def --- nss/lib/ssl/ssl.def.alert-handler 2017-02-17 14:20:06.000000000 +0100 +++ nss/lib/ssl/ssl.def 2017-03-14 11:01:42.564689693 +0100 @@ -221,3 +221,10 @@ SSL_SignatureSchemePrefGet; ;+ local: ;+*; ;+}; +;+NSS_3.30.0.1 { # Additional symbols for NSS 3.30 release +;+ global: +SSL_AlertReceivedCallback; +SSL_AlertSentCallback; +;+ local: +;+*; +;+}; diff -up nss/lib/ssl/ssl.h.alert-handler nss/lib/ssl/ssl.h --- nss/lib/ssl/ssl.h.alert-handler 2017-02-17 14:20:06.000000000 +0100 +++ nss/lib/ssl/ssl.h 2017-03-14 11:01:42.564689693 +0100 @@ -820,6 +820,25 @@ SSL_IMPORT PRFileDesc *SSL_ReconfigFD(PR SSL_IMPORT SECStatus SSL_SetPKCS11PinArg(PRFileDesc *fd, void *a); /* +** These are callbacks for dealing with SSL alerts. + */ + +typedef PRUint8 SSLAlertLevel; +typedef PRUint8 SSLAlertDescription; + +typedef struct { + SSLAlertLevel level; + SSLAlertDescription description; +} SSLAlert; + +typedef void(PR_CALLBACK *SSLAlertCallback)(const PRFileDesc *fd, void *arg, + const SSLAlert *alert); + +SSL_IMPORT SECStatus SSL_AlertReceivedCallback(PRFileDesc *fd, SSLAlertCallback cb, + void *arg); +SSL_IMPORT SECStatus SSL_AlertSentCallback(PRFileDesc *fd, SSLAlertCallback cb, + void *arg); +/* ** This is a callback for dealing with server certs that are not authenticated ** by the client. The client app can decide that it actually likes the ** cert by some external means and restart the connection. diff -up nss/lib/ssl/sslimpl.h.alert-handler nss/lib/ssl/sslimpl.h --- nss/lib/ssl/sslimpl.h.alert-handler 2017-02-17 14:20:06.000000000 +0100 +++ nss/lib/ssl/sslimpl.h 2017-03-14 11:01:42.566689641 +0100 @@ -1121,6 +1121,10 @@ struct sslSocketStr { void *getClientAuthDataArg; SSLSNISocketConfig sniSocketConfig; void *sniSocketConfigArg; + SSLAlertCallback alertReceivedCallback; + void *alertReceivedCallbackArg; + SSLAlertCallback alertSentCallback; + void *alertSentCallbackArg; SSLBadCertHandler handleBadCert; void *badCertArg; SSLHandshakeCallback handshakeCallback; diff -up nss/lib/ssl/sslsecur.c.alert-handler nss/lib/ssl/sslsecur.c --- nss/lib/ssl/sslsecur.c.alert-handler 2017-02-17 14:20:06.000000000 +0100 +++ nss/lib/ssl/sslsecur.c 2017-03-14 11:01:42.566689641 +0100 @@ -994,6 +994,42 @@ ssl_SecureWrite(sslSocket *ss, const uns } SECStatus +SSL_AlertReceivedCallback(PRFileDesc *fd, SSLAlertCallback cb, void *arg) +{ + sslSocket *ss; + + ss = ssl_FindSocket(fd); + if (!ss) { + SSL_DBG(("%d: SSL[%d]: unable to find socket in SSL_AlertReceivedCallback", + SSL_GETPID(), fd)); + return SECFailure; + } + + ss->alertReceivedCallback = cb; + ss->alertReceivedCallbackArg = arg; + + return SECSuccess; +} + +SECStatus +SSL_AlertSentCallback(PRFileDesc *fd, SSLAlertCallback cb, void *arg) +{ + sslSocket *ss; + + ss = ssl_FindSocket(fd); + if (!ss) { + SSL_DBG(("%d: SSL[%d]: unable to find socket in SSL_AlertSentCallback", + SSL_GETPID(), fd)); + return SECFailure; + } + + ss->alertSentCallback = cb; + ss->alertSentCallbackArg = arg; + + return SECSuccess; +} + +SECStatus SSL_BadCertHook(PRFileDesc *fd, SSLBadCertHandler f, void *arg) { sslSocket *ss; diff -up nss/lib/ssl/sslsock.c.alert-handler nss/lib/ssl/sslsock.c --- nss/lib/ssl/sslsock.c.alert-handler 2017-03-14 11:01:42.538690367 +0100 +++ nss/lib/ssl/sslsock.c 2017-03-14 11:01:42.566689641 +0100 @@ -330,6 +330,10 @@ ssl_DupSocket(sslSocket *os) ss->getClientAuthDataArg = os->getClientAuthDataArg; ss->sniSocketConfig = os->sniSocketConfig; ss->sniSocketConfigArg = os->sniSocketConfigArg; + ss->alertReceivedCallback = os->alertReceivedCallback; + ss->alertReceivedCallbackArg = os->alertReceivedCallbackArg; + ss->alertSentCallback = os->alertSentCallback; + ss->alertSentCallbackArg = os->alertSentCallbackArg; ss->handleBadCert = os->handleBadCert; ss->badCertArg = os->badCertArg; ss->handshakeCallback = os->handshakeCallback; @@ -2149,6 +2153,14 @@ SSL_ReconfigFD(PRFileDesc *model, PRFile ss->sniSocketConfig = sm->sniSocketConfig; if (sm->sniSocketConfigArg) ss->sniSocketConfigArg = sm->sniSocketConfigArg; + if (ss->alertReceivedCallback) { + ss->alertReceivedCallback = sm->alertReceivedCallback; + ss->alertReceivedCallbackArg = sm->alertReceivedCallbackArg; + } + if (ss->alertSentCallback) { + ss->alertSentCallback = sm->alertSentCallback; + ss->alertSentCallbackArg = sm->alertSentCallbackArg; + } if (sm->handleBadCert) ss->handleBadCert = sm->handleBadCert; if (sm->badCertArg) @@ -3691,6 +3703,10 @@ ssl_NewSocket(PRBool makeLocks, SSLProto ss->sniSocketConfig = NULL; ss->sniSocketConfigArg = NULL; ss->getClientAuthData = NULL; + ss->alertReceivedCallback = NULL; + ss->alertReceivedCallbackArg = NULL; + ss->alertSentCallback = NULL; + ss->alertSentCallbackArg = NULL; ss->handleBadCert = NULL; ss->badCertArg = NULL; ss->pkcs11PinArg = NULL; # HG changeset patch # User Kai Engert # Date 1493741561 -7200 # Tue May 02 18:12:41 2017 +0200 # Node ID 8804a0c65a08ee53096c07cc091536c7cf102b58 # Parent 769f9ae07b103494af809620478e60256a344adc Bug 1360207, Fix incorrect if (ss->...) in SSL_ReconfigFD, Patch contributed by Ian Goldberg, r=ttaubert diff --git a/lib/ssl/sslsock.c b/lib/ssl/sslsock.c --- a/lib/ssl/sslsock.c +++ b/lib/ssl/sslsock.c @@ -2152,11 +2152,11 @@ SSL_ReconfigFD(PRFileDesc *model, PRFile ss->sniSocketConfig = sm->sniSocketConfig; if (sm->sniSocketConfigArg) ss->sniSocketConfigArg = sm->sniSocketConfigArg; - if (ss->alertReceivedCallback) { + if (sm->alertReceivedCallback) { ss->alertReceivedCallback = sm->alertReceivedCallback; ss->alertReceivedCallbackArg = sm->alertReceivedCallbackArg; } - if (ss->alertSentCallback) { + if (sm->alertSentCallback) { ss->alertSentCallback = sm->alertSentCallback; ss->alertSentCallbackArg = sm->alertSentCallbackArg; }