vishalmishra434 / rpms / openssh

Forked from rpms/openssh a month ago
Clone
Blob Blame History Raw
diff --color -ruNp a/gss-genr.c b/gss-genr.c
--- a/gss-genr.c	2024-05-16 15:49:43.999411060 +0200
+++ b/gss-genr.c	2024-06-26 12:17:55.586856954 +0200
@@ -346,6 +346,7 @@ ssh_gssapi_build_ctx(Gssctxt **ctx)
 	(*ctx)->creds = GSS_C_NO_CREDENTIAL;
 	(*ctx)->client = GSS_C_NO_NAME;
 	(*ctx)->client_creds = GSS_C_NO_CREDENTIAL;
+	(*ctx)->first = 1;
 }
 
 /* Delete our context, providing it has been built correctly */
@@ -371,6 +372,12 @@ ssh_gssapi_delete_ctx(Gssctxt **ctx)
 		gss_release_name(&ms, &(*ctx)->client);
 	if ((*ctx)->client_creds != GSS_C_NO_CREDENTIAL)
 		gss_release_cred(&ms, &(*ctx)->client_creds);
+	sshbuf_free((*ctx)->shared_secret);
+	sshbuf_free((*ctx)->server_pubkey);
+	sshbuf_free((*ctx)->server_host_key_blob);
+	sshbuf_free((*ctx)->server_blob);
+	explicit_bzero((*ctx)->hash, sizeof((*ctx)->hash));
+        BN_clear_free((*ctx)->dh_client_pub);
 
 	free(*ctx);
 	*ctx = NULL;
diff --color -ruNp a/kexgssc.c b/kexgssc.c
--- a/kexgssc.c	2024-05-16 15:49:43.820407648 +0200
+++ b/kexgssc.c	2024-07-02 16:26:25.628746744 +0200
@@ -47,566 +47,658 @@
 
 #include "ssh-gss.h"
 
-int
-kexgss_client(struct ssh *ssh)
+static int input_kexgss_hostkey(int, u_int32_t, struct ssh *);
+static int input_kexgss_continue(int, u_int32_t, struct ssh *);
+static int input_kexgss_complete(int, u_int32_t, struct ssh *);
+static int input_kexgss_error(int, u_int32_t, struct ssh *);
+static int input_kexgssgex_group(int, u_int32_t, struct ssh *);
+static int input_kexgssgex_continue(int, u_int32_t, struct ssh *);
+static int input_kexgssgex_complete(int, u_int32_t, struct ssh *);
+
+static int
+kexgss_final(struct ssh *ssh)
 {
 	struct kex *kex = ssh->kex;
-	gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER,
-	    recv_tok = GSS_C_EMPTY_BUFFER,
-	    gssbuf, msg_tok = GSS_C_EMPTY_BUFFER, *token_ptr;
-	Gssctxt *ctxt;
-	OM_uint32 maj_status, min_status, ret_flags;
-	struct sshbuf *server_blob = NULL;
-	struct sshbuf *shared_secret = NULL;
-	struct sshbuf *server_host_key_blob = NULL;
+	Gssctxt *gss = kex->gss;
 	struct sshbuf *empty = NULL;
-	u_char *msg;
-	int type = 0;
-	int first = 1;
+	struct sshbuf *shared_secret = NULL;
 	u_char hash[SSH_DIGEST_MAX_LENGTH];
 	size_t hashlen;
-	u_char c;
 	int r;
 
-	/* Initialise our GSSAPI world */
-	ssh_gssapi_build_ctx(&ctxt);
-	if (ssh_gssapi_id_kex(ctxt, kex->name, kex->kex_type)
-	    == GSS_C_NO_OID)
-		fatal("Couldn't identify host exchange");
-
-	if (ssh_gssapi_import_name(ctxt, kex->gss_host))
-		fatal("Couldn't import hostname");
-
-	if (kex->gss_client &&
-	    ssh_gssapi_client_identity(ctxt, kex->gss_client))
-		fatal("Couldn't acquire client credentials");
-
-	/* Step 1 */
-	switch (kex->kex_type) {
-	case KEX_GSS_GRP1_SHA1:
-	case KEX_GSS_GRP14_SHA1:
-	case KEX_GSS_GRP14_SHA256:
-	case KEX_GSS_GRP16_SHA512:
-		r = kex_dh_keypair(kex);
-		break;
-	case KEX_GSS_NISTP256_SHA256:
-		r = kex_ecdh_keypair(kex);
-		break;
-	case KEX_GSS_C25519_SHA256:
-		r = kex_c25519_keypair(kex);
-		break;
-	default:
-		fatal_f("Unexpected KEX type %d", kex->kex_type);
-	}
-	if (r != 0) {
-		ssh_gssapi_delete_ctx(&ctxt);
-		return r;
-	}
-
-	token_ptr = GSS_C_NO_BUFFER;
-
-	do {
-		debug("Calling gss_init_sec_context");
-
-		maj_status = ssh_gssapi_init_ctx(ctxt,
-		    kex->gss_deleg_creds, token_ptr, &send_tok,
-		    &ret_flags);
-
-		if (GSS_ERROR(maj_status)) {
-			/* XXX Useles code: Missing send? */
-			if (send_tok.length != 0) {
-				if ((r = sshpkt_start(ssh,
-				        SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
-				    (r = sshpkt_put_string(ssh, send_tok.value,
-				        send_tok.length)) != 0)
-					fatal("sshpkt failed: %s", ssh_err(r));
-			}
-			fatal("gss_init_context failed");
-		}
-
-		/* If we've got an old receive buffer get rid of it */
-		if (token_ptr != GSS_C_NO_BUFFER)
-			gss_release_buffer(&min_status, &recv_tok);
-
-		if (maj_status == GSS_S_COMPLETE) {
-			/* If mutual state flag is not true, kex fails */
-			if (!(ret_flags & GSS_C_MUTUAL_FLAG))
-				fatal("Mutual authentication failed");
-
-			/* If integ avail flag is not true kex fails */
-			if (!(ret_flags & GSS_C_INTEG_FLAG))
-				fatal("Integrity check failed");
-		}
-
-		/*
-		 * If we have data to send, then the last message that we
-		 * received cannot have been a 'complete'.
-		 */
-		if (send_tok.length != 0) {
-			if (first) {
-				if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_INIT)) != 0 ||
-				    (r = sshpkt_put_string(ssh, send_tok.value,
-				        send_tok.length)) != 0 ||
-				    (r = sshpkt_put_stringb(ssh, kex->client_pub)) != 0)
-					fatal("failed to construct packet: %s", ssh_err(r));
-				first = 0;
-			} else {
-				if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
-				    (r = sshpkt_put_string(ssh, send_tok.value,
-				        send_tok.length)) != 0)
-					fatal("failed to construct packet: %s", ssh_err(r));
-			}
-			if ((r = sshpkt_send(ssh)) != 0)
-				fatal("failed to send packet: %s", ssh_err(r));
-			gss_release_buffer(&min_status, &send_tok);
-
-			/* If we've sent them data, they should reply */
-			do {
-				type = ssh_packet_read(ssh);
-				if (type == SSH2_MSG_KEXGSS_HOSTKEY) {
-					u_char *tmp = NULL;
-					size_t tmp_len = 0;
-
-					debug("Received KEXGSS_HOSTKEY");
-					if (server_host_key_blob)
-						fatal("Server host key received more than once");
-					if ((r = sshpkt_get_string(ssh, &tmp, &tmp_len)) != 0)
-						fatal("Failed to read server host key: %s", ssh_err(r));
-					if ((server_host_key_blob = sshbuf_from(tmp, tmp_len)) == NULL)
-						fatal("sshbuf_from failed");
-				}
-			} while (type == SSH2_MSG_KEXGSS_HOSTKEY);
-
-			switch (type) {
-			case SSH2_MSG_KEXGSS_CONTINUE:
-				debug("Received GSSAPI_CONTINUE");
-				if (maj_status == GSS_S_COMPLETE)
-					fatal("GSSAPI Continue received from server when complete");
-				if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
-				        &recv_tok)) != 0 ||
-				    (r = sshpkt_get_end(ssh)) != 0)
-					fatal("Failed to read token: %s", ssh_err(r));
-				break;
-			case SSH2_MSG_KEXGSS_COMPLETE:
-				debug("Received GSSAPI_COMPLETE");
-				if (msg_tok.value != NULL)
-				        fatal("Received GSSAPI_COMPLETE twice?");
-				if ((r = sshpkt_getb_froms(ssh, &server_blob)) != 0 ||
-				    (r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
-				        &msg_tok)) != 0)
-					fatal("Failed to read message: %s", ssh_err(r));
-
-				/* Is there a token included? */
-				if ((r = sshpkt_get_u8(ssh, &c)) != 0)
-					fatal("sshpkt failed: %s", ssh_err(r));
-				if (c) {
-					if ((r = ssh_gssapi_sshpkt_get_buffer_desc(
-					    ssh, &recv_tok)) != 0)
-						fatal("Failed to read token: %s", ssh_err(r));
-					/* If we're already complete - protocol error */
-					if (maj_status == GSS_S_COMPLETE)
-						sshpkt_disconnect(ssh, "Protocol error: received token when complete");
-				} else {
-					/* No token included */
-					if (maj_status != GSS_S_COMPLETE)
-						sshpkt_disconnect(ssh, "Protocol error: did not receive final token");
-				}
-				if ((r = sshpkt_get_end(ssh)) != 0) {
-					fatal("Expecting end of packet.");
-				}
-				break;
-			case SSH2_MSG_KEXGSS_ERROR:
-				debug("Received Error");
-				if ((r = sshpkt_get_u32(ssh, &maj_status)) != 0 ||
-				    (r = sshpkt_get_u32(ssh, &min_status)) != 0 ||
-				    (r = sshpkt_get_string(ssh, &msg, NULL)) != 0 ||
-				    (r = sshpkt_get_string(ssh, NULL, NULL)) != 0 || /* lang tag */
-				    (r = sshpkt_get_end(ssh)) != 0)
-					fatal("sshpkt_get failed: %s", ssh_err(r));
-				fatal("GSSAPI Error: \n%.400s", msg);
-			default:
-				sshpkt_disconnect(ssh, "Protocol error: didn't expect packet type %d",
-				    type);
-			}
-			token_ptr = &recv_tok;
-		} else {
-			/* No data, and not complete */
-			if (maj_status != GSS_S_COMPLETE)
-				fatal("Not complete, and no token output");
-		}
-	} while (maj_status & GSS_S_CONTINUE_NEEDED);
-
 	/*
 	 * We _must_ have received a COMPLETE message in reply from the
 	 * server, which will have set server_blob and msg_tok
 	 */
 
-	if (type != SSH2_MSG_KEXGSS_COMPLETE)
-		fatal("Didn't receive a SSH2_MSG_KEXGSS_COMPLETE when I expected it");
-
 	/* compute shared secret */
 	switch (kex->kex_type) {
 	case KEX_GSS_GRP1_SHA1:
 	case KEX_GSS_GRP14_SHA1:
 	case KEX_GSS_GRP14_SHA256:
 	case KEX_GSS_GRP16_SHA512:
-		r = kex_dh_dec(kex, server_blob, &shared_secret);
+		r = kex_dh_dec(kex, gss->server_blob, &shared_secret);
 		break;
 	case KEX_GSS_C25519_SHA256:
-		if (sshbuf_ptr(server_blob)[sshbuf_len(server_blob)] & 0x80)
+		if (sshbuf_ptr(gss->server_blob)[sshbuf_len(gss->server_blob)] & 0x80)
 			fatal("The received key has MSB of last octet set!");
-		r = kex_c25519_dec(kex, server_blob, &shared_secret);
+		r = kex_c25519_dec(kex, gss->server_blob, &shared_secret);
 		break;
 	case KEX_GSS_NISTP256_SHA256:
-		if (sshbuf_len(server_blob) != 65)
-			fatal("The received NIST-P256 key did not match"
-			    "expected length (expected 65, got %zu)", sshbuf_len(server_blob));
+		if (sshbuf_len(gss->server_blob) != 65)
+			fatal("The received NIST-P256 key did not match "
+			      "expected length (expected 65, got %zu)",
+			      sshbuf_len(gss->server_blob));
 
-		if (sshbuf_ptr(server_blob)[0] != POINT_CONVERSION_UNCOMPRESSED)
+		if (sshbuf_ptr(gss->server_blob)[0] != POINT_CONVERSION_UNCOMPRESSED)
 			fatal("The received NIST-P256 key does not have first octet 0x04");
 
-		r = kex_ecdh_dec(kex, server_blob, &shared_secret);
+		r = kex_ecdh_dec(kex, gss->server_blob, &shared_secret);
 		break;
 	default:
 		r = SSH_ERR_INVALID_ARGUMENT;
 		break;
 	}
-	if (r != 0)
+	if (r != 0) {
+		ssh_gssapi_delete_ctx(&kex->gss);
 		goto out;
+	}
 
 	if ((empty = sshbuf_new()) == NULL) {
+		ssh_gssapi_delete_ctx(&kex->gss);
 		r = SSH_ERR_ALLOC_FAIL;
 		goto out;
 	}
 
 	hashlen = sizeof(hash);
-	if ((r = kex_gen_hash(
-	    kex->hash_alg,
-	    kex->client_version,
-	    kex->server_version,
-	    kex->my,
-	    kex->peer,
-	    (server_host_key_blob ? server_host_key_blob : empty),
-	    kex->client_pub,
-	    server_blob,
-	    shared_secret,
-	    hash, &hashlen)) != 0)
+	r = kex_gen_hash(kex->hash_alg, kex->client_version,
+			 kex->server_version, kex->my, kex->peer,
+			 (gss->server_host_key_blob ? gss->server_host_key_blob : empty),
+			 kex->client_pub, gss->server_blob, shared_secret,
+			 hash, &hashlen);
+	sshbuf_free(empty);
+	if (r != 0)
 		fatal_f("Unexpected KEX type %d", kex->kex_type);
 
-	gssbuf.value = hash;
-	gssbuf.length = hashlen;
+	gss->buf.value = hash;
+	gss->buf.length = hashlen;
 
 	/* Verify that the hash matches the MIC we just got. */
-	if (GSS_ERROR(ssh_gssapi_checkmic(ctxt, &gssbuf, &msg_tok)))
+	if (GSS_ERROR(ssh_gssapi_checkmic(gss, &gss->buf, &gss->msg_tok)))
 		sshpkt_disconnect(ssh, "Hash's MIC didn't verify");
 
-	gss_release_buffer(&min_status, &msg_tok);
+	gss_release_buffer(&gss->minor, &gss->msg_tok);
 
 	if (kex->gss_deleg_creds)
-		ssh_gssapi_credentials_updated(ctxt);
+		ssh_gssapi_credentials_updated(gss);
 
 	if (gss_kex_context == NULL)
-		gss_kex_context = ctxt;
+		gss_kex_context = gss;
 	else
-		ssh_gssapi_delete_ctx(&ctxt);
+		ssh_gssapi_delete_ctx(&kex->gss);
 
 	if ((r = kex_derive_keys(ssh, hash, hashlen, shared_secret)) == 0)
 		r = kex_send_newkeys(ssh);
 
+	if (kex->gss != NULL) {
+		sshbuf_free(gss->server_host_key_blob);
+		gss->server_host_key_blob = NULL;
+		sshbuf_free(gss->server_blob);
+		gss->server_blob = NULL;
+	}
 out:
-	explicit_bzero(hash, sizeof(hash));
 	explicit_bzero(kex->c25519_client_key, sizeof(kex->c25519_client_key));
-	sshbuf_free(empty);
-	sshbuf_free(server_host_key_blob);
-	sshbuf_free(server_blob);
+	explicit_bzero(hash, sizeof(hash));
 	sshbuf_free(shared_secret);
 	sshbuf_free(kex->client_pub);
 	kex->client_pub = NULL;
 	return r;
 }
 
+static int
+kexgss_init_ctx(struct ssh *ssh,
+		gss_buffer_desc *token_ptr)
+{
+	struct kex *kex = ssh->kex;
+	Gssctxt *gss = kex->gss;
+	gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER;
+	OM_uint32 ret_flags;
+	int r;
+
+	debug("Calling gss_init_sec_context");
+
+	gss->major = ssh_gssapi_init_ctx(gss, kex->gss_deleg_creds,
+					 token_ptr, &send_tok, &ret_flags);
+
+	if (GSS_ERROR(gss->major)) {
+		/* XXX Useless code: Missing send? */
+		if (send_tok.length != 0) {
+			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
+			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0)
+				fatal("sshpkt failed: %s", ssh_err(r));
+		}
+		fatal("gss_init_context failed");
+	}
+
+	/* If we've got an old receive buffer get rid of it */
+	if (token_ptr != GSS_C_NO_BUFFER)
+		gss_release_buffer(&gss->minor, token_ptr);
+
+	if (gss->major == GSS_S_COMPLETE) {
+		/* If mutual state flag is not true, kex fails */
+		if (!(ret_flags & GSS_C_MUTUAL_FLAG))
+			fatal("Mutual authentication failed");
+
+		/* If integ avail flag is not true kex fails */
+		if (!(ret_flags & GSS_C_INTEG_FLAG))
+			fatal("Integrity check failed");
+	}
+
+	/*
+	 * If we have data to send, then the last message that we
+	 * received cannot have been a 'complete'.
+	 */
+	if (send_tok.length != 0) {
+		if (gss->first) {
+			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_INIT)) != 0 ||
+			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0 ||
+			    (r = sshpkt_put_stringb(ssh, kex->client_pub)) != 0)
+				fatal("failed to construct packet: %s", ssh_err(r));
+			gss->first = 0;
+		} else {
+			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
+			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0)
+				fatal("failed to construct packet: %s", ssh_err(r));
+		}
+		if ((r = sshpkt_send(ssh)) != 0)
+			fatal("failed to send packet: %s", ssh_err(r));
+		gss_release_buffer(&gss->minor, &send_tok);
+
+		/* If we've sent them data, they should reply */
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_HOSTKEY, &input_kexgss_hostkey);
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, &input_kexgss_continue);
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_COMPLETE, &input_kexgss_complete);
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_ERROR, &input_kexgss_error);
+		return 0;
+	}
+	/* No data, and not complete */
+	if (gss->major != GSS_S_COMPLETE)
+		fatal("Not complete, and no token output");
+
+	if  (gss->major & GSS_S_CONTINUE_NEEDED)
+		return kexgss_init_ctx(ssh, token_ptr);
+
+	return kexgss_final(ssh);
+}
+
 int
-kexgssgex_client(struct ssh *ssh)
+kexgss_client(struct ssh *ssh)
 {
 	struct kex *kex = ssh->kex;
-	gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER,
-	    recv_tok = GSS_C_EMPTY_BUFFER, gssbuf,
-            msg_tok = GSS_C_EMPTY_BUFFER, *token_ptr;
-	Gssctxt *ctxt;
-	OM_uint32 maj_status, min_status, ret_flags;
-	struct sshbuf *shared_secret = NULL;
-	BIGNUM *p = NULL;
-	BIGNUM *g = NULL;
-	struct sshbuf *buf = NULL;
-	struct sshbuf *server_host_key_blob = NULL;
-	struct sshbuf *server_blob = NULL;
-	BIGNUM *dh_server_pub = NULL;
-	u_char *msg;
-	int type = 0;
-	int first = 1;
-	u_char hash[SSH_DIGEST_MAX_LENGTH];
-	size_t hashlen;
-	const BIGNUM *pub_key, *dh_p, *dh_g;
-	int nbits = 0, min = DH_GRP_MIN, max = DH_GRP_MAX;
-	struct sshbuf *empty = NULL;
-	u_char c;
 	int r;
 
 	/* Initialise our GSSAPI world */
-	ssh_gssapi_build_ctx(&ctxt);
-	if (ssh_gssapi_id_kex(ctxt, kex->name, kex->kex_type)
-	    == GSS_C_NO_OID)
+	ssh_gssapi_build_ctx(&kex->gss);
+	if (ssh_gssapi_id_kex(kex->gss, kex->name, kex->kex_type) == GSS_C_NO_OID)
 		fatal("Couldn't identify host exchange");
 
-	if (ssh_gssapi_import_name(ctxt, kex->gss_host))
+	if (ssh_gssapi_import_name(kex->gss, kex->gss_host))
 		fatal("Couldn't import hostname");
 
 	if (kex->gss_client &&
-	    ssh_gssapi_client_identity(ctxt, kex->gss_client))
+	    ssh_gssapi_client_identity(kex->gss, kex->gss_client))
 		fatal("Couldn't acquire client credentials");
 
