Logo Search packages:      
Sourcecode: r-cran-multicore version File versions  Download package

forknt.c

/* Implementation of COW fork() using NTDLL API for Windows systems

   (C)Copyright 2009 Simon Urbanek <simon.urbanek@r-project.org>

   This code is partially based on the book
   "Windows NT/2000 Native API Reference" by Gary Nebbett
   (Sams Publishing, 2000, ISBN 1-57870-199-6)

 */

#ifdef WIN32

#include <windows.h>
#include <setjmp.h>

/* winternl.h is not part of MinGW so we have to declare whatever is needed */

#pragma mark ntdll API types

typedef LONG NTSTATUS;

typedef struct _SYSTEM_HANDLE_INFORMATION {
      ULONG ProcessId;
      UCHAR ObjectTypeNumber;
      UCHAR Flags;
      USHORT Handle;
      PVOID Object;
      ACCESS_MASK GrantedAccess;
} SYSTEM_HANDLE_INFORMATION, *PSYSTEM_HANDLE_INFORMATION;

typedef struct _OBJECT_ATTRIBUTES {
      ULONG Length;
      HANDLE RootDirectory;
      PVOID /* really PUNICODE_STRING */  ObjectName;
      ULONG Attributes;
      PVOID SecurityDescriptor;       /* type SECURITY_DESCRIPTOR */
      PVOID SecurityQualityOfService; /* type SECURITY_QUALITY_OF_SERVICE */
} OBJECT_ATTRIBUTES, *POBJECT_ATTRIBUTES;

typedef enum _MEMORY_INFORMATION_{
      MemoryBasicInformation,
      MemoryWorkingSetList,
      MemorySectionName,
      MemoryBasicVlmInformation
} MEMORY_INFORMATION_CLASS;

typedef struct _CLIENT_ID {
      HANDLE UniqueProcess;
      HANDLE UniqueThread;
} CLIENT_ID, *PCLIENT_ID;

typedef struct _USER_STACK {
      PVOID FixedStackBase;
      PVOID FixedStackLimit;
      PVOID ExpandableStackBase;
      PVOID ExpandableStackLimit;
      PVOID ExpandableStackBottom;
} USER_STACK, *PUSER_STACK;

typedef LONG KPRIORITY;
typedef ULONG_PTR KAFFINITY;
typedef KAFFINITY *PKAFFINITY;

typedef struct _THREAD_BASIC_INFORMATION {
      NTSTATUS                ExitStatus;
      PVOID                   TebBaseAddress;
      CLIENT_ID               ClientId;
      KAFFINITY               AffinityMask;
      KPRIORITY               Priority;
      KPRIORITY               BasePriority;
} THREAD_BASIC_INFORMATION, *PTHREAD_BASIC_INFORMATION;

typedef enum _THREAD_INFORMATION_CLASS {
      ThreadBasicInformation,
      ThreadTimes,
      ThreadPriority,
      ThreadBasePriority,
      ThreadAffinityMask,
      ThreadImpersonationToken,
      ThreadDescriptorTableEntry,
      ThreadEnableAlignmentFaultFixup,
      ThreadEventPair,
      ThreadQuerySetWin32StartAddress,
      ThreadZeroTlsCell,
      ThreadPerformanceCount,
      ThreadAmILastThread,
      ThreadIdealProcessor,
      ThreadPriorityBoost,
      ThreadSetTlsArrayAddress,
      ThreadIsIoPending,
      ThreadHideFromDebugger
} THREAD_INFORMATION_CLASS, *PTHREAD_INFORMATION_CLASS;

typedef enum _SYSTEM_INFORMATION_CLASS { SystemHandleInformation = 0x10 } SYSTEM_INFORMATION_CLASS;

#pragma mark ntdll API - function entry points

typedef NTSTATUS (NTAPI *ZwWriteVirtualMemory_t)(IN HANDLE               ProcessHandle,
                                                                         IN PVOID                BaseAddress,
                                                                         IN PVOID                Buffer,
                                                                         IN ULONG                NumberOfBytesToWrite,
                                                                         OUT PULONG              NumberOfBytesWritten OPTIONAL);
typedef NTSTATUS (NTAPI *ZwCreateProcess_t)(OUT PHANDLE            ProcessHandle, 
                                                                  IN  ACCESS_MASK        DesiredAccess, 
                                                                  IN  POBJECT_ATTRIBUTES ObjectAttributes, 
                                                                  IN  HANDLE             InheriteFromProcessHandle, 
                                                                  IN  BOOLEAN            InheritHandles, 
                                                                  IN  HANDLE             SectionHandle    OPTIONAL, 
                                                                  IN  HANDLE             DebugPort        OPTIONAL, 
                                                                  IN  HANDLE             ExceptionPort    OPTIONAL);
