Skip to content

Commit 9e1ef50

Browse files
Add status code to endpoint response (#172)
* add status code * filter out extraneous fields explicitly * black
1 parent 4f468aa commit 9e1ef50

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

‎launch/model_endpoint.py‎

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
TASK_SUCCESS_STATE="SUCCESS"
2121
TASK_FAILURE_STATE="FAILURE"
2222

23+
# Echoes fields in EndpointResponse class
24+
ALLOWED_ENDPOINT_RESPONSE_FIELDS={"status", "result_url", "result", "traceback", "status_code"}
25+
2326

2427
@dataclass_json(undefined=Undefined.EXCLUDE)
2528
@dataclass
@@ -189,6 +192,7 @@ def __init__(
189192
result_url: Optional[str] =None,
190193
result: Optional[str] =None,
191194
traceback: Optional[str] =None,
195+
status_code: Optional[int] =None,
192196
):
193197
"""
194198
Parameters:
@@ -210,12 +214,15 @@ def __init__(
210214
211215
traceback: The stack trace if the inference endpoint raised an error. Can be used for debugging
212216
217+
status_code: The underlying status code of the response, given from the inference endpoint itself.
218+
213219
"""
214220
self.client=client
215221
self.status=status
216222
self.result_url=result_url
217223
self.result=result
218224
self.traceback=traceback
225+
self.status_code=status_code
219226

220227
def__str__(self) ->str:
221228
return (
@@ -271,6 +278,7 @@ def get(self, timeout: Optional[float] = None) -> EndpointResponse:
271278
result_url=async_response.get("result",{}).get("result_url", None),
272279
result=async_response.get("result",{}).get("result", None),
273280
traceback=None,
281+
status_code=async_response.get("status_code", None),
274282
)
275283
elifstatus=="FAILURE":
276284
returnEndpointResponse(
@@ -279,6 +287,7 @@ def get(self, timeout: Optional[float] = None) -> EndpointResponse:
279287
result_url=None,
280288
result=None,
281289
traceback=async_response.get("traceback", None),
290+
status_code=async_response.get("status_code", None),
282291
)
283292
else:
284293
raiseValueError(f"Unrecognized status: {async_response['status']}")
@@ -312,6 +321,7 @@ def __next__(self):
312321
result_url=result.get("result_url", None),
313322
result=result.get("result", None),
314323
traceback=data.get("traceback"),
324+
status_code=data.get("status_code", None),
315325
)
316326

317327

@@ -397,7 +407,10 @@ def predict(self, request: EndpointRequest) -> EndpointResponse:
397407
args=request.args,
398408
return_pickled=request.return_pickled,
399409
)
400-
raw_response={k: vfork, vinraw_response.items() ifvisnotNone}
410+
411+
raw_response={
412+
k: vfork, vinraw_response.items() ifvisnotNoneandkinALLOWED_ENDPOINT_RESPONSE_FIELDS
413+
}
401414
returnEndpointResponse(client=self.client, **raw_response)
402415

403416

@@ -632,6 +645,7 @@ def single_request(inner_url, inner_task_id):
632645
result_url=raw_response.get("result_url", None),
633646
result=raw_response.get("result", None),
634647
traceback=raw_response.get("traceback", None),
648+
status_code=raw_response.get("status_code", None),
635649
)
636650
self.responses[url] =response_object
637651

0 commit comments

Comments
(0)