fal.distributed package

Submodules

fal.distributed.utils module

class fal.distributed.utils.KeepAliveTimer(func, timeout, start=False, *args, **kwargs)

Bases: object

Call a function after a certain amount of time to keep the worker alive.

cancel()

Cancel the timer.

Return type:

None

reset()

Reset the timer.

Return type:

None

start()

Start the timer.

Return type:

None

timer: Optional[Timer]
fal.distributed.utils.distributed_deserialize(serialized)

Deserializes a JSON string to an object. :type serialized: Union[bytes, str] :param serialized: The serialized JSON string. :rtype: Any :return: The deserialized object.

fal.distributed.utils.distributed_serialize(obj, is_final=False, image_format='jpeg')

Serializes an object to a JSON string. :type obj: Any :param obj: The object to serialize. :rtype: bytes :return: The serialized JSON string.

fal.distributed.utils.encode_text_event(obj, is_final=False, image_format='jpeg')

Encodes a text response as a JSON string. :param response: The text response to encode. :type is_final: bool :param is_final: Whether this is the final response. :rtype: bytes :return: The encoded JSON string.

fal.distributed.utils.format_deserialized_data(data)

Formats the deserialized data for further processing. :type data: Any :param data: The data to format. :rtype: Any :return: The formatted data.

fal.distributed.utils.format_for_serialization(response, image_format='jpeg', is_final=False, as_data_urls=False)

Formats the response for serialization. Most importantly, it encodes images to base64 and returns the image format and size. :type response: Any :param response: The response to format. :type is_final: bool :param is_final: Whether this is the final response. :rtype: Any :return: The formatted response.

fal.distributed.utils.has_type_name(maybe_type, type_name)

Checks if the given object has a type name that matches the provided type name. This is used to avoid importing torch or other libraries unnecessarily. :type maybe_type: Any :param maybe_type: The object to check. :type type_name: str :param type_name: The type name to match against. :rtype: bool :return: True if the object’s type name matches, False otherwise.

fal.distributed.utils.is_numpy_array(obj)

Checks if the given object is a NumPy array without importing numpy.

Return type:

bool

fal.distributed.utils.is_pil_image(obj)

Checks if the given object is a PIL Image without importing PIL.

Return type:

bool

fal.distributed.utils.is_torch_tensor(obj)

Checks if the given object is a PyTorch tensor without importing torch.

Return type:

bool

fal.distributed.utils.launch_distributed_processes(func, world_size=1, master_addr='127.0.0.1', master_port=29500, timeout=1800, cwd=None, *args, **kwargs)

Launches a distributed process group using torch.multiprocessing.spawn. This function is designed to be called from the main process and will spawn multiple worker processes for distributed training or inference. :type func: Callable :param func: The function to run in each worker process. :type world_size: int :param world_size: The total number of processes to spawn. :type master_addr: str :param master_addr: The address of the master node. :type master_port: int :param master_port: The port on which the master node will listen. :rtype: mp.ProcessContext :return: The process context for the spawned processes.

fal.distributed.utils.wrap_distributed_worker(rank, func, world_size, master_addr, master_port, timeout, cwd, args, kwargs)

Worker function for distributed training or inference.

This function is called by each worker process spawned by torch.multiprocessing.spawn.

Parameters:
  • func (Callable) – The function to run in each worker process.

  • world_size (int) – The total number of processes.

  • rank (int) – The rank of the current process.

  • master_addr (str) – The address of the master node.

  • master_port (int) – The port on which the master node will listen.

Return type:

None

fal.distributed.worker module

class fal.distributed.worker.DistributedRunner(worker_cls=<class 'fal.distributed.worker.DistributedWorker'>, world_size=1, master_addr='127.0.0.1', master_port=29500, worker_addr='127.0.0.1', worker_port=54923, timeout=86400, keepalive_payload={}, keepalive_interval=None, cwd=None, set_device=None)

Bases: object

A class to launch and manage distributed workers.

close_zmq_socket()

Closes the ZeroMQ socket.

Return type:

None

context: Optional[mp.ProcessContext]
ensure_alive()

Ensures that the distributed worker processes are alive. If the processes are not alive, it raises an error.

Return type:

None

gather_errors()

Gathers errors from the distributed worker processes.

