Skip to content

Commit

Permalink
Working setup
Browse files Browse the repository at this point in the history
  • Loading branch information
sondreso committed Aug 14, 2024
1 parent 77fa511 commit 8b6d0f2
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 87 deletions.
136 changes: 109 additions & 27 deletions src/ert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@
ValidationStatus,
)

from opentelemetry._logs import set_logger_provider
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider, SpanLimits
from opentelemetry.trace import Status, StatusCode
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
ConsoleSpanExporter,
)
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.threading import ThreadingInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, ConsoleLogExporter
from opentelemetry.exporter.otlp.proto.http._log_exporter import (
OTLPLogExporter,
)
from azure.monitor.opentelemetry.exporter import AzureMonitorTraceExporter


def run_ert_storage(args: Namespace, _: Optional[ErtPluginManager] = None) -> None:
with StorageService.start_server(
Expand Down Expand Up @@ -641,12 +660,52 @@ def log_process_usage() -> None:


def main() -> None:
ThreadingInstrumentor().instrument()
warnings.filterwarnings("ignore", category=DeprecationWarning)
locale.setlocale(locale.LC_NUMERIC, "C")

# Have ErtThread re-raise uncaught exceptions on main thread
set_signal_handler()

# Service name is required for most backends
resource = Resource(attributes={SERVICE_NAME: "ert"})

logger_provider = LoggerProvider(resource=resource)
set_logger_provider(logger_provider)

# add the batch processors to the trace provider
otpl_handler = LoggingHandler(level=logging.DEBUG, logger_provider=logger_provider)
logger_provider.add_log_record_processor(
BatchLogRecordProcessor(
OTLPLogExporter(endpoint="http://127.0.0.1:4318/v1/logs")
)
)
logger_provider.add_log_record_processor(
BatchLogRecordProcessor(ConsoleLogExporter())
)

connection_string = os.getenv("AZ_CON_STRING")

traceProvider = TracerProvider(
resource=resource, span_limits=SpanLimits(max_events=128 * 16)
)
processor = BatchSpanProcessor(
OTLPSpanExporter(endpoint="http://127.0.0.1:4318/v1/traces")
)
traceProvider.add_span_processor(processor)
traceProvider.add_span_processor(
BatchSpanProcessor(
AzureMonitorTraceExporter(connection_string=connection_string)
)
)

# console_processor = BatchSpanProcessor(ConsoleSpanExporter())
# traceProvider.add_span_processor(console_processor)

LoggingInstrumentor(set_logging_format=True, tracer_provider=traceProvider)
trace.set_tracer_provider(traceProvider)
# Sets the global default tracer provider

args = ert_parser(None, sys.argv[1:])

log_dir = os.path.abspath(args.logdir)
Expand All @@ -671,33 +730,56 @@ def main() -> None:
handler.setLevel(logging.INFO)
root_logger.addHandler(handler)

FeatureScheduler.set_value(args)
try:
with ErtPluginContext(logger=logging.getLogger()) as context:
logger.info(f"Running ert with {args}")
args.func(args, context.plugin_manager)
except (ErtCliError, ErtTimeoutError) as err:
logger.exception(str(err))
sys.exit(str(err))
except ConfigValidationError as err:
err_msg = err.cli_message()
logger.exception(err_msg)
sys.exit(err_msg)
except BaseException as err:
logger.exception(f'ERT crashed unexpectedly with "{err}"')

logfiles = set() # Use set to avoid duplicates...
for loghandler in logging.getLogger().handlers:
if isinstance(loghandler, logging.FileHandler):
logfiles.add(loghandler.baseFilename)

msg = f'ERT crashed unexpectedly with "{err}".\nSee logfile(s) for details:'
msg += "\n " + "\n ".join(logfiles)

sys.exit(msg)
finally:
log_process_usage()
os.environ.pop("ERT_LOG_DIR")
logger = logging.getLogger()
logger.addHandler(otpl_handler)
# Creates a tracer from the global tracer provider
tracer = trace.get_tracer("ert.main")
with tracer.start_as_current_span("ert.application.start") as span:
FeatureScheduler.set_value(args)
try:
with ErtPluginContext(logger=logger) as context:
span.add_event(
"log",
{"log.severity": "info", "log.message": f"Running ert with {args}"},
)
args.func(args, context.plugin_manager)
except (ErtCliError, ErtTimeoutError) as err:
span.set_status(Status(StatusCode.ERROR))
span.record_exception(err)
span.add_event(
"log", {"log.severity": "exception", "log.message": str(err)}
)
sys.exit(str(err))
except ConfigValidationError as err:
span.set_status(Status(StatusCode.ERROR))
span.record_exception(err)
err_msg = err.cli_message()
span.add_event("log", {"log.severity": "exception", "log.message": err_msg})
sys.exit(err_msg)
except BaseException as err:
span.set_status(Status(StatusCode.ERROR))
span.record_exception(err)
span.add_event(
"log",
{
"log.severity": "exception",
"log.message": f'ERT crashed unexpectedly with "{err}"',
},
)

logfiles = set() # Use set to avoid duplicates...
for loghandler in logging.getLogger().handlers:
if isinstance(loghandler, logging.FileHandler):
logfiles.add(loghandler.baseFilename)

msg = f'ERT crashed unexpectedly with "{err}".\nSee logfile(s) for details:'
msg += "\n " + "\n ".join(logfiles)

sys.exit(msg)
finally:
log_process_usage()
os.environ.pop("ERT_LOG_DIR")
ThreadingInstrumentor().uninstrument()


if __name__ == "__main__":
Expand Down
114 changes: 65 additions & 49 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@
from ert.runpaths import Runpaths
from ert.storage import Storage

from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode

from .event import (
RunModelDataEvent,
RunModelErrorEvent,
Expand Down Expand Up @@ -316,29 +319,40 @@ def _clean_env_context(self) -> None:
def start_simulations_thread(
self, evaluator_server_config: EvaluatorServerConfig
) -> None:
try:
self.start_time = int(time.time())
with captured_logs(self._error_messages):
self._set_default_env_context()
self._initial_realizations_mask = (
self._simulation_arguments.active_realizations
)
run_context = self.run_experiment(
evaluator_server_config=evaluator_server_config,
)
self._completed_realizations_mask = run_context.mask
except ErtRunError as e:
self._completed_realizations_mask = []
self._failed = True
self._exception = e
self._simulationEnded()
except UserWarning as e:
self._exception = e
self._simulationEnded()
except Exception as e:
self._failed = True
self._exception = e
self._simulationEnded()
tracer = trace.get_tracer("ert.main")
with tracer.start_as_current_span("ert.run_model.start") as span:
try:
span.add_event("log", {"log.severity": "info", "log.message": f"Starting simulation thread {self.__class__.__name__}"})
self.start_time = int(time.time())
with captured_logs(self._error_messages):
self._set_default_env_context()
self._initial_realizations_mask = (
self._simulation_arguments.active_realizations
)
run_context = self.run_experiment(
evaluator_server_config=evaluator_server_config,
)
self._completed_realizations_mask = run_context.mask
except ErtRunError as e:
span.set_status(Status(StatusCode.ERROR))
span.record_exception(e)
span.add_event("log", {"log.severity": "exception", "log.message": f'Simulation ended with error "{e}"'})
self._completed_realizations_mask = []
self._failed = True
self._exception = e
self._simulationEnded()
except UserWarning as e:
span.record_exception(e)
span.add_event("log", {"log.severity": "exception", "log.message": f'Simulation ended with warning "{e}"'})
self._exception = e
self._simulationEnded()
except Exception as e:
span.set_status(Status(StatusCode.ERROR))
span.record_exception(e)
span.add_event("log", {"log.severity": "exception", "log.message": f'Simulation ended with error "{e}"'})
self._failed = True
self._exception = e
self._simulationEnded()

def run_experiment(
self, evaluator_server_config: EvaluatorServerConfig
Expand Down Expand Up @@ -541,37 +555,39 @@ async def run_monitor(self, ee_config: EvaluatorServerConfig) -> bool:
def run_ensemble_evaluator(
self, run_context: RunContext, ee_config: EvaluatorServerConfig
) -> List[int]:
if not self._end_queue.empty():
event_logger.debug("Run model canceled - pre evaluation")
self._end_queue.get()
return []
ensemble = self._build_ensemble(run_context)
evaluator = EnsembleEvaluator(
ensemble,
ee_config,
run_context.iteration,
)
evaluator.start_running()
tracer = trace.get_tracer("ert.main")
with tracer.start_as_current_span("ert.run_model.run_ensemble") as span:
if not self._end_queue.empty():
event_logger.debug("Run model canceled - pre evaluation")
self._end_queue.get()
return []
ensemble = self._build_ensemble(run_context)
evaluator = EnsembleEvaluator(
ensemble,
ee_config,
run_context.iteration,
)
evaluator.start_running()

if not get_running_loop().run_until_complete(self.run_monitor(ee_config)):
return []
if not get_running_loop().run_until_complete(self.run_monitor(ee_config)):
return []

event_logger.debug(
"observed that model was finished, waiting tasks completion..."
)
# The model has finished, we indicate this by sending a DONE
event_logger.debug("tasks complete")
event_logger.debug(
"observed that model was finished, waiting tasks completion..."
)
# The model has finished, we indicate this by sending a DONE
event_logger.debug("tasks complete")

evaluator.join()
if not self._end_queue.empty():
event_logger.debug("Run model canceled - post evaluation")
self._end_queue.get()
return []
evaluator.join()
if not self._end_queue.empty():
event_logger.debug("Run model canceled - post evaluation")
self._end_queue.get()
return []

run_context.ensemble.unify_parameters()
run_context.ensemble.unify_responses()
run_context.ensemble.unify_parameters()
run_context.ensemble.unify_responses()

return evaluator.get_successful_realizations()
return evaluator.get_successful_realizations()

def _build_ensemble(
self,
Expand Down
21 changes: 10 additions & 11 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

_TERMINATE_TIMEOUT = 10.0

logger = logging.getLogger(__name__)

from opentelemetry import trace

class LocalDriver(Driver):
def __init__(self) -> None:
super().__init__()
self._tasks: MutableMapping[int, asyncio.Task[None]] = {}
self._sent_finished_events: Set[int] = set()
self._current_span = trace.get_current_span()
print(self._current_span)

async def submit(
self,
Expand All @@ -41,27 +42,25 @@ async def submit(
async def kill(self, iens: int) -> None:
try:
self._tasks[iens].cancel()
logger.info(f"Killing realization {iens}")
self._current_span.add_event("log", {"log.severity": "info", "log.message": f"Killing realization {iens}"})
with contextlib.suppress(asyncio.CancelledError):
await self._tasks[iens]
del self._tasks[iens]
await self._dispatch_finished_event(iens, signal.SIGTERM + SIGNAL_OFFSET)

except KeyError:
logger.info(f"Realization {iens} is already killed")
self._current_span.add_event("log", {"log.severity": "info", "log.message": f"Realization {iens} is already killed"})
return
except Exception as err:
logger.error(f"Killing realization {iens} failed with error {err}")
self._current_span.add_event("log", {"log.severity": "error", "log.message": f"Killing realization {iens} failed with error {err}"})
raise err

async def finish(self) -> None:
await asyncio.gather(*self._tasks.values())
logger.info("All realization tasks finished")
self._current_span.add_event("log", {"log.severity": "info", "log.message": "All realization tasks finished"})

async def _run(self, iens: int, executable: str, /, *args: str) -> None:
logger.debug(
f"Submitting realization {iens} as command '{executable} {' '.join(args)}'"
)
self._current_span.add_event("log", {"log.severity": "debug", "log.message": f"Submitting realization {iens} as command '{executable} {' '.join(args)}'"})
try:
proc = await self._init(
iens,
Expand All @@ -72,7 +71,7 @@ async def _run(self, iens: int, executable: str, /, *args: str) -> None:
# /bin/sh uses returncode 127 for FileNotFound, so copy that
# behaviour.
msg = f"Realization {iens} failed with {err}"
logger.error(msg)
self._current_span.add_event("log", {"log.severity": "error", "log.message": msg})
self._job_error_message_by_iens[iens] = msg
await self._dispatch_finished_event(iens, 127)
return
Expand All @@ -82,7 +81,7 @@ async def _run(self, iens: int, executable: str, /, *args: str) -> None:
returncode = 0
try:
returncode = await self._wait(proc)
logger.info(f"Realization {iens} finished with {returncode=}")
self._current_span.add_event("log", {"log.severity": "info", "log.message": f"Realization {iens} finished with {returncode=}"})
except asyncio.CancelledError:
returncode = await self._kill(proc)
finally:
Expand Down

0 comments on commit 8b6d0f2

Please sign in to comment.