Reactos
1/*
2 * PROJECT: ReactOS kernel-mode tests
3 * LICENSE: LGPL-2.1+ (https://spdx.org/licenses/LGPL-2.1+)
4 * PURPOSE: Kernel-Mode Test Suite loader application
5 * COPYRIGHT: Copyright 2011-2018 Thomas Faber <thomas.faber@reactos.org>
6 */
7
8#define KMT_DEFINE_TEST_FUNCTIONS
9#include <kmt_test.h>
10
11#include "kmtest.h"
12#include <kmt_public.h>
13
14#include <assert.h>
15#include <stdio.h>
16#include <stdlib.h>
17
18#define SERVICE_NAME L"Kmtest"
19#define SERVICE_PATH L"kmtest_drv.sys"
20#define SERVICE_DESCRIPTION L"ReactOS Kernel-Mode Test Suite Driver"
21
22#define RESULTBUFFER_SIZE (1024 * 1024)
23
24typedef enum
25{
26 KMT_DO_NOTHING,
27 KMT_LIST_TESTS,
28 KMT_LIST_ALL_TESTS,
29 KMT_RUN_TEST,
30} KMT_OPERATION;
31
32HANDLE KmtestHandle;
33SC_HANDLE KmtestServiceHandle;
34PCSTR ErrorFileAndLine = "No error";
35
36static void OutputError(IN DWORD Error);
37static DWORD ListTests(IN BOOLEAN IncludeHidden);
38static PKMT_TESTFUNC FindTest(IN PCSTR TestName);
39static DWORD OutputResult(IN PCSTR TestName);
40static DWORD RunTest(IN PCSTR TestName);
41int __cdecl main(int ArgCount, char **Arguments);
42
43/**
44 * @name OutputError
45 *
46 * Output an error message to the console.
47 *
48 * @param Error
49 * Win32 error code
50 */
51static
52void
53OutputError(
54 IN DWORD Error)
55{
56 PSTR Message;
57 if (!FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS | FORMAT_MESSAGE_ALLOCATE_BUFFER,
58 NULL, Error, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&Message, 0, NULL))
59 {
60 fprintf(stderr, "%s: Could not retrieve error message (error 0x%08lx). Original error: 0x%08lx\n",
61 ErrorFileAndLine, GetLastError(), Error);
62 return;
63 }
64
65 fprintf(stderr, "%s: error 0x%08lx: %s\n", ErrorFileAndLine, Error, Message);
66
67 LocalFree(Message);
68}
69
70/**
71 * @name CompareTestNames
72 *
73 * strcmp that skips a leading '-' on either string if present
74 *
75 * @param Str1
76 * @param Str2
77 * @return see strcmp
78 */
79static
80INT
81CompareTestNames(
82 IN PCSTR Str1,
83 IN PCSTR Str2)
84{
85 if (*Str1 == '-')
86 ++Str1;
87 if (*Str2 == '-')
88 ++Str2;
89 while (*Str1 && *Str1 == *Str2)
90 {
91 ++Str1;
92 ++Str2;
93 }
94 return *Str1 - *Str2;
95}
96
97/**
98 * @name ListTests
99 *
100 * Output the list of tests to the console.
101 * The list will comprise tests as listed by the driver
102 * in addition to user-mode tests in TestList.
103 *
104 * @param IncludeHidden
105 * TRUE to include "hidden" tests prefixed with a '-'
106 *
107 * @return Win32 error code
108 */
109static
110DWORD
111ListTests(
112 IN BOOLEAN IncludeHidden)
113{
114 DWORD Error = ERROR_SUCCESS;
115 PSTR Buffer = NULL;
116 DWORD BufferSize = 1024;
117 DWORD BytesRead = BufferSize;
118 PCSTR TestName;
119 PCKMT_TEST TestEntry = TestList;
120 PCSTR NextTestName;
121
122 puts("Valid test names:");
123
124 // get test list from driver
125 while (TRUE)
126 {
127 Buffer = HeapAlloc(GetProcessHeap(), 0, BufferSize);
128 if (!Buffer)
129 error_goto(Error, cleanup);
130
131 if (!DeviceIoControl(KmtestHandle, IOCTL_KMTEST_GET_TESTS, NULL, 0, Buffer, BufferSize, &BytesRead, NULL))
132 error_goto(Error, cleanup);
133 if (BytesRead < BufferSize)
134 break;
135
136 HeapFree(GetProcessHeap(), 0, Buffer);
137 BufferSize *= 2;
138 }
139
140 // output test list plus user-mode tests
141 TestName = Buffer;
142 while (TestEntry->TestName || *TestName)
143 {
144 if (!TestEntry->TestName)
145 {
146 NextTestName = TestName;
147 TestName += strlen(TestName) + 1;
148 }
149 else if (!*TestName)
150 {
151 NextTestName = TestEntry->TestName;
152 ++TestEntry;
153 }
154 else
155 {
156 INT Result = CompareTestNames(TestEntry->TestName, TestName);
157
158 if (Result == 0)
159 {
160 NextTestName = TestEntry->TestName;
161 TestName += strlen(TestName) + 1;
162 ++TestEntry;
163 }
164 else if (Result < 0)
165 {
166 NextTestName = TestEntry->TestName;
167 ++TestEntry;
168 }
169 else
170 {
171 NextTestName = TestName;
172 TestName += strlen(TestName) + 1;
173 }
174 }
175
176 if (IncludeHidden && NextTestName[0] == '-')
177 ++NextTestName;
178
179 if (NextTestName[0] != '-')
180 printf(" %s\n", NextTestName);
181 }
182
183cleanup:
184 if (Buffer)
185 HeapFree(GetProcessHeap(), 0, Buffer);
186
187 return Error;
188}
189
190/**
191 * @name FindTest
192 *
193 * Find a test in TestList by name.
194 *
195 * @param TestName
196 * Name of the test to look for. Case sensitive
197 *
198 * @return pointer to test function, or NULL if not found
199 */
200static
201PKMT_TESTFUNC
202FindTest(
203 IN PCSTR TestName)
204{
205 PCKMT_TEST TestEntry = TestList;
206
207 for (TestEntry = TestList; TestEntry->TestName; ++TestEntry)
208 {
209 PCSTR TestEntryName = TestEntry->TestName;
210
211 // skip leading '-' if present
212 if (*TestEntryName == '-')
213 ++TestEntryName;
214
215 if (!lstrcmpA(TestEntryName, TestName))
216 break;
217 }
218
219 return TestEntry->TestFunction;
220}
221
222/**
223 * @name OutputResult
224 *
225 * Output the test results in ResultBuffer to the console.
226 *
227 * @param TestName
228 * Name of the test whose result is to be printed
229 *
230 * @return Win32 error code
231 */
232static
233DWORD
234OutputResult(
235 IN PCSTR TestName)
236{
237 DWORD Error = ERROR_SUCCESS;
238 DWORD BytesWritten;
239 DWORD LogBufferLength;
240 DWORD Offset = 0;
241 /* A console window can't handle a single
242 * huge block of data, so split it up */
243 const DWORD BlockSize = 8 * 1024;
244
245 KmtFinishTest(TestName);
246
247 LogBufferLength = ResultBuffer->LogBufferLength;
248 for (Offset = 0; Offset < LogBufferLength; Offset += BlockSize)
249 {
250 DWORD Length = min(LogBufferLength - Offset, BlockSize);
251 if (!WriteFile(GetStdHandle(STD_OUTPUT_HANDLE), ResultBuffer->LogBuffer + Offset, Length, &BytesWritten, NULL))
252 error(Error);
253 }
254
255 return Error;
256}
257
258/**
259 * @name RunTest
260 *
261 * Run the named test and output its results.
262 *
263 * @param TestName
264 * Name of the test to run. Case sensitive
265 *
266 * @return Win32 error code
267 */
268static
269DWORD
270RunTest(
271 IN PCSTR TestName)
272{
273 DWORD Error = ERROR_SUCCESS;
274 PKMT_TESTFUNC TestFunction;
275 DWORD BytesRead;
276
277 assert(TestName != NULL);
278
279 if (!ResultBuffer)
280 {
281 ResultBuffer = KmtAllocateResultBuffer(RESULTBUFFER_SIZE);
282 if (!ResultBuffer)
283 error_goto(Error, cleanup);
284 if (!DeviceIoControl(KmtestHandle, IOCTL_KMTEST_SET_RESULTBUFFER, ResultBuffer, RESULTBUFFER_SIZE, NULL, 0, &BytesRead, NULL))
285 error_goto(Error, cleanup);
286 }
287
288 // check test list
289 TestFunction = FindTest(TestName);
290
291 if (TestFunction)
292 {
293 TestFunction();
294 goto cleanup;
295 }
296
297 // not found in user-mode test list, call driver
298 Error = KmtRunKernelTest(TestName);
299
300cleanup:
301 if (!Error)
302 Error = OutputResult(TestName);
303
304 return Error;
305}
306
307/**
308 * @name main
309 *
310 * Program entry point
311 *
312 * @param ArgCount
313 * @param Arguments
314 *
315 * @return EXIT_SUCCESS on success, EXIT_FAILURE on failure
316 */
317int
318main(
319 int ArgCount,
320 char **Arguments)
321{
322 INT Status = EXIT_SUCCESS;
323 DWORD Error = ERROR_SUCCESS;
324 PCSTR AppName = "kmtest.exe";
325 PCSTR TestName = NULL;
326 KMT_OPERATION Operation = KMT_DO_NOTHING;
327 BOOLEAN ShowHidden = FALSE;
328
329 Error = KmtServiceInit();
330 if (Error)
331 goto cleanup;
332
333 if (ArgCount >= 1)
334 AppName = Arguments[0];
335
336 if (ArgCount <= 1)
337 {
338 printf("Usage: %s <test_name> - run the specified test (creates/starts the driver(s) as appropriate)\n", AppName);
339 printf(" %s --list - list available tests\n", AppName);
340 printf(" %s --list-all - list available tests, including hidden\n", AppName);
341 printf(" %s <create|delete|start|stop> - manage the kmtest driver\n\n", AppName);
342 Operation = KMT_LIST_TESTS;
343 }
344 else
345 {
346 TestName = Arguments[1];
347 if (!lstrcmpA(TestName, "create"))
348 Error = KmtCreateService(SERVICE_NAME, SERVICE_PATH, SERVICE_DESCRIPTION, &KmtestServiceHandle);
349 else if (!lstrcmpA(TestName, "delete"))
350 Error = KmtDeleteService(SERVICE_NAME, &KmtestServiceHandle);
351 else if (!lstrcmpA(TestName, "start"))
352 Error = KmtStartService(SERVICE_NAME, &KmtestServiceHandle);
353 else if (!lstrcmpA(TestName, "stop"))
354 Error = KmtStopService(SERVICE_NAME, &KmtestServiceHandle);
355
356 else if (!lstrcmpA(TestName, "--list"))
357 Operation = KMT_LIST_TESTS;
358 else if (!lstrcmpA(TestName, "--list-all"))
359 Operation = KMT_LIST_ALL_TESTS;
360 else
361 Operation = KMT_RUN_TEST;
362 }
363
364 if (Operation)
365 {
366 Error = KmtCreateAndStartService(SERVICE_NAME, SERVICE_PATH, SERVICE_DESCRIPTION, &KmtestServiceHandle, FALSE);
367 if (Error)
368 goto cleanup;
369
370 KmtestHandle = CreateFile(KMTEST_DEVICE_PATH, GENERIC_READ | GENERIC_WRITE, 0, NULL, OPEN_EXISTING, 0, NULL);
371 if (KmtestHandle == INVALID_HANDLE_VALUE)
372 error_goto(Error, cleanup);
373
374 switch (Operation)
375 {
376 case KMT_LIST_ALL_TESTS:
377 ShowHidden = TRUE;
378 /* fall through */
379 case KMT_LIST_TESTS:
380 Error = ListTests(ShowHidden);
381 break;
382 case KMT_RUN_TEST:
383 Error = RunTest(TestName);
384 break;
385 default:
386 assert(FALSE);
387 }
388 }
389
390cleanup:
391 if (KmtestHandle)
392 CloseHandle(KmtestHandle);
393
394 if (ResultBuffer)
395 KmtFreeResultBuffer(ResultBuffer);
396
397 KmtCloseService(&KmtestServiceHandle);
398
399 if (Error)
400 KmtServiceCleanup(TRUE);
401 else
402 Error = KmtServiceCleanup(FALSE);
403
404 if (Error)
405 {
406 OutputError(Error);
407
408 Status = EXIT_FAILURE;
409 }
410
411 return Status;
412}