-	debug("Doing group exchange");
-	nbits = dh_estimate(kex->dh_need * 8);
+	/* Step 1 */
+	switch (kex->kex_type) {
+	case KEX_GSS_GRP1_SHA1:
+	case KEX_GSS_GRP14_SHA1:
+	case KEX_GSS_GRP14_SHA256:
+	case KEX_GSS_GRP16_SHA512:
+		r = kex_dh_keypair(kex);
+		break;
+	case KEX_GSS_NISTP256_SHA256:
+		r = kex_ecdh_keypair(kex);
+		break;
+	case KEX_GSS_C25519_SHA256:
+		r = kex_c25519_keypair(kex);
+		break;
+	default:
+		fatal_f("Unexpected KEX type %d", kex->kex_type);
+	}
+	if (r != 0) {
+		ssh_gssapi_delete_ctx(&kex->gss);
+		return r;
+	}
+	return kexgss_init_ctx(ssh, GSS_C_NO_BUFFER);
+}
 
-	kex->min = DH_GRP_MIN;
-	kex->max = DH_GRP_MAX;
-	kex->nbits = nbits;
-	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_GROUPREQ)) != 0 ||
-	    (r = sshpkt_put_u32(ssh, min)) != 0 ||
-	    (r = sshpkt_put_u32(ssh, nbits)) != 0 ||
-	    (r = sshpkt_put_u32(ssh, max)) != 0 ||
-	    (r = sshpkt_send(ssh)) != 0)
-		fatal("Failed to construct a packet: %s", ssh_err(r));
+static int
+input_kexgss_hostkey(int type,
+		     u_int32_t seq,
+		     struct ssh *ssh)
+{
+	Gssctxt *gss = ssh->kex->gss;
+	u_char *tmp = NULL;
+	size_t tmp_len = 0;
+	int r;
+
+	debug("Received KEXGSS_HOSTKEY");
+	if (gss->server_host_key_blob)
+		fatal("Server host key received more than once");
+	if ((r = sshpkt_get_string(ssh, &tmp, &tmp_len)) != 0)
+		fatal("Failed to read server host key: %s", ssh_err(r));
+	if ((gss->server_host_key_blob = sshbuf_from(tmp, tmp_len)) == NULL)
+		fatal("sshbuf_from failed");
+	return 0;
+}
 
-	if ((r = ssh_packet_read_expect(ssh, SSH2_MSG_KEXGSS_GROUP)) != 0)
-		fatal("Error: %s", ssh_err(r));
+static int
+input_kexgss_continue(int type,
+		      u_int32_t seq,
+		      struct ssh *ssh)
+{
+	Gssctxt *gss = ssh->kex->gss;
+	gss_buffer_desc recv_tok = GSS_C_EMPTY_BUFFER;
+	int r;
 
-	if ((r = sshpkt_get_bignum2(ssh, &p)) != 0 ||
-	    (r = sshpkt_get_bignum2(ssh, &g)) != 0 ||
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_HOSTKEY, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_COMPLETE, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_ERROR, NULL);
+
+	debug("Received GSSAPI_CONTINUE");
+	if (gss->major == GSS_S_COMPLETE)
+		fatal("GSSAPI Continue received from server when complete");
+	if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0 ||
 	    (r = sshpkt_get_end(ssh)) != 0)
-		fatal("shpkt_get_bignum2 failed: %s", ssh_err(r));
+		fatal("Failed to read token: %s", ssh_err(r));
+	if  (!(gss->major & GSS_S_CONTINUE_NEEDED))
+		fatal("Didn't receive a SSH2_MSG_KEXGSS_COMPLETE when I expected it");
+	return kexgss_init_ctx(ssh, &recv_tok);
+}
 
-	if (BN_num_bits(p) < min || BN_num_bits(p) > max)
-		fatal("GSSGRP_GEX group out of range: %d !< %d !< %d",
-		    min, BN_num_bits(p), max);
+static int
+input_kexgss_complete(int type,
+		      u_int32_t seq,
+		      struct ssh *ssh)
+{
+	Gssctxt *gss = ssh->kex->gss;
+	gss_buffer_desc recv_tok = GSS_C_EMPTY_BUFFER;
+	u_char c;
+	int r;
 
-	if ((kex->dh = dh_new_group(g, p)) == NULL)
-		fatal("dn_new_group() failed");
-	p = g = NULL; /* belong to kex->dh now */
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_HOSTKEY, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_COMPLETE, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_ERROR, NULL);
+
+	debug("Received GSSAPI_COMPLETE");
+	if (gss->msg_tok.value != NULL)
+	        fatal("Received GSSAPI_COMPLETE twice?");
+	if ((r = sshpkt_getb_froms(ssh, &gss->server_blob)) != 0 ||
+	    (r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &gss->msg_tok)) != 0)
+		fatal("Failed to read message: %s", ssh_err(r));
+
+	/* Is there a token included? */
+	if ((r = sshpkt_get_u8(ssh, &c)) != 0)
+		fatal("sshpkt failed: %s", ssh_err(r));
+	if (c) {
+		if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0)
+			fatal("Failed to read token: %s", ssh_err(r));
+		/* If we're already complete - protocol error */
+		if (gss->major == GSS_S_COMPLETE)
+			sshpkt_disconnect(ssh, "Protocol error: received token when complete");
+	} else {
+		if (gss->major != GSS_S_COMPLETE)
+			sshpkt_disconnect(ssh, "Protocol error: did not receive final token");
+	}
+	if ((r = sshpkt_get_end(ssh)) != 0)
+		fatal("Expecting end of packet.");
 
-	if ((r = dh_gen_key(kex->dh, kex->we_need * 8)) != 0)
-		goto out;
-	DH_get0_key(kex->dh, &pub_key, NULL);
+	if  (gss->major & GSS_S_CONTINUE_NEEDED)
+		return kexgss_init_ctx(ssh, &recv_tok);
 
-	token_ptr = GSS_C_NO_BUFFER;
+	return kexgss_final(ssh);
+}
 
