diff --git a/src/core.jl b/src/core.jl index 2d3bb771..a7c5926e 100644 --- a/src/core.jl +++ b/src/core.jl @@ -554,8 +554,7 @@ mutable struct Session this = new(ptr, graph) check_status(status) finalizer(this, self->begin - status = Status() - @tfcall(:TF_DeleteSession, Void, (Ptr{Void}, Ptr{Void}), self.ptr, status.ptr) + close(self) end) return this end @@ -571,6 +570,21 @@ mutable struct Session end end +""" + close(sess::Session) + +Closes the TensorFlow session, freeing the associated computational resources. +""" +function Base.close(sess::Session) + if sess.ptr != C_NULL + status = Status() + @tfcall(:TF_DeleteSession, Void, (Ptr{Void}, Ptr{Void}), sess.ptr, status.ptr) + check_status(status) + sess.ptr = C_NULL + end + return nothing +end + mutable struct Buffer ptr::Ptr{Void} diff --git a/src/run.jl b/src/run.jl index c82f56b3..6d679367 100644 --- a/src/run.jl +++ b/src/run.jl @@ -77,9 +77,18 @@ function build_input(tensor_map::Dict) input_tensors, input_values end +struct ClosedSessionError <: Exception +end + +function Base.show(io::IO, err::ClosedSessionError) + print(io, "An operation was attempted on a closed TensorFlow session.") +end + function run(sess::Session, inputs, input_values, outputs, targets) #Low level run, without size checking, and type conversion etc. - + if sess.ptr == C_NULL + throw(ClosedSessionError()) + end status = Status() output_values = fill(C_NULL, length(outputs)) input_tensors = [RawTensor(x) for x in input_values] @@ -184,6 +193,9 @@ end """ + run(sess::Session, output, input_dict::Dict) + + Compute the result of one of more operations in the computation graph. """ function run(sess::Session, output, input_dict) diff --git a/test/core.jl b/test/core.jl index 8bf438b8..ce885c93 100644 --- a/test/core.jl +++ b/test/core.jl @@ -26,6 +26,15 @@ end end end +@testset "Session closing" begin + session = tf.Session(Graph()) + x = constant(1) + @test run(session, x) == 1 + close(session) + close(session) # Test that we can safely call `close` twice on the same session + @test_throws tf.ClosedSessionError run(session, x) +end + @testset "get_operations" begin let graph = Graph()