Skip to content

Replicate Provider API Reference

The ReplicateProvider class implements the AIProvider interface for the Replicate service. It provides methods for text generation and chat functionality using models hosted on Replicate.

Class Definition

Bases: AIProvider

Replicate-specific implementation of the AIProvider abstract base class.

This class provides methods to interact with Replicate's AI models for text generation and chat functionality.

Attributes:

Name Type Description
client ReplicateClientProtocol

The Replicate client used for making API calls.

Parameters:

Name Type Description Default
api_key str

The API key for authenticating with Replicate.

required

Raises:

Type Description
ImportError

If the Replicate package is not installed.

Examples:

Initialize the Replicate provider:

provider = Provider(api_key="your-replicate-api-key")

Source code in clientai/replicate/provider.py
class Provider(AIProvider):
    """
    Replicate-specific implementation of the AIProvider abstract base class.

    This class provides methods to interact with Replicate's AI models for
    text generation and chat functionality.

    Attributes:
        client: The Replicate client used for making API calls.

    Args:
        api_key: The API key for authenticating with Replicate.

    Raises:
        ImportError: If the Replicate package is not installed.

    Examples:
        Initialize the Replicate provider:
        ```python
        provider = Provider(api_key="your-replicate-api-key")
        ```
    """

    def __init__(self, api_key: str):
        if not REPLICATE_INSTALLED or Client is None:
            raise ImportError(
                "The replicate package is not installed. "
                "Please install it with 'pip install clientai[replicate]'."
            )
        self.client: ReplicateClientProtocol = Client(api_token=api_key)

    def _process_output(self, output: Any) -> str:
        """
        Process the output from Replicate API into a string format.

        Args:
            output: The raw output from Replicate API.

        Returns:
            str: The processed output as a string.
        """
        if isinstance(output, List):
            return "".join(str(item) for item in output)
        elif isinstance(output, str):
            return output
        else:
            return str(output)

    def _wait_for_prediction(
        self, prediction_id: str, max_wait_time: int = 300
    ) -> ReplicatePredictionProtocol:
        """
        Wait for a prediction to complete or fail.

        Args:
            prediction_id: The ID of the prediction to wait for.
            max_wait_time: Maximum time to wait in seconds. Defaults to 300.

        Returns:
            ReplicatePredictionProtocol: The completed prediction.

        Raises:
            TimeoutError: If the prediction doesn't complete within
                          the max_wait_time.
            APIError: If the prediction fails.
        """
        start_time = time.time()
        while time.time() - start_time < max_wait_time:
            prediction = self.client.predictions.get(prediction_id)
            if prediction.status == "succeeded":
                return prediction
            elif prediction.status == "failed":
                raise self._map_exception_to_clientai_error(
                    Exception(f"Prediction failed: {prediction.error}")
                )
            time.sleep(1)

        raise self._map_exception_to_clientai_error(
            Exception("Prediction timed out"), status_code=408
        )

    def _stream_response(
        self,
        prediction: ReplicatePredictionProtocol,
        return_full_response: bool,
    ) -> Iterator[Union[str, ReplicateStreamResponse]]:
        """
        Stream the response from a prediction.

        Args:
            prediction: The prediction to stream.
            return_full_response: If True, yield full response objects.

        Yields:
            Union[str, ReplicateStreamResponse]: Processed output or
                                                 full response objects.
        """
        metadata = cast(ReplicateStreamResponse, prediction.__dict__.copy())
        for event in prediction.stream():
            if return_full_response:
                metadata["output"] = self._process_output(event)
                yield metadata
            else:
                yield self._process_output(event)

    def _map_exception_to_clientai_error(
        self, e: Exception, status_code: Optional[int] = None
    ) -> ClientAIError:
        """
        Maps a Replicate exception to the appropriate ClientAI exception.

        Args:
            e (Exception): The exception caught during the API call.
            status_code (int, optional): The HTTP status code, if available.

        Returns:
            ClientAIError: An instance of the appropriate ClientAI exception.
        """
        error_message = str(e)
        status_code = status_code or getattr(e, "status_code", None)

        if (
            "authentication" in error_message.lower()
            or "unauthorized" in error_message.lower()
        ):
            return AuthenticationError(
                error_message, status_code, original_error=e
            )
        elif "rate limit" in error_message.lower():
            return RateLimitError(error_message, status_code, original_error=e)
        elif "not found" in error_message.lower():
            return ModelError(error_message, status_code, original_error=e)
        elif "invalid" in error_message.lower():
            return InvalidRequestError(
                error_message, status_code, original_error=e
            )
        elif "timeout" in error_message.lower() or status_code == 408:
            return TimeoutError(error_message, status_code, original_error=e)
        elif status_code == 400:
            return InvalidRequestError(
                error_message, status_code, original_error=e
            )
        else:
            return APIError(error_message, status_code, original_error=e)

    def generate_text(
        self,
        prompt: str,
        model: str,
        return_full_response: bool = False,
        stream: bool = False,
        **kwargs: Any,
    ) -> ReplicateGenericResponse:
        """
        Generate text based on a given prompt
        using a specified Replicate model.

        Args:
            prompt: The input prompt for text generation.
            model: The name or identifier of the Replicate model to use.
            return_full_response: If True, return the full response object.
                If False, return only the generated text. Defaults to False.
            stream: If True, return an iterator for streaming responses.
                Defaults to False.
            **kwargs: Additional keyword arguments
                      to pass to the Replicate API.

        Returns:
            ReplicateGenericResponse: The generated text, full response object,
            or an iterator for streaming responses.

        Examples:
            Generate text (text only):
            ```python
            response = provider.generate_text(
                "Explain quantum computing",
                model="meta/llama-2-70b-chat:latest",
            )
            print(response)
            ```

            Generate text (full response):
            ```python
            response = provider.generate_text(
                "Explain quantum computing",
                model="meta/llama-2-70b-chat:latest",
                return_full_response=True
            )
            print(response["output"])
            ```

            Generate text (streaming):
            ```python
            for chunk in provider.generate_text(
                "Explain quantum computing",
                model="meta/llama-2-70b-chat:latest",
                stream=True
            ):
                print(chunk, end="", flush=True)
            ```
        """
        try:
            prediction = self.client.predictions.create(
                model=model, input={"prompt": prompt}, stream=stream, **kwargs
            )

            if stream:
                return self._stream_response(prediction, return_full_response)
            else:
                completed_prediction = self._wait_for_prediction(prediction.id)
                if return_full_response:
                    response = cast(
                        ReplicateResponse, completed_prediction.__dict__.copy()
                    )
                    response["output"] = self._process_output(
                        completed_prediction.output
                    )
                    return response
                else:
                    return self._process_output(completed_prediction.output)

        except Exception as e:
            raise self._map_exception_to_clientai_error(e)

    def chat(
        self,
        messages: List[Message],
        model: str,
        return_full_response: bool = False,
        stream: bool = False,
        **kwargs: Any,
    ) -> ReplicateGenericResponse:
        """
        Engage in a chat conversation using a specified Replicate model.

        Args:
            messages: A list of message dictionaries, each containing
                      'role' and 'content'.
            model: The name or identifier of the Replicate model to use.
            return_full_response: If True, return the full response object.
                If False, return only the generated text. Defaults to False.
            stream: If True, return an iterator for streaming responses.
                Defaults to False.
            **kwargs: Additional keyword arguments
                      to pass to the Replicate API.

        Returns:
            ReplicateGenericResponse: The chat response, full response object,
            or an iterator for streaming responses.

        Examples:
            Chat (message content only):
            ```python
            messages = [
                {"role": "user", "content": "What is the capital of France?"},
                {"role": "assistant", "content": "The capital is Paris."},
                {"role": "user", "content": "What is its population?"}
            ]
            response = provider.chat(
                messages,
                model="meta/llama-2-70b-chat:latest",
            )
            print(response)
            ```

            Chat (full response):
            ```python
            response = provider.chat(
                messages,
                model="meta/llama-2-70b-chat:latest",
                return_full_response=True
            )
            print(response["output"])
            ```

            Chat (streaming):
            ```python
            for chunk in provider.chat(
                messages,
                model="meta/llama-2-70b-chat:latest",
                stream=True
            ):
                print(chunk, end="", flush=True)
            ```
        """
        try:
            prompt = "\n".join(
                [f"{m['role']}: {m['content']}" for m in messages]
            )
            prompt += "\nassistant: "

            prediction = self.client.predictions.create(
                model=model, input={"prompt": prompt}, stream=stream, **kwargs
            )

            if stream:
                return self._stream_response(prediction, return_full_response)
            else:
                completed_prediction = self._wait_for_prediction(prediction.id)
                if return_full_response:
                    response = cast(
                        ReplicateResponse, completed_prediction.__dict__.copy()
                    )
                    response["output"] = self._process_output(
                        completed_prediction.output
                    )
                    return response
                else:
                    return self._process_output(completed_prediction.output)

        except Exception as e:
            raise self._map_exception_to_clientai_error(e)

