[ros-diffs] [ion] 24672: - Implement NtSecureConnectPort so that clients can connect to SMSS. Does not yet support actual secure connections (with a SID) and will fail those requests. Also doesn't support memory-mapped LPC yet.

ion at svn.reactos.org ion at svn.reactos.org
Mon Oct 30 15:46:57 CET 2006


Author: ion
Date: Mon Oct 30 17:46:56 2006
New Revision: 24672

URL: http://svn.reactos.org/svn/reactos?rev=24672&view=rev
Log:
- Implement NtSecureConnectPort so that clients can connect to SMSS. Does not yet support actual secure connections (with a SID) and will fail those requests. Also doesn't support memory-mapped LPC yet.

Modified:
    trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c

Modified: trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c
URL: http://svn.reactos.org/svn/reactos/trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c?rev=24672&r1=24671&r2=24672&view=diff
==============================================================================
--- trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c (original)
+++ trunk/reactos/ntoskrnl/lpc/ntlpc/connect.c Mon Oct 30 17:46:56 2006
@@ -14,6 +14,52 @@
 #include <internal/debug.h>
 
 /* PRIVATE FUNCTIONS *********************************************************/
+
+PVOID
+NTAPI
+LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message,
+               IN OUT PLPCP_CONNECTION_MESSAGE *ConnectMessage,
+               IN PETHREAD CurrentThread)
+{
+    PVOID SectionToMap;
+
+    /* Acquire the LPC lock */
+    KeAcquireGuardedMutex(&LpcpLock);
+
+    /* Check if the reply chain is not empty */
+    if (!IsListEmpty(&CurrentThread->LpcReplyChain))
+    {
+        /* Remove this entry and re-initialize it */
+        RemoveEntryList(&CurrentThread->LpcReplyChain);
+        InitializeListHead(&CurrentThread->LpcReplyChain);
+    }
+
+    /* Check if there's a reply message */
+    if (CurrentThread->LpcReplyMessage)
+    {
+        /* Get the message */
+        *Message = CurrentThread->LpcReplyMessage;
+
+        /* Clear message data */
+        CurrentThread->LpcReceivedMessageId = 0;
+        CurrentThread->LpcReplyMessage = NULL;
+
+        /* Get the connection message and clear the section */
+        *ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(*Message + 1);
+        SectionToMap = (*ConnectMessage)->SectionToMap;
+        (*ConnectMessage)->SectionToMap = NULL;
+    }
+    else
+    {
+        /* No message to return */
+        *Message = NULL;
+        SectionToMap = NULL;
+    }
+
+    /* Release the lock and return the section */
+    KeReleaseGuardedMutex(&LpcpLock);
+    return SectionToMap;
+}
 
 /* PUBLIC FUNCTIONS **********************************************************/
 