-	do {
-		/* Step 2 - call GSS_Init_sec_context() */
-		debug("Calling gss_init_sec_context");
-
-		maj_status = ssh_gssapi_init_ctx(ctxt,
-		    kex->gss_deleg_creds, token_ptr, &send_tok,
-		    &ret_flags);
-
-		if (GSS_ERROR(maj_status)) {
-			/* XXX Useles code: Missing send? */
-			if (send_tok.length != 0) {
-				if ((r = sshpkt_start(ssh,
-				        SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
-				    (r = sshpkt_put_string(ssh, send_tok.value,
-				        send_tok.length)) != 0)
-					fatal("sshpkt failed: %s", ssh_err(r));
-			}
-			fatal("gss_init_context failed");
-		}
+static int
+input_kexgss_error(int type,
+		   u_int32_t seq,
+		   struct ssh *ssh)
+{
+	Gssctxt *gss = ssh->kex->gss;
+	u_char *msg;
+	int r;
 
-		/* If we've got an old receive buffer get rid of it */
-		if (token_ptr != GSS_C_NO_BUFFER)
-			gss_release_buffer(&min_status, &recv_tok);
-
-		if (maj_status == GSS_S_COMPLETE) {
-			/* If mutual state flag is not true, kex fails */
-			if (!(ret_flags & GSS_C_MUTUAL_FLAG))
-				fatal("Mutual authentication failed");
-
-			/* If integ avail flag is not true kex fails */
-			if (!(ret_flags & GSS_C_INTEG_FLAG))
-				fatal("Integrity check failed");
-		}
+	debug("Received Error");
+	if ((r = sshpkt_get_u32(ssh, &gss->major)) != 0 ||
+	    (r = sshpkt_get_u32(ssh, &gss->minor)) != 0 ||
+	    (r = sshpkt_get_string(ssh, &msg, NULL)) != 0 ||
+	    (r = sshpkt_get_string(ssh, NULL, NULL)) != 0 || /* lang tag */
+	    (r = sshpkt_get_end(ssh)) != 0)
+		fatal("sshpkt_get failed: %s", ssh_err(r));
+	fatal("GSSAPI Error: \n%.400s", msg);
+	return 0;
+}
 
-		/*
-		 * If we have data to send, then the last message that we
-		 * received cannot have been a 'complete'.
-		 */
-		if (send_tok.length != 0) {
-			if (first) {
-				if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_INIT)) != 0 ||
-				    (r = sshpkt_put_string(ssh, send_tok.value,
-				        send_tok.length)) != 0 ||
-				    (r = sshpkt_put_bignum2(ssh, pub_key)) != 0)
-					fatal("sshpkt failed: %s", ssh_err(r));
-				first = 0;
-			} else {
-				if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
-				    (r = sshpkt_put_string(ssh,send_tok.value,
-				        send_tok.length)) != 0)
-					fatal("sshpkt failed: %s", ssh_err(r));
-			}
-			if ((r = sshpkt_send(ssh)) != 0)
-				fatal("sshpkt_send failed: %s", ssh_err(r));
-			gss_release_buffer(&min_status, &send_tok);
-
-			/* If we've sent them data, they should reply */
-			do {
-				type = ssh_packet_read(ssh);
-				if (type == SSH2_MSG_KEXGSS_HOSTKEY) {
-					u_char *tmp = NULL;
-					size_t tmp_len = 0;
-
-					debug("Received KEXGSS_HOSTKEY");
-					if (server_host_key_blob)
-						fatal("Server host key received more than once");
-					if ((r = sshpkt_get_string(ssh, &tmp, &tmp_len)) != 0)
-						fatal("sshpkt failed: %s", ssh_err(r));
-					if ((server_host_key_blob = sshbuf_from(tmp, tmp_len)) == NULL)
-						fatal("sshbuf_from failed");
-				}
-			} while (type == SSH2_MSG_KEXGSS_HOSTKEY);
-
-			switch (type) {
-			case SSH2_MSG_KEXGSS_CONTINUE:
-				debug("Received GSSAPI_CONTINUE");
-				if (maj_status == GSS_S_COMPLETE)
-					fatal("GSSAPI Continue received from server when complete");
-				if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
-				        &recv_tok)) != 0 ||
-				    (r = sshpkt_get_end(ssh)) != 0)
-					fatal("sshpkt failed: %s", ssh_err(r));
-				break;
-			case SSH2_MSG_KEXGSS_COMPLETE:
-				debug("Received GSSAPI_COMPLETE");
-				if (msg_tok.value != NULL)
-				        fatal("Received GSSAPI_COMPLETE twice?");
-				if ((r = sshpkt_getb_froms(ssh, &server_blob)) != 0 ||
-				    (r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
-				        &msg_tok)) != 0)
-					fatal("sshpkt failed: %s", ssh_err(r));
-
-				/* Is there a token included? */
-				if ((r = sshpkt_get_u8(ssh, &c)) != 0)
-					fatal("sshpkt failed: %s", ssh_err(r));
-				if (c) {
-					if ((r = ssh_gssapi_sshpkt_get_buffer_desc(
-					        ssh, &recv_tok)) != 0 ||
-					    (r = sshpkt_get_end(ssh)) != 0)
-						fatal("sshpkt failed: %s", ssh_err(r));
-					/* If we're already complete - protocol error */
-					if (maj_status == GSS_S_COMPLETE)
-						sshpkt_disconnect(ssh, "Protocol error: received token when complete");
-				} else {
-					/* No token included */
-					if (maj_status != GSS_S_COMPLETE)
-						sshpkt_disconnect(ssh, "Protocol error: did not receive final token");
-				}
-				break;
-			case SSH2_MSG_KEXGSS_ERROR:
-				debug("Received Error");
-				if ((r = sshpkt_get_u32(ssh, &maj_status)) != 0 ||
-				    (r = sshpkt_get_u32(ssh, &min_status)) != 0 ||
-				    (r = sshpkt_get_string(ssh, &msg, NULL)) != 0 ||
-				    (r = sshpkt_get_string(ssh, NULL, NULL)) != 0 || /* lang tag */
-				    (r = sshpkt_get_end(ssh)) != 0)
-					fatal("sshpkt failed: %s", ssh_err(r));
-				fatal("GSSAPI Error: \n%.400s", msg);
-			default:
-				sshpkt_disconnect(ssh, "Protocol error: didn't expect packet type %d",
-				    type);
-			}
-			token_ptr = &recv_tok;
-		} else {
-			/* No data, and not complete */
-			if (maj_status != GSS_S_COMPLETE)
-				fatal("Not complete, and no token output");
-		}
-	} while (maj_status & GSS_S_CONTINUE_NEEDED);
+/*******************************************************/
+/******************** KEXGSSGEX ************************/
+/*******************************************************/
+
+int
+kexgssgex_client(struct ssh *ssh)
+{
+	struct kex *kex = ssh->kex;
+	int r;
+
+	/* Initialise our GSSAPI world */
+	ssh_gssapi_build_ctx(&kex->gss);
+	if (ssh_gssapi_id_kex(kex->gss, kex->name, kex->kex_type) == GSS_C_NO_OID)
+		fatal("Couldn't identify host exchange");
+
+	if (ssh_gssapi_import_name(kex->gss, kex->gss_host))
+		fatal("Couldn't import hostname");
+
+	if (kex->gss_client &&
+	    ssh_gssapi_client_identity(kex->gss, kex->gss_client))
+		fatal("Couldn't acquire client credentials");
+
+	debug("Doing group exchange");
+	kex->min = DH_GRP_MIN;
+	kex->max = DH_GRP_MAX;
+	kex->nbits = dh_estimate(kex->dh_need * 8);
+
+	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_GROUPREQ)) != 0 ||
+	    (r = sshpkt_put_u32(ssh, kex->min)) != 0 ||
+	    (r = sshpkt_put_u32(ssh, kex->nbits)) != 0 ||
+	    (r = sshpkt_put_u32(ssh, kex->max)) != 0 ||
+	    (r = sshpkt_send(ssh)) != 0)
+		fatal("Failed to construct a packet: %s", ssh_err(r));
+
+	debug("Wait SSH2_MSG_KEXGSS_GROUP");
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_GROUP, &input_kexgssgex_group);
+	return 0;
+}
+
+static int
+kexgssgex_final(struct ssh *ssh)
+{
+	struct kex *kex = ssh->kex;
+	Gssctxt *gss = kex->gss;
+	struct sshbuf *buf = NULL;
+	struct sshbuf *empty = NULL;
+	struct sshbuf *shared_secret = NULL;
+	BIGNUM *dh_server_pub = NULL;
+	const BIGNUM *pub_key, *dh_p, *dh_g;
+	u_char hash[SSH_DIGEST_MAX_LENGTH];
+	size_t hashlen;
+	int r = SSH_ERR_INTERNAL_ERROR;
 
 	/*
 	 * We _must_ have received a COMPLETE message in reply from the
-	 * server, which will have set dh_server_pub and msg_tok
+	 * server, which will have set server_blob and msg_tok
 	 */
 
-	if (type != SSH2_MSG_KEXGSS_COMPLETE)
-		fatal("Didn't receive a SSH2_MSG_KEXGSS_COMPLETE when I expected it");
-
 	/* 7. C verifies that the key Q_S is valid */
 	/* 8. C computes shared secret */
 	if ((buf = sshbuf_new()) == NULL ||
-	    (r = sshbuf_put_stringb(buf, server_blob)) != 0 ||
-	    (r = sshbuf_get_bignum2(buf, &dh_server_pub)) != 0)
+	    (r = sshbuf_put_stringb(buf, gss->server_blob)) != 0 ||
+	    (r = sshbuf_get_bignum2(buf, &dh_server_pub)) != 0) {
+		ssh_gssapi_delete_ctx(&kex->gss);
 		goto out;
+	}
 	sshbuf_free(buf);
 	buf = NULL;
 
 	if ((shared_secret = sshbuf_new()) == NULL) {
+		ssh_gssapi_delete_ctx(&kex->gss);
 		r = SSH_ERR_ALLOC_FAIL;
 		goto out;
 	}
 
-	if ((r = kex_dh_compute_key(kex, dh_server_pub, shared_secret)) != 0)
+	if ((r = kex_dh_compute_key(kex, dh_server_pub, shared_secret)) != 0) {
+		ssh_gssapi_delete_ctx(&kex->gss);
 		goto out;
+	}
+
 	if ((empty = sshbuf_new()) == NULL) {
+		ssh_gssapi_delete_ctx(&kex->gss);
 		r = SSH_ERR_ALLOC_FAIL;
 		goto out;
 	}
 
+	DH_get0_key(kex->dh, &pub_key, NULL);
 	DH_get0_pqg(kex->dh, &dh_p, NULL, &dh_g);
 	hashlen = sizeof(hash);
-	if ((r = kexgex_hash(
-	    kex->hash_alg,
-	    kex->client_version,
-	    kex->server_version,
-	    kex->my,
-	    kex->peer,
-	    (server_host_key_blob ? server_host_key_blob : empty),
- 	    kex->min, kex->nbits, kex->max,
-	    dh_p, dh_g,
-	    pub_key,
-	    dh_server_pub,
-	    sshbuf_ptr(shared_secret), sshbuf_len(shared_secret),
-	    hash, &hashlen)) != 0)
+	r = kexgex_hash(kex->hash_alg, kex->client_version,
+			kex->server_version, kex->my, kex->peer,
+			(gss->server_host_key_blob ? gss->server_host_key_blob : empty),
+			kex->min, kex->nbits, kex->max, dh_p, dh_g, pub_key,
+			dh_server_pub, sshbuf_ptr(shared_secret), sshbuf_len(shared_secret),
+			hash, &hashlen);
+	sshbuf_free(empty);
+	if (r != 0)
 		fatal("Failed to calculate hash: %s", ssh_err(r));
 
-	gssbuf.value = hash;
-	gssbuf.length = hashlen;
+	gss->buf.value = hash;
+	gss->buf.length = hashlen;
 
 	/* Verify that the hash matches the MIC we just got. */
-	if (GSS_ERROR(ssh_gssapi_checkmic(ctxt, &gssbuf, &msg_tok)))
+	if (GSS_ERROR(ssh_gssapi_checkmic(gss, &gss->buf, &gss->msg_tok)))
 		sshpkt_disconnect(ssh, "Hash's MIC didn't verify");
 
-	gss_release_buffer(&min_status, &msg_tok);
+	gss_release_buffer(&gss->minor, &gss->msg_tok);
 
 	if (kex->gss_deleg_creds)
-		ssh_gssapi_credentials_updated(ctxt);
+		ssh_gssapi_credentials_updated(gss);
 
 	if (gss_kex_context == NULL)
-		gss_kex_context = ctxt;
+		gss_kex_context = gss;
 	else
-		ssh_gssapi_delete_ctx(&ctxt);
+		ssh_gssapi_delete_ctx(&kex->gss);
 
 	/* Finally derive the keys and send them */
 	if ((r = kex_derive_keys(ssh, hash, hashlen, shared_secret)) == 0)
 		r = kex_send_newkeys(ssh);
+
+	if (kex->gss != NULL) {
+		sshbuf_free(gss->server_host_key_blob);
+		gss->server_host_key_blob = NULL;
+		sshbuf_free(gss->server_blob);
+		gss->server_blob = NULL;
+	}
 out:
-	sshbuf_free(buf);
-	sshbuf_free(server_blob);
-	sshbuf_free(empty);
 	explicit_bzero(hash, sizeof(hash));
 	DH_free(kex->dh);
 	kex->dh = NULL;
 	BN_clear_free(dh_server_pub);
 	sshbuf_free(shared_secret);
-	sshbuf_free(server_host_key_blob);
 	return r;
 }
 
