[ros-diffs] [weiden] 13984: Thomas Weidenmueller <w3seek@reactos.com>

Alex Ionescu ionucu at videotron.ca
Sat Mar 12 23:28:25 CET 2005


weiden at svn.reactos.com wrote:

>Thomas Weidenmueller <w3seek at reactos.com>
>- Fix various security structures and constants
>- Add code to capture quality of service structures and ACLs
>- Secure buffer access in NtQueryInformationToken, NtSetInformationToken, NtNotifyChangeDirectoryFile and NtQueryDirectoryFile
>
>Modified: trunk/reactos/include/ddk/setypes.h
>Modified: trunk/reactos/include/ntos/security.h
>Modified: trunk/reactos/lib/rtl/sid.c
>Modified: trunk/reactos/ntoskrnl/include/internal/ob.h
>Modified: trunk/reactos/ntoskrnl/include/internal/se.h
>Modified: trunk/reactos/ntoskrnl/io/dir.c
>Modified: trunk/reactos/ntoskrnl/ob/object.c
>Modified: trunk/reactos/ntoskrnl/se/acl.c
>Modified: trunk/reactos/ntoskrnl/se/luid.c
>Modified: trunk/reactos/ntoskrnl/se/sd.c
>Modified: trunk/reactos/ntoskrnl/se/sid.c
>Modified: trunk/reactos/ntoskrnl/se/token.c
>  
>
> ------------------------------------------------------------------------

What didn't make it in the ros-diff:

Index: ntoskrnl/se/sid.c
===================================================================
--- ntoskrnl/se/sid.c    (.../trunk/reactos)    (revision 13937)
+++ ntoskrnl/se/sid.c    (.../branches/alex_devel_branch/reactos)    
(revision 13942)
@@ -466,4 +466,107 @@
   return(TRUE);
 }
 