typedef NTSTATUS (WINAPI *ZwQuerySystemInformation_t)(SYSTEM_INFORMATION_CLASS SystemInformationClass,
                                                                                PVOID SystemInformation,
                                                                                ULONG SystemInformationLength,
                                                                                PULONG ReturnLength);
typedef NTSTATUS (NTAPI *ZwQueryVirtualMemory_t)(IN  HANDLE ProcessHandle,
                                                                         IN  PVOID BaseAddress,
                                                                         IN  MEMORY_INFORMATION_CLASS MemoryInformationClass,
                                                                         OUT PVOID MemoryInformation,
                                                                         IN  ULONG MemoryInformationLength,
                                                                         OUT PULONG ReturnLength OPTIONAL);
typedef NTSTATUS (NTAPI *ZwGetContextThread_t)(IN HANDLE ThreadHandle, OUT PCONTEXT Context);
typedef NTSTATUS (NTAPI *ZwCreateThread_t)(OUT PHANDLE ThreadHandle,
                                                               IN  ACCESS_MASK DesiredAccess,
                                                               IN  POBJECT_ATTRIBUTES ObjectAttributes,
                                                               IN  HANDLE ProcessHandle,
                                                               OUT PCLIENT_ID ClientId,
                                                               IN  PCONTEXT ThreadContext,
                                                               IN  PUSER_STACK UserStack,
                                                               IN  BOOLEAN CreateSuspended); 
typedef NTSTATUS (NTAPI *ZwResumeThread_t)(IN HANDLE ThreadHandle, OUT PULONG SuspendCount OPTIONAL);
typedef NTSTATUS (NTAPI *ZwClose_t)(IN HANDLE ObjectHandle);
typedef NTSTATUS (NTAPI *ZwQueryInformationThread_t)(IN HANDLE               ThreadHandle,
                                                                               IN THREAD_INFORMATION_CLASS ThreadInformationClass,
                                                                               OUT PVOID               ThreadInformation,
                                                                               IN ULONG                ThreadInformationLength,
                                                                               OUT PULONG              ReturnLength OPTIONAL );

/* function pointers */
static ZwCreateProcess_t ZwCreateProcess;
static ZwQuerySystemInformation_t ZwQuerySystemInformation;
static ZwQueryVirtualMemory_t ZwQueryVirtualMemory;
static ZwCreateThread_t ZwCreateThread;
static ZwGetContextThread_t ZwGetContextThread;
static ZwResumeThread_t ZwResumeThread;
static ZwClose_t ZwClose;
static ZwQueryInformationThread_t ZwQueryInformationThread;
static ZwWriteVirtualMemory_t ZwWriteVirtualMemory;

/* macro definitions */

#define NtCurrentProcess() ((HANDLE)-1)
#define NtCurrentThread() ((HANDLE) -2)
/* we use really the Nt versions - so the following is just for completeness */
#define ZwCurrentProcess() NtCurrentProcess()     
#define ZwCurrentThread() NtCurrentThread()

#define STATUS_INFO_LENGTH_MISMATCH      ((NTSTATUS)0xC0000004L)
#define STATUS_SUCCESS ((NTSTATUS)0x00000000L)

#pragma mark -- helper functions --

#ifdef INHERIT_ALL
/* set all handles belonging to this process as inheritable */
static void set_inherit_all()
{
      ULONG n = 0x1000;
      PULONG p = (PULONG) calloc(n, sizeof(ULONG));

      /* some guesswork to allocate a structure that will fit it all */
      while (ZwQuerySystemInformation(SystemHandleInformation, p, n * sizeof(ULONG), 0) == STATUS_INFO_LENGTH_MISMATCH) {
            free(p);
            n *= 2;
            p = (PULONG) calloc(n, sizeof(ULONG));
      }
      
      /* p points to an ULONG with the count, the entries follow (hence p[0] is the size and p[1] is where the first entry starts */
      PSYSTEM_HANDLE_INFORMATION h = (PSYSTEM_HANDLE_INFORMATION)(p + 1);

      ULONG pid = GetCurrentProcessId();
      ULONG i = 0, count = *p;

      while (i < count) {
            if (h[i].ProcessId == pid)
                  SetHandleInformation((HANDLE)(ULONG) h[i].Handle, HANDLE_FLAG_INHERIT, HANDLE_FLAG_INHERIT);
            i++;
      }
      free(p);
}
#endif

