this repo has no description
at trunk 201 lines 5.9 kB view raw
1// Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) 2#include "capi-testing.h" 3 4#include <cstring> 5#include <memory> 6 7#include "Python.h" 8#include "gtest/gtest.h" 9 10namespace py { 11namespace testing { 12 13Borrowed borrow(PyObject* obj) { return Borrowed(obj); } 14 15void collectGarbage() { 16 PyRun_SimpleString(R"( 17try: 18 from _builtins import _gc 19 _gc() 20except: 21 pass 22)"); 23} 24 25PyObject* mainModuleGet(const char* name) { 26 return moduleGet("__main__", name); 27} 28 29PyObject* moduleGet(const char* module, const char* name) { 30 PyObject* mods = PyImport_GetModuleDict(); 31 PyObject* module_name = PyUnicode_FromString(module); 32 PyObject* mod = PyDict_GetItem(mods, module_name); 33 if (mod == nullptr) return nullptr; 34 Py_DECREF(module_name); 35 return PyObject_GetAttrString(mod, name); 36} 37 38int moduleSet(const char* module, const char* name, PyObject* value) { 39 PyObject* mods = PyImport_GetModuleDict(); 40 PyObject* module_name = PyUnicode_FromString(module); 41 PyObject* mod = PyDict_GetItem(mods, module_name); 42 if (mod == nullptr && strcmp(module, "__main__") == 0) { 43 // create __main__ if not yet available 44 PyRun_SimpleString(""); 45 mod = PyDict_GetItem(mods, module_name); 46 } 47 if (mod == nullptr) return -1; 48 Py_DECREF(module_name); 49 50 PyObject* name_obj = PyUnicode_FromString(name); 51 int ret = PyObject_SetAttr(mod, name_obj, value); 52 Py_DECREF(name_obj); 53 return ret; 54} 55 56PyObject* importGetModule(PyObject* name) { 57 PyObject* modules_dict = PyImport_GetModuleDict(); 58 PyObject* module = PyDict_GetItem(modules_dict, name); 59 Py_XINCREF(module); // Return a new reference 60 return module; 61} 62 63template <typename T> 64static ::testing::AssertionResult failNullObj(const T& expected, 65 const char* delim) { 66 PyObjectPtr exception(PyErr_Occurred()); 67 Py_INCREF(exception); 68 if (exception != nullptr) { 69 PyErr_Clear(); 70 PyObjectPtr exception_repr(PyObject_Repr(exception)); 71 if (exception_repr != nullptr) { 72 const char* exception_cstr = PyUnicode_AsUTF8(exception_repr); 73 if (exception_cstr != nullptr) { 74 return ::testing::AssertionFailure() 75 << "pending exception: " << exception_cstr; 76 } 77 } 78 } 79 return ::testing::AssertionFailure() 80 << "nullptr is not equal to " << delim << expected << delim; 81} 82 83template <typename T> 84static ::testing::AssertionResult failBadValue(PyObject* obj, const T& expected, 85 const char* delim) { 86 PyObjectPtr repr_str(PyObject_Repr(obj)); 87 const char* repr_cstr = nullptr; 88 if (repr_str != nullptr) { 89 repr_cstr = PyUnicode_AsUTF8(repr_str); 90 } 91 repr_cstr = repr_cstr == nullptr ? "NULL" : repr_cstr; 92 return ::testing::AssertionFailure() 93 << repr_cstr << " is not equal to " << delim << expected << delim; 94} 95 96::testing::AssertionResult isBytesEqualsCStr(PyObject* obj, const char* c_str) { 97 if (obj == nullptr) return failNullObj(c_str, "'"); 98 99 if (!PyBytes_Check(obj) || std::strcmp(PyBytes_AsString(obj), c_str) != 0) { 100 return failBadValue(obj, c_str, "'"); 101 } 102 return ::testing::AssertionSuccess(); 103} 104 105::testing::AssertionResult isLongEqualsLong(PyObject* obj, long value) { 106 if (obj == nullptr) return failNullObj(value, ""); 107 108 if (PyLong_Check(obj)) { 109 long longval = PyLong_AsLong(obj); 110 if (longval == -1 && PyErr_Occurred() != nullptr) { 111 PyErr_Clear(); 112 } else if (longval == value) { 113 return ::testing::AssertionSuccess(); 114 } 115 } 116 117 return failBadValue(obj, value, ""); 118} 119 120::testing::AssertionResult isUnicodeEqualsCStr(PyObject* obj, 121 const char* c_str) { 122 if (obj == nullptr) return failNullObj(c_str, "'"); 123 124 if (!PyUnicode_Check(obj)) { 125 return failBadValue(obj, c_str, "'"); 126 } 127 PyObjectPtr expected(PyUnicode_FromString(c_str)); 128 if (PyUnicode_Compare(obj, expected) != 0) { 129 return failBadValue(obj, c_str, "'"); 130 } 131 return ::testing::AssertionSuccess(); 132} 133 134CaptureStdStreams::CaptureStdStreams() { 135 ::testing::internal::CaptureStdout(); 136 ::testing::internal::CaptureStderr(); 137} 138 139CaptureStdStreams::~CaptureStdStreams() { 140 // Print any unread buffers to their respective streams to assist in 141 // debugging. 142 if (!restored_stdout_) std::cout << out(); 143 if (!restored_stderr_) std::cerr << err(); 144} 145 146std::string CaptureStdStreams::out() { 147 assert(!restored_stdout_); 148 PyObject *exc, *value, *tb; 149 PyErr_Fetch(&exc, &value, &tb); 150 PyRun_SimpleString(R"( 151import sys 152if hasattr(sys, "stdout") and hasattr(sys.stdout, "flush"): 153 sys.stdout.flush() 154)"); 155 PyErr_Restore(exc, value, tb); 156 restored_stdout_ = true; 157 return ::testing::internal::GetCapturedStdout(); 158} 159 160std::string CaptureStdStreams::err() { 161 assert(!restored_stderr_); 162 PyObject *exc, *value, *tb; 163 PyErr_Fetch(&exc, &value, &tb); 164 PyRun_SimpleString(R"( 165import sys 166if hasattr(sys, "stderr") and hasattr(sys.stderr, "flush"): 167 sys.stderr.flush() 168)"); 169 PyErr_Restore(exc, value, tb); 170 restored_stderr_ = true; 171 return ::testing::internal::GetCapturedStderr(); 172} 173 174TempDirectory::TempDirectory() : TempDirectory("PYRO_TEST") {} 175 176TempDirectory::TempDirectory(const char* prefix) { 177 const char* tmpdir = std::getenv("TMPDIR"); 178 if (tmpdir == nullptr) { 179 tmpdir = "/tmp/"; 180 } 181 const char* format = "%s%s.XXXXXXXX"; 182 int length = std::snprintf(nullptr, 0, format, tmpdir, prefix); 183 184 std::unique_ptr<char[]> buffer(new char[length]); 185 std::snprintf(buffer.get(), length, format, tmpdir, prefix); 186 char* result(::mkdtemp(buffer.get())); 187 assert(result != nullptr); 188 path_ = result; 189 assert(!path_.empty()); 190 if (path_.back() != '/') path_ += "/"; 191} 192 193TempDirectory::~TempDirectory() { 194 std::string cleanup = "rm -rf " + path_; 195 int result = system(cleanup.c_str()); 196 (void)result; 197 assert(result == 0); 198} 199 200} // namespace testing 201} // namespace py