diff --git a/python/src/main/java/org/apache/zeppelin/python/PythonDockerInterpreter.java b/python/src/main/java/org/apache/zeppelin/python/PythonDockerInterpreter.java new file mode 100644 index 00000000000..582debd4cde --- /dev/null +++ b/python/src/main/java/org/apache/zeppelin/python/PythonDockerInterpreter.java @@ -0,0 +1,200 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.zeppelin.python; + +import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.scheduler.Scheduler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.*; +import java.nio.file.Paths; +import java.util.Properties; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Helps run python interpreter on a docker container + */ +public class PythonDockerInterpreter extends Interpreter { + Logger logger = LoggerFactory.getLogger(PythonDockerInterpreter.class); + Pattern activatePattern = Pattern.compile("activate\\s*(.*)"); + Pattern deactivatePattern = Pattern.compile("deactivate"); + Pattern helpPattern = Pattern.compile("help"); + private File zeppelinHome; + + public PythonDockerInterpreter(Properties property) { + super(property); + } + + @Override + public void open() { + if (System.getenv("ZEPPELIN_HOME") != null) { + zeppelinHome = new File(System.getenv("ZEPPELIN_HOME")); + } else { + zeppelinHome = Paths.get("..").toAbsolutePath().toFile(); + } + } + + @Override + public void close() { + + } + + @Override + public InterpreterResult interpret(String st, InterpreterContext context) { + File pythonScript = new File(getPythonInterpreter().getScriptPath()); + InterpreterOutput out = context.out; + + Matcher activateMatcher = activatePattern.matcher(st); + Matcher deactivateMatcher = deactivatePattern.matcher(st); + Matcher helpMatcher = helpPattern.matcher(st); + + if (st == null || st.isEmpty() || helpMatcher.matches()) { + printUsage(out); + return new InterpreterResult(InterpreterResult.Code.SUCCESS); + } else if (activateMatcher.matches()) { + String image = activateMatcher.group(1); + pull(out, image); + + // mount pythonscript dir + String mountPythonScript = "-v " + + pythonScript.getParentFile().getAbsolutePath() + + ":/_zeppelin_tmp "; + + // mount zeppelin dir + String mountPy4j = "-v " + + zeppelinHome.getAbsolutePath() + + ":/_zeppelin "; + + // set PYTHONPATH + String pythonPath = ":/_zeppelin/" + PythonInterpreter.ZEPPELIN_PY4JPATH + ":" + + ":/_zeppelin/" + PythonInterpreter.ZEPPELIN_PYTHON_LIBS; + + setPythonCommand("docker run -i --rm " + + mountPythonScript + + mountPy4j + + "-e PYTHONPATH=\"" + pythonPath + "\" " + + image + + " python /_zeppelin_tmp/" + pythonScript.getName()); + restartPythonProcess(); + out.clear(); + return new InterpreterResult(InterpreterResult.Code.SUCCESS, "\"" + image + "\" activated"); + } else if (deactivateMatcher.matches()) { + setPythonCommand(null); + restartPythonProcess(); + return new InterpreterResult(InterpreterResult.Code.SUCCESS, "Deactivated"); + } else { + return new InterpreterResult(InterpreterResult.Code.ERROR, "Not supported command: " + st); + } + } + + + public void setPythonCommand(String cmd) { + PythonInterpreter python = getPythonInterpreter(); + python.setPythonCommand(cmd); + } + + private void printUsage(InterpreterOutput out) { + try { + out.setType(InterpreterResult.Type.HTML); + out.writeResource("output_templates/docker_usage.html"); + } catch (IOException e) { + logger.error("Can't print usage", e); + } + } + + @Override + public void cancel(InterpreterContext context) { + + } + + @Override + public FormType getFormType() { + return FormType.NONE; + } + + @Override + public int getProgress(InterpreterContext context) { + return 0; + } + + /** + * Use python interpreter's scheduler. + * To make sure %python.docker paragraph and %python paragraph runs sequentially + */ + @Override + public Scheduler getScheduler() { + PythonInterpreter pythonInterpreter = getPythonInterpreter(); + if (pythonInterpreter != null) { + return pythonInterpreter.getScheduler(); + } else { + return null; + } + } + + private void restartPythonProcess() { + PythonInterpreter python = getPythonInterpreter(); + python.close(); + python.open(); + } + + protected PythonInterpreter getPythonInterpreter() { + LazyOpenInterpreter lazy = null; + PythonInterpreter python = null; + Interpreter p = getInterpreterInTheSameSessionByClassName(PythonInterpreter.class.getName()); + + while (p instanceof WrappedInterpreter) { + if (p instanceof LazyOpenInterpreter) { + lazy = (LazyOpenInterpreter) p; + } + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + python = (PythonInterpreter) p; + + if (lazy != null) { + lazy.open(); + } + return python; + } + + public boolean pull(InterpreterOutput out, String image) { + int exit = 0; + try { + exit = runCommand(out, "docker", "pull", image); + } catch (IOException | InterruptedException e) { + logger.error(e.getMessage(), e); + throw new InterpreterException(e); + } + return exit == 0; + } + + protected int runCommand(InterpreterOutput out, String... command) + throws IOException, InterruptedException { + ProcessBuilder builder = new ProcessBuilder(command); + builder.redirectErrorStream(true); + Process process = builder.start(); + InputStream stdout = process.getInputStream(); + BufferedReader br = new BufferedReader(new InputStreamReader(stdout)); + String line; + while ((line = br.readLine()) != null) { + out.write(line + "\n"); + } + int r = process.waitFor(); // Let the process finish. + return r; + } +} diff --git a/python/src/main/java/org/apache/zeppelin/python/PythonInterpreter.java b/python/src/main/java/org/apache/zeppelin/python/PythonInterpreter.java index 70942f57c13..f8255681c23 100644 --- a/python/src/main/java/org/apache/zeppelin/python/PythonInterpreter.java +++ b/python/src/main/java/org/apache/zeppelin/python/PythonInterpreter.java @@ -27,9 +27,7 @@ import java.io.OutputStreamWriter; import java.io.PipedInputStream; import java.io.PipedOutputStream; -import java.net.ServerSocket; -import java.net.URISyntaxException; -import java.net.URL; +import java.net.*; import java.nio.file.Path; import java.nio.file.Paths; import java.util.Collection; @@ -59,6 +57,7 @@ import org.slf4j.LoggerFactory; import py4j.GatewayServer; +import py4j.commands.Command; /** * Python interpreter for Zeppelin. @@ -78,7 +77,7 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl private String py4jLibPath; private String pythonLibPath; - private String pythonCommand = DEFAULT_ZEPPELIN_PYTHON; + private String pythonCommand; private GatewayServer gatewayServer; private DefaultExecutor executor; @@ -95,11 +94,10 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl Integer statementSetNotifier = new Integer(0); - public PythonInterpreter(Properties property) { super(property); try { - File scriptFile = File.createTempFile("zeppelin_python-", ".py"); + File scriptFile = File.createTempFile("zeppelin_python-", ".py", new File("/tmp")); scriptPath = scriptFile.getAbsolutePath(); } catch (IOException e) { throw new InterpreterException(e); @@ -128,6 +126,10 @@ private void createPythonScript() { logger.info("File {} created", scriptPath); } + public String getScriptPath() { + return scriptPath; + } + private void copyFile(File out, String sourceFile) { ClassLoader classLoader = getClass().getClassLoader(); try { @@ -141,7 +143,7 @@ private void copyFile(File out, String sourceFile) { } } - private void createGatewayServerAndStartScript() { + private void createGatewayServerAndStartScript() throws UnknownHostException { createPythonScript(); if (System.getenv("ZEPPELIN_HOME") != null) { py4jLibPath = System.getenv("ZEPPELIN_HOME") + File.separator + ZEPPELIN_PY4JPATH; @@ -153,13 +155,28 @@ private void createGatewayServerAndStartScript() { } port = findRandomOpenPortOnAllLocalInterfaces(); - gatewayServer = new GatewayServer(this, port); + gatewayServer = new GatewayServer(this, + port, + GatewayServer.DEFAULT_PYTHON_PORT, + InetAddress.getByName("0.0.0.0"), + InetAddress.getByName("0.0.0.0"), + GatewayServer.DEFAULT_CONNECT_TIMEOUT, + GatewayServer.DEFAULT_READ_TIMEOUT, + (List) null); + gatewayServer.start(); // Run python shell - CommandLine cmd = CommandLine.parse(getPythonCommand()); - cmd.addArgument(scriptPath, false); + String pythonCmd = getPythonCommand(); + CommandLine cmd = CommandLine.parse(pythonCmd); + + if (!pythonCmd.endsWith(".py")) { + // PythonDockerInterpreter set pythoncmd with script + cmd.addArgument(getScriptPath(), false); + } cmd.addArgument(Integer.toString(port), false); + cmd.addArgument(getLocalIp(), false); + executor = new DefaultExecutor(); outputStream = new InterpreterOutputStream(logger); PipedOutputStream ps = new PipedOutputStream(); @@ -185,6 +202,7 @@ private void createGatewayServerAndStartScript() { py4jLibPath + File.pathSeparator + pythonLibPath); } + logger.info("cmd = {}", cmd.toString()); executor.execute(cmd, env, this); pythonscriptRunning = true; } catch (IOException e) { @@ -207,7 +225,11 @@ public void open() { registerHook(HookType.POST_EXEC_DEV, "z._displayhook()"); } // Add matplotlib display hook - createGatewayServerAndStartScript(); + try { + createGatewayServerAndStartScript(); + } catch (UnknownHostException e) { + throw new InterpreterException(e); + } } @Override @@ -244,25 +266,18 @@ public void close() { */ public class PythonInterpretRequest { public String statements; - public String jobGroup; - public PythonInterpretRequest(String statements, String jobGroup) { + public PythonInterpretRequest(String statements) { this.statements = statements; - this.jobGroup = jobGroup; } public String statements() { return statements; } - - public String jobGroup() { - return jobGroup; - } } public PythonInterpretRequest getStatements() { synchronized (statementSetNotifier) { - while (pythonInterpretRequest == null && pythonscriptRunning && pythonScriptInitialized) { try { statementSetNotifier.wait(1000); @@ -350,7 +365,7 @@ public InterpreterResult interpret(String cmd, InterpreterContext contextInterpr return new InterpreterResult(Code.ERROR, errorMessage); } - pythonInterpretRequest = new PythonInterpretRequest(cmd, null); + pythonInterpretRequest = new PythonInterpretRequest(cmd); statementOutput = null; synchronized (statementSetNotifier) { @@ -420,16 +435,17 @@ public List completion(String buf, int cursor) { return null; } - public void setPythonPath(String pythonPath) { - this.pythonPath = pythonPath; - } - public void setPythonCommand(String cmd) { + logger.info("Set Python Command : {}", cmd); pythonCommand = cmd; } public String getPythonCommand() { - return pythonCommand; + if (pythonCommand == null) { + return DEFAULT_ZEPPELIN_PYTHON; + } else { + return pythonCommand; + } } private Job getRunningJob(String paragraphId) { @@ -462,8 +478,14 @@ public GUI getGui() { return context.getGui(); } - public Integer getPy4jPort() { - return port; + String getLocalIp() { + try { + return Inet4Address.getLocalHost().getHostAddress(); + } catch (UnknownHostException e) { + logger.error("can't get local IP", e); + } + // fall back to loopback addreess + return "127.0.0.1"; } private int findRandomOpenPortOnAllLocalInterfaces() { diff --git a/python/src/main/resources/interpreter-setting.json b/python/src/main/resources/interpreter-setting.json index 36f7ad07f65..bc4d4ec5280 100644 --- a/python/src/main/resources/interpreter-setting.json +++ b/python/src/main/resources/interpreter-setting.json @@ -43,5 +43,16 @@ "language": "sh", "editOnDblClick": false } + }, + { + "group": "python", + "name": "docker", + "className": "org.apache.zeppelin.python.PythonDockerInterpreter", + "properties": { + }, + "editor":{ + "language": "sh", + "editOnDblClick": false + } } ] diff --git a/python/src/main/resources/python/zeppelin_python.py b/python/src/main/resources/python/zeppelin_python.py index 24202c0825e..0a36cbafe94 100644 --- a/python/src/main/resources/python/zeppelin_python.py +++ b/python/src/main/resources/python/zeppelin_python.py @@ -18,7 +18,7 @@ import os, sys, getopt, traceback, json, re from py4j.java_gateway import java_import, JavaGateway, GatewayClient -from py4j.protocol import Py4JJavaError +from py4j.protocol import Py4JJavaError, Py4JNetworkError import warnings import ast import traceback @@ -175,11 +175,11 @@ def handler_stop_signals(sig, frame): signal.signal(signal.SIGINT, handler_stop_signals) -output = Logger() -sys.stdout = output -sys.stderr = output +host = "127.0.0.1" +if len(sys.argv) >= 3: + host = sys.argv[2] -client = GatewayClient(port=int(sys.argv[1])) +client = GatewayClient(address=host, port=int(sys.argv[1])) #gateway = JavaGateway(client, auto_convert = True) gateway = JavaGateway(client) @@ -190,11 +190,17 @@ def handler_stop_signals(sig, frame): z = PyZeppelinContext() z._setup_matplotlib() +output = Logger() +sys.stdout = output +#sys.stderr = output + while True : req = intp.getStatements() + if req == None: + break + try: stmts = req.statements().split("\n") - jobGroup = req.jobGroup() final_code = [] # Get post-execute hooks @@ -227,7 +233,6 @@ def handler_stop_signals(sig, frame): if final_code: # use exec mode to compile the statements except the last statement, # so that the last statement's evaluation will be printed to stdout - #sc.setJobGroup(jobGroup, "Zeppelin") code = compile('\n'.join(final_code), '', 'exec', ast.PyCF_ONLY_AST, 1) to_run_hooks = [] @@ -262,6 +267,9 @@ def handler_stop_signals(sig, frame): if innerErrorStart > -1: excInnerError = excInnerError[innerErrorStart:] intp.setStatementsFinished(excInnerError + str(sys.exc_info()), True) + except Py4JNetworkError: + # lost connection from gateway server. exit + sys.exit(1) except: intp.setStatementsFinished(traceback.format_exc(), True) diff --git a/python/src/test/java/org/apache/zeppelin/python/PythonDockerInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/PythonDockerInterpreterTest.java new file mode 100644 index 00000000000..566b5e0b35a --- /dev/null +++ b/python/src/test/java/org/apache/zeppelin/python/PythonDockerInterpreterTest.java @@ -0,0 +1,94 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.zeppelin.python; + +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import java.io.IOException; +import java.net.Inet4Address; +import java.net.UnknownHostException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.*; + +public class PythonDockerInterpreterTest { + private PythonDockerInterpreter docker; + private PythonInterpreter python; + + @Before + public void setUp() { + docker = spy(new PythonDockerInterpreter(new Properties())); + python = mock(PythonInterpreter.class); + + InterpreterGroup group = new InterpreterGroup(); + group.put("note", Arrays.asList(python, docker)); + python.setInterpreterGroup(group); + docker.setInterpreterGroup(group); + + doReturn(true).when(docker).pull(any(InterpreterOutput.class), anyString()); + doReturn(python).when(docker).getPythonInterpreter(); + doReturn("/scriptpath/zeppelin_python.py").when(python).getScriptPath(); + + docker.open(); + } + + @Test + public void testActivateEnv() { + InterpreterContext context = getInterpreterContext(); + docker.interpret("activate env", context); + verify(python, times(1)).open(); + verify(python, times(1)).close(); + verify(docker, times(1)).pull(any(InterpreterOutput.class), anyString()); + verify(python).setPythonCommand(Mockito.matches("docker run -i --rm -v.*")); + } + + @Test + public void testDeactivate() { + InterpreterContext context = getInterpreterContext(); + docker.interpret("deactivate", context); + verify(python, times(1)).open(); + verify(python, times(1)).close(); + verify(python).setPythonCommand(null); + } + + private InterpreterContext getInterpreterContext() { + return new InterpreterContext( + "noteId", + "paragraphId", + "replName", + "paragraphTitle", + "paragraphText", + new AuthenticationInfo(), + new HashMap(), + new GUI(), + null, + null, + null, + new InterpreterOutput(null)); + } +} diff --git a/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java index affc7c41c86..b5cd680d8da 100644 --- a/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java +++ b/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java @@ -95,7 +95,7 @@ public void afterTest() throws IOException { @Test public void testInterpret() throws InterruptedException, IOException { - InterpreterResult result = pythonInterpreter.interpret("print \"hi\"", context); + InterpreterResult result = pythonInterpreter.interpret("print (\"hi\")", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); }