Skip to content

Commit

Permalink
Add ExUnit.TmpDir.tmp_dir/0
Browse files Browse the repository at this point in the history
Inspired by golang/go#35998.
  • Loading branch information
wojtekmach committed May 15, 2020
1 parent 42af53f commit 90ea4e0
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
2 changes: 2 additions & 0 deletions lib/ex_unit/lib/ex_unit/case.ex
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ defmodule ExUnit.Case do

test = %ExUnit.Test{name: name, case: mod, tags: tags, module: mod}
Module.put_attribute(mod, :ex_unit_tests, test)
Module.put_attribute(mod, :ex_unit_test, test)
Module.put_attribute(mod, :ex_unit_tmp_dirs, [])

for attribute <- Module.get_attribute(mod, :ex_unit_registered_test_attributes) do
Module.delete_attribute(mod, attribute)
Expand Down
33 changes: 33 additions & 0 deletions lib/ex_unit/lib/ex_unit/tmp_dir.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
defmodule ExUnit.TmpDir do
@doc """
Returns a path to a temporary directory for the current test.
The directory is lazily created on the first access and is automatically
removed once the test finishes (regardless if the test suceeds or fails).
"""
defmacro tmp_dir do
test = Module.get_attribute(__CALLER__.module, :ex_unit_test)

unless test do
raise "cannot invoke tmp_dir/0 outside of a test. Please make sure you have invoked " <>
"\"use ExUnit.Case\" in the current module"
end

path = Path.join([System.tmp_dir(), "#{inspect(test.module)}_#{test.name}"])
tmp_dirs = Module.get_attribute(__CALLER__.module, :ex_unit_tmp_dirs)

if test.name in tmp_dirs do
quote do
unquote(path)
end
else
Module.put_attribute(__CALLER__.module, :ex_unit_tmp_dirs, [test.name | tmp_dirs])

quote bind_quoted: [path: path] do
File.mkdir_p!(path)
on_exit(fn -> File.rm_rf!(path) end)
path
end
end
end
end
7 changes: 3 additions & 4 deletions lib/mix/test/mix/shell_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ Code.require_file("../test_helper.exs", __DIR__)

defmodule Mix.ShellTest do
use MixTest.Case
import ExUnit.TmpDir

defp capture_io(fun) do
fun |> ExUnit.CaptureIO.capture_io() |> String.replace("\r\n", "\n")
Expand All @@ -20,12 +21,10 @@ defmodule Mix.ShellTest do

test "with :cd" do
Mix.shell(Mix.Shell.IO)
tmp_dir = System.tmp_dir()
File.mkdir_p!(tmp_dir)
{pwd, 0} = System.cmd("pwd", [], cd: tmp_dir)
{pwd, 0} = System.cmd("pwd", [], cd: tmp_dir())

assert ExUnit.CaptureIO.capture_io(fn ->
Mix.shell().cmd("pwd", cd: tmp_dir)
Mix.shell().cmd("pwd", cd: tmp_dir())
end) == pwd
after
Mix.shell(Mix.Shell.Process)
Expand Down

0 comments on commit 90ea4e0

Please sign in to comment.