State and Dependencies
State and Dependencies
State
The TaskiqState
is a global variable where you can keep the variables you want to use later. For example, you want to open a database connection pool at a broker's startup.
This can be achieved by adding event handlers.
You can use one of these events:
WORKER_STARTUP
CLIENT_STARTUP
WORKER_SHUTDOWN
CLIENT_SHUTDOWN
Worker events are called when you start listening to the broker messages using taskiq. Client events are called when you call the startup
method of your broker from your code.
This is an example of code using event handlers:
import asyncio
from typing import Annotated, Optional
from redis.asyncio import ConnectionPool, Redis # type: ignore
from taskiq_aio_pika import AioPikaBroker
from taskiq_redis import RedisAsyncResultBackend
from taskiq import Context, TaskiqDepends, TaskiqEvents, TaskiqState
# To run this example, please install:
# * taskiq
# * taskiq-redis
# * taskiq-aio-pika
broker = AioPikaBroker(
"amqp://localhost",
).with_result_backend(RedisAsyncResultBackend("redis://localhost"))
@broker.on_event(TaskiqEvents.WORKER_STARTUP)
async def startup(state: TaskiqState) -> None:
# Here we store connection pool on startup for later use.
state.redis = ConnectionPool.from_url("redis://localhost/1")
@broker.on_event(TaskiqEvents.WORKER_SHUTDOWN)
async def shutdown(state: TaskiqState) -> None:
# Here we close our pool on shutdown event.
await state.redis.disconnect()
@broker.task
async def get_val(
key: str,
context: Annotated[Context, TaskiqDepends()],
) -> Optional[str]:
# Now we can use our pool.
redis = Redis(connection_pool=context.state.redis, decode_responses=True)
return await redis.get(key)
@broker.task
async def set_val(
key: str,
value: str,
context: Annotated[Context, TaskiqDepends()],
) -> None:
# Now we can use our pool to set value.
await Redis(connection_pool=context.state.redis).set(key, value)
async def main() -> None:
await broker.startup()
set_task = await set_val.kiq("key", "value")
set_result = await set_task.wait_result(with_logs=True)
if set_result.is_err:
print(set_result.log)
raise ValueError("Cannot set value in redis. See logs.")
get_task = await get_val.kiq("key")
get_res = await get_task.wait_result()
print(f"Got redis value: {get_res.return_value}")
await broker.shutdown()
if __name__ == "__main__":
asyncio.run(main())
import asyncio
from typing import Optional
from redis.asyncio import ConnectionPool, Redis # type: ignore
from taskiq_aio_pika import AioPikaBroker
from taskiq_redis import RedisAsyncResultBackend
from taskiq import Context, TaskiqDepends, TaskiqEvents, TaskiqState
# To run this example, please install:
# * taskiq
# * taskiq-redis
# * taskiq-aio-pika
broker = AioPikaBroker(
"amqp://localhost",
).with_result_backend(RedisAsyncResultBackend("redis://localhost"))
@broker.on_event(TaskiqEvents.WORKER_STARTUP)
async def startup(state: TaskiqState) -> None:
# Here we store connection pool on startup for later use.
state.redis = ConnectionPool.from_url("redis://localhost/1")
@broker.on_event(TaskiqEvents.WORKER_SHUTDOWN)
async def shutdown(state: TaskiqState) -> None:
# Here we close our pool on shutdown event.
await state.redis.disconnect()
@broker.task
async def get_val(key: str, context: Context = TaskiqDepends()) -> Optional[str]:
# Now we can use our pool.
redis = Redis(connection_pool=context.state.redis, decode_responses=True)
return await redis.get(key)
@broker.task
async def set_val(key: str, value: str, context: Context = TaskiqDepends()) -> None:
# Now we can use our pool to set value.
await Redis(connection_pool=context.state.redis).set(key, value)
async def main() -> None:
await broker.startup()
set_task = await set_val.kiq("key", "value")
set_result = await set_task.wait_result(with_logs=True)
if set_result.is_err:
print(set_result.log)
raise ValueError("Cannot set value in redis. See logs.")
get_task = await get_val.kiq("key")
get_res = await get_task.wait_result()
print(f"Got redis value: {get_res.return_value}")
await broker.shutdown()
if __name__ == "__main__":
asyncio.run(main())
Cool tip!
If you want to add handlers programmatically, you can use the broker.add_event_handler
function.
As you can see in this example, this worker will initialize the Redis pool at the startup. You can access the state from the context.
Dependencies
Using context directly is nice, but this way you won't get code-completion.
That's why we suggest you try TaskiqDependencies. The implementation is very similar to FastApi's dependencies. You can use classes, functions, and generators as dependencies.
We use the taskiq-dependencies package to provide autocompletion. You can easily integrate it in your own project.
How dependencies are useful
You can use dependencies for better autocompletion and reduce the amount of code you write. Since the state is generic, we cannot guess the types of the state fields. Dependencies can be annotated with type hints and therefore provide better auto-completion.
Let's assume that you've stored a Redis connection pool in the state as in the example above.
@broker.on_event(TaskiqEvents.WORKER_STARTUP)
async def startup(state: TaskiqState) -> None:
# Here we store connection pool on startup for later use.
state.redis = ConnectionPool.from_url("redis://localhost/1")
You can access this variable by using the current execution context directly, like this:
from typing import Annotated
@broker.task
async def my_task(context: Annotated[Context, TaskiqDepends()]) -> None:
async with Redis(connection_pool=context.state.redis, decode_responses=True) as redis:
await redis.set('key', 'value')
@broker.task
async def my_task(context: Context = TaskiqDepends()) -> None:
async with Redis(connection_pool=context.state.redis, decode_responses=True) as redis:
await redis.set('key', 'value')
If you hit the TAB
button after the context.state.
expression, your IDE won't give you any auto-completion. But we can create a dependency function to add auto-completion.
from typing import Annotated
def redis_dep(context: Annotated[Context, TaskiqDepends()]) -> Redis:
return Redis(connection_pool=context.state.redis, decode_responses=True)
@broker.task
async def my_task(redis: Annotated[Redis, TaskiqDepends(redis_dep)]) -> None:
await redis.set('key', 'value')
def redis_dep(context: Context = TaskiqDepends()) -> Redis:
return Redis(connection_pool=context.state.redis, decode_responses=True)
@broker.task
async def my_task(redis: Redis = TaskiqDepends(redis_dep)) -> None:
await redis.set('key', 'value')
Now, this dependency injection will be autocompleted. But, of course, state fields cannot be autocompleted, even in dependencies. But this way, you won't make any typos while writing tasks.
How do dependencies work
We build a graph of dependencies on startup. If the parameter of the function has the default value of TaskiqDepends
this parameter will be treated as a dependency.
Dependencies can also depend on something. Also dependencies are optimized to not evaluate things many times.
For example:
import random
from typing import Annotated
from taskiq import TaskiqDepends
def common_dep() -> int:
# For example it returns 8
return random.randint(1, 10)
def dep1(cd: Annotated[int, TaskiqDepends(common_dep)]) -> int:
# This function will return 9
return cd + 1
def dep2(cd: Annotated[int, TaskiqDepends(common_dep)]) -> int:
# This function will return 10
return cd + 2
def my_task(
d1: Annotated[int, TaskiqDepends(dep1)],
d2: Annotated[int, TaskiqDepends(dep2)],
) -> int:
# This function will return 19
return d1 + d2
import random
from taskiq import TaskiqDepends
def common_dep() -> int:
# For example it returns 8
return random.randint(1, 10)
def dep1(cd: int = TaskiqDepends(common_dep)) -> int:
# This function will return 9
return cd + 1
def dep2(cd: int = TaskiqDepends(common_dep)) -> int:
# This function will return 10
return cd + 2
def my_task(
d1: int = TaskiqDepends(dep1),
d2: int = TaskiqDepends(dep2),
) -> int:
# This function will return 19
return d1 + d2
In this code, the dependency common_dep
is going to be evaluated only once and the dep1
and the dep2
are going to receive the same value. You can control this behavior by using the use_cache=False
parameter to you dependency. This parameter will force the dependency to reevaluate all it's subdependencies.
In this example we cannot predict the result. Since the dep2
doesn't use cache for the common_dep
function.
import random
from typing import Annotated
from taskiq import TaskiqDepends
def common_dep() -> int:
return random.randint(1, 10)
def dep1(cd: Annotated[int, TaskiqDepends(common_dep)]) -> int:
return cd + 1
def dep2(cd: Annotated[int, TaskiqDepends(common_dep, use_cache=False)]) -> int:
return cd + 2
def my_task(
d1: Annotated[int, TaskiqDepends(dep1)],
d2: Annotated[int, TaskiqDepends(dep2)],
) -> int:
return d1 + d2
import random
from taskiq import TaskiqDepends
def common_dep() -> int:
return random.randint(1, 10)
def dep1(cd: int = TaskiqDepends(common_dep)) -> int:
return cd + 1
def dep2(cd: int = TaskiqDepends(common_dep, use_cache=False)) -> int:
return cd + 2
def my_task(
d1: int = TaskiqDepends(dep1),
d2: int = TaskiqDepends(dep2),
) -> int:
return d1 + d2
The graph for cached dependencies looks like this:
The dependencies graph for my_task
where dep2
doesn't use cached value for common_dep
looks like this:
Class as a dependency
You can use classes as dependencies, and they can also use other dependencies too.
Let's see an example:
from typing import Annotated
from taskiq import TaskiqDepends
async def db_connection() -> str:
return "let's pretend as this is a connection"
class MyDAO:
def __init__(self, db_conn: Annotated[str, TaskiqDepends(db_connection)]) -> None:
self.db_conn = db_conn
def get_users(self) -> str:
return self.db_conn.upper()
def my_task(dao: Annotated[MyDAO, TaskiqDepends()]) -> None:
print(dao.get_users())
from taskiq import TaskiqDepends
async def db_connection() -> str:
return "let's pretend as this is a connection"
class MyDAO:
def __init__(self, db_conn: str = TaskiqDepends(db_connection)) -> None:
self.db_conn = db_conn
def get_users(self) -> str:
return self.db_conn.upper()
def my_task(dao: MyDAO = TaskiqDepends()) -> None:
print(dao.get_users())
As you can see, the dependency for my_task
function is declared with TaskiqDependency()
. It's because you can omit the class if it's declared in type-hint for the parameter. This feature doesn't work with dependency functions, it's only for classes.
You can pass dependencies for classes in the constructor.
Generator dependencies
Generator dependencies are used to perform startup before task execution and teardown after the task execution.
from typing import Annotated, Generator
from taskiq import TaskiqDepends
def dependency() -> Generator[str, None, None]:
print("Startup")
yield "value"
print("Shutdown")
async def my_task(dep: Annotated[str, TaskiqDepends(dependency)]) -> None:
print(dep.upper())
from typing import Generator
from taskiq import TaskiqDepends
def dependency() -> Generator[str, None, None]:
print("Startup")
yield "value"
print("Shutdown")
async def my_task(dep: str = TaskiqDepends(dependency)) -> None:
print(dep.upper())
In this example, we can do something at startup before the execution and at shutdown after the task is completed.
If you want to do something asynchronously, convert this function to an asynchronous generator. Like this:
import asyncio
from typing import Annotated, AsyncGenerator
from taskiq import TaskiqDepends
async def dependency() -> AsyncGenerator[str, None]:
print("Startup")
await asyncio.sleep(0.1)
yield "value"
await asyncio.sleep(0.1)
print("Shutdown")
async def my_task(dep: Annotated[str, TaskiqDepends(dependency)]) -> None:
print(dep.upper())
import asyncio
from typing import AsyncGenerator
from taskiq import TaskiqDepends
async def dependency() -> AsyncGenerator[str, None]:
print("Startup")
await asyncio.sleep(0.1)
yield "value"
await asyncio.sleep(0.1)
print("Shutdown")
async def my_task(dep: str = TaskiqDepends(dependency)) -> None:
print(dep.upper())
Exception handling
Generator dependencies can handle exceptions that happen in tasks. This feature is handy if you want your system to be more atomic.
For example, if you open a database transaction in your dependency and want to commit it only if the function you execute is completed successfully.
from typing import Annotated
async def get_transaction(
db_driver: Annotated[DBDriver, TaskiqDepends(get_driver)],
) -> AsyncGenerator[Transaction, None]:
trans = db_driver.begin_transaction():
try:
# Here we give transaction to our dependant function.
yield trans
# If exception was found in dependant function,
# we rollback our transaction.
except Exception:
await trans.rollback()
return
# Here we commit if everything is fine.
await trans.commit()
async def get_transaction(
db_driver: DBDriver = TaskiqDepends(get_driver),
) -> AsyncGenerator[Transaction, None]:
trans = db_driver.begin_transaction():
try:
# Here we give transaction to our dependant function.
yield trans
# If exception was found in dependant function,
# we rollback our transaction.
except Exception:
await trans.rollback()
return
# Here we commit if everything is fine.
await trans.commit()
If you don't want to propagate exceptions in dependencies, you can add --no-propagate-errors
option to worker
command.
taskiq worker my_file:broker --no-propagate-errors
In this case, no exception will ever going to be propagated to any dependency.
Generics
Taskiq supports generic dependencies. You can create a generic class that is generic over another class and takskiq will be able to resolve generics based on type annotations.
Default dependencies
By default taskiq has only two dependencies:
- Context from
taskiq.context.Context
- TaskiqState from
taskiq.state.TaskiqState
Adding first-level dependencies
You can expand default list of available dependencies for you application. Taskiq have an ability to add new first-level dependencies using brokers.
The AsyncBroker interface has a function called add_dependency_context
and you can add more default dependencies to the taskiq. This may be useful for libraries if you want to add new dependencies to users.