chat(messages, model, return_full_response=False, stream=False, **kwargs)

Engage in a chat conversation using a specified Replicate model.

Parameters:

Name Type Description Default
messages List[Message]

A list of message dictionaries, each containing 'role' and 'content'.

required
model str

The name or identifier of the Replicate model to use.

required
return_full_response bool

If True, return the full response object. If False, return only the generated text. Defaults to False.

False
stream bool

If True, return an iterator for streaming responses. Defaults to False.

False
**kwargs Any

Additional keyword arguments to pass to the Replicate API.

{}

Returns:

Name Type Description
ReplicateGenericResponse ReplicateGenericResponse

The chat response, full response object,

ReplicateGenericResponse

or an iterator for streaming responses.

Examples:

Chat (message content only):

messages = [
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "assistant", "content": "The capital is Paris."},
    {"role": "user", "content": "What is its population?"}
]
response = provider.chat(
    messages,
    model="meta/llama-2-70b-chat:latest",
)
print(response)

Chat (full response):

response = provider.chat(
    messages,
    model="meta/llama-2-70b-chat:latest",
    return_full_response=True
)
print(response["output"])

Chat (streaming):

for chunk in provider.chat(
    messages,
    model="meta/llama-2-70b-chat:latest",
    stream=True
):
    print(chunk, end="", flush=True)

