Skip to content

Commit

Permalink
fix KATE
Browse files Browse the repository at this point in the history
  • Loading branch information
swiesend committed Feb 12, 2018
1 parent 2be54d2 commit 344972e
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 76 deletions.
137 changes: 78 additions & 59 deletions src/KATE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,6 @@ using Flux.Tracker: data
using StatsBase: wsample
using Base.Iterators: partition

function non_empty_string(s)
s != ""
end

function count_words(text::Array{String})
unique_words = unique(text)
word_counts = Dict{String,Int64}(map(word -> (word, 0), unique_words))
map(word -> word_counts[word] += 1, text)
word_counts
end

function normalize_log(text::Array{String}, word_counts::Dict{String,Int64})
V = length(word_counts)

function nl(word::String, wc::Dict{String,Int64}, V::Int)
n = wc[word]
logwc = log(1+n)
logwc/V*logwc
end

map(word -> nl(word,word_counts,V), text)
end

function transform_text_to_input(src::String, limit::Int=1000)
text_raw = readstring(src)
text = map(s -> String(s),
filter(non_empty_string, split(text_raw, r"[^\wäÄöÖüÜ&]+")))
text = text[1:limit]
word_counts = count_words(text)
input = normalize_log(text, word_counts)
return input
end

struct KCompetetive{F,S,T}
σ::F
W::S
Expand All @@ -60,57 +27,56 @@ treelike(KCompetetive)
function (a::KCompetetive)(x)
W, b, σ, α = a.W, a.b, a.σ, a.α

# @show x
# @show typeof(x), length(x)

ps = Vector{Float64}()
ns = Vector{Float64}()
ps = Vector{Tuple{Int64,Float64}}()
ns = Vector{Tuple{Int64,Float64}}()

for (i, activation) = enumerate(data(x))
if activation >= 0
push!(ps, activation)
push!(ps, (i, activation))
else
push!(ns, activation)
push!(ns, (i, activation))
end
end
ps = sort(ps)
ns = sort(ns,rev=true)
ps = sort(ps, by=last)
ns = sort(ns, by=last, rev=true)
P = length(ps)
N = length(ns)
k = last(size(a.W))
# @show ps, ns
# @show typeof(ps), typeof(ns)
# @show P, N, k, k==P+N, size(W)
k = first(size(a.W))

z = data(x)

p = P-Int(k/2)
if p > 0
Epos = sum(ps[1:p])
Epos = sum(map(last, ps[1:p]))
# @show Epos
for i = p+1:P
ps[i] += α * Epos
# positive winners
z[first(ps[i])] += α * Epos
end
for i = 1:p
ps[i] = 0.0
# positive losers
z[first(ps[i])] = 0.0
end
end

n = N-Int(k/2)
if n > 0
Eneg = sum(ns[1:n])
Eneg = sum(map(last, ns[1:n]))
# @show Eneg
for i = n+1:N
ns[i] += α * Eneg
# negative winners
z[first(ns[i])] += α * Eneg
end
for i = 1:n
ns[i] = 0.0
# negative losers
z[first(ns[i])] = 0.0
end
end

z = vcat(reverse(ps),ns) #needs to be param?!
# z = param(rand(k))
# z = x
# @show z
# @show typeof(z), length(z), p, n

σ.(W*z .+ b)
# result = σ.(W*x .+ b)
# @show result
# result
σ.(W*x .+ b)
end

function Base.show(io::IO, l::KCompetetive)
Expand All @@ -119,4 +85,57 @@ function Base.show(io::IO, l::KCompetetive)
print(io, ")")
end

function non_empty_string(s)
s != ""
end

function count_words(text::Array{String})
unique_words = unique(text)
word_counts = Dict{String,Int64}(map(word -> (word, 0), unique_words))
map(word -> word_counts[word] += 1, text)
word_counts
end

function nl(word::String, wc::Dict{String,Int64}, V::Int)
n = wc[word]
logwc = log(1+n)
logwc/V*logwc
end

function normalize_log(text::Array{String}, word_counts::Dict{String,Int64})
V = length(word_counts)
map(word -> nl(word,word_counts,V), text)
end

function transform_text_to_input(src::String, limit::Int=1000)
text_raw = readstring(src)
text = map(s -> String(s),
filter(non_empty_string, split(text_raw, r"[^\wäÄöÖüÜ&]+")))
text = text[1:limit]
word_counts = count_words(text)
input = normalize_log(text, word_counts)
return input, word_counts
end

function get_similar_words(model, query_id, vocab; topn=10)
# @show vocab
W = model[1].W.data
W = W/norm(W)
query = W[query_id]
# @show query
score = query*(W')
# @show score
# @show size(score), typeof(score)
# @show score[:,1]
# @show size(score[:,1]), typeof(score[:,1])
vidx = sort(score[:,1])[1:topn]
# @show vidx
# weights = unitmatrix(weights) # normalize
# query = weights[query_id]
# score = query.dot(weights.T)
# vidx = score.argsort()[::-1][:topn]

return [vocab[idx] for idx in vidx]
end

end # module KATE
27 changes: 10 additions & 17 deletions test/test_KATE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ isfile("input.txt") ||
download("http://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt",
"input.txt")

input = KATE.transform_text_to_input("input.txt", 1000)
# @show typeof(input), typeof(wc)
input, wc = KATE.transform_text_to_input("input.txt", 1000)
vocabulary = collect(keys(wc))

# Xs, Ys = deepcopy(input), deepcopy(input)
# Xs, Ys = deepcopy([input]), deepcopy([input])
Expand All @@ -24,11 +24,12 @@ Xs, Ys = [deepcopy(input)], [deepcopy(input)]

N = length(input)

k = 64

m = Chain(
Dense(N, 128, tanh),
KCompetetive(128, 128, tanh),
Dense(128, N, sigmoid),
softmax)
KCompetetive(N, k, tanh),
Dense(k, N, sigmoid)
)

function loss(xs, ys)
# @show typeof(xs), typeof(ys)
Expand All @@ -43,21 +44,13 @@ end

opt = Flux.ADADelta(params(m))

@progress for i = 1:5000
@progress for i = 1:10
info("Epoch $i")
Flux.train!(
loss,
zip(Xs, Ys),
opt,
cb = evaluation_callback)
cb = throttle(evaluation_callback, 30))
end

# function sample(m, alphabet, len; temp = 1)
# buf = IOBuffer()
# c = rand(alphabet)
# for i = 1:len
# write(buf, c)
# c = wsample(alphabet, m(onehot(c, alphabet)).data)
# end
# return String(take!(buf))
# end
# @show KATE.get_similar_words(m, 1, vocabulary)

0 comments on commit 344972e

Please sign in to comment.