This method should be called to collect any errors that occurred during execution.

Return type:

list[Exception]

Returns:

A list of exceptions raised by the worker processes.

get_zmq_socket()

Returns a ZeroMQ socket of the specified type. :param socket_type: The type of the ZeroMQ socket. :rtype: Socket[Any] :return: A ZeroMQ socket.

async invoke(payload={}, timeout=None)

Invokes the distributed worker with the given payload. :type payload: dict[str, Any] :param payload: The payload to send to the worker. :type timeout: Optional[int] :param timeout: The timeout for the overall operation. :rtype: Any :return: The result from the worker.

is_alive()

Check if the distributed worker processes are alive. :rtype: bool :return: True if the distributed processes are alive, False otherwise.

keepalive(timeout=60.0)

Sends the keepalive payload to the worker.

Return type:

None

keepalive_timer: Optional[KeepAliveTimer]
maybe_cancel_keepalive()

Cancels the keepalive timer if it is set.

Return type:

None

maybe_reset_keepalive()

Resets the keepalive timer if it is set.

Return type:

None

maybe_start_keepalive()

Starts the keepalive timer if it is set.

Return type:

None

run(**kwargs)

The main function to run the distributed worker.

This function is called by each worker process spawned by torch.multiprocessing.spawn. This method must be synchronous.

Parameters:

kwargs (Any) – The arguments to pass to the worker.

Return type:

None

async start(timeout=1800, **kwargs)

Starts the distributed worker processes. :type timeout: int :param timeout: The timeout for the distributed processes.

Return type:

None

async stop(timeout=10)

Stops the distributed worker processes. :type timeout: int :param timeout: The timeout for the distributed processes to stop.

Return type:

None

async stream(payload={}, timeout=None, streaming_timeout=None, as_text_events=False)

Streams the result from the distributed worker. :type payload: dict[str, Any] :param payload: The payload to send to the worker. :type timeout: Optional[int] :param timeout: The timeout for the overall operation. :type streaming_timeout: Optional[int] :param streaming_timeout: The timeout in-between streamed results. :type as_text_events: bool :param as_text_events: Whether to yield results as text events. :rtype: AsyncIterator[Any] :return: An async iterator that yields the result from the worker.

terminate(timeout=10)

Terminates the distributed worker processes. This method should be called to clean up the worker processes.

Return type:

None

zmq_socket: Optional[Socket[Any]]
class fal.distributed.worker.DistributedWorker(rank=0, world_size=1)

Bases: object

A base class for distributed workers.

add_streaming_error(error)

Add an error to the queue. :type error: Exception :param error: The error to add to the queue.

Return type:

None

add_streaming_result(result, image_format='jpeg', as_text_event=False)

Add a streaming result to the queue. :type result: Any :param result: The result to add to the queue.

Return type:

None

property device: torch.device

The device for the current worker.

Type:

return

initialize(**kwargs)

Initialize the worker.

Return type:

None

loop: AbstractEventLoop
queue: Queue[bytes]
rank_print(message, debug=False)

Print a message with the rank of the current worker. :type message: str :param message: The message to print. :type debug: bool :param debug: Whether to print the message as a debug message.

Return type:

None

run_in_worker(func, *args, **kwargs)

Run a function in the worker.

Return type:

Future[Any]

property running: bool

Whether the event loop is running.

Type:

return

setup(**kwargs)

Override this method to set up the worker. This method is called once per worker.

Return type:

None

shutdown(timeout=None)

Shutdown the event loop. :type timeout: Union[int, float, None] :param timeout: The timeout for the shutdown.

Return type:

None

submit(coro)

Submit a coroutine to the event loop. :type coro: Coroutine[Any, Any, Any] :param coro: The coroutine to submit to the event loop. :rtype: Future[Any] :return: A future that will resolve to the result of the coroutine.

teardown()

Override this method to tear down the worker. This method is called once per worker.

Return type:

None

thread: Thread

Module contents

class fal.distributed.DistributedRunner(worker_cls=<class 'fal.distributed.worker.DistributedWorker'>, world_size=1, master_addr='127.0.0.1', master_port=29500, worker_addr='127.0.0.1', worker_port=54923, timeout=86400, keepalive_payload={}, keepalive_interval=None, cwd=None, set_device=None)

