Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: flatten core model params #168

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion posthog/ai/anthropic/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import uuid
from typing import Any, Dict, Optional

from posthog.ai.utils import call_llm_and_track_usage, get_model_params, merge_system_prompt, with_privacy_mode
from posthog.ai.utils import (
call_llm_and_track_usage,
extract_core_model_params,
get_model_params,
merge_system_prompt,
with_privacy_mode,
)
from posthog.client import Client as PostHogClient


Expand Down Expand Up @@ -187,6 +193,7 @@ def _capture_streaming_event(
"$ai_latency": latency,
"$ai_trace_id": posthog_trace_id,
"$ai_base_url": str(self._client.base_url),
**extract_core_model_params(kwargs, "anthropic"),
**(posthog_properties or {}),
}

Expand Down
4 changes: 3 additions & 1 deletion posthog/ai/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from langchain_core.outputs import ChatGeneration, LLMResult
from pydantic import BaseModel

from posthog.ai.utils import get_model_params, with_privacy_mode
from posthog.ai.utils import extract_core_model_params, get_model_params, with_privacy_mode
from posthog.client import Client

log = logging.getLogger("posthog")
Expand Down Expand Up @@ -178,6 +178,7 @@ def on_llm_end(
"$ai_latency": latency,
"$ai_trace_id": trace_id,
"$ai_base_url": run.get("base_url"),
**extract_core_model_params(run.get("model_params"), run.get("provider")),
**self._properties,
}
if self._distinct_id is None:
Expand Down Expand Up @@ -224,6 +225,7 @@ def on_llm_error(
"$ai_latency": latency,
"$ai_trace_id": trace_id,
"$ai_base_url": run.get("base_url"),
**extract_core_model_params(run.get("model_params"), run.get("provider")),
**self._properties,
}
if self._distinct_id is None:
Expand Down
3 changes: 2 additions & 1 deletion posthog/ai/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
except ImportError:
raise ModuleNotFoundError("Please install the OpenAI SDK to use this feature: 'pip install openai'")

from posthog.ai.utils import call_llm_and_track_usage, get_model_params, with_privacy_mode
from posthog.ai.utils import call_llm_and_track_usage, extract_core_model_params, get_model_params, with_privacy_mode
from posthog.client import Client as PostHogClient


Expand Down Expand Up @@ -167,6 +167,7 @@ def _capture_streaming_event(
"$ai_latency": latency,
"$ai_trace_id": posthog_trace_id,
"$ai_base_url": str(self._client.base_url),
**extract_core_model_params(kwargs, "openai"),
**posthog_properties,
}

Expand Down
31 changes: 31 additions & 0 deletions posthog/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,35 @@ def get_model_params(kwargs: Dict[str, Any]) -> Dict[str, Any]:
return model_params


def extract_core_model_params(kwargs: Dict[str, Any], provider: str) -> Dict[str, Any]:
"""
Extracts core model parameters from the kwargs dictionary.
"""
output = {}
if provider == "anthropic":
if "temperature" in kwargs:
output["$ai_temperature"] = kwargs.get("temperature")
if "max_tokens" in kwargs:
output["$ai_max_tokens"] = kwargs.get("max_tokens")
if "stream" in kwargs:
output["$ai_stream"] = kwargs.get("stream")
elif provider == "openai":
if "temperature" in kwargs:
output["$ai_temperature"] = kwargs.get("temperature")
if "max_completion_tokens" in kwargs:
output["$ai_max_tokens"] = kwargs.get("max_completion_tokens")
if "stream" in kwargs:
output["$ai_stream"] = kwargs.get("stream")
else: # default to openai params
if "temperature" in kwargs:
output["$ai_temperature"] = kwargs.get("temperature")
if "max_tokens" in kwargs:
output["$ai_max_tokens"] = kwargs.get("max_completion_tokens")
if "stream" in kwargs:
output["$ai_stream"] = kwargs.get("stream")
return output