Source code in clientai/replicate/provider.py
def chat(
    self,
    messages: List[Message],
    model: str,
    return_full_response: bool = False,
    stream: bool = False,
    **kwargs: Any,
) -> ReplicateGenericResponse:
    """
    Engage in a chat conversation using a specified Replicate model.

    Args:
        messages: A list of message dictionaries, each containing
                  'role' and 'content'.
        model: The name or identifier of the Replicate model to use.
        return_full_response: If True, return the full response object.
            If False, return only the generated text. Defaults to False.
        stream: If True, return an iterator for streaming responses.
            Defaults to False.
        **kwargs: Additional keyword arguments
                  to pass to the Replicate API.

    Returns:
        ReplicateGenericResponse: The chat response, full response object,
        or an iterator for streaming responses.

    Examples:
        Chat (message content only):
        ```python
        messages = [
            {"role": "user", "content": "What is the capital of France?"},
            {"role": "assistant", "content": "The capital is Paris."},
            {"role": "user", "content": "What is its population?"}
        ]
        response = provider.chat(
            messages,
            model="meta/llama-2-70b-chat:latest",
        )
        print(response)
        ```

        Chat (full response):
        ```python
        response = provider.chat(
            messages,
            model="meta/llama-2-70b-chat:latest",
            return_full_response=True
        )
        print(response["output"])
        ```

        Chat (streaming):
        ```python
        for chunk in provider.chat(
            messages,
            model="meta/llama-2-70b-chat:latest",
            stream=True
        ):
            print(chunk, end="", flush=True)
        ```
    """
    try:
        prompt = "\n".join(
            [f"{m['role']}: {m['content']}" for m in messages]
        )
        prompt += "\nassistant: "

        prediction = self.client.predictions.create(
            model=model, input={"prompt": prompt}, stream=stream, **kwargs
        )

        if stream:
            return self._stream_response(prediction, return_full_response)
        else:
            completed_prediction = self._wait_for_prediction(prediction.id)
            if return_full_response:
                response = cast(
                    ReplicateResponse, completed_prediction.__dict__.copy()
                )
                response["output"] = self._process_output(
                    completed_prediction.output
                )
                return response
            else:
                return self._process_output(completed_prediction.output)

    except Exception as e:
        raise self._map_exception_to_clientai_error(e)