Bases: object

A class to launch and manage distributed workers.

close_zmq_socket()

Closes the ZeroMQ socket.

Return type:

None

context: Optional[mp.ProcessContext]
ensure_alive()

Ensures that the distributed worker processes are alive. If the processes are not alive, it raises an error.

Return type:

None

gather_errors()

Gathers errors from the distributed worker processes.

This method should be called to collect any errors that occurred during execution.

Return type:

list[Exception]

Returns:

A list of exceptions raised by the worker processes.

get_zmq_socket()

Returns a ZeroMQ socket of the specified type. :param socket_type: The type of the ZeroMQ socket. :rtype: Socket[Any] :return: A ZeroMQ socket.

async invoke(payload={}, timeout=None)

Invokes the distributed worker with the given payload. :type payload: dict[str, Any] :param payload: The payload to send to the worker. :type timeout: Optional[int] :param timeout: The timeout for the overall operation. :rtype: Any :return: The result from the worker.

is_alive()

Check if the distributed worker processes are alive. :rtype: bool :return: True if the distributed processes are alive, False otherwise.

keepalive(timeout=60.0)

Sends the keepalive payload to the worker.

Return type:

None

keepalive_timer: Optional[KeepAliveTimer]
maybe_cancel_keepalive()

Cancels the keepalive timer if it is set.

Return type:

None

maybe_reset_keepalive()

Resets the keepalive timer if it is set.

Return type:

None

maybe_start_keepalive()

Starts the keepalive timer if it is set.

Return type:

None

run(**kwargs)

The main function to run the distributed worker.

This function is called by each worker process spawned by torch.multiprocessing.spawn. This method must be synchronous.

Parameters:

kwargs (Any) – The arguments to pass to the worker.

Return type:

None

async start(timeout=1800, **kwargs)

Starts the distributed worker processes. :type timeout: int :param timeout: The timeout for the distributed processes.

Return type:

None

async stop(timeout=10)

Stops the distributed worker processes. :type timeout: int :param timeout: The timeout for the distributed processes to stop.

Return type:

None

async stream(payload={}, timeout=None, streaming_timeout=None, as_text_events=False)

Streams the result from the distributed worker. :type payload: dict[str, Any] :param payload: The payload to send to the worker. :type timeout: Optional[int] :param timeout: The timeout for the overall operation. :type streaming_timeout: Optional[int] :param streaming_timeout: The timeout in-between streamed results. :type as_text_events: bool :param as_text_events: Whether to yield results as text events. :rtype: AsyncIterator[Any] :return: An async iterator that yields the result from the worker.

terminate(timeout=10)

Terminates the distributed worker processes. This method should be called to clean up the worker processes.

Return type:

None

zmq_socket: Optional[Socket[Any]]
class fal.distributed.DistributedWorker(rank=0, world_size=1)

Bases: object

A base class for distributed workers.

add_streaming_error(error)

Add an error to the queue. :type error: Exception :param error: The error to add to the queue.

Return type:

None

add_streaming_result(result, image_format='jpeg', as_text_event=False)

Add a streaming result to the queue. :type result: Any :param result: The result to add to the queue.

Return type:

None

property device: torch.device

The device for the current worker.

Type:

return

initialize(**kwargs)

Initialize the worker.

Return type:

None

loop: AbstractEventLoop
queue: Queue[bytes]
rank_print(message, debug=False)

Print a message with the rank of the current worker. :type message: str :param message: The message to print. :type debug: bool :param debug: Whether to print the message as a debug message.

Return type:

None

run_in_worker(func, *args, **kwargs)

Run a function in the worker.

Return type:

Future[Any]

property running: bool

Whether the event loop is running.

Type:

return

setup(**kwargs)

Override this method to set up the worker. This method is called once per worker.

Return type:

None

shutdown(timeout=None)

Shutdown the event loop. :type timeout: Union[int, float, None] :param timeout: The timeout for the shutdown.

Return type:

None

submit(coro)

Submit a coroutine to the event loop. :type coro: Coroutine[Any, Any, Any] :param coro: The coroutine to submit to the event loop. :rtype: Future[Any] :return: A future that will resolve to the result of the coroutine.

teardown()

Override this method to tear down the worker. This method is called once per worker.

Return type:

None

thread: Thread