Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
encode str in python3
Browse files Browse the repository at this point in the history
  • Loading branch information
ShigekiKarita authored Jul 18, 2017
1 parent a9cf230 commit 6417a28
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions pynvrtc/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# SOFTWARE.


import sys
from ctypes import (
POINTER,
c_int,
Expand All @@ -34,6 +35,8 @@
from platform import system


is_python2 = sys.version_info.major == 2

# NVRTC status codes
NVRTC_SUCCESS = 0
NVRTC_ERROR_OUT_OF_MEMORY = 1
Expand All @@ -45,6 +48,18 @@
NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7


def encode_str(s):
if is_python2:
return s
return s.encode("utf-8")


def encode_str_list(str_list):
if is_python2:
return str_list
return list(map(encode_str, str_list))


class NVRTCException(Exception):
"""
Exception wrapper for NVRTC error codes.
Expand Down Expand Up @@ -178,11 +193,11 @@ def nvrtcCreateProgram(self, src, name, headers, include_names):
"""
res = c_void_p()
headers_array = (c_char_p * len(headers))()
headers_array[:] = headers
headers_array[:] = encode_str_list(headers)
include_names_array = (c_char_p * len(include_names))()
include_names_array[:] = include_names
include_names_array[:] = encode_str_list(include_names)
code = self._lib.nvrtcCreateProgram(byref(res),
c_char_p(src), c_char_p(name),
c_char_p(encode_str(src)), c_char_p(encode_str(name)),
len(headers),
headers_array, include_names_array)
self._throw_on_error(code)
Expand All @@ -202,7 +217,7 @@ def nvrtcCompileProgram(self, prog, options):
array. See the NVRTC API documentation for accepted options.
"""
options_array = (c_char_p * len(options))()
options_array[:] = options
options_array[:] = encode_str_list(options)
code = self._lib.nvrtcCompileProgram(prog, len(options), options_array)
self._throw_on_error(code)
return
Expand Down Expand Up @@ -243,7 +258,7 @@ def nvrtcAddNameExpression(self, prog, name_expression):
function template instantiation.
"""
code = self._lib.nvrtcAddNameExpression(prog,
c_char_p(name_expression))
c_char_p(encode_str(name_expression)))
self._throw_on_error(code)
return

Expand All @@ -254,7 +269,7 @@ def nvrtcGetLoweredName(self, prog, name_expression):
"""
lowered_name = c_char_p()
code = self._lib.nvrtcGetLoweredName(prog,
c_char_p(name_expression),
c_char_p(encode_str(name_expression)),
byref(lowered_name))
self._throw_on_error(code)
return lowered_name.value.decode('utf-8')
Expand Down

0 comments on commit 6417a28

Please sign in to comment.