2020TASK_SUCCESS_STATE = "SUCCESS"
2121TASK_FAILURE_STATE = "FAILURE"
2222
23+ # Echoes fields in EndpointResponse class
24+ ALLOWED_ENDPOINT_RESPONSE_FIELDS = {"status" , "result_url" , "result" , "traceback" , "status_code" }
25+
2326
2427@dataclass_json (undefined = Undefined .EXCLUDE )
2528@dataclass
@@ -189,6 +192,7 @@ def __init__(
189192result_url : Optional [str ] = None ,
190193result : Optional [str ] = None ,
191194traceback : Optional [str ] = None ,
195+ status_code : Optional [int ] = None ,
192196 ):
193197"""
194198 Parameters:
@@ -210,12 +214,15 @@ def __init__(
210214
211215 traceback: The stack trace if the inference endpoint raised an error. Can be used for debugging
212216
217+ status_code: The underlying status code of the response, given from the inference endpoint itself.
218+
213219 """
214220self .client = client
215221self .status = status
216222self .result_url = result_url
217223self .result = result
218224self .traceback = traceback
225+ self .status_code = status_code
219226
220227def __str__ (self ) -> str :
221228return (
@@ -271,6 +278,7 @@ def get(self, timeout: Optional[float] = None) -> EndpointResponse:
271278result_url = async_response .get ("result" ,{}).get ("result_url" , None ),
272279result = async_response .get ("result" ,{}).get ("result" , None ),
273280traceback = None ,
281+ status_code = async_response .get ("status_code" , None ),
274282 )
275283elif status == "FAILURE" :
276284return EndpointResponse (
@@ -279,6 +287,7 @@ def get(self, timeout: Optional[float] = None) -> EndpointResponse:
279287result_url = None ,
280288result = None ,
281289traceback = async_response .get ("traceback" , None ),
290+ status_code = async_response .get ("status_code" , None ),
282291 )
283292else :
284293raise ValueError (f"Unrecognized status: { async_response ['status' ]} " )
@@ -312,6 +321,7 @@ def __next__(self):
312321result_url = result .get ("result_url" , None ),
313322result = result .get ("result" , None ),
314323traceback = data .get ("traceback" ),
324+ status_code = data .get ("status_code" , None ),
315325 )
316326
317327
@@ -397,7 +407,10 @@ def predict(self, request: EndpointRequest) -> EndpointResponse:
397407args = request .args ,
398408return_pickled = request .return_pickled ,
399409 )
400- raw_response = {k : v for k , v in raw_response .items () if v is not None }
410+
411+ raw_response = {
412+ k : v for k , v in raw_response .items () if v is not None and k in ALLOWED_ENDPOINT_RESPONSE_FIELDS
413+ }
401414return EndpointResponse (client = self .client , ** raw_response )
402415
403416
@@ -632,6 +645,7 @@ def single_request(inner_url, inner_task_id):
632645result_url = raw_response .get ("result_url" , None ),
633646result = raw_response .get ("result" , None ),
634647traceback = raw_response .get ("traceback" , None ),
648+ status_code = raw_response .get ("status_code" , None ),
635649 )
636650self .responses [url ] = response_object
637651
0 commit comments