@@ -32,8 +78,389 @@
                     IN OUT PVOID ConnectionInformation OPTIONAL,
                     IN OUT PULONG ConnectionInformationLength OPTIONAL)
 {
-    UNIMPLEMENTED;
-    return STATUS_NOT_IMPLEMENTED;
+    ULONG ConnectionInfoLength = 0;
+    PLPCP_PORT_OBJECT Port, ClientPort;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+    NTSTATUS Status = STATUS_SUCCESS;
+    HANDLE Handle;
+    PVOID SectionToMap;
+    PLPCP_MESSAGE Message;
+    PLPCP_CONNECTION_MESSAGE ConnectMessage;
+    PETHREAD Thread = PsGetCurrentThread();
+    ULONG PortMessageLength;
+    PAGED_CODE();
+    LPCTRACE(LPC_CONNECT_DEBUG,
+             "Name: %wZ. Qos: %p. Views: %p/%p\n",
+             PortName,
+             Qos,
+             ClientView,
+             ServerView);
+
+    /* Validate client view */
+    if ((ClientView) && (ClientView->Length != sizeof(PORT_VIEW)))
+    {
+        /* Fail */
+        return STATUS_INVALID_PARAMETER;
+    }
+
+    /* Validate server view */
+    if ((ServerView) && (ServerView->Length != sizeof(REMOTE_PORT_VIEW)))
+    {
+        /* Fail */
+        return STATUS_INVALID_PARAMETER;
+    }
+
+    /* Check if caller sent connection information length */
+    if (ConnectionInformationLength)
+    {
+        /* Retrieve the input length */
+        ConnectionInfoLength = *ConnectionInformationLength;
+    }
+
+    /* Get the port */
+    Status = ObReferenceObjectByName(PortName,
+                                     0,
+                                     NULL,
+                                     PORT_ALL_ACCESS,
+                                     LpcPortObjectType,
+                                     PreviousMode,
+                                     NULL,
+                                     (PVOID *)&Port);
+    if (!NT_SUCCESS(Status)) return Status;
+
+    /* This has to be a connection port */
+    if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
+    {
+        /* It isn't, so fail */
+        ObDereferenceObject(Port);
+        return STATUS_INVALID_PORT_HANDLE;
+    }
+
+    /* Check if we have a SID */
+    if (ServerSid)
+    {
+        /* FIXME: TODO */
+        UNIMPLEMENTED;
+        return STATUS_NOT_IMPLEMENTED;
+    }
+
+    /* Create the client port */
+    Status = ObCreateObject(PreviousMode,
+                            LpcPortObjectType,
+                            NULL,
+                            PreviousMode,
+                            NULL,
+                            sizeof(LPCP_PORT_OBJECT),
+                            0,
+                            0,
+                            (PVOID *)&ClientPort);
+    if (!NT_SUCCESS(Status))
+    {
+        /* Failed, dereference the server port and return */
+        ObDereferenceObject(Port);
+        return Status;
+    }
+
+    /* Setup the client port */
+    RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT));
+    ClientPort->Flags = LPCP_CLIENT_PORT;
+    ClientPort->ConnectionPort = Port;
+    ClientPort->MaxMessageLength = Port->MaxMessageLength;
+    ClientPort->SecurityQos = *Qos;
+    InitializeListHead(&ClientPort->LpcReplyChainHead);
+    InitializeListHead(&ClientPort->LpcDataInfoChainHead);
+
+    /* Check if we have dynamic security */
+    if (Qos->ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
+    {
+        /* Remember that */
+        ClientPort->Flags |= LPCP_SECURITY_DYNAMIC;
+    }
+    else
+    {
+        /* Create our own client security */
+        Status = SeCreateClientSecurity(Thread,
+                                        Qos,
+                                        FALSE,
+                                        &ClientPort->StaticSecurity);
+        if (!NT_SUCCESS(Status))
+        {
+            /* Security failed, dereference and return */
+            ObDereferenceObject(ClientPort);
+            return Status;
+        }
+    }
+
+    /* Initialize the port queue */
+    Status = LpcpInitializePortQueue(ClientPort);
+    if (!NT_SUCCESS(Status))
+    {
+        /* Failed */
+        ObDereferenceObject(ClientPort);
+        return Status;
+    }
+
+    /* Check if we have a client view */
+    if (ClientView)
+    {
+        /* FIXME: TODO */
+        UNIMPLEMENTED;
+        return STATUS_NOT_IMPLEMENTED;
+    }
+    else
+    {
+        /* No section */
+        SectionToMap = NULL;
+    }
+
+    /* Normalize connection information */
+    if (ConnectionInfoLength > Port->MaxConnectionInfoLength)
+    {
+        /* Use the port's maximum allowed value */
+        ConnectionInfoLength = Port->MaxConnectionInfoLength;
+    }
+
+    /* Allocate a message from the port zone while holding the lock */
+    KeAcquireGuardedMutex(&LpcpLock);
+    Message = ExAllocateFromPagedLookasideList(&LpcpMessagesLookaside);
+    if (!Message)
+    {
+        /* Fail if we couldn't allocate a message */
+        KeReleaseGuardedMutex(&LpcpLock);
+        if (SectionToMap) ObDereferenceObject(SectionToMap);
+        ObDereferenceObject(ClientPort);
+        return STATUS_NO_MEMORY;
+    }
+
+    /* Initialize it */
+    InitializeListHead(&Message->Entry);
+    Message->RepliedToThread = NULL;
+    Message->Request.u2.ZeroInit = 0;
+
+    /* Release the lock */
+    KeReleaseGuardedMutex(&LpcpLock);
+
+    /* Set pointer to the connection message and fill in the CID */
+    ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
+    Message->Request.ClientId = Thread->Cid;
+
+    /* Check if we have a client view */
+    if (ClientView)
+    {
+        /* FIXME: TODO */
+        UNIMPLEMENTED;
+        return STATUS_NOT_IMPLEMENTED;
+    }
+    else
+    {
+        /* Set the size to 0 and clear the connect message */
+        Message->Request.ClientViewSize = 0;
+        RtlZeroMemory(ConnectMessage, sizeof(LPCP_CONNECTION_MESSAGE));
+    }
+
+    /* Set the section and client port. Port is NULL for now */
+    ConnectMessage->ClientPort = NULL;
+    ConnectMessage->SectionToMap = SectionToMap;
+
+    /* Set the data for the connection request message */
+    Message->Request.u1.s1.DataLength = sizeof(LPCP_CONNECTION_MESSAGE) +
+                                        ConnectionInfoLength;
+    Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
+                                         Message->Request.u1.s1.DataLength;
+    Message->Request.u2.s2.Type = LPC_CONNECTION_REQUEST;
+
+    /* Check if we have connection information */
+    if (ConnectionInformation)
+    {
+        /* Copy it in */
+        RtlMoveMemory(ConnectMessage + 1,
+                      ConnectionInformation,
+                      ConnectionInfoLength);
+    }
+
+    /* Acquire the port lock */
+    KeAcquireGuardedMutex(&LpcpLock);
+
+    /* Check if someone already deleted the port name */
+    if (Port->Flags & LPCP_NAME_DELETED)
+    {
+        /* Fail the request */
+        KeReleaseGuardedMutex(&LpcpLock);
+        Status = STATUS_OBJECT_NAME_NOT_FOUND;
+        goto Cleanup;
+    }
+
+    /* Associate no thread yet */
+    Message->RepliedToThread = NULL;
+
+    /* Generate the Message ID and set it */
+    Message->Request.MessageId =  LpcpNextMessageId++;
+    if (!LpcpNextMessageId) LpcpNextMessageId = 1;
+    Thread->LpcReplyMessageId = Message->Request.MessageId;
+
+    /* Insert the message into the queue and thread chain */
+    InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry);
+    InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain);
+    Thread->LpcReplyMessage = Message;
+
+    /* Now we can finally reference the client port and link it*/
+    ObReferenceObject(ClientPort);
+    ConnectMessage->ClientPort = ClientPort;
+
+    /* Release the lock */
+    KeReleaseGuardedMutex(&LpcpLock);
+    LPCTRACE(LPC_CONNECT_DEBUG,
+             "Messages: %p/%p. Ports: %p/%p. Status: %lx\n",
+             Message,
+             ConnectMessage,
+             Port,
+             ClientPort,
+             Status);
+
+    /* If this is a waitable port, set the event */
+    if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent,
+                                                     1,
+                                                     FALSE);
+
+    /* Release the queue semaphore */
+    LpcpCompleteWait(Port->MsgQueue.Semaphore);
+
+    /* Now wait for a reply */
+    LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
+
+    /* Check if our wait ended in success */
+    if (Status != STATUS_SUCCESS) goto Cleanup;
+
+    /* Free the connection message */
+    SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
+
+    /* Check if we got a message back */
+    if (Message)
+    {
+        /* Check for new return length */
+        if ((Message->Request.u1.s1.DataLength -
+             sizeof(LPCP_CONNECTION_MESSAGE)) < ConnectionInfoLength)
+        {
+            /* Set new normalized connection length */
+            ConnectionInfoLength = Message->Request.u1.s1.DataLength -
+                                   sizeof(LPCP_CONNECTION_MESSAGE);
+        }
+
+        /* Check if we had connection information */
+        if (ConnectionInformation)
+        {
+            /* Check if we had a length pointer */
+            if (ConnectionInformationLength)
+            {
+                /* Return the length */
+                *ConnectionInformationLength = ConnectionInfoLength;
+            }
+
+            /* Return the connection information */
+            RtlMoveMemory(ConnectionInformation,
+                          ConnectMessage + 1,
+                          ConnectionInfoLength );
+        }
+
+        /* Make sure we had a connected port */
+        if (ClientPort->ConnectedPort)
+        {
+            /* Get the message length before the port might get killed */
+            PortMessageLength = Port->MaxMessageLength;
+
+            /* Insert the client port */
+            Status = ObInsertObject(ClientPort,
+                                    NULL,
+                                    PORT_ALL_ACCESS,
+                                    0,
+                                    (PVOID *)NULL,
+                                    &Handle);
+            if (NT_SUCCESS(Status))
+            {
+                /* Return the handle */
+                *PortHandle = Handle;
+                LPCTRACE(LPC_CONNECT_DEBUG,
+                         "Handle: %lx. Length: %lx\n",
+                         Handle,
+                         PortMessageLength);
+
+                /* Check if maximum length was requested */
+                if (MaxMessageLength) *MaxMessageLength = PortMessageLength;
+
+                /* Check if we had a client view */
+                if (ClientView)
+                {
+                    /* Copy it back */
+                    RtlMoveMemory(ClientView,
+                                  &ConnectMessage->ClientView,
+                                  sizeof(PORT_VIEW));
+                }
+
+                /* Check if we had a server view */
+                if (ServerView)
+                {
+                    /* Copy it back */
+                    RtlMoveMemory(ServerView,
+                                  &ConnectMessage->ServerView,
+                                  sizeof(REMOTE_PORT_VIEW));
+                }
+            }
+        }
+        else
+        {
+            /* No connection port, we failed */
+            if (SectionToMap) ObDereferenceObject(SectionToMap);
+
+            /* Check if it's because the name got deleted */
+            if (Port->Flags & LPCP_NAME_DELETED)
+            {
+                /* Set the correct status */
+                Status = STATUS_OBJECT_NAME_NOT_FOUND;
+            }
+            else
+            {
+                /* Otherwise, the caller refused us */
+                Status = STATUS_PORT_CONNECTION_REFUSED;
+            }
+
+            /* Kill the port */
+            ObDereferenceObject(ClientPort);
+        }
+
+        /* Free the message */
+        LpcpFreeToPortZone(Message, FALSE);
+        return Status;
+    }
+
+    /* No reply message, fail */
+    if (SectionToMap) ObDereferenceObject(SectionToMap);
+    ObDereferenceObject(ClientPort);
+    return STATUS_PORT_CONNECTION_REFUSED;
+
+Cleanup:
+    /* We failed, free the message */
+    SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
+
+    /* Check if the semaphore got signaled */
+    if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
+    {
+        /* Wait on it */
+        KeWaitForSingleObject(&Thread->LpcReplySemaphore,
+                              KernelMode,
+                              Executive,
+                              FALSE,
+                              NULL);
+    }
+
+    /* Check if we had a message and free it */
+    if (Message) LpcpFreeToPortZone(Message, FALSE);
+
+    /* Dereference other objects */
+    if (SectionToMap) ObDereferenceObject(SectionToMap);
+    ObDereferenceObject(ClientPort);
+
+    /* Return status */
+    return Status;
 }
 
 /*




More information about the Ros-diffs mailing list