+static int
+kexgssgex_init_ctx(struct ssh *ssh,
+		   gss_buffer_desc *token_ptr)
+{
+	struct kex *kex = ssh->kex;
+	Gssctxt *gss = kex->gss;
+	const BIGNUM *pub_key;
+	gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER;
+	OM_uint32 ret_flags;
+	int r;
+
+	/* Step 2 - call GSS_Init_sec_context() */
+	debug("Calling gss_init_sec_context");
+
+	gss->major = ssh_gssapi_init_ctx(gss, kex->gss_deleg_creds,
+					 token_ptr, &send_tok, &ret_flags);
+
+	if (GSS_ERROR(gss->major)) {
+		/* XXX Useless code: Missing send? */
+		if (send_tok.length != 0) {
+			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
+			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0)
+				fatal("sshpkt failed: %s", ssh_err(r));
+		}
+		fatal("gss_init_context failed");
+	}
+
+	/* If we've got an old receive buffer get rid of it */
+	if (token_ptr != GSS_C_NO_BUFFER)
+		gss_release_buffer(&gss->minor, token_ptr);
+
+	if (gss->major == GSS_S_COMPLETE) {
+		/* If mutual state flag is not true, kex fails */
+		if (!(ret_flags & GSS_C_MUTUAL_FLAG))
+			fatal("Mutual authentication failed");
+
+		/* If integ avail flag is not true kex fails */
+		if (!(ret_flags & GSS_C_INTEG_FLAG))
+			fatal("Integrity check failed");
+	}
+
+	/*
+	 * If we have data to send, then the last message that we
+	 * received cannot have been a 'complete'.
+	 */
+	if (send_tok.length != 0) {
+		if (gss->first) {
+	                DH_get0_key(kex->dh, &pub_key, NULL);
+			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_INIT)) != 0 ||
+			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0 ||
+			    (r = sshpkt_put_bignum2(ssh, pub_key)) != 0)
+				fatal("failed to construct packet: %s", ssh_err(r));
+			gss->first = 0;
+		} else {
+			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
+			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0)
+				fatal("failed to construct packet: %s", ssh_err(r));
+		}
+		if ((r = sshpkt_send(ssh)) != 0)
+			fatal("failed to send packet: %s", ssh_err(r));
+		gss_release_buffer(&gss->minor, &send_tok);
+
+		/* If we've sent them data, they should reply */
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_HOSTKEY, &input_kexgss_hostkey);
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, &input_kexgssgex_continue);
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_COMPLETE, &input_kexgssgex_complete);
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_ERROR, &input_kexgss_error);
+		return 0;
+	}
+	/* No data, and not complete */
+	if (gss->major != GSS_S_COMPLETE)
+		fatal("Not complete, and no token output");
+
+	if  (gss->major & GSS_S_CONTINUE_NEEDED)
+		return kexgssgex_init_ctx(ssh, token_ptr);
+
+	return kexgssgex_final(ssh);
+}
+
+static int
+input_kexgssgex_group(int type,
+		      u_int32_t seq,
+		      struct ssh *ssh)
+{
+	struct kex *kex = ssh->kex;
+	BIGNUM *p = NULL;
+	BIGNUM *g = NULL;
+	int r;
+
+	debug("Received SSH2_MSG_KEXGSS_GROUP");
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_GROUP, NULL);
+
+	if ((r = sshpkt_get_bignum2(ssh, &p)) != 0 ||
+	    (r = sshpkt_get_bignum2(ssh, &g)) != 0 ||
+	    (r = sshpkt_get_end(ssh)) != 0)
+		fatal("shpkt_get_bignum2 failed: %s", ssh_err(r));
+
+	if (BN_num_bits(p) < kex->min || BN_num_bits(p) > kex->max)
+		fatal("GSSGRP_GEX group out of range: %d !< %d !< %d",
+		    kex->min, BN_num_bits(p), kex->max);
+
+	if ((kex->dh = dh_new_group(g, p)) == NULL)
+		fatal("dn_new_group() failed");
+	p = g = NULL; /* belong to kex->dh now */
+
+	if ((r = dh_gen_key(kex->dh, kex->we_need * 8)) != 0) {
+		ssh_gssapi_delete_ctx(&kex->gss);
+		DH_free(kex->dh);
+		kex->dh = NULL;
+		return r;
+	}
+
+	return kexgssgex_init_ctx(ssh, GSS_C_NO_BUFFER);
+}
+
+static int
+input_kexgssgex_continue(int type,
+			 u_int32_t seq,
+			 struct ssh *ssh)
+{
+	Gssctxt *gss = ssh->kex->gss;
+	gss_buffer_desc recv_tok = GSS_C_EMPTY_BUFFER;
+	int r;
+
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_HOSTKEY, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_COMPLETE, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_ERROR, NULL);
+
+	debug("Received GSSAPI_CONTINUE");
+	if (gss->major == GSS_S_COMPLETE)
+		fatal("GSSAPI Continue received from server when complete");
+	if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0 ||
+	    (r = sshpkt_get_end(ssh)) != 0)
+		fatal("Failed to read token: %s", ssh_err(r));
+	if  (!(gss->major & GSS_S_CONTINUE_NEEDED))
+		fatal("Didn't receive a SSH2_MSG_KEXGSS_COMPLETE when I expected it");
+	return kexgssgex_init_ctx(ssh, &recv_tok);
+}
+
+static int
+input_kexgssgex_complete(int type,
+		      u_int32_t seq,
+		      struct ssh *ssh)
+{
+	Gssctxt *gss = ssh->kex->gss;
+	gss_buffer_desc recv_tok = GSS_C_EMPTY_BUFFER;
+	u_char c;
+	int r;
+
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_HOSTKEY, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_COMPLETE, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_ERROR, NULL);
+
+	debug("Received GSSAPI_COMPLETE");
+	if (gss->msg_tok.value != NULL)
+	        fatal("Received GSSAPI_COMPLETE twice?");
+	if ((r = sshpkt_getb_froms(ssh, &gss->server_blob)) != 0 ||
+	    (r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &gss->msg_tok)) != 0)
+		fatal("Failed to read message: %s", ssh_err(r));
+
+	/* Is there a token included? */
+	if ((r = sshpkt_get_u8(ssh, &c)) != 0)
+		fatal("sshpkt failed: %s", ssh_err(r));
+	if (c) {
+		if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0)
+			fatal("Failed to read token: %s", ssh_err(r));
+		/* If we're already complete - protocol error */
+		if (gss->major == GSS_S_COMPLETE)
+			sshpkt_disconnect(ssh, "Protocol error: received token when complete");
+	} else {
+		if (gss->major != GSS_S_COMPLETE)
+			sshpkt_disconnect(ssh, "Protocol error: did not receive final token");
+	}
+	if ((r = sshpkt_get_end(ssh)) != 0)
+		fatal("Expecting end of packet.");
+
+	if  (gss->major & GSS_S_CONTINUE_NEEDED)
+		return kexgssgex_init_ctx(ssh, &recv_tok);
+
+	return kexgssgex_final(ssh);
+}
+
 #endif /* defined(GSSAPI) && defined(WITH_OPENSSL) */
diff --color -ruNp a/kexgsss.c b/kexgsss.c
--- a/kexgsss.c	2024-05-16 15:49:43.820407648 +0200
+++ b/kexgsss.c	2024-07-02 16:29:05.744790839 +0200
@@ -50,33 +50,18 @@
 
 extern ServerOptions options;
 
+static int input_kexgss_init(int, u_int32_t, struct ssh *);
+static int input_kexgss_continue(int, u_int32_t, struct ssh *);
+static int input_kexgssgex_groupreq(int, u_int32_t, struct ssh *);
+static int input_kexgssgex_init(int, u_int32_t, struct ssh *);
+static int input_kexgssgex_continue(int, u_int32_t, struct ssh *);
+
 int
 kexgss_server(struct ssh *ssh)
 {
 	struct kex *kex = ssh->kex;
-	OM_uint32 maj_status, min_status;
-
-	/*
-	 * Some GSSAPI implementations use the input value of ret_flags (an
-	 * output variable) as a means of triggering mechanism specific
-	 * features. Initializing it to zero avoids inadvertently
-	 * activating this non-standard behaviour.
-	 */
-
-	OM_uint32 ret_flags = 0;
-	gss_buffer_desc gssbuf = {0, NULL}, recv_tok, msg_tok;
-	gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER;
-	Gssctxt *ctxt = NULL;
-	struct sshbuf *shared_secret = NULL;
-	struct sshbuf *client_pubkey = NULL;
-	struct sshbuf *server_pubkey = NULL;
-	struct sshbuf *empty = sshbuf_new();
-	int type = 0;
 	gss_OID oid;
 	char *mechs;
-	u_char hash[SSH_DIGEST_MAX_LENGTH];
-	size_t hashlen;
-	int r;
 
 	/* Initialise GSSAPI */
 
@@ -92,135 +77,91 @@ kexgss_server(struct ssh *ssh)
 	debug2_f("Identifying %s", kex->name);
 	oid = ssh_gssapi_id_kex(NULL, kex->name, kex->kex_type);
 	if (oid == GSS_C_NO_OID)
-	   fatal("Unknown gssapi mechanism");
+		fatal("Unknown gssapi mechanism");
 
 	debug2_f("Acquiring credentials");
 
-	if (GSS_ERROR(mm_ssh_gssapi_server_ctx(&ctxt, oid)))
+	if (GSS_ERROR(mm_ssh_gssapi_server_ctx(&kex->gss, oid)))
 		fatal("Unable to acquire credentials for the server");
 
-	do {
-		debug("Wait SSH2_MSG_KEXGSS_INIT");
-		type = ssh_packet_read(ssh);
-		switch(type) {
-		case SSH2_MSG_KEXGSS_INIT:
-			if (gssbuf.value != NULL)
-				fatal("Received KEXGSS_INIT after initialising");
-			if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
-			        &recv_tok)) != 0 ||
-			    (r = sshpkt_getb_froms(ssh, &client_pubkey)) != 0 ||
-			    (r = sshpkt_get_end(ssh)) != 0)
-				fatal("sshpkt failed: %s", ssh_err(r));
+	ssh_gssapi_build_ctx(&kex->gss);
+	if (kex->gss == NULL)
+		fatal("Unable to allocate memory for gss context");
+
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_INIT, &input_kexgss_init);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, &input_kexgss_continue);
+	debug("Wait SSH2_MSG_KEXGSS_INIT");
+	return 0;
+}
 
-			switch (kex->kex_type) {
-			case KEX_GSS_GRP1_SHA1:
-			case KEX_GSS_GRP14_SHA1:
-			case KEX_GSS_GRP14_SHA256:
-			case KEX_GSS_GRP16_SHA512:
-				r = kex_dh_enc(kex, client_pubkey, &server_pubkey,
-				    &shared_secret);
-				break;
-			case KEX_GSS_NISTP256_SHA256:
-				r = kex_ecdh_enc(kex, client_pubkey, &server_pubkey,
-				    &shared_secret);
-				break;
-			case KEX_GSS_C25519_SHA256:
-				r = kex_c25519_enc(kex, client_pubkey, &server_pubkey,
-				    &shared_secret);
-				break;
-			default:
-				fatal_f("Unexpected KEX type %d", kex->kex_type);
-			}
-			if (r != 0)
-				goto out;
-
-			/* Send SSH_MSG_KEXGSS_HOSTKEY here, if we want */
-
-			/* Calculate the hash early so we can free the
-			* client_pubkey, which has reference to the parent
-			* buffer state->incoming_packet
-			*/
-			hashlen = sizeof(hash);
-			if ((r = kex_gen_hash(
-			    kex->hash_alg,
-			    kex->client_version,
-			    kex->server_version,
-			    kex->peer,
-			    kex->my,
-			    empty,
-			    client_pubkey,
-			    server_pubkey,
-			    shared_secret,
-			    hash, &hashlen)) != 0)
-				goto out;
-
-			gssbuf.value = hash;
-			gssbuf.length = hashlen;
-
-			sshbuf_free(client_pubkey);
-			client_pubkey = NULL;
-
-			break;
-		case SSH2_MSG_KEXGSS_CONTINUE:
-			if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
-			        &recv_tok)) != 0 ||
-			    (r = sshpkt_get_end(ssh)) != 0)
-				fatal("sshpkt failed: %s", ssh_err(r));
-			break;
-		default:
-			sshpkt_disconnect(ssh,
-			    "Protocol error: didn't expect packet type %d",
-			    type);
-		}
+static inline void
+kexgss_accept_ctx(struct ssh *ssh,
+		  gss_buffer_desc *recv_tok,
+		  gss_buffer_desc *send_tok,
+		  OM_uint32 *ret_flags)
+{
+	Gssctxt *gss = ssh->kex->gss;
+	int r;
 
-		maj_status = mm_ssh_gssapi_accept_ctx(ctxt, &recv_tok,
-		    &send_tok, &ret_flags);
+	gss->major = mm_ssh_gssapi_accept_ctx(gss, recv_tok, send_tok, ret_flags);
+	gss_release_buffer(&gss->minor, recv_tok);
 
-		gss_release_buffer(&min_status, &recv_tok);
+	if (gss->major != GSS_S_COMPLETE && send_tok->length == 0)
+		fatal("Zero length token output when incomplete");
 
-		if (maj_status != GSS_S_COMPLETE && send_tok.length == 0)
-			fatal("Zero length token output when incomplete");
+	if (gss->buf.value == NULL)
+		fatal("No client public key");
 
-		if (gssbuf.value == NULL)
-			fatal("No client public key");
+	if (gss->major & GSS_S_CONTINUE_NEEDED) {
+		debug("Sending GSSAPI_CONTINUE");
+		if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
+		    (r = sshpkt_put_string(ssh, send_tok->value, send_tok->length)) != 0 ||
+		    (r = sshpkt_send(ssh)) != 0)
+			fatal("sshpkt failed: %s", ssh_err(r));
+		gss_release_buffer(&gss->minor, send_tok);
+	}
+}
 
-		if (maj_status & GSS_S_CONTINUE_NEEDED) {
-			debug("Sending GSSAPI_CONTINUE");
-			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
-			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0 ||
-			    (r = sshpkt_send(ssh)) != 0)
-				fatal("sshpkt failed: %s", ssh_err(r));
-			gss_release_buffer(&min_status, &send_tok);
-		}
-	} while (maj_status & GSS_S_CONTINUE_NEEDED);
+static inline int
+kexgss_final(struct ssh *ssh,
+	     gss_buffer_desc *send_tok,
+	     OM_uint32 *ret_flags)
+{
+	struct kex *kex = ssh->kex;
+	Gssctxt *gss = kex->gss;
+	gss_buffer_desc msg_tok;
+	int r;
+
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_INIT, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, NULL);
 
-	if (GSS_ERROR(maj_status)) {
-		if (send_tok.length > 0) {
+	if (GSS_ERROR(gss->major)) {
+		if (send_tok->length > 0) {
 			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
-			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0 ||
+			    (r = sshpkt_put_string(ssh, send_tok->value, send_tok->length)) != 0 ||
 			    (r = sshpkt_send(ssh)) != 0)
 				fatal("sshpkt failed: %s", ssh_err(r));
 		}
 		fatal("accept_ctx died");
 	}
 
-	if (!(ret_flags & GSS_C_MUTUAL_FLAG))
+	if (!(*ret_flags & GSS_C_MUTUAL_FLAG))
 		fatal("Mutual Authentication flag wasn't set");
 
-	if (!(ret_flags & GSS_C_INTEG_FLAG))
+	if (!(*ret_flags & GSS_C_INTEG_FLAG))
 		fatal("Integrity flag wasn't set");
 
-	if (GSS_ERROR(mm_ssh_gssapi_sign(ctxt, &gssbuf, &msg_tok)))
+	if (GSS_ERROR(mm_ssh_gssapi_sign(gss, &gss->buf, &msg_tok)))
 		fatal("Couldn't get MIC");
 
 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_COMPLETE)) != 0 ||
-	    (r = sshpkt_put_stringb(ssh, server_pubkey)) != 0 ||
+	    (r = sshpkt_put_stringb(ssh, gss->server_pubkey)) != 0 ||
 	    (r = sshpkt_put_string(ssh, msg_tok.value, msg_tok.length)) != 0)
 		fatal("sshpkt failed: %s", ssh_err(r));
 
-	if (send_tok.length != 0) {
+	if (send_tok->length != 0) {
 		if ((r = sshpkt_put_u8(ssh, 1)) != 0 || /* true */
-		    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0)
+		    (r = sshpkt_put_string(ssh, send_tok->value, send_tok->length)) != 0)
 			fatal("sshpkt failed: %s", ssh_err(r));
 	} else {
 		if ((r = sshpkt_put_u8(ssh, 0)) != 0) /* false */
@@ -229,59 +170,139 @@ kexgss_server(struct ssh *ssh)
 	if ((r = sshpkt_send(ssh)) != 0)
 		fatal("sshpkt_send failed: %s", ssh_err(r));
 
-	gss_release_buffer(&min_status, &send_tok);
-	gss_release_buffer(&min_status, &msg_tok);
+	gss_release_buffer(&gss->minor, send_tok);
+	gss_release_buffer(&gss->minor, &msg_tok);
 
 	if (gss_kex_context == NULL)
-		gss_kex_context = ctxt;
+		gss_kex_context = gss;
 	else
-		ssh_gssapi_delete_ctx(&ctxt);
+		ssh_gssapi_delete_ctx(&kex->gss);
 
-	if ((r = kex_derive_keys(ssh, hash, hashlen, shared_secret)) == 0)
+	if ((r = kex_derive_keys(ssh, gss->hash, gss->hashlen, gss->shared_secret)) == 0)
 		r = kex_send_newkeys(ssh);
 
 	/* If this was a rekey, then save out any delegated credentials we
 	 * just exchanged.  */
 	if (options.gss_store_rekey)
 		ssh_gssapi_rekey_creds();
-out:
-	sshbuf_free(empty);
-	explicit_bzero(hash, sizeof(hash));
-	sshbuf_free(shared_secret);
-	sshbuf_free(client_pubkey);
-	sshbuf_free(server_pubkey);
+
+	if (kex->gss != NULL) {
+		explicit_bzero(gss->hash, sizeof(gss->hash));
+		sshbuf_free(gss->shared_secret);
+		gss->shared_secret = NULL;
+		sshbuf_free(gss->server_pubkey);
+		gss->server_pubkey = NULL;
+	}
 	return r;
 }
 
-int
-kexgssgex_server(struct ssh *ssh)
+static int
+input_kexgss_init(int type,
+		  u_int32_t seq,
+		  struct ssh *ssh)
 {
 	struct kex *kex = ssh->kex;
-	OM_uint32 maj_status, min_status;
+	Gssctxt *gss = kex->gss;
+	struct sshbuf *empty;
+	struct sshbuf *client_pubkey = NULL;
+	gss_buffer_desc recv_tok, send_tok = GSS_C_EMPTY_BUFFER;
+	OM_uint32 ret_flags = 0;
+	int r;
+
+	debug("SSH2_MSG_KEXGSS_INIT received");
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_INIT, NULL);
 
-	/*
-	 * Some GSSAPI implementations use the input value of ret_flags (an
-	 * output variable) as a means of triggering mechanism specific
-	 * features. Initializing it to zero avoids inadvertently
-	 * activating this non-standard behaviour.
+	if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0 ||
+	    (r = sshpkt_getb_froms(ssh, &client_pubkey)) != 0 ||
+	    (r = sshpkt_get_end(ssh)) != 0)
+		fatal("sshpkt failed: %s", ssh_err(r));
+
+	switch (kex->kex_type) {
+	case KEX_GSS_GRP1_SHA1:
+	case KEX_GSS_GRP14_SHA1:
+	case KEX_GSS_GRP14_SHA256:
+	case KEX_GSS_GRP16_SHA512:
+		r = kex_dh_enc(kex, client_pubkey, &gss->server_pubkey, &gss->shared_secret);
+		break;
+	case KEX_GSS_NISTP256_SHA256:
+		r = kex_ecdh_enc(kex, client_pubkey, &gss->server_pubkey, &gss->shared_secret);
+		break;
+	case KEX_GSS_C25519_SHA256:
+		r = kex_c25519_enc(kex, client_pubkey, &gss->server_pubkey, &gss->shared_secret);
+		break;
+	default:
+		fatal_f("Unexpected KEX type %d", kex->kex_type);
+	}
+	if (r != 0) {
+		sshbuf_free(client_pubkey);
+                ssh_gssapi_delete_ctx(&kex->gss);
+		return r;
+	}
+
+	/* Send SSH_MSG_KEXGSS_HOSTKEY here, if we want */
+
+	if ((empty = sshbuf_new()) == NULL) {
+		sshbuf_free(client_pubkey);
+		ssh_gssapi_delete_ctx(&kex->gss);
+		return SSH_ERR_ALLOC_FAIL;
+	}
+
+	/* Calculate the hash early so we can free the
+	 * client_pubkey, which has reference to the parent
+	 * buffer state->incoming_packet
 	 */
+	gss->hashlen = sizeof(gss->hash);
+	r = kex_gen_hash(kex->hash_alg, kex->client_version, kex->server_version,
+			 kex->peer, kex->my, empty, client_pubkey, gss->server_pubkey,
+			 gss->shared_secret, gss->hash, &gss->hashlen);
+	sshbuf_free(empty);
+	sshbuf_free(client_pubkey);
+	if (r != 0) {
+		ssh_gssapi_delete_ctx(&kex->gss);
+		return r;
+	}
+
+	gss->buf.value = gss->hash;
+	gss->buf.length = gss->hashlen;
+
+	kexgss_accept_ctx(ssh, &recv_tok, &send_tok, &ret_flags);
+	if (gss->major & GSS_S_CONTINUE_NEEDED)
+		return 0;
 