/* setjmp env for the jump back into the fork() function */
static jmp_buf jenv;

/* entry point for our child thread process - just longjmp into fork */
static int child_entry(void) {
      longjmp(jenv, 1);
      return 0;
}

/* initialize NTDLL entry points */
static int init_NTAPI(void) {
      HANDLE ntdll = GetModuleHandle("ntdll");
      if (ntdll == NULL) return -1;
      ZwCreateProcess = (ZwCreateProcess_t) GetProcAddress(ntdll, "ZwCreateProcess");
      ZwQuerySystemInformation = (ZwQuerySystemInformation_t) GetProcAddress(ntdll, "ZwQuerySystemInformation");
      ZwQueryVirtualMemory = (ZwQueryVirtualMemory_t) GetProcAddress(ntdll, "ZwQueryVirtualMemory");
      ZwCreateThread = (ZwCreateThread_t) GetProcAddress(ntdll, "ZwCreateThread");
      ZwGetContextThread = (ZwGetContextThread_t) GetProcAddress(ntdll, "ZwGetContextThread");
      ZwResumeThread = (ZwResumeThread_t) GetProcAddress(ntdll, "ZwResumeThread");
      ZwQueryInformationThread = (ZwQueryInformationThread_t) GetProcAddress(ntdll, "ZwQueryInformationThread");
      ZwWriteVirtualMemory = (ZwWriteVirtualMemory_t) GetProcAddress(ntdll, "ZwWriteVirtualMemory");
      ZwClose = (ZwClose_t) GetProcAddress(ntdll, "ZwClose");
      /* in theory we chould check all of them - but I guess that would be a waste of time ... */
      return (!ZwCreateProcess) ? -1 : 0;
}

#pragma mark -- fork() --

int fork(void) {
      if (setjmp(jenv) != 0) return 0; /* return as a child */

      /* check whether the entry points are initilized and get them if necessary */
      if (!ZwCreateProcess && init_NTAPI()) return -1;

#ifdef INHERIT_ALL
      /* make sure all handles are inheritable */
      set_inherit_all();
#endif

      HANDLE hProcess = 0, hThread = 0;
      OBJECT_ATTRIBUTES oa = { sizeof(oa) };

      /* create forked process */
      ZwCreateProcess(&hProcess, PROCESS_ALL_ACCESS, &oa, NtCurrentProcess(), TRUE, 0, 0, 0);

      CONTEXT context = {CONTEXT_FULL | CONTEXT_DEBUG_REGISTERS | CONTEXT_FLOATING_POINT};

      /* set the Eip for the child process to our child function */
      ZwGetContextThread(NtCurrentThread(), &context);
      context.Eip = (ULONG)child_entry;

      MEMORY_BASIC_INFORMATION mbi;
      ZwQueryVirtualMemory(NtCurrentProcess(), (PVOID)context.Esp, MemoryBasicInformation, &mbi, sizeof mbi, 0);

      USER_STACK stack = {0, 0, (PCHAR)mbi.BaseAddress + mbi.RegionSize, mbi.BaseAddress, mbi.AllocationBase};
      CLIENT_ID cid;

      /* create thread using the modified context and stack */
      ZwCreateThread(&hThread, THREAD_ALL_ACCESS, &oa, hProcess, &cid, &context, &stack, TRUE);

      /* copy exception table */
      THREAD_BASIC_INFORMATION tbi;
      ZwQueryInformationThread(NtCurrentThread(), ThreadBasicInformation, &tbi, sizeof tbi, 0);
      PNT_TIB tib = (PNT_TIB)tbi.TebBaseAddress;
      ZwQueryInformationThread(hThread, ThreadBasicInformation, &tbi, sizeof tbi, 0);
      ZwWriteVirtualMemory(hProcess, tbi.TebBaseAddress, &tib->ExceptionList, sizeof tib->ExceptionList, 0);

      /* start (resume really) the child */
      ZwResumeThread(hThread, 0);

      /* clean up */
      ZwClose(hThread);
      ZwClose(hProcess);

      /* exit with child's pid */
      return (int)cid.UniqueProcess;
}


/* Dear Emacs, please be nice and use
   Local Variables:
   mode:c
   tab-width: 4
   c-basic-offset:4
   End:
*/

#else
/* unix has fork() already */
#include <unistd.h>
#endif

Generated by  Doxygen 1.6.0   Back to index