1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
|
class RealtimeClient(RealtimeEventHandler):
def __init__(self, system_prompt: str):
super().__init__()
self.system_prompt = system_prompt
self.default_session_config = {
"modalities": ["text", "audio"],
"instructions": self.system_prompt,
"voice": "shimmer",
"input_audio_format": "pcm16",
"output_audio_format": "pcm16",
"input_audio_transcription": { "model": 'whisper-1' },
"turn_detection": { "type": 'server_vad' },
"tools": [],
"tool_choice": "auto",
"temperature": 0.8,
"max_response_output_tokens": 4096,
}
self.session_config = {}
self.transcription_models = [{"model": "whisper-1"}]
self.default_server_vad_config = {
"type": "server_vad",
"threshold": 0.5,
"prefix_padding_ms": 300,
"silence_duration_ms": 200,
}
self.realtime = RealtimeAPI()
self.conversation = RealtimeConversation()
self._reset_config()
self._add_api_event_handlers()
def _reset_config(self):
self.session_created = False
self.tools = {}
self.session_config = self.default_session_config.copy()
self.input_audio_buffer = bytearray()
return True
def _add_api_event_handlers(self):
self.realtime.on("client.*", self._log_event)
self.realtime.on("server.*", self._log_event)
self.realtime.on("server.session.created", self._on_session_created)
self.realtime.on("server.response.created", self._process_event)
self.realtime.on("server.response.output_item.added", self._process_event)
self.realtime.on("server.response.content_part.added", self._process_event)
self.realtime.on("server.input_audio_buffer.speech_started", self._on_speech_started)
self.realtime.on("server.input_audio_buffer.speech_stopped", self._on_speech_stopped)
self.realtime.on("server.conversation.item.created", self._on_item_created)
self.realtime.on("server.conversation.item.truncated", self._process_event)
self.realtime.on("server.conversation.item.deleted", self._process_event)
self.realtime.on("server.conversation.item.input_audio_transcription.completed", self._process_event)
self.realtime.on("server.response.audio_transcript.delta", self._process_event)
self.realtime.on("server.response.audio.delta", self._process_event)
self.realtime.on("server.response.text.delta", self._process_event)
self.realtime.on("server.response.function_call_arguments.delta", self._process_event)
self.realtime.on("server.response.output_item.done", self._on_output_item_done)
def _log_event(self, event):
realtime_event = {
"time": datetime.utcnow().isoformat(),
"source": "client" if event["type"].startswith("client.") else "server",
"event": event,
}
self.dispatch("realtime.event", realtime_event)
def _on_session_created(self, event):
self.session_created = True
def _process_event(self, event, *args):
item, delta = self.conversation.process_event(event, *args)
if item:
self.dispatch("conversation.updated", {"item": item, "delta": delta})
return item, delta
def _on_speech_started(self, event):
self._process_event(event)
self.dispatch("conversation.interrupted", event)
def _on_speech_stopped(self, event):
self._process_event(event, self.input_audio_buffer)
def _on_item_created(self, event):
item, delta = self._process_event(event)
self.dispatch("conversation.item.appended", {"item": item})
if item and item["status"] == "completed":
self.dispatch("conversation.item.completed", {"item": item})
async def _on_output_item_done(self, event):
item, delta = self._process_event(event)
if item and item["status"] == "completed":
self.dispatch("conversation.item.completed", {"item": item})
if item and item.get("formatted", {}).get("tool"):
await self._call_tool(item["formatted"]["tool"])
async def _call_tool(self, tool):
try:
print(tool["arguments"])
json_arguments = json.loads(tool["arguments"])
tool_config = self.tools.get(tool["name"])
if not tool_config:
raise Exception(f'Tool "{tool["name"]}" has not been added')
result = await tool_config["handler"](**json_arguments)
await self.realtime.send("conversation.item.create", {
"item": {
"type": "function_call_output",
"call_id": tool["call_id"],
"output": json.dumps(result),
}
})
except Exception as e:
logger.error(traceback.format_exc())
await self.realtime.send("conversation.item.create", {
"item": {
"type": "function_call_output",
"call_id": tool["call_id"],
"output": json.dumps({"error": str(e)}),
}
})
await self.create_response()
def is_connected(self):
return self.realtime.is_connected()
def reset(self):
self.disconnect()
self.realtime.clear_event_handlers()
self._reset_config()
self._add_api_event_handlers()
return True
async def connect(self):
if self.is_connected():
raise Exception("Already connected, use .disconnect() first")
await self.realtime.connect()
await self.update_session()
return True
async def wait_for_session_created(self):
if not self.is_connected():
raise Exception("Not connected, use .connect() first")
while not self.session_created:
await asyncio.sleep(0.001)
return True
async def disconnect(self):
self.session_created = False
self.conversation.clear()
if self.realtime.is_connected():
await self.realtime.disconnect()
def get_turn_detection_type(self):
return self.session_config.get("turn_detection", {}).get("type")
async def add_tool(self, definition, handler):
if not definition.get("name"):
raise Exception("Missing tool name in definition")
name = definition["name"]
if name in self.tools:
raise Exception(f'Tool "{name}" already added. Please use .removeTool("{name}") before trying to add again.')
if not callable(handler):
raise Exception(f'Tool "{name}" handler must be a function')
self.tools[name] = {"definition": definition, "handler": handler}
await self.update_session()
return self.tools[name]
def remove_tool(self, name):
if name not in self.tools:
raise Exception(f'Tool "{name}" does not exist, can not be removed.')
del self.tools[name]
return True
async def delete_item(self, id):
await self.realtime.send("conversation.item.delete", {"item_id": id})
return True
async def update_session(self, **kwargs):
self.session_config.update(kwargs)
use_tools = [
{**tool_definition, "type": "function"}
for tool_definition in self.session_config.get("tools", [])
] + [
{**self.tools[key]["definition"], "type": "function"}
for key in self.tools
]
session = {**self.session_config, "tools": use_tools}
if self.realtime.is_connected():
await self.realtime.send("session.update", {"session": session})
return True
async def create_conversation_item(self, item):
await self.realtime.send("conversation.item.create", {
"item": item
})
async def send_user_message_content(self, content=[]):
if content:
for c in content:
if c["type"] == "input_audio":
if isinstance(c["audio"], (bytes, bytearray)):
c["audio"] = array_buffer_to_base64(c["audio"])
await self.realtime.send("conversation.item.create", {
"item": {
"type": "message",
"role": "user",
"content": content,
}
})
await self.create_response()
return True
async def append_input_audio(self, array_buffer):
if len(array_buffer) > 0:
await self.realtime.send("input_audio_buffer.append", {
"audio": array_buffer_to_base64(np.array(array_buffer)),
})
self.input_audio_buffer.extend(array_buffer)
return True
async def create_response(self):
if self.get_turn_detection_type() is None and len(self.input_audio_buffer) > 0:
await self.realtime.send("input_audio_buffer.commit")
self.conversation.queue_input_audio(self.input_audio_buffer)
self.input_audio_buffer = bytearray()
await self.realtime.send("response.create")
return True
async def cancel_response(self, id=None, sample_count=0):
if not id:
await self.realtime.send("response.cancel")
return {"item": None}
else:
item = self.conversation.get_item(id)
if not item:
raise Exception(f'Could not find item "{id}"')
if item["type"] != "message":
raise Exception('Can only cancelResponse messages with type "message"')
if item["role"] != "assistant":
raise Exception('Can only cancelResponse messages with role "assistant"')
await self.realtime.send("response.cancel")
audio_index = next((i for i, c in enumerate(item["content"]) if c["type"] == "audio"), -1)
if audio_index == -1:
raise Exception("Could not find audio on item to cancel")
await self.realtime.send("conversation.item.truncate", {
"item_id": id,
"content_index": audio_index,
"audio_end_ms": int((sample_count / self.conversation.default_frequency) * 1000),
})
return {"item": item}
async def wait_for_next_item(self):
event = await self.wait_for_next("conversation.item.appended")
return {"item": event["item"]}
async def wait_for_next_completed_item(self):
event = await self.wait_for_next("conversation.item.completed")
return {"item": event["item"]}
|