+	return kexgss_final(ssh, &send_tok, &ret_flags);
+}
+
+static int
+input_kexgss_continue(int type,
+		      u_int32_t seq,
+		      struct ssh *ssh)
+{
+	Gssctxt *gss = ssh->kex->gss;
+	gss_buffer_desc recv_tok, send_tok = GSS_C_EMPTY_BUFFER;
 	OM_uint32 ret_flags = 0;
-	gss_buffer_desc gssbuf, recv_tok, msg_tok;
-	gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER;
-	Gssctxt *ctxt = NULL;
-	struct sshbuf *shared_secret = NULL;
-	int type = 0;
+	int r;
+
+	if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0 ||
+	    (r = sshpkt_get_end(ssh)) != 0)
+		fatal("sshpkt failed: %s", ssh_err(r));
+
+	kexgss_accept_ctx(ssh, &recv_tok, &send_tok, &ret_flags);
+	if (gss->major & GSS_S_CONTINUE_NEEDED)
+		return 0;
+
+	return kexgss_final(ssh, &send_tok, &ret_flags);
+}
+
+/*******************************************************/
+/******************** KEXGSSGEX ************************/
+/*******************************************************/
+
+int
+kexgssgex_server(struct ssh *ssh)
+{
+	struct kex *kex = ssh->kex;
 	gss_OID oid;
 	char *mechs;
-	u_char hash[SSH_DIGEST_MAX_LENGTH];
-	size_t hashlen;
-	BIGNUM *dh_client_pub = NULL;
-	const BIGNUM *pub_key, *dh_p, *dh_g;
-	int min = -1, max = -1, nbits = -1;
-	int cmin = -1, cmax = -1; /* client proposal */
-	struct sshbuf *empty = sshbuf_new();
-	int r;
 
 	/* Initialise GSSAPI */
 
@@ -289,153 +310,125 @@ kexgssgex_server(struct ssh *ssh)
 	 * in the GSSAPI code are no longer available. This kludges them back
 	 * into life
 	 */
-	if (!ssh_gssapi_oid_table_ok())
-		if ((mechs = ssh_gssapi_server_mechanisms()))
-			free(mechs);
+	if (!ssh_gssapi_oid_table_ok()) {
+		mechs = ssh_gssapi_server_mechanisms();
+		free(mechs);
+	}
 
 	debug2_f("Identifying %s", kex->name);
 	oid = ssh_gssapi_id_kex(NULL, kex->name, kex->kex_type);
 	if (oid == GSS_C_NO_OID)
-	   fatal("Unknown gssapi mechanism");
+		fatal("Unknown gssapi mechanism");
 
 	debug2_f("Acquiring credentials");
 
-	if (GSS_ERROR(mm_ssh_gssapi_server_ctx(&ctxt, oid)))
+	if (GSS_ERROR(mm_ssh_gssapi_server_ctx(&kex->gss, oid)))
 		fatal("Unable to acquire credentials for the server");
 
-	/* 5. S generates an ephemeral key pair (do the allocations early) */
-	debug("Doing group exchange");
-	ssh_packet_read_expect(ssh, SSH2_MSG_KEXGSS_GROUPREQ);
-	/* store client proposal to provide valid signature */
-	if ((r = sshpkt_get_u32(ssh, &cmin)) != 0 ||
-	    (r = sshpkt_get_u32(ssh, &nbits)) != 0 ||
-	    (r = sshpkt_get_u32(ssh, &cmax)) != 0 ||
-	    (r = sshpkt_get_end(ssh)) != 0)
-		fatal("sshpkt failed: %s", ssh_err(r));
-	kex->nbits = nbits;
-	kex->min = cmin;
-	kex->max = cmax;
-	min = MAX(DH_GRP_MIN, cmin);
-	max = MIN(DH_GRP_MAX, cmax);
-	nbits = MAXIMUM(DH_GRP_MIN, nbits);
-	nbits = MINIMUM(DH_GRP_MAX, nbits);
-	if (max < min || nbits < min || max < nbits)
-		fatal("GSS_GEX, bad parameters: %d !< %d !< %d",
-		    min, nbits, max);
-	kex->dh = mm_choose_dh(min, nbits, max);
-	if (kex->dh == NULL) {
-		sshpkt_disconnect(ssh, "Protocol error: no matching group found");
-		fatal("Protocol error: no matching group found");
-	}
+	ssh_gssapi_build_ctx(&kex->gss);
+	if (kex->gss == NULL)
+		fatal("Unable to allocate memory for gss context");
 
-	DH_get0_pqg(kex->dh, &dh_p, NULL, &dh_g);
-	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_GROUP)) != 0 ||
-	    (r = sshpkt_put_bignum2(ssh, dh_p)) != 0 ||
-	    (r = sshpkt_put_bignum2(ssh, dh_g)) != 0 ||
-	    (r = sshpkt_send(ssh)) != 0)
-		fatal("sshpkt failed: %s", ssh_err(r));
-
-	if ((r = ssh_packet_write_wait(ssh)) != 0)
-		fatal("ssh_packet_write_wait: %s", ssh_err(r));
-
-	/* Compute our exchange value in parallel with the client */
-	if ((r = dh_gen_key(kex->dh, kex->we_need * 8)) != 0)
-		goto out;
+	debug("Doing group exchange");
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_GROUPREQ, &input_kexgssgex_groupreq);
+	return 0;
+}
 
-	do {
-		debug("Wait SSH2_MSG_GSSAPI_INIT");
-		type = ssh_packet_read(ssh);
-		switch(type) {
-		case SSH2_MSG_KEXGSS_INIT:
-			if (dh_client_pub != NULL)
-				fatal("Received KEXGSS_INIT after initialising");
-			if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
-			        &recv_tok)) != 0 ||
-			    (r = sshpkt_get_bignum2(ssh, &dh_client_pub)) != 0 ||
-			    (r = sshpkt_get_end(ssh)) != 0)
-				fatal("sshpkt failed: %s", ssh_err(r));
+static inline void
+kexgssgex_accept_ctx(struct ssh *ssh,
+		     gss_buffer_desc *recv_tok,
+		     gss_buffer_desc *send_tok,
+		     OM_uint32 *ret_flags)
+{
+	Gssctxt *gss = ssh->kex->gss;
+	int r;
 
-			/* Send SSH_MSG_KEXGSS_HOSTKEY here, if we want */
-			break;
-		case SSH2_MSG_KEXGSS_CONTINUE:
-			if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
-			        &recv_tok)) != 0 ||
-			    (r = sshpkt_get_end(ssh)) != 0)
-				fatal("sshpkt failed: %s", ssh_err(r));
-			break;
-		default:
-			sshpkt_disconnect(ssh,
-			    "Protocol error: didn't expect packet type %d",
-			    type);
-		}
+	gss->major = mm_ssh_gssapi_accept_ctx(gss, recv_tok, send_tok, ret_flags);
+	gss_release_buffer(&gss->minor, recv_tok);
 
-		maj_status = mm_ssh_gssapi_accept_ctx(ctxt, &recv_tok,
-		    &send_tok, &ret_flags);
+	if (gss->major != GSS_S_COMPLETE && send_tok->length == 0)
+		fatal("Zero length token output when incomplete");
 
-		gss_release_buffer(&min_status, &recv_tok);
+	if (gss->dh_client_pub == NULL)
+		fatal("No client public key");
 
-		if (maj_status != GSS_S_COMPLETE && send_tok.length == 0)
-			fatal("Zero length token output when incomplete");
+	if (gss->major & GSS_S_CONTINUE_NEEDED) {
+		debug("Sending GSSAPI_CONTINUE");
+		if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
+		    (r = sshpkt_put_string(ssh, send_tok->value, send_tok->length)) != 0 ||
+		    (r = sshpkt_send(ssh)) != 0)
+			fatal("sshpkt failed: %s", ssh_err(r));
+		gss_release_buffer(&gss->minor, send_tok);
+	}
+}
 
-		if (dh_client_pub == NULL)
-			fatal("No client public key");
+static inline int
+kexgssgex_final(struct ssh *ssh,
+		gss_buffer_desc *send_tok,
+		OM_uint32 *ret_flags)
+{
+	struct kex *kex = ssh->kex;
+	Gssctxt *gss = kex->gss;
+	gss_buffer_desc msg_tok;
+	u_char hash[SSH_DIGEST_MAX_LENGTH];
+	size_t hashlen;
+	const BIGNUM *pub_key, *dh_p, *dh_g;
+	struct sshbuf *shared_secret = NULL;
+	struct sshbuf *empty = NULL;
+	int r;
 
-		if (maj_status & GSS_S_CONTINUE_NEEDED) {
-			debug("Sending GSSAPI_CONTINUE");
-			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
-			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0 ||
-			    (r = sshpkt_send(ssh)) != 0)
-				fatal("sshpkt failed: %s", ssh_err(r));
-			gss_release_buffer(&min_status, &send_tok);
-		}
-	} while (maj_status & GSS_S_CONTINUE_NEEDED);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_INIT, NULL);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, NULL);
 
-	if (GSS_ERROR(maj_status)) {
-		if (send_tok.length > 0) {
+	if (GSS_ERROR(gss->major)) {
+		if (send_tok->length > 0) {
 			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
-			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0 ||
+			    (r = sshpkt_put_string(ssh, send_tok->value, send_tok->length)) != 0 ||
 			    (r = sshpkt_send(ssh)) != 0)
 				fatal("sshpkt failed: %s", ssh_err(r));
 		}
 		fatal("accept_ctx died");
 	}
 
-	if (!(ret_flags & GSS_C_MUTUAL_FLAG))
+	if (!(*ret_flags & GSS_C_MUTUAL_FLAG))
 		fatal("Mutual Authentication flag wasn't set");
 
-	if (!(ret_flags & GSS_C_INTEG_FLAG))
+	if (!(*ret_flags & GSS_C_INTEG_FLAG))
 		fatal("Integrity flag wasn't set");
 
 	/* calculate shared secret */
-	if ((shared_secret = sshbuf_new()) == NULL) {
+	shared_secret = sshbuf_new();
+	if (shared_secret == NULL) {
+		ssh_gssapi_delete_ctx(&kex->gss);
 		r = SSH_ERR_ALLOC_FAIL;
 		goto out;
 	}
-	if ((r = kex_dh_compute_key(kex, dh_client_pub, shared_secret)) != 0)
+	if ((r = kex_dh_compute_key(kex, gss->dh_client_pub, shared_secret)) != 0) {
+		ssh_gssapi_delete_ctx(&kex->gss);
 		goto out;
+	}
+
+	if ((empty = sshbuf_new()) == NULL) {
+		ssh_gssapi_delete_ctx(&kex->gss);
+		r = SSH_ERR_ALLOC_FAIL;
+		goto out;
+	}
 
 	DH_get0_key(kex->dh, &pub_key, NULL);
 	DH_get0_pqg(kex->dh, &dh_p, NULL, &dh_g);
 	hashlen = sizeof(hash);
-	if ((r = kexgex_hash(
-	    kex->hash_alg,
-	    kex->client_version,
-	    kex->server_version,
-	    kex->peer,
-	    kex->my,
-	    empty,
-	    cmin, nbits, cmax,
-	    dh_p, dh_g,
-	    dh_client_pub,
-	    pub_key,
-	    sshbuf_ptr(shared_secret), sshbuf_len(shared_secret),
-	    hash, &hashlen)) != 0)
+	r = kexgex_hash(kex->hash_alg, kex->client_version, kex->server_version,
+			kex->peer, kex->my, empty, kex->min, kex->nbits, kex->max, dh_p, dh_g,
+			gss->dh_client_pub, pub_key, sshbuf_ptr(shared_secret),
+			sshbuf_len(shared_secret), hash, &hashlen);
+	sshbuf_free(empty);
+	if (r != 0)
 		fatal("kexgex_hash failed: %s", ssh_err(r));
 
-	gssbuf.value = hash;
-	gssbuf.length = hashlen;
+	gss->buf.value = hash;
+	gss->buf.length = hashlen;
 
-	if (GSS_ERROR(mm_ssh_gssapi_sign(ctxt, &gssbuf, &msg_tok)))
+	if (GSS_ERROR(mm_ssh_gssapi_sign(gss, &gss->buf, &msg_tok)))
 		fatal("Couldn't get MIC");
 
 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_COMPLETE)) != 0 ||