def get_usage(response, provider: str) -> Dict[str, Any]:
if provider == "anthropic":
return {
Expand Down Expand Up @@ -148,6 +177,7 @@ def call_llm_and_track_usage(
"$ai_latency": latency,
"$ai_trace_id": posthog_trace_id,
"$ai_base_url": str(base_url),
**extract_core_model_params(kwargs, provider),
**(posthog_properties or {}),
}

Expand Down Expand Up @@ -218,6 +248,7 @@ async def call_llm_and_track_usage_async(
"$ai_latency": latency,
"$ai_trace_id": posthog_trace_id,
"$ai_base_url": str(base_url),
**extract_core_model_params(kwargs, provider),
**(posthog_properties or {}),
}

Expand Down
23 changes: 22 additions & 1 deletion posthog/test/ai/anthropic/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def test_streaming_system_prompt(mock_client, mock_anthropic_stream):

call_args = mock_client.capture.call_args[1]
props = call_args["properties"]

assert props["$ai_input"] == [{"role": "system", "content": "Foo"}, {"role": "user", "content": "Bar"}]


Expand Down Expand Up @@ -325,3 +324,25 @@ async def test_async_streaming_system_prompt(mock_client, mock_anthropic_stream)
{"role": "system", "content": "You must always answer with 'Bar'."},
{"role": "user", "content": "Foo"},
]


def test_core_model_params(mock_client, mock_anthropic_response):
with patch("anthropic.resources.Messages.create", return_value=mock_anthropic_response):
client = Anthropic(api_key="test-key", posthog_client=mock_client)
response = client.messages.create(
model="claude-3-opus-20240229",
temperature=0.5,
max_tokens=100,
stream=False,
messages=[{"role": "user", "content": "Hello"}],
posthog_distinct_id="test-id",
posthog_properties={"foo": "bar"},
)

assert response == mock_anthropic_response
props = mock_client.capture.call_args[1]["properties"]
assert props["$ai_model_parameters"] == {"temperature": 0.5, "max_tokens": 100, "stream": False}
assert props["$ai_temperature"] == 0.5
assert props["$ai_max_tokens"] == 100
assert props["$ai_stream"] == False
assert props["foo"] == "bar"
22 changes: 22 additions & 0 deletions posthog/test/ai/langchain/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,3 +771,25 @@ def test_tool_calls(mock_client):
}
]
assert "additional_kwargs" not in call["properties"]["$ai_output_choices"][0]


@pytest.mark.skipif(not OPENAI_API_KEY, reason="OPENAI_API_KEY is not set")
def test_core_model_params(mock_client):
prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
chain = prompt | ChatOpenAI(
api_key=OPENAI_API_KEY,
model="gpt-4",
temperature=0.5,
max_tokens=100,
stream=False,
)
callbacks = CallbackHandler(mock_client)
chain.invoke({}, config={"callbacks": [callbacks]})

assert mock_client.capture.call_count == 1
call = mock_client.capture.call_args[1]
assert call["properties"]["$ai_model_parameters"] == {"temperature": 0.5, "max_tokens": 100, "stream": False}
assert call["properties"]["$ai_temperature"] == 0.5
assert call["properties"]["$ai_max_tokens"] == 100
assert call["properties"]["$ai_stream"] == False
assert call["properties"]["foo"] == "bar"
25 changes: 25 additions & 0 deletions posthog/test/ai/openai/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,28 @@ def test_privacy_mode_global(mock_client, mock_openai_response):
props = call_args["properties"]
assert props["$ai_input"] is None
assert props["$ai_output_choices"] is None


def test_core_model_params(mock_client, mock_openai_response):
with patch("openai.resources.chat.completions.Completions.create", return_value=mock_openai_response):
client = OpenAI(api_key="test-key", posthog_client=mock_client)
response = client.chat.completions.create(
model="gpt-4",
temperature=0.5,
max_completion_tokens=100,
stream=False,
messages=[{"role": "user", "content": "Hello"}],
posthog_distinct_id="test-id",
posthog_properties={"foo": "bar"},
)

assert response == mock_openai_response
assert mock_client.capture.call_count == 1

call_args = mock_client.capture.call_args[1]
props = call_args["properties"]
assert props["$ai_model_parameters"] == {"temperature": 0.5, "max_completion_tokens": 100, "stream": False}
assert props["$ai_temperature"] == 0.5
assert props["$ai_max_tokens"] == 100
assert props["$ai_stream"] == False
assert props["foo"] == "bar"
Loading