generate_text(prompt, model, return_full_response=False, stream=False, **kwargs)

Generate text based on a given prompt using a specified Replicate model.

Parameters:

Name Type Description Default
prompt str

The input prompt for text generation.

required
model str

The name or identifier of the Replicate model to use.

required
return_full_response bool

If True, return the full response object. If False, return only the generated text. Defaults to False.

False
stream bool

If True, return an iterator for streaming responses. Defaults to False.

False
**kwargs Any

Additional keyword arguments to pass to the Replicate API.

{}

Returns:

Name Type Description
ReplicateGenericResponse ReplicateGenericResponse

The generated text, full response object,

ReplicateGenericResponse

or an iterator for streaming responses.

Examples:

Generate text (text only):

response = provider.generate_text(
    "Explain quantum computing",
    model="meta/llama-2-70b-chat:latest",
)
print(response)

Generate text (full response):

response = provider.generate_text(
    "Explain quantum computing",
    model="meta/llama-2-70b-chat:latest",
    return_full_response=True
)
print(response["output"])

Generate text (streaming):

for chunk in provider.generate_text(
    "Explain quantum computing",
    model="meta/llama-2-70b-chat:latest",
    stream=True
):
    print(chunk, end="", flush=True)

Source code in clientai/replicate/provider.py
def generate_text(
    self,
    prompt: str,
    model: str,
    return_full_response: bool = False,
    stream: bool = False,
    **kwargs: Any,
) -> ReplicateGenericResponse:
    """
    Generate text based on a given prompt
    using a specified Replicate model.

    Args:
        prompt: The input prompt for text generation.
        model: The name or identifier of the Replicate model to use.
        return_full_response: If True, return the full response object.
            If False, return only the generated text. Defaults to False.
        stream: If True, return an iterator for streaming responses.
            Defaults to False.
        **kwargs: Additional keyword arguments
                  to pass to the Replicate API.

    Returns:
        ReplicateGenericResponse: The generated text, full response object,
        or an iterator for streaming responses.

    Examples:
        Generate text (text only):
        ```python
        response = provider.generate_text(
            "Explain quantum computing",
            model="meta/llama-2-70b-chat:latest",
        )
        print(response)
        ```

        Generate text (full response):
        ```python
        response = provider.generate_text(
            "Explain quantum computing",
            model="meta/llama-2-70b-chat:latest",
            return_full_response=True
        )
        print(response["output"])
        ```

        Generate text (streaming):
        ```python
        for chunk in provider.generate_text(
            "Explain quantum computing",
            model="meta/llama-2-70b-chat:latest",
            stream=True
        ):
            print(chunk, end="", flush=True)
        ```
    """
    try:
        prediction = self.client.predictions.create(
            model=model, input={"prompt": prompt}, stream=stream, **kwargs
        )

        if stream:
            return self._stream_response(prediction, return_full_response)
        else:
            completed_prediction = self._wait_for_prediction(prediction.id)
            if return_full_response:
                response = cast(
                    ReplicateResponse, completed_prediction.__dict__.copy()
                )
                response["output"] = self._process_output(
                    completed_prediction.output
                )
                return response
            else:
                return self._process_output(completed_prediction.output)

    except Exception as e:
        raise self._map_exception_to_clientai_error(e)