@@ -443,24 +436,24 @@ kexgssgex_server(struct ssh *ssh)
 	    (r = sshpkt_put_string(ssh, msg_tok.value, msg_tok.length)) != 0)
 		fatal("sshpkt failed: %s", ssh_err(r));
 
-	if (send_tok.length != 0) {
+	if (send_tok->length != 0) {
 		if ((r = sshpkt_put_u8(ssh, 1)) != 0 || /* true */
-		    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0)
+		    (r = sshpkt_put_string(ssh, send_tok->value, send_tok->length)) != 0)
 			fatal("sshpkt failed: %s", ssh_err(r));
 	} else {
 		if ((r = sshpkt_put_u8(ssh, 0)) != 0) /* false */
 			fatal("sshpkt failed: %s", ssh_err(r));
 	}
 	if ((r = sshpkt_send(ssh)) != 0)
-		fatal("sshpkt failed: %s", ssh_err(r));
+		fatal("sshpkt_send failed: %s", ssh_err(r));
 
-	gss_release_buffer(&min_status, &send_tok);
-	gss_release_buffer(&min_status, &msg_tok);
+	gss_release_buffer(&gss->minor, send_tok);
+	gss_release_buffer(&gss->minor, &msg_tok);
 
 	if (gss_kex_context == NULL)
-		gss_kex_context = ctxt;
+		gss_kex_context = gss;
 	else
-		ssh_gssapi_delete_ctx(&ctxt);
+		ssh_gssapi_delete_ctx(&kex->gss);
 
 	/* Finally derive the keys and send them */
 	if ((r = kex_derive_keys(ssh, hash, hashlen, shared_secret)) == 0)
@@ -470,13 +463,128 @@ kexgssgex_server(struct ssh *ssh)
 	 * just exchanged.  */
 	if (options.gss_store_rekey)
 		ssh_gssapi_rekey_creds();
+
+	if (kex->gss != NULL)
+		BN_clear_free(gss->dh_client_pub);
+
 out:
-	sshbuf_free(empty);
 	explicit_bzero(hash, sizeof(hash));
 	DH_free(kex->dh);
 	kex->dh = NULL;
-	BN_clear_free(dh_client_pub);
 	sshbuf_free(shared_secret);
 	return r;
 }
+
+static int
+input_kexgssgex_groupreq(int type,
+			 u_int32_t seq,
+			 struct ssh *ssh)
+{
+	struct kex *kex = ssh->kex;
+	const BIGNUM *dh_p, *dh_g;
+	int min = -1, max = -1, nbits = -1;
+	int cmin = -1, cmax = -1; /* client proposal */
+	int r;
+
+	/* 5. S generates an ephemeral key pair (do the allocations early) */
+
+	debug("SSH2_MSG_KEXGSS_GROUPREQ received");
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_GROUPREQ, NULL);
+
+	/* store client proposal to provide valid signature */
+	if ((r = sshpkt_get_u32(ssh, &cmin)) != 0 ||
+	    (r = sshpkt_get_u32(ssh, &nbits)) != 0 ||
+	    (r = sshpkt_get_u32(ssh, &cmax)) != 0 ||
+	    (r = sshpkt_get_end(ssh)) != 0)
+		fatal("sshpkt failed: %s", ssh_err(r));
+
+	kex->nbits = nbits;
+	kex->min = cmin;
+	kex->max = cmax;
+	min = MAX(DH_GRP_MIN, cmin);
+	max = MIN(DH_GRP_MAX, cmax);
+	nbits = MAXIMUM(DH_GRP_MIN, nbits);
+	nbits = MINIMUM(DH_GRP_MAX, nbits);
+
+	if (max < min || nbits < min || max < nbits)
+		fatal("GSS_GEX, bad parameters: %d !< %d !< %d", min, nbits, max);
+
+	kex->dh = mm_choose_dh(min, nbits, max);
+	if (kex->dh == NULL) {
+		sshpkt_disconnect(ssh, "Protocol error: no matching group found");
+		fatal("Protocol error: no matching group found");
+	}
+
+	DH_get0_pqg(kex->dh, &dh_p, NULL, &dh_g);
+	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_GROUP)) != 0 ||
+	    (r = sshpkt_put_bignum2(ssh, dh_p)) != 0 ||
+	    (r = sshpkt_put_bignum2(ssh, dh_g)) != 0 ||
+	    (r = sshpkt_send(ssh)) != 0)
+		fatal("sshpkt failed: %s", ssh_err(r));
+
+	if ((r = ssh_packet_write_wait(ssh)) != 0)
+		fatal("ssh_packet_write_wait: %s", ssh_err(r));
+
+	/* Compute our exchange value in parallel with the client */
+	if ((r = dh_gen_key(kex->dh, kex->we_need * 8)) != 0) {
+		ssh_gssapi_delete_ctx(&kex->gss);
+		DH_free(kex->dh);
+		kex->dh = NULL;
+		return r;
+	}
+
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_INIT, &input_kexgssgex_init);
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, &input_kexgssgex_continue);
+	debug("Wait SSH2_MSG_KEXGSS_INIT");
+	return 0;
+}
+
+static int
+input_kexgssgex_init(int type,
+		     u_int32_t seq,
+		     struct ssh *ssh)
+{
+	Gssctxt *gss = ssh->kex->gss;
+	gss_buffer_desc recv_tok, send_tok = GSS_C_EMPTY_BUFFER;
+	OM_uint32 ret_flags = 0;
+	int r;
+
+	debug("SSH2_MSG_KEXGSS_INIT received");
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_INIT, NULL);
+
+	if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0 ||
+	    (r = sshpkt_get_bignum2(ssh, &gss->dh_client_pub)) != 0 ||
+	    (r = sshpkt_get_end(ssh)) != 0)
+		fatal("sshpkt failed: %s", ssh_err(r));
+
+	/* Send SSH_MSG_KEXGSS_HOSTKEY here, if we want */
+
+	kexgssgex_accept_ctx(ssh, &recv_tok, &send_tok, &ret_flags);
+	if (gss->major & GSS_S_CONTINUE_NEEDED)
+		return 0;
+
+	return kexgssgex_final(ssh, &send_tok, &ret_flags);
+}
+
+static int
+input_kexgssgex_continue(int type,
+			 u_int32_t seq,
+			 struct ssh *ssh)
+{
+	Gssctxt *gss = ssh->kex->gss;
+	gss_buffer_desc recv_tok, send_tok = GSS_C_EMPTY_BUFFER;
+	OM_uint32 ret_flags = 0;
+	int r;
+
+	if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0 ||
+	    (r = sshpkt_get_end(ssh)) != 0)
+		fatal("sshpkt failed: %s", ssh_err(r));
+
+	kexgssgex_accept_ctx(ssh, &recv_tok, &send_tok, &ret_flags);
+	if (gss->major & GSS_S_CONTINUE_NEEDED)
+		return 0;
+
+	return kexgssgex_final(ssh, &send_tok, &ret_flags);
+}
+
 #endif /* defined(GSSAPI) && defined(WITH_OPENSSL) */
diff --color -ruNp a/kex.h b/kex.h
--- a/kex.h	2024-05-16 15:49:43.986410812 +0200
+++ b/kex.h	2024-06-18 12:19:48.580347469 +0200
@@ -29,6 +29,10 @@
 #include "mac.h"
 #include "crypto_api.h"
 
+#ifdef GSSAPI
+# include "ssh-gss.h" /* Gssctxt */
+#endif
+
 #ifdef WITH_OPENSSL
 # include <openssl/bn.h>
 # include <openssl/dh.h>
@@ -177,6 +181,7 @@ struct kex {
 	int	hash_alg;
 	int	ec_nid;
 #ifdef GSSAPI
+	Gssctxt *gss;
 	int	gss_deleg_creds;
 	int	gss_trust_dns;
 	char    *gss_host;
diff --color -ruNp a/ssh-gss.h b/ssh-gss.h
--- a/ssh-gss.h	2024-05-16 15:49:43.837407972 +0200
+++ b/ssh-gss.h	2024-06-27 14:12:48.659866937 +0200
@@ -88,6 +88,8 @@ extern char **k5users_allowed_cmds;
 	KEX_GSS_GRP14_SHA1_ID "," \
 	KEX_GSS_GEX_SHA1_ID
 
+#include "digest.h" /* SSH_DIGEST_MAX_LENGTH */
+
 typedef struct {
 	char *filename;
 	char *envvar;
@@ -127,6 +129,16 @@ typedef struct {
 	gss_cred_id_t	creds; /* server */
 	gss_name_t	client; /* server */
 	gss_cred_id_t	client_creds; /* both */
+	struct sshbuf *shared_secret; /* both */
+	struct sshbuf *server_pubkey; /* server */
+	struct sshbuf *server_blob; /* client */
+	struct sshbuf *server_host_key_blob; /* client */
+	gss_buffer_desc msg_tok; /* client */
+	gss_buffer_desc buf; /* both */
+	u_char hash[SSH_DIGEST_MAX_LENGTH]; /* both */
+	size_t hashlen; /* both */
+	int first; /* client */
+	BIGNUM *dh_client_pub; /* server (gex) */
 } Gssctxt;
 
 extern ssh_gssapi_mech *supported_mechs[];