Skip to content

Commit

Permalink
Fix Normalize_TRT plugin segfault
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Chen <kevinch@nvidia.com>
  • Loading branch information
samurdhikaru authored and kevinch-nv committed Jul 22, 2022
1 parent 7818985 commit d90e0d1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
26 changes: 20 additions & 6 deletions plugin/normalizePlugin/normalizePlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ const char* NORMALIZE_PLUGIN_NAME{"Normalize_TRT"};
PluginFieldCollection NormalizePluginCreator::mFC{};
std::vector<PluginField> NormalizePluginCreator::mPluginAttributes;

Normalize::Normalize(const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps)
Normalize::Normalize(Weights const* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps)
: acrossSpatial(acrossSpatial)
, channelShared(channelShared)
, eps(eps)
Expand All @@ -44,11 +44,13 @@ Normalize::Normalize(const Weights* weights, int nbWeights, bool acrossSpatial,
PLUGIN_VALIDATE(nbWeights == 1);
PLUGIN_VALIDATE(weights[0].count >= 1);
mWeights = copyToDevice(weights[0].values, weights[0].count);
mScalarScale = static_cast<float const*>(weights[0].values)[0];
}

Normalize::Normalize(
const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W)
: acrossSpatial(acrossSpatial)
Weights const* weights, int nbWeights, float scalarScale, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W)
: mScalarScale(scalarScale)
, acrossSpatial(acrossSpatial)
, channelShared(channelShared)
, eps(eps)
, C(C)
Expand All @@ -74,6 +76,7 @@ Normalize::Normalize(const void* buffer, size_t length)

mNbWeights = read<int>(d);
int count = read<int>(d);
std::memcpy(&mScalarScale, d, sizeof(float));
mWeights = deserializeToDevice(d, count);
PLUGIN_VALIDATE(d == a + length);
}
Expand Down Expand Up @@ -111,8 +114,19 @@ int Normalize::enqueue(
{
const void* inputData = inputs[0];
void* outputData = outputs[0];
pluginStatus_t status = normalizeInference(stream, mCublas, acrossSpatial, channelShared, batchSize, C, H, W, eps,
static_cast<const float*>(mWeights.values), inputData, outputData, workspace);

pluginStatus_t status;

if(acrossSpatial && channelShared) // Since cublasPointerMode_t is CUBLAS_POINTER_MODE_HOST, scale should be on the host
{
status = normalizeInference(stream, mCublas, acrossSpatial, channelShared, batchSize, C, H, W, eps,
&mScalarScale, inputData, outputData, workspace);
}
else // No risk of device pointers being passed to cublas as alpha or beta
{
status = normalizeInference(stream, mCublas, acrossSpatial, channelShared, batchSize, C, H, W, eps,
static_cast<float const*>(mWeights.values), inputData, outputData, workspace);
}

return status;
}
Expand Down Expand Up @@ -254,7 +268,7 @@ IPluginV2Ext* Normalize::clone() const noexcept
try
{
// Create a new instance
IPluginV2Ext* plugin = new Normalize(&mWeights, mNbWeights, acrossSpatial, channelShared, eps, C, H, W);
IPluginV2Ext* plugin = new Normalize(&mWeights, mNbWeights, mScalarScale, acrossSpatial, channelShared, eps, C, H, W);

// Set the namespace
plugin->setPluginNamespace(mPluginNamespace.c_str());
Expand Down
7 changes: 4 additions & 3 deletions plugin/normalizePlugin/normalizePlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ namespace plugin
class Normalize : public IPluginV2Ext
{
public:
Normalize(const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps);
Normalize(Weights const* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps);

Normalize(
const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W);
Weights const* weights, int nbWeights, float scalarScale, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W);

Normalize(const void* buffer, size_t length);

Expand Down Expand Up @@ -93,8 +93,9 @@ class Normalize : public IPluginV2Ext

cublasHandle_t mCublas;

Weights mWeights{};
Weights mWeights{}; // mWeights.values is on the device
int mNbWeights{};
float mScalarScale{}; // keep track of scale on the host (for when channelShared is true)
bool acrossSpatial{};
bool channelShared{};
float eps{};
Expand Down

0 comments on commit d90e0d1

Please sign in to comment.