From f0a0683fa6a3f696c4bc5ba88c128bc781c54895 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Mon, 11 Oct 2021 12:19:40 +0200 Subject: [PATCH] Refactored RPC gateway parser Utilize wStream instead of custom binary parsing code, add proper input validation. Reported by Sunglin from the Knownsec 404 team & 0103 sec team (cherry picked from commit f0b44da67c09488178000725ff9f2729ccfdf9fe) --- libfreerdp/core/gateway/ncacn_http.c | 6 +- libfreerdp/core/gateway/rdg.c | 77 +- libfreerdp/core/gateway/rpc.c | 134 +- libfreerdp/core/gateway/rpc.h | 18 +- libfreerdp/core/gateway/rpc_bind.c | 253 +-- libfreerdp/core/gateway/rpc_bind.h | 2 +- libfreerdp/core/gateway/rpc_client.c | 265 +-- libfreerdp/core/gateway/rpc_client.h | 3 +- libfreerdp/core/gateway/rpc_fault.c | 48 +- libfreerdp/core/gateway/rts.c | 2085 ++++++++++++++++++----- libfreerdp/core/gateway/rts.h | 29 +- libfreerdp/core/gateway/rts_signature.c | 127 +- libfreerdp/core/gateway/rts_signature.h | 8 +- libfreerdp/core/gateway/tsg.c | 305 ++-- libfreerdp/core/surface.c | 8 +- 15 files changed, 2375 insertions(+), 993 deletions(-) diff --git a/libfreerdp/core/gateway/ncacn_http.c b/libfreerdp/core/gateway/ncacn_http.c index cffb378..f288a0f 100644 --- a/libfreerdp/core/gateway/ncacn_http.c +++ b/libfreerdp/core/gateway/ncacn_http.c @@ -121,6 +121,7 @@ BOOL rpc_ncacn_http_recv_in_channel_response(RpcChannel* inChannel, HttpResponse if (ntlmTokenData && ntlmTokenLength) return ntlm_client_set_input_buffer(ntlm, FALSE, ntlmTokenData, ntlmTokenLength); + free(ntlmTokenData); return TRUE; } @@ -274,5 +275,6 @@ BOOL rpc_ncacn_http_recv_out_channel_response(RpcChannel* outChannel, HttpRespon if (ntlmTokenData && ntlmTokenLength) return ntlm_client_set_input_buffer(ntlm, FALSE, ntlmTokenData, ntlmTokenLength); + free(ntlmTokenData); return TRUE; } diff --git a/libfreerdp/core/gateway/rdg.c b/libfreerdp/core/gateway/rdg.c index 107e396..d333880 100644 --- a/libfreerdp/core/gateway/rdg.c +++ b/libfreerdp/core/gateway/rdg.c @@ -1159,6 +1159,9 @@ static BOOL rdg_tunnel_connect(rdpRdg* rdg) if (!status) { + assert(rdg); + assert(rdg->context); + assert(rdg->context->rdp); rdg->context->rdp->transport->layer = TRANSPORT_LAYER_CLOSED; return FALSE; } @@ -1190,6 +1193,9 @@ BOOL rdg_connect(rdpRdg* rdg, DWORD timeout, BOOL* rpcFallback) if (!status) { + assert(rdg); + assert(rdg->context); + assert(rdg->context->rdp); rdg->context->rdp->transport->layer = TRANSPORT_LAYER_CLOSED; return FALSE; } @@ -1535,10 +1541,10 @@ static int rdg_bio_gets(BIO* bio, char* str, int size) return -2; } -static long rdg_bio_ctrl(BIO* bio, int cmd, long arg1, void* arg2) +static long rdg_bio_ctrl(BIO* in_bio, int cmd, long arg1, void* arg2) { long status = -1; - rdpRdg* rdg = (rdpRdg*)BIO_get_data(bio); + rdpRdg* rdg = (rdpRdg*)BIO_get_data(in_bio); rdpTls* tlsOut = rdg->tlsOut; rdpTls* tlsIn = rdg->tlsIn; diff --git a/libfreerdp/core/gateway/rpc.c b/libfreerdp/core/gateway/rpc.c index 218a407..4ba52f3 100644 --- a/libfreerdp/core/gateway/rpc.c +++ b/libfreerdp/core/gateway/rpc.c @@ -24,6 +24,7 @@ #endif #include +#include #include #include #include @@ -46,6 +47,7 @@ #include "rpc_client.h" #include "rpc.h" +#include "rts.h" #define TAG FREERDP_TAG("core.gateway.rpc") @@ -88,8 +90,10 @@ static const char* PTYPE_STRINGS[] = { "PTYPE_REQUEST", "PTYPE_PING", * */ -void rpc_pdu_header_print(rpcconn_hdr_t* header) +void rpc_pdu_header_print(const rpcconn_hdr_t* header) { + assert(header); + WLog_INFO(TAG, "rpc_vers: %" PRIu8 "", header->common.rpc_vers); WLog_INFO(TAG, "rpc_vers_minor: %" PRIu8 "", header->common.rpc_vers_minor); @@ -139,26 +143,30 @@ void rpc_pdu_header_print(rpcconn_hdr_t* header) } } -void rpc_pdu_header_init(rdpRpc* rpc, rpcconn_common_hdr_t* header) +rpcconn_common_hdr_t rpc_pdu_header_init(const rdpRpc* rpc) { - header->rpc_vers = rpc->rpc_vers; - header->rpc_vers_minor = rpc->rpc_vers_minor; - header->packed_drep[0] = rpc->packed_drep[0]; - header->packed_drep[1] = rpc->packed_drep[1]; - header->packed_drep[2] = rpc->packed_drep[2]; - header->packed_drep[3] = rpc->packed_drep[3]; + rpcconn_common_hdr_t header = { 0 }; + assert(rpc); + + header.rpc_vers = rpc->rpc_vers; + header.rpc_vers_minor = rpc->rpc_vers_minor; + header.packed_drep[0] = rpc->packed_drep[0]; + header.packed_drep[1] = rpc->packed_drep[1]; + header.packed_drep[2] = rpc->packed_drep[2]; + header.packed_drep[3] = rpc->packed_drep[3]; + return header; } -UINT32 rpc_offset_align(UINT32* offset, UINT32 alignment) +size_t rpc_offset_align(size_t* offset, size_t alignment) { - UINT32 pad; + size_t pad; pad = *offset; *offset = (*offset + alignment - 1) & ~(alignment - 1); pad = *offset - pad; return pad; } -UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad) +size_t rpc_offset_pad(size_t* offset, size_t pad) { *offset += pad; return pad; @@ -239,64 +247,67 @@ UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad) * */ -BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* length) +BOOL rpc_get_stub_data_info(const rpcconn_hdr_t* header, size_t* poffset, size_t* length) { - UINT32 alloc_hint = 0; - rpcconn_hdr_t* header; + size_t used = 0; + size_t offset = 0; + BOOL rc = FALSE; UINT32 frag_length; UINT32 auth_length; - UINT32 auth_pad_length; + UINT32 auth_pad_length = 0; UINT32 sec_trailer_offset; - rpc_sec_trailer* sec_trailer; - *offset = RPC_COMMON_FIELDS_LENGTH; - header = ((rpcconn_hdr_t*)buffer); + const rpc_sec_trailer* sec_trailer = NULL; + + assert(header); + assert(poffset); + assert(length); + + offset = RPC_COMMON_FIELDS_LENGTH; switch (header->common.ptype) { case PTYPE_RESPONSE: - *offset += 8; - rpc_offset_align(offset, 8); - alloc_hint = header->response.alloc_hint; + offset += 8; + rpc_offset_align(&offset, 8); + sec_trailer = &header->response.auth_verifier; break; case PTYPE_REQUEST: - *offset += 4; - rpc_offset_align(offset, 8); - alloc_hint = header->request.alloc_hint; + offset += 4; + rpc_offset_align(&offset, 8); + sec_trailer = &header->request.auth_verifier; break; case PTYPE_RTS: - *offset += 4; + offset += 4; break; default: WLog_ERR(TAG, "Unknown PTYPE: 0x%02" PRIX8 "", header->common.ptype); - return FALSE; + goto fail; } - if (!length) - return TRUE; + frag_length = header->common.frag_length; + auth_length = header->common.auth_length; + + if (poffset) + *poffset = offset; - if (header->common.ptype == PTYPE_REQUEST) + /* The fragment must be larger than the authentication trailer */ + used = offset + auth_length + 8ull; + if (sec_trailer) { - UINT32 sec_trailer_offset; - sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8; - *length = sec_trailer_offset - *offset; - return TRUE; + auth_pad_length = sec_trailer->auth_pad_length; + used += sec_trailer->auth_pad_length; } - frag_length = header->common.frag_length; - auth_length = header->common.auth_length; + if (frag_length < used) + goto fail; + + if (!length) + return TRUE; + sec_trailer_offset = frag_length - auth_length - 8; - sec_trailer = (rpc_sec_trailer*)&buffer[sec_trailer_offset]; - auth_pad_length = sec_trailer->auth_pad_length; -#if 0 - WLog_DBG(TAG, - "sec_trailer: type: %"PRIu8" level: %"PRIu8" pad_length: %"PRIu8" reserved: %"PRIu8" context_id: %"PRIu32"", - sec_trailer->auth_type, sec_trailer->auth_level, - sec_trailer->auth_pad_length, sec_trailer->auth_reserved, - sec_trailer->auth_context_id); -#endif /** * According to [MS-RPCE], auth_pad_length is the number of padding @@ -310,18 +321,21 @@ BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* l auth_length, (frag_length - (sec_trailer_offset + 8))); } - *length = frag_length - auth_length - 24 - 8 - auth_pad_length; - return TRUE; + *length = sec_trailer_offset - auth_pad_length - offset; + + rc = TRUE; +fail: + return rc; } SSIZE_T rpc_channel_read(RpcChannel* channel, wStream* s, size_t length) { int status; - if (!channel) + if (!channel || (length > INT32_MAX)) return -1; - status = BIO_read(channel->tls->bio, Stream_Pointer(s), length); + status = BIO_read(channel->tls->bio, Stream_Pointer(s), (INT32)length); if (status > 0) { @@ -340,10 +354,10 @@ SSIZE_T rpc_channel_write(RpcChannel* channel, const BYTE* data, size_t length) { int status; - if (!channel) + if (!channel || (length > INT32_MAX)) return -1; - status = tls_write_all(channel->tls, data, length); + status = tls_write_all(channel->tls, data, (INT32)length); return status; } @@ -629,7 +643,7 @@ static void rpc_virtual_connection_free(RpcVirtualConnection* connection) free(connection); } -static BOOL rpc_channel_tls_connect(RpcChannel* channel, int timeout) +static BOOL rpc_channel_tls_connect(RpcChannel* channel, UINT32 timeout) { int sockfd; rdpTls* tls; @@ -719,7 +733,7 @@ static BOOL rpc_channel_tls_connect(RpcChannel* channel, int timeout) return TRUE; } -static int rpc_in_channel_connect(RpcInChannel* inChannel, int timeout) +static int rpc_in_channel_connect(RpcInChannel* inChannel, UINT32 timeout) { rdpContext* context; @@ -814,7 +828,7 @@ int rpc_out_channel_replacement_connect(RpcOutChannel* outChannel, int timeout) return 1; } -BOOL rpc_connect(rdpRpc* rpc, int timeout) +BOOL rpc_connect(rdpRpc* rpc, UINT32 timeout) { RpcInChannel* inChannel; RpcOutChannel* outChannel; @@ -840,7 +854,15 @@ BOOL rpc_connect(rdpRpc* rpc, int timeout) rdpRpc* rpc_new(rdpTransport* transport) { - rdpRpc* rpc = (rdpRpc*)calloc(1, sizeof(rdpRpc)); + rdpContext* context; + rdpRpc* rpc; + + assert(transport); + + context = transport->context; + assert(context); + + rpc = (rdpRpc*)calloc(1, sizeof(rdpRpc)); if (!rpc) return NULL; @@ -848,7 +870,7 @@ rdpRpc* rpc_new(rdpTransport* transport) rpc->State = RPC_CLIENT_STATE_INITIAL; rpc->transport = transport; rpc->settings = transport->settings; - rpc->context = transport->context; + rpc->context = context; rpc->SendSeqNum = 0; rpc->ntlm = ntlm_new(); @@ -873,7 +895,7 @@ rdpRpc* rpc_new(rdpTransport* transport) rpc->CurrentKeepAliveInterval = rpc->KeepAliveInterval; rpc->CurrentKeepAliveTime = 0; rpc->CallId = 2; - rpc->client = rpc_client_new(rpc->context, rpc->max_recv_frag); + rpc->client = rpc_client_new(context, rpc->max_recv_frag); if (!rpc->client) goto out_free; diff --git a/libfreerdp/core/gateway/rpc.h b/libfreerdp/core/gateway/rpc.h index 3ca18a7..28f7f30 100644 --- a/libfreerdp/core/gateway/rpc.h +++ b/libfreerdp/core/gateway/rpc.h @@ -73,7 +73,6 @@ typedef struct _RPC_PDU #include "../tcp.h" #include "../transport.h" -#include "rts.h" #include "http.h" #include "ntlm.h" @@ -522,7 +521,8 @@ typedef struct rpcconn_common_hdr_t header; } rpcconn_shutdown_hdr_t; -typedef union { +typedef union +{ rpcconn_common_hdr_t common; rpcconn_alter_context_hdr_t alter_context; rpcconn_alter_context_response_hdr_t alter_context_response; @@ -764,14 +764,14 @@ struct rdp_rpc RpcVirtualConnection* VirtualConnection; }; -FREERDP_LOCAL void rpc_pdu_header_print(rpcconn_hdr_t* header); -FREERDP_LOCAL void rpc_pdu_header_init(rdpRpc* rpc, rpcconn_common_hdr_t* header); +FREERDP_LOCAL void rpc_pdu_header_print(const rpcconn_hdr_t* header); +FREERDP_LOCAL rpcconn_common_hdr_t rpc_pdu_header_init(const rdpRpc* rpc); -FREERDP_LOCAL UINT32 rpc_offset_align(UINT32* offset, UINT32 alignment); -FREERDP_LOCAL UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad); +FREERDP_LOCAL size_t rpc_offset_align(size_t* offset, size_t alignment); +FREERDP_LOCAL size_t rpc_offset_pad(size_t* offset, size_t pad); -FREERDP_LOCAL BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset, - UINT32* length); +FREERDP_LOCAL BOOL rpc_get_stub_data_info(const rpcconn_hdr_t* header, size_t* offset, + size_t* length); FREERDP_LOCAL SSIZE_T rpc_channel_write(RpcChannel* channel, const BYTE* data, size_t length); @@ -791,7 +791,7 @@ FREERDP_LOCAL BOOL rpc_virtual_connection_transition_to_state(rdpRpc* rpc, RpcVirtualConnection* connection, VIRTUAL_CONNECTION_STATE state); -FREERDP_LOCAL BOOL rpc_connect(rdpRpc* rpc, int timeout); +FREERDP_LOCAL BOOL rpc_connect(rdpRpc* rpc, UINT32 timeout); FREERDP_LOCAL rdpRpc* rpc_new(rdpTransport* transport); FREERDP_LOCAL void rpc_free(rdpRpc* rpc); diff --git a/libfreerdp/core/gateway/rpc_bind.c b/libfreerdp/core/gateway/rpc_bind.c index 98ed9a9..bdf54d5 100644 --- a/libfreerdp/core/gateway/rpc_bind.c +++ b/libfreerdp/core/gateway/rpc_bind.c @@ -22,11 +22,14 @@ #endif #include +#include #include #include "rpc_client.h" +#include "rts.h" + #include "rpc_bind.h" #define TAG FREERDP_TAG("core.gateway.rpc") @@ -106,18 +109,32 @@ int rpc_send_bind_pdu(rdpRpc* rpc) { BOOL continueNeeded = FALSE; int status = -1; - BYTE* buffer = NULL; + wStream* buffer = NULL; UINT32 offset; - UINT32 length; RpcClientCall* clientCall; p_cont_elem_t* p_cont_elem; - rpcconn_bind_hdr_t* bind_pdu = NULL; + rpcconn_bind_hdr_t bind_pdu = { 0 }; BOOL promptPassword = FALSE; - rdpSettings* settings = rpc->settings; - freerdp* instance = (freerdp*)settings->instance; - RpcVirtualConnection* connection = rpc->VirtualConnection; - RpcInChannel* inChannel = connection->DefaultInChannel; + rdpSettings* settings; + freerdp* instance; + RpcVirtualConnection* connection; + RpcInChannel* inChannel; const SecBuffer* sbuffer = NULL; + + assert(rpc); + + settings = rpc->settings; + assert(settings); + + instance = (freerdp*)settings->instance; + assert(instance); + + connection = rpc->VirtualConnection; + + assert(connection); + + inChannel = connection->DefaultInChannel; + WLog_DBG(TAG, "Sending Bind PDU"); ntlm_free(rpc->ntlm); rpc->ntlm = ntlm_new(); @@ -180,36 +197,31 @@ int rpc_send_bind_pdu(rdpRpc* rpc) if (!continueNeeded) goto fail; - bind_pdu = (rpcconn_bind_hdr_t*)calloc(1, sizeof(rpcconn_bind_hdr_t)); - - if (!bind_pdu) - goto fail; - sbuffer = ntlm_client_get_output_buffer(rpc->ntlm); if (!sbuffer) goto fail; - rpc_pdu_header_init(rpc, &bind_pdu->header); - bind_pdu->header.auth_length = (UINT16)sbuffer->cbBuffer; - bind_pdu->auth_verifier.auth_value = sbuffer->pvBuffer; - bind_pdu->header.ptype = PTYPE_BIND; - bind_pdu->header.pfc_flags = + bind_pdu.header = rpc_pdu_header_init(rpc); + bind_pdu.header.auth_length = (UINT16)sbuffer->cbBuffer; + bind_pdu.auth_verifier.auth_value = sbuffer->pvBuffer; + bind_pdu.header.ptype = PTYPE_BIND; + bind_pdu.header.pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG | PFC_SUPPORT_HEADER_SIGN | PFC_CONC_MPX; - bind_pdu->header.call_id = 2; - bind_pdu->max_xmit_frag = rpc->max_xmit_frag; - bind_pdu->max_recv_frag = rpc->max_recv_frag; - bind_pdu->assoc_group_id = 0; - bind_pdu->p_context_elem.n_context_elem = 2; - bind_pdu->p_context_elem.reserved = 0; - bind_pdu->p_context_elem.reserved2 = 0; - bind_pdu->p_context_elem.p_cont_elem = - calloc(bind_pdu->p_context_elem.n_context_elem, sizeof(p_cont_elem_t)); - - if (!bind_pdu->p_context_elem.p_cont_elem) + bind_pdu.header.call_id = 2; + bind_pdu.max_xmit_frag = rpc->max_xmit_frag; + bind_pdu.max_recv_frag = rpc->max_recv_frag; + bind_pdu.assoc_group_id = 0; + bind_pdu.p_context_elem.n_context_elem = 2; + bind_pdu.p_context_elem.reserved = 0; + bind_pdu.p_context_elem.reserved2 = 0; + bind_pdu.p_context_elem.p_cont_elem = + calloc(bind_pdu.p_context_elem.n_context_elem, sizeof(p_cont_elem_t)); + + if (!bind_pdu.p_context_elem.p_cont_elem) goto fail; - p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[0]; + p_cont_elem = &bind_pdu.p_context_elem.p_cont_elem[0]; p_cont_elem->p_cont_id = 0; p_cont_elem->n_transfer_syn = 1; p_cont_elem->reserved = 0; @@ -222,7 +234,7 @@ int rpc_send_bind_pdu(rdpRpc* rpc) CopyMemory(&(p_cont_elem->transfer_syntaxes[0].if_uuid), &NDR_UUID, sizeof(p_uuid_t)); p_cont_elem->transfer_syntaxes[0].if_version = NDR_SYNTAX_IF_VERSION; - p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[1]; + p_cont_elem = &bind_pdu.p_context_elem.p_cont_elem[1]; p_cont_elem->p_cont_id = 1; p_cont_elem->n_transfer_syn = 1; p_cont_elem->reserved = 0; @@ -236,32 +248,23 @@ int rpc_send_bind_pdu(rdpRpc* rpc) CopyMemory(&(p_cont_elem->transfer_syntaxes[0].if_uuid), &BTFN_UUID, sizeof(p_uuid_t)); p_cont_elem->transfer_syntaxes[0].if_version = BTFN_SYNTAX_IF_VERSION; offset = 116; - bind_pdu->auth_verifier.auth_pad_length = rpc_offset_align(&offset, 4); - bind_pdu->auth_verifier.auth_type = RPC_C_AUTHN_WINNT; - bind_pdu->auth_verifier.auth_level = RPC_C_AUTHN_LEVEL_PKT_INTEGRITY; - bind_pdu->auth_verifier.auth_reserved = 0x00; - bind_pdu->auth_verifier.auth_context_id = 0x00000000; - offset += (8 + bind_pdu->header.auth_length); - bind_pdu->header.frag_length = offset; - buffer = (BYTE*)malloc(bind_pdu->header.frag_length); + + bind_pdu.auth_verifier.auth_type = RPC_C_AUTHN_WINNT; + bind_pdu.auth_verifier.auth_level = RPC_C_AUTHN_LEVEL_PKT_INTEGRITY; + bind_pdu.auth_verifier.auth_reserved = 0x00; + bind_pdu.auth_verifier.auth_context_id = 0x00000000; + offset += (8 + bind_pdu.header.auth_length); + bind_pdu.header.frag_length = offset; + + buffer = Stream_New(NULL, bind_pdu.header.frag_length); if (!buffer) goto fail; - CopyMemory(buffer, bind_pdu, 24); - CopyMemory(&buffer[24], &bind_pdu->p_context_elem, 4); - CopyMemory(&buffer[28], &bind_pdu->p_context_elem.p_cont_elem[0], 24); - CopyMemory(&buffer[52], bind_pdu->p_context_elem.p_cont_elem[0].transfer_syntaxes, 20); - CopyMemory(&buffer[72], &bind_pdu->p_context_elem.p_cont_elem[1], 24); - CopyMemory(&buffer[96], bind_pdu->p_context_elem.p_cont_elem[1].transfer_syntaxes, 20); - offset = 116; - rpc_offset_pad(&offset, bind_pdu->auth_verifier.auth_pad_length); - CopyMemory(&buffer[offset], &bind_pdu->auth_verifier.auth_type, 8); - CopyMemory(&buffer[offset + 8], bind_pdu->auth_verifier.auth_value, - bind_pdu->header.auth_length); - offset += (8 + bind_pdu->header.auth_length); - length = bind_pdu->header.frag_length; - clientCall = rpc_client_call_new(bind_pdu->header.call_id, 0); + if (!rts_write_pdu_bind(buffer, &bind_pdu)) + goto fail; + + clientCall = rpc_client_call_new(bind_pdu.header.call_id, 0); if (!clientCall) goto fail; @@ -272,22 +275,19 @@ int rpc_send_bind_pdu(rdpRpc* rpc) goto fail; } - status = rpc_in_channel_send_pdu(inChannel, buffer, length); + Stream_SealLength(buffer); + status = rpc_in_channel_send_pdu(inChannel, Stream_Buffer(buffer), Stream_Length(buffer)); fail: - if (bind_pdu) + if (bind_pdu.p_context_elem.p_cont_elem) { - if (bind_pdu->p_context_elem.p_cont_elem) - { - free(bind_pdu->p_context_elem.p_cont_elem[0].transfer_syntaxes); - free(bind_pdu->p_context_elem.p_cont_elem[1].transfer_syntaxes); - } - - free(bind_pdu->p_context_elem.p_cont_elem); + free(bind_pdu.p_context_elem.p_cont_elem[0].transfer_syntaxes); + free(bind_pdu.p_context_elem.p_cont_elem[1].transfer_syntaxes); } - free(bind_pdu); - free(buffer); + free(bind_pdu.p_context_elem.p_cont_elem); + + Stream_Free(buffer, TRUE); return (status > 0) ? 1 : -1; } @@ -317,31 +317,47 @@ fail: * example. */ -int rpc_recv_bind_ack_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) +BOOL rpc_recv_bind_ack_pdu(rdpRpc* rpc, wStream* s) { + BOOL rc = FALSE; BOOL continueNeeded = FALSE; - BYTE* auth_data; - rpcconn_hdr_t* header; - header = (rpcconn_hdr_t*)buffer; + const BYTE* auth_data; + size_t pos, end; + rpcconn_hdr_t header = { 0 }; + + assert(rpc); + assert(rpc->ntlm); + assert(s); + + pos = Stream_GetPosition(s); + if (!rts_read_pdu_header(s, &header)) + goto fail; + WLog_DBG(TAG, "Receiving BindAck PDU"); - if (!rpc || !rpc->ntlm) - return -1; + rpc->max_recv_frag = header.bind_ack.max_xmit_frag; + rpc->max_xmit_frag = header.bind_ack.max_recv_frag; - rpc->max_recv_frag = header->bind_ack.max_xmit_frag; - rpc->max_xmit_frag = header->bind_ack.max_recv_frag; - auth_data = buffer + (header->common.frag_length - header->common.auth_length); + /* Get the correct offset in the input data and pass that on as input buffer. + * rts_read_pdu_header did already do consistency checks */ + end = Stream_GetPosition(s); + Stream_SetPosition(s, pos + header.common.frag_length - header.common.auth_length); + auth_data = Stream_Pointer(s); + Stream_SetPosition(s, end); - if (!ntlm_client_set_input_buffer(rpc->ntlm, TRUE, auth_data, header->common.auth_length)) - return -1; + if (!ntlm_client_set_input_buffer(rpc->ntlm, TRUE, auth_data, header.common.auth_length)) + goto fail; if (!ntlm_authenticate(rpc->ntlm, &continueNeeded)) - return -1; + goto fail; if (continueNeeded) - return -1; + goto fail; - return (int)length; + rc = TRUE; +fail: + rts_free_pdu_header(&header, FALSE); + return rc; } /** @@ -354,68 +370,63 @@ int rpc_recv_bind_ack_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) int rpc_send_rpc_auth_3_pdu(rdpRpc* rpc) { int status = -1; - BYTE* buffer; - UINT32 offset; - UINT32 length; + wStream* buffer; + size_t offset; const SecBuffer* sbuffer; RpcClientCall* clientCall; - rpcconn_rpc_auth_3_hdr_t* auth_3_pdu; - RpcVirtualConnection* connection = rpc->VirtualConnection; - RpcInChannel* inChannel = connection->DefaultInChannel; - WLog_DBG(TAG, "Sending RpcAuth3 PDU"); - auth_3_pdu = (rpcconn_rpc_auth_3_hdr_t*)calloc(1, sizeof(rpcconn_rpc_auth_3_hdr_t)); + rpcconn_rpc_auth_3_hdr_t auth_3_pdu = { 0 }; + RpcVirtualConnection* connection; + RpcInChannel* inChannel; - if (!auth_3_pdu) - return -1; + assert(rpc); + + connection = rpc->VirtualConnection; + assert(connection); + + inChannel = connection->DefaultInChannel; + assert(inChannel); + + WLog_DBG(TAG, "Sending RpcAuth3 PDU"); sbuffer = ntlm_client_get_output_buffer(rpc->ntlm); if (!sbuffer) - { - free(auth_3_pdu); return -1; - } - rpc_pdu_header_init(rpc, &auth_3_pdu->header); - auth_3_pdu->header.auth_length = (UINT16)sbuffer->cbBuffer; - auth_3_pdu->auth_verifier.auth_value = sbuffer->pvBuffer; - auth_3_pdu->header.ptype = PTYPE_RPC_AUTH_3; - auth_3_pdu->header.pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG | PFC_CONC_MPX; - auth_3_pdu->header.call_id = 2; - auth_3_pdu->max_xmit_frag = rpc->max_xmit_frag; - auth_3_pdu->max_recv_frag = rpc->max_recv_frag; + auth_3_pdu.header = rpc_pdu_header_init(rpc); + auth_3_pdu.header.auth_length = (UINT16)sbuffer->cbBuffer; + auth_3_pdu.auth_verifier.auth_value = sbuffer->pvBuffer; + auth_3_pdu.header.ptype = PTYPE_RPC_AUTH_3; + auth_3_pdu.header.pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG | PFC_CONC_MPX; + auth_3_pdu.header.call_id = 2; + auth_3_pdu.max_xmit_frag = rpc->max_xmit_frag; + auth_3_pdu.max_recv_frag = rpc->max_recv_frag; offset = 20; - auth_3_pdu->auth_verifier.auth_pad_length = rpc_offset_align(&offset, 4); - auth_3_pdu->auth_verifier.auth_type = RPC_C_AUTHN_WINNT; - auth_3_pdu->auth_verifier.auth_level = RPC_C_AUTHN_LEVEL_PKT_INTEGRITY; - auth_3_pdu->auth_verifier.auth_reserved = 0x00; - auth_3_pdu->auth_verifier.auth_context_id = 0x00000000; - offset += (8 + auth_3_pdu->header.auth_length); - auth_3_pdu->header.frag_length = offset; - buffer = (BYTE*)malloc(auth_3_pdu->header.frag_length); + auth_3_pdu.auth_verifier.auth_pad_length = rpc_offset_align(&offset, 4); + auth_3_pdu.auth_verifier.auth_type = RPC_C_AUTHN_WINNT; + auth_3_pdu.auth_verifier.auth_level = RPC_C_AUTHN_LEVEL_PKT_INTEGRITY; + auth_3_pdu.auth_verifier.auth_reserved = 0x00; + auth_3_pdu.auth_verifier.auth_context_id = 0x00000000; + offset += (8 + auth_3_pdu.header.auth_length); + auth_3_pdu.header.frag_length = offset; + + buffer = Stream_New(NULL, auth_3_pdu.header.frag_length); if (!buffer) - { - free(auth_3_pdu); return -1; - } - CopyMemory(buffer, auth_3_pdu, 20); - offset = 20; - rpc_offset_pad(&offset, auth_3_pdu->auth_verifier.auth_pad_length); - CopyMemory(&buffer[offset], &auth_3_pdu->auth_verifier.auth_type, 8); - CopyMemory(&buffer[offset + 8], auth_3_pdu->auth_verifier.auth_value, - auth_3_pdu->header.auth_length); - offset += (8 + auth_3_pdu->header.auth_length); - length = auth_3_pdu->header.frag_length; - clientCall = rpc_client_call_new(auth_3_pdu->header.call_id, 0); + if (!rts_write_pdu_auth3(buffer, &auth_3_pdu)) + goto fail; + + clientCall = rpc_client_call_new(auth_3_pdu.header.call_id, 0); if (ArrayList_Add(rpc->client->ClientCallList, clientCall) >= 0) { - status = rpc_in_channel_send_pdu(inChannel, buffer, length); + Stream_SealLength(buffer); + status = rpc_in_channel_send_pdu(inChannel, Stream_Buffer(buffer), Stream_Length(buffer)); } - free(auth_3_pdu); - free(buffer); +fail: + Stream_Free(buffer, TRUE); return (status > 0) ? 1 : -1; } diff --git a/libfreerdp/core/gateway/rpc_bind.h b/libfreerdp/core/gateway/rpc_bind.h index 759555f..69758e5 100644 --- a/libfreerdp/core/gateway/rpc_bind.h +++ b/libfreerdp/core/gateway/rpc_bind.h @@ -35,7 +35,7 @@ FREERDP_LOCAL extern const p_uuid_t BTFN_UUID; #define BTFN_SYNTAX_IF_VERSION 0x00000001 FREERDP_LOCAL int rpc_send_bind_pdu(rdpRpc* rpc); -FREERDP_LOCAL int rpc_recv_bind_ack_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length); +FREERDP_LOCAL BOOL rpc_recv_bind_ack_pdu(rdpRpc* rpc, wStream* s); FREERDP_LOCAL int rpc_send_rpc_auth_3_pdu(rdpRpc* rpc); #endif /* FREERDP_LIB_CORE_GATEWAY_RPC_BIND_H */ diff --git a/libfreerdp/core/gateway/rpc_fault.c b/libfreerdp/core/gateway/rpc_fault.c index 7259f04..c4cb086 100644 --- a/libfreerdp/core/gateway/rpc_fault.c +++ b/libfreerdp/core/gateway/rpc_fault.c @@ -133,10 +133,7 @@ static const RPC_FAULT_CODE RPC_FAULT_CODES[] = { CAT_GATEWAY) DEFINE_RPC_FAULT_CODE( RPC_S_INVALID_OBJECT, - CAT_GATEWAY){ - 0, - NULL, - NULL } + CAT_GATEWAY) }; static const RPC_FAULT_CODE RPC_TSG_FAULT_CODES[] = { @@ -222,9 +219,7 @@ static const RPC_FAULT_CODE RPC_TSG_FAULT_CODES[] = { DEFINE_RPC_FAULT_CODE( HRESULT_CODE( RPC_S_CALL_CANCELLED), - CAT_GATEWAY){ - 0, NULL, - NULL } + CAT_GATEWAY) }; /** @@ -377,22 +372,22 @@ const char* rpc_error_to_string(UINT32 code) size_t index; static char buffer[1024]; - for (index = 0; RPC_FAULT_CODES[index].name != NULL; index++) + for (index = 0; index < ARRAYSIZE(RPC_FAULT_CODES); index++) { - if (RPC_FAULT_CODES[index].code == code) + const RPC_FAULT_CODE* const current = &RPC_FAULT_CODES[index]; + if (current->code == code) { - sprintf_s(buffer, ARRAYSIZE(buffer), "%s [0x%08" PRIX32 "]", - RPC_FAULT_CODES[index].name, code); + sprintf_s(buffer, ARRAYSIZE(buffer), "%s", current->name); goto out; } } - for (index = 0; RPC_TSG_FAULT_CODES[index].name != NULL; index++) + for (index = 0; index < ARRAYSIZE(RPC_TSG_FAULT_CODES); index++) { - if (RPC_TSG_FAULT_CODES[index].code == code) + const RPC_FAULT_CODE* const current = &RPC_TSG_FAULT_CODES[index]; + if (current->code == code) { - sprintf_s(buffer, ARRAYSIZE(buffer), "%s [0x%08" PRIX32 "]", - RPC_TSG_FAULT_CODES[index].name, code); + sprintf_s(buffer, ARRAYSIZE(buffer), "%s", current->name); goto out; } } @@ -406,16 +401,18 @@ const char* rpc_error_to_category(UINT32 code) { size_t index; - for (index = 0; RPC_FAULT_CODES[index].category != NULL; index++) + for (index = 0; index < ARRAYSIZE(RPC_FAULT_CODES); index++) { - if (RPC_FAULT_CODES[index].code == code) - return RPC_FAULT_CODES[index].category; + const RPC_FAULT_CODE* const current = &RPC_FAULT_CODES[index]; + if (current->code == code) + return current->category; } - for (index = 0; RPC_TSG_FAULT_CODES[index].category != NULL; index++) + for (index = 0; index < ARRAYSIZE(RPC_TSG_FAULT_CODES); index++) { - if (RPC_TSG_FAULT_CODES[index].code == code) - return RPC_TSG_FAULT_CODES[index].category; + const RPC_FAULT_CODE* const current = &RPC_TSG_FAULT_CODES[index]; + if (current->code == code) + return current->category; } return "UNKNOWN"; diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index d4432cb..50ec8e5 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -35,6 +36,8 @@ #include "rpc_bind.h" #include "rpc_fault.h" #include "rpc_client.h" +#include "rts_signature.h" + #include "../rdp.h" #include "../proxy.h" @@ -99,7 +102,7 @@ static int rpc_client_receive_pipe_write(RpcClient* client, const BYTE* buffer, int rpc_client_receive_pipe_read(RpcClient* client, BYTE* buffer, size_t length) { int index = 0; - int status = 0; + size_t status = 0; int nchunks = 0; DataChunk chunks[2]; @@ -122,7 +125,10 @@ int rpc_client_receive_pipe_read(RpcClient* client, BYTE* buffer, size_t length) ResetEvent(client->PipeEvent); LeaveCriticalSection(&(client->PipeLock)); - return status; + + if (status > INT_MAX) + return -1; + return (int)status; } static int rpc_client_transition_to_state(rdpRpc* rpc, RPC_CLIENT_STATE state) @@ -173,8 +179,15 @@ static int rpc_client_transition_to_state(rdpRpc* rpc, RPC_CLIENT_STATE state) static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu) { int status = -1; - rpcconn_rts_hdr_t* rts; - rdpTsg* tsg = rpc->transport->tsg; + rdpTsg* tsg; + + assert(rpc); + assert(pdu); + + Stream_SealLength(pdu->s); + Stream_SetPosition(pdu->s, 0); + + tsg = rpc->transport->tsg; if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED) { @@ -187,17 +200,13 @@ static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu) break; case VIRTUAL_CONNECTION_STATE_WAIT_A3W: - rts = (rpcconn_rts_hdr_t*)Stream_Buffer(pdu->s); - - if (!rts_match_pdu_signature(&RTS_PDU_CONN_A3_SIGNATURE, rts)) + if (!rts_match_pdu_signature(&RTS_PDU_CONN_A3_SIGNATURE, pdu->s, NULL)) { WLog_ERR(TAG, "unexpected RTS PDU: Expected CONN/A3"); return -1; } - status = rts_recv_CONN_A3_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)); - - if (status < 0) + if (!rts_recv_CONN_A3_pdu(rpc, pdu->s)) { WLog_ERR(TAG, "rts_recv_CONN_A3_pdu failure"); return -1; @@ -209,17 +218,13 @@ static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu) break; case VIRTUAL_CONNECTION_STATE_WAIT_C2: - rts = (rpcconn_rts_hdr_t*)Stream_Buffer(pdu->s); - - if (!rts_match_pdu_signature(&RTS_PDU_CONN_C2_SIGNATURE, rts)) + if (!rts_match_pdu_signature(&RTS_PDU_CONN_C2_SIGNATURE, pdu->s, NULL)) { WLog_ERR(TAG, "unexpected RTS PDU: Expected CONN/C2"); return -1; } - status = rts_recv_CONN_C2_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)); - - if (status < 0) + if (!rts_recv_CONN_C2_pdu(rpc, pdu->s)) { WLog_ERR(TAG, "rts_recv_CONN_C2_pdu failure"); return -1; @@ -252,7 +257,7 @@ static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu) { if (pdu->Type == PTYPE_BIND_ACK) { - if (rpc_recv_bind_ack_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)) <= 0) + if (!rpc_recv_bind_ack_pdu(rpc, pdu->s)) { WLog_ERR(TAG, "rpc_recv_bind_ack_pdu failure"); return -1; @@ -301,89 +306,117 @@ static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu) static int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment) { - BYTE* buffer; + int rc = -1; RPC_PDU* pdu; - UINT32 StubOffset; - UINT32 StubLength; + size_t StubOffset; + size_t StubLength; RpcClientCall* call; - rpcconn_hdr_t* header; + rpcconn_hdr_t header = { 0 }; + + assert(rpc); + assert(rpc->client); + assert(fragment); + pdu = rpc->client->pdu; - buffer = (BYTE*)Stream_Buffer(fragment); - header = (rpcconn_hdr_t*)Stream_Buffer(fragment); + assert(pdu); + + Stream_SealLength(fragment); + Stream_SetPosition(fragment, 0); + + if (!rts_read_pdu_header(fragment, &header)) + goto fail; - if (header->common.ptype == PTYPE_RESPONSE) + if (header.common.ptype == PTYPE_RESPONSE) { - rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->common.frag_length; + rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header.common.frag_length; rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow -= - header->common.frag_length; + header.common.frag_length; if (rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow < (rpc->ReceiveWindow / 2)) { - if (rts_send_flow_control_ack_pdu(rpc) < 0) - return -1; + if (!rts_send_flow_control_ack_pdu(rpc)) + goto fail; } - if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength)) + if (!rpc_get_stub_data_info(&header, &StubOffset, &StubLength)) { WLog_ERR(TAG, "expected stub"); - return -1; + goto fail; } if (StubLength == 4) { - if ((header->common.call_id == rpc->PipeCallId) && - (header->common.pfc_flags & PFC_LAST_FRAG)) + if ((header.common.call_id == rpc->PipeCallId) && + (header.common.pfc_flags & PFC_LAST_FRAG)) { /* End of TsProxySetupReceivePipe */ TerminateEventArgs e; - rpc->result = *((UINT32*)&buffer[StubOffset]); - freerdp_abort_connect(rpc->context->instance); - tsg_set_state(rpc->transport->tsg, TSG_STATE_TUNNEL_CLOSE_PENDING); + rdpContext* context = rpc->transport->context; + rdpTsg* tsg = rpc->transport->tsg; + + assert(context); + + if (Stream_Length(fragment) < StubOffset + 4) + goto fail; + Stream_SetPosition(fragment, StubOffset); + Stream_Read_UINT32(fragment, rpc->result); + + freerdp_abort_connect(context->instance); + tsg_set_state(tsg, TSG_STATE_TUNNEL_CLOSE_PENDING); EventArgsInit(&e, "freerdp"); e.code = 0; - PubSub_OnTerminate(rpc->context->pubSub, rpc->context, &e); - return 0; + PubSub_OnTerminate(context->pubSub, context, &e); + rc = 0; + goto success; } - if (header->common.call_id != rpc->PipeCallId) + if (header.common.call_id != rpc->PipeCallId) { /* Ignoring non-TsProxySetupReceivePipe Response */ - return 0; + rc = 0; + goto success; } } if (rpc->StubFragCount == 0) - rpc->StubCallId = header->common.call_id; + rpc->StubCallId = header.common.call_id; - if (rpc->StubCallId != header->common.call_id) + if (rpc->StubCallId != header.common.call_id) { WLog_ERR(TAG, "invalid call_id: actual: %" PRIu32 ", expected: %" PRIu32 ", frag_count: %" PRIu32 "", - rpc->StubCallId, header->common.call_id, rpc->StubFragCount); + rpc->StubCallId, header.common.call_id, rpc->StubFragCount); } call = rpc_client_call_find_by_id(rpc->client, rpc->StubCallId); if (!call) - return -1; + goto fail; if (call->OpNum != TsProxySetupReceivePipeOpnum) { - if (!Stream_EnsureCapacity(pdu->s, header->response.alloc_hint)) - return -1; + const rpcconn_response_hdr_t* response = + (const rpcconn_response_hdr_t*)&header.response; + if (!Stream_EnsureCapacity(pdu->s, response->alloc_hint)) + goto fail; + + if (Stream_Length(fragment) < StubOffset + StubLength) + goto fail; - Stream_Write(pdu->s, &buffer[StubOffset], StubLength); + Stream_SetPosition(fragment, StubOffset); + Stream_Write(pdu->s, Stream_Pointer(fragment), StubLength); rpc->StubFragCount++; - if (header->response.alloc_hint == StubLength) + if (response->alloc_hint == StubLength) { pdu->Flags = RPC_PDU_FLAG_STUB; pdu->Type = PTYPE_RESPONSE; pdu->CallId = rpc->StubCallId; - Stream_SealLength(pdu->s); - rpc_client_recv_pdu(rpc, pdu); + + if (rpc_client_recv_pdu(rpc, pdu) < 0) + goto fail; rpc_pdu_reset(pdu); rpc->StubFragCount = 0; rpc->StubCallId = 0; @@ -391,75 +424,84 @@ static int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment) } else { - rpc_client_receive_pipe_write(rpc->client, &buffer[StubOffset], (size_t)StubLength); + const rpcconn_response_hdr_t* response = &header.response; + if (Stream_Length(fragment) < StubOffset + StubLength) + goto fail; + Stream_SetPosition(fragment, StubOffset); + rpc_client_receive_pipe_write(rpc->client, Stream_Pointer(fragment), + (size_t)StubLength); rpc->StubFragCount++; - if (header->response.alloc_hint == StubLength) + if (response->alloc_hint == StubLength) { rpc->StubFragCount = 0; rpc->StubCallId = 0; } } - return 1; + goto success; } - else if (header->common.ptype == PTYPE_RTS) + else if (header.common.ptype == PTYPE_RTS) { if (rpc->State < RPC_CLIENT_STATE_CONTEXT_NEGOTIATED) { pdu->Flags = 0; - pdu->Type = header->common.ptype; - pdu->CallId = header->common.call_id; + pdu->Type = header.common.ptype; + pdu->CallId = header.common.call_id; if (!Stream_EnsureCapacity(pdu->s, Stream_Length(fragment))) - return -1; + goto fail; - Stream_Write(pdu->s, buffer, Stream_Length(fragment)); - Stream_SealLength(pdu->s); + Stream_Write(pdu->s, Stream_Buffer(fragment), Stream_Length(fragment)); if (rpc_client_recv_pdu(rpc, pdu) < 0) - return -1; + goto fail; rpc_pdu_reset(pdu); } else { - if (rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length) < 0) - return -1; + if (!rts_recv_out_of_sequence_pdu(rpc, fragment, &header)) + goto fail; } - return 1; + goto success; } - else if (header->common.ptype == PTYPE_BIND_ACK) + else if (header.common.ptype == PTYPE_BIND_ACK) { pdu->Flags = 0; - pdu->Type = header->common.ptype; - pdu->CallId = header->common.call_id; + pdu->Type = header.common.ptype; + pdu->CallId = header.common.call_id; if (!Stream_EnsureCapacity(pdu->s, Stream_Length(fragment))) - return -1; + goto fail; - Stream_Write(pdu->s, buffer, Stream_Length(fragment)); - Stream_SealLength(pdu->s); + Stream_Write(pdu->s, Stream_Buffer(fragment), Stream_Length(fragment)); if (rpc_client_recv_pdu(rpc, pdu) < 0) - return -1; + goto fail; rpc_pdu_reset(pdu); - return 1; + goto success; } - else if (header->common.ptype == PTYPE_FAULT) + else if (header.common.ptype == PTYPE_FAULT) { - rpc_recv_fault_pdu(header->fault.status); - return -1; + const rpcconn_fault_hdr_t* fault = (const rpcconn_fault_hdr_t*)&header.fault; + rpc_recv_fault_pdu(fault->status); + goto fail; } else { - WLog_ERR(TAG, "unexpected RPC PDU type 0x%02" PRIX8 "", header->common.ptype); - return -1; + WLog_ERR(TAG, "unexpected RPC PDU type 0x%02" PRIX8 "", header.common.ptype); + goto fail; } - return 1; +success: + rc = (rc < 0) ? 1 : 0; /* In case of default error return change to 1, otherwise we already set + the return code */ +fail: + rts_free_pdu_header(&header, FALSE); + return rc; } static int rpc_client_default_out_channel_recv(rdpRpc* rpc) @@ -509,7 +551,7 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) /* Send CONN/A1 PDU over OUT channel */ - if (rts_send_CONN_A1_pdu(rpc) < 0) + if (!rts_send_CONN_A1_pdu(rpc)) { http_response_free(response); WLog_ERR(TAG, "rpc_send_CONN_A1_pdu error!"); @@ -549,7 +591,8 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) if (statusCode == HTTP_STATUS_DENIED) { - freerdp_set_last_error_if_not(rpc->context, FREERDP_ERROR_AUTHENTICATION_FAILED); + rdpContext* context = rpc->context; + freerdp_set_last_error_if_not(context, FREERDP_ERROR_AUTHENTICATION_FAILED); } http_response_free(response); @@ -563,12 +606,13 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) } else { - wStream* fragment; - rpcconn_common_hdr_t* header; - fragment = rpc->client->ReceiveFragment; + wStream* fragment = rpc->client->ReceiveFragment; while (1) { + size_t pos; + rpcconn_common_hdr_t header = { 0 }; + while (Stream_GetPosition(fragment) < RPC_COMMON_FIELDS_LENGTH) { status = rpc_channel_read(&outChannel->common, fragment, @@ -581,22 +625,27 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) return 0; } - header = (rpcconn_common_hdr_t*)Stream_Buffer(fragment); + pos = Stream_GetPosition(fragment); + Stream_SetPosition(fragment, 0); + + /* Ignore errors, the PDU might not be complete. */ + rts_read_common_pdu_header(fragment, &header); + Stream_SetPosition(fragment, pos); - if (header->frag_length > rpc->max_recv_frag) + if (header.frag_length > rpc->max_recv_frag) { WLog_ERR(TAG, "rpc_client_recv: invalid fragment size: %" PRIu16 " (max: %" PRIu16 ")", - header->frag_length, rpc->max_recv_frag); + header.frag_length, rpc->max_recv_frag); winpr_HexDump(TAG, WLOG_ERROR, Stream_Buffer(fragment), Stream_GetPosition(fragment)); return -1; } - while (Stream_GetPosition(fragment) < header->frag_length) + while (Stream_GetPosition(fragment) < header.frag_length) { status = rpc_channel_read(&outChannel->common, fragment, - header->frag_length - Stream_GetPosition(fragment)); + header.frag_length - Stream_GetPosition(fragment)); if (status < 0) { @@ -604,14 +653,12 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) return -1; } - if (Stream_GetPosition(fragment) < header->frag_length) + if (Stream_GetPosition(fragment) < header.frag_length) return 0; } { /* complete fragment received */ - Stream_SealLength(fragment); - Stream_SetPosition(fragment, 0); status = rpc_client_recv_fragment(rpc, fragment); if (status < 0) @@ -663,10 +710,10 @@ static int rpc_client_nondefault_out_channel_recv(rdpRpc* rpc) if (rpc_ncacn_http_send_out_channel_request(&nextOutChannel->common, TRUE)) { rpc_ncacn_http_ntlm_uninit(&nextOutChannel->common); - status = rts_send_OUT_R1_A3_pdu(rpc); - if (status >= 0) + if (rts_send_OUT_R1_A3_pdu(rpc)) { + status = 1; rpc_out_channel_transition_to_state( nextOutChannel, CLIENT_OUT_CHANNEL_STATE_OPENED_A6W); } @@ -687,11 +734,14 @@ static int rpc_client_nondefault_out_channel_recv(rdpRpc* rpc) break; + case CLIENT_OUT_CHANNEL_STATE_INITIAL: + case CLIENT_OUT_CHANNEL_STATE_CONNECTED: + case CLIENT_OUT_CHANNEL_STATE_NEGOTIATED: default: WLog_ERR(TAG, "rpc_client_nondefault_out_channel_recv: Unexpected message %08" PRIx32, nextOutChannel->State); - return -1; + status = -1; } http_response_free(response); @@ -769,7 +819,7 @@ int rpc_client_in_channel_recv(rdpRpc* rpc) /* Send CONN/B1 PDU over IN channel */ - if (rts_send_CONN_B1_pdu(rpc) < 0) + if (!rts_send_CONN_B1_pdu(rpc)) { WLog_ERR(TAG, "rpc_send_CONN_B1_pdu error!"); http_response_free(response); @@ -810,8 +860,8 @@ int rpc_client_in_channel_recv(rdpRpc* rpc) RpcClientCall* rpc_client_call_find_by_id(RpcClient* client, UINT32 CallId) { - int index; - int count; + size_t index; + size_t count; RpcClientCall* clientCall = NULL; if (!client) @@ -856,18 +906,23 @@ static void rpc_array_client_call_free(void* call) rpc_client_call_free((RpcClientCall*)call); } -int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, UINT32 length) +int rpc_in_channel_send_pdu(RpcInChannel* inChannel, const BYTE* buffer, size_t length) { - int status; + SSIZE_T status; RpcClientCall* clientCall; - rpcconn_common_hdr_t* header; + wStream s; + rpcconn_common_hdr_t header = { 0 }; + status = rpc_channel_write(&inChannel->common, buffer, length); if (status <= 0) return -1; - header = (rpcconn_common_hdr_t*)buffer; - clientCall = rpc_client_call_find_by_id(inChannel->common.client, header->call_id); + Stream_StaticInit(&s, buffer, length); + if (!rts_read_common_pdu_header(&s, &header)) + return -1; + + clientCall = rpc_client_call_find_by_id(inChannel->common.client, header.call_id); clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED; /* @@ -877,7 +932,7 @@ int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, UINT32 length * variables specified by this abstract data model. */ - if (header->ptype == PTYPE_REQUEST) + if (header.ptype == PTYPE_REQUEST) { inChannel->BytesSent += status; inChannel->SenderAvailableWindow -= status; @@ -888,7 +943,7 @@ int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, UINT32 length BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum) { - UINT32 offset; + size_t offset; BYTE* buffer = NULL; UINT32 stub_data_pad; SecBuffer Buffers[2] = { 0 }; @@ -936,7 +991,7 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum) if (size < 0) goto fail; - rpc_pdu_header_init(rpc, &request_pdu.header); + request_pdu.header = rpc_pdu_header_init(rpc); request_pdu.header.ptype = PTYPE_REQUEST; request_pdu.header.pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG; request_pdu.header.auth_length = (UINT16)size; @@ -1025,7 +1080,7 @@ static BOOL rpc_client_resolve_gateway(rdpSettings* settings, char** host, UINT1 const char* peerHostname = settings->GatewayHostname; const char* proxyUsername = settings->ProxyUsername; const char* proxyPassword = settings->ProxyPassword; - *port = settings->GatewayPort; + *port = (UINT16)settings->GatewayPort; *isProxy = proxy_prepare(settings, &peerHostname, port, &proxyUsername, &proxyPassword); result = freerdp_tcp_resolve_host(peerHostname, *port, 0); diff --git a/libfreerdp/core/gateway/rpc_client.h b/libfreerdp/core/gateway/rpc_client.h index af0b8ce..7b509de 100644 --- a/libfreerdp/core/gateway/rpc_client.h +++ b/libfreerdp/core/gateway/rpc_client.h @@ -31,7 +31,8 @@ FREERDP_LOCAL RpcClientCall* rpc_client_call_find_by_id(RpcClient* client, UINT3 FREERDP_LOCAL RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum); FREERDP_LOCAL void rpc_client_call_free(RpcClientCall* client_call); -FREERDP_LOCAL int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, UINT32 length); +FREERDP_LOCAL int rpc_in_channel_send_pdu(RpcInChannel* inChannel, const BYTE* buffer, + size_t length); FREERDP_LOCAL int rpc_client_in_channel_recv(rdpRpc* rpc); FREERDP_LOCAL int rpc_client_out_channel_recv(rdpRpc* rpc); diff --git a/libfreerdp/core/gateway/rts.c b/libfreerdp/core/gateway/rts.c index c003b1e..ccbe0be 100644 --- a/libfreerdp/core/gateway/rts.c +++ b/libfreerdp/core/gateway/rts.c @@ -21,6 +21,7 @@ #include "config.h" #endif +#include #include #include #include @@ -29,6 +30,7 @@ #include "ncacn_http.h" #include "rpc_client.h" +#include "rts_signature.h" #include "rts.h" @@ -67,545 +69,1643 @@ * */ +static const char* rts_pdu_ptype_to_string(UINT32 ptype) +{ + switch (ptype) + { + case PTYPE_REQUEST: + return "PTYPE_REQUEST"; + case PTYPE_PING: + return "PTYPE_PING"; + case PTYPE_RESPONSE: + return "PTYPE_RESPONSE"; + case PTYPE_FAULT: + return "PTYPE_FAULT"; + case PTYPE_WORKING: + return "PTYPE_WORKING"; + case PTYPE_NOCALL: + return "PTYPE_NOCALL"; + case PTYPE_REJECT: + return "PTYPE_REJECT"; + case PTYPE_ACK: + return "PTYPE_ACK"; + case PTYPE_CL_CANCEL: + return "PTYPE_CL_CANCEL"; + case PTYPE_FACK: + return "PTYPE_FACK"; + case PTYPE_CANCEL_ACK: + return "PTYPE_CANCEL_ACK"; + case PTYPE_BIND: + return "PTYPE_BIND"; + case PTYPE_BIND_ACK: + return "PTYPE_BIND_ACK"; + case PTYPE_BIND_NAK: + return "PTYPE_BIND_NAK"; + case PTYPE_ALTER_CONTEXT: + return "PTYPE_ALTER_CONTEXT"; + case PTYPE_ALTER_CONTEXT_RESP: + return "PTYPE_ALTER_CONTEXT_RESP"; + case PTYPE_RPC_AUTH_3: + return "PTYPE_RPC_AUTH_3"; + case PTYPE_SHUTDOWN: + return "PTYPE_SHUTDOWN"; + case PTYPE_CO_CANCEL: + return "PTYPE_CO_CANCEL"; + case PTYPE_ORPHANED: + return "PTYPE_ORPHANED"; + case PTYPE_RTS: + return "PTYPE_RTS"; + default: + return "UNKNOWN"; + } +} static rpcconn_rts_hdr_t rts_pdu_header_init(void) { - rpcconn_rts_hdr_t header = { 0 }; - header.header.rpc_vers = 5; - header.header.rpc_vers_minor = 0; - header.header.ptype = PTYPE_RTS; - header.header.packed_drep[0] = 0x10; - header.header.packed_drep[1] = 0x00; - header.header.packed_drep[2] = 0x00; - header.header.packed_drep[3] = 0x00; - header.header.pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG; - header.header.auth_length = 0; - header.header.call_id = 0; + rpcconn_rts_hdr_t header = { 0 }; + header.header.rpc_vers = 5; + header.header.rpc_vers_minor = 0; + header.header.ptype = PTYPE_RTS; + header.header.packed_drep[0] = 0x10; + header.header.packed_drep[1] = 0x00; + header.header.packed_drep[2] = 0x00; + header.header.packed_drep[3] = 0x00; + header.header.pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG; + header.header.auth_length = 0; + header.header.call_id = 0; + + return header; +} + +static BOOL rts_align_stream(wStream* s, size_t alignment) +{ + size_t pos, pad; + + assert(s); + assert(alignment > 0); + + pos = Stream_GetPosition(s); + pad = rpc_offset_align(&pos, alignment); + return Stream_SafeSeek(s, pad); +} + +static char* sdup(const void* src, size_t length) +{ + char* dst; + assert(src || (length == 0)); + if (length == 0) + return NULL; + + dst = calloc(length + 1, sizeof(char)); + if (!dst) + return NULL; + memcpy(dst, src, length); + return dst; +} + +static BOOL rts_write_common_pdu_header(wStream* s, const rpcconn_common_hdr_t* header) +{ + assert(s); + assert(header); + if (!Stream_EnsureRemainingCapacity(s, sizeof(rpcconn_common_hdr_t))) + return FALSE; + + Stream_Write_UINT8(s, header->rpc_vers); + Stream_Write_UINT8(s, header->rpc_vers_minor); + Stream_Write_UINT8(s, header->ptype); + Stream_Write_UINT8(s, header->pfc_flags); + Stream_Write(s, header->packed_drep, ARRAYSIZE(header->packed_drep)); + Stream_Write_UINT16(s, header->frag_length); + Stream_Write_UINT16(s, header->auth_length); + Stream_Write_UINT32(s, header->call_id); + return TRUE; +} + +BOOL rts_read_common_pdu_header(wStream* s, rpcconn_common_hdr_t* header) +{ + size_t left; + assert(s); + assert(header); + + if (Stream_GetRemainingLength(s) < sizeof(rpcconn_common_hdr_t)) + return FALSE; + + Stream_Read_UINT8(s, header->rpc_vers); + Stream_Read_UINT8(s, header->rpc_vers_minor); + Stream_Read_UINT8(s, header->ptype); + Stream_Read_UINT8(s, header->pfc_flags); + Stream_Read(s, header->packed_drep, ARRAYSIZE(header->packed_drep)); + Stream_Read_UINT16(s, header->frag_length); + Stream_Read_UINT16(s, header->auth_length); + Stream_Read_UINT32(s, header->call_id); + + if (header->frag_length < sizeof(rpcconn_common_hdr_t)) + return FALSE; + + left = Stream_GetRemainingLength(s); + if (left < header->frag_length - sizeof(rpcconn_common_hdr_t)) + return FALSE; + + return TRUE; +} + +static BOOL rts_read_auth_verifier_no_checks(wStream* s, auth_verifier_co_t* auth, + const rpcconn_common_hdr_t* header, size_t* startPos) +{ + assert(s); + assert(auth); + assert(header); + + assert(header->frag_length > header->auth_length); + + if (startPos) + *startPos = Stream_GetPosition(s); + + /* Read the auth verifier and check padding matches frag_length */ + { + const size_t expected = header->frag_length - header->auth_length - 8; + + Stream_SetPosition(s, expected); + if (Stream_GetRemainingLength(s) < sizeof(auth_verifier_co_t)) + return FALSE; + + Stream_Read_UINT8(s, auth->auth_type); + Stream_Read_UINT8(s, auth->auth_level); + Stream_Read_UINT8(s, auth->auth_pad_length); + Stream_Read_UINT8(s, auth->auth_reserved); + Stream_Read_UINT32(s, auth->auth_context_id); + } + + if (header->auth_length != 0) + { + const void* ptr = Stream_Pointer(s); + if (!Stream_SafeSeek(s, header->auth_length)) + return FALSE; + auth->auth_value = (BYTE*)sdup(ptr, header->auth_length); + if (auth->auth_value == NULL) + return FALSE; + } + + return TRUE; +} + +static BOOL rts_read_auth_verifier(wStream* s, auth_verifier_co_t* auth, + const rpcconn_common_hdr_t* header) +{ + size_t pos; + assert(s); + assert(auth); + assert(header); + + if (!rts_read_auth_verifier_no_checks(s, auth, header, &pos)) + return FALSE; + + { + const size_t expected = header->frag_length - header->auth_length - 8; + assert(pos + auth->auth_pad_length == expected); + } + + return TRUE; +} + +static BOOL rts_read_auth_verifier_with_stub(wStream* s, auth_verifier_co_t* auth, + rpcconn_common_hdr_t* header) +{ + size_t pos; + size_t alloc_hint = 0; + BYTE** ptr = NULL; + + if (!rts_read_auth_verifier_no_checks(s, auth, header, &pos)) + return FALSE; + + switch (header->ptype) + { + case PTYPE_FAULT: + { + rpcconn_fault_hdr_t* hdr = (rpcconn_fault_hdr_t*)header; + alloc_hint = hdr->alloc_hint; + ptr = &hdr->stub_data; + } + break; + case PTYPE_RESPONSE: + { + rpcconn_response_hdr_t* hdr = (rpcconn_response_hdr_t*)header; + alloc_hint = hdr->alloc_hint; + ptr = &hdr->stub_data; + } + break; + case PTYPE_REQUEST: + { + rpcconn_request_hdr_t* hdr = (rpcconn_request_hdr_t*)header; + alloc_hint = hdr->alloc_hint; + ptr = &hdr->stub_data; + } + break; + default: + return FALSE; + } + + if (alloc_hint > 0) + { + const size_t size = + header->frag_length - header->auth_length - 8 - auth->auth_pad_length - pos; + const void* src = Stream_Buffer(s) + pos; + + *ptr = (BYTE*)sdup(src, size); + if (!*ptr) + return FALSE; + } + + return TRUE; +} + +static void rts_free_auth_verifier(auth_verifier_co_t* auth) +{ + if (!auth) + return; + free(auth->auth_value); +} + +static BOOL rts_write_auth_verifier(wStream* s, const auth_verifier_co_t* auth, + const rpcconn_common_hdr_t* header) +{ + size_t pos; + UINT8 auth_pad_length = 0; + + assert(s); + assert(auth); + assert(header); + + /* Align start to a multiple of 4 */ + pos = Stream_GetPosition(s); + if ((pos % 4) != 0) + { + auth_pad_length = 4 - (pos % 4); + if (!Stream_EnsureRemainingCapacity(s, auth_pad_length)) + return FALSE; + Stream_Zero(s, auth_pad_length); + } + + assert(header->frag_length + 8ull > header->auth_length); + { + size_t pos = Stream_GetPosition(s); + size_t expected = header->frag_length - header->auth_length - 8; + + assert(pos == expected); + } + + if (!Stream_EnsureRemainingCapacity(s, sizeof(auth_verifier_co_t))) + return FALSE; + + Stream_Write_UINT8(s, auth->auth_type); + Stream_Write_UINT8(s, auth->auth_level); + Stream_Write_UINT8(s, auth_pad_length); + Stream_Write_UINT8(s, 0); /* auth->auth_reserved */ + Stream_Write_UINT32(s, auth->auth_context_id); + + if (!Stream_EnsureRemainingCapacity(s, header->auth_length)) + return FALSE; + Stream_Write(s, auth->auth_value, header->auth_length); + return TRUE; +} + +static BOOL rts_read_version(wStream* s, p_rt_version_t* version) +{ + assert(s); + assert(version); + + if (Stream_GetRemainingLength(s) < 2 * sizeof(UINT8)) + return FALSE; + Stream_Read_UINT8(s, version->major); + Stream_Read_UINT8(s, version->minor); + return TRUE; +} + +void rts_free_supported_versions(p_rt_versions_supported_t* versions) +{ + if (!versions) + return; + free(versions->p_protocols); + versions->p_protocols = NULL; +} + +static BOOL rts_read_supported_versions(wStream* s, p_rt_versions_supported_t* versions) +{ + BYTE x; + + assert(s); + assert(versions); + + if (Stream_GetRemainingLength(s) < sizeof(UINT8)) + return FALSE; + + Stream_Read_UINT8(s, versions->n_protocols); /* count */ + + if (versions->n_protocols > 0) + { + versions->p_protocols = calloc(versions->n_protocols, sizeof(p_rt_version_t)); + if (!versions->p_protocols) + return FALSE; + } + for (x = 0; x < versions->n_protocols; x++) + { + p_rt_version_t* version = &versions->p_protocols[x]; + if (!rts_read_version(s, version)) /* size_is(n_protocols) */ + { + rts_free_supported_versions(versions); + return FALSE; + } + } + + return TRUE; +} + +static BOOL rts_read_port_any(wStream* s, port_any_t* port) +{ + const void* ptr; + + assert(s); + assert(port); + + if (Stream_GetRemainingLength(s) < sizeof(UINT16)) + return FALSE; + + Stream_Read_UINT16(s, port->length); + if (port->length == 0) + return TRUE; + + ptr = Stream_Pointer(s); + if (!Stream_SafeSeek(s, port->length)) + return FALSE; + port->port_spec = sdup(ptr, port->length); + return port->port_spec != NULL; +} + +static void rts_free_port_any(port_any_t* port) +{ + if (!port) + return; + free(port->port_spec); +} + +static BOOL rts_read_uuid(wStream* s, p_uuid_t* uuid) +{ + assert(s); + assert(uuid); + + if (Stream_GetRemainingLength(s) < sizeof(p_uuid_t)) + return FALSE; + + Stream_Read_UINT32(s, uuid->time_low); + Stream_Read_UINT16(s, uuid->time_mid); + Stream_Read_UINT16(s, uuid->time_hi_and_version); + Stream_Read_UINT8(s, uuid->clock_seq_hi_and_reserved); + Stream_Read_UINT8(s, uuid->clock_seq_low); + Stream_Read(s, uuid->node, ARRAYSIZE(uuid->node)); + return TRUE; +} + +static BOOL rts_write_uuid(wStream* s, const p_uuid_t* uuid) +{ + assert(s); + assert(uuid); + + if (!Stream_EnsureRemainingCapacity(s, sizeof(p_uuid_t))) + return FALSE; + + Stream_Write_UINT32(s, uuid->time_low); + Stream_Write_UINT16(s, uuid->time_mid); + Stream_Write_UINT16(s, uuid->time_hi_and_version); + Stream_Write_UINT8(s, uuid->clock_seq_hi_and_reserved); + Stream_Write_UINT8(s, uuid->clock_seq_low); + Stream_Write(s, uuid->node, ARRAYSIZE(uuid->node)); + return TRUE; +} + +static p_syntax_id_t* rts_syntax_id_new(size_t count) +{ + return calloc(count, sizeof(p_syntax_id_t)); +} + +static void rts_syntax_id_free(p_syntax_id_t* ptr) +{ + free(ptr); +} + +static BOOL rts_read_syntax_id(wStream* s, p_syntax_id_t* syntax_id) +{ + assert(s); + assert(syntax_id); + + if (!rts_read_uuid(s, &syntax_id->if_uuid)) + return FALSE; + + if (Stream_GetRemainingLength(s) < 4) + return FALSE; + + Stream_Read_UINT32(s, syntax_id->if_version); + return TRUE; +} + +static BOOL rts_write_syntax_id(wStream* s, const p_syntax_id_t* syntax_id) +{ + assert(s); + assert(syntax_id); + + if (!rts_write_uuid(s, &syntax_id->if_uuid)) + return FALSE; + + if (!Stream_EnsureRemainingCapacity(s, 4)) + return FALSE; + + Stream_Write_UINT32(s, syntax_id->if_version); + return TRUE; +} + +p_cont_elem_t* rts_context_elem_new(size_t count) +{ + p_cont_elem_t* ctx = calloc(count, sizeof(p_cont_elem_t)); + return ctx; +} + +void rts_context_elem_free(p_cont_elem_t* ptr) +{ + if (!ptr) + return; + rts_syntax_id_free(ptr->transfer_syntaxes); + free(ptr); +} + +static BOOL rts_read_context_elem(wStream* s, p_cont_elem_t* element) +{ + BYTE x; + assert(s); + assert(element); + + if (Stream_GetRemainingLength(s) < 4) + return FALSE; + + Stream_Read_UINT16(s, element->p_cont_id); + Stream_Read_UINT8(s, element->n_transfer_syn); /* number of items */ + Stream_Read_UINT8(s, element->reserved); /* alignment pad, m.b.z. */ + + if (!rts_read_syntax_id(s, &element->abstract_syntax)) /* transfer syntax list */ + return FALSE; + + if (element->n_transfer_syn > 0) + { + element->transfer_syntaxes = rts_syntax_id_new(element->n_transfer_syn); + if (!element->transfer_syntaxes) + return FALSE; + for (x = 0; x < element->n_transfer_syn; x++) + { + p_syntax_id_t* syn = &element->transfer_syntaxes[x]; + if (!rts_read_syntax_id(s, syn)) /* size_is(n_transfer_syn) */ + return FALSE; + } + } + + return TRUE; +} + +static BOOL rts_write_context_elem(wStream* s, const p_cont_elem_t* element) +{ + BYTE x; + assert(s); + assert(element); + + if (!Stream_EnsureRemainingCapacity(s, 4)) + return FALSE; + Stream_Write_UINT16(s, element->p_cont_id); + Stream_Write_UINT8(s, element->n_transfer_syn); /* number of items */ + Stream_Write_UINT8(s, element->reserved); /* alignment pad, m.b.z. */ + if (!rts_write_syntax_id(s, &element->abstract_syntax)) /* transfer syntax list */ + return FALSE; + + for (x = 0; x < element->n_transfer_syn; x++) + { + const p_syntax_id_t* syn = &element->transfer_syntaxes[x]; + if (!rts_write_syntax_id(s, syn)) /* size_is(n_transfer_syn) */ + return FALSE; + } + + return TRUE; +} + +static BOOL rts_read_context_list(wStream* s, p_cont_list_t* list) +{ + BYTE x; + + assert(s); + assert(list); + + if (Stream_GetRemainingLength(s) < 4) + return FALSE; + Stream_Read_UINT8(s, list->n_context_elem); /* number of items */ + Stream_Read_UINT8(s, list->reserved); /* alignment pad, m.b.z. */ + Stream_Read_UINT16(s, list->reserved2); /* alignment pad, m.b.z. */ + + if (list->n_context_elem > 0) + { + list->p_cont_elem = rts_context_elem_new(list->n_context_elem); + if (!list->p_cont_elem) + return FALSE; + for (x = 0; x < list->n_context_elem; x++) + { + p_cont_elem_t* element = &list->p_cont_elem[x]; + if (!rts_read_context_elem(s, element)) + return FALSE; + } + } + return TRUE; +} + +static void rts_free_context_list(p_cont_list_t* list) +{ + if (!list) + return; + rts_context_elem_free(list->p_cont_elem); +} + +static BOOL rts_write_context_list(wStream* s, const p_cont_list_t* list) +{ + BYTE x; + + assert(s); + assert(list); + + if (!Stream_EnsureRemainingCapacity(s, 4)) + return FALSE; + Stream_Write_UINT8(s, list->n_context_elem); /* number of items */ + Stream_Write_UINT8(s, 0); /* alignment pad, m.b.z. */ + Stream_Write_UINT16(s, 0); /* alignment pad, m.b.z. */ + + for (x = 0; x < list->n_context_elem; x++) + { + const p_cont_elem_t* element = &list->p_cont_elem[x]; + if (!rts_write_context_elem(s, element)) + return FALSE; + } + return TRUE; +} + +static p_result_t* rts_result_new(size_t count) +{ + return calloc(count, sizeof(p_result_t)); +} + +static void rts_result_free(p_result_t* results) +{ + if (!results) + return; + free(results); +} + +static BOOL rts_read_result(wStream* s, p_result_t* result) +{ + assert(s); + assert(result); + + if (Stream_GetRemainingLength(s) < 2) + return FALSE; + Stream_Read_UINT16(s, result->result); + Stream_Read_UINT16(s, result->reason); + + return rts_read_syntax_id(s, &result->transfer_syntax); +} + +static void rts_free_result(p_result_t* result) +{ + if (!result) + return; +} + +static BOOL rts_read_result_list(wStream* s, p_result_list_t* list) +{ + BYTE x; + + assert(s); + assert(list); + + if (Stream_GetRemainingLength(s) < 4) + return FALSE; + Stream_Read_UINT8(s, list->n_results); /* count */ + Stream_Read_UINT8(s, list->reserved); /* alignment pad, m.b.z. */ + Stream_Read_UINT16(s, list->reserved2); /* alignment pad, m.b.z. */ + + if (list->n_results > 0) + { + list->p_results = rts_result_new(list->n_results); + if (!list->p_results) + return FALSE; + + for (x = 0; x < list->n_results; x++) + { + p_result_t* result = &list->p_results[x]; /* size_is(n_results) */ + if (!rts_read_result(s, result)) + return FALSE; + } + } + + return TRUE; +} + +static void rts_free_result_list(p_result_list_t* list) +{ + BYTE x; + + if (!list) + return; + for (x = 0; x < list->n_results; x++) + { + p_result_t* result = &list->p_results[x]; + rts_free_result(result); + } + rts_result_free(list->p_results); +} + +static void rts_free_pdu_alter_context(rpcconn_alter_context_hdr_t* ctx) +{ + if (!ctx) + return; + + rts_free_context_list(&ctx->p_context_elem); + rts_free_auth_verifier(&ctx->auth_verifier); +} + +static BOOL rts_read_pdu_alter_context(wStream* s, rpcconn_alter_context_hdr_t* ctx) +{ + assert(s); + assert(ctx); + + if (Stream_GetRemainingLength(s) < + sizeof(rpcconn_alter_context_hdr_t) - sizeof(rpcconn_common_hdr_t)) + return FALSE; + + Stream_Read_UINT16(s, ctx->max_xmit_frag); + Stream_Read_UINT16(s, ctx->max_recv_frag); + Stream_Read_UINT32(s, ctx->assoc_group_id); + + if (!rts_read_context_list(s, &ctx->p_context_elem)) + return FALSE; + + if (!rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header)) + return FALSE; + + return TRUE; +} + +static BOOL rts_read_pdu_alter_context_response(wStream* s, + rpcconn_alter_context_response_hdr_t* ctx) +{ + assert(s); + assert(ctx); + + if (Stream_GetRemainingLength(s) < + sizeof(rpcconn_alter_context_response_hdr_t) - sizeof(rpcconn_common_hdr_t)) + return FALSE; + Stream_Read_UINT16(s, ctx->max_xmit_frag); + Stream_Read_UINT16(s, ctx->max_recv_frag); + Stream_Read_UINT32(s, ctx->assoc_group_id); + + if (!rts_read_port_any(s, &ctx->sec_addr)) + return FALSE; + + if (!rts_align_stream(s, 4)) + return FALSE; + + if (!rts_read_result_list(s, &ctx->p_result_list)) + return FALSE; + + if (!rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header)) + return FALSE; + + return TRUE; +} + +static void rts_free_pdu_alter_context_response(rpcconn_alter_context_response_hdr_t* ctx) +{ + if (!ctx) + return; + + rts_free_port_any(&ctx->sec_addr); + rts_free_result_list(&ctx->p_result_list); + rts_free_auth_verifier(&ctx->auth_verifier); +} + +static BOOL rts_read_pdu_bind(wStream* s, rpcconn_bind_hdr_t* ctx) +{ + assert(s); + assert(ctx); + + if (Stream_GetRemainingLength(s) < sizeof(rpcconn_bind_hdr_t) - sizeof(rpcconn_common_hdr_t)) + return FALSE; + Stream_Read_UINT16(s, ctx->max_xmit_frag); + Stream_Read_UINT16(s, ctx->max_recv_frag); + Stream_Read_UINT32(s, ctx->assoc_group_id); + + if (!rts_read_context_list(s, &ctx->p_context_elem)) + return FALSE; + + if (!rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header)) + return FALSE; + + return TRUE; +} + +static void rts_free_pdu_bind(rpcconn_bind_hdr_t* ctx) +{ + if (!ctx) + return; + rts_free_context_list(&ctx->p_context_elem); + rts_free_auth_verifier(&ctx->auth_verifier); +} + +static BOOL rts_read_pdu_bind_ack(wStream* s, rpcconn_bind_ack_hdr_t* ctx) +{ + assert(s); + assert(ctx); + + if (Stream_GetRemainingLength(s) < + sizeof(rpcconn_bind_ack_hdr_t) - sizeof(rpcconn_common_hdr_t)) + return FALSE; + Stream_Read_UINT16(s, ctx->max_xmit_frag); + Stream_Read_UINT16(s, ctx->max_recv_frag); + Stream_Read_UINT32(s, ctx->assoc_group_id); + + if (!rts_read_port_any(s, &ctx->sec_addr)) + return FALSE; + + if (!rts_align_stream(s, 4)) + return FALSE; + + if (!rts_read_result_list(s, &ctx->p_result_list)) + return FALSE; + + return rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header); +} + +static void rts_free_pdu_bind_ack(rpcconn_bind_ack_hdr_t* ctx) +{ + if (!ctx) + return; + rts_free_port_any(&ctx->sec_addr); + rts_free_result_list(&ctx->p_result_list); + rts_free_auth_verifier(&ctx->auth_verifier); +} + +static BOOL rts_read_pdu_bind_nak(wStream* s, rpcconn_bind_nak_hdr_t* ctx) +{ + assert(s); + assert(ctx); + + if (Stream_GetRemainingLength(s) < + sizeof(rpcconn_bind_nak_hdr_t) - sizeof(rpcconn_common_hdr_t)) + return FALSE; + Stream_Read_UINT16(s, ctx->provider_reject_reason); + return rts_read_supported_versions(s, &ctx->versions); +} + +static void rts_free_pdu_bind_nak(rpcconn_bind_nak_hdr_t* ctx) +{ + if (!ctx) + return; + + rts_free_supported_versions(&ctx->versions); +} + +static BOOL rts_read_pdu_auth3(wStream* s, rpcconn_rpc_auth_3_hdr_t* ctx) +{ + assert(s); + assert(ctx); + + if (Stream_GetRemainingLength(s) < + sizeof(rpcconn_rpc_auth_3_hdr_t) - sizeof(rpcconn_common_hdr_t)) + return FALSE; + Stream_Read_UINT16(s, ctx->max_xmit_frag); + Stream_Read_UINT16(s, ctx->max_recv_frag); + + return rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header); +} + +static void rts_free_pdu_auth3(rpcconn_rpc_auth_3_hdr_t* ctx) +{ + if (!ctx) + return; + rts_free_auth_verifier(&ctx->auth_verifier); +} + +static BOOL rts_read_pdu_fault(wStream* s, rpcconn_fault_hdr_t* ctx) +{ + assert(s); + assert(ctx); + + if (Stream_GetRemainingLength(s) < sizeof(rpcconn_fault_hdr_t) - sizeof(rpcconn_common_hdr_t)) + return FALSE; + Stream_Read_UINT32(s, ctx->alloc_hint); + Stream_Read_UINT16(s, ctx->p_cont_id); + Stream_Read_UINT8(s, ctx->cancel_count); + Stream_Read_UINT8(s, ctx->reserved); + Stream_Read_UINT32(s, ctx->status); + + return rts_read_auth_verifier_with_stub(s, &ctx->auth_verifier, &ctx->header); +} + +static void rts_free_pdu_fault(rpcconn_fault_hdr_t* ctx) +{ + if (!ctx) + return; + rts_free_auth_verifier(&ctx->auth_verifier); +} + +static BOOL rts_read_pdu_cancel_ack(wStream* s, rpcconn_cancel_hdr_t* ctx) +{ + assert(s); + assert(ctx); + + if (Stream_GetRemainingLength(s) < sizeof(rpcconn_cancel_hdr_t) - sizeof(rpcconn_common_hdr_t)) + return FALSE; + return rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header); +} + +static void rts_free_pdu_cancel_ack(rpcconn_cancel_hdr_t* ctx) +{ + if (!ctx) + return; + rts_free_auth_verifier(&ctx->auth_verifier); +} + +static BOOL rts_read_pdu_orphaned(wStream* s, rpcconn_orphaned_hdr_t* ctx) +{ + assert(s); + assert(ctx); + + if (Stream_GetRemainingLength(s) < + sizeof(rpcconn_orphaned_hdr_t) - sizeof(rpcconn_common_hdr_t)) + return FALSE; + return rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header); +} + +static void rts_free_pdu_orphaned(rpcconn_orphaned_hdr_t* ctx) +{ + if (!ctx) + return; + rts_free_auth_verifier(&ctx->auth_verifier); +} + +static BOOL rts_read_pdu_request(wStream* s, rpcconn_request_hdr_t* ctx) +{ + assert(s); + assert(ctx); + + if (Stream_GetRemainingLength(s) < sizeof(rpcconn_request_hdr_t) - sizeof(rpcconn_common_hdr_t)) + return FALSE; + Stream_Read_UINT32(s, ctx->alloc_hint); + Stream_Read_UINT16(s, ctx->p_cont_id); + Stream_Read_UINT16(s, ctx->opnum); + if (!rts_read_uuid(s, &ctx->object)) + return FALSE; + + return rts_read_auth_verifier_with_stub(s, &ctx->auth_verifier, &ctx->header); +} + +static void rts_free_pdu_request(rpcconn_request_hdr_t* ctx) +{ + if (!ctx) + return; + rts_free_auth_verifier(&ctx->auth_verifier); +} + +static BOOL rts_read_pdu_response(wStream* s, rpcconn_response_hdr_t* ctx) +{ + assert(s); + assert(ctx); + + if (Stream_GetRemainingLength(s) < + sizeof(rpcconn_response_hdr_t) - sizeof(rpcconn_common_hdr_t)) + return FALSE; + Stream_Read_UINT32(s, ctx->alloc_hint); + Stream_Read_UINT16(s, ctx->p_cont_id); + Stream_Read_UINT8(s, ctx->cancel_count); + Stream_Read_UINT8(s, ctx->reserved); + + if (!rts_align_stream(s, 8)) + return FALSE; + + return rts_read_auth_verifier_with_stub(s, &ctx->auth_verifier, &ctx->header); +} + +static void rts_free_pdu_response(rpcconn_response_hdr_t* ctx) +{ + if (!ctx) + return; + free(ctx->stub_data); + rts_free_auth_verifier(&ctx->auth_verifier); +} + +static BOOL rts_read_pdu_rts(wStream* s, rpcconn_rts_hdr_t* ctx) +{ + assert(s); + assert(ctx); + + if (Stream_GetRemainingLength(s) < sizeof(rpcconn_rts_hdr_t) - sizeof(rpcconn_common_hdr_t)) + return FALSE; + + Stream_Read_UINT16(s, ctx->Flags); + Stream_Read_UINT16(s, ctx->NumberOfCommands); + return TRUE; +} + +static void rts_free_pdu_rts(rpcconn_rts_hdr_t* ctx) +{ + WINPR_UNUSED(ctx); +} + +void rts_free_pdu_header(rpcconn_hdr_t* header, BOOL allocated) +{ + if (!header) + return; + + switch (header->common.ptype) + { + case PTYPE_ALTER_CONTEXT: + rts_free_pdu_alter_context(&header->alter_context); + break; + case PTYPE_ALTER_CONTEXT_RESP: + rts_free_pdu_alter_context_response(&header->alter_context_response); + break; + case PTYPE_BIND: + rts_free_pdu_bind(&header->bind); + break; + case PTYPE_BIND_ACK: + rts_free_pdu_bind_ack(&header->bind_ack); + break; + case PTYPE_BIND_NAK: + rts_free_pdu_bind_nak(&header->bind_nak); + break; + case PTYPE_RPC_AUTH_3: + rts_free_pdu_auth3(&header->rpc_auth_3); + break; + case PTYPE_CANCEL_ACK: + rts_free_pdu_cancel_ack(&header->cancel); + break; + case PTYPE_FAULT: + rts_free_pdu_fault(&header->fault); + break; + case PTYPE_ORPHANED: + rts_free_pdu_orphaned(&header->orphaned); + break; + case PTYPE_REQUEST: + rts_free_pdu_request(&header->request); + break; + case PTYPE_RESPONSE: + rts_free_pdu_response(&header->response); + break; + case PTYPE_RTS: + rts_free_pdu_rts(&header->rts); + break; + /* No extra fields */ + case PTYPE_SHUTDOWN: + break; + + /* not handled */ + case PTYPE_PING: + case PTYPE_WORKING: + case PTYPE_NOCALL: + case PTYPE_REJECT: + case PTYPE_ACK: + case PTYPE_CL_CANCEL: + case PTYPE_FACK: + case PTYPE_CO_CANCEL: + default: + break; + } + + if (allocated) + free(header); +} + +BOOL rts_read_pdu_header(wStream* s, rpcconn_hdr_t* header) +{ + BOOL rc = FALSE; + assert(s); + assert(header); + + if (!rts_read_common_pdu_header(s, &header->common)) + return FALSE; + + WLog_DBG(TAG, "Reading PDU type %s", rts_pdu_ptype_to_string(header->common.ptype)); + fflush(stdout); + switch (header->common.ptype) + { + case PTYPE_ALTER_CONTEXT: + rc = rts_read_pdu_alter_context(s, &header->alter_context); + break; + case PTYPE_ALTER_CONTEXT_RESP: + rc = rts_read_pdu_alter_context_response(s, &header->alter_context_response); + break; + case PTYPE_BIND: + rc = rts_read_pdu_bind(s, &header->bind); + break; + case PTYPE_BIND_ACK: + rc = rts_read_pdu_bind_ack(s, &header->bind_ack); + break; + case PTYPE_BIND_NAK: + rc = rts_read_pdu_bind_nak(s, &header->bind_nak); + break; + case PTYPE_RPC_AUTH_3: + rc = rts_read_pdu_auth3(s, &header->rpc_auth_3); + break; + case PTYPE_CANCEL_ACK: + rc = rts_read_pdu_cancel_ack(s, &header->cancel); + break; + case PTYPE_FAULT: + rc = rts_read_pdu_fault(s, &header->fault); + break; + case PTYPE_ORPHANED: + rc = rts_read_pdu_orphaned(s, &header->orphaned); + break; + case PTYPE_REQUEST: + rc = rts_read_pdu_request(s, &header->request); + break; + case PTYPE_RESPONSE: + rc = rts_read_pdu_response(s, &header->response); + break; + case PTYPE_RTS: + rc = rts_read_pdu_rts(s, &header->rts); + break; + case PTYPE_SHUTDOWN: + rc = TRUE; /* No extra fields */ + break; + + /* not handled */ + case PTYPE_PING: + case PTYPE_WORKING: + case PTYPE_NOCALL: + case PTYPE_REJECT: + case PTYPE_ACK: + case PTYPE_CL_CANCEL: + case PTYPE_FACK: + case PTYPE_CO_CANCEL: + default: + break; + } + + return rc; +} + +static BOOL rts_write_pdu_header(wStream* s, const rpcconn_rts_hdr_t* header) +{ + assert(s); + assert(header); + if (!Stream_EnsureRemainingCapacity(s, sizeof(rpcconn_rts_hdr_t))) + return FALSE; - return header; + if (!rts_write_common_pdu_header(s, &header->header)) + return FALSE; + + Stream_Write_UINT16(s, header->Flags); + Stream_Write_UINT16(s, header->NumberOfCommands); + return TRUE; } -static int rts_receive_window_size_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length, +static int rts_receive_window_size_command_read(rdpRpc* rpc, wStream* buffer, UINT32* ReceiveWindowSize) { + UINT32 val; + + assert(rpc); + assert(buffer); + + if (Stream_GetRemainingLength(buffer) < 4) + return -1; + Stream_Read_UINT32(buffer, val); if (ReceiveWindowSize) - *ReceiveWindowSize = *((UINT32*)&buffer[0]); /* ReceiveWindowSize (4 bytes) */ + *ReceiveWindowSize = val; /* ReceiveWindowSize (4 bytes) */ return 4; } -static int rts_receive_window_size_command_write(BYTE* buffer, UINT32 ReceiveWindowSize) +static BOOL rts_receive_window_size_command_write(wStream* s, UINT32 ReceiveWindowSize) { - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_RECEIVE_WINDOW_SIZE; /* CommandType (4 bytes) */ - *((UINT32*)&buffer[4]) = ReceiveWindowSize; /* ReceiveWindowSize (4 bytes) */ - } + assert(s); + + if (!Stream_EnsureRemainingCapacity(s, 2 * sizeof(UINT32))) + return FALSE; - return 8; + Stream_Write_UINT32(s, RTS_CMD_RECEIVE_WINDOW_SIZE); /* CommandType (4 bytes) */ + Stream_Write_UINT32(s, ReceiveWindowSize); /* ReceiveWindowSize (4 bytes) */ + + return TRUE; } -static int rts_flow_control_ack_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length, - UINT32* BytesReceived, UINT32* AvailableWindow, - BYTE* ChannelCookie) +static int rts_flow_control_ack_command_read(rdpRpc* rpc, wStream* buffer, UINT32* BytesReceived, + UINT32* AvailableWindow, BYTE* ChannelCookie) { + UINT32 val; + assert(rpc); + assert(buffer); + /* Ack (24 bytes) */ + if (Stream_GetRemainingLength(buffer) < 24) + return -1; + + Stream_Read_UINT32(buffer, val); if (BytesReceived) - *BytesReceived = *((UINT32*)&buffer[0]); /* BytesReceived (4 bytes) */ + *BytesReceived = val; /* BytesReceived (4 bytes) */ + Stream_Read_UINT32(buffer, val); if (AvailableWindow) - *AvailableWindow = *((UINT32*)&buffer[4]); /* AvailableWindow (4 bytes) */ + *AvailableWindow = val; /* AvailableWindow (4 bytes) */ if (ChannelCookie) - CopyMemory(ChannelCookie, &buffer[8], 16); /* ChannelCookie (16 bytes) */ - + Stream_Read(buffer, ChannelCookie, 16); /* ChannelCookie (16 bytes) */ + else + Stream_Seek(buffer, 16); return 24; } -static int rts_flow_control_ack_command_write(BYTE* buffer, UINT32 BytesReceived, - UINT32 AvailableWindow, BYTE* ChannelCookie) +static BOOL rts_flow_control_ack_command_write(wStream* s, UINT32 BytesReceived, + UINT32 AvailableWindow, BYTE* ChannelCookie) { - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_FLOW_CONTROL_ACK; /* CommandType (4 bytes) */ - /* Ack (24 bytes) */ - *((UINT32*)&buffer[4]) = BytesReceived; /* BytesReceived (4 bytes) */ - *((UINT32*)&buffer[8]) = AvailableWindow; /* AvailableWindow (4 bytes) */ - CopyMemory(&buffer[12], ChannelCookie, 16); /* ChannelCookie (16 bytes) */ - } + assert(s); - return 28; -} + if (!Stream_EnsureRemainingCapacity(s, 28)) + return FALSE; -static int rts_connection_timeout_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length, - UINT32* ConnectionTimeout) -{ - if (ConnectionTimeout) - *ConnectionTimeout = *((UINT32*)&buffer[0]); /* ConnectionTimeout (4 bytes) */ + Stream_Write_UINT32(s, RTS_CMD_FLOW_CONTROL_ACK); /* CommandType (4 bytes) */ + Stream_Write_UINT32(s, BytesReceived); /* BytesReceived (4 bytes) */ + Stream_Write_UINT32(s, AvailableWindow); /* AvailableWindow (4 bytes) */ + Stream_Write(s, ChannelCookie, 16); /* ChannelCookie (16 bytes) */ - return 4; + return TRUE; } -static int rts_connection_timeout_command_write(BYTE* buffer, UINT32 ConnectionTimeout) +static BOOL rts_connection_timeout_command_read(rdpRpc* rpc, wStream* buffer, + UINT32* ConnectionTimeout) { - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_CONNECTION_TIMEOUT; /* CommandType (4 bytes) */ - *((UINT32*)&buffer[4]) = ConnectionTimeout; /* ConnectionTimeout (4 bytes) */ - } + UINT32 val; + assert(rpc); + assert(buffer); - return 8; -} + if (Stream_GetRemainingLength(buffer) < 4) + return FALSE; -static int rts_cookie_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length) -{ - /* Cookie (16 bytes) */ - return 16; + Stream_Read_UINT32(buffer, val); + if (ConnectionTimeout) + *ConnectionTimeout = val; /* ConnectionTimeout (4 bytes) */ + + return TRUE; } -static int rts_cookie_command_write(BYTE* buffer, BYTE* Cookie) +static BOOL rts_cookie_command_write(wStream* s, const BYTE* Cookie) { - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_COOKIE; /* CommandType (4 bytes) */ - CopyMemory(&buffer[4], Cookie, 16); /* Cookie (16 bytes) */ - } + assert(s); - return 20; -} + if (!Stream_EnsureRemainingCapacity(s, 20)) + return FALSE; -static int rts_channel_lifetime_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length) -{ - /* ChannelLifetime (4 bytes) */ - return 4; + Stream_Write_UINT32(s, RTS_CMD_COOKIE); /* CommandType (4 bytes) */ + Stream_Write(s, Cookie, 16); /* Cookie (16 bytes) */ + + return TRUE; } -static int rts_channel_lifetime_command_write(BYTE* buffer, UINT32 ChannelLifetime) +static BOOL rts_channel_lifetime_command_write(wStream* s, UINT32 ChannelLifetime) { - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_CHANNEL_LIFETIME; /* CommandType (4 bytes) */ - *((UINT32*)&buffer[4]) = ChannelLifetime; /* ChannelLifetime (4 bytes) */ - } + assert(s); - return 8; -} + if (!Stream_EnsureRemainingCapacity(s, 8)) + return FALSE; + Stream_Write_UINT32(s, RTS_CMD_CHANNEL_LIFETIME); /* CommandType (4 bytes) */ + Stream_Write_UINT32(s, ChannelLifetime); /* ChannelLifetime (4 bytes) */ -static int rts_client_keepalive_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length) -{ - /* ClientKeepalive (4 bytes) */ - return 4; + return TRUE; } -static int rts_client_keepalive_command_write(BYTE* buffer, UINT32 ClientKeepalive) +static BOOL rts_client_keepalive_command_write(wStream* s, UINT32 ClientKeepalive) { + assert(s); + + if (!Stream_EnsureRemainingCapacity(s, 8)) + return FALSE; /** * An unsigned integer that specifies the keep-alive interval, in milliseconds, * that this connection is configured to use. This value MUST be 0 or in the inclusive * range of 60,000 through 4,294,967,295. If it is 0, it MUST be interpreted as 300,000. */ - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_CLIENT_KEEPALIVE; /* CommandType (4 bytes) */ - *((UINT32*)&buffer[4]) = ClientKeepalive; /* ClientKeepalive (4 bytes) */ - } - return 8; -} + Stream_Write_UINT32(s, RTS_CMD_CLIENT_KEEPALIVE); /* CommandType (4 bytes) */ + Stream_Write_UINT32(s, ClientKeepalive); /* ClientKeepalive (4 bytes) */ -static int rts_version_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length) -{ - /* Version (4 bytes) */ - return 4; + return TRUE; } -static int rts_version_command_write(BYTE* buffer) +static BOOL rts_version_command_read(rdpRpc* rpc, wStream* buffer) { - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_VERSION; /* CommandType (4 bytes) */ - *((UINT32*)&buffer[4]) = 1; /* Version (4 bytes) */ - } + assert(rpc); + assert(buffer); - return 8; -} + if (!Stream_SafeSeek(buffer, 4)) + return FALSE; -static int rts_empty_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length) -{ - return 0; + /* Version (4 bytes) */ + return TRUE; } -static int rts_empty_command_write(BYTE* buffer) +static BOOL rts_version_command_write(wStream* buffer) { - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_EMPTY; /* CommandType (4 bytes) */ - } - - return 4; -} + assert(buffer); -static SSIZE_T rts_padding_command_read(const BYTE* buffer, size_t length) -{ - UINT32 ConformanceCount; - ConformanceCount = *((UINT32*)&buffer[0]); /* ConformanceCount (4 bytes) */ - /* Padding (variable) */ - return ConformanceCount + 4; -} + if (Stream_GetRemainingCapacity(buffer) < 8) + return FALSE; -static int rts_padding_command_write(BYTE* buffer, UINT32 ConformanceCount) -{ - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_PADDING; /* CommandType (4 bytes) */ - *((UINT32*)&buffer[4]) = ConformanceCount; /* ConformanceCount (4 bytes) */ - ZeroMemory(&buffer[8], ConformanceCount); /* Padding (variable) */ - } + Stream_Write_UINT32(buffer, RTS_CMD_VERSION); /* CommandType (4 bytes) */ + Stream_Write_UINT32(buffer, 1); /* Version (4 bytes) */ - return 8 + ConformanceCount; + return TRUE; } -static int rts_negative_ance_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length) +static BOOL rts_empty_command_write(wStream* s) { - return 0; -} + assert(s); -static int rts_negative_ance_command_write(BYTE* buffer) -{ - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_NEGATIVE_ANCE; /* CommandType (4 bytes) */ - } + if (!Stream_EnsureRemainingCapacity(s, 8)) + return FALSE; - return 4; -} + Stream_Write_UINT32(s, RTS_CMD_EMPTY); /* CommandType (4 bytes) */ -static int rts_ance_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length) -{ - return 0; + return TRUE; } -static int rts_ance_command_write(BYTE* buffer) +static BOOL rts_padding_command_read(wStream* s, size_t* length) { - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_ANCE; /* CommandType (4 bytes) */ - } - - return 4; + UINT32 ConformanceCount; + assert(s); + assert(length); + if (Stream_GetRemainingLength(s) < 4) + return FALSE; + Stream_Read_UINT32(s, ConformanceCount); /* ConformanceCount (4 bytes) */ + *length = ConformanceCount + 4; + return TRUE; } -static SSIZE_T rts_client_address_command_read(const BYTE* buffer, size_t length) +static BOOL rts_client_address_command_read(wStream* s, size_t* length) { UINT32 AddressType; - AddressType = *((UINT32*)&buffer[0]); /* AddressType (4 bytes) */ + + assert(s); + assert(length); + + if (Stream_GetRemainingLength(s) < 4) + return FALSE; + Stream_Read_UINT32(s, AddressType); /* AddressType (4 bytes) */ if (AddressType == 0) { /* ClientAddress (4 bytes) */ /* padding (12 bytes) */ - return 4 + 4 + 12; + *length = 4 + 4 + 12; } else { /* ClientAddress (16 bytes) */ /* padding (12 bytes) */ - return 4 + 16 + 12; + *length = 4 + 16 + 12; } + return TRUE; } -static int rts_client_address_command_write(BYTE* buffer, UINT32 AddressType, BYTE* ClientAddress) +static BOOL rts_association_group_id_command_write(wStream* s, const BYTE* AssociationGroupId) { - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_CLIENT_ADDRESS; /* CommandType (4 bytes) */ - *((UINT32*)&buffer[4]) = AddressType; /* AddressType (4 bytes) */ - } + assert(s); - if (AddressType == 0) - { - if (buffer) - { - CopyMemory(&buffer[8], ClientAddress, 4); /* ClientAddress (4 bytes) */ - ZeroMemory(&buffer[12], 12); /* padding (12 bytes) */ - } - - return 24; - } - else - { - if (buffer) - { - CopyMemory(&buffer[8], ClientAddress, 16); /* ClientAddress (16 bytes) */ - ZeroMemory(&buffer[24], 12); /* padding (12 bytes) */ - } + if (!Stream_EnsureRemainingCapacity(s, 20)) + return FALSE; - return 36; - } -} + Stream_Write_UINT32(s, RTS_CMD_ASSOCIATION_GROUP_ID); /* CommandType (4 bytes) */ + Stream_Write(s, AssociationGroupId, 16); /* AssociationGroupId (16 bytes) */ -static int rts_association_group_id_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length) -{ - /* AssociationGroupId (16 bytes) */ - return 16; + return TRUE; } -static int rts_association_group_id_command_write(BYTE* buffer, BYTE* AssociationGroupId) +static int rts_destination_command_read(rdpRpc* rpc, wStream* buffer, UINT32* Destination) { - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_ASSOCIATION_GROUP_ID; /* CommandType (4 bytes) */ - CopyMemory(&buffer[4], AssociationGroupId, 16); /* AssociationGroupId (16 bytes) */ - } - - return 20; -} + UINT32 val; + assert(rpc); + assert(buffer); -static int rts_destination_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length, - UINT32* Destination) -{ + if (Stream_GetRemainingLength(buffer) < 4) + return -1; + Stream_Read_UINT32(buffer, val); if (Destination) - *Destination = *((UINT32*)&buffer[0]); /* Destination (4 bytes) */ + *Destination = val; /* Destination (4 bytes) */ return 4; } -static int rts_destination_command_write(BYTE* buffer, UINT32 Destination) +static BOOL rts_destination_command_write(wStream* s, UINT32 Destination) { - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_DESTINATION; /* CommandType (4 bytes) */ - *((UINT32*)&buffer[4]) = Destination; /* Destination (4 bytes) */ - } - - return 8; -} + assert(s); -static int rts_ping_traffic_sent_notify_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length) -{ - /* PingTrafficSent (4 bytes) */ - return 4; -} + if (!Stream_EnsureRemainingCapacity(s, 8)) + return FALSE; -static int rts_ping_traffic_sent_notify_command_write(BYTE* buffer, UINT32 PingTrafficSent) -{ - if (buffer) - { - *((UINT32*)&buffer[0]) = RTS_CMD_PING_TRAFFIC_SENT_NOTIFY; /* CommandType (4 bytes) */ - *((UINT32*)&buffer[4]) = PingTrafficSent; /* PingTrafficSent (4 bytes) */ - } + Stream_Write_UINT32(s, RTS_CMD_DESTINATION); /* CommandType (4 bytes) */ + Stream_Write_UINT32(s, Destination); /* Destination (4 bytes) */ - return 8; + return TRUE; } void rts_generate_cookie(BYTE* cookie) { + assert(cookie); winpr_RAND(cookie, 16); } +static BOOL rts_send_buffer(RpcChannel* channel, wStream* s, size_t frag_length) +{ + BOOL status = FALSE; + SSIZE_T rc; + + assert(channel); + assert(s); + + Stream_SealLength(s); + if (Stream_Length(s) < sizeof(rpcconn_common_hdr_t)) + goto fail; + if (Stream_Length(s) != frag_length) + goto fail; + + rc = rpc_channel_write(channel, Stream_Buffer(s), Stream_Length(s)); + if (rc < 0) + goto fail; + if ((size_t)rc != Stream_Length(s)) + goto fail; + status = TRUE; +fail: + return status; +} + /* CONN/A Sequence */ -int rts_send_CONN_A1_pdu(rdpRpc* rpc) +BOOL rts_send_CONN_A1_pdu(rdpRpc* rpc) { - int status; - BYTE* buffer; + BOOL status = FALSE; + wStream* buffer; rpcconn_rts_hdr_t header = rts_pdu_header_init(); UINT32 ReceiveWindowSize; BYTE* OUTChannelCookie; BYTE* VirtualConnectionCookie; - RpcVirtualConnection* connection = rpc->VirtualConnection; - RpcOutChannel* outChannel = connection->DefaultOutChannel; + RpcVirtualConnection* connection; + RpcOutChannel* outChannel; + + assert(rpc); + + connection = rpc->VirtualConnection; + assert(connection); + + outChannel = connection->DefaultOutChannel; + assert(outChannel); header.header.frag_length = 76; header.Flags = RTS_FLAG_NONE; header.NumberOfCommands = 4; + WLog_DBG(TAG, "Sending CONN/A1 RTS PDU"); VirtualConnectionCookie = (BYTE*)&(connection->Cookie); OUTChannelCookie = (BYTE*)&(outChannel->common.Cookie); ReceiveWindowSize = outChannel->ReceiveWindow; - buffer = (BYTE*)malloc(header.header.frag_length); + + buffer = Stream_New(NULL, header.header.frag_length); if (!buffer) return -1; - CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */ - rts_version_command_write(&buffer[20]); /* Version (8 bytes) */ - rts_cookie_command_write(&buffer[28], - VirtualConnectionCookie); /* VirtualConnectionCookie (20 bytes) */ - rts_cookie_command_write(&buffer[48], OUTChannelCookie); /* OUTChannelCookie (20 bytes) */ - rts_receive_window_size_command_write(&buffer[68], - ReceiveWindowSize); /* ReceiveWindowSize (8 bytes) */ - status = rpc_channel_write(&outChannel->common, buffer, header.header.frag_length); - free(buffer); - return (status > 0) ? 1 : -1; + if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */ + goto fail; + status = rts_version_command_write(buffer); /* Version (8 bytes) */ + if (!status) + goto fail; + status = rts_cookie_command_write( + buffer, VirtualConnectionCookie); /* VirtualConnectionCookie (20 bytes) */ + if (!status) + goto fail; + status = rts_cookie_command_write(buffer, OUTChannelCookie); /* OUTChannelCookie (20 bytes) */ + if (!status) + goto fail; + status = rts_receive_window_size_command_write( + buffer, ReceiveWindowSize); /* ReceiveWindowSize (8 bytes) */ + if (!status) + goto fail; + status = rts_send_buffer(&outChannel->common, buffer, header.header.frag_length); +fail: + Stream_Free(buffer, TRUE); + return status; } -int rts_recv_CONN_A3_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) +BOOL rts_recv_CONN_A3_pdu(rdpRpc* rpc, wStream* buffer) { + BOOL rc; UINT32 ConnectionTimeout; - rts_connection_timeout_command_read(rpc, &buffer[24], length - 24, &ConnectionTimeout); + + if (!Stream_SafeSeek(buffer, 24)) + return FALSE; + + rc = rts_connection_timeout_command_read(rpc, buffer, &ConnectionTimeout); + if (!rc) + return rc; + WLog_DBG(TAG, "Receiving CONN/A3 RTS PDU: ConnectionTimeout: %" PRIu32 "", ConnectionTimeout); + + assert(rpc); + assert(rpc->VirtualConnection); + assert(rpc->VirtualConnection->DefaultInChannel); + rpc->VirtualConnection->DefaultInChannel->PingOriginator.ConnectionTimeout = ConnectionTimeout; - return 1; + return TRUE; } /* CONN/B Sequence */ -int rts_send_CONN_B1_pdu(rdpRpc* rpc) +BOOL rts_send_CONN_B1_pdu(rdpRpc* rpc) { - int status; - BYTE* buffer; - UINT32 length; + BOOL status = FALSE; + wStream* buffer; rpcconn_rts_hdr_t header = rts_pdu_header_init(); BYTE* INChannelCookie; BYTE* AssociationGroupId; BYTE* VirtualConnectionCookie; - RpcVirtualConnection* connection = rpc->VirtualConnection; - RpcInChannel* inChannel = connection->DefaultInChannel; + RpcVirtualConnection* connection; + RpcInChannel* inChannel; + + assert(rpc); + + connection = rpc->VirtualConnection; + assert(connection); + + inChannel = connection->DefaultInChannel; + assert(inChannel); header.header.frag_length = 104; header.Flags = RTS_FLAG_NONE; header.NumberOfCommands = 6; + WLog_DBG(TAG, "Sending CONN/B1 RTS PDU"); + VirtualConnectionCookie = (BYTE*)&(connection->Cookie); INChannelCookie = (BYTE*)&(inChannel->common.Cookie); AssociationGroupId = (BYTE*)&(connection->AssociationGroupId); - buffer = (BYTE*)malloc(header.header.frag_length); + buffer = Stream_New(NULL, header.header.frag_length); if (!buffer) - return -1; - - CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */ - rts_version_command_write(&buffer[20]); /* Version (8 bytes) */ - rts_cookie_command_write(&buffer[28], - VirtualConnectionCookie); /* VirtualConnectionCookie (20 bytes) */ - rts_cookie_command_write(&buffer[48], INChannelCookie); /* INChannelCookie (20 bytes) */ - rts_channel_lifetime_command_write(&buffer[68], - rpc->ChannelLifetime); /* ChannelLifetime (8 bytes) */ - rts_client_keepalive_command_write(&buffer[76], - rpc->KeepAliveInterval); /* ClientKeepalive (8 bytes) */ - rts_association_group_id_command_write(&buffer[84], - AssociationGroupId); /* AssociationGroupId (20 bytes) */ - length = header.header.frag_length; - status = rpc_channel_write(&inChannel->common, buffer, length); - free(buffer); - return (status > 0) ? 1 : -1; + goto fail; + if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */ + goto fail; + if (!rts_version_command_write(buffer)) /* Version (8 bytes) */ + goto fail; + if (!rts_cookie_command_write(buffer, + VirtualConnectionCookie)) /* VirtualConnectionCookie (20 bytes) */ + goto fail; + if (!rts_cookie_command_write(buffer, INChannelCookie)) /* INChannelCookie (20 bytes) */ + goto fail; + if (!rts_channel_lifetime_command_write(buffer, + rpc->ChannelLifetime)) /* ChannelLifetime (8 bytes) */ + goto fail; + if (!rts_client_keepalive_command_write(buffer, + rpc->KeepAliveInterval)) /* ClientKeepalive (8 bytes) */ + goto fail; + if (!rts_association_group_id_command_write( + buffer, AssociationGroupId)) /* AssociationGroupId (20 bytes) */ + goto fail; + status = rts_send_buffer(&inChannel->common, buffer, header.header.frag_length); +fail: + Stream_Free(buffer, TRUE); + return status; } /* CONN/C Sequence */ -int rts_recv_CONN_C2_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) +BOOL rts_recv_CONN_C2_pdu(rdpRpc* rpc, wStream* buffer) { - UINT32 offset; + BOOL rc; UINT32 ReceiveWindowSize; UINT32 ConnectionTimeout; - offset = 24; - offset += rts_version_command_read(rpc, &buffer[offset], length - offset) + 4; - offset += rts_receive_window_size_command_read(rpc, &buffer[offset], length - offset, - &ReceiveWindowSize) + - 4; - offset += rts_connection_timeout_command_read(rpc, &buffer[offset], length - offset, - &ConnectionTimeout) + - 4; + + assert(rpc); + assert(buffer); + + if (!Stream_SafeSeek(buffer, 24)) + return FALSE; + + rc = rts_version_command_read(rpc, buffer); + if (rc < 0) + return rc; + rc = rts_receive_window_size_command_read(rpc, buffer, &ReceiveWindowSize); + if (rc < 0) + return rc; + rc = rts_connection_timeout_command_read(rpc, buffer, &ConnectionTimeout); + if (rc < 0) + return rc; WLog_DBG(TAG, "Receiving CONN/C2 RTS PDU: ConnectionTimeout: %" PRIu32 " ReceiveWindowSize: %" PRIu32 "", ConnectionTimeout, ReceiveWindowSize); + + assert(rpc); + assert(rpc->VirtualConnection); + assert(rpc->VirtualConnection->DefaultInChannel); + rpc->VirtualConnection->DefaultInChannel->PingOriginator.ConnectionTimeout = ConnectionTimeout; rpc->VirtualConnection->DefaultInChannel->PeerReceiveWindow = ReceiveWindowSize; - return 1; + return TRUE; } /* Out-of-Sequence PDUs */ -static int rts_send_keep_alive_pdu(rdpRpc* rpc) +BOOL rts_send_flow_control_ack_pdu(rdpRpc* rpc) { - int status; - BYTE* buffer; - UINT32 length; + BOOL status = FALSE; + wStream* buffer; rpcconn_rts_hdr_t header = rts_pdu_header_init(); - RpcInChannel* inChannel = rpc->VirtualConnection->DefaultInChannel; + UINT32 BytesReceived; + UINT32 AvailableWindow; + BYTE* ChannelCookie; + RpcVirtualConnection* connection; + RpcInChannel* inChannel; + RpcOutChannel* outChannel; - header.header.frag_length = 28; - header.Flags = RTS_FLAG_OTHER_CMD; - header.NumberOfCommands = 1; - WLog_DBG(TAG, "Sending Keep-Alive RTS PDU"); - buffer = (BYTE*)malloc(header.header.frag_length); + assert(rpc); - if (!buffer) - return -1; + connection = rpc->VirtualConnection; + assert(connection); - CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */ - rts_client_keepalive_command_write( - &buffer[20], rpc->CurrentKeepAliveInterval); /* ClientKeepAlive (8 bytes) */ - length = header.header.frag_length; - status = rpc_channel_write(&inChannel->common, buffer, length); - free(buffer); - return (status > 0) ? 1 : -1; -} + inChannel = connection->DefaultInChannel; + assert(inChannel); -int rts_send_flow_control_ack_pdu(rdpRpc* rpc) -{ - int status; - BYTE* buffer; - UINT32 length; - rpcconn_rts_hdr_t header = rts_pdu_header_init(); - UINT32 BytesReceived; - UINT32 AvailableWindow; - BYTE* ChannelCookie; - RpcVirtualConnection* connection = rpc->VirtualConnection; - RpcInChannel* inChannel = connection->DefaultInChannel; - RpcOutChannel* outChannel = connection->DefaultOutChannel; + outChannel = connection->DefaultOutChannel; + assert(outChannel); header.header.frag_length = 56; header.Flags = RTS_FLAG_OTHER_CMD; header.NumberOfCommands = 2; + WLog_DBG(TAG, "Sending FlowControlAck RTS PDU"); + BytesReceived = outChannel->BytesReceived; AvailableWindow = outChannel->AvailableWindowAdvertised; ChannelCookie = (BYTE*)&(outChannel->common.Cookie); outChannel->ReceiverAvailableWindow = outChannel->AvailableWindowAdvertised; - buffer = (BYTE*)malloc(header.header.frag_length); + buffer = Stream_New(NULL, header.header.frag_length); if (!buffer) - return -1; + goto fail; + + if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */ + goto fail; + if (!rts_destination_command_write(buffer, FDOutProxy)) /* Destination Command (8 bytes) */ + goto fail; - CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */ - rts_destination_command_write(&buffer[20], FDOutProxy); /* Destination Command (8 bytes) */ /* FlowControlAck Command (28 bytes) */ - rts_flow_control_ack_command_write(&buffer[28], BytesReceived, AvailableWindow, ChannelCookie); - length = header.header.frag_length; - status = rpc_channel_write(&inChannel->common, buffer, length); - free(buffer); - return (status > 0) ? 1 : -1; + if (!rts_flow_control_ack_command_write(buffer, BytesReceived, AvailableWindow, ChannelCookie)) + goto fail; + + status = rts_send_buffer(&inChannel->common, buffer, header.header.frag_length); +fail: + Stream_Free(buffer, TRUE); + return status; } -static int rts_recv_flow_control_ack_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) +static int rts_recv_flow_control_ack_pdu(rdpRpc* rpc, wStream* buffer) { - UINT32 offset; + int rc; UINT32 BytesReceived; UINT32 AvailableWindow; - BYTE ChannelCookie[16]; - offset = 24; - offset += - rts_flow_control_ack_command_read(rpc, &buffer[offset], length - offset, &BytesReceived, - &AvailableWindow, (BYTE*)&ChannelCookie) + - 4; + BYTE ChannelCookie[16] = { 0 }; + + rc = rts_flow_control_ack_command_read(rpc, buffer, &BytesReceived, &AvailableWindow, + (BYTE*)&ChannelCookie); + if (rc < 0) + return rc; WLog_ERR(TAG, "Receiving FlowControlAck RTS PDU: BytesReceived: %" PRIu32 " AvailableWindow: %" PRIu32 "", BytesReceived, AvailableWindow); + + assert(rpc->VirtualConnection); + assert(rpc->VirtualConnection->DefaultInChannel); + rpc->VirtualConnection->DefaultInChannel->SenderAvailableWindow = AvailableWindow - (rpc->VirtualConnection->DefaultInChannel->BytesSent - BytesReceived); return 1; } -static int rts_recv_flow_control_ack_with_destination_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) +static int rts_recv_flow_control_ack_with_destination_pdu(rdpRpc* rpc, wStream* buffer) { - UINT32 offset; + int rc; UINT32 Destination; UINT32 BytesReceived; UINT32 AvailableWindow; - BYTE ChannelCookie[16]; + BYTE ChannelCookie[16] = { 0 }; /** * When the sender receives a FlowControlAck RTS PDU, it MUST use the following formula to * recalculate its Sender AvailableWindow variable: @@ -622,16 +1722,23 @@ static int rts_recv_flow_control_ack_with_destination_pdu(rdpRpc* rpc, BYTE* buf * structure in the PDU received. * */ - offset = 24; - offset += rts_destination_command_read(rpc, &buffer[offset], length - offset, &Destination) + 4; - offset += - rts_flow_control_ack_command_read(rpc, &buffer[offset], length - offset, &BytesReceived, - &AvailableWindow, (BYTE*)&ChannelCookie) + - 4; + + rc = rts_destination_command_read(rpc, buffer, &Destination); + if (rc < 0) + return rc; + + rc = rts_flow_control_ack_command_read(rpc, buffer, &BytesReceived, &AvailableWindow, + ChannelCookie); + if (rc < 0) + return rc; + WLog_DBG(TAG, "Receiving FlowControlAckWithDestination RTS PDU: BytesReceived: %" PRIu32 " AvailableWindow: %" PRIu32 "", BytesReceived, AvailableWindow); + + assert(rpc->VirtualConnection); + assert(rpc->VirtualConnection->DefaultInChannel); rpc->VirtualConnection->DefaultInChannel->SenderAvailableWindow = AvailableWindow - (rpc->VirtualConnection->DefaultInChannel->BytesSent - BytesReceived); return 1; @@ -639,31 +1746,41 @@ static int rts_recv_flow_control_ack_with_destination_pdu(rdpRpc* rpc, BYTE* buf static int rts_send_ping_pdu(rdpRpc* rpc) { - int status; - BYTE* buffer; - UINT32 length; + BOOL status = FALSE; + wStream* buffer; rpcconn_rts_hdr_t header = rts_pdu_header_init(); - RpcInChannel* inChannel = rpc->VirtualConnection->DefaultInChannel; + RpcInChannel* inChannel; + + assert(rpc); + assert(rpc->VirtualConnection); + + inChannel = rpc->VirtualConnection->DefaultInChannel; + assert(inChannel); header.header.frag_length = 20; header.Flags = RTS_FLAG_PING; header.NumberOfCommands = 0; + WLog_DBG(TAG, "Sending Ping RTS PDU"); - buffer = (BYTE*)malloc(header.header.frag_length); + buffer = Stream_New(NULL, header.header.frag_length); if (!buffer) - return -1; - - CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */ - length = header.header.frag_length; - status = rpc_channel_write(&inChannel->common, buffer, length); - free(buffer); - return (status > 0) ? 1 : -1; + goto fail; + + if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */ + goto fail; + status = rts_send_buffer(&inChannel->common, buffer, header.header.frag_length); +fail: + Stream_Free(buffer, TRUE); + return (status) ? 1 : -1; } -SSIZE_T rts_command_length(UINT32 CommandType, const BYTE* buffer, size_t length) +BOOL rts_command_length(UINT32 CommandType, wStream* s, size_t* length) { - int CommandLength = 0; + size_t padding = 0; + size_t CommandLength = 0; + + assert(s); switch (CommandType) { @@ -700,7 +1817,8 @@ SSIZE_T rts_command_length(UINT32 CommandType, const BYTE* buffer, size_t length break; case RTS_CMD_PADDING: /* variable-size */ - CommandLength = rts_padding_command_read(buffer, length); + if (!rts_padding_command_read(s, &padding)) + return FALSE; break; case RTS_CMD_NEGATIVE_ANCE: @@ -712,7 +1830,8 @@ SSIZE_T rts_command_length(UINT32 CommandType, const BYTE* buffer, size_t length break; case RTS_CMD_CLIENT_ADDRESS: /* variable-size */ - CommandLength = rts_client_address_command_read(buffer, length); + if (!rts_client_address_command_read(s, &CommandLength)) + return FALSE; break; case RTS_CMD_ASSOCIATION_GROUP_ID: @@ -729,118 +1848,176 @@ SSIZE_T rts_command_length(UINT32 CommandType, const BYTE* buffer, size_t length default: WLog_ERR(TAG, "Error: Unknown RTS Command Type: 0x%" PRIx32 "", CommandType); - return -1; + return FALSE; } - return CommandLength; + CommandLength += padding; + if (Stream_GetRemainingLength(s) < CommandLength) + return FALSE; + + if (length) + *length = CommandLength; + return TRUE; } static int rts_send_OUT_R2_A7_pdu(rdpRpc* rpc) { - int status; - BYTE* buffer; + BOOL status = FALSE; + wStream* buffer; rpcconn_rts_hdr_t header = rts_pdu_header_init(); BYTE* SuccessorChannelCookie; - RpcInChannel* inChannel = rpc->VirtualConnection->DefaultInChannel; - RpcOutChannel* nextOutChannel = rpc->VirtualConnection->NonDefaultOutChannel; + RpcInChannel* inChannel; + RpcOutChannel* nextOutChannel; + + assert(rpc); + assert(rpc->VirtualConnection); + + inChannel = rpc->VirtualConnection->DefaultInChannel; + assert(inChannel); + + nextOutChannel = rpc->VirtualConnection->NonDefaultOutChannel; + assert(nextOutChannel); header.header.frag_length = 56; header.Flags = RTS_FLAG_OUT_CHANNEL; header.NumberOfCommands = 3; + WLog_DBG(TAG, "Sending OUT_R2/A7 RTS PDU"); + SuccessorChannelCookie = (BYTE*)&(nextOutChannel->common.Cookie); - buffer = (BYTE*)malloc(header.header.frag_length); + buffer = Stream_New(NULL, header.header.frag_length); if (!buffer) return -1; - CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */ - rts_destination_command_write(&buffer[20], FDServer); /* Destination (8 bytes)*/ - rts_cookie_command_write(&buffer[28], - SuccessorChannelCookie); /* SuccessorChannelCookie (20 bytes) */ - rts_version_command_write(&buffer[48]); /* Version (8 bytes) */ - status = rpc_channel_write(&inChannel->common, buffer, header.header.frag_length); - free(buffer); - return (status > 0) ? 1 : -1; + if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */ + goto fail; + if (!rts_destination_command_write(buffer, FDServer)) /* Destination (8 bytes)*/ + goto fail; + if (!rts_cookie_command_write(buffer, + SuccessorChannelCookie)) /* SuccessorChannelCookie (20 bytes) */ + goto fail; + if (!rts_version_command_write(buffer)) /* Version (8 bytes) */ + goto fail; + status = rts_send_buffer(&inChannel->common, buffer, header.header.frag_length); +fail: + Stream_Free(buffer, TRUE); + return (status) ? 1 : -1; } static int rts_send_OUT_R2_C1_pdu(rdpRpc* rpc) { - int status; - BYTE* buffer; + BOOL status = FALSE; + wStream* buffer; rpcconn_rts_hdr_t header = rts_pdu_header_init(); - RpcOutChannel* nextOutChannel = rpc->VirtualConnection->NonDefaultOutChannel; + RpcOutChannel* nextOutChannel; + + assert(rpc); + assert(rpc->VirtualConnection); + + nextOutChannel = rpc->VirtualConnection->NonDefaultOutChannel; + assert(nextOutChannel); header.header.frag_length = 24; header.Flags = RTS_FLAG_PING; header.NumberOfCommands = 1; + WLog_DBG(TAG, "Sending OUT_R2/C1 RTS PDU"); - buffer = (BYTE*)malloc(header.header.frag_length); + buffer = Stream_New(NULL, header.header.frag_length); if (!buffer) return -1; - CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */ - rts_empty_command_write(&buffer[20]); /* Empty command (4 bytes) */ - status = rpc_channel_write(&nextOutChannel->common, buffer, header.header.frag_length); - free(buffer); - return (status > 0) ? 1 : -1; + if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */ + goto fail; + + if (!rts_empty_command_write(buffer)) /* Empty command (4 bytes) */ + goto fail; + status = rts_send_buffer(&nextOutChannel->common, buffer, header.header.frag_length); +fail: + Stream_Free(buffer, TRUE); + return (status) ? 1 : -1; } -int rts_send_OUT_R1_A3_pdu(rdpRpc* rpc) +BOOL rts_send_OUT_R1_A3_pdu(rdpRpc* rpc) { - int status; - BYTE* buffer; + BOOL status = FALSE; + wStream* buffer; rpcconn_rts_hdr_t header = rts_pdu_header_init(); UINT32 ReceiveWindowSize; BYTE* VirtualConnectionCookie; BYTE* PredecessorChannelCookie; BYTE* SuccessorChannelCookie; - RpcVirtualConnection* connection = rpc->VirtualConnection; - RpcOutChannel* outChannel = connection->DefaultOutChannel; - RpcOutChannel* nextOutChannel = connection->NonDefaultOutChannel; + RpcVirtualConnection* connection; + RpcOutChannel* outChannel; + RpcOutChannel* nextOutChannel; + + assert(rpc); + + connection = rpc->VirtualConnection; + assert(connection); + + outChannel = connection->DefaultOutChannel; + assert(outChannel); + + nextOutChannel = connection->NonDefaultOutChannel; + assert(nextOutChannel); header.header.frag_length = 96; header.Flags = RTS_FLAG_RECYCLE_CHANNEL; header.NumberOfCommands = 5; + WLog_DBG(TAG, "Sending OUT_R1/A3 RTS PDU"); + VirtualConnectionCookie = (BYTE*)&(connection->Cookie); PredecessorChannelCookie = (BYTE*)&(outChannel->common.Cookie); SuccessorChannelCookie = (BYTE*)&(nextOutChannel->common.Cookie); ReceiveWindowSize = outChannel->ReceiveWindow; - buffer = (BYTE*)malloc(header.header.frag_length); + buffer = Stream_New(NULL, header.header.frag_length); if (!buffer) return -1; - CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */ - rts_version_command_write(&buffer[20]); /* Version (8 bytes) */ - rts_cookie_command_write(&buffer[28], - VirtualConnectionCookie); /* VirtualConnectionCookie (20 bytes) */ - rts_cookie_command_write(&buffer[48], - PredecessorChannelCookie); /* PredecessorChannelCookie (20 bytes) */ - rts_cookie_command_write(&buffer[68], - SuccessorChannelCookie); /* SuccessorChannelCookie (20 bytes) */ - rts_receive_window_size_command_write(&buffer[88], - ReceiveWindowSize); /* ReceiveWindowSize (8 bytes) */ - status = rpc_channel_write(&nextOutChannel->common, buffer, header.header.frag_length); - free(buffer); - return (status > 0) ? 1 : -1; + if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */ + goto fail; + if (!rts_version_command_write(buffer)) /* Version (8 bytes) */ + goto fail; + if (!rts_cookie_command_write(buffer, + VirtualConnectionCookie)) /* VirtualConnectionCookie (20 bytes) */ + goto fail; + if (!rts_cookie_command_write( + buffer, PredecessorChannelCookie)) /* PredecessorChannelCookie (20 bytes) */ + goto fail; + if (!rts_cookie_command_write(buffer, + SuccessorChannelCookie)) /* SuccessorChannelCookie (20 bytes) */ + goto fail; + if (!rts_receive_window_size_command_write(buffer, + ReceiveWindowSize)) /* ReceiveWindowSize (8 bytes) */ + goto fail; + + status = rts_send_buffer(&nextOutChannel->common, buffer, header.header.frag_length); +fail: + Stream_Free(buffer, TRUE); + return status; } -static int rts_recv_OUT_R1_A2_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) +static int rts_recv_OUT_R1_A2_pdu(rdpRpc* rpc, wStream* buffer) { int status; - UINT32 offset; UINT32 Destination = 0; - RpcVirtualConnection* connection = rpc->VirtualConnection; + RpcVirtualConnection* connection; + assert(rpc); + assert(buffer); + + connection = rpc->VirtualConnection; + assert(connection); + WLog_DBG(TAG, "Receiving OUT R1/A2 RTS PDU"); - offset = 24; - if (length < offset) - return -1; + status = rts_destination_command_read(rpc, buffer, &Destination); + if (status < 0) + return status; - rts_destination_command_read(rpc, &buffer[offset], length - offset, &Destination); connection->NonDefaultOutChannel = rpc_out_channel_new(rpc); if (!connection->NonDefaultOutChannel) @@ -859,10 +2036,17 @@ static int rts_recv_OUT_R1_A2_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) return 1; } -static int rts_recv_OUT_R2_A6_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) +static int rts_recv_OUT_R2_A6_pdu(rdpRpc* rpc, wStream* buffer) { int status; - RpcVirtualConnection* connection = rpc->VirtualConnection; + RpcVirtualConnection* connection; + + assert(rpc); + assert(buffer); + + connection = rpc->VirtualConnection; + assert(connection); + WLog_DBG(TAG, "Receiving OUT R2/A6 RTS PDU"); status = rts_send_OUT_R2_C1_pdu(rpc); @@ -887,47 +2071,59 @@ static int rts_recv_OUT_R2_A6_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) return 1; } -static int rts_recv_OUT_R2_B3_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) +static int rts_recv_OUT_R2_B3_pdu(rdpRpc* rpc, wStream* buffer) { - RpcVirtualConnection* connection = rpc->VirtualConnection; + RpcVirtualConnection* connection; + + assert(rpc); + assert(buffer); + + connection = rpc->VirtualConnection; + assert(connection); + WLog_DBG(TAG, "Receiving OUT R2/B3 RTS PDU"); rpc_out_channel_transition_to_state(connection->DefaultOutChannel, CLIENT_OUT_CHANNEL_STATE_RECYCLED); return 1; } -int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) +BOOL rts_recv_out_of_sequence_pdu(rdpRpc* rpc, wStream* buffer, const rpcconn_hdr_t* header) { - int status = -1; + BOOL status = FALSE; UINT32 SignatureId; - rpcconn_rts_hdr_t* rts; - RtsPduSignature signature; + size_t length, total; + RtsPduSignature signature = { 0 }; RpcVirtualConnection* connection; - if (!rpc || !buffer) - return -1; + assert(rpc); + assert(buffer); + assert(header); + + total = Stream_Length(buffer); + length = header->common.frag_length; + if (total < length) + return FALSE; connection = rpc->VirtualConnection; if (!connection) - return -1; - - rts = (rpcconn_rts_hdr_t*)buffer; + return FALSE; - if (!rts_extract_pdu_signature(&signature, rts)) - return -1; + if (!rts_extract_pdu_signature(&signature, buffer, header)) + return FALSE; SignatureId = rts_identify_pdu_signature(&signature, NULL); - if (rts_match_pdu_signature(&RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE, rts)) + if (rts_match_pdu_signature(&RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE, buffer, header)) { - status = rts_recv_flow_control_ack_pdu(rpc, buffer, length); + status = rts_recv_flow_control_ack_pdu(rpc, buffer); } - else if (rts_match_pdu_signature(&RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE, rts)) + else if (rts_match_pdu_signature(&RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE, buffer, + header)) { - status = rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length); + status = rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer); } - else if (rts_match_pdu_signature(&RTS_PDU_PING_SIGNATURE, rts)) + else if (rts_match_pdu_signature(&RTS_PDU_PING_SIGNATURE, buffer, header)) { status = rts_send_ping_pdu(rpc); } @@ -935,28 +2131,28 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) { if (connection->DefaultOutChannel->State == CLIENT_OUT_CHANNEL_STATE_OPENED) { - if (rts_match_pdu_signature(&RTS_PDU_OUT_R1_A2_SIGNATURE, rts)) + if (rts_match_pdu_signature(&RTS_PDU_OUT_R1_A2_SIGNATURE, buffer, header)) { - status = rts_recv_OUT_R1_A2_pdu(rpc, buffer, length); + status = rts_recv_OUT_R1_A2_pdu(rpc, buffer); } } else if (connection->DefaultOutChannel->State == CLIENT_OUT_CHANNEL_STATE_OPENED_A6W) { - if (rts_match_pdu_signature(&RTS_PDU_OUT_R2_A6_SIGNATURE, rts)) + if (rts_match_pdu_signature(&RTS_PDU_OUT_R2_A6_SIGNATURE, buffer, header)) { - status = rts_recv_OUT_R2_A6_pdu(rpc, buffer, length); + status = rts_recv_OUT_R2_A6_pdu(rpc, buffer); } } else if (connection->DefaultOutChannel->State == CLIENT_OUT_CHANNEL_STATE_OPENED_B3W) { - if (rts_match_pdu_signature(&RTS_PDU_OUT_R2_B3_SIGNATURE, rts)) + if (rts_match_pdu_signature(&RTS_PDU_OUT_R2_B3_SIGNATURE, buffer, header)) { - status = rts_recv_OUT_R2_B3_pdu(rpc, buffer, length); + status = rts_recv_OUT_R2_B3_pdu(rpc, buffer); } } } - if (status < 0) + if (!status) { WLog_ERR(TAG, "error parsing RTS PDU with signature id: 0x%08" PRIX32 "", SignatureId); rts_print_pdu_signature(&signature); @@ -964,3 +2160,42 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) return status; } + +BOOL rts_write_pdu_auth3(wStream* s, const rpcconn_rpc_auth_3_hdr_t* auth) +{ + assert(s); + assert(auth); + + if (!rts_write_common_pdu_header(s, &auth->header)) + return FALSE; + + if (!Stream_EnsureRemainingCapacity(s, 2 * sizeof(UINT16))) + return FALSE; + + Stream_Write_UINT16(s, auth->max_xmit_frag); + Stream_Write_UINT16(s, auth->max_recv_frag); + + return rts_write_auth_verifier(s, &auth->auth_verifier, &auth->header); +} + +BOOL rts_write_pdu_bind(wStream* s, const rpcconn_bind_hdr_t* bind) +{ + + assert(s); + assert(bind); + + if (!rts_write_common_pdu_header(s, &bind->header)) + return FALSE; + + if (!Stream_EnsureRemainingCapacity(s, 8)) + return FALSE; + + Stream_Write_UINT16(s, bind->max_xmit_frag); + Stream_Write_UINT16(s, bind->max_recv_frag); + Stream_Write_UINT32(s, bind->assoc_group_id); + + if (!rts_write_context_list(s, &bind->p_context_elem)) + return FALSE; + + return rts_write_auth_verifier(s, &bind->auth_verifier, &bind->header); +} diff --git a/libfreerdp/core/gateway/rts.h b/libfreerdp/core/gateway/rts.h index ccc4cfd..01a66a7 100644 --- a/libfreerdp/core/gateway/rts.h +++ b/libfreerdp/core/gateway/rts.h @@ -24,12 +24,14 @@ #include "config.h" #endif -#include "rpc.h" +#include #include #include #include +#include "rpc.h" + #define RTS_FLAG_NONE 0x0000 #define RTS_FLAG_PING 0x0001 #define RTS_FLAG_OTHER_CMD 0x0002 @@ -79,21 +81,28 @@ FREERDP_LOCAL void rts_generate_cookie(BYTE* cookie); -FREERDP_LOCAL SSIZE_T rts_command_length(UINT32 CommandType, const BYTE* buffer, size_t length); +FREERDP_LOCAL BOOL rts_write_pdu_auth3(wStream* s, const rpcconn_rpc_auth_3_hdr_t* auth); +FREERDP_LOCAL BOOL rts_write_pdu_bind(wStream* s, const rpcconn_bind_hdr_t* bind); + +FREERDP_LOCAL BOOL rts_read_pdu_header(wStream* s, rpcconn_hdr_t* header); +FREERDP_LOCAL void rts_free_pdu_header(rpcconn_hdr_t* header, BOOL allocated); + +FREERDP_LOCAL BOOL rts_read_common_pdu_header(wStream* s, rpcconn_common_hdr_t* header); -FREERDP_LOCAL int rts_send_CONN_A1_pdu(rdpRpc* rpc); -FREERDP_LOCAL int rts_recv_CONN_A3_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length); +FREERDP_LOCAL BOOL rts_command_length(UINT32 CommandType, wStream* s, size_t* length); -FREERDP_LOCAL int rts_send_CONN_B1_pdu(rdpRpc* rpc); +FREERDP_LOCAL BOOL rts_send_CONN_A1_pdu(rdpRpc* rpc); +FREERDP_LOCAL BOOL rts_recv_CONN_A3_pdu(rdpRpc* rpc, wStream* buffer); -FREERDP_LOCAL int rts_recv_CONN_C2_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length); +FREERDP_LOCAL BOOL rts_send_CONN_B1_pdu(rdpRpc* rpc); -FREERDP_LOCAL int rts_send_OUT_R1_A3_pdu(rdpRpc* rpc); +FREERDP_LOCAL BOOL rts_recv_CONN_C2_pdu(rdpRpc* rpc, wStream* buffer); -FREERDP_LOCAL int rts_send_flow_control_ack_pdu(rdpRpc* rpc); +FREERDP_LOCAL BOOL rts_send_OUT_R1_A3_pdu(rdpRpc* rpc); -FREERDP_LOCAL int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length); +FREERDP_LOCAL BOOL rts_send_flow_control_ack_pdu(rdpRpc* rpc); -#include "rts_signature.h" +FREERDP_LOCAL BOOL rts_recv_out_of_sequence_pdu(rdpRpc* rpc, wStream* buffer, + const rpcconn_hdr_t* header); #endif /* FREERDP_LIB_CORE_GATEWAY_RTS_H */ diff --git a/libfreerdp/core/gateway/rts_signature.c b/libfreerdp/core/gateway/rts_signature.c index 4d605f0..b9319d1 100644 --- a/libfreerdp/core/gateway/rts_signature.c +++ b/libfreerdp/core/gateway/rts_signature.c @@ -17,6 +17,9 @@ * limitations under the License. */ +#include +#include + #include #include "rts_signature.h" @@ -276,90 +279,74 @@ static const RTS_PDU_SIGNATURE_ENTRY RTS_PDU_SIGNATURE_TABLE[] = { { RTS_PDU_PING, TRUE, &RTS_PDU_PING_SIGNATURE, "Ping" }, { RTS_PDU_FLOW_CONTROL_ACK, TRUE, &RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE, "FlowControlAck" }, { RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION, TRUE, - &RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE, "FlowControlAckWithDestination" }, - - { 0, FALSE, NULL, NULL } + &RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE, "FlowControlAckWithDestination" } }; -BOOL rts_match_pdu_signature(const RtsPduSignature* signature, const rpcconn_rts_hdr_t* rts) +BOOL rts_match_pdu_signature(const RtsPduSignature* signature, wStream* src, + const rpcconn_hdr_t* header) { - UINT16 i; - int status; - const BYTE* buffer; - UINT32 length; - UINT32 offset; - UINT32 CommandType; - UINT32 CommandLength; - - if (!signature || !rts) - return FALSE; + RtsPduSignature extracted = { 0 }; - if (rts->Flags != signature->Flags) - return FALSE; + assert(signature); + assert(src); - if (rts->NumberOfCommands != signature->NumberOfCommands) + if (!rts_extract_pdu_signature(&extracted, src, header)) return FALSE; - buffer = (const BYTE*)rts; - offset = RTS_PDU_HEADER_LENGTH; - length = rts->header.frag_length - offset; - - for (i = 0; i < rts->NumberOfCommands; i++) - { - CommandType = *((UINT32*)&buffer[offset]); /* CommandType (4 bytes) */ - offset += 4; - - if (CommandType != signature->CommandTypes[i]) - return FALSE; + return memcmp(signature, &extracted, sizeof(extracted)) == 0; +} - status = rts_command_length(CommandType, &buffer[offset], length); +BOOL rts_extract_pdu_signature(RtsPduSignature* signature, wStream* src, + const rpcconn_hdr_t* header) +{ + BOOL rc = FALSE; + UINT16 i; + wStream tmp; + rpcconn_hdr_t rheader = { 0 }; + const rpcconn_rts_hdr_t* rts; - if (status < 0) - return FALSE; + assert(signature); + assert(src); - CommandLength = (UINT32)status; - offset += CommandLength; - length = rts->header.frag_length - offset; + Stream_StaticInit(&tmp, Stream_Pointer(src), Stream_GetRemainingLength(src)); + if (!header) + { + if (!rts_read_pdu_header(&tmp, &rheader)) + goto fail; + header = &rheader; } - - return TRUE; -} - -BOOL rts_extract_pdu_signature(RtsPduSignature* signature, const rpcconn_rts_hdr_t* rts) -{ - int i; - int status; - BYTE* buffer; - UINT32 length; - UINT32 offset; - UINT32 CommandType; - UINT32 CommandLength; - - if (!signature || !rts) - return FALSE; + rts = &header->rts; + if (rts->header.frag_length < sizeof(rpcconn_rts_hdr_t)) + goto fail; signature->Flags = rts->Flags; signature->NumberOfCommands = rts->NumberOfCommands; - buffer = (BYTE*)rts; - offset = RTS_PDU_HEADER_LENGTH; - length = rts->header.frag_length - offset; for (i = 0; i < rts->NumberOfCommands; i++) { - CommandType = *((UINT32*)&buffer[offset]); /* CommandType (4 bytes) */ - offset += 4; - signature->CommandTypes[i] = CommandType; - status = rts_command_length(CommandType, &buffer[offset], length); + UINT32 CommandType; + size_t CommandLength; + + if (Stream_GetRemainingLength(&tmp) < 4) + goto fail; + + Stream_Read_UINT32(&tmp, CommandType); /* CommandType (4 bytes) */ - if (status < 0) - return FALSE; + /* We only need this for comparison against known command types */ + if (i < ARRAYSIZE(signature->CommandTypes)) + signature->CommandTypes[i] = CommandType; - CommandLength = (UINT32)status; - offset += CommandLength; - length = rts->header.frag_length - offset; + if (!rts_command_length(CommandType, &tmp, &CommandLength)) + goto fail; + if (!Stream_SafeSeek(&tmp, CommandLength)) + goto fail; } - return TRUE; + rc = TRUE; +fail: + rts_free_pdu_header(&rheader, FALSE); + Stream_Free(&tmp, FALSE); + return rc; } UINT32 rts_identify_pdu_signature(const RtsPduSignature* signature, @@ -367,11 +354,15 @@ UINT32 rts_identify_pdu_signature(const RtsPduSignature* signature, { size_t i, j; - for (i = 0; RTS_PDU_SIGNATURE_TABLE[i].SignatureId != 0; i++) + if (entry) + *entry = NULL; + + for (i = 0; i < ARRAYSIZE(RTS_PDU_SIGNATURE_TABLE); i++) { - const RtsPduSignature* pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature; + const RTS_PDU_SIGNATURE_ENTRY* current = &RTS_PDU_SIGNATURE_TABLE[i]; + const RtsPduSignature* pSignature = current->Signature; - if (!RTS_PDU_SIGNATURE_TABLE[i].SignatureClient) + if (!current->SignatureClient) continue; if (signature->Flags != pSignature->Flags) @@ -387,9 +378,9 @@ UINT32 rts_identify_pdu_signature(const RtsPduSignature* signature, } if (entry) - *entry = &RTS_PDU_SIGNATURE_TABLE[i]; + *entry = current; - return RTS_PDU_SIGNATURE_TABLE[i].SignatureId; + return current->SignatureId; } return 0; diff --git a/libfreerdp/core/gateway/rts_signature.h b/libfreerdp/core/gateway/rts_signature.h index 2c43cdc..31f0e81 100644 --- a/libfreerdp/core/gateway/rts_signature.h +++ b/libfreerdp/core/gateway/rts_signature.h @@ -178,10 +178,10 @@ FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_PING_SIGNATURE; FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE; FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE; -FREERDP_LOCAL BOOL rts_match_pdu_signature(const RtsPduSignature* signature, - const rpcconn_rts_hdr_t* rts); -FREERDP_LOCAL BOOL rts_extract_pdu_signature(RtsPduSignature* signature, - const rpcconn_rts_hdr_t* rts); +FREERDP_LOCAL BOOL rts_match_pdu_signature(const RtsPduSignature* signature, wStream* s, + const rpcconn_hdr_t* header); +FREERDP_LOCAL BOOL rts_extract_pdu_signature(RtsPduSignature* signature, wStream* s, + const rpcconn_hdr_t* header); FREERDP_LOCAL UINT32 rts_identify_pdu_signature(const RtsPduSignature* signature, const RTS_PDU_SIGNATURE_ENTRY** entry); FREERDP_LOCAL BOOL rts_print_pdu_signature(const RtsPduSignature* signature); diff --git a/libfreerdp/core/gateway/tsg.c b/libfreerdp/core/gateway/tsg.c index 3376fb0..ee56a38 100644 --- a/libfreerdp/core/gateway/tsg.c +++ b/libfreerdp/core/gateway/tsg.c @@ -24,7 +24,7 @@ #include "config.h" #endif -#include +#include #include #include #include @@ -221,7 +221,6 @@ struct rdp_tsg UINT32 TunnelId; UINT32 ChannelId; BOOL reauthSequence; - rdpSettings* settings; rdpTransport* transport; UINT64 ReauthTunnelContext; CONTEXT_HANDLE TunnelContext; @@ -310,9 +309,9 @@ static BOOL tsg_print(char** buffer, size_t* len, const char* fmt, ...) static BOOL tsg_packet_header_to_string(char** buffer, size_t* length, const TSG_PACKET_HEADER* header) { - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(header); + assert(buffer); + assert(length); + assert(header); return tsg_print(buffer, length, "header { ComponentId=0x%04" PRIx16 ", PacketId=0x%04" PRIx16 " }", @@ -322,9 +321,9 @@ static BOOL tsg_packet_header_to_string(char** buffer, size_t* length, static BOOL tsg_type_capability_nap_to_string(char** buffer, size_t* length, const TSG_CAPABILITY_NAP* cur) { - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(cur); + assert(buffer); + assert(length); + assert(cur); return tsg_print(buffer, length, "%s { capabilities=0x%08" PRIx32 " }", tsg_packet_id_to_string(TSG_CAPABILITY_TYPE_NAP), cur->capabilities); @@ -335,9 +334,9 @@ static BOOL tsg_packet_capabilities_to_string(char** buffer, size_t* length, { UINT32 x; - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(caps); + assert(buffer); + assert(length); + assert(caps); if (!tsg_print(buffer, length, "capabilities { ")) return FALSE; @@ -363,9 +362,9 @@ static BOOL tsg_packet_capabilities_to_string(char** buffer, size_t* length, static BOOL tsg_packet_versioncaps_to_string(char** buffer, size_t* length, const TSG_PACKET_VERSIONCAPS* caps) { - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(caps); + assert(buffer); + assert(length); + assert(caps); if (!tsg_print(buffer, length, "versioncaps { ")) return FALSE; @@ -391,9 +390,9 @@ static BOOL tsg_packet_versioncaps_to_string(char** buffer, size_t* length, static BOOL tsg_packet_quarconfigrequest_to_string(char** buffer, size_t* length, const TSG_PACKET_QUARCONFIGREQUEST* caps) { - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(caps); + assert(buffer); + assert(length); + assert(caps); if (!tsg_print(buffer, length, "quarconfigrequest { ")) return FALSE; @@ -414,9 +413,9 @@ static BOOL tsg_packet_quarrequest_to_string(char** buffer, size_t* length, char* name = NULL; char* strdata = NULL; - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(caps); + assert(buffer); + assert(length); + assert(caps); if (!tsg_print(buffer, length, "quarrequest { ")) return FALSE; @@ -426,7 +425,9 @@ static BOOL tsg_packet_quarrequest_to_string(char** buffer, size_t* length, if (caps->nameLength > 0) { - if (ConvertFromUnicode(CP_UTF8, 0, caps->machineName, caps->nameLength, &name, 0, NULL, + if (caps->nameLength > INT_MAX) + return FALSE; + if (ConvertFromUnicode(CP_UTF8, 0, caps->machineName, (int)caps->nameLength, &name, 0, NULL, NULL) < 0) return FALSE; } @@ -454,8 +455,8 @@ static const char* tsg_bool_to_string(BOOL val) static const char* tsg_redirection_flags_to_string(char* buffer, size_t size, const TSG_REDIRECTION_FLAGS* flags) { - WINPR_ASSERT(buffer || (size == 0)); - WINPR_ASSERT(flags); + assert(buffer || (size == 0)); + assert(flags); _snprintf(buffer, size, "enableAllRedirections=%s, disableAllRedirections=%s, driveRedirectionDisabled=%s, " @@ -479,9 +480,9 @@ static BOOL tsg_packet_response_to_string(char** buffer, size_t* length, char* strdata = NULL; char tbuffer[8192] = { 0 }; - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(caps); + assert(buffer); + assert(length); + assert(caps); if (!tsg_print(buffer, length, "response { ")) return FALSE; @@ -514,9 +515,9 @@ static BOOL tsg_packet_quarenc_response_to_string(char** buffer, size_t* length, size_t size = ARRAYSIZE(tbuffer); char* ptbuffer = tbuffer; - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(caps); + assert(buffer); + assert(length); + assert(caps); if (!tsg_print(buffer, length, "quarenc_response { ")) return FALSE; @@ -526,8 +527,10 @@ static BOOL tsg_packet_quarenc_response_to_string(char** buffer, size_t* length, if (caps->certChainLen > 0) { - if (ConvertFromUnicode(CP_UTF8, 0, caps->certChainData, caps->certChainLen, &strdata, 0, - NULL, NULL) <= 0) + if (caps->certChainLen > INT_MAX) + return FALSE; + if (ConvertFromUnicode(CP_UTF8, 0, caps->certChainData, (int)caps->certChainLen, &strdata, + 0, NULL, NULL) <= 0) return FALSE; } @@ -549,9 +552,9 @@ static BOOL tsg_packet_quarenc_response_to_string(char** buffer, size_t* length, static BOOL tsg_packet_message_response_to_string(char** buffer, size_t* length, const TSG_PACKET_MSG_RESPONSE* caps) { - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(caps); + assert(buffer); + assert(length); + assert(caps); if (!tsg_print(buffer, length, "msg_response { ")) return FALSE; @@ -567,9 +570,9 @@ static BOOL tsg_packet_message_response_to_string(char** buffer, size_t* length, static BOOL tsg_packet_caps_response_to_string(char** buffer, size_t* length, const TSG_PACKET_CAPS_RESPONSE* caps) { - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(caps); + assert(buffer); + assert(length); + assert(caps); if (!tsg_print(buffer, length, "caps_response { ")) return FALSE; @@ -586,9 +589,9 @@ static BOOL tsg_packet_caps_response_to_string(char** buffer, size_t* length, static BOOL tsg_packet_message_request_to_string(char** buffer, size_t* length, const TSG_PACKET_MSG_REQUEST* caps) { - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(caps); + assert(buffer); + assert(length); + assert(caps); if (!tsg_print(buffer, length, "caps_message_request { ")) return FALSE; @@ -603,9 +606,9 @@ static BOOL tsg_packet_auth_to_string(char** buffer, size_t* length, const TSG_P { BOOL rc = FALSE; char* strdata = NULL; - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(caps); + assert(buffer); + assert(length); + assert(caps); if (!tsg_print(buffer, length, "caps_message_request { ")) return FALSE; @@ -627,14 +630,14 @@ static BOOL tsg_packet_reauth_to_string(char** buffer, size_t* length, const TSG_PACKET_REAUTH* caps) { BOOL rc = FALSE; - WINPR_ASSERT(buffer); - WINPR_ASSERT(length); - WINPR_ASSERT(caps); + assert(buffer); + assert(length); + assert(caps); if (!tsg_print(buffer, length, "caps_message_request { ")) return FALSE; - if (!tsg_print(buffer, length, " tunnelContext=0x%08" PRIx32 ", packetId=%s [0x%08" PRIx32 "]", + if (!tsg_print(buffer, length, " tunnelContext=0x%016" PRIx64 ", packetId=%s [0x%08" PRIx32 "]", caps->tunnelContext, tsg_packet_id_to_string(caps->packetId), caps->packetId)) return FALSE; @@ -793,7 +796,7 @@ static int TsProxySendToServer(handle_t IDL_handle, const byte pRpcMessage[], UI { wStream* s; rdpTsg* tsg; - int length; + size_t length; const byte* buffer1 = NULL; const byte* buffer2 = NULL; const byte* buffer3 = NULL; @@ -829,7 +832,9 @@ static int TsProxySendToServer(handle_t IDL_handle, const byte pRpcMessage[], UI totalDataBytes += lengths[2] + 4; } - length = 28 + totalDataBytes; + length = 28ull + totalDataBytes; + if (length > INT_MAX) + return -1; s = Stream_New(NULL, length); if (!s) @@ -865,7 +870,7 @@ static int TsProxySendToServer(handle_t IDL_handle, const byte pRpcMessage[], UI if (!rpc_client_write_call(tsg->rpc, s, TsProxySendToServerOpnum)) return -1; - return length; + return (int)length; } /** @@ -1018,12 +1023,20 @@ static BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, UINT32 SwitchValue; UINT32 MessageSwitchValue = 0; UINT32 IsMessagePresent; + rdpContext* context; UINT32 MsgBytes; + TSG_PACKET_STRING_MESSAGE packetStringMessage; PTSG_PACKET_CAPABILITIES tsgCaps = NULL; PTSG_PACKET_VERSIONCAPS versionCaps = NULL; PTSG_PACKET_CAPS_RESPONSE packetCapsResponse = NULL; PTSG_PACKET_QUARENC_RESPONSE packetQuarEncResponse = NULL; + assert(tsg); + assert(tsg->rpc); + + context = tsg->rpc->context; + assert(context); + if (!pdu) return FALSE; @@ -1170,8 +1183,8 @@ static BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, if (Stream_GetRemainingLength(pdu->s) < 16) goto fail; - Stream_Seek_UINT32(pdu->s); /* IsDisplayMandatory (4 bytes) */ - Stream_Seek_UINT32(pdu->s); /* IsConsent Mandatory (4 bytes) */ + Stream_Read_INT32(pdu->s, packetStringMessage.isDisplayMandatory); + Stream_Read_INT32(pdu->s, packetStringMessage.isConsentMandatory); Stream_Read_UINT32(pdu->s, MsgBytes); Stream_Read_UINT32(pdu->s, Pointer); @@ -1351,16 +1364,19 @@ fail: static BOOL TsProxyAuthorizeTunnelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunnelContext) { - UINT32 pad; + size_t pad; wStream* s; size_t count; - UINT32 offset; + size_t offset; rdpRpc* rpc; if (!tsg || !tsg->rpc || !tunnelContext || !tsg->MachineName) return FALSE; count = _wcslen(tsg->MachineName) + 1; + if (count > UINT32_MAX) + return FALSE; + rpc = tsg->rpc; WLog_DBG(TAG, "TsProxyAuthorizeTunnelWriteRequest"); s = Stream_New(NULL, 1024 + count * 2); @@ -1377,13 +1393,13 @@ static BOOL TsProxyAuthorizeTunnelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunn Stream_Write_UINT32(s, 0x00020000); /* PacketQuarRequestPtr (4 bytes) */ Stream_Write_UINT32(s, 0x00000000); /* Flags (4 bytes) */ Stream_Write_UINT32(s, 0x00020004); /* MachineNamePtr (4 bytes) */ - Stream_Write_UINT32(s, count); /* NameLength (4 bytes) */ + Stream_Write_UINT32(s, (UINT32)count); /* NameLength (4 bytes) */ Stream_Write_UINT32(s, 0x00020008); /* DataPtr (4 bytes) */ Stream_Write_UINT32(s, 0); /* DataLength (4 bytes) */ /* MachineName */ - Stream_Write_UINT32(s, count); /* MaxCount (4 bytes) */ + Stream_Write_UINT32(s, (UINT32)count); /* MaxCount (4 bytes) */ Stream_Write_UINT32(s, 0); /* Offset (4 bytes) */ - Stream_Write_UINT32(s, count); /* ActualCount (4 bytes) */ + Stream_Write_UINT32(s, (UINT32)count); /* ActualCount (4 bytes) */ Stream_Write_UTF16_String(s, tsg->MachineName, count); /* Array */ /* 4-byte alignment */ offset = Stream_GetPosition(s); @@ -1394,7 +1410,7 @@ static BOOL TsProxyAuthorizeTunnelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunn return rpc_client_write_call(rpc, s, TsProxyAuthorizeTunnelOpnum); } -static BOOL TsProxyAuthorizeTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) +static BOOL TsProxyAuthorizeTunnelReadResponse(RPC_PDU* pdu) { BOOL rc = FALSE; UINT32 Pointer; @@ -1456,25 +1472,24 @@ static BOOL TsProxyAuthorizeTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) Stream_Seek_UINT32(pdu->s); /* Reserved (4 bytes) */ Stream_Read_UINT32(pdu->s, Pointer); /* ResponseDataPtr (4 bytes) */ Stream_Read_UINT32(pdu->s, packetResponse->responseDataLen); /* ResponseDataLength (4 bytes) */ - Stream_Read_UINT32(pdu->s, packetResponse->redirectionFlags - .enableAllRedirections); /* EnableAllRedirections (4 bytes) */ - Stream_Read_UINT32(pdu->s, packetResponse->redirectionFlags - .disableAllRedirections); /* DisableAllRedirections (4 bytes) */ - Stream_Read_UINT32(pdu->s, - packetResponse->redirectionFlags - .driveRedirectionDisabled); /* DriveRedirectionDisabled (4 bytes) */ - Stream_Read_UINT32(pdu->s, - packetResponse->redirectionFlags - .printerRedirectionDisabled); /* PrinterRedirectionDisabled (4 bytes) */ - Stream_Read_UINT32(pdu->s, - packetResponse->redirectionFlags - .portRedirectionDisabled); /* PortRedirectionDisabled (4 bytes) */ - Stream_Read_UINT32(pdu->s, packetResponse->redirectionFlags.reserved); /* Reserved (4 bytes) */ - Stream_Read_UINT32( + Stream_Read_INT32(pdu->s, packetResponse->redirectionFlags + .enableAllRedirections); /* EnableAllRedirections (4 bytes) */ + Stream_Read_INT32(pdu->s, packetResponse->redirectionFlags + .disableAllRedirections); /* DisableAllRedirections (4 bytes) */ + Stream_Read_INT32(pdu->s, + packetResponse->redirectionFlags + .driveRedirectionDisabled); /* DriveRedirectionDisabled (4 bytes) */ + Stream_Read_INT32(pdu->s, + packetResponse->redirectionFlags + .printerRedirectionDisabled); /* PrinterRedirectionDisabled (4 bytes) */ + Stream_Read_INT32(pdu->s, packetResponse->redirectionFlags + .portRedirectionDisabled); /* PortRedirectionDisabled (4 bytes) */ + Stream_Read_INT32(pdu->s, packetResponse->redirectionFlags.reserved); /* Reserved (4 bytes) */ + Stream_Read_INT32( pdu->s, packetResponse->redirectionFlags .clipboardRedirectionDisabled); /* ClipboardRedirectionDisabled (4 bytes) */ - Stream_Read_UINT32(pdu->s, packetResponse->redirectionFlags - .pnpRedirectionDisabled); /* PnpRedirectionDisabled (4 bytes) */ + Stream_Read_INT32(pdu->s, packetResponse->redirectionFlags + .pnpRedirectionDisabled); /* PnpRedirectionDisabled (4 bytes) */ Stream_Read_UINT32(pdu->s, SizeValue); /* (4 bytes) */ if (SizeValue != packetResponse->responseDataLen) @@ -1574,11 +1589,18 @@ static BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu) UINT32 Pointer; UINT32 SwitchValue; TSG_PACKET packet; + rdpContext* context; char* messageText = NULL; TSG_PACKET_MSG_RESPONSE packetMsgResponse = { 0 }; TSG_PACKET_STRING_MESSAGE packetStringMessage = { 0 }; TSG_PACKET_REAUTH_MESSAGE packetReauthMessage = { 0 }; + assert(tsg); + assert(tsg->rpc); + + context = tsg->rpc->context; + assert(context); + /* This is an asynchronous response */ if (!pdu) @@ -1628,10 +1650,10 @@ static BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu) WLog_INFO(TAG, "Consent Message: %s", messageText); free(messageText); - if (tsg->rpc && tsg->rpc->context && tsg->rpc->context->instance) + if (context->instance) { - rc = IFCALLRESULT(TRUE, tsg->rpc->context->instance->PresentGatewayMessage, - tsg->rpc->context->instance, SwitchValue, + rc = IFCALLRESULT(TRUE, context->instance->PresentGatewayMessage, + context->instance, SwitchValue, packetStringMessage.isDisplayMandatory != 0, packetStringMessage.isConsentMandatory != 0, packetStringMessage.msgBytes, packetStringMessage.msgBuffer); @@ -1649,10 +1671,10 @@ static BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu) WLog_INFO(TAG, "Service Message: %s", messageText); free(messageText); - if (tsg->rpc && tsg->rpc->context && tsg->rpc->context->instance) + if (context->instance) { - rc = IFCALLRESULT(TRUE, tsg->rpc->context->instance->PresentGatewayMessage, - tsg->rpc->context->instance, SwitchValue, + rc = IFCALLRESULT(TRUE, context->instance->PresentGatewayMessage, + context->instance, SwitchValue, packetStringMessage.isDisplayMandatory != 0, packetStringMessage.isConsentMandatory != 0, packetStringMessage.msgBytes, packetStringMessage.msgBuffer); @@ -1704,6 +1726,8 @@ static BOOL TsProxyCreateChannelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunnel rpc = tsg->rpc; count = _wcslen(tsg->Hostname) + 1; + if (count > UINT32_MAX) + return FALSE; s = Stream_New(NULL, 60 + count * 2); if (!s) @@ -1723,15 +1747,15 @@ static BOOL TsProxyCreateChannelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunnel Stream_Write_UINT16(s, tsg->Port); /* PortNumber (0xD3D = 3389) (2 bytes) */ Stream_Write_UINT32(s, 0x00000001); /* NumResourceNames (4 bytes) */ Stream_Write_UINT32(s, 0x00020004); /* ResourceNamePtr (4 bytes) */ - Stream_Write_UINT32(s, count); /* MaxCount (4 bytes) */ + Stream_Write_UINT32(s, (UINT32)count); /* MaxCount (4 bytes) */ Stream_Write_UINT32(s, 0); /* Offset (4 bytes) */ - Stream_Write_UINT32(s, count); /* ActualCount (4 bytes) */ + Stream_Write_UINT32(s, (UINT32)count); /* ActualCount (4 bytes) */ Stream_Write_UTF16_String(s, tsg->Hostname, count); /* Array */ return rpc_client_write_call(rpc, s, TsProxyCreateChannelOpnum); } -static BOOL TsProxyCreateChannelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, - CONTEXT_HANDLE* channelContext, UINT32* channelId) +static BOOL TsProxyCreateChannelReadResponse(RPC_PDU* pdu, CONTEXT_HANDLE* channelContext, + UINT32* channelId) { BOOL rc = FALSE; WLog_DBG(TAG, "TsProxyCreateChannelReadResponse"); @@ -1779,7 +1803,7 @@ static BOOL TsProxyCloseChannelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* context return rpc_client_write_call(rpc, s, TsProxyCloseChannelOpnum); } -static BOOL TsProxyCloseChannelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, CONTEXT_HANDLE* context) +static BOOL TsProxyCloseChannelReadResponse(RPC_PDU* pdu, CONTEXT_HANDLE* context) { BOOL rc = FALSE; WLog_DBG(TAG, "TsProxyCloseChannelReadResponse"); @@ -1826,7 +1850,7 @@ static BOOL TsProxyCloseTunnelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* context) return rpc_client_write_call(rpc, s, TsProxyCloseTunnelOpnum); } -static BOOL TsProxyCloseTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, CONTEXT_HANDLE* context) +static BOOL TsProxyCloseTunnelReadResponse(RPC_PDU* pdu, CONTEXT_HANDLE* context) { BOOL rc = FALSE; WLog_DBG(TAG, "TsProxyCloseTunnelReadResponse"); @@ -1977,8 +2001,6 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) return FALSE; rpc = tsg->rpc; - Stream_SealLength(pdu->s); - Stream_SetPosition(pdu->s, 0); if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) { @@ -2017,7 +2039,7 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) CONTEXT_HANDLE* TunnelContext; TunnelContext = (tsg->reauthSequence) ? &tsg->NewTunnelContext : &tsg->TunnelContext; - if (!TsProxyAuthorizeTunnelReadResponse(tsg, pdu)) + if (!TsProxyAuthorizeTunnelReadResponse(pdu)) { WLog_ERR(TAG, "TsProxyAuthorizeTunnelReadResponse failure"); return FALSE; @@ -2066,7 +2088,7 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) { CONTEXT_HANDLE ChannelContext; - if (!TsProxyCreateChannelReadResponse(tsg, pdu, &ChannelContext, &tsg->ChannelId)) + if (!TsProxyCreateChannelReadResponse(pdu, &ChannelContext, &tsg->ChannelId)) { WLog_ERR(TAG, "TsProxyCreateChannelReadResponse failure"); return FALSE; @@ -2139,7 +2161,7 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) { CONTEXT_HANDLE ChannelContext; - if (!TsProxyCloseChannelReadResponse(tsg, pdu, &ChannelContext)) + if (!TsProxyCloseChannelReadResponse(pdu, &ChannelContext)) { WLog_ERR(TAG, "TsProxyCloseChannelReadResponse failure"); return FALSE; @@ -2151,7 +2173,7 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) { CONTEXT_HANDLE TunnelContext; - if (!TsProxyCloseTunnelReadResponse(tsg, pdu, &TunnelContext)) + if (!TsProxyCloseTunnelReadResponse(pdu, &TunnelContext)) { WLog_ERR(TAG, "TsProxyCloseTunnelReadResponse failure"); return FALSE; @@ -2166,7 +2188,7 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) { CONTEXT_HANDLE ChannelContext; - if (!TsProxyCloseChannelReadResponse(tsg, pdu, &ChannelContext)) + if (!TsProxyCloseChannelReadResponse(pdu, &ChannelContext)) { WLog_ERR(TAG, "TsProxyCloseChannelReadResponse failure"); return FALSE; @@ -2196,7 +2218,7 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) { CONTEXT_HANDLE TunnelContext; - if (!TsProxyCloseTunnelReadResponse(tsg, pdu, &TunnelContext)) + if (!TsProxyCloseTunnelReadResponse(pdu, &TunnelContext)) { WLog_ERR(TAG, "TsProxyCloseTunnelReadResponse failure"); return FALSE; @@ -2305,10 +2327,25 @@ BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port, DWORD timeout) { UINT64 looptimeout = timeout * 1000ULL; DWORD nCount; - HANDLE events[64]; - rdpRpc* rpc = tsg->rpc; - rdpSettings* settings = rpc->settings; - rdpTransport* transport = rpc->transport; + HANDLE events[MAXIMUM_WAIT_OBJECTS] = { 0 }; + rdpRpc* rpc; + rdpContext* context; + rdpSettings* settings; + rdpTransport* transport; + + assert(tsg); + + rpc = tsg->rpc; + assert(rpc); + + transport = rpc->transport; + assert(transport); + + context = tsg->rpc->context; + assert(context); + + settings = context->settings; + tsg->Port = port; tsg->transport = transport; @@ -2327,7 +2364,7 @@ BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port, DWORD timeout) return FALSE; } - nCount = tsg_get_event_handles(tsg, events, 64); + nCount = tsg_get_event_handles(tsg, events, ARRAYSIZE(events)); if (nCount == 0) return FALSE; @@ -2410,7 +2447,7 @@ BOOL tsg_disconnect(rdpTsg* tsg) * @return < 0 on error; 0 if not enough data is available (non blocking mode); > 0 bytes to read */ -static int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length) +static int tsg_read(rdpTsg* tsg, BYTE* data, size_t length) { rdpRpc* rpc; int status = 0; @@ -2428,7 +2465,7 @@ static int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length) do { - status = rpc_client_receive_pipe_read(rpc->client, data, (size_t)length); + status = rpc_client_receive_pipe_read(rpc->client, data, length); if (status < 0) return -1; @@ -2478,7 +2515,7 @@ static int tsg_write(rdpTsg* tsg, const BYTE* data, UINT32 length) if (status < 0) return -1; - return length; + return (int)length; } rdpTsg* tsg_new(rdpTransport* transport) @@ -2490,7 +2527,6 @@ rdpTsg* tsg_new(rdpTransport* transport) return NULL; tsg->transport = transport; - tsg->settings = transport->settings; tsg->rpc = rpc_new(tsg->transport); if (!tsg->rpc) @@ -2518,7 +2554,10 @@ static int transport_bio_tsg_write(BIO* bio, const char* buf, int num) int status; rdpTsg* tsg = (rdpTsg*)BIO_get_data(bio); BIO_clear_flags(bio, BIO_FLAGS_WRITE); - status = tsg_write(tsg, (BYTE*)buf, num); + + if (num < 0) + return -1; + status = tsg_write(tsg, (const BYTE*)buf, (UINT32)num); if (status < 0) { @@ -2550,7 +2589,7 @@ static int transport_bio_tsg_read(BIO* bio, char* buf, int size) } BIO_clear_flags(bio, BIO_FLAGS_READ); - status = tsg_read(tsg, (BYTE*)buf, size); + status = tsg_read(tsg, (BYTE*)buf, (size_t)size); if (status < 0) { @@ -2572,17 +2611,22 @@ static int transport_bio_tsg_read(BIO* bio, char* buf, int size) static int transport_bio_tsg_puts(BIO* bio, const char* str) { + WINPR_UNUSED(bio); + WINPR_UNUSED(str); return 1; } static int transport_bio_tsg_gets(BIO* bio, char* str, int size) { + WINPR_UNUSED(bio); + WINPR_UNUSED(str); + WINPR_UNUSED(size); return 1; } static long transport_bio_tsg_ctrl(BIO* bio, int cmd, long arg1, void* arg2) { - int status = -1; + long status = -1; rdpTsg* tsg = (rdpTsg*)BIO_get_data(bio); RpcVirtualConnection* connection = tsg->rpc->VirtualConnection; RpcInChannel* inChannel = connection->DefaultInChannel; @@ -2611,27 +2655,27 @@ static long transport_bio_tsg_ctrl(BIO* bio, int cmd, long arg1, void* arg2) case BIO_C_READ_BLOCKED: { - BIO* bio = outChannel->common.bio; - status = BIO_read_blocked(bio); + BIO* cbio = outChannel->common.bio; + status = BIO_read_blocked(cbio); } break; case BIO_C_WRITE_BLOCKED: { - BIO* bio = inChannel->common.bio; - status = BIO_write_blocked(bio); + BIO* cbio = inChannel->common.bio; + status = BIO_write_blocked(cbio); } break; case BIO_C_WAIT_READ: { int timeout = (int)arg1; - BIO* bio = outChannel->common.bio; + BIO* cbio = outChannel->common.bio; - if (BIO_read_blocked(bio)) - return BIO_wait_read(bio, timeout); - else if (BIO_write_blocked(bio)) - return BIO_wait_write(bio, timeout); + if (BIO_read_blocked(cbio)) + return BIO_wait_read(cbio, timeout); + else if (BIO_write_blocked(cbio)) + return BIO_wait_write(cbio, timeout); else status = 1; } @@ -2640,12 +2684,12 @@ static long transport_bio_tsg_ctrl(BIO* bio, int cmd, long arg1, void* arg2) case BIO_C_WAIT_WRITE: { int timeout = (int)arg1; - BIO* bio = inChannel->common.bio; + BIO* cbio = inChannel->common.bio; - if (BIO_write_blocked(bio)) - status = BIO_wait_write(bio, timeout); - else if (BIO_read_blocked(bio)) - status = BIO_wait_read(bio, timeout); + if (BIO_write_blocked(cbio)) + status = BIO_wait_write(cbio, timeout); + else if (BIO_read_blocked(cbio)) + status = BIO_wait_read(cbio, timeout); else status = 1; } @@ -2660,6 +2704,7 @@ static long transport_bio_tsg_ctrl(BIO* bio, int cmd, long arg1, void* arg2) static int transport_bio_tsg_new(BIO* bio) { + assert(bio); BIO_set_init(bio, 1); BIO_set_flags(bio, BIO_FLAGS_SHOULD_RETRY); return 1; @@ -2667,6 +2712,8 @@ static int transport_bio_tsg_new(BIO* bio) static int transport_bio_tsg_free(BIO* bio) { + assert(bio); + WINPR_UNUSED(bio); return 1; }