Skip to content

Commit 2e0607a

Browse files
authored
Merge pull request #1819 from weaviate/max_tokens_aws
Add max token to AWS generative module
2 parents 6a7d42a + 05bdc63 commit 2e0607a

File tree

11 files changed

+278
-266
lines changed

11 files changed

+278
-266
lines changed

test/collection/test_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2060,7 +2060,7 @@ def test_config_with_named_vectors(
20602060
(
20612061
[
20622062
Configure.Vectors.text2vec_aws(
2063-
name="test", region="us-east-1", source_properties=["prop"]
2063+
name="test", region="us-east-1", source_properties=["prop"], model="model"
20642064
)
20652065
],
20662066
{
@@ -2071,6 +2071,7 @@ def test_config_with_named_vectors(
20712071
"vectorizeClassName": True,
20722072
"region": "us-east-1",
20732073
"service": "bedrock",
2074+
"model": "model",
20742075
}
20752076
},
20762077
"vectorIndexType": "hnsw",

weaviate/collections/classes/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ class _GenerativeAWSConfig(_GenerativeProvider):
464464
service: str
465465
model: Optional[str]
466466
endpoint: Optional[str]
467+
maxTokens: Optional[int]
467468

468469

469470
class _GenerativeAnthropicConfig(_GenerativeProvider):
@@ -891,6 +892,7 @@ def aws(
891892
region: str = "", # cant have a non-default value after a default value, but we cant change the order for BC
892893
endpoint: Optional[str] = None,
893894
service: Union[AWSService, str] = "bedrock",
895+
max_tokens: Optional[int] = None,
894896
) -> _GenerativeProvider:
895897
"""Create a `_GenerativeAWSConfig` object for use when performing AI generation using the `generative-aws` module.
896898
@@ -899,15 +901,13 @@ def aws(
899901
900902
Args:
901903
model: The model to use, REQUIRED for service "bedrock".
904+
max_tokens: The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default.
902905
region: The AWS region to run the model from, REQUIRED.
903906
endpoint: The model to use, REQUIRED for service "sagemaker".
904907
service: The AWS service to use, options are "bedrock" and "sagemaker".
905908
"""
906909
return _GenerativeAWSConfig(
907-
model=model,
908-
region=region,
909-
service=service,
910-
endpoint=endpoint,
910+
model=model, region=region, service=service, endpoint=endpoint, maxTokens=max_tokens
911911
)
912912

913913
@staticmethod

weaviate/collections/classes/config_vectors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def text2vec_aws(
652652
name: Optional[str] = None,
653653
quantizer: Optional[_QuantizerConfigCreate] = None,
654654
endpoint: Optional[str] = None,
655-
model: Optional[Union[AWSModel, str]] = None,
655+
model: Optional[Union[AWSModel, str]],
656656
region: str,
657657
service: Union[AWSService, str] = "bedrock",
658658
source_properties: Optional[List[str]] = None,
@@ -668,7 +668,7 @@ def text2vec_aws(
668668
name: The name of the vector.
669669
quantizer: The quantizer to use for the vector index. If not provided, no quantization will be applied.
670670
endpoint: The endpoint to use. Defaults to `None`, which uses the server-defined default.
671-
model: The model to use.
671+
model: The model to use, REQUIRED.
672672
region: The AWS region to run the model from, REQUIRED.
673673
service: The AWS service to use. Defaults to `bedrock`.
674674
source_properties: Which properties should be included when vectorizing. By default all text properties are included.

weaviate/collections/classes/generative.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class _GenerativeAWS(_GenerativeConfigRuntime):
103103
generative: Union[GenerativeSearches, _EnumLikeStr] = Field(
104104
default=GenerativeSearches.AWS, frozen=True, exclude=True
105105
)
106+
max_tokens: Optional[int]
106107
model: Optional[str]
107108
region: Optional[str]
108109
endpoint: Optional[AnyHttpUrl]
@@ -122,6 +123,7 @@ def _to_grpc(self, opts: _GenerativeConfigRuntimeOptions) -> generative_pb2.Gene
122123
target_model=self.target_model,
123124
target_variant=self.target_variant,
124125
temperature=self.temperature,
126+
max_tokens=self.max_tokens,
125127
images=_to_text_array(opts.images),
126128
image_properties=_to_text_array(opts.image_properties),
127129
),
@@ -474,6 +476,7 @@ def anyscale(
474476
def aws(
475477
*,
476478
endpoint: Optional[str] = None,
479+
max_tokens: Optional[int] = None,
477480
model: Optional[str] = None,
478481
region: Optional[str] = None,
479482
service: Optional[Union[AWSService, str]] = None,
@@ -488,6 +491,7 @@ def aws(
488491
489492
Args:
490493
endpoint: The endpoint to use when requesting the generation. Defaults to `None`, which uses the server-defined default
494+
max_tokens: The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default
491495
model: The model to use. Defaults to `None`, which uses the server-defined default
492496
region: The AWS region to run the model from. Defaults to `None`, which uses the server-defined default
493497
service: The AWS service to use. Defaults to `None`, which uses the server-defined default
@@ -497,6 +501,7 @@ def aws(
497501
"""
498502
return _GenerativeAWS(
499503
model=model,
504+
max_tokens=max_tokens,
500505
region=region,
501506
service=service,
502507
endpoint=AnyUrl(endpoint) if endpoint is not None else None,

weaviate/collections/classes/internal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def __init__(
306306
def to_grpc(self, server_version: _ServerVersion) -> generative_pb2.GenerativeSearch:
307307
if server_version.is_lower_than(1, 27, 14):
308308
if self.generative_provider is not None:
309-
raise WeaviateUnsupportedFeatureError("Dynamic RAG", str(server_version), "1.30.0")
309+
raise WeaviateUnsupportedFeatureError("Dynamic RAG", str(server_version), "1.27.14")
310310

311311
if isinstance(self.single, _SinglePrompt):
312312
single_prompt: Optional[str] = self.single.prompt

weaviate/proto/v1/v4216/v1/generative_pb2.py

Lines changed: 84 additions & 84 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

weaviate/proto/v1/v4216/v1/generative_pb2.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class GenerativeAnyscale(_message.Message):
105105
def __init__(self, base_url: _Optional[str] = ..., model: _Optional[str] = ..., temperature: _Optional[float] = ...) -> None: ...
106106

107107
class GenerativeAWS(_message.Message):
108-
__slots__ = ["model", "temperature", "service", "region", "endpoint", "target_model", "target_variant", "images", "image_properties"]
108+
__slots__ = ["model", "temperature", "service", "region", "endpoint", "target_model", "target_variant", "images", "image_properties", "max_tokens"]
109109
MODEL_FIELD_NUMBER: _ClassVar[int]
110110
TEMPERATURE_FIELD_NUMBER: _ClassVar[int]
111111
SERVICE_FIELD_NUMBER: _ClassVar[int]
@@ -115,6 +115,7 @@ class GenerativeAWS(_message.Message):
115115
TARGET_VARIANT_FIELD_NUMBER: _ClassVar[int]
116116
IMAGES_FIELD_NUMBER: _ClassVar[int]
117117
IMAGE_PROPERTIES_FIELD_NUMBER: _ClassVar[int]
118+
MAX_TOKENS_FIELD_NUMBER: _ClassVar[int]
118119
model: str
119120
temperature: float
120121
service: str
@@ -124,7 +125,8 @@ class GenerativeAWS(_message.Message):
124125
target_variant: str
125126
images: _base_pb2.TextArray
126127
image_properties: _base_pb2.TextArray
127-
def __init__(self, model: _Optional[str] = ..., temperature: _Optional[float] = ..., service: _Optional[str] = ..., region: _Optional[str] = ..., endpoint: _Optional[str] = ..., target_model: _Optional[str] = ..., target_variant: _Optional[str] = ..., images: _Optional[_Union[_base_pb2.TextArray, _Mapping]] = ..., image_properties: _Optional[_Union[_base_pb2.TextArray, _Mapping]] = ...) -> None: ...
128+
max_tokens: int
129+
def __init__(self, model: _Optional[str] = ..., temperature: _Optional[float] = ..., service: _Optional[str] = ..., region: _Optional[str] = ..., endpoint: _Optional[str] = ..., target_model: _Optional[str] = ..., target_variant: _Optional[str] = ..., images: _Optional[_Union[_base_pb2.TextArray, _Mapping]] = ..., image_properties: _Optional[_Union[_base_pb2.TextArray, _Mapping]] = ..., max_tokens: _Optional[int] = ...) -> None: ...
128130

129131
class GenerativeCohere(_message.Message):
130132
__slots__ = ["base_url", "frequency_penalty", "max_tokens", "model", "k", "p", "presence_penalty", "stop_sequences", "temperature"]

weaviate/proto/v1/v5261/v1/generative_pb2.py

Lines changed: 84 additions & 84 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

weaviate/proto/v1/v5261/v1/generative_pb2.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class GenerativeAnyscale(_message.Message):
105105
def __init__(self, base_url: _Optional[str] = ..., model: _Optional[str] = ..., temperature: _Optional[float] = ...) -> None: ...
106106

107107
class GenerativeAWS(_message.Message):
108-
__slots__ = ("model", "temperature", "service", "region", "endpoint", "target_model", "target_variant", "images", "image_properties")
108+
__slots__ = ("model", "temperature", "service", "region", "endpoint", "target_model", "target_variant", "images", "image_properties", "max_tokens")
109109
MODEL_FIELD_NUMBER: _ClassVar[int]
110110
TEMPERATURE_FIELD_NUMBER: _ClassVar[int]
111111
SERVICE_FIELD_NUMBER: _ClassVar[int]
@@ -115,6 +115,7 @@ class GenerativeAWS(_message.Message):
115115
TARGET_VARIANT_FIELD_NUMBER: _ClassVar[int]
116116
IMAGES_FIELD_NUMBER: _ClassVar[int]
117117
IMAGE_PROPERTIES_FIELD_NUMBER: _ClassVar[int]
118+
MAX_TOKENS_FIELD_NUMBER: _ClassVar[int]
118119
model: str
119120
temperature: float
120121
service: str
@@ -124,7 +125,8 @@ class GenerativeAWS(_message.Message):
124125
target_variant: str
125126
images: _base_pb2.TextArray
126127
image_properties: _base_pb2.TextArray
127-
def __init__(self, model: _Optional[str] = ..., temperature: _Optional[float] = ..., service: _Optional[str] = ..., region: _Optional[str] = ..., endpoint: _Optional[str] = ..., target_model: _Optional[str] = ..., target_variant: _Optional[str] = ..., images: _Optional[_Union[_base_pb2.TextArray, _Mapping]] = ..., image_properties: _Optional[_Union[_base_pb2.TextArray, _Mapping]] = ...) -> None: ...
128+
max_tokens: int
129+
def __init__(self, model: _Optional[str] = ..., temperature: _Optional[float] = ..., service: _Optional[str] = ..., region: _Optional[str] = ..., endpoint: _Optional[str] = ..., target_model: _Optional[str] = ..., target_variant: _Optional[str] = ..., images: _Optional[_Union[_base_pb2.TextArray, _Mapping]] = ..., image_properties: _Optional[_Union[_base_pb2.TextArray, _Mapping]] = ..., max_tokens: _Optional[int] = ...) -> None: ...
128130

129131
class GenerativeCohere(_message.Message):
130132
__slots__ = ("base_url", "frequency_penalty", "max_tokens", "model", "k", "p", "presence_penalty", "stop_sequences", "temperature")

weaviate/proto/v1/v6300/v1/generative_pb2.py

Lines changed: 84 additions & 84 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)