Skip to content

Commit

Permalink
Merge pull request #342 from malmaud/sess_close
Browse files Browse the repository at this point in the history
Define `Base.close(::Session)`
  • Loading branch information
malmaud authored Oct 24, 2017
2 parents 5d40570 + f25ea1e commit 9a2b5a1
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
18 changes: 16 additions & 2 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,7 @@ mutable struct Session
this = new(ptr, graph)
check_status(status)
finalizer(this, self->begin
status = Status()
@tfcall(:TF_DeleteSession, Void, (Ptr{Void}, Ptr{Void}), self.ptr, status.ptr)
close(self)
end)
return this
end
Expand All @@ -571,6 +570,21 @@ mutable struct Session
end
end

"""
close(sess::Session)
Closes the TensorFlow session, freeing the associated computational resources.
"""
function Base.close(sess::Session)
if sess.ptr != C_NULL
status = Status()
@tfcall(:TF_DeleteSession, Void, (Ptr{Void}, Ptr{Void}), sess.ptr, status.ptr)
check_status(status)
sess.ptr = C_NULL
end
return nothing
end


mutable struct Buffer
ptr::Ptr{Void}
Expand Down
14 changes: 13 additions & 1 deletion src/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,18 @@ function build_input(tensor_map::Dict)
input_tensors, input_values
end

struct ClosedSessionError <: Exception
end

function Base.show(io::IO, err::ClosedSessionError)
print(io, "An operation was attempted on a closed TensorFlow session.")
end

function run(sess::Session, inputs, input_values, outputs, targets)
#Low level run, without size checking, and type conversion etc.

if sess.ptr == C_NULL
throw(ClosedSessionError())
end
status = Status()
output_values = fill(C_NULL, length(outputs))
input_tensors = [RawTensor(x) for x in input_values]
Expand Down Expand Up @@ -184,6 +193,9 @@ end


"""
run(sess::Session, output, input_dict::Dict)
Compute the result of one of more operations in the computation graph.
"""
function run(sess::Session, output, input_dict)
Expand Down
9 changes: 9 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ end
end
end

@testset "Session closing" begin
session = tf.Session(Graph())
x = constant(1)
@test run(session, x) == 1
close(session)
close(session) # Test that we can safely call `close` twice on the same session
@test_throws tf.ClosedSessionError run(session, x)
end

@testset "get_operations" begin
let
graph = Graph()
Expand Down

0 comments on commit 9a2b5a1

Please sign in to comment.