this repo has no description
1// Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
2#include "capi-testing.h"
3
4#include <cstring>
5#include <memory>
6
7#include "Python.h"
8#include "gtest/gtest.h"
9
10namespace py {
11namespace testing {
12
13Borrowed borrow(PyObject* obj) { return Borrowed(obj); }
14
15void collectGarbage() {
16 PyRun_SimpleString(R"(
17try:
18 from _builtins import _gc
19 _gc()
20except:
21 pass
22)");
23}
24
25PyObject* mainModuleGet(const char* name) {
26 return moduleGet("__main__", name);
27}
28
29PyObject* moduleGet(const char* module, const char* name) {
30 PyObject* mods = PyImport_GetModuleDict();
31 PyObject* module_name = PyUnicode_FromString(module);
32 PyObject* mod = PyDict_GetItem(mods, module_name);
33 if (mod == nullptr) return nullptr;
34 Py_DECREF(module_name);
35 return PyObject_GetAttrString(mod, name);
36}
37
38int moduleSet(const char* module, const char* name, PyObject* value) {
39 PyObject* mods = PyImport_GetModuleDict();
40 PyObject* module_name = PyUnicode_FromString(module);
41 PyObject* mod = PyDict_GetItem(mods, module_name);
42 if (mod == nullptr && strcmp(module, "__main__") == 0) {
43 // create __main__ if not yet available
44 PyRun_SimpleString("");
45 mod = PyDict_GetItem(mods, module_name);
46 }
47 if (mod == nullptr) return -1;
48 Py_DECREF(module_name);
49
50 PyObject* name_obj = PyUnicode_FromString(name);
51 int ret = PyObject_SetAttr(mod, name_obj, value);
52 Py_DECREF(name_obj);
53 return ret;
54}
55
56PyObject* importGetModule(PyObject* name) {
57 PyObject* modules_dict = PyImport_GetModuleDict();
58 PyObject* module = PyDict_GetItem(modules_dict, name);
59 Py_XINCREF(module); // Return a new reference
60 return module;
61}
62
63template <typename T>
64static ::testing::AssertionResult failNullObj(const T& expected,
65 const char* delim) {
66 PyObjectPtr exception(PyErr_Occurred());
67 Py_INCREF(exception);
68 if (exception != nullptr) {
69 PyErr_Clear();
70 PyObjectPtr exception_repr(PyObject_Repr(exception));
71 if (exception_repr != nullptr) {
72 const char* exception_cstr = PyUnicode_AsUTF8(exception_repr);
73 if (exception_cstr != nullptr) {
74 return ::testing::AssertionFailure()
75 << "pending exception: " << exception_cstr;
76 }
77 }
78 }
79 return ::testing::AssertionFailure()
80 << "nullptr is not equal to " << delim << expected << delim;
81}
82
83template <typename T>
84static ::testing::AssertionResult failBadValue(PyObject* obj, const T& expected,
85 const char* delim) {
86 PyObjectPtr repr_str(PyObject_Repr(obj));
87 const char* repr_cstr = nullptr;
88 if (repr_str != nullptr) {
89 repr_cstr = PyUnicode_AsUTF8(repr_str);
90 }
91 repr_cstr = repr_cstr == nullptr ? "NULL" : repr_cstr;
92 return ::testing::AssertionFailure()
93 << repr_cstr << " is not equal to " << delim << expected << delim;
94}
95
96::testing::AssertionResult isBytesEqualsCStr(PyObject* obj, const char* c_str) {
97 if (obj == nullptr) return failNullObj(c_str, "'");
98
99 if (!PyBytes_Check(obj) || std::strcmp(PyBytes_AsString(obj), c_str) != 0) {
100 return failBadValue(obj, c_str, "'");
101 }
102 return ::testing::AssertionSuccess();
103}
104
105::testing::AssertionResult isLongEqualsLong(PyObject* obj, long value) {
106 if (obj == nullptr) return failNullObj(value, "");
107
108 if (PyLong_Check(obj)) {
109 long longval = PyLong_AsLong(obj);
110 if (longval == -1 && PyErr_Occurred() != nullptr) {
111 PyErr_Clear();
112 } else if (longval == value) {
113 return ::testing::AssertionSuccess();
114 }
115 }
116
117 return failBadValue(obj, value, "");
118}
119
120::testing::AssertionResult isUnicodeEqualsCStr(PyObject* obj,
121 const char* c_str) {
122 if (obj == nullptr) return failNullObj(c_str, "'");
123
124 if (!PyUnicode_Check(obj)) {
125 return failBadValue(obj, c_str, "'");
126 }
127 PyObjectPtr expected(PyUnicode_FromString(c_str));
128 if (PyUnicode_Compare(obj, expected) != 0) {
129 return failBadValue(obj, c_str, "'");
130 }
131 return ::testing::AssertionSuccess();
132}
133
134CaptureStdStreams::CaptureStdStreams() {
135 ::testing::internal::CaptureStdout();
136 ::testing::internal::CaptureStderr();
137}
138
139CaptureStdStreams::~CaptureStdStreams() {
140 // Print any unread buffers to their respective streams to assist in
141 // debugging.
142 if (!restored_stdout_) std::cout << out();
143 if (!restored_stderr_) std::cerr << err();
144}
145
146std::string CaptureStdStreams::out() {
147 assert(!restored_stdout_);
148 PyObject *exc, *value, *tb;
149 PyErr_Fetch(&exc, &value, &tb);
150 PyRun_SimpleString(R"(
151import sys
152if hasattr(sys, "stdout") and hasattr(sys.stdout, "flush"):
153 sys.stdout.flush()
154)");
155 PyErr_Restore(exc, value, tb);
156 restored_stdout_ = true;
157 return ::testing::internal::GetCapturedStdout();
158}
159
160std::string CaptureStdStreams::err() {
161 assert(!restored_stderr_);
162 PyObject *exc, *value, *tb;
163 PyErr_Fetch(&exc, &value, &tb);
164 PyRun_SimpleString(R"(
165import sys
166if hasattr(sys, "stderr") and hasattr(sys.stderr, "flush"):
167 sys.stderr.flush()
168)");
169 PyErr_Restore(exc, value, tb);
170 restored_stderr_ = true;
171 return ::testing::internal::GetCapturedStderr();
172}
173
174TempDirectory::TempDirectory() : TempDirectory("PYRO_TEST") {}
175
176TempDirectory::TempDirectory(const char* prefix) {
177 const char* tmpdir = std::getenv("TMPDIR");
178 if (tmpdir == nullptr) {
179 tmpdir = "/tmp/";
180 }
181 const char* format = "%s%s.XXXXXXXX";
182 int length = std::snprintf(nullptr, 0, format, tmpdir, prefix);
183
184 std::unique_ptr<char[]> buffer(new char[length]);
185 std::snprintf(buffer.get(), length, format, tmpdir, prefix);
186 char* result(::mkdtemp(buffer.get()));
187 assert(result != nullptr);
188 path_ = result;
189 assert(!path_.empty());
190 if (path_.back() != '/') path_ += "/";
191}
192
193TempDirectory::~TempDirectory() {
194 std::string cleanup = "rm -rf " + path_;
195 int result = system(cleanup.c_str());
196 (void)result;
197 assert(result == 0);
198}
199
200} // namespace testing
201} // namespace py