+NTSTATUS
+SepCaptureSid(IN PSID InputSid,
+              IN KPROCESSOR_MODE AccessMode,
+              IN POOL_TYPE PoolType,
+              IN BOOLEAN CaptureIfKernel,
+              OUT PSID *CapturedSid)
+{
+  ULONG SidSize = 0;
+  PISID NewSid, Sid = (PISID)InputSid;
+  NTSTATUS Status = STATUS_SUCCESS;
+ 
+  PAGED_CODE();
+
+  if(AccessMode != KernelMode)
+  {
+    _SEH_TRY
+    {
+      ProbeForRead(Sid,
+                   sizeof(*Sid) - sizeof(Sid->SubAuthority),
+                   sizeof(UCHAR));
+      SidSize = RtlLengthRequiredSid(Sid->SubAuthorityCount);
+      ProbeForRead(Sid,
+                   SidSize,
+                   sizeof(UCHAR));
+    }
+    _SEH_HANDLE
+    {
+      Status = _SEH_GetExceptionCode();
+    }
+    _SEH_END;
+   
+    if(NT_SUCCESS(Status))
+    {
+      /* allocate a SID and copy it */
+      NewSid = ExAllocatePool(PoolType,
+                              SidSize);
+      if(NewSid != NULL)
+      {
+        _SEH_TRY
+        {
+          RtlCopyMemory(NewSid,
+                        Sid,
+                        SidSize);
+
+          *CapturedSid = NewSid;
+        }
+        _SEH_HANDLE
+        {
+          ExFreePool(NewSid);
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+      }
+      else
+      {
+        Status = STATUS_INSUFFICIENT_RESOURCES;
+      }
+    }
+  }
+  else if(!CaptureIfKernel)
+  {
+    *CapturedSid = InputSid;
+    return STATUS_SUCCESS;
+  }
+  else
+  {
+    SidSize = RtlLengthRequiredSid(Sid->SubAuthorityCount);
+
+    /* allocate a SID and copy it */
+    NewSid = ExAllocatePool(PoolType,
+                            SidSize);
+    if(NewSid != NULL)
+    {
+      RtlCopyMemory(NewSid,
+                    Sid,
+                    SidSize);
+
+      *CapturedSid = NewSid;
+    }
+    else
+    {
+      Status = STATUS_INSUFFICIENT_RESOURCES;
+    }
+  }
+
+  return Status;
+}
+
+VOID
+SepReleaseSid(IN PSID CapturedSid,
+              IN KPROCESSOR_MODE AccessMode,
+              IN BOOLEAN CaptureIfKernel)
+{
+  PAGED_CODE();
+ 
+  if(CapturedSid != NULL &&
+     (AccessMode == UserMode ||
+      (AccessMode == KernelMode && CaptureIfKernel)))
+  {
+    ExFreePool(CapturedSid);
+  }
+}
+
 /* EOF */
Index: ntoskrnl/se/token.c
===================================================================
--- ntoskrnl/se/token.c    (.../trunk/reactos)    (revision 13937)
+++ ntoskrnl/se/token.c    (.../branches/alex_devel_branch/reactos)    
(revision 13942)
@@ -18,12 +18,54 @@
 /* GLOBALS 
*******************************************************************/
 
 POBJECT_TYPE SepTokenObjectType = NULL;
+ERESOURCE SepTokenLock;
 
 static GENERIC_MAPPING SepTokenMapping = {TOKEN_READ,
                       TOKEN_WRITE,
                       TOKEN_EXECUTE,
                       TOKEN_ALL_ACCESS};
 
+static const INFORMATION_CLASS_INFO SeTokenInformationClass[] = {
+
+    /* Class 0 not used, blame M$! */
+    ICI_SQ_SAME( 0, 0, 0),
+
+    /* TokenUser */
+    ICI_SQ_SAME( sizeof(TOKEN_USER),                   sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE | ICIF_SET | ICIF_SET_SIZE_VARIABLE ),
+    /* TokenGroups */
+    ICI_SQ_SAME( sizeof(TOKEN_GROUPS),                 sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE | ICIF_SET | ICIF_SET_SIZE_VARIABLE ),
+    /* TokenPrivileges */
+    ICI_SQ_SAME( sizeof(TOKEN_PRIVILEGES),             sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE | ICIF_SET | ICIF_SET_SIZE_VARIABLE ),
+    /* TokenOwner */
+    ICI_SQ_SAME( sizeof(TOKEN_OWNER),                  sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE | ICIF_SET | ICIF_SET_SIZE_VARIABLE ),
+    /* TokenPrimaryGroup */
+    ICI_SQ_SAME( sizeof(TOKEN_PRIMARY_GROUP),          sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE | ICIF_SET | ICIF_SET_SIZE_VARIABLE ),
+    /* TokenDefaultDacl */
+    ICI_SQ_SAME( sizeof(TOKEN_DEFAULT_DACL),           sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE | ICIF_SET | ICIF_SET_SIZE_VARIABLE ),
+    /* TokenSource */
+    ICI_SQ_SAME( sizeof(TOKEN_SOURCE),                 sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE | ICIF_SET | ICIF_SET_SIZE_VARIABLE ),
+    /* TokenType */
+    ICI_SQ_SAME( sizeof(TOKEN_TYPE),                   sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE ),
+    /* TokenImpersonationLevel */
+    ICI_SQ_SAME( sizeof(SECURITY_IMPERSONATION_LEVEL), sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE ),
+    /* TokenStatistics */
+    ICI_SQ_SAME( sizeof(TOKEN_STATISTICS),             sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE | ICIF_SET | ICIF_SET_SIZE_VARIABLE ),
+    /* TokenRestrictedSids */
+    ICI_SQ_SAME( sizeof(TOKEN_GROUPS),                 sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE ),
+    /* TokenSessionId */
+    ICI_SQ_SAME( sizeof(ULONG),                        sizeof(ULONG), 
ICIF_QUERY | ICIF_SET ),
+    /* TokenGroupsAndPrivileges */
+    ICI_SQ_SAME( sizeof(TOKEN_GROUPS_AND_PRIVILEGES),  sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE ),
+    /* TokenSessionReference */
+    ICI_SQ_SAME( /* FIXME */0,                         sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE ),
+    /* TokenSandBoxInert */
+    ICI_SQ_SAME( sizeof(ULONG),                        sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE ),
+    /* TokenAuditPolicy */
+    ICI_SQ_SAME( /* FIXME */0,                         sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE ),
+    /* TokenOrigin */
+    ICI_SQ_SAME( sizeof(TOKEN_ORIGIN),                 sizeof(ULONG), 
ICIF_QUERY | ICIF_QUERY_SIZE_VARIABLE ),
+};
+
 /* FUNCTIONS 
*****************************************************************/
 
 VOID SepFreeProxyData(PVOID ProxyData)
@@ -140,6 +182,8 @@
   PVOID EndMem;
   PTOKEN AccessToken;
   NTSTATUS Status;
+ 
+  PAGED_CODE();
 
   Status = ObCreateObject(PreviousMode,
               SepTokenObjectType,
@@ -170,6 +214,8 @@
       return(Status);
     }
 
+  AccessToken->TokenLock = &SepTokenLock;
+
   AccessToken->TokenInUse = 0;
   AccessToken->TokenType  = TokenType;
   AccessToken->ImpersonationLevel = Level;
@@ -189,7 +235,7 @@
     uLength += RtlLengthSid(Token->UserAndGroups[i].Sid);
 
   AccessToken->UserAndGroups =
-    (PSID_AND_ATTRIBUTES)ExAllocatePoolWithTag(NonPagedPool,
+    (PSID_AND_ATTRIBUTES)ExAllocatePoolWithTag(PagedPool,
                            uLength,
                            TAG('T', 'O', 'K', 'u'));
 
@@ -216,7 +262,7 @@
 
       uLength = AccessToken->PrivilegeCount * sizeof(LUID_AND_ATTRIBUTES);
       AccessToken->Privileges =
-    (PLUID_AND_ATTRIBUTES)ExAllocatePoolWithTag(NonPagedPool,
+    (PLUID_AND_ATTRIBUTES)ExAllocatePoolWithTag(PagedPool,
                             uLength,
                             TAG('T', 'O', 'K', 'p'));
 
@@ -231,7 +277,7 @@
       if ( Token->DefaultDacl )
     {
       AccessToken->DefaultDacl =
-        (PACL) ExAllocatePoolWithTag(NonPagedPool,
+        (PACL) ExAllocatePoolWithTag(PagedPool,
                      Token->DefaultDacl->AclSize,
                      TAG('T', 'O', 'K', 'd'));
       memcpy(AccessToken->DefaultDacl,
@@ -534,6 +580,8 @@
 VOID INIT_FUNCTION
 SepInitializeTokenImplementation(VOID)
 {
+  ExInitializeResource(&SepTokenLock);
+
   SepTokenObjectType = ExAllocatePool(NonPagedPool, sizeof(OBJECT_TYPE));
 
   SepTokenObjectType->Tag = TAG('T', 'O', 'K', 'T');
@@ -555,8 +603,7 @@
   SepTokenObjectType->Create = NULL;
   SepTokenObjectType->DuplicationNotify = NULL;
 
-  RtlpCreateUnicodeString(&SepTokenObjectType->TypeName,
-          L"Token", NonPagedPool);
+  RtlInitUnicodeString(&SepTokenObjectType->TypeName, L"Token");
   ObpCreateTypeObject (SepTokenObjectType);
 }
 
@@ -571,266 +618,456 @@
             IN ULONG TokenInformationLength,
             OUT PULONG ReturnLength)
 {
-  NTSTATUS Status, LengthStatus;
-  PVOID UnusedInfo;
-  PVOID EndMem;
+  union
+  {
+    PVOID Ptr;
+    ULONG Ulong;
+  } Unused;
   PTOKEN Token;
-  ULONG Length;
-  PTOKEN_GROUPS PtrTokenGroups;
-  PTOKEN_DEFAULT_DACL PtrDefaultDacl;
-  PTOKEN_STATISTICS PtrTokenStatistics;
+  ULONG RequiredLength;
+  KPROCESSOR_MODE PreviousMode;
+  NTSTATUS Status = STATUS_SUCCESS;
  
   PAGED_CODE();
+ 
+  PreviousMode = ExGetPreviousMode();
+ 
+  /* Check buffers and class validity */
+  DefaultQueryInfoBufferCheck(TokenInformationClass,
+                              SeTokenInformationClass,
+                              TokenInformation,
+                              TokenInformationLength,
+                              ReturnLength,
+                              PreviousMode,
+                              &Status);
 
+  if(!NT_SUCCESS(Status))
+  {
+    DPRINT("NtQueryInformationToken() failed, Status: 0x%x\n", Status);
+    return Status;
+  }
+
   Status = ObReferenceObjectByHandle(TokenHandle,
                      (TokenInformationClass == TokenSource) ? 
TOKEN_QUERY_SOURCE : TOKEN_QUERY,
                      SepTokenObjectType,
-                     UserMode,
+                     PreviousMode,
                      (PVOID*)&Token,
                      NULL);
-  if (!NT_SUCCESS(Status))
+  if (NT_SUCCESS(Status))
+  {
+    switch (TokenInformationClass)
     {
-      return(Status);
-    }
+      case TokenUser:
+      {
+        PTOKEN_USER tu = (PTOKEN_USER)TokenInformation;
+       
+        DPRINT("NtQueryInformationToken(TokenUser)\n");
+        RequiredLength = sizeof(TOKEN_USER) +
+                         RtlLengthSid(Token->UserAndGroups[0].Sid);
 
-  switch (TokenInformationClass)
-    {
-      case TokenUser:
-    DPRINT("NtQueryInformationToken(TokenUser)\n");
-    Length = RtlLengthSidAndAttributes(1, Token->UserAndGroups);
-    if (TokenInformationLength < Length)
-      {
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-        if (NT_SUCCESS(Status))
-          Status = STATUS_BUFFER_TOO_SMALL;
-      }
-    else
-      {
-        Status = RtlCopySidAndAttributesArray(1,
-                          Token->UserAndGroups,
-                          TokenInformationLength,
-                          TokenInformation,
-                          (char*)TokenInformation + 8,
-                          &UnusedInfo,
-                          &Length);
-        if (NT_SUCCESS(Status))
-          {
-        Length = TokenInformationLength - Length;
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-          }
-      }
-    break;
+        _SEH_TRY
+        {
+          if(TokenInformationLength >= RequiredLength)
+          {
+            Status = RtlCopySidAndAttributesArray(1,
+                                                  &Token->UserAndGroups[0],
+                                                  RequiredLength - 
sizeof(TOKEN_USER),
+                                                  &tu->User,
+                                                  (PSID)(tu + 1),
+                                                  &Unused.Ptr,
+                                                  &Unused.Ulong);
+          }
+          else
+          {
+            Status = STATUS_BUFFER_TOO_SMALL;
+          }
+         
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredLength;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+       
+        break;
+      }
    
       case TokenGroups:
-    DPRINT("NtQueryInformationToken(TokenGroups)\n");
-    Length = RtlLengthSidAndAttributes(Token->UserAndGroupCount - 1, 
&Token->UserAndGroups[1]) + sizeof(ULONG);
-    if (TokenInformationLength < Length)
-      {
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-        if (NT_SUCCESS(Status))
-          Status = STATUS_BUFFER_TOO_SMALL;
-      }
-    else
-      {
-        EndMem = (char*)TokenInformation + Token->UserAndGroupCount * 
sizeof(SID_AND_ATTRIBUTES);
-        PtrTokenGroups = (PTOKEN_GROUPS)TokenInformation;
-        PtrTokenGroups->GroupCount = Token->UserAndGroupCount - 1;
-        Status = RtlCopySidAndAttributesArray(Token->UserAndGroupCount - 1,
-                          &Token->UserAndGroups[1],
-                          TokenInformationLength,
-                          PtrTokenGroups->Groups,
-                          EndMem,
-                          &UnusedInfo,
-                          &Length);
-        if (NT_SUCCESS(Status))
-          {
-        Length = TokenInformationLength - Length;
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-          }
-      }
-    break;
+      {
+        PTOKEN_GROUPS tg = (PTOKEN_GROUPS)TokenInformation;
+       
+        DPRINT("NtQueryInformationToken(TokenGroups)\n");
+        RequiredLength = sizeof(tg->GroupCount) +
+                         
RtlLengthSidAndAttributes(Token->UserAndGroupCount - 1, 
&Token->UserAndGroups[1]);
 
-      case TokenPrivileges:
-    DPRINT("NtQueryInformationToken(TokenPrivileges)\n");
-    Length = sizeof(ULONG) + Token->PrivilegeCount * 
sizeof(LUID_AND_ATTRIBUTES);
-    if (TokenInformationLength < Length)
-      {
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-        if (NT_SUCCESS(Status))
-          Status = STATUS_BUFFER_TOO_SMALL;
-      }
-    else
-      {
-        ULONG i;
-        TOKEN_PRIVILEGES* pPriv = (TOKEN_PRIVILEGES*)TokenInformation;
+        _SEH_TRY
+        {
+          if(TokenInformationLength >= RequiredLength)
+          {
+            ULONG SidLen = RequiredLength - sizeof(tg->GroupCount) -
+                           ((Token->UserAndGroupCount - 1) * 
sizeof(SID_AND_ATTRIBUTES));
+            PSID_AND_ATTRIBUTES Sid = 
(PSID_AND_ATTRIBUTES)((ULONG_PTR)TokenInformation + sizeof(tg->GroupCount) +
+                                                            
((Token->UserAndGroupCount - 1) * sizeof(SID_AND_ATTRIBUTES)));
 
-        pPriv->PrivilegeCount = Token->PrivilegeCount;
-        for (i = 0; i < Token->PrivilegeCount; i++)
-          {
-        RtlCopyLuid(&pPriv->Privileges[i].Luid, 
&Token->Privileges[i].Luid);
-        pPriv->Privileges[i].Attributes = Token->Privileges[i].Attributes;
-          }
-        Status = STATUS_SUCCESS;
-      }
+            tg->GroupCount = Token->UserAndGroupCount - 1;
+            Status = 
RtlCopySidAndAttributesArray(Token->UserAndGroupCount - 1,
+                                                  &Token->UserAndGroups[1],
+                                                  SidLen,
+                                                  &tg->Groups[0],
+                                                  (PSID)Sid,
+                                                  &Unused.Ptr,
+                                                  &Unused.Ulong);
+          }
+          else
+          {
+            Status = STATUS_BUFFER_TOO_SMALL;
+          }
+
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredLength;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+       
     break;
+      }
+     
+      case TokenPrivileges:
+      {
+        PTOKEN_PRIVILEGES tp = (PTOKEN_PRIVILEGES)TokenInformation;
+       
+        DPRINT("NtQueryInformationToken(TokenPrivileges)\n");
+        RequiredLength = sizeof(tp->PrivilegeCount) +
+                         (Token->PrivilegeCount * 
sizeof(LUID_AND_ATTRIBUTES));
 
+        _SEH_TRY
+        {
+          if(TokenInformationLength >= RequiredLength)
+          {
+            tp->PrivilegeCount = Token->PrivilegeCount;
+            RtlCopyLuidAndAttributesArray(Token->PrivilegeCount,
+                                          Token->Privileges,
+                                          &tp->Privileges[0]);
+          }
+          else
+          {
+            Status = STATUS_BUFFER_TOO_SMALL;
+          }
+         
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredLength;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+       
+        break;
+      }
+     
       case TokenOwner:
-    DPRINT("NtQueryInformationToken(TokenOwner)\n");
-    Length = 
RtlLengthSid(Token->UserAndGroups[Token->DefaultOwnerIndex].Sid) + 
sizeof(TOKEN_OWNER);
-    if (TokenInformationLength < Length)
-      {
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-        if (NT_SUCCESS(Status))
-          Status = STATUS_BUFFER_TOO_SMALL;
-      }
-    else
-      {
-        ((PTOKEN_OWNER)TokenInformation)->Owner =
-          (PSID)(((PTOKEN_OWNER)TokenInformation) + 1);
-        RtlCopySid(TokenInformationLength - sizeof(TOKEN_OWNER),
-               ((PTOKEN_OWNER)TokenInformation)->Owner,
-               Token->UserAndGroups[Token->DefaultOwnerIndex].Sid);
-        Status = STATUS_SUCCESS;
-      }
-    break;
+      {
+        ULONG SidLen;
+        PTOKEN_OWNER to = (PTOKEN_OWNER)TokenInformation;
+       
+        DPRINT("NtQueryInformationToken(TokenOwner)\n");
+        SidLen = 
RtlLengthSid(Token->UserAndGroups[Token->DefaultOwnerIndex].Sid);
+        RequiredLength = sizeof(TOKEN_OWNER) + SidLen;
 
+        _SEH_TRY
+        {
+          if(TokenInformationLength >= RequiredLength)
+          {
+            to->Owner = (PSID)(to + 1);
+            Status = RtlCopySid(SidLen,
+                                to->Owner,
+                                
Token->UserAndGroups[Token->DefaultOwnerIndex].Sid);
+          }
+          else
+          {
+            Status = STATUS_BUFFER_TOO_SMALL;
+          }
+         
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredLength;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+       
+        break;
+      }
+     
       case TokenPrimaryGroup:
-    DPRINT("NtQueryInformationToken(TokenPrimaryGroup),"
-           "Token->PrimaryGroup = 0x%08x\n", Token->PrimaryGroup);
-    Length = RtlLengthSid(Token->PrimaryGroup) + 
sizeof(TOKEN_PRIMARY_GROUP);
-    if (TokenInformationLength < Length)
-      {
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-        if (NT_SUCCESS(Status))
-          Status = STATUS_BUFFER_TOO_SMALL;
-      }
-    else
-      {
-        ((PTOKEN_PRIMARY_GROUP)TokenInformation)->PrimaryGroup =
-          (PSID)(((PTOKEN_PRIMARY_GROUP)TokenInformation) + 1);
-        RtlCopySid(TokenInformationLength - sizeof(TOKEN_PRIMARY_GROUP),
-               ((PTOKEN_PRIMARY_GROUP)TokenInformation)->PrimaryGroup,
-               Token->PrimaryGroup);
-        Status = STATUS_SUCCESS;
-      }
-    break;
+      {
+        ULONG SidLen;
+        PTOKEN_PRIMARY_GROUP tpg = (PTOKEN_PRIMARY_GROUP)TokenInformation;
 
+        DPRINT("NtQueryInformationToken(TokenPrimaryGroup)\n");
+        SidLen = RtlLengthSid(Token->PrimaryGroup);
+        RequiredLength = sizeof(TOKEN_PRIMARY_GROUP) + SidLen;
+
+        _SEH_TRY
+        {
+          if(TokenInformationLength >= RequiredLength)
+          {
+            tpg->PrimaryGroup = (PSID)(tpg + 1);
+            Status = RtlCopySid(SidLen,
+                                tpg->PrimaryGroup,
+                                Token->PrimaryGroup);
+          }
+          else
+          {
+            Status = STATUS_BUFFER_TOO_SMALL;
+          }
+
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredLength;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+
+        break;
+      }
+     
       case TokenDefaultDacl:
-    DPRINT("NtQueryInformationToken(TokenDefaultDacl)\n");
-    PtrDefaultDacl = (PTOKEN_DEFAULT_DACL) TokenInformation;
-    Length = (Token->DefaultDacl ? Token->DefaultDacl->AclSize : 0) + 
sizeof(TOKEN_DEFAULT_DACL);
-    if (TokenInformationLength < Length)
-      {
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-        if (NT_SUCCESS(Status))
-          Status = STATUS_BUFFER_TOO_SMALL;
-      }
-    else if (!Token->DefaultDacl)
-      {
-        PtrDefaultDacl->DefaultDacl = 0;
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-      }
-    else
-      {
-        PtrDefaultDacl->DefaultDacl = (PACL) (PtrDefaultDacl + 1);
-        memmove(PtrDefaultDacl->DefaultDacl,
-            Token->DefaultDacl,
-            Token->DefaultDacl->AclSize);
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-      }
-    break;
+      {
+        PTOKEN_DEFAULT_DACL tdd = (PTOKEN_DEFAULT_DACL)TokenInformation;
 
+        DPRINT("NtQueryInformationToken(TokenDefaultDacl)\n");
+        RequiredLength = sizeof(TOKEN_DEFAULT_DACL);
+       
+        if(Token->DefaultDacl != NULL)
+        {
+          RequiredLength += Token->DefaultDacl->AclSize;
+        }
+
+        _SEH_TRY
+        {
+          if(TokenInformationLength >= RequiredLength)
+          {
+            if(Token->DefaultDacl != NULL)
+            {
+              tdd->DefaultDacl = (PACL)(tdd + 1);
+              RtlCopyMemory(tdd->DefaultDacl,
+                            Token->DefaultDacl,
+                            Token->DefaultDacl->AclSize);
+            }
+            else
+            {
+              tdd->DefaultDacl = NULL;
+            }
+          }
+          else
+          {
+            Status = STATUS_BUFFER_TOO_SMALL;
+          }
+
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredLength;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+
+        break;
+      }
+     
       case TokenSource:
-    DPRINT("NtQueryInformationToken(TokenSource)\n");
-    if (TokenInformationLength < sizeof(TOKEN_SOURCE))
-      {
-        Length = sizeof(TOKEN_SOURCE);
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-        if (NT_SUCCESS(Status))
-          Status = STATUS_BUFFER_TOO_SMALL;
-      }
-    else
-      {
-        Status = MmCopyToCaller(TokenInformation, &Token->TokenSource, 
sizeof(TOKEN_SOURCE));
-      }
-    break;
+      {
+        PTOKEN_SOURCE ts = (PTOKEN_SOURCE)TokenInformation;
 
+        DPRINT("NtQueryInformationToken(TokenSource)\n");
+        RequiredLength = sizeof(TOKEN_SOURCE);
+
+        _SEH_TRY
+        {
+          if(TokenInformationLength >= RequiredLength)
+          {
+            *ts = Token->TokenSource;
+          }
+          else
+          {
+            Status = STATUS_BUFFER_TOO_SMALL;
+          }
+
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredLength;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+
+        break;
+      }
+     
       case TokenType:
-    DPRINT("NtQueryInformationToken(TokenType)\n");
-    if (TokenInformationLength < sizeof(TOKEN_TYPE))
-      {
-        Length = sizeof(TOKEN_TYPE);
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-        if (NT_SUCCESS(Status))
-          Status = STATUS_BUFFER_TOO_SMALL;
-      }
-    else
-      {
-        Status = MmCopyToCaller(TokenInformation, &Token->TokenType, 
sizeof(TOKEN_TYPE));
-      }
-    break;
+      {
+        PTOKEN_TYPE tt = (PTOKEN_TYPE)TokenInformation;
 
+        DPRINT("NtQueryInformationToken(TokenType)\n");
+        RequiredLength = sizeof(TOKEN_TYPE);
+
+        _SEH_TRY
+        {
+          if(TokenInformationLength >= RequiredLength)
+          {
+            *tt = Token->TokenType;
+          }
+          else
+          {
+            Status = STATUS_BUFFER_TOO_SMALL;
+          }
+
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredLength;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+
+        break;
+      }
+     
       case TokenImpersonationLevel:
-    DPRINT("NtQueryInformationToken(TokenImpersonationLevel)\n");
-    if (TokenInformationLength < sizeof(SECURITY_IMPERSONATION_LEVEL))
-      {
-        Length = sizeof(SECURITY_IMPERSONATION_LEVEL);
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-        if (NT_SUCCESS(Status))
-          Status = STATUS_BUFFER_TOO_SMALL;
-      }
-    else
-      {
-        Status = MmCopyToCaller(TokenInformation, 
&Token->ImpersonationLevel, sizeof(SECURITY_IMPERSONATION_LEVEL));
-      }
-    break;
+      {
+        PSECURITY_IMPERSONATION_LEVEL sil = 
(PSECURITY_IMPERSONATION_LEVEL)TokenInformation;
 
+        DPRINT("NtQueryInformationToken(TokenImpersonationLevel)\n");
+        RequiredLength = sizeof(SECURITY_IMPERSONATION_LEVEL);
+
+        _SEH_TRY
+        {
+          if(TokenInformationLength >= RequiredLength)
+          {
+            *sil = Token->ImpersonationLevel;
+          }
+          else
+          {
+            Status = STATUS_BUFFER_TOO_SMALL;
+          }
+
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredLength;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+
+        break;
+      }
+     
       case TokenStatistics:
-    DPRINT("NtQueryInformationToken(TokenStatistics)\n");
-    if (TokenInformationLength < sizeof(TOKEN_STATISTICS))
-      {
-        Length = sizeof(TOKEN_STATISTICS);
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-        if (NT_SUCCESS(Status))
-          Status = STATUS_BUFFER_TOO_SMALL;
-      }
-    else
-      {
-        PtrTokenStatistics = (PTOKEN_STATISTICS)TokenInformation;
-        PtrTokenStatistics->TokenId = Token->TokenId;
-        PtrTokenStatistics->AuthenticationId = Token->AuthenticationId;
-        PtrTokenStatistics->ExpirationTime = Token->ExpirationTime;
-        PtrTokenStatistics->TokenType = Token->TokenType;
-        PtrTokenStatistics->ImpersonationLevel = Token->ImpersonationLevel;
-        PtrTokenStatistics->DynamicCharged = Token->DynamicCharged;
-        PtrTokenStatistics->DynamicAvailable = Token->DynamicAvailable;
-        PtrTokenStatistics->GroupCount = Token->UserAndGroupCount - 1;
-        PtrTokenStatistics->PrivilegeCount = Token->PrivilegeCount;
-        PtrTokenStatistics->ModifiedId = Token->ModifiedId;
+      {
+        PTOKEN_STATISTICS ts = (PTOKEN_STATISTICS)TokenInformation;
 
-        Status = STATUS_SUCCESS;
-      }
-    break;
+        DPRINT("NtQueryInformationToken(TokenStatistics)\n");
+        RequiredLength = sizeof(TOKEN_STATISTICS);
 
+        _SEH_TRY
+        {
+          if(TokenInformationLength >= RequiredLength)
+          {
+            ts->TokenId = Token->TokenId;
+            ts->AuthenticationId = Token->AuthenticationId;
+            ts->ExpirationTime = Token->ExpirationTime;
+            ts->TokenType = Token->TokenType;
+            ts->ImpersonationLevel = Token->ImpersonationLevel;
+            ts->DynamicCharged = Token->DynamicCharged;
+            ts->DynamicAvailable = Token->DynamicAvailable;
+            ts->GroupCount = Token->UserAndGroupCount - 1;
+            ts->PrivilegeCount = Token->PrivilegeCount;
+            ts->ModifiedId = Token->ModifiedId;
+          }
+          else
+          {
+            Status = STATUS_BUFFER_TOO_SMALL;
+          }
+
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredLength;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+
+        break;
+      }
+     
       case TokenOrigin:
-    DPRINT("NtQueryInformationToken(TokenOrigin)\n");
-    if (TokenInformationLength < sizeof(TOKEN_ORIGIN))
-      {
-        Status = STATUS_BUFFER_TOO_SMALL;
-      }
-    else
-      {
-        Status = 
MmCopyToCaller(&((PTOKEN_ORIGIN)TokenInformation)->OriginatingLogonSession,
-                                &Token->AuthenticationId, sizeof(LUID));
-      }
-    Length = sizeof(TOKEN_ORIGIN);
-    LengthStatus = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-    if (NT_SUCCESS(Status))
-      {
-        Status = LengthStatus;
-      }
-    break;
+      {
+        PTOKEN_ORIGIN to = (PTOKEN_ORIGIN)TokenInformation;
 
+        DPRINT("NtQueryInformationToken(TokenOrigin)\n");
+        RequiredLength = sizeof(TOKEN_ORIGIN);
+
+        _SEH_TRY
+        {
+          if(TokenInformationLength >= RequiredLength)
+          {
+            RtlCopyLuid(&to->OriginatingLogonSession,
+                        &Token->AuthenticationId);
+          }
+          else
+          {
+            Status = STATUS_BUFFER_TOO_SMALL;
+          }
+
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredLength;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+
+        break;
+      }
+
       case TokenGroupsAndPrivileges:
     DPRINT1("NtQueryInformationToken(TokenGroupsAndPrivileges) not 
implemented\n");
     Status = STATUS_NOT_IMPLEMENTED;
@@ -847,27 +1084,44 @@
     break;
 
       case TokenSessionId:
-    DPRINT("NtQueryInformationToken(TokenSessionId)\n");
-    if (TokenInformationLength < sizeof(ULONG))
-      {
-        Length = sizeof(ULONG);
-        Status = MmCopyToCaller(ReturnLength, &Length, sizeof(ULONG));
-        if (NT_SUCCESS(Status))
-          Status = STATUS_BUFFER_TOO_SMALL;
-      }
-    else
-      {
-        Status = MmCopyToCaller(TokenInformation, &Token->SessionId, 
sizeof(ULONG));
-      }
-    break;
+      {
+        ULONG SessionId = 0;
 
+        DPRINT("NtQueryInformationToken(TokenSessionId)\n");
+       
+        Status = SeQuerySessionIdToken(Token,
+                                       &SessionId);
+
+        if(NT_SUCCESS(Status))
+        {
+          _SEH_TRY
+          {
+            /* buffer size was already verified, no need to check here 
again */
+            *(PULONG)TokenInformation = SessionId;
+
+            if(ReturnLength != NULL)
+            {
+              *ReturnLength = sizeof(ULONG);
+            }
+          }
+          _SEH_HANDLE
+          {
+            Status = _SEH_GetExceptionCode();
+          }
+          _SEH_END;
+        }
+       
+        break;
+      }
+
       default:
-    DPRINT1("NtQueryInformationToken(%d) invalid parameter\n");
-    Status = STATUS_INVALID_PARAMETER;
+    DPRINT1("NtQueryInformationToken(%d) invalid information class\n", 
TokenInformationClass);
+    Status = STATUS_INVALID_INFO_CLASS;
     break;
     }
 
-  ObDereferenceObject(Token);
+    ObDereferenceObject(Token);
+  }
 
   return(Status);
 }
@@ -888,7 +1142,7 @@
 }
 
 /*
- * @unimplemented
+ * @implemented
  */
 NTSTATUS
 STDCALL
@@ -897,14 +1151,14 @@
     IN PULONG pSessionId
     )
 {
-    UNIMPLEMENTED;
-    return STATUS_NOT_IMPLEMENTED;
+  *pSessionId = ((PTOKEN)Token)->SessionId;
+  return STATUS_SUCCESS;
 }
 
 /*
  * NtSetTokenInformation: Partly implemented.
  * Unimplemented:
- *  TokenOrigin, TokenDefaultDacl, TokenSessionId
+ *  TokenOrigin, TokenDefaultDacl
  */
 
 NTSTATUS STDCALL
@@ -913,123 +1167,229 @@
               OUT PVOID TokenInformation,
               IN ULONG TokenInformationLength)
 {
-  NTSTATUS Status;
   PTOKEN Token;
-  TOKEN_OWNER TokenOwnerSet = { 0 };
-  TOKEN_PRIMARY_GROUP TokenPrimaryGroupSet = { 0 };
-  DWORD NeededAccess = 0;
+  KPROCESSOR_MODE PreviousMode;
+  ULONG NeededAccess = TOKEN_ADJUST_DEFAULT;
+  NTSTATUS Status = STATUS_SUCCESS;
  
   PAGED_CODE();
+ 
+  PreviousMode = ExGetPreviousMode();
+ 
+  DefaultSetInfoBufferCheck(TokenInformationClass,
+                            SeTokenInformationClass,
+                            TokenInformation,
+                            TokenInformationLength,
+                            PreviousMode,
+                            &Status);
 
-  switch (TokenInformationClass)
-    {
-    case TokenOwner:
-    case TokenPrimaryGroup:
-      NeededAccess = TOKEN_ADJUST_DEFAULT;
-      break;
+  if(!NT_SUCCESS(Status))
+  {
+    /* Invalid buffers */
+    DPRINT("NtSetInformationToken() failed, Status: 0x%x\n", Status);
+    return Status;
+  }
+ 
+  if(TokenInformationClass == TokenSessionId)
+  {
+    NeededAccess |= TOKEN_ADJUST_SESSIONID;
+  }
 
-    case TokenDefaultDacl:
-      if (TokenInformationLength < sizeof(TOKEN_DEFAULT_DACL))
-        return STATUS_BUFFER_TOO_SMALL;
-      NeededAccess = TOKEN_ADJUST_DEFAULT;
-      break;
-
-    default:
-      DPRINT1("NtSetInformationToken: lying about success (stub) - 
%x\n", TokenInformationClass);
-      return STATUS_SUCCESS; 
-
-    }
-
   Status = ObReferenceObjectByHandle(TokenHandle,
                      NeededAccess,
                      SepTokenObjectType,
-                     UserMode,
+                     PreviousMode,
                      (PVOID*)&Token,
                      NULL);
-  if (!NT_SUCCESS(Status))
+  if (NT_SUCCESS(Status))
+  {
+    switch (TokenInformationClass)
     {
-      return(Status);
-    }
-
-  switch (TokenInformationClass)
-    {
-    case TokenOwner:
-      MmCopyFromCaller( &TokenOwnerSet, TokenInformation,
-            min(sizeof(TokenOwnerSet),TokenInformationLength) );
-      RtlCopySid(TokenInformationLength - sizeof(TOKEN_OWNER),
-         Token->UserAndGroups[Token->DefaultOwnerIndex].Sid,
-         TokenOwnerSet.Owner);
-      Status = STATUS_SUCCESS;
-      DPRINT("NtSetInformationToken(TokenOwner)\n");
-      break;
+      case TokenOwner:
+      {
+        if(TokenInformationLength >= sizeof(TOKEN_OWNER))
+        {
+          PTOKEN_OWNER to = (PTOKEN_OWNER)TokenInformation;
+          PSID InputSid = NULL;
+         
+          _SEH_TRY
+          {
+            InputSid = to->Owner;
+          }
+          _SEH_HANDLE
+          {
+            Status = _SEH_GetExceptionCode();
+          }
+          _SEH_END;
+         
+          if(NT_SUCCESS(Status))
+          {
+            PSID CapturedSid;
+           
+            Status = SepCaptureSid(InputSid,
+                                   PreviousMode,
+                                   PagedPool,
+                                   FALSE,
+                                   &CapturedSid);
+            if(NT_SUCCESS(Status))
+            {
+              RtlCopySid(RtlLengthSid(CapturedSid),
+                         
Token->UserAndGroups[Token->DefaultOwnerIndex].Sid,
+                         CapturedSid);
+              SepReleaseSid(CapturedSid,
+                            PreviousMode,
+                            FALSE);
+            }
+          }
+        }
+        else
+        {
+          Status = STATUS_INFO_LENGTH_MISMATCH;
+        }
+        break;
+      }
      
-    case TokenPrimaryGroup:
-      MmCopyFromCaller( &TokenPrimaryGroupSet, TokenInformation,
-            min(sizeof(TokenPrimaryGroupSet),
-                TokenInformationLength) );
-      RtlCopySid(TokenInformationLength - sizeof(TOKEN_PRIMARY_GROUP),
-         Token->PrimaryGroup,
-         TokenPrimaryGroupSet.PrimaryGroup);
-      Status = STATUS_SUCCESS;
-      DPRINT("NtSetInformationToken(TokenPrimaryGroup),"
-         "Token->PrimaryGroup = 0x%08x\n", Token->PrimaryGroup);
-      break;
-
-    case TokenDefaultDacl:
+      case TokenPrimaryGroup:
       {
-        TOKEN_DEFAULT_DACL TokenDefaultDacl = { 0 };
-        ACL OldAcl;
-        PACL NewAcl;
+        if(TokenInformationLength >= sizeof(TOKEN_PRIMARY_GROUP))
+        {
+          PTOKEN_PRIMARY_GROUP tpg = 
(PTOKEN_PRIMARY_GROUP)TokenInformation;
+          PSID InputSid = NULL;
 
-        Status = MmCopyFromCaller( &TokenDefaultDacl, TokenInformation,
-                                   sizeof(TOKEN_DEFAULT_DACL) );
-        if (!NT_SUCCESS(Status))
+          _SEH_TRY
           {
-            Status = STATUS_INVALID_PARAMETER;
-            break;
+            InputSid = tpg->PrimaryGroup;
           }
+          _SEH_HANDLE
+          {
+            Status = _SEH_GetExceptionCode();
+          }
+          _SEH_END;
 
-        Status = MmCopyFromCaller( &OldAcl, TokenDefaultDacl.DefaultDacl,
-                                   sizeof(ACL) );
-        if (!NT_SUCCESS(Status))
+          if(NT_SUCCESS(Status))
           {
-            Status = STATUS_INVALID_PARAMETER;
-            break;
+            PSID CapturedSid;
+
+            Status = SepCaptureSid(InputSid,
+                                   PreviousMode,
+                                   PagedPool,
+                                   FALSE,
+                                   &CapturedSid);
+            if(NT_SUCCESS(Status))
+            {
+              RtlCopySid(RtlLengthSid(CapturedSid),
+                         Token->PrimaryGroup,
+                         CapturedSid);
+              SepReleaseSid(CapturedSid,
+                            PreviousMode,
+                            FALSE);
+            }
           }
+        }
+        else
+        {
+          Status = STATUS_INFO_LENGTH_MISMATCH;
+        }
+        break;
+      }
+     
+      case TokenDefaultDacl:
+      {
+        if(TokenInformationLength >= sizeof(TOKEN_DEFAULT_DACL))
+        {
+          PTOKEN_DEFAULT_DACL tdd = (PTOKEN_DEFAULT_DACL)TokenInformation;
+          PACL InputAcl = NULL;
 
-        NewAcl = ExAllocatePool(NonPagedPool, sizeof(ACL));
-        if (NewAcl == NULL)
+          _SEH_TRY
           {
-            Status = STATUS_INSUFFICIENT_RESOURCES;
-            break;
+            InputAcl = tdd->DefaultDacl;
           }
+          _SEH_HANDLE
+          {
+            Status = _SEH_GetExceptionCode();
+          }
+          _SEH_END;
 
-        Status = MmCopyFromCaller( NewAcl, TokenDefaultDacl.DefaultDacl,
-                                   OldAcl.AclSize );
-        if (!NT_SUCCESS(Status))
+          if(NT_SUCCESS(Status))
           {
-            Status = STATUS_INVALID_PARAMETER;
-            ExFreePool(NewAcl);
-            break;
+            if(InputAcl != NULL)
+            {
+              PACL CapturedAcl;
+
+              /* capture and copy the dacl */
+              Status = SepCaptureAcl(InputAcl,
+                                     PreviousMode,
+                                     PagedPool,
+                                     TRUE,
+                                     &CapturedAcl);
+              if(NT_SUCCESS(Status))
+              {
+                /* free the previous dacl if present */
+                if(Token->DefaultDacl != NULL)
+                {
+                  ExFreePool(Token->DefaultDacl);
+                }
+               
+                /* set the new dacl */
+                Token->DefaultDacl = CapturedAcl;
+              }
+            }
+            else
+            {
+              /* clear and free the default dacl if present */
+              if(Token->DefaultDacl != NULL)
+              {
+                ExFreePool(Token->DefaultDacl);
+                Token->DefaultDacl = NULL;
+              }
+            }
           }
+        }
+        else
+        {
+          Status = STATUS_INFO_LENGTH_MISMATCH;
+        }
+        break;
+      }
+     
+      case TokenSessionId:
+      {
+        ULONG SessionId = 0;
 
-        if (Token->DefaultDacl)
+        _SEH_TRY
+        {
+          /* buffer size was already verified, no need to check here 
again */
+          SessionId = *(PULONG)TokenInformation;
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+
+        if(NT_SUCCESS(Status))
+        {
+          if(!SeSinglePrivilegeCheck(SeTcbPrivilege,
+                                     PreviousMode))
           {
-            ExFreePool(Token->DefaultDacl);
+            Status = STATUS_PRIVILEGE_NOT_HELD;
+            break;
           }
 
-        Token->DefaultDacl = NewAcl;
+          Token->SessionId = SessionId;
+        }
+        break;
+      }
 
-        Status = STATUS_SUCCESS;
+      default:
+      {
+        Status = STATUS_NOT_IMPLEMENTED;
         break;
       }
-
-    default:
-      Status = STATUS_NOT_IMPLEMENTED;
-      break;
     }
 
-  ObDereferenceObject(Token);
+    ObDereferenceObject(Token);
+  }
 
   return(Status);
 }
@@ -1045,16 +1405,18 @@
  */
 NTSTATUS STDCALL
 NtDuplicateToken(IN HANDLE ExistingTokenHandle,
-         IN ACCESS_MASK DesiredAccess,
-       IN POBJECT_ATTRIBUTES ObjectAttributes OPTIONAL /*is it really 
optional?*/,
-       IN BOOLEAN EffectiveOnly,
-         IN TOKEN_TYPE TokenType,
-         OUT PHANDLE NewTokenHandle)
+                 IN ACCESS_MASK DesiredAccess,
+                 IN POBJECT_ATTRIBUTES ObjectAttributes  OPTIONAL,
+                 IN BOOLEAN EffectiveOnly,
+                 IN TOKEN_TYPE TokenType,
+                 OUT PHANDLE NewTokenHandle)
 {
   KPROCESSOR_MODE PreviousMode;
   HANDLE hToken;
   PTOKEN Token;
   PTOKEN NewToken;
+  PSECURITY_QUALITY_OF_SERVICE CapturedSecurityQualityOfService;
+  BOOLEAN QoSPresent;
   NTSTATUS Status = STATUS_SUCCESS;
  
   PAGED_CODE();
@@ -1081,57 +1443,66 @@
     }
   }
  
+  Status = SepCaptureSecurityQualityOfService(ObjectAttributes,
+                                              PreviousMode,
+                                              PagedPool,
+                                              FALSE,
+                                              
&CapturedSecurityQualityOfService,
+                                              &QoSPresent);
+  if(!NT_SUCCESS(Status))
+  {
+    DPRINT1("NtDuplicateToken() failed to capture QoS! Status: 0x%x\n", 
Status);
+    return Status;
+  }
+ 
   Status = ObReferenceObjectByHandle(ExistingTokenHandle,
                      TOKEN_DUPLICATE,
                      SepTokenObjectType,
                      PreviousMode,
                      (PVOID*)&Token,
                      NULL);
-  if (!NT_SUCCESS(Status))
-    {
-      DPRINT1("Failed to reference token (Status %lx)\n", Status);
-      return Status;
-    }
+  if (NT_SUCCESS(Status))
+  {
+    Status = SepDuplicateToken(Token,
+                               ObjectAttributes,
+                               EffectiveOnly,
+                               TokenType,
+                               (QoSPresent ? 
CapturedSecurityQualityOfService->ImpersonationLevel : SecurityAnonymous),
+                     PreviousMode,
+                     &NewToken);
 
-  Status = SepDuplicateToken(Token,
-                 ObjectAttributes,
-                 EffectiveOnly,
-                 TokenType,
-              ObjectAttributes->SecurityQualityOfService ?
-                  
((PSECURITY_QUALITY_OF_SERVICE)(ObjectAttributes->SecurityQualityOfService))->ImpersonationLevel 
:
-                  0 /*SecurityAnonymous*/,
-                 PreviousMode,
-                 &NewToken);
+    ObDereferenceObject(Token);
 
-  ObDereferenceObject(Token);
-
-  if (!NT_SUCCESS(Status))
+    if (NT_SUCCESS(Status))
     {
-      DPRINT1("Failed to duplicate token (Status %lx)\n", Status);
-      return Status;
-    }
+      Status = ObInsertObject((PVOID)NewToken,
+                  NULL,
+                  DesiredAccess,
+                  0,
+                  NULL,
+                  &hToken);
 
-  Status = ObInsertObject((PVOID)NewToken,
-              NULL,
-              DesiredAccess,
-              0,
-              NULL,
-              &hToken);
+      ObDereferenceObject(NewToken);
 
-  ObDereferenceObject(NewToken);
-
-  if (NT_SUCCESS(Status))
-    {
-      _SEH_TRY
+      if (NT_SUCCESS(Status))
       {
-        *NewTokenHandle = hToken;
+        _SEH_TRY
+        {
+          *NewTokenHandle = hToken;
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
       }
-      _SEH_HANDLE
-      {
-        Status = _SEH_GetExceptionCode();
-      }
-      _SEH_END;
     }
+  }
+ 
+  /* free the captured structure */
+  SepReleaseSecurityQualityOfService(CapturedSecurityQualityOfService,
+                                     PreviousMode,
+                                     FALSE);
 
   return Status;
 }
@@ -1408,6 +1779,8 @@
   NTSTATUS Status;
   ULONG uSize;
   ULONG i;
+ 
+  PAGED_CODE();
 
   ULONG uLocalSystemLength = RtlLengthSid(SeLocalSystemSid);
   ULONG uWorldLength       = RtlLengthSid(SeWorldSid);
@@ -1456,6 +1829,8 @@
       return Status;
     }
 
+  AccessToken->TokenLock = &SepTokenLock;
+
   AccessToken->TokenType = TokenPrimary;
   AccessToken->ImpersonationLevel = SecurityDelegation;
   AccessToken->TokenSource.SourceIdentifier.LowPart = 0;
@@ -1706,6 +2081,8 @@
       return(Status);
     }
 
+  AccessToken->TokenLock = &SepTokenLock;
+
   RtlCopyLuid(&AccessToken->TokenSource.SourceIdentifier,
           &TokenSource->SourceIdentifier);
   memcpy(AccessToken->TokenSource.SourceName,
@@ -1738,7 +2115,7 @@
     uLength += RtlLengthSid(TokenGroups->Groups[i].Sid);
 
   AccessToken->UserAndGroups =
-    (PSID_AND_ATTRIBUTES)ExAllocatePoolWithTag(NonPagedPool,
+    (PSID_AND_ATTRIBUTES)ExAllocatePoolWithTag(PagedPool,
                            uLength,
                            TAG('T', 'O', 'K', 'u'));
 
@@ -1774,7 +2151,7 @@
     {
       uLength = TokenPrivileges->PrivilegeCount * 
sizeof(LUID_AND_ATTRIBUTES);
       AccessToken->Privileges =
-    (PLUID_AND_ATTRIBUTES)ExAllocatePoolWithTag(NonPagedPool,
+    (PLUID_AND_ATTRIBUTES)ExAllocatePoolWithTag(PagedPool,
                             uLength,
                             TAG('T', 'O', 'K', 'p'));
 
@@ -1791,7 +2168,7 @@
   if (NT_SUCCESS(Status))
     {
       AccessToken->DefaultDacl =
-    (PACL) ExAllocatePoolWithTag(NonPagedPool,
+    (PACL) ExAllocatePoolWithTag(PagedPool,
                      TokenDefaultDacl->DefaultDacl->AclSize,
                      TAG('T', 'O', 'K', 'd'));
       memcpy(AccessToken->DefaultDacl,
Index: ntoskrnl/se/sd.c
===================================================================
--- ntoskrnl/se/sd.c    (.../trunk/reactos)    (revision 13937)
+++ ntoskrnl/se/sd.c    (.../branches/alex_devel_branch/reactos)    
(revision 13942)
@@ -108,6 +108,174 @@
   return TRUE;
 }
 
+
+NTSTATUS
+SepCaptureSecurityQualityOfService(IN POBJECT_ATTRIBUTES 
ObjectAttributes  OPTIONAL,
+                                   IN KPROCESSOR_MODE AccessMode,
+                                   IN POOL_TYPE PoolType,
+                                   IN BOOLEAN CaptureIfKernel,
+                                   OUT PSECURITY_QUALITY_OF_SERVICE 
*CapturedSecurityQualityOfService,
+                                   OUT PBOOLEAN Present)
+{
+  PSECURITY_QUALITY_OF_SERVICE CapturedQos;
+  NTSTATUS Status = STATUS_SUCCESS;
+ 
+  PAGED_CODE();
+
+  ASSERT(CapturedSecurityQualityOfService);
+  ASSERT(Present);
+
+  if(ObjectAttributes != NULL)
+  {
+    if(AccessMode != KernelMode)
+    {
+      SECURITY_QUALITY_OF_SERVICE SafeQos;
+
+      _SEH_TRY
+      {
+        ProbeForRead(ObjectAttributes,
+                     sizeof(ObjectAttributes),
+                     sizeof(ULONG));
+        if(ObjectAttributes->Length == sizeof(OBJECT_ATTRIBUTES))
+        {
+          if(ObjectAttributes->SecurityQualityOfService != NULL)
+          {
+            ProbeForRead(ObjectAttributes->SecurityQualityOfService,
+                         sizeof(SECURITY_QUALITY_OF_SERVICE),
+                         sizeof(ULONG));
+
+            
if(((PSECURITY_QUALITY_OF_SERVICE)ObjectAttributes->SecurityQualityOfService)->Length 
==
+               sizeof(SECURITY_QUALITY_OF_SERVICE))
+            {
+              /* don't allocate memory here because ExAllocate should 
bugcheck
+                 the system if it's buggy, SEH would catch that! So 
make a local
+                 copy of the qos structure.*/
+              RtlCopyMemory(&SafeQos,
+                            ObjectAttributes->SecurityQualityOfService,
+                            sizeof(SECURITY_QUALITY_OF_SERVICE));
+              *Present = TRUE;
+            }
+            else
+            {
+              Status = STATUS_INVALID_PARAMETER;
+            }
+          }
+          else
+          {
+            *CapturedSecurityQualityOfService = NULL;
+            *Present = FALSE;
+          }
+        }
+        else
+        {
+          Status = STATUS_INVALID_PARAMETER;
+        }
+      }
+      _SEH_HANDLE
+      {
+        Status = _SEH_GetExceptionCode();
+      }
+      _SEH_END;
+
+      if(NT_SUCCESS(Status))
+      {
+        if(*Present)
+        {
+          CapturedQos = ExAllocatePool(PoolType,
+                                       
sizeof(SECURITY_QUALITY_OF_SERVICE));
+          if(CapturedQos != NULL)
+          {
+            RtlCopyMemory(CapturedQos,
+                          &SafeQos,
+                          sizeof(SECURITY_QUALITY_OF_SERVICE));
+            *CapturedSecurityQualityOfService = CapturedQos;
+          }
+          else
+          {
+            Status = STATUS_INSUFFICIENT_RESOURCES;
+          }
+        }
+        else
+        {
+          *CapturedSecurityQualityOfService = NULL;
+        }
+      }
+    }
+    else
+    {
+      if(ObjectAttributes->Length == sizeof(OBJECT_ATTRIBUTES))
+      {
+        if(CaptureIfKernel)
+        {
+          if(ObjectAttributes->SecurityQualityOfService != NULL)
+          {
+            
if(((PSECURITY_QUALITY_OF_SERVICE)ObjectAttributes->SecurityQualityOfService)->Length 
==
+               sizeof(SECURITY_QUALITY_OF_SERVICE))
+            {
+              CapturedQos = ExAllocatePool(PoolType,
+                                           
sizeof(SECURITY_QUALITY_OF_SERVICE));
+              if(CapturedQos != NULL)
+              {
+                RtlCopyMemory(CapturedQos,
+                              ObjectAttributes->SecurityQualityOfService,
+                              sizeof(SECURITY_QUALITY_OF_SERVICE));
+                *CapturedSecurityQualityOfService = CapturedQos;
+                *Present = TRUE;
+              }
+              else
+              {
+                Status = STATUS_INSUFFICIENT_RESOURCES;
+              }
+            }
+            else
+            {
+              Status = STATUS_INVALID_PARAMETER;
+            }
+          }
+          else
+          {
+            *CapturedSecurityQualityOfService = NULL;
+            *Present = FALSE;
+          }
+        }
+        else
+        {
+          *CapturedSecurityQualityOfService = 
(PSECURITY_QUALITY_OF_SERVICE)ObjectAttributes->SecurityQualityOfService;
+          *Present = (ObjectAttributes->SecurityQualityOfService != NULL);
+        }
+      }
+      else
+      {
+        Status = STATUS_INVALID_PARAMETER;
+      }
+    }
+  }
+  else
+  {
+    *CapturedSecurityQualityOfService = NULL;
+    *Present = FALSE;
+  }
+
+  return Status;
+}
+
+
+VOID
+SepReleaseSecurityQualityOfService(IN PSECURITY_QUALITY_OF_SERVICE 
CapturedSecurityQualityOfService  OPTIONAL,
+                                   IN KPROCESSOR_MODE AccessMode,
+                                   IN BOOLEAN CaptureIfKernel)
+{
+  PAGED_CODE();
+ 
+  if(CapturedSecurityQualityOfService != NULL &&
+     (AccessMode == UserMode ||
+      (AccessMode == KernelMode && CaptureIfKernel)))
+  {
+    ExFreePool(CapturedSecurityQualityOfService);
+  }
+}
+
+
 /*
  * @implemented
  */
@@ -129,6 +297,8 @@
   ULONG DescriptorSize = 0;
   NTSTATUS Status = STATUS_SUCCESS;
  
+  PAGED_CODE();
+ 
   if(OriginalSecurityDescriptor != NULL)
   {
     if(CurrentMode != KernelMode)
@@ -144,39 +314,40 @@
                      DescriptorSize,
                      sizeof(ULONG));
 
-        if(OriginalSecurityDescriptor->Revision != 
SECURITY_DESCRIPTOR_REVISION1)
+        if(OriginalSecurityDescriptor->Revision == 
SECURITY_DESCRIPTOR_REVISION1)
         {
-          Status = STATUS_UNKNOWN_REVISION;
-          _SEH_LEAVE;
-        }
-       
-        /* make a copy on the stack */
-        DescriptorCopy.Revision = OriginalSecurityDescriptor->Revision;
-        DescriptorCopy.Sbz1 = OriginalSecurityDescriptor->Sbz1;
-        DescriptorCopy.Control = OriginalSecurityDescriptor->Control;
-        DescriptorSize = ((DescriptorCopy.Control & SE_SELF_RELATIVE) ?
-                          sizeof(SECURITY_DESCRIPTOR_RELATIVE) : 
sizeof(SECURITY_DESCRIPTOR));
+          /* make a copy on the stack */
+          DescriptorCopy.Revision = OriginalSecurityDescriptor->Revision;
+          DescriptorCopy.Sbz1 = OriginalSecurityDescriptor->Sbz1;
+          DescriptorCopy.Control = OriginalSecurityDescriptor->Control;
+          DescriptorSize = ((DescriptorCopy.Control & SE_SELF_RELATIVE) ?
+                            sizeof(SECURITY_DESCRIPTOR_RELATIVE) : 
sizeof(SECURITY_DESCRIPTOR));
 
-        /* probe and copy the entire security descriptor structure. The 
SIDs
-           and ACLs will be probed and copied later though */
-        ProbeForRead(OriginalSecurityDescriptor,
-                     DescriptorSize,
-                     sizeof(ULONG));
-        if(DescriptorCopy.Control & SE_SELF_RELATIVE)
-        {
-          PSECURITY_DESCRIPTOR_RELATIVE RelSD = 
(PSECURITY_DESCRIPTOR_RELATIVE)OriginalSecurityDescriptor;
-         
-          DescriptorCopy.Owner = (PSID)RelSD->Owner;
-          DescriptorCopy.Group = (PSID)RelSD->Group;
-          DescriptorCopy.Sacl = (PACL)RelSD->Sacl;
-          DescriptorCopy.Dacl = (PACL)RelSD->Dacl;
+          /* probe and copy the entire security descriptor structure. 
The SIDs
+             and ACLs will be probed and copied later though */
+          ProbeForRead(OriginalSecurityDescriptor,
+                       DescriptorSize,
+                       sizeof(ULONG));
+          if(DescriptorCopy.Control & SE_SELF_RELATIVE)
+          {
+            PSECURITY_DESCRIPTOR_RELATIVE RelSD = 
(PSECURITY_DESCRIPTOR_RELATIVE)OriginalSecurityDescriptor;
+
+            DescriptorCopy.Owner = (PSID)RelSD->Owner;
+            DescriptorCopy.Group = (PSID)RelSD->Group;
+            DescriptorCopy.Sacl = (PACL)RelSD->Sacl;
+            DescriptorCopy.Dacl = (PACL)RelSD->Dacl;
+          }
+          else
+          {
+            DescriptorCopy.Owner = OriginalSecurityDescriptor->Owner;
+            DescriptorCopy.Group = OriginalSecurityDescriptor->Group;
+            DescriptorCopy.Sacl = OriginalSecurityDescriptor->Sacl;
+            DescriptorCopy.Dacl = OriginalSecurityDescriptor->Dacl;
+          }
         }
         else
         {
-          DescriptorCopy.Owner = OriginalSecurityDescriptor->Owner;
-          DescriptorCopy.Group = OriginalSecurityDescriptor->Group;
-          DescriptorCopy.Sacl = OriginalSecurityDescriptor->Sacl;
-          DescriptorCopy.Dacl = OriginalSecurityDescriptor->Dacl;
+          Status = STATUS_UNKNOWN_REVISION;
         }
       }
       _SEH_HANDLE
@@ -572,6 +743,8 @@
     IN BOOLEAN CaptureIfKernelMode
     )
 {
+  PAGED_CODE();
+ 
   /* WARNING! You need to call this function with the same value for 
CurrentMode
               and CaptureIfKernelMode that you previously passed to
               SeCaptureSecurityDescriptor() in order to avoid memory 
leaks! */




More information about the Ros-diffs mailing list