Skip to content

Commit

Permalink
Add get_engine and deprecate engine_from_environment (quantumlib#3151)
Browse files Browse the repository at this point in the history
quantumlib#2767 introduced `get_engine_sampler`.  This adds `get_engine` via a similar pattern, and in particular uses the same shared environment variable for projects, GOOGLE_PROJECT_ID.

Also
* deprecates `engine_from_environment` which used a different environment variable and a different nomenclature.  It also fixes a bug in this methods id.
* corrects a typo in the doc of `get_engine_sampler`
* Gives the a sensible small string for an Engine object.
  • Loading branch information
dabacon authored Jul 22, 2020
1 parent 8a0b054 commit 66bd758
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 33 deletions.
1 change: 1 addition & 0 deletions cirq/google/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
EngineTimeSlot,
ProtoVersion,
QuantumEngineSampler,
get_engine,
get_engine_sampler,
)

Expand Down
1 change: 1 addition & 0 deletions cirq/google/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from cirq.google.engine.engine import (
Engine,
get_engine,
ProtoVersion,
)

Expand Down
45 changes: 40 additions & 5 deletions cirq/google/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import datetime
import enum
import os
import random
import string
from typing import Dict, List, Optional, Sequence, TypeVar, Union, TYPE_CHECKING
Expand Down Expand Up @@ -108,7 +109,6 @@ def copy(self) -> 'EngineContext':
def _value_equality_values_(self):
return self.proto_version, self.client


class Engine:
"""Runs programs via the Quantum Engine API.
Expand All @@ -133,8 +133,8 @@ def __init__(
proto_version: Optional[ProtoVersion] = None,
service_args: Optional[Dict] = None,
verbose: Optional[bool] = None,
context: Optional[EngineContext] = None,
timeout: Optional[int] = None,
context: Optional[EngineContext] = None,
) -> None:
"""Supports creating and running programs against the Quantum Engine.
Expand All @@ -143,7 +143,6 @@ def __init__(
API interactions will be attributed to this project and any
resources created will be owned by the project. See
https://cloud.google.com/resource-manager/docs/creating-managing-projects#identifying_projects
context: Engine configuration and context to use.
proto_version: The version of cirq protos to use. If None, then
ProtoVersion.V2 will be used.
service_args: A dictionary of arguments that can be used to
Expand All @@ -152,11 +151,13 @@ def __init__(
true.
timeout: Timeout for polling for results, in seconds. Default is
to never timeout.
context: Engine configuration and context to use. For most users
this should never be specified.
"""
if context and (proto_version or service_args or verbose):
raise ValueError(
'either provide context or proto_version, service_args'
' and verbose')
'Either provide context or proto_version, service_args'
' and verbose.')

self.project_id = project_id
if not context:
Expand All @@ -166,6 +167,9 @@ def __init__(
timeout=timeout)
self.context = context

def __str__(self) -> str:
return f'Engine(project_id={self.project_id!r})'

def run(
self,
program: 'cirq.Circuit',
Expand Down Expand Up @@ -518,3 +522,34 @@ def sampler(self, processor_id: Union[str, List[str]],
return engine_sampler.QuantumEngineSampler(engine=self,
processor_id=processor_id,
gate_set=gate_set)


def get_engine(project_id: Optional[str] = None) -> Engine:
"""Get an Engine instance assuming some sensible defaults.
This uses the environment variable GOOGLE_CLOUD_PROJECT for the Engine
project_id, unless set explicitly. By using an environment variable,
you can avoid hard-coding the project_id in shared code.
If the environment variables are set, but incorrect, an authentication
failure will occur when attempting to run jobs on the engine.
Args:
project_id: If set overrides the project id obtained from the
environment variable `GOOGLE_CLOUD_PROJECT`.
Returns:
The Engine instance.
Raises:
EnvironmentError: If the environment variable GOOGLE_CLOUD_PROJECT is
not set.
"""
env_project_id = 'GOOGLE_CLOUD_PROJECT'
if not project_id:
project_id = os.environ.get(env_project_id)
if not project_id:
raise EnvironmentError(
f'Environment variable {env_project_id} is not set.')

return Engine(project_id=project_id)
19 changes: 11 additions & 8 deletions cirq/google/engine/engine_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List, TYPE_CHECKING, Union, Optional, cast

from cirq import work, circuits
Expand Down Expand Up @@ -72,7 +71,7 @@ def get_engine_sampler(processor_id: str, gate_set_name: str,
-> 'cirq.google.QuantumEngineSampler':
"""Get an EngineSampler assuming some sensible defaults.
This uses the environment variable GOOGLE_GLOUD_PROJECT for the Engine
This uses the environment variable GOOGLE_CLOUD_PROJECT for the Engine
project_id, unless set explicitly.
Args:
Expand All @@ -84,16 +83,20 @@ def get_engine_sampler(processor_id: str, gate_set_name: str,
this defaults to the environment variable GOOGLE_CLOUD_PROJECT.
By using an environment variable, you can avoid hard-coding
personal project IDs in shared code.
Returns:
A `QuantumEngineSampler` instance.
Raises:
ValueError: If the supplied gate set is not a supported gate set name.
EnvironmentError: If no project_id is specified and the environment
variable GOOGLE_CLOUD_PROJECT is not set.
"""
try:
gate_set = gate_sets.NAMED_GATESETS[gate_set_name]
except KeyError:
raise ValueError(f"Please use one of the following gateset names: "
f"{sorted(gate_sets.NAMED_GATESETS.keys())}")

if project_id is None:
project_id = os.environ['GOOGLE_CLOUD_PROJECT']

return engine.Engine(project_id=project_id,
proto_version=engine.ProtoVersion.V2) \
.sampler(processor_id=processor_id, gate_set=gate_set)
return engine.get_engine(project_id).sampler(processor_id=processor_id,
gate_set=gate_set)
33 changes: 31 additions & 2 deletions cirq/google/engine/engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Tests for engine."""
import os
from unittest import mock
import numpy as np
import pytest
Expand Down Expand Up @@ -258,8 +259,11 @@ def test_create_engine(client):
with pytest.raises(
ValueError,
match='provide context or proto_version, service_args and verbose'):
cg.Engine('proj', mock.Mock(), cg.engine.engine.ProtoVersion.V2,
{'args': 'test'}, True)
cg.Engine('proj',
proto_version=cg.engine.engine.ProtoVersion.V2,
service_args={'args': 'test'},
verbose=True,
context=mock.Mock())

assert cg.Engine(
'proj',
Expand All @@ -271,6 +275,14 @@ def test_create_engine(client):
assert client.called_with({'args': 'test'}, True)


def test_engine_str():
engine = cg.Engine('proj',
proto_version=cg.engine.engine.ProtoVersion.V2,
service_args={'args': 'test'},
verbose=True)
assert str(engine) == 'Engine(project_id=\'proj\')'


def setup_run_circuit_with_result_(client, result):
client().create_program.return_value = (
'prog', qtypes.QuantumProgram(name='projects/proj/programs/prog'))
Expand Down Expand Up @@ -692,3 +704,20 @@ def test_sampler(client):
assert results[i].params.param_dict == {'a': v}
assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')}
assert client().create_program.call_args[0][0] == 'proj'


@mock.patch('cirq.google.engine.client.quantum.QuantumEngineServiceClient')
def test_get_engine(build):
# Default project id present.
with mock.patch.dict(os.environ, {
'GOOGLE_CLOUD_PROJECT': 'project!',
},
clear=True):
eng = cirq.google.get_engine()
assert eng.project_id == 'project!'

# Nothing present.
with mock.patch.dict(os.environ, {}, clear=True):
with pytest.raises(EnvironmentError, match='GOOGLE_CLOUD_PROJECT'):
_ = cirq.google.get_engine()
_ = cirq.google.get_engine('project!')
18 changes: 12 additions & 6 deletions cirq/google/engine/env_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,27 @@

import os

from cirq.google.engine.engine import Engine
from typing import TYPE_CHECKING

from cirq.google import engine
from cirq._compat import deprecated

if TYPE_CHECKING:
import cirq

ENV_PROJECT_ID = 'CIRQ_QUANTUM_ENGINE_DEFAULT_PROJECT_ID'


def engine_from_environment() -> Engine:
@deprecated(deadline='v0.10.0', fix='Use cirq.get_engine instead.')
def engine_from_environment() -> 'cirq.google.Engine':
"""Returns an Engine instance configured using environment variables.
If the environment variables are set, but incorrect, an authentication
failure will occur when attempting to run jobs on the engine.
Required Environment Variables:
QUANTUM_ENGINE_PROJECT: The name of a google cloud project, with the
quantum engine enabled, that you have access to.
CIRQ_QUANTUM_ENGINE_DEFAULT_PROJECT_ID: The name of a google cloud
project, with the quantum engine enabled, that you have access to.
Raises:
EnvironmentError: The environment variables are not set.
Expand All @@ -38,5 +45,4 @@ def engine_from_environment() -> Engine:
if not project_id:
raise EnvironmentError(
'Environment variable {} is not set.'.format(ENV_PROJECT_ID))

return Engine(project_id=project_id)
return engine.get_engine(project_id)
31 changes: 20 additions & 11 deletions cirq/google/engine/env_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,32 @@

import os
from unittest import mock
import warnings
import pytest

import cirq


@mock.patch('cirq.google.engine.client.quantum.QuantumEngineServiceClient')
def test_engine_from_environment(build):
# Default project id present.
with mock.patch.dict(os.environ, {
'CIRQ_QUANTUM_ENGINE_DEFAULT_PROJECT_ID': 'project!',
},
clear=True):
eng = cirq.google.engine_from_environment()
assert eng.project_id == 'project!'
with warnings.catch_warnings():
warnings.simplefilter('ignore')
# Default project id present.
env_dict = {'CIRQ_QUANTUM_ENGINE_DEFAULT_PROJECT_ID': 'project!'}
with mock.patch.dict(os.environ, env_dict, clear=True):
eng = cirq.google.engine_from_environment()
assert eng.project_id == 'project!'

# Nothing present.
with mock.patch.dict(os.environ, {}, clear=True):
with pytest.raises(EnvironmentError,
match='CIRQ_QUANTUM_ENGINE_DEFAULT_PROJECT_ID'):
# Nothing present.
with mock.patch.dict(os.environ, {}, clear=True):
with pytest.raises(EnvironmentError,
match='CIRQ_QUANTUM_ENGINE_DEFAULT_PROJECT_ID'):
_ = cirq.google.engine_from_environment()


def test_deprecation():
with cirq.testing.assert_logs('engine_from_environment', 'get_engine',
'deprecated'):
env_dict = {'CIRQ_QUANTUM_ENGINE_DEFAULT_PROJECT_ID': 'project!'}
with mock.patch.dict(os.environ, env_dict, clear=True):
_ = cirq.google.engine_from_environment()
3 changes: 2 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ Functionality specific to quantum hardware and services from Google.
cirq.google.SYC
cirq.google.SYC_GATESET
cirq.google.XMON
cirq.google.engine_from_environment
cirq.google.get_engine
cirq.google.get_engine_sampler
cirq.google.is_native_xmon_gate
cirq.google.is_native_xmon_op
Expand Down Expand Up @@ -728,3 +728,4 @@ These objects and methods will be removed in a future version of the library.
cirq.WaveFunctionSimulatorState
cirq.WaveFunctionStepResult
cirq.WaveFunctionTrialResult
cirq.google.engine_from_environment

0 comments on commit 66bd758

Please sign in to comment.