Custom chat models
This notebook goes over how to create a custom chat model wrapper, in case you want to use your own chat model or a different wrapper than one that is directly supported in LangChain.
There are a few required things that a chat model needs to implement after extending the SimpleChatModel
class:
- A
_call
method that takes in a list of messages and call options (which includes things likestop
sequences), and returns a string. - A
_llmType
method that returns a string. Used for logging purposes only.
You can also implement the following optional method:
- A
_streamResponseChunks
method that returns anAsyncGenerator
and yieldsChatGenerationChunks
. This allows the LLM to support streaming outputs.
Let's implement a very simple custom chat model that just echoes back the first n
characters of the input.
import {
SimpleChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { AIMessageChunk, type BaseMessage } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";
export interface CustomChatModelInput extends BaseChatModelParams {
n: number;
}
export class CustomChatModel extends SimpleChatModel {
n: number;
constructor(fields: CustomChatModelInput) {
super(fields);
this.n = fields.n;
}
_llmType() {
return "custom";
}
async _call(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<string> {
if (!messages.length) {
throw new Error("No messages provided.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
return messages[0].content.slice(0, this.n);
}
async *_streamResponseChunks(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
for (const letter of messages[0].content.slice(0, this.n)) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: letter,
}),
text: letter,
});
// Trigger the appropriate callback for new chunks
await runManager?.handleLLMNewToken(letter);
}
}
}
We can now use this as any other chat model:
const chatModel = new CustomChatModel({ n: 4 });
await chatModel.invoke([["human", "I am an LLM"]]);
AIMessage {
content: 'I am',
additional_kwargs: {}
}
And support streaming:
const stream = await chatModel.stream([["human", "I am an LLM"]]);
for await (const chunk of stream) {
console.log(chunk);
}
AIMessageChunk {
content: 'I',
additional_kwargs: {}
}
AIMessageChunk {
content: ' ',
additional_kwargs: {}
}
AIMessageChunk {
content: 'a',
additional_kwargs: {}
}
AIMessageChunk {
content: 'm',
additional_kwargs: {}
}
Richer outputs
If you want to take advantage of LangChain's callback system for functionality like token tracking, you can extend the BaseChatModel
class and implement the lower level
_generate
method. It also takes a list of BaseMessage
s as input, but requires you to construct and return a ChatGeneration
object that permits additional metadata.
Here's an example:
import { AIMessage, BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";
import {
BaseChatModel,
BaseChatModelCallOptions,
BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
export interface AdvancedCustomChatModelOptions
extends BaseChatModelCallOptions {}
export interface AdvancedCustomChatModelParams extends BaseChatModelParams {
n: number;
}
export class AdvancedCustomChatModel extends BaseChatModel<AdvancedCustomChatModelOptions> {
n: number;
static lc_name(): string {
return "AdvancedCustomChatModel";
}
constructor(fields: AdvancedCustomChatModelParams) {
super(fields);
this.n = fields.n;
}
async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
const content = messages[0].content.slice(0, this.n);
const tokenUsage = {
usedTokens: this.n,
};
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: { tokenUsage },
};
}
_llmType(): string {
return "advanced_custom_chat_model";
}
}
This will pass the additional returned information in callback events and in the `streamEvents method:
const chatModel = new AdvancedCustomChatModel({ n: 4 });
const eventStream = await chatModel.streamEvents([["human", "I am an LLM"]], {
version: "v1",
});
for await (const event of eventStream) {
if (event.event === "on_llm_end") {
console.log(JSON.stringify(event, null, 2));
}
}
{
"event": "on_llm_end",
"name": "AdvancedCustomChatModel",
"run_id": "b500b98d-bee5-4805-9b92-532a491f5c70",
"tags": [],
"metadata": {},
"data": {
"output": {
"generations": [
[
{
"message": {
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"messages",
"AIMessage"
],
"kwargs": {
"content": "I am",
"additional_kwargs": {}
}
},
"text": "I am"
}
]
],
"llmOutput": {
"tokenUsage": {
"usedTokens": 4
}
}
}
}
}
Tracing (advanced)
If you are implementing a custom chat model and want to use it with a tracing service like LangSmith,
you can automatically log params used for a given invocation by implementing the invocationParams()
method on the model.
This method is purely optional, but anything it returns will be logged as metadata for the trace.
Here's one pattern you might use:
export interface CustomChatModelOptions extends BaseChatModelCallOptions {
// Some required or optional inner args
tools: Record<string, any>[];
}
export interface CustomChatModelParams extends BaseChatModelParams {
temperature: number;
}
export class CustomChatModel extends BaseChatModel<CustomChatModelOptions> {
temperature: number;
static lc_name(): string {
return "CustomChatModel";
}
constructor(fields: CustomChatModelParams) {
super(fields);
this.temperature = fields.temperature;
}
// Anything returned in this method will be logged as metadata in the trace.
// It is common to pass it any options used to invoke the function.
invocationParams(options?: this["ParsedCallOptions"]) {
return {
tools: options?.tools,
n: this.n,
};
}
async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
const additionalParams = this.invocationParams(options);
const content = await someAPIRequest(messages, additionalParams);
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: {},
};
}
_llmType(): string {
return "advanced_custom_chat_model";
}
}