This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
R-package problem with "predict" function : array shapes mismatch #6919
Comments
It seems the Windows pkg needs update. On my Linux machine: > library(mxnet)
> p = 32
> n_label = 10
> n = 89
> X = matrix(runif(p*n,0,1),p,n)
> y = matrix(round(rpois(n_label*n,10)),n_label,n)
> data <- mx.symbol.Variable("data")
> label=mx.symbol.Variable('label')
> fc <- mx.symbol.FullyConnected(data, num_hidden=n_label,name="fc")
> vecto.symb = mx.symbol.MakeLoss(data= mx.symbol.exp(fc) - label * fc , name="poisson")
> devices = mx.gpu(0)
> model = mx.model.FeedForward.create(symbol = vecto.symb,ctx = devices, X=X,y=y,num.round=5,array.layout="colmajor",learning.rate=0.01, optimizer="sgd",initializer=mx.init.normal(0.03),array.batch.size=20)
Start training with 1 devices
> p = predict(model,X,array.layout="colmajor")
> p
[,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
[1,] 7.816670 8.157823 10.474072 11.39288 7.779306 6.970769 9.895404 12.01964 7.290720 8.944089
[2,] 7.453233 8.445386 10.373491 12.59584 7.349095 6.664856 10.164314 11.51193 6.594102 8.864347
[3,] 8.928808 10.708864 10.548546 13.22053 7.846766 7.146811 10.764274 13.85162 7.298235 9.293221
[4,] 8.230606 9.123090 11.217604 13.23091 8.138767 7.019768 10.248020 14.35250 7.356753 8.135243
[5,] 8.071321 8.288631 10.508270 12.12168 7.196826 6.733716 9.084327 12.06030 6.710617 8.396719
[6,] 8.482371 10.366955 9.545642 11.35876 7.762624 6.653240 11.191854 13.40291 7.470777 8.287533
[7,] 7.665237 9.516909 10.175713 12.69094 7.598200 6.970988 11.616826 13.16461 6.503124 9.161448
[8,] 7.701565 9.458061 9.218641 12.51689 7.938196 6.988401 11.182857 12.04601 7.121068 8.645683
[9,] 7.566068 8.341774 8.861707 10.38127 6.559480 6.436257 9.217819 10.99246 6.040694 7.019975
[10,] 8.520127 9.488525 10.543437 12.88380 7.868061 7.011839 10.414857 12.94717 7.134251 9.127591
[,11] [,12] [,13] [,14] [,15] [,16] [,17] [,18] [,19] [,20]
[1,] 9.168724 8.523341 7.678204 7.943961 7.256366 10.397227 9.454224 9.628304 10.630404 9.333558
[2,] 8.994869 10.060082 8.047057 7.978253 8.053446 9.684225 8.919069 9.212090 10.225703 9.098025
[3,] 9.813600 10.093615 8.200279 7.718523 7.769364 11.378102 9.855852 10.003072 10.997955 9.812004
[4,] 9.079152 11.014524 7.525928 8.309593 6.809573 12.059962 10.649134 9.013678 12.408504 9.326276
[5,] 7.977436 9.912253 6.456532 7.237380 6.385749 10.745084 8.633645 8.501617 10.114634 7.791702
[6,] 9.058272 8.981851 7.984339 7.783956 7.703894 10.208138 10.439919 11.094374 11.401168 8.967443
[7,] 9.031200 10.548182 9.062616 8.302197 7.647370 10.454777 10.687029 9.221517 13.639665 8.928701
[8,] 10.083397 9.335425 8.125834 9.245922 7.748134 10.652124 9.383847 9.822195 10.822229 8.678563
[9,] 8.763423 8.120615 6.653493 7.085348 6.808999 8.905265 8.267180 8.320374 9.960668 8.852212
[10,] 9.612265 9.996041 8.390809 8.186895 7.708246 10.205514 9.550326 9.781266 12.611280 9.213140
[,21] [,22] [,23] [,24] [,25] [,26] [,27] [,28] [,29] [,30]
[1,] 11.308876 9.288301 9.569324 12.46846 10.172219 8.790600 9.483993 13.28197 10.465868 11.33658
[2,] 11.253362 8.632038 9.748217 11.87590 10.855169 9.407741 10.046068 11.42039 9.945236 11.29840
[3,] 12.477365 9.152213 11.524874 15.02225 11.458111 9.633025 9.971084 14.89848 10.365815 12.15181
[4,] 11.579143 10.519217 11.190501 14.75211 12.171769 9.657764 11.083182 14.42005 11.121315 11.32189
[5,] 10.637225 9.505926 9.946135 12.60800 11.126405 8.921020 9.440014 13.36149 9.041626 10.35690
[6,] 12.291706 8.437302 10.696656 14.43274 10.379085 9.411674 9.331091 12.90256 9.731915 11.89710
[7,] 12.905643 9.070361 11.840013 14.30707 11.611399 9.670755 9.901644 14.70455 9.644423 11.26303
[8,] 11.293889 8.842366 10.220294 14.22478 10.802110 8.745304 10.874222 14.66447 9.923064 11.58790
[9,] 9.354656 7.657266 8.565167 11.63647 9.593769 8.502037 9.264493 11.46476 9.335666 10.89800
[10,] 11.988859 9.438472 10.797145 14.80075 11.518509 9.250444 9.859594 13.98825 10.614679 11.72729
[,31] [,32] [,33] [,34] [,35] [,36] [,37] [,38] [,39] [,40]
[1,] 9.866580 10.49420 8.642398 8.790415 9.733511 11.99423 6.756829 7.628892 9.427886 7.902532
[2,] 8.441378 10.91738 7.268434 9.354983 8.620921 11.54677 6.560183 7.412773 9.592722 7.433565
[3,] 11.526748 12.11366 9.255115 10.099893 11.629895 13.73594 7.003594 8.172505 11.263753 7.457634
[4,] 9.954642 11.29973 9.896523 10.200948 9.685892 14.23606 7.303020 8.677130 10.437410 7.933517
[5,] 9.637197 10.60180 8.510577 9.413569 8.995958 12.14189 6.585538 8.386744 9.185835 7.152886
[6,] 11.354940 10.81478 9.096291 9.356621 10.409036 13.30373 6.882611 7.748830 10.529068 6.938926
[7,] 10.151248 11.24501 9.504297 10.656692 9.294201 14.29821 8.126706 7.991077 10.253646 7.316355
[8,] 9.457038 12.77455 9.239510 9.746669 11.279565 14.08356 6.649528 9.087983 9.574911 8.240922
[9,] 8.786914 9.41922 7.200623 8.183215 8.312158 11.07285 5.736485 6.726950 8.709372 6.553879
[10,] 9.437313 11.99361 9.123729 9.999371 9.689675 12.84750 7.461349 8.532560 10.222152 7.424592
[,41] [,42] [,43] [,44] [,45] [,46] [,47] [,48] [,49] [,50]
[1,] 8.649339 5.314861 12.621219 5.514723 10.003050 9.503062 11.746906 11.38585 11.742462 8.583298
[2,] 7.704547 5.672434 10.236336 5.900691 9.757183 8.363594 11.409889 10.86170 10.496987 7.399814
[3,] 10.187313 5.949660 12.824229 6.548615 11.043896 11.143600 12.783177 13.57207 12.047352 8.578787
[4,] 8.473114 5.517111 12.646370 6.752975 11.089304 10.094254 12.687654 14.06325 12.750614 8.981051
[5,] 8.205828 5.278526 11.942309 6.155215 10.293086 9.945828 11.597883 11.71625 10.327518 7.684636
[6,] 9.521443 6.392505 12.312798 6.588591 11.302396 9.797333 11.525074 12.75276 12.482216 8.209350
[7,] 8.503182 6.182213 13.231131 6.001590 11.974772 10.321801 14.011294 11.64822 13.432488 7.991267
[8,] 9.458187 6.004597 12.362864 5.971149 11.299578 8.827434 12.157065 12.39539 11.267076 7.882904
[9,] 7.755996 5.199047 9.454911 5.931271 9.107282 7.602388 9.795255 11.51299 9.897947 7.584461
[10,] 9.022410 6.131191 11.992719 6.484079 11.273059 9.595859 12.916764 12.10417 11.582853 8.536198
[,51] [,52] [,53] [,54] [,55] [,56] [,57] [,58] [,59] [,60]
[1,] 9.244983 9.026376 11.60768 7.445289 12.20319 9.007307 8.202882 11.69678 11.291792 11.53786
[2,] 8.126652 8.328223 13.41882 7.544335 10.47258 9.024238 9.460758 11.10555 10.212744 10.30471
[3,] 8.480664 10.294396 13.88592 8.045505 12.07821 9.582493 9.304089 11.37732 12.087280 13.55252
[4,] 9.313348 9.124362 13.88364 9.203846 12.22964 9.436004 8.649411 12.50475 11.735001 12.62620
[5,] 7.924311 8.885804 12.02092 7.150859 10.57416 8.598110 8.200793 10.52620 10.053449 10.77155
[6,] 8.096569 9.666099 12.58955 7.703288 12.29899 8.817454 8.296067 10.80231 10.658768 11.92008
[7,] 8.842400 9.423360 13.69213 9.047062 11.73560 9.024308 8.826710 11.14274 12.560871 12.56917
[8,] 9.414310 9.070774 12.86080 7.420550 12.20953 10.125348 9.079483 12.66835 11.415252 12.35455
[9,] 7.392401 7.960190 11.09324 6.392342 10.30098 8.214180 8.085503 10.24684 9.695411 10.35095
[10,] 9.260792 9.340340 14.07300 8.074622 12.85251 9.981878 9.257211 11.91849 11.715885 12.43994
[,61] [,62] [,63] [,64] [,65] [,66] [,67] [,68] [,69] [,70]
[1,] 8.964652 10.141632 9.096560 8.861149 10.476036 9.345021 10.259480 9.010168 11.57214 6.659705
[2,] 8.534002 10.257945 7.253962 8.705607 9.481013 7.980782 10.571997 8.232688 10.29385 6.539469
[3,] 10.711432 10.380169 8.970151 9.116873 11.060877 11.656600 11.390860 11.598317 12.59681 7.431988
[4,] 10.769454 10.100412 10.068944 10.034698 11.652020 10.033549 11.318120 10.702707 12.96376 7.485595
[5,] 9.410174 9.412863 9.099507 9.225079 10.146431 10.289325 11.106381 10.103276 11.16685 6.908048
[6,] 9.705215 9.786076 8.753164 8.416007 11.158924 9.879974 11.187751 10.781698 11.69304 7.059194
[7,] 9.968441 10.295626 9.260633 9.638412 10.761553 9.065764 11.900574 9.745747 11.93590 6.707314
[8,] 9.288058 9.857211 8.992968 9.769420 10.898926 9.228750 10.567677 8.946614 11.19650 7.743548
[9,] 8.392508 8.878320 7.380243 8.299747 8.773889 8.152227 8.889181 8.789065 10.46103 6.591502
[10,] 10.565228 10.766788 9.772991 9.537443 11.391335 9.922419 11.529438 10.707608 12.25396 7.643187
[,71] [,72] [,73] [,74] [,75] [,76] [,77] [,78] [,79] [,80]
[1,] 8.496582 5.932070 8.736745 7.127186 10.683019 11.51356 8.878646 7.619807 11.38468 11.048352
[2,] 7.593750 6.164192 8.259268 6.310826 9.306476 11.35056 9.147801 7.353465 10.92469 9.857203
[3,] 9.020813 6.697681 10.578875 8.397291 11.469223 13.95665 9.171689 8.823219 13.22135 12.855165
[4,] 8.484274 6.673910 9.801679 6.904182 11.306722 12.91221 9.277265 9.201324 12.87133 11.594430
[5,] 8.706759 6.278548 8.878038 6.869167 10.205997 11.35329 8.014498 8.201095 11.34289 10.734923
[6,] 9.373224 6.283204 10.551237 7.767508 10.409093 13.09805 9.084136 7.589930 11.67227 11.563037
[7,] 8.381074 6.922416 10.277338 7.866474 11.223840 13.72707 10.068846 8.340858 10.85645 10.862444
[8,] 8.100209 6.363280 9.032307 7.740890 11.322384 11.65092 8.911162 7.813134 12.39743 11.608423
[9,] 6.830784 5.390416 7.995050 6.114217 9.249428 10.38231 7.801753 7.293544 10.63287 9.385077
[10,] 8.761749 6.808009 9.693430 7.363374 10.770193 12.13806 9.655453 8.921905 11.41752 11.697877
[,81] [,82] [,83] [,84] [,85] [,86] [,87] [,88] [,89]
[1,] 8.319080 9.214132 7.316243 8.342993 8.825691 10.382159 6.941805 8.265296 11.088474
[2,] 7.492801 7.521801 7.128818 8.230494 7.547138 11.223130 6.734413 8.538315 9.608241
[3,] 8.821959 9.503194 7.671940 9.203848 10.735788 13.862556 7.180660 8.137114 9.824462
[4,] 8.719627 9.177809 8.077137 9.009045 10.390970 12.364767 7.818877 8.203750 10.473236
[5,] 7.458200 8.372662 7.605329 8.418344 9.782778 11.312627 6.777676 7.380335 9.437925
[6,] 9.505312 9.496328 7.079411 8.963075 11.295750 12.375880 7.207112 7.754936 9.533134
[7,] 9.581206 9.374516 7.761055 9.984800 9.554749 12.038147 6.906099 8.259149 10.723992
[8,] 8.620052 9.202593 7.301843 9.875078 9.660048 12.118854 7.423788 7.825188 11.017103
[9,] 7.345890 7.959776 6.273556 7.370767 8.206044 9.588604 6.565289 7.717051 8.245582
[10,] 8.573742 8.522071 7.391146 9.527450 9.333231 11.119854 7.414830 8.066483 10.641735
> |
@thirdwing , So you reckon it should work if I re-do the installation (https://gist.github.com/thirdwing/89aa9bfc588ade138496e6932072152c) with the last update of windows build (by downloading the last version of |
It should work if you use the latest code. You can also wait for a while. I will update the Windows pkg. |
The prebuilt pkgs have been updated. |
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Environment info
Operating System: Windows 10
Package used : R
MXNet version: 0.10.1 (given by R). I installed the gpu version by following this tutorial https://gist.github.com/thirdwing/89aa9bfc588ade138496e6932072152c one month ago, and it was working fine until now.
R
sessionInfo()
:R version 3.3.2 (2016-10-31)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows >= 8 x64 (build 9200)
Error Message:
[18:18:32] d:\program files (x86)\jenkins\workspace\mxnet\mxnet\dmlc-core\include\dmlc./logging.h:304: [18:18:32] d:\program files (x86)\jenkins\workspace\mxnet\mxnet\src\operator\tensor../elemwise_op_common.h:33: Check failed: assign(&dattr, (*vec)[i]) Incompatible attr in node _mul11 at 1-th input: expected (89,), got (89,10)
Error in symbol$infer.shape(list(...)) :
Error in operator _mul11: [18:18:32] d:\program files (x86)\jenkins\workspace\mxnet\mxnet\src\operator\tensor../elemwise_op_common.h:33: Check failed: assign(&dattr, (*vec)[i]) Incompatible attr in node _mul11 at 1-th input: expected (89,), got (89,10)
Minimum reproducible example
# This example is a multi-label regression on a generalized linear model with poisson likelihood
# we simulate random data
p = 32
n_label = 10
n = 89
X = matrix(runif(p*n,0,1),p,n)
y = matrix(round(rpois(n_label*n,10)),n_label,n)
# model building
data <- mx.symbol.Variable("data")
label=mx.symbol.Variable('label')
# we create a parametrized linear combination of input variables in fc
fc <- mx.symbol.FullyConnected(data, num_hidden=n_label,name="fc")
# As the loss, we write the Negative Log-Likelihood associated with an exponential link function
vecto.symb = mx.symbol.MakeLoss(data= mx.symbol.exp(fc) - label * fc , name="poisson")
devices = mx.gpu(0)
model = mx.model.FeedForward.create(symbol = vecto.symb,ctx = devices, X=X,y=y,num.round=5,array.layout="colmajor",learning.rate=0.01, optimizer="sgd",initializer=mx.init.normal(0.03),array.batch.size=20)
p = predict(model,X,array.layout="colmajor")
What have you tried to solve it?
The error appears when using the
predict
function. The same code works fine on another machine (same Windows 10, same R version) which has the classic cpu R install (used install.packages) of mxnet.So the error might be an installation problem with the gpu version? or it might be related to #113 ?
The text was updated successfully, but these errors were encountered: