Skip to content

Commit

Permalink
Merge pull request #7 from maekawataiki/dev
Browse files Browse the repository at this point in the history
Rinna / Kendra / Project 名の修正
  • Loading branch information
ysekiy authored Oct 7, 2023
2 parents 623faad + 95d908f commit 11caf5e
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 151 deletions.
2 changes: 1 addition & 1 deletion CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Code of Conduct
このプロジェクトは [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct) を採用しています。
詳細については、[Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) もしくは
opensource-codeofconduct@amazon.com にコメント下さい。
opensource-codeofconduct@amazon.com にコメント下さい。
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
[SECURITY](SECURITY.md) を参照して下さい。

## Licensing
プロジェクトのライセンスについては、[LICENSE](LICENSE.txt) ファイルを参照してください。
プロジェクトのライセンスについては、[LICENSE](LICENSE.txt) ファイルを参照してください。
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
2 changes: 1 addition & 1 deletion amplify/.config/project-config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"projectName": "jpragsampleamplify",
"projectName": "rag",
"version": "3.1",
"frontend": "javascript",
"javascript": {
Expand Down
118 changes: 25 additions & 93 deletions amplify/backend/api/fargate/fargate-cloudformation-template.json
Original file line number Diff line number Diff line change
Expand Up @@ -156,47 +156,17 @@
"",
[
{
"Fn::Select": [
4,
{
"Fn::Split": [
":",
{
"Fn::GetAtt": [
"langchainRepository",
"Arn"
]
}
]
}
]
"Ref": "AWS::AccountId"
},
".dkr.ecr.",
{
"Fn::Select": [
3,
{
"Fn::Split": [
":",
{
"Fn::GetAtt": [
"langchainRepository",
"Arn"
]
}
]
}
]
"Ref": "AWS::Region"
},
".",
{
"Ref": "AWS::URLSuffix"
},
"/",
{
"Ref": "langchainRepository"
},
":latest"
"/amplify-rag-dev-102319-api-fargate-langchain:latest"
]
]
},
Expand Down Expand Up @@ -299,9 +269,23 @@
],
"Effect": "Allow",
"Resource": {
"Fn::GetAtt": [
"langchainRepository",
"Arn"
"Fn::Join": [
"",
[
"arn:",
{
"Ref": "AWS::Partition"
},
":ecr:",
{
"Ref": "AWS::Region"
},
":",
{
"Ref": "AWS::AccountId"
},
":repository/amplify-rag-dev-102319-api-fargate-langchain"
]
]
}
},
Expand Down Expand Up @@ -391,27 +375,6 @@
"UpdateReplacePolicy": "Delete",
"DeletionPolicy": "Delete"
},
"langchainRepository": {
"Type": "AWS::ECR::Repository",
"Properties": {
"LifecyclePolicy": {
"LifecyclePolicyText": "{\"rules\":[{\"rulePriority\":10,\"selection\":{\"tagStatus\":\"tagged\",\"tagPrefixList\":[\"latest\"],\"countType\":\"imageCountMoreThan\",\"countNumber\":1},\"action\":{\"type\":\"expire\"}},{\"rulePriority\":100,\"selection\":{\"tagStatus\":\"any\",\"countType\":\"sinceImagePushed\",\"countNumber\":7,\"countUnit\":\"days\"},\"action\":{\"type\":\"expire\"}}]}"
},
"RepositoryName": {
"Fn::Join": [
"",
[
{
"Ref": "rootStackName"
},
"-api-fargate-langchain"
]
]
}
},
"UpdateReplacePolicy": "Retain",
"DeletionPolicy": "Retain"
},
"ServiceSG": {
"Type": "AWS::EC2::SecurityGroup",
"Properties": {
Expand Down Expand Up @@ -444,7 +407,7 @@
"Cluster": {
"Ref": "NetworkStackClusterName"
},
"DesiredCount": 0,
"DesiredCount": 1,
"LaunchType": "FARGATE",
"NetworkConfiguration": {
"AwsvpcConfiguration": {
Expand Down Expand Up @@ -980,47 +943,17 @@
},
"\"},{\"name\":\"langchain_REPOSITORY_URI\",\"type\":\"PLAINTEXT\",\"value\":\"",
{
"Fn::Select": [
4,
{
"Fn::Split": [
":",
{
"Fn::GetAtt": [
"langchainRepository",
"Arn"
]
}
]
}
]
"Ref": "AWS::AccountId"
},
".dkr.ecr.",
{
"Fn::Select": [
3,
{
"Fn::Split": [
":",
{
"Fn::GetAtt": [
"langchainRepository",
"Arn"
]
}
]
}
]
"Ref": "AWS::Region"
},
".",
{
"Ref": "AWS::URLSuffix"
},
"/",
{
"Ref": "langchainRepository"
},
"\"}]"
"/amplify-rag-dev-102319-api-fargate-langchain\"}]"
]
]
}
Expand Down Expand Up @@ -2646,6 +2579,5 @@
"states": "states.us-west-2.amazonaws.com"
}
}
},
"Description": "{\"createdOn\":\"Mac\",\"createdBy\":\"Amplify\",\"createdWith\":\"12.1.1\",\"stackType\":\"api-ElasticContainer\",\"metadata\":{}}"
}
}
2 changes: 1 addition & 1 deletion amplify/backend/api/fargate/parameters.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"ParamZipPath": "amplify-builds/fargate-73386155644967766f52-build.zip"
"ParamZipPath": "amplify-builds/fargate-6572486d4e6155524f79-build.zip"
}
11 changes: 4 additions & 7 deletions amplify/backend/api/fargate/src/langchain/app/chain/rinna.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ class RinnaContentHandler(LLMContentHandler):
def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
input_str = json.dumps(
{
"instruction": "",
"input": prompt.replace("\n", "<NL>"),
**model_kwargs,
"inputs": prompt.replace("\n", "<NL>"),
"parameters": model_kwargs
}
)
print("prompt: ", prompt)
return input_str.encode("utf-8")

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json.replace("<NL>", "\n")
response = response_json[0]['generated_text']
return response.replace("<NL>", "\n")


def build_rinna_chain(endpoint_name: str, aws_region: str) -> LLMChain:
Expand All @@ -52,9 +52,6 @@ def build_rinna_chain(endpoint_name: str, aws_region: str) -> LLMChain:
"max_new_tokens": 256,
"temperature": 0.3,
"do_sample": True,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 3,
},
content_handler=content_handler,
)
Expand Down
1 change: 1 addition & 0 deletions kendra/kendra-docs-index.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ Resources:
Name: 'KendraDocsDS'
RoleArn: !GetAtt KendraDSRole.Arn
Type: 'WEBCRAWLER'
LanguageCode: ja

DataSourceSyncLambdaRole:
Type: AWS::IAM::Role
Expand Down
6 changes: 0 additions & 6 deletions llm/delete_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,12 @@ pip3 install sagemaker
python3 <<EOF
import sagemaker
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
sess = sagemaker.Session()
endpoint_name = "Rinna-Inference"
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sess,
serializer=JSONSerializer(),
deserializer=JSONDeserializer()
)
predictor.delete_model()
Expand Down
47 changes: 7 additions & 40 deletions llm/deploy_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,13 @@
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Licensed under the MIT-0 License (https://github.com/aws/mit-0)

git clone https://github.com/aws-samples/aws-ml-jp.git /tmp/aws-ml-jp
cd /tmp/aws-ml-jp/tasks/generative-ai/text-to-text/fine-tuning/instruction-tuning/Transformers/scripts
tar -czvf ../package.tar.gz *
cd ..
pip3 install sagemaker
pip3 install "sagemaker>=2.168.0"
python3 <<EOF
import sagemaker, boto3, json
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.serializers import JSONSerializer
from sagemaker.jumpstart.model import JumpStartModel
role = sagemaker.get_execution_role()
sess = sagemaker.Session()
bucket = sess.default_bucket()
region = sess._region_name
model_path = sess.upload_data('package.tar.gz', bucket=bucket, key_prefix=f"Rinna-Inference")
model_path
endpoint_name = "Rinna-Inference"
huggingface_model = PyTorchModel(
model_data=model_path,
framework_version="2.0",
py_version='py310',
role=role,
name=endpoint_name,
env={
"model_params": json.dumps({
"base_model": "rinna/japanese-gpt-neox-3.6b-instruction-ppo",
"peft": False,
"prompt_template": "rinna",
}),
"SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600"
}
)
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type='ml.g4dn.xlarge',
endpoint_name=endpoint_name,
serializer=JSONSerializer()
model = JumpStartModel(
model_id="huggingface-llm-rinna-3-6b-instruction-ppo-bf16",
model_version="v1.2.0"
)
EOF
predictor = model.deploy(endpoint_name="Rinna-Inference")
EOF

0 comments on commit 11caf5e

Please sign in to comment.