Skip to content

Commit 69c5783

Browse files
Added arg checking to functions in dpctl_sycl_usm_interface.cpp
1 parent 8fed78b commit 69c5783

File tree

1 file changed

+73
-6
lines changed

1 file changed

+73
-6
lines changed

dpctl-capi/source/dpctl_sycl_usm_interface.cpp

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,24 +44,46 @@ DEFINE_SIMPLE_CONVERSION_FUNCTIONS(void, DPCTLSyclUSMRef)
4444
__dpctl_give DPCTLSyclUSMRef
4545
DPCTLmalloc_shared(size_t size, __dpctl_keep const DPCTLSyclQueueRef QRef)
4646
{
47-
auto Q = unwrap(QRef);
48-
auto Ptr = malloc_shared(size, *Q);
49-
return wrap(Ptr);
47+
if (!QRef) {
48+
std::cerr << "Input QRef is nullptr\n";
49+
return nullptr;
50+
}
51+
try {
52+
auto Q = unwrap(QRef);
53+
auto Ptr = malloc_shared(size, *Q);
54+
return wrap(Ptr);
55+
} catch (feature_not_supported const &fns) {
56+
std::cerr << fns.what() << '\n';
57+
return nullptr;
58+
}
5059
}
5160

5261
__dpctl_give DPCTLSyclUSMRef
5362
DPCTLaligned_alloc_shared(size_t alignment,
5463
size_t size,
5564
__dpctl_keep const DPCTLSyclQueueRef QRef)
5665
{
57-
auto Q = unwrap(QRef);
58-
auto Ptr = aligned_alloc_shared(alignment, size, *Q);
59-
return wrap(Ptr);
66+
if (!QRef) {
67+
std::cerr << "Input QRef is nullptr\n";
68+
return nullptr;
69+
}
70+
try {
71+
auto Q = unwrap(QRef);
72+
auto Ptr = aligned_alloc_shared(alignment, size, *Q);
73+
return wrap(Ptr);
74+
} catch (feature_not_supported const &fns) {
75+
std::cerr << fns.what() << '\n';
76+
return nullptr;
77+
}
6078
}
6179

6280
__dpctl_give DPCTLSyclUSMRef
6381
DPCTLmalloc_host(size_t size, __dpctl_keep const DPCTLSyclQueueRef QRef)
6482
{
83+
if (!QRef) {
84+
std::cerr << "Input QRef is nullptr\n";
85+
return nullptr;
86+
}
6587
auto Q = unwrap(QRef);
6688
auto Ptr = malloc_host(size, *Q);
6789
return wrap(Ptr);
@@ -72,6 +94,10 @@ DPCTLaligned_alloc_host(size_t alignment,
7294
size_t size,
7395
__dpctl_keep const DPCTLSyclQueueRef QRef)
7496
{
97+
if (!QRef) {
98+
std::cerr << "Input QRef is nullptr\n";
99+
return nullptr;
100+
}
75101
auto Q = unwrap(QRef);
76102
auto Ptr = aligned_alloc_host(alignment, size, *Q);
77103
return wrap(Ptr);
@@ -80,6 +106,10 @@ DPCTLaligned_alloc_host(size_t alignment,
80106
__dpctl_give DPCTLSyclUSMRef
81107
DPCTLmalloc_device(size_t size, __dpctl_keep const DPCTLSyclQueueRef QRef)
82108
{
109+
if (!QRef) {
110+
std::cerr << "Input QRef is nullptr\n";
111+
return nullptr;
112+
}
83113
auto Q = unwrap(QRef);
84114
auto Ptr = malloc_device(size, *Q);
85115
return wrap(Ptr);
@@ -90,6 +120,10 @@ DPCTLaligned_alloc_device(size_t alignment,
90120
size_t size,
91121
__dpctl_keep const DPCTLSyclQueueRef QRef)
92122
{
123+
if (!QRef) {
124+
std::cerr << "Input QRef is nullptr\n";
125+
return nullptr;
126+
}
93127
auto Q = unwrap(QRef);
94128
auto Ptr = aligned_alloc_device(alignment, size, *Q);
95129
return wrap(Ptr);
@@ -98,6 +132,14 @@ DPCTLaligned_alloc_device(size_t alignment,
98132
void DPCTLfree_with_queue(__dpctl_take DPCTLSyclUSMRef MRef,
99133
__dpctl_keep const DPCTLSyclQueueRef QRef)
100134
{
135+
if (!QRef) {
136+
std::cerr << "Input QRef is nullptr\n";
137+
return;
138+
}
139+
if (!MRef) {
140+
std::cerr << "Input MRef is nullptr, nothing to free\n";
141+
return;
142+
}
101143
auto Ptr = unwrap(MRef);
102144
auto Q = unwrap(QRef);
103145
free(Ptr, *Q);
@@ -106,6 +148,14 @@ void DPCTLfree_with_queue(__dpctl_take DPCTLSyclUSMRef MRef,
106148
void DPCTLfree_with_context(__dpctl_take DPCTLSyclUSMRef MRef,
107149
__dpctl_keep const DPCTLSyclContextRef CRef)
108150
{
151+
if (!CRef) {
152+
std::cerr << "Input CRef is nullptr\n";
153+
return;
154+
}
155+
if (!MRef) {
156+
std::cerr << "Input MRef is nullptr, nothing to free\n";
157+
return;
158+
}
109159
auto Ptr = unwrap(MRef);
110160
auto C = unwrap(CRef);
111161
free(Ptr, *C);
@@ -114,6 +164,14 @@ void DPCTLfree_with_context(__dpctl_take DPCTLSyclUSMRef MRef,
114164
const char *DPCTLUSM_GetPointerType(__dpctl_keep const DPCTLSyclUSMRef MRef,
115165
__dpctl_keep const DPCTLSyclContextRef CRef)
116166
{
167+
if (!CRef) {
168+
std::cerr << "Input CRef is nullptr\n";
169+
return "unknown";
170+
}
171+
if (!MRef) {
172+
std::cerr << "Input MRef is nullptr\n";
173+
return "unknown";
174+
}
117175
auto Ptr = unwrap(MRef);
118176
auto C = unwrap(CRef);
119177

@@ -134,6 +192,15 @@ DPCTLSyclDeviceRef
134192
DPCTLUSM_GetPointerDevice(__dpctl_keep const DPCTLSyclUSMRef MRef,
135193
__dpctl_keep const DPCTLSyclContextRef CRef)
136194
{
195+
if (!CRef) {
196+
std::cerr << "Input CRef is nullptr\n";
197+
return nullptr;
198+
}
199+
if (!MRef) {
200+
std::cerr << "Input MRef is nullptr\n";
201+
return nullptr;
202+
}
203+
137204
auto Ptr = unwrap(MRef);
138205
auto C = unwrap(CRef);
139206

0 commit comments

Comments
 (0)