Skip to content

Commit

Permalink
bug fixes & doc update
Browse files Browse the repository at this point in the history
  • Loading branch information
kice committed May 1, 2020
1 parent 46a6b6f commit 26a7b79
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 137 deletions.
29 changes: 14 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,20 @@ if not hasattr(core, 'mx'):
# Your code goes here
```

Python will try to help use load all require dlls (like, MXNet and CUDA). If you delete `core.std.LoadPlugin`, it will still work for vsedit but not work under vspipe.
Due to Vapoursynth DLL loading method, by `import mxnet`, Python will try to help load all require dlls (like, MXNet and CUDA). If you delete `core.std.LoadPlugin`, it will still work for vsedit but not work under vspipe.

Usage
=====

mx.Predict(clip clip, string symbol, string param[, float scale=1.0, int patch_w=0, int patch_h=0, int output_w=128, int output_h=block_w, int frame_w=3, int frame_h=True, int step_w=0, int step_h=0, int outstep_w=0, int outstep_h=0, int padding=0, int border_type=1, int ctx=0, int dev_id=0])
mx.Predict(clip clip, string symbol, string param[, int scale=1, int patch_w=0, int patch_h=0, int output_w=128, int output_h=block_w, int frame_w=3, int frame_h=True, int step_w=0, int step_h=0, int outstep_w=0, int outstep_h=0, int padding=0, int border_type=1, int ctx=0, int dev_id=0])

* clip: Clip to process. Only planar format is float sample type of 32 bit depth is supported. RGB and GRAY is supported. YUV is not correctly supported.
* clip: Clip to process. Only planar format is float32 or int8 supported. RGB and GRAY is supported, YUV is not correctly supported.

* symbol: MXNet symbol json file. If the plugin cannot read the file, it will try to read it from `plugins64\mxnet-symbol\`. You can find more MXNet model [here](https://github.com/WolframRhodium/Super-Resolution-Zoo).
* symbol: MXNet symbol json file. If the plugin cannot read the file, it will try to read it from `plugins64\mxnet-symbol\`. You can find more MXNet models [here](https://github.com/WolframRhodium/Super-Resolution-Zoo).

* param: The same as `symbol`, but for model parameters data.

* scale: Set output shape and final frame shape to twice of patch and input clip. It will be ignore if you manully set corresponding parameters.
* scale: Set output shape and final frame shape form the shape of patch and input clip. It will be ignore if you manully set corresponding parameters. default: `1`

* patch_w: The horizontal block size for dividing the image during processing. Smaller value results in lower VRAM usage, while larger value may not necessarily give faster speed. The optimal value may vary according to different graphics card and image size. If patch_h is larger than clip's width, it will clamp to clip's width. default: clip's width.

Expand Down Expand Up @@ -96,7 +96,7 @@ Usage
* input_name: Set input name. Most MXNet model use `data` as input name. defalut: `data`.

* ctx: Specifies which type of device to use. If GPU was chosen, cuDNN will be used by defalut.
* 1 = CPU
* 1 = CPU (default)
* 2 = GPU

* dev_id: Which device to use. Starting with 0.
Expand All @@ -120,7 +120,7 @@ sr2x = core.mx.Predict(src, symbol='Some2x-symbol.json', param='Some2x-0000.para
# run Waifu2x 2x upconv model with pre-padding, patch size=400x300 on second GPU, output size is 1920x1080
waifu2x = core.mx.Predict(clip, symbol=r'noise0_scale2.0x_model-symbol.json',
param=r'noise0_scale2.0x_model-0000.params',
patch_w=patch_w, patch_h=patch_h,
patch_w=patch_w, patch_h=patch_h,
output_w=patch_w*2, output_h=patch_h*2,
frame_w=1920, frame_h=1080,
step_w=patch_w, step_h=patch_h,
Expand Down Expand Up @@ -236,16 +236,15 @@ block_h = src.height
scale = 2
# Waifu2x need to set pad=7, other model dose not have to set padding
pad = 0
# Waifu2x symbol file should comes with padding
def process(clip, gpu):
return core.mx.Predict(clip, symbol=symbol, param=param,
patch_w = block_w + pad*2, patch_h = block_h + pad*2,
patch_w = block_w, patch_h = block_h,
output_w = block_w*scale, output_h = block_h*scale,
frame_w = clip.width*scale, frame_h = clip.height*scale,
step_w = block_w, step_h = block_h,
padding = pad, ctx = 2, dev_id = gpu)
frame_w = clip.width*scale, frame_h = clip.height*scale,
step_w = block_w, step_h = block_h,
ctx = 2, dev_id = gpu)
queue_size = 3
gpus = 2
Expand All @@ -272,7 +271,7 @@ Limitation

5. MXNet will take some time for cudnn auto tuning for convolution layers every time. set MXNET_CUDNN_AUTOTUNE_DEFAULT=0 to disable it. More info [here](https://mxnet.incubator.apache.org/faq/env_var.html).

6. Please remember that during feeding the first frame, MXNet will allocate very large VRAM block, you might get **Out of Memory** error. Please reduce the patch size to solve it.
6. MXNet will allocate VRAM when feeding the first frame, you might get **Out of Memory** error. Reducing the patch size may solve it.

7. You might need to restart the program (e.g. vsedit) after you changing the input model file.

Expand All @@ -281,4 +280,4 @@ Compilation

There are some code to bypass Vapoursynth plugin loading system, which only works on Windows. You can remove that part and replace all MXNet function calls with normal calls will work on other system. All the header you need is here [`MXNet C predict API`](https://github.com/apache/incubator-mxnet/tree/master/include/mxnet)

On Windows, the plugins uses `LoadLibrary` to dynamically load MXNet, no need for MXNet header to compile.
On Windows, the plugins uses `LoadLibrary` to dynamically load MXNet, no need for MXNet header to compile.
178 changes: 66 additions & 112 deletions vs_mxnet/vsMXNet.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <fstream>
#include <string>
#include <algorithm>
#include <vector>

#include <VapourSynth/VapourSynth.h>
#include <VapourSynth/VSHelper.h>
Expand All @@ -17,11 +18,6 @@
#endif
#endif

#define DEFER_1(x, y) x##y
#define DEFER_2(x, y) DEFER_1(x, y)
#define DEFER_0(x) DEFER_2(x, __COUNTER__)
#define defer(expr) auto DEFER_0(_defered_option) = deferer([&](){expr;})

// no int8 and uint16
inline int VSFormatToMXDtype(const VSFormat *format)
{
Expand Down Expand Up @@ -52,27 +48,13 @@ inline int VSFormatToMXDtype(const VSFormat *format)
return -1;
}

template <typename Function>
struct doDefer
{
Function f;
doDefer(Function f) : f(f) {}
~doDefer() { f(); }
};

template <typename Function>
doDefer<Function> deferer(Function f)
{
return doDefer<Function>(f);
}

struct mxnetData
{
VSNodeRef *node;
VSVideoInfo vi;
int patch_w, patch_h;
int step_w, step_h;
float scale;
int scale;
int output_w, output_h;
int outstep_w, outstep_h;
int frame_w, frame_h;
Expand All @@ -81,49 +63,23 @@ struct mxnetData
PredictorHandle hPred;
};

class BufferFile
std::vector<char> ReadFile(const std::string &file_path)
{
public:
std::string file_path_;
size_t length_;
char* buffer_;

explicit BufferFile(std::string file_path)
:file_path_(file_path)
{
std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
if (!ifs) {
length_ = 0;
buffer_ = NULL;
return;
}

ifs.seekg(0, std::ios::end);
length_ = ifs.tellg();
ifs.seekg(0, std::ios::beg);

buffer_ = new char[sizeof(char) * length_];
ifs.read(buffer_, length_);
ifs.close();
std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
if (!ifs) {
return std::vector<char>();
}

size_t GetLength()
{
return length_;
}
char* GetBuffer()
{
return buffer_;
}
ifs.seekg(0, std::ios::end);
auto length = ifs.tellg();
ifs.seekg(0, std::ios::beg);

~BufferFile()
{
if (buffer_) {
delete[] buffer_;
buffer_ = NULL;
}
}
};
std::vector<char> buf(length);
ifs.read(buf.data(), length);
ifs.close();

return buf;
}

MXNet mx("libmxnet.dll");

Expand Down Expand Up @@ -170,53 +126,56 @@ static int process(const VSFrameRef *src, VSFrameRef *dst, mxnetData * VS_RESTRI
return 3;
}

int ch = d->vi.format->numPlanes;
int width = vsapi->getFrameWidth(src, 0);
int height = vsapi->getFrameHeight(src, 0);
const int ch = d->vi.format->numPlanes;
const int width = vsapi->getFrameWidth(src, 0);
const int height = vsapi->getFrameHeight(src, 0);

uint8_t **srcp = new uint8_t *[ch];
int *srcStride = new int[ch];
defer(delete[] srcp; delete[] srcStride;);
std::vector<const uint8_t *> srcp(ch);
std::vector<int> srcStride(ch);

std::vector<uint8_t *> dstp(ch);
std::vector<int> dstStride(ch);

for (int plane = 0; plane < ch; ++plane) {
auto _srcStride = vsapi->getStride(src, plane);
auto _srcp = vsapi->getReadPtr(src, plane);
srcp[plane] = vsapi->getReadPtr(src, plane);
dstp[plane] = vsapi->getWritePtr(dst, plane);

srcp[plane] = (uint8_t *)_srcp;
srcStride[plane] = _srcStride;
srcStride[plane] = vsapi->getStride(src, plane);
dstStride[plane] = vsapi->getStride(dst, plane);
}

int patchSize = d->patch_w * d->patch_h * d->in_elem;
int outputSize = d->output_w * d->output_h * d->out_elem;
int in_rowSize = d->patch_w * d->in_elem;
int out_rowSize = d->output_w * d->out_elem;
const int patch_size = d->patch_w * d->patch_h * d->in_elem;
const int output_size = d->output_w * d->output_h * d->out_elem;
const int in_stride = d->patch_w * d->in_elem;
const int out_stride = d->output_w * d->out_elem;

int x = 0, y = 0;
while (true) {
int sy = std::min(y * d->step_h, height - d->patch_h);
int ey = std::min(y * d->step_h + d->patch_h, height);
auto sy = std::min(y * d->step_h, height - d->patch_h);
auto ey = std::min(y * d->step_h + d->patch_h, height);

while (true) {
int sx = std::min(x * d->step_w, width - d->patch_w);
int ex = std::min(x * d->step_w + d->patch_w, width);
auto sx = std::min(x * d->step_w, width - d->patch_w);
auto ex = std::min(x * d->step_w + d->patch_w, width);

for (int plane = 0; plane < ch; ++plane) {
auto _srcp = srcp[plane] + sx + srcStride[plane] * sy;
auto buf = (uint8_t *)d->srcBuffer + patchSize * plane;
vs_bitblt(buf, in_rowSize, _srcp, srcStride[plane], in_rowSize, d->patch_h);
auto stride = srcStride[plane];
auto _srcp = srcp[plane] + sx * d->in_elem + sy * stride;
auto buf = (uint8_t *)d->srcBuffer + patch_size * plane;
vs_bitblt(buf, in_stride, _srcp, stride, in_stride, d->patch_h);
}

if (auto err = mxForward(d)) return err;

for (int plane = 0; plane < ch; ++plane) {
int dstoff_x = std::min(d->frame_w - d->output_w, x * d->outstep_w);
int dstoff_y = std::min(d->frame_h - d->output_h, y * d->outstep_h);

auto stride = vsapi->getStride(dst, plane);
auto dstp = vsapi->getWritePtr(dst, plane) + dstoff_x * d->out_elem + dstoff_y * stride;
auto dstoff_x = std::min(d->frame_w - d->output_w, x * d->outstep_w);
auto dstoff_y = std::min(d->frame_h - d->output_h, y * d->outstep_h);

auto outbuf = (uint8_t *)d->dstBuffer + outputSize * plane;
vs_bitblt(dstp, stride, outbuf, out_rowSize, out_rowSize, d->output_h);
for (int plane = 0; plane < ch; ++plane) {
auto stride = dstStride[plane];
auto _dstp = dstp[plane] + dstoff_x * d->out_elem + dstoff_y * stride;
auto outbuf = (uint8_t *)d->dstBuffer + output_size * plane;
vs_bitblt(_dstp, stride, outbuf, out_stride, out_stride, d->output_h);
}

if (ex == width) break;
Expand All @@ -233,7 +192,7 @@ static int process(const VSFrameRef *src, VSFrameRef *dst, mxnetData * VS_RESTRI

static const VSFrameRef *VS_CC mxGetFrame(int n, int activationReason, void **instanceData, void **frameData, VSFrameContext *frameCtx, VSCore *core, const VSAPI *vsapi)
{
mxnetData *d = (mxnetData *)* instanceData;
mxnetData *d = static_cast<mxnetData *>(*instanceData);

if (activationReason == arInitial) {
vsapi->requestFrameFilter(n, d->node, frameCtx);
Expand Down Expand Up @@ -269,7 +228,7 @@ static const VSFrameRef *VS_CC mxGetFrame(int n, int activationReason, void **in

static void VS_CC mxFree(void *instanceData, VSCore *core, const VSAPI *vsapi)
{
mxnetData *d = (mxnetData *)instanceData;
mxnetData *d = static_cast<mxnetData *>(instanceData);
vsapi->freeNode(d->node);

mx.MXPredFree(d->hPred);
Expand All @@ -282,7 +241,7 @@ static void VS_CC mxFree(void *instanceData, VSCore *core, const VSAPI *vsapi)

static void VS_CC mxInit(VSMap *in, VSMap *out, void **instanceData, VSNode *node, VSCore *core, const VSAPI *vsapi)
{
mxnetData * d = static_cast<mxnetData *>(*instanceData);
mxnetData *d = static_cast<mxnetData *>(*instanceData);
vsapi->setVideoInfo(&d->vi, 1, node);
}

Expand Down Expand Up @@ -341,9 +300,9 @@ static void VS_CC mxCreate(const VSMap *in, VSMap *out, void *userData, VSCore *
d.step_h = height;

// Scale
d.scale = static_cast<float>(vsapi->propGetFloat(in, "scale", 0, &err));
d.scale = int64ToIntS(vsapi->propGetInt(in, "scale", 0, &err));
if (err)
d.scale = 1.0;
d.scale = 1;

// Forward output size
d.output_w = int64ToIntS(vsapi->propGetInt(in, "output_w", 0, &err));
Expand Down Expand Up @@ -435,32 +394,28 @@ static void VS_CC mxCreate(const VSMap *in, VSMap *out, void *userData, VSCore *
if (dev_id < 0)
throw std::string{ "device id must be greater than or equal to 0" };

d.srcBuffer = vs_aligned_malloc(d.patch_w * d.patch_h * ch * d.in_elem, 512);
d.dstBuffer = vs_aligned_malloc(d.output_w * d.output_h * ch * d.out_elem, 512);
d.srcBuffer = vs_aligned_malloc(d.patch_w * d.patch_h * ch * d.in_elem, 128);
d.dstBuffer = vs_aligned_malloc(d.output_w * d.output_h * ch * d.out_elem, 128);
if (!d.srcBuffer || !d.dstBuffer)
throw std::string{ "malloc failure (buffer)" };

const std::string pluginPath{ vsapi->getPluginPath(vsapi->getPluginById("vs.kice.mxnet", core)) };
std::string dataPath{ pluginPath.substr(0, pluginPath.find_last_of('/')) };
const std::string pluginPath = vsapi->getPluginPath(vsapi->getPluginById("vs.kice.mxnet", core));
const std::string dataPath = pluginPath.substr(0, pluginPath.find_last_of('/'));

BufferFile *json_data = new BufferFile(symbol);
if (json_data->GetLength() == 0) {
delete json_data;
json_data = new BufferFile(dataPath + "/mxnet-symbol/" + symbol);
auto json_data = ReadFile(symbol);
if (json_data.empty()) {
json_data = ReadFile(dataPath + "/mxnet-symbol/" + symbol);
}

BufferFile *param_data = new BufferFile(param);
if (param_data->GetLength() == 0) {
delete param_data;
param_data = new BufferFile(dataPath + "/mxnet-symbol/" + param);
auto param_data = ReadFile(param);
if (param_data.empty()) {
param_data = ReadFile(dataPath + "/mxnet-symbol/" + param);
}

defer([&](...) { delete json_data; delete param_data; });

if (json_data->GetLength() == 0 || param_data->GetLength() == 0)
if (json_data.empty() || param_data.empty())
throw std::string{ "Cannot open symbol json file or param data file" };

d.hPred = 0;
d.hPred = nullptr;

// Parameters
int dev_type = ctx == 0 ? 1 : 2;
Expand Down Expand Up @@ -497,9 +452,8 @@ static void VS_CC mxCreate(const VSMap *in, VSMap *out, void *userData, VSCore *

// Create Predictor
if (mx.MXPredCreateEx(
(const char*)json_data->GetBuffer(),
(const char*)param_data->GetBuffer(),
static_cast<int>(param_data->GetLength()),
json_data.data(), param_data.data(),
static_cast<int>(param_data.size()),
dev_type, dev_id,
num_input_nodes,
input_keys, input_shape_indptr, input_shape_data,
Expand Down Expand Up @@ -530,7 +484,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit(VSConfigPlugin configFunc, VSRegiste
"param:data;"
"patch_w:int:opt;"
"patch_h:int:opt;"
"scale:float:opt;"
"scale:int:opt;"
"output_w:int:opt;"
"output_h:int:opt;"
"frame_w:int:opt;"
Expand Down
Loading

0 comments on commit 26a7b79

Please sign in to comment.