Blob Blame History Raw
diff --git a/_ssl.c b/_ssl.c
index 5fb6e28..1332547 100644
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -226,6 +226,19 @@ typedef struct {
     enum py_ssl_server_or_client socket_type;
 } PySSLSocket;
 
+typedef struct {
+    PyObject_HEAD
+    PySocketSockObject *Socket;         /* Socket on which we're layered */
+    SSL_CTX*            ctx;
+    SSL*                ssl;
+    X509*               peer_cert;
+    char                server[X509_NAME_MAXLEN];
+    char                issuer[X509_NAME_MAXLEN];
+    int                 shutdown_seen_zero;
+
+} PySSLObject;
+
+static PyTypeObject PySSL_Type;
 static PyTypeObject PySSLContext_Type;
 static PyTypeObject PySSLSocket_Type;
 
@@ -527,6 +540,203 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
     return self;
 }
 
+static PySSLObject *
+newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file,
+               enum py_ssl_server_or_client socket_type,
+               enum py_ssl_cert_requirements certreq,
+               enum py_ssl_version proto_version,
+               char *cacerts_file, char *ciphers)
+{
+    PySSLObject *self;
+    char *errstr = NULL;
+    int ret;
+    int verification_mode;
+    long options;
+
+    self = PyObject_New(PySSLObject, &PySSL_Type); /* Create new object */
+    if (self == NULL)
+        return NULL;
+    memset(self->server, '\0', sizeof(char) * X509_NAME_MAXLEN);
+    memset(self->issuer, '\0', sizeof(char) * X509_NAME_MAXLEN);
+    self->peer_cert = NULL;
+    self->ssl = NULL;
+    self->ctx = NULL;
+    self->Socket = NULL;
+    self->shutdown_seen_zero = 0;
+
+    /* Make sure the SSL error state is initialized */
+    (void) ERR_get_state();
+    ERR_clear_error();
+
+    if ((key_file && !cert_file) || (!key_file && cert_file)) {
+        errstr = ERRSTR("Both the key & certificate files "
+                        "must be specified");
+        goto fail;
+    }
+
+    if ((socket_type == PY_SSL_SERVER) &&
+        ((key_file == NULL) || (cert_file == NULL))) {
+        errstr = ERRSTR("Both the key & certificate files "
+                        "must be specified for server-side operation");
+        goto fail;
+    }
+
+    PySSL_BEGIN_ALLOW_THREADS
+    if (proto_version == PY_SSL_VERSION_TLS1)
+        self->ctx = SSL_CTX_new(TLSv1_method()); /* Set up context */
+    else if (proto_version == PY_SSL_VERSION_SSL3)
+        self->ctx = SSL_CTX_new(SSLv3_method()); /* Set up context */
+#ifndef OPENSSL_NO_SSL2
+    else if (proto_version == PY_SSL_VERSION_SSL2)
+        self->ctx = SSL_CTX_new(SSLv2_method()); /* Set up context */
+#endif
+    else if (proto_version == PY_SSL_VERSION_SSL23)
+        self->ctx = SSL_CTX_new(SSLv23_method()); /* Set up context */
+    PySSL_END_ALLOW_THREADS
+
+    if (self->ctx == NULL) {
+        errstr = ERRSTR("Invalid SSL protocol variant specified.");
+        goto fail;
+    }
+
+    if (ciphers != NULL) {
+        ret = SSL_CTX_set_cipher_list(self->ctx, ciphers);
+        if (ret == 0) {
+            errstr = ERRSTR("No cipher can be selected.");
+            goto fail;
+        }
+    }
+
+    if (certreq != PY_SSL_CERT_NONE) {
+        if (cacerts_file == NULL) {
+            errstr = ERRSTR("No root certificates specified for "
+                            "verification of other-side certificates.");
+            goto fail;
+        } else {
+            PySSL_BEGIN_ALLOW_THREADS
+            ret = SSL_CTX_load_verify_locations(self->ctx,
+                                                cacerts_file,
+                                                NULL);
+            PySSL_END_ALLOW_THREADS
+            if (ret != 1) {
+                _setSSLError(NULL, 0, __FILE__, __LINE__);
+                goto fail;
+            }
+        }
+    }
+    if (key_file) {
+        PySSL_BEGIN_ALLOW_THREADS
+        ret = SSL_CTX_use_PrivateKey_file(self->ctx, key_file,
+                                          SSL_FILETYPE_PEM);
+        PySSL_END_ALLOW_THREADS
+        if (ret != 1) {
+            _setSSLError(NULL, ret, __FILE__, __LINE__);
+            goto fail;
+        }
+
+        PySSL_BEGIN_ALLOW_THREADS
+        ret = SSL_CTX_use_certificate_chain_file(self->ctx,
+                                                 cert_file);
+        PySSL_END_ALLOW_THREADS
+        if (ret != 1) {
+            /*
+            fprintf(stderr, "ret is %d, errcode is %lu, %lu, with file \"%s\"\n",
+                ret, ERR_peek_error(), ERR_peek_last_error(), cert_file);
+                */
+            if (ERR_peek_last_error() != 0) {
+                _setSSLError(NULL, ret, __FILE__, __LINE__);
+                goto fail;
+            }
+        }
+    }
+
+    /* ssl compatibility */
+    options = SSL_OP_ALL & ~SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
+    if (proto_version != PY_SSL_VERSION_SSL2)
+        options |= SSL_OP_NO_SSLv2;
+    SSL_CTX_set_options(self->ctx, options);
+
+    verification_mode = SSL_VERIFY_NONE;
+    if (certreq == PY_SSL_CERT_OPTIONAL)
+        verification_mode = SSL_VERIFY_PEER;
+    else if (certreq == PY_SSL_CERT_REQUIRED)
+        verification_mode = (SSL_VERIFY_PEER |
+                             SSL_VERIFY_FAIL_IF_NO_PEER_CERT);
+    SSL_CTX_set_verify(self->ctx, verification_mode,
+                       NULL); /* set verify lvl */
+
+    PySSL_BEGIN_ALLOW_THREADS
+    self->ssl = SSL_new(self->ctx); /* New ssl struct */
+    PySSL_END_ALLOW_THREADS
+    SSL_set_fd(self->ssl, Sock->sock_fd);       /* Set the socket for SSL */
+#ifdef SSL_MODE_AUTO_RETRY
+    SSL_set_mode(self->ssl, SSL_MODE_AUTO_RETRY);
+#endif
+
+    /* If the socket is in non-blocking mode or timeout mode, set the BIO
+     * to non-blocking mode (blocking is the default)
+     */
+    if (Sock->sock_timeout >= 0.0) {
+        /* Set both the read and write BIO's to non-blocking mode */
+        BIO_set_nbio(SSL_get_rbio(self->ssl), 1);
+        BIO_set_nbio(SSL_get_wbio(self->ssl), 1);
+    }
+
+    PySSL_BEGIN_ALLOW_THREADS
+    if (socket_type == PY_SSL_CLIENT)
+        SSL_set_connect_state(self->ssl);
+    else
+        SSL_set_accept_state(self->ssl);
+    PySSL_END_ALLOW_THREADS
+
+    self->Socket = Sock;
+    Py_INCREF(self->Socket);
+    return self;
+ fail:
+    if (errstr)
+        PyErr_SetString(PySSLErrorObject, errstr);
+    Py_DECREF(self);
+    return NULL;
+}
+
+static PyObject *
+PySSL_sslwrap(PyObject *self, PyObject *args)
+{
+    PySocketSockObject *Sock;
+    int server_side = 0;
+    int verification_mode = PY_SSL_CERT_NONE;
+    int protocol = PY_SSL_VERSION_SSL23;
+    char *key_file = NULL;
+    char *cert_file = NULL;
+    char *cacerts_file = NULL;
+    char *ciphers = NULL;
+
+    if (!PyArg_ParseTuple(args, "O!i|zziizz:sslwrap",
+                          PySocketModule.Sock_Type,
+                          &Sock,
+                          &server_side,
+                          &key_file, &cert_file,
+                          &verification_mode, &protocol,
+                          &cacerts_file, &ciphers))
+        return NULL;
+
+    /*
+    fprintf(stderr,
+        "server_side is %d, keyfile %p, certfile %p, verify_mode %d, "
+        "protocol %d, certs %p\n",
+        server_side, key_file, cert_file, verification_mode,
+        protocol, cacerts_file);
+     */
+
+    return (PyObject *) newPySSLObject(Sock, key_file, cert_file,
+                                       server_side, verification_mode,
+                                       protocol, cacerts_file,
+                                       ciphers);
+}
+
+PyDoc_STRVAR(ssl_doc,
+"sslwrap(socket, server_side, [keyfile, certfile, certs_mode, protocol,\n"
+"                              cacertsfile, ciphers]) -> sslobject");
 
 /* SSL object methods */
 
@@ -1911,6 +2121,7 @@ static PyGetSetDef ssl_getsetlist[] = {
 };
 
 static PyMethodDef PySSLMethods[] = {
+    {"sslwrap", PySSL_sslwrap, METH_VARARGS, ssl_doc},
     {"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS},
     {"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS,
      PySSL_SSLwrite_doc},
@@ -1969,6 +2180,29 @@ static PyTypeObject PySSLSocket_Type = {
     ssl_getsetlist,                     /*tp_getset*/
 };
 
+static PyObject *PySSL_getattr(PySSLObject *self, char *name)
+{
+    return Py_FindMethod(PySSLMethods, (PyObject *)self, name);
+}
+
+static PyTypeObject PySSL_Type = {
+    PyVarObject_HEAD_INIT(NULL, 0)
+    "ssl.SSLContext",                   /*tp_name*/
+    sizeof(PySSLObject),                /*tp_basicsize*/
+    0,                                  /*tp_itemsize*/
+    /* methods */
+    (destructor)PySSL_dealloc,          /*tp_dealloc*/
+    0,                                  /*tp_print*/
+    (getattrfunc)PySSL_getattr,         /*tp_getattr*/
+    0,                                  /*tp_setattr*/
+    0,                                  /*tp_compare*/
+    0,                                  /*tp_repr*/
+    0,                                  /*tp_as_number*/
+    0,                                  /*tp_as_sequence*/
+    0,                                  /*tp_as_mapping*/
+    0,                                  /*tp_hash*/
+};
+
 
 /*
  * _SSLContext objects