-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathhandler.py
More file actions
50 lines (39 loc) · 1.46 KB
/
handler.py
File metadata and controls
50 lines (39 loc) · 1.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import re
import torch
import boto3
from botocore import UNSIGNED
from botocore.client import Config
from model import IrisNet
labels = ["setosa", "versicolor", "virginica"]
class Handler:
def __init__(self, config):
# extract s3 model path
bucket, key = re.match("s3://(.+?)/(.+)", config["model"]).groups()
# download the model
s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
s3.download_file(bucket, key, "/tmp/model.pth")
# initialize the model
model = IrisNet()
model.load_state_dict(torch.load("/tmp/model.pth"))
model.eval()
self.model = model
def handle_post(self, payload):
responses = []
# note: this is not the most efficient way, it's just to test server-side batching
for sample in payload:
# Convert the request to a tensor and pass it into the model
input_tensor = torch.FloatTensor(
[
[
sample["sepal_length"],
sample["sepal_width"],
sample["petal_length"],
sample["petal_width"],
]
]
)
# Run the prediction
output = self.model(input_tensor)
# Translate the model output to the corresponding label string
responses.append(labels[torch.argmax(output[0])])
return responses