diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d9c946f41..06b4379ac 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,7 +1,7 @@ # Copyright 2024 Google LLC # SPDX-License-Identifier: Apache-2.0 -name: "JS: Run Tests" +name: "JS: Run Tests and Build" on: push: @@ -38,3 +38,6 @@ jobs: - name: Run tests run: cd js && pnpm test + + - name: Build + run: cd js && pnpm build diff --git a/captainhook.json b/captainhook.json index ba275bffe..d6be284cf 100644 --- a/captainhook.json +++ b/captainhook.json @@ -43,7 +43,16 @@ "run": "uv run --directory python mypy ." }, { - "run": "scripts/run_tests" + "run": "scripts/run_python_tests" + }, + { + "run": "scripts/run_js_tests" + }, + { + "run": "scripts/run_go_tests" + }, + { + "run": "pnpm -C js build" }, { "run": "uv run mkdocs build" @@ -80,7 +89,16 @@ "run": "uv run --directory python mypy ." }, { - "run": "scripts/run_tests" + "run": "scripts/run_python_tests" + }, + { + "run": "scripts/run_js_tests" + }, + { + "run": "scripts/run_go_tests" + }, + { + "run": "pnpm -C js build" }, { "run": "uv run mkdocs build" diff --git a/js/src/parse.test.ts b/js/src/parse.test.ts index e8aa912cd..38e27f90c 100644 --- a/js/src/parse.test.ts +++ b/js/src/parse.test.ts @@ -11,20 +11,27 @@ import { convertNamespacedEntryToNestedObject, extractFrontmatterAndBody, insertHistory, + messageSourcesToMessages, + messagesHaveHistory, parseDocument, + parseMediaPart, + parsePart, + parseSectionPart, + parseTextPart, splitByMediaAndSectionMarkers, splitByRegex, splitByRoleAndHistoryMarkers, - toParts, + toMessages, transformMessagesToHistory, } from './parse'; -import type { Message } from './types'; +import type { MessageSource } from './parse'; +import type { DataArgument, Message } from './types'; describe('ROLE_AND_HISTORY_MARKER_REGEX', () => { describe('valid patterns', () => { const validPatterns = [ '<<>>', - '<<>>', + '<<>>', '<<>>', '<<>>', '<<>>', @@ -42,7 +49,7 @@ describe('ROLE_AND_HISTORY_MARKER_REGEX', () => { describe('invalid patterns', () => { const invalidPatterns = [ '<<>>', // uppercase not allowed - '<<>>', // numbers not allowed + '<<>>', // numbers not allowed '<<>>', // needs at least one letter '<<>>', // missing role value '<<>>', // history should be exact @@ -62,7 +69,7 @@ describe('ROLE_AND_HISTORY_MARKER_REGEX', () => { it('should match multiple occurrences in a string', () => { const text = ` <<>> Hello - <<>> Hi there + <<>> Hi there <<>> <<>> How are you? `; @@ -107,9 +114,9 @@ describe('splitByRoleAndHistoryMarkers', () => { }); it('splits a string with a single marker correctly', () => { - const input = 'Hello <<>> world'; + const input = 'Hello <<>> world'; const output = splitByRoleAndHistoryMarkers(input); - expect(output).toEqual(['Hello ', '<< { @@ -232,8 +239,9 @@ describe('extractFrontmatterAndBody', () => { expect(body).toBe('This is the body.'); }); - it('should not extract frontmatter when there is no frontmatter', () => { - // The frontmatter is not optional. + it('should match as empty frontmatter and body when there is no frontmatter', () => { + // Both the frontmatter and the body match as empty when there is no + // frontmatter. const source = 'No frontmatter here.'; const { frontmatter, body } = extractFrontmatterAndBody(source); expect(frontmatter).toBe(''); @@ -241,7 +249,7 @@ describe('extractFrontmatterAndBody', () => { }); }); -describe('splitIntoParts', () => { +describe('splitByMediaAndSectionMarkers', () => { it('should return entire string in an array if there are no markers', () => { const source = 'This is a test string.'; const parts = splitByMediaAndSectionMarkers(source); @@ -260,29 +268,65 @@ describe('splitIntoParts', () => { '!', ]); }); +}); + +describe('splitByRegex', () => { + it('should split string by regex and filter empty/whitespace pieces', () => { + const source = ' one , , two , three '; + const result = splitByRegex(source, /,/g); + expect(result).toEqual([' one ', ' two ', ' three ']); + }); + + it('should handle string with no matches', () => { + const source = 'no matches here'; + const result = splitByRegex(source, /,/g); + expect(result).toEqual(['no matches here']); + }); - it('should remove parts that are only whitespace', () => { - const source = ' <<>> '; - const result = toParts(source); - expect(result).toEqual([{ media: { url: undefined } }]); + it('should return empty array for empty string', () => { + const result = splitByRegex('', /,/g); + expect(result).toEqual([]); }); }); describe('transformMessagesToHistory', () => { - it('should add history purpose to messages without metadata', () => { - const messages: Message[] = [{ content: 'Hello' }, { content: 'World' }]; + it('should add history metadata to messages', () => { + const messages: Message[] = [ + { role: 'user', content: [{ text: 'Hello' }] }, + { role: 'model', content: [{ text: 'Hi there' }] }, + ]; + const result = transformMessagesToHistory(messages); + + expect(result).toHaveLength(2); expect(result).toEqual([ - { content: 'Hello', metadata: { purpose: 'history' } }, - { content: 'World', metadata: { purpose: 'history' } }, + { + role: 'user', + content: [{ text: 'Hello' }], + metadata: { purpose: 'history' }, + }, + { + role: 'model', + content: [{ text: 'Hi there' }], + metadata: { purpose: 'history' }, + }, ]); }); it('should preserve existing metadata while adding history purpose', () => { - const messages = [{ content: 'Test', metadata: { foo: 'bar' } }]; + const messages: Message[] = [ + { role: 'user', content: [{ text: 'Hello' }], metadata: { foo: 'bar' } }, + ]; + const result = transformMessagesToHistory(messages); + + expect(result).toHaveLength(1); expect(result).toEqual([ - { content: 'Test', metadata: { foo: 'bar', purpose: 'history' } }, + { + role: 'user', + content: [{ text: 'Hello' }], + metadata: { foo: 'bar', purpose: 'history' }, + }, ]); }); @@ -292,118 +336,448 @@ describe('transformMessagesToHistory', () => { }); }); -describe('splitByRegex', () => { - it('should split string by regex and filter empty/whitespace pieces', () => { - const source = ' one , , two , three '; - const result = splitByRegex(source, /,/g); - expect(result).toEqual([' one ', ' two ', ' three ']); +describe('messagesHaveHistory', () => { + it('should return true if messages have history metadata', () => { + const messages: Message[] = [ + { + role: 'user', + content: [{ text: 'Hello' }], + metadata: { purpose: 'history' }, + }, + ]; + + const result = messagesHaveHistory(messages); + + expect(result).toBe(true); }); - it('should handle string with no matches', () => { - const source = 'no matches here'; - const result = splitByRegex(source, /,/g); - expect(result).toEqual(['no matches here']); + it('should return false if messages do not have history metadata', () => { + const messages: Message[] = [ + { role: 'user', content: [{ text: 'Hello' }] }, + ]; + + const result = messagesHaveHistory(messages); + + expect(result).toBe(false); + }); +}); + +describe('messageSourcesToMessages', () => { + it('should handle empty array', () => { + const messageSources: MessageSource[] = []; + const expected: Message[] = []; + expect(messageSourcesToMessages(messageSources)).toEqual(expected); }); - it('should return empty array for empty string', () => { - const result = splitByRegex('', /,/g); - expect(result).toEqual([]); + it('should convert a single message source', () => { + const messageSources: MessageSource[] = [{ role: 'user', source: 'Hello' }]; + const expected: Message[] = [ + { role: 'user', content: [{ text: 'Hello' }] }, + ]; + expect(messageSourcesToMessages(messageSources)).toEqual(expected); + }); + + it('should handle message source with content', () => { + const messageSources: MessageSource[] = [ + { role: 'user', content: [{ text: 'Existing content' }] }, + ]; + const expected: Message[] = [ + { role: 'user', content: [{ text: 'Existing content' }] }, + ]; + expect(messageSourcesToMessages(messageSources)).toEqual(expected); + }); + + it('should handle message source with metadata', () => { + const messageSources: MessageSource[] = [ + { role: 'user', source: 'Hello', metadata: { foo: 'bar' } }, + ]; + const expected: Message[] = [ + { + role: 'user', + content: [{ text: 'Hello' }], + metadata: { foo: 'bar' }, + }, + ]; + expect(messageSourcesToMessages(messageSources)).toEqual(expected); + }); + + it('should filter out message sources with empty source and content', () => { + const messageSources: MessageSource[] = [ + { role: 'user', source: '' }, + { role: 'model', source: ' ' }, + { role: 'user', source: 'Hello' }, + ]; + const expected: Message[] = [ + { role: 'model', content: [] }, + { role: 'user', content: [{ text: 'Hello' }] }, + ]; + expect(messageSourcesToMessages(messageSources)).toEqual(expected); + }); + + it('should handle multiple message sources', () => { + const messageSources: MessageSource[] = [ + { role: 'user', source: 'Hello' }, + { role: 'model', source: 'Hi there!' }, + { role: 'user', source: 'How are you?' }, + ]; + const expected: Message[] = [ + { role: 'user', content: [{ text: 'Hello' }] }, + { role: 'model', content: [{ text: 'Hi there!' }] }, + { role: 'user', content: [{ text: 'How are you?' }] }, + ]; + expect(messageSourcesToMessages(messageSources)).toEqual(expected); + }); +}); + +describe('toMessages', () => { + it('should handle a simple string with no markers', () => { + const renderedString = 'Hello world'; + const result = toMessages(renderedString); + + expect(result).toHaveLength(1); + expect(result[0].role).toBe('user'); + expect(result[0].content).toEqual([{ text: 'Hello world' }]); + }); + + it('should handle a string with a single role marker', () => { + const renderedString = '<<>>Hello world'; + const result = toMessages(renderedString); + + expect(result).toHaveLength(1); + expect(result[0].role).toBe('model'); + expect(result[0].content).toEqual([{ text: 'Hello world' }]); + }); + + it('should handle a string with multiple role markers', () => { + const renderedString = + '<<>>System instructions\n' + + '<<>>User query\n' + + '<<>>Model response'; + const result = toMessages(renderedString); + + expect(result).toHaveLength(3); + + expect(result[0].role).toBe('system'); + expect(result[0].content).toEqual([{ text: 'System instructions\n' }]); + + expect(result[1].role).toBe('user'); + expect(result[1].content).toEqual([{ text: 'User query\n' }]); + + expect(result[2].role).toBe('model'); + expect(result[2].content).toEqual([{ text: 'Model response' }]); + }); + + it('should update the role of an empty message instead of creating a new one', () => { + const renderedString = + '<<>><<>>Response'; + const result = toMessages(renderedString); + + // Should only have one message since the first role marker doesn't have content + expect(result).toHaveLength(1); + expect(result[0].role).toBe('model'); + expect(result[0].content).toEqual([{ text: 'Response' }]); + }); + + it('should handle history markers and add metadata', () => { + const renderedString = + '<<>>Query<<>>Follow-up'; + const historyMessages: Message[] = [ + { role: 'user', content: [{ text: 'Previous question' }] }, + { role: 'model', content: [{ text: 'Previous answer' }] }, + ]; + + const data: DataArgument = { messages: historyMessages }; + const result = toMessages(renderedString, data); + + expect(result).toHaveLength(4); + + // First message is the user query + expect(result[0].role).toBe('user'); + expect(result[0].content).toEqual([{ text: 'Query' }]); + + // Next two messages should be history messages with appropriate metadata + expect(result[1].role).toBe('user'); + expect(result[1].content).toEqual([{ text: 'Previous question' }]); + expect(result[1].metadata).toEqual({ purpose: 'history' }); + + expect(result[2].role).toBe('model'); + expect(result[2].content).toEqual([{ text: 'Previous answer' }]); + expect(result[2].metadata).toEqual({ purpose: 'history' }); + + // Last message is the follow-up + expect(result[3].role).toBe('model'); + expect(result[3].content).toEqual([{ text: 'Follow-up' }]); + }); + + it('should handle empty history gracefully', () => { + const renderedString = + '<<>>Query<<>>Follow-up'; + const result = toMessages(renderedString, { messages: [] }); + + expect(result).toHaveLength(2); + expect(result[0].role).toBe('user'); + expect(result[0].content).toEqual([{ text: 'Query' }]); + expect(result[1].role).toBe('model'); + expect(result[1].content).toEqual([{ text: 'Follow-up' }]); + }); + + it('should handle undefined data gracefully', () => { + const renderedString = + '<<>>Query<<>>Follow-up'; + const result = toMessages(renderedString, undefined); + + expect(result).toHaveLength(2); + expect(result[0].role).toBe('user'); + expect(result[0].content).toEqual([{ text: 'Query' }]); + expect(result[1].role).toBe('model'); + expect(result[1].content).toEqual([{ text: 'Follow-up' }]); + }); + + it('should filter out empty messages', () => { + const renderedString = + '<<>> ' + + '<<>> ' + + '<<>>Response'; + const result = toMessages(renderedString); + + expect(result).toHaveLength(1); + expect(result[0].role).toBe('model'); + expect(result[0].content).toEqual([{ text: 'Response' }]); + }); + + it('should handle multiple history markers by treating each as a separate insertion point', () => { + const renderedString = + '<<>>First<<>>Second'; + const historyMessages: Message[] = [ + { role: 'user', content: [{ text: 'Previous' }] }, + ]; + + const data: DataArgument = { messages: historyMessages }; + const result = toMessages(renderedString, data); + + expect(result).toHaveLength(4); + + expect(result[0].metadata).toEqual({ purpose: 'history' }); + expect(result[1].content).toEqual([{ text: 'First' }]); + expect(result[2].metadata).toEqual({ purpose: 'history' }); + expect(result[3].content).toEqual([{ text: 'Second' }]); + }); + + it('should support complex interleaving of role and history markers', () => { + const renderedString = + '<<>>Instructions\n' + + '<<>>Initial Query\n' + + '<<>>\n' + + '<<>>Follow-up Question\n' + + '<<>>Final Response'; + + const historyMessages: Message[] = [ + { role: 'user', content: [{ text: 'Previous question' }] }, + { role: 'model', content: [{ text: 'Previous answer' }] }, + ]; + + const data: DataArgument = { messages: historyMessages }; + const result = toMessages(renderedString, data); + + expect(result).toHaveLength(6); + + expect(result[0].role).toBe('system'); + expect(result[0].content).toEqual([{ text: 'Instructions\n' }]); + + expect(result[1].role).toBe('user'); + expect(result[1].content).toEqual([{ text: 'Initial Query\n' }]); + + expect(result[2].role).toBe('user'); + expect(result[2].metadata).toEqual({ purpose: 'history' }); + + expect(result[3].role).toBe('model'); + expect(result[3].metadata).toEqual({ purpose: 'history' }); + + expect(result[4].role).toBe('user'); + expect(result[4].content).toEqual([{ text: 'Follow-up Question\n' }]); + + expect(result[5].role).toBe('model'); + expect(result[5].content).toEqual([{ text: 'Final Response' }]); + }); + + it('should handle an empty input string', () => { + const result = toMessages(''); + expect(result).toHaveLength(0); + }); + + it('should properly call insertHistory with data.messages', () => { + const renderedString = '<<>>Question'; + const historyMessages: Message[] = [ + { role: 'user', content: [{ text: 'Previous' }] }, + ]; + + const data: DataArgument = { messages: historyMessages }; + const result = toMessages(renderedString, data); + + // The resulting messages should have the history message inserted + // before the user message by the insertHistory function + expect(result).toHaveLength(2); + expect(result[0].role).toBe('user'); + expect(result[0].content).toEqual([{ text: 'Previous' }]); + expect(result[0].metadata).toBeUndefined(); // insertHistory shouldn't add history metadata + + expect(result[1].role).toBe('user'); + expect(result[1].content).toEqual([{ text: 'Question' }]); }); }); describe('insertHistory', () => { - it('should insert history messages at the correct position', () => { + it('should return original messages if history is undefined', () => { const messages: Message[] = [ - { role: 'user', content: 'first' }, - { role: 'model', content: 'second', metadata: { purpose: 'history' } }, - { role: 'user', content: 'third' }, - ]; - const history: Message[] = [ - { role: 'user', content: 'past1' }, - { role: 'assistant', content: 'past2' }, + { role: 'user', content: [{ text: 'Hello' }] }, ]; - const result = insertHistory(messages, history); - // Since there's already a history marker, the original messages should be - // returned unchanged. + + const result = insertHistory(messages, []); + expect(result).toEqual(messages); }); - it('should handle empty history', () => { - const messages = [ - { role: 'user', content: 'first' }, - { role: 'user', content: 'second' }, + it('should return original messages if history purpose already exists', () => { + const messages: Message[] = [ + { + role: 'user', + content: [{ text: 'Hello' }], + metadata: { purpose: 'history' }, + }, ]; - const result = insertHistory(messages); + + const history: Message[] = [ + { + role: 'model', + content: [{ text: 'Previous' }], + metadata: { purpose: 'history' }, + }, + ]; + + const result = insertHistory(messages, history); + expect(result).toEqual(messages); }); - it('should append history if no history marker and no trailing user message', () => { - const messages = [ - { role: 'user', content: 'first' }, - { role: 'assistant', content: 'second' }, + it('should insert history before the last user message', () => { + const messages: Message[] = [ + { role: 'system', content: [{ text: 'System prompt' }] }, + { role: 'user', content: [{ text: 'Current question' }] }, ]; - const history = [ - { role: 'user', content: 'past1' }, - { role: 'assistant', content: 'past2' }, + + const history: Message[] = [ + { + role: 'model', + content: [{ text: 'Previous' }], + metadata: { purpose: 'history' }, + }, ]; + const result = insertHistory(messages, history); - expect(result).toEqual([...messages, ...history]); + + expect(result).toHaveLength(3); + expect(result).toEqual([ + { role: 'system', content: [{ text: 'System prompt' }] }, + { + role: 'model', + content: [{ text: 'Previous' }], + metadata: { purpose: 'history' }, + }, + { role: 'user', content: [{ text: 'Current question' }] }, + ]); }); - it('should insert history before last user message if no history marker', () => { - const messages = [ - { role: 'user', content: 'first' }, - { role: 'assistant', content: 'second' }, - { role: 'user', content: 'third' }, + it('should append history at the end if no user message is last', () => { + const messages: Message[] = [ + { role: 'system', content: [{ text: 'System prompt' }] }, + { role: 'model', content: [{ text: 'Model message' }] }, ]; - const history = [ - { role: 'user', content: 'past1' }, - { role: 'assistant', content: 'past2' }, + const history: Message[] = [ + { + role: 'model', + content: [{ text: 'Previous' }], + metadata: { purpose: 'history' }, + }, ]; + const result = insertHistory(messages, history); + + expect(result).toHaveLength(3); expect(result).toEqual([ - { role: 'user', content: 'first' }, - { role: 'assistant', content: 'second' }, - ...history, - { role: 'user', content: 'third' }, + { role: 'system', content: [{ text: 'System prompt' }] }, + { role: 'model', content: [{ text: 'Model message' }] }, + { + role: 'model', + content: [{ text: 'Previous' }], + metadata: { purpose: 'history' }, + }, ]); }); }); -describe('toParts', () => { - it('should convert text content to parts', () => { +describe('parsePart', () => { + it('should parse a media part', () => { + const source = '<<>> https://example.com/image.jpg'; + const result = parsePart(source); + expect(result).toEqual({ media: { url: 'https://example.com/image.jpg' } }); + }); + + it('should parse a section piece', () => { + const source = '<<>> code'; + const result = parsePart(source); + expect(result).toEqual({ metadata: { purpose: 'code', pending: true } }); + }); + + it('should parse a text piece', () => { const source = 'Hello World'; - const result = toParts(source); - expect(result).toEqual([{ text: 'Hello World' }]); + const result = parsePart(source); + expect(result).toEqual({ text: 'Hello World' }); }); +}); - it('should handle media markers', () => { +describe('parseMediaPart', () => { + it('should parse a media part', () => { const source = '<<>> https://example.com/image.jpg'; - const result = toParts(source); - expect(result).toEqual([ - { media: { url: undefined } }, - { text: ' https://example.com/image.jpg' }, - ]); + const result = parseMediaPart(source); + expect(result).toEqual({ media: { url: 'https://example.com/image.jpg' } }); + }); + + it('should parse a media piece with content type', () => { + const source = + '<<>> https://example.com/image.jpg image/jpeg'; + const result = parseMediaPart(source); + expect(result).toEqual({ + media: { + url: 'https://example.com/image.jpg', + contentType: 'image/jpeg', + }, + }); }); - it('should handle section markers', () => { + it('should throw an error if the media piece is invalid', () => { + const source = 'https://example.com/image.jpg'; + expect(() => parseMediaPart(source)).toThrow(); + }); +}); + +describe('parseSectionPart', () => { + it('should parse a section part', () => { const source = '<<>> code'; - const result = toParts(source); - expect(result).toEqual([ - { metadata: { purpose: undefined, pending: true } }, - { text: ' code' }, - ]); + const result = parseSectionPart(source); + expect(result).toEqual({ metadata: { purpose: 'code', pending: true } }); }); - it('should handle mixed content', () => { - const source = - 'Text before <<>> https://example.com/image.jpg Text after'; - const result = toParts(source); - expect(result).toEqual([ - { text: 'Text before ' }, - { media: { url: undefined } }, - { text: ' https://example.com/image.jpg Text after' }, - ]); + it('should throw an error if the section piece is invalid', () => { + const source = 'https://example.com/image.jpg'; + expect(() => parseSectionPart(source)).toThrow(); + }); +}); + +describe('parseTextPart', () => { + it('should parse a text part', () => { + const source = 'Hello World'; + const result = parseTextPart(source); + expect(result).toEqual({ text: 'Hello World' }); }); }); diff --git a/js/src/parse.ts b/js/src/parse.ts index 3f37c901c..619123a82 100644 --- a/js/src/parse.ts +++ b/js/src/parse.ts @@ -10,9 +10,43 @@ import type { Message, ParsedPrompt, Part, + PendingPart, PromptMetadata, + Role, + TextPart, } from './types'; +/** + * A message source is a message with a source string and optional content and + * metadata. + */ +export type MessageSource = { + role: Role; + source?: string; + content?: Message['content']; + metadata?: Record; +}; + +/** + * Prefixes for the role markers in the template. + */ +export const ROLE_MARKER_PREFIX = '<<>> and * <<>> markers in the template. @@ -66,6 +80,26 @@ export const ROLE_AND_HISTORY_MARKER_REGEX = export const MEDIA_AND_SECTION_MARKER_REGEX = /(<<>>/g; +/** + * List of reserved keywords that are handled specially in the metadata. + * These keys are processed differently from extension metadata. + */ +const RESERVED_METADATA_KEYWORDS: (keyof PromptMetadata)[] = [ + // NOTE: KEEP SORTED + 'config', + 'description', + 'ext', + 'input', + 'model', + 'name', + 'output', + 'raw', + 'toolDefs', + 'tools', + 'variant', + 'version', +]; + /** * Default metadata structure with empty extension and configuration objects. */ @@ -135,7 +169,8 @@ export function convertNamespacedEntryToNestedObject( * Extracts the YAML frontmatter and body from a document. * * @param source The source document containing frontmatter and template - * @returns An object containing the frontmatter and body + * @returns An object containing the frontmatter and body If the pattern does + * not match, both the values returned will be empty. */ export function extractFrontmatterAndBody(source: string) { const match = source.match(FRONTMATTER_AND_BODY_REGEX); @@ -183,6 +218,44 @@ export function parseDocument>( return { ...BASE_METADATA, template: source }; } +/** + * Processes an array of message sources into an array of messages. + * + * @param messageSources Array of message sources + * @returns Array of structured messages + */ +export function messageSourcesToMessages( + messageSources: MessageSource[] +): Message[] { + return messageSources + .filter((ms) => ms.content || ms.source) + .map((m) => { + const out: Message = { + role: m.role as Role, + content: m.content || toParts(m.source || ''), + }; + if (m.metadata) { + out.metadata = m.metadata; + } + return out; + }); +} + +/** + * Transforms an array of messages by adding history metadata to each message. + * + * @param messages Array of messages to transform + * @returns Array of messages with history metadata added + */ +export function transformMessagesToHistory( + messages: Array +): Array { + return messages.map((m) => ({ + ...m, + metadata: { ...m.metadata, purpose: 'history' }, + })); +} + /** * Converts a rendered template string into an array of messages. Processes * role markers and history placeholders to structure the conversation. @@ -192,73 +265,65 @@ export function parseDocument>( * @param data Optional data containing message history * @return Array of structured messages */ -export function toMessages>( +export function toMessages>( renderedString: string, data?: DataArgument ): Message[] { - let currentMessage: { role: string; source: string } = { - role: 'user', - source: '', - }; - const messageSources: { - role: string; - source?: string; - content?: Message['content']; - metadata?: Record; - }[] = [currentMessage]; + let currentMessage: MessageSource = { role: 'user', source: '' }; + const messageSources: MessageSource[] = [currentMessage]; for (const piece of splitByRoleAndHistoryMarkers(renderedString)) { - if (piece.startsWith('<< ms.content || ms.source) - .map((m) => { - const out: Message = { - role: m.role as Message['role'], - content: m.content || toParts(m.source!), - }; - if (m.metadata) out.metadata = m.metadata; - return out; - }); - + const messages: Message[] = messageSourcesToMessages(messageSources); return insertHistory(messages, data?.messages); } /** - * Transforms an array of messages by adding history metadata to each message. + * Checks if the messages have history metadata. * - * @param messages Array of messages to transform - * @returns Array of messages with history metadata added + * @param messages The messages to check + * @return True if the messages have history metadata, false otherwise */ -export function transformMessagesToHistory( - messages: Array<{ metadata?: Record }> -): Array<{ metadata: Record }> { - return messages.map((m) => ({ - ...m, - metadata: { ...m.metadata, purpose: 'history' }, - })); +export function messagesHaveHistory(messages: Message[]): boolean { + return messages.some((m) => m.metadata?.purpose === 'history'); } /** - * Inserts historical messages into the conversation at the appropriate - * position. + * Inserts historical messages into the conversation at appropriate positions. + * + * The history is inserted at: + * - Before the last user message if there is a user message. + * - The end of the conversation if there is no history or no user message. + * + * The history is not inserted: + * - If it already exists in the messages. + * - If there is no user message. * * @param messages Current array of messages * @param history Historical messages to insert @@ -268,11 +333,19 @@ export function insertHistory( messages: Message[], history: Message[] = [] ): Message[] { - if (!history || messages.find((m) => m.metadata?.purpose === 'history')) + // If we have no history or find an existing instance of history, return the + // original messages unmodified. + if (!history || messagesHaveHistory(messages)) { return messages; - if (messages.at(-1)?.role === 'user') { - return [...messages.slice(0, -1)!, ...history!, messages.at(-1)!]; } + + // If the last message is a user message, insert the history before it. + const lastMessage = messages.at(-1); + if (lastMessage?.role === 'user') { + const messagesWithoutLast = messages.slice(0, -1); + return [...messagesWithoutLast, ...history, lastMessage]; + } + // Otherwise, append the history to the end of the messages. return [...messages, ...history]; } @@ -284,26 +357,62 @@ export function insertHistory( * @return Array of structured parts (text, media, or metadata) */ export function toParts(source: string): Part[] { - const parts: Part[] = []; - const pieces = splitByMediaAndSectionMarkers(source); - for (let i = 0; i < pieces.length; i++) { - const piece = pieces[i]; - if (piece.startsWith('<<>> and +# <<>> markers in the template. +# +# Examples of matching patterns: +# - <<>> +# - <<>> +# - <<>> +# - <<>> +# +# Note: Only lowercase letters are allowed after 'role:'. +ROLE_AND_HISTORY_MARKER_REGEX = re.compile( + r'(<<>>' +) + +# Regular expression to match <<>> and +# <<>> markers in the template. +# +# Examples of matching patterns: +# - <<>> +# - <<>> +MEDIA_AND_SECTION_MARKER_REGEX = re.compile( + r'(<<>>' +) + +# List of reserved keywords that are handled specially in the metadata. +# These keys are processed differently from extension metadata. +RESERVED_METADATA_KEYWORDS = [ + # NOTE: KEEP SORTED + 'config', + 'description', + 'ext', + 'input', + 'model', + 'name', + 'output', + 'raw', + 'toolDefs', + 'tools', + 'variant', + 'version', +] + +# Default metadata structure with empty extension and configuration objects. +BASE_METADATA: PromptMetadata[Any] = PromptMetadata( + ext={}, + metadata={}, + config={}, +) + + +def split_by_regex(source: str, regex: re.Pattern[str]) -> list[str]: + """Splits a string by a regular expression while filtering out + empty/whitespace-only pieces. + + Args: + source: The source string to split into parts. + regex: The regular expression to use for splitting. + + Returns: + An array of non-empty string pieces. + """ + return [s for s in regex.split(source) if s.strip()] + + +def split_by_role_and_history_markers(rendered_string: str) -> list[str]: + """Splits a rendered template string into pieces based on role and history + markers while filtering out empty/whitespace-only pieces. + + Args: + rendered_string: The template string to split. + + Returns: + Array of non-empty string pieces. + """ + return split_by_regex(rendered_string, ROLE_AND_HISTORY_MARKER_REGEX) + + +def split_by_media_and_section_markers(source: str) -> list[str]: + """Split the source into pieces based on media and section markers while + filtering out empty/whitespace-only pieces. + + Args: + source: The source string to split into parts + + Returns: + An array of string parts + """ + return split_by_regex(source, MEDIA_AND_SECTION_MARKER_REGEX) + + +def convert_namespaced_entry_to_nested_object( + key: str, + value: Any, + obj: dict[str, dict[str, Any]] | None = None, +) -> dict[str, dict[str, Any]]: + """Processes a namespaced key-value pair into a nested object structure. + For example, 'foo.bar': 'value' becomes { foo: { bar: 'value' } } + + Args: + key: The dotted namespace key (e.g., 'foo.bar') + value: The value to assign + obj: The object to add the namespaced value to + + Returns: + The updated target object + """ + if obj is None: + obj = {} + + last_dot_index = key.rindex('.') + ns = key[:last_dot_index] + field = key[last_dot_index + 1 :] + obj.setdefault(ns, {}) + obj[ns][field] = value + return obj + + +def extract_frontmatter_and_body(source: str) -> tuple[str, str]: + """Extracts the YAML frontmatter and body from a document. + + Args: + source: The source document containing frontmatter and template + + Returns: + A tuple containing the frontmatter and body If the pattern does not + match, both the values returned will be empty. + """ + match = FRONTMATTER_AND_BODY_REGEX.match(source) + if match: + frontmatter, body = match.groups() + return frontmatter, body + return '', '' + + +# def parse_document(source: str) -> ParsedPrompt[T]: +# """Parses a .dotprompt document. +# +# The frontmatter YAML contains metadata and configuration for the prompt. +# +# Args: +# source: The source document containing frontmatter and template +# +# Returns: +# Parsed prompt with metadata and template content +# """ +# # TODO: Implement this +# pass + + +def to_messages( + rendered_string: str, + data: DataArgument[Any] | None = None, +) -> list[Message]: + """ + Converts a rendered template string into an array of messages. Processes + role markers and history placeholders to structure the conversation. + + Args: + rendered_string: The rendered template string to convert + data: Optional data containing message history + + Returns: + List of structured messages + """ + current_message = MessageSource(role=Role.USER, source='') + message_sources = [current_message] + + for piece in split_by_role_and_history_markers(rendered_string): + if piece.startswith(ROLE_MARKER_PREFIX): + role = piece[len(ROLE_MARKER_PREFIX) :] + + if current_message.source and current_message.source.strip(): + # If the current message has content, create a new message + current_message = MessageSource(role=Role(role), source='') + message_sources.append(current_message) + else: + # Otherwise, update the role of the current message + current_message.role = Role(role) + + elif piece.startswith(HISTORY_MARKER_PREFIX): + # Add the history messages to the message sources + msgs: list[Message] = [] + if data and data.messages: + msgs = data.messages + history_messages = transform_messages_to_history(msgs) + if history_messages: + message_sources.extend( + [ + MessageSource( + role=msg.role, + content=msg.content, + metadata=msg.metadata, + ) + for msg in history_messages + ] + ) + + # Add a new message source for the model + current_message = MessageSource(role=Role.MODEL, source='') + message_sources.append(current_message) + + else: + # Otherwise, add the piece to the current message source + current_message.source = (current_message.source or '') + piece + + messages = message_sources_to_messages(message_sources) + return insert_history(messages, data.messages if data else None) + + +def message_sources_to_messages( + message_sources: list[MessageSource], +) -> list[Message]: + """ + Processes an array of message sources into an array of messages. + + Args: + message_sources: List of message sources + + Returns: + List of structured messages + """ + messages: list[Message] = [] + for m in message_sources: + if m.content or m.source: + message = Message( + role=m.role, + content=m.content + if m.content is not None + else to_parts(m.source or ''), + ) + + if m.metadata: + message.metadata = m.metadata + + messages.append(message) + + return messages + + +def transform_messages_to_history( + messages: list[Message], +) -> list[Message]: + """Adds history metadata to an array of messages. + + Args: + messages: Array of messages to transform + + Returns: + Array of messages with history metadata added + """ + return [ + Message( + role=message.role, + content=message.content, + metadata={**(message.metadata or {}), 'purpose': 'history'}, + ) + for message in messages + ] + + +def messages_have_history(messages: list[Message]) -> bool: + """Checks if the messages have history metadata. + + Args: + messages: The messages to check + + Returns: + True if the messages have history metadata, False otherwise + """ + return any( + msg.metadata and msg.metadata.get('purpose') == 'history' + for msg in messages + ) + + +def insert_history( + messages: list[Message], + history: list[Message] | None = None, +) -> list[Message]: + """Inserts historical messages into the conversation. + + The history is inserted at: + - The end of the conversation if there is no history or no user message. + - Before the last user message if there is a user message. + + The history is not inserted: + - If it already exists in the messages. + - If there is no user message. + + Args: + messages: Current array of messages + history: Historical messages to insert + + Returns: + Messages with history inserted + """ + # If we have no history or find an existing instance of history, return the + # original messages unmodified. + if not history or messages_have_history(messages): + return messages + + last_message = messages[-1] + if last_message.role == 'user': + # If the last message is a user message, insert the history before it. + messages = messages[:-1] + messages.extend(history) + messages.append(last_message) + else: + # Otherwise, append the history to the end of the messages. + messages.extend(history) + return messages + + +def to_parts(source: str) -> list[Part]: + """Converts a source string into an array of parts. + + Also processes media and section markers. + + Args: + source: The source string to convert into parts + + Returns: + Array of structured parts (text, media, or metadata) + """ + return [ + parse_part(piece) + for piece in split_by_media_and_section_markers(source) + ] + + +def parse_part(piece: str) -> Part: + """Parses a part from a piece of rendered template. + + Args: + piece: The piece to parse + + Returns: + Part, PendingPart, TextPart, or MediaPart + """ + if piece.startswith(MEDIA_MARKER_PREFIX): + return parse_media_part(piece) + elif piece.startswith(SECTION_MARKER_PREFIX): + return parse_section_part(piece) + else: + return parse_text_part(piece) + + +def parse_media_part(piece: str) -> MediaPart: + """Parses a media part from a piece of rendered template. + + Args: + piece: The piece to parse + + Returns: + Media part + + Raises: + ValueError: If the media piece is invalid + """ + if not piece.startswith(MEDIA_MARKER_PREFIX): + raise ValueError(f'Invalid media piece: {piece}') + + fields = piece.split(' ') + n = len(fields) + if n == 3: + _, url, content_type = fields + elif n == 2: + _, url = fields + content_type = None + else: + raise ValueError(f'Invalid media piece: {piece}') + + part = MediaPart(media=dict(url=url)) + if content_type and content_type.strip(): + part.media['contentType'] = content_type + return part + + +def parse_section_part(piece: str) -> PendingPart: + """Parses a section part from a piece of rendered template. + + Args: + piece: The piece to parse + + Returns: + Section part + + Raises: + ValueError: If the section piece is invalid + """ + if not piece.startswith(SECTION_MARKER_PREFIX): + raise ValueError(f'Invalid section piece: {piece}') + + fields = piece.split(' ') + if len(fields) == 2: + section_type = fields[1] + else: + raise ValueError(f'Invalid section piece: {piece}') + return PendingPart(metadata=dict(purpose=section_type, pending=True)) + + +def parse_text_part(piece: str) -> TextPart: + """Parses a text part from a piece of rendered template. + + Args: + piece: The piece to parse + + Returns: + Text part + """ + return TextPart(text=piece) diff --git a/python/dotpromptz/src/dotpromptz/parse_test.py b/python/dotpromptz/src/dotpromptz/parse_test.py new file mode 100644 index 000000000..9f433728b --- /dev/null +++ b/python/dotpromptz/src/dotpromptz/parse_test.py @@ -0,0 +1,650 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for parse module.""" + +import re +import unittest + +import pytest + +from dotpromptz.parse import ( + FRONTMATTER_AND_BODY_REGEX, + MEDIA_AND_SECTION_MARKER_REGEX, + ROLE_AND_HISTORY_MARKER_REGEX, + MessageSource, + convert_namespaced_entry_to_nested_object, + extract_frontmatter_and_body, + insert_history, + message_sources_to_messages, + messages_have_history, + parse_media_part, + parse_part, + parse_section_part, + parse_text_part, + split_by_media_and_section_markers, + split_by_regex, + split_by_role_and_history_markers, + transform_messages_to_history, +) +from dotpromptz.typing import ( + MediaPart, + Message, + Part, + PendingPart, + Role, + TextPart, +) + + +class TestSplitByMediaAndSectionMarkers(unittest.TestCase): + def test_split_by_media_and_section_markers(self) -> None: + """Test splitting by media and section markers.""" + input_str = '<<>> https://example.com/image.jpg' + output = split_by_media_and_section_markers(input_str) + assert output == [ + '<< None: + """Test multiple markers in a string.""" + input_str = '<<>> https://example.com/image.jpg' + output = split_by_media_and_section_markers(input_str) + assert output == [ + '<< None: + """Test no markers in a string.""" + input_str = 'Hello World' + output = split_by_media_and_section_markers(input_str) + assert output == ['Hello World'] + + +def test_role_and_history_marker_regex_valid_patterns() -> None: + """Test valid patterns for role and history markers.""" + valid_patterns = [ + '<<>>', + '<<>>', + '<<>>', + '<<>>', + '<<>>', + '<<>>', + '<<>>', + ] + + for pattern in valid_patterns: + assert ROLE_AND_HISTORY_MARKER_REGEX.search(pattern) is not None + + +def test_role_and_history_marker_regex_invalid_patterns() -> None: + """Test invalid patterns for role and history markers.""" + invalid_patterns = [ + '<<>>', # uppercase not allowed + '<<>>', # numbers not allowed + '<<>>', # needs at least one letter + '<<>>', # missing role value + '<<>>', # history should be exact + '<<>>', # history must be lowercase + 'dotprompt:role:user', # missing brackets + '<<>>', # incomplete opening + ] + + for pattern in invalid_patterns: + assert ROLE_AND_HISTORY_MARKER_REGEX.search(pattern) is None + + +def test_role_and_history_marker_regex_multiple_matches() -> None: + """Test multiple matches in a string.""" + text = """ + <<>> Hello + <<>> Hi there + <<>> + <<>> How are you? + """ + + matches = ROLE_AND_HISTORY_MARKER_REGEX.findall(text) + assert len(matches) == 4 + + +def test_media_and_section_marker_regex_valid_patterns() -> None: + """Test valid patterns for media and section markers.""" + valid_patterns = [ + '<<>>', + '<<>>', + ] + + for pattern in valid_patterns: + assert MEDIA_AND_SECTION_MARKER_REGEX.search(pattern) is not None + + +def test_media_and_section_marker_regex_multiple_matches() -> None: + """Test multiple matches in a string.""" + text = """ + <<>> https://example.com/image.jpg + <<>> Section 1 + <<>> https://example.com/video.mp4 + <<>> Section 2 + """ + + matches = MEDIA_AND_SECTION_MARKER_REGEX.findall(text) + assert len(matches) == 4 + + +class TestSplitByRoleAndHistoryMarkers(unittest.TestCase): + def test_no_markers(self) -> None: + """Test splitting when no markers are present.""" + input_str = 'Hello World' + output = split_by_role_and_history_markers(input_str) + assert output == ['Hello World'] + + def test_single_marker(self) -> None: + """Test splitting with a single marker.""" + input_str = 'Hello <<>> world' + output = split_by_role_and_history_markers(input_str) + assert output == ['Hello ', '<< None: + """Test splitting with a single marker.""" + input_str = 'Hello <<>> world' + output = split_by_role_and_history_markers(input_str) + assert output == ['Hello ', '<< None: + """Test filtering empty and whitespace-only pieces.""" + input_str = ' <<>> ' + output = split_by_role_and_history_markers(input_str) + assert output == ['<< None: + """Test adjacent markers.""" + input_str = '<<>><<>>' + output = split_by_role_and_history_markers(input_str) + assert output == ['<< None: + """Test no split on markers with uppercase letters (invalid format).""" + input_str = '<<>>' + output = split_by_role_and_history_markers(input_str) + assert output == ['<<>>'] + + def test_split_by_role_and_history_markers_multiple_markers(self) -> None: + """Test string with multiple markers interleaved with text.""" + input_str = ( + 'Start <<>> middle <<>> end' + ) + output = split_by_role_and_history_markers(input_str) + assert output == [ + 'Start ', + '<< None: + """Test creating nested object structure from namespaced key.""" + result = convert_namespaced_entry_to_nested_object('foo.bar', 'hello') + self.assertEqual( + result, + { + 'foo': { + 'bar': 'hello', + }, + }, + ) + + def test_adding_to_existing_namespace(self) -> None: + """Test adding to existing namespace.""" + existing = { + 'foo': { + 'bar': 'hello', + }, + } + result = convert_namespaced_entry_to_nested_object( + 'foo.baz', 'world', existing + ) + self.assertEqual( + result, + { + 'foo': { + 'bar': 'hello', + 'baz': 'world', + }, + }, + ) + + def test_handling_multiple_namespaces(self) -> None: + """Test handling multiple namespaces.""" + result = convert_namespaced_entry_to_nested_object('foo.bar', 'hello') + final_result = convert_namespaced_entry_to_nested_object( + 'baz.qux', 'world', result + ) + self.assertEqual( + final_result, + { + 'foo': { + 'bar': 'hello', + }, + 'baz': { + 'qux': 'world', + }, + }, + ) + + +@pytest.mark.parametrize( + 'source,expected_frontmatter,expected_body', + [ + ( + '---\nfoo: bar\n---\nThis is the body.', + 'foo: bar', + 'This is the body.', + ), # Test document with frontmatter and body + ( + '---\n\n---\nBody only.', + '', + 'Body only.', + ), # Test document with empty frontmatter + ( + '---\nfoo: bar\n---\n', + 'foo: bar', + '', + ), # Test document with empty body + ( + '---\nfoo: bar\nbaz: qux\n---\nThis is the body.', + 'foo: bar\nbaz: qux', + 'This is the body.', + ), # Test document with multiline frontmatter + ( + 'Just a body.', + None, + None, + ), # Test document with no frontmatter markers + ( + '---\nfoo: bar\nThis is the body.', + None, + None, + ), # Test document with incomplete frontmatter markers + ( + '---\nfoo: bar\n---\nThis is the body.\n---\nExtra section.', + 'foo: bar', + 'This is the body.\n---\nExtra section.', + ), # Test document with extra frontmatter markers + ], +) +def test_frontmatter_and_body_regex( + source: str, + expected_frontmatter: str | None, + expected_body: str | None, +) -> None: + """Test frontmatter and body regex.""" + match = FRONTMATTER_AND_BODY_REGEX.match(source) + + if expected_frontmatter is None: + assert match is None + else: + assert match is not None + frontmatter, body = match.groups() + assert frontmatter == expected_frontmatter + assert body == expected_body + + +class TestExtractFrontmatterAndBody(unittest.TestCase): + """Test extracting frontmatter and body from a string.""" + + def test_should_extract_frontmatter_and_body(self) -> None: + """Test extracting frontmatter and body when both are present.""" + input_str = '---\nfoo: bar\n---\nThis is the body.' + frontmatter, body = extract_frontmatter_and_body(input_str) + assert frontmatter == 'foo: bar' + assert body == 'This is the body.' + + def test_extract_frontmatter_and_body_no_frontmatter(self) -> None: + """Test extracting body when no frontmatter is present. + + Both the frontmatter and body are empty strings, when there + is no frontmatter marker. + """ + + input_str = 'Hello World' + frontmatter, body = extract_frontmatter_and_body(input_str) + assert frontmatter == '' + assert body == '' + + +def test_split_by_regex() -> None: + """Test splitting by regex and filtering empty/whitespace pieces.""" + source = ' one , , two , three ' + result = split_by_regex(source, re.compile(r',')) + assert result == [' one ', ' two ', ' three '] + + +class TestTransformMessagesToHistory(unittest.TestCase): + def test_add_history_metadata_to_messages(self) -> None: + messages: list[Message] = [ + Message(role=Role.USER, content=[TextPart(text='Hello')]), + Message(role=Role.MODEL, content=[TextPart(text='Hi there')]), + ] + + result = transform_messages_to_history(messages) + + assert len(result) == 2 + assert result == [ + Message( + role=Role.USER, + content=[TextPart(text='Hello')], + metadata={'purpose': 'history'}, + ), + Message( + role=Role.MODEL, + content=[TextPart(text='Hi there')], + metadata={'purpose': 'history'}, + ), + ] + + def test_preserve_existing_metadata_while_adding_history_purpose( + self, + ) -> None: + messages: list[Message] = [ + Message( + role=Role.USER, + content=[TextPart(text='Hello')], + metadata={'foo': 'bar'}, + ) + ] + + result = transform_messages_to_history(messages) + + assert len(result) == 1 + assert result == [ + Message( + role=Role.USER, + content=[TextPart(text='Hello')], + metadata={'foo': 'bar', 'purpose': 'history'}, + ) + ] + + def test_handle_empty_array(self) -> None: + result = transform_messages_to_history([]) + assert result == [] + + +class TestMessageSourcesToMessages(unittest.TestCase): + def test_should_handle_empty_array(self) -> None: + message_sources: list[MessageSource] = [] + expected: list[Message] = [] + assert message_sources_to_messages(message_sources) == expected + + def test_should_convert_a_single_message_source(self) -> None: + message_sources: list[MessageSource] = [ + MessageSource(role=Role.USER, source='Hello') + ] + expected: list[Message] = [ + Message(role=Role.USER, content=[TextPart(text='Hello')]) + ] + assert message_sources_to_messages(message_sources) == expected + + def test_should_handle_message_source_with_content(self) -> None: + message_sources: list[MessageSource] = [ + MessageSource( + role=Role.USER, content=[TextPart(text='Existing content')] + ) + ] + expected: list[Message] = [ + Message(role=Role.USER, content=[TextPart(text='Existing content')]) + ] + assert message_sources_to_messages(message_sources) == expected + + def test_should_handle_message_source_with_metadata(self) -> None: + message_sources: list[MessageSource] = [ + MessageSource( + role=Role.USER, + content=[TextPart(text='Existing content')], + metadata={'foo': 'bar'}, + ) + ] + expected: list[Message] = [ + Message( + role=Role.USER, + content=[TextPart(text='Existing content')], + metadata={'foo': 'bar'}, + ) + ] + assert message_sources_to_messages(message_sources) == expected + + def test_should_filter_out_message_sources_with_empty_source_and_content( + self, + ) -> None: + message_sources: list[MessageSource] = [ + MessageSource(role=Role.USER, source=''), + MessageSource(role=Role.MODEL, source=' '), + MessageSource(role=Role.USER, source='Hello'), + ] + expected: list[Message] = [ + Message(role=Role.MODEL, content=[]), + Message(role=Role.USER, content=[TextPart(text='Hello')]), + ] + assert message_sources_to_messages(message_sources) == expected + + def test_should_handle_multiple_message_sources(self) -> None: + message_sources: list[MessageSource] = [ + MessageSource(role=Role.USER, source='Hello'), + MessageSource(role=Role.MODEL, source='Hi there'), + MessageSource(role=Role.USER, source='How are you?'), + ] + expected: list[Message] = [ + Message(role=Role.USER, content=[TextPart(text='Hello')]), + Message(role=Role.MODEL, content=[TextPart(text='Hi there')]), + Message(role=Role.USER, content=[TextPart(text='How are you?')]), + ] + assert message_sources_to_messages(message_sources) == expected + + +class TestMessagesHaveHistory(unittest.TestCase): + def test_should_return_true_if_messages_have_history_metadata(self) -> None: + messages: list[Message] = [ + Message( + role=Role.USER, + content=[TextPart(text='Hello')], + metadata={'purpose': 'history'}, + ) + ] + + result = messages_have_history(messages) + + self.assertTrue(result) + + def test_should_return_false_if_messages_do_not_have_history_metadata( + self, + ) -> None: + messages: list[Message] = [ + Message(role=Role.USER, content=[TextPart(text='Hello')]) + ] + + result = messages_have_history(messages) + + self.assertFalse(result) + + +class TestInsertHistory(unittest.TestCase): + def test_should_return_original_messages_if_history_is_undefined( + self, + ) -> None: + messages: list[Message] = [ + Message(role=Role.USER, content=[TextPart(text='Hello')]) + ] + + result = insert_history(messages, []) + + assert result == messages + + def test_should_return_original_messages_if_history_purpose_already_exists( + self, + ) -> None: + messages: list[Message] = [ + Message( + role=Role.USER, + content=[TextPart(text='Hello')], + metadata={'purpose': 'history'}, + ) + ] + + history: list[Message] = [ + Message( + role=Role.MODEL, + content=[TextPart(text='Previous')], + metadata={'purpose': 'history'}, + ) + ] + + result = insert_history(messages, history) + + assert result == messages + + def test_should_insert_history_before_the_last_user_message(self) -> None: + messages: list[Message] = [ + Message(role=Role.SYSTEM, content=[TextPart(text='System prompt')]), + Message( + role=Role.USER, content=[TextPart(text='Current question')] + ), + ] + + history: list[Message] = [ + Message( + role=Role.MODEL, + content=[TextPart(text='Previous')], + metadata={'purpose': 'history'}, + ) + ] + + result = insert_history(messages, history) + + assert len(result) == 3 + assert result == [ + Message(role=Role.SYSTEM, content=[TextPart(text='System prompt')]), + Message( + role=Role.MODEL, + content=[TextPart(text='Previous')], + metadata={'purpose': 'history'}, + ), + Message( + role=Role.USER, content=[TextPart(text='Current question')] + ), + ] + + def test_should_append_history_at_the_end_if_no_user_message_is_last( + self, + ) -> None: + messages: list[Message] = [ + Message(role=Role.SYSTEM, content=[TextPart(text='System prompt')]), + Message(role=Role.MODEL, content=[TextPart(text='Model message')]), + ] + history: list[Message] = [ + Message( + role=Role.MODEL, + content=[TextPart(text='Previous')], + metadata={'purpose': 'history'}, + ) + ] + + result = insert_history(messages, history) + + assert len(result) == 3 + assert result == [ + Message(role=Role.SYSTEM, content=[TextPart(text='System prompt')]), + Message(role=Role.MODEL, content=[TextPart(text='Model message')]), + Message( + role=Role.MODEL, + content=[TextPart(text='Previous')], + metadata={'purpose': 'history'}, + ), + ] + + +@pytest.mark.parametrize( + 'piece,expected', + [ + ( + 'Hello World', + TextPart(text='Hello World'), + ), + ( + '<<>> https://example.com/image.jpg', + MediaPart( + media=dict( + url='https://example.com/image.jpg', + ) + ), + ), + ( + '<<>> https://example.com/image.jpg image/jpeg', + MediaPart( + media={ + 'url': 'https://example.com/image.jpg', + 'contentType': 'image/jpeg', + }, + ), + ), + ( + 'https://example.com/image.jpg', + TextPart(text='https://example.com/image.jpg'), + ), + ( + '<<>> code', + PendingPart(metadata=dict(purpose='code', pending=True)), + ), + ( + 'Text before <<>> https://example.com/image.jpg Text after', + TextPart( + text='Text before <<>> https://example.com/image.jpg Text after' + ), + ), + ], +) +def test_parse_part(piece: str, expected: Part) -> None: + """Test parsing pieces.""" + result = parse_part(piece) + assert result == expected + + +def test_parse_media_piece() -> None: + """Test parsing media pieces.""" + piece = '<<>> https://example.com/image.jpg' + result = parse_media_part(piece) + assert result == MediaPart(media={'url': 'https://example.com/image.jpg'}) + + +def test_parse_media_piece_invalid() -> None: + """Test parsing invalid media pieces.""" + piece = 'https://example.com/image.jpg' + with pytest.raises(ValueError): + parse_media_part(piece) + + +def test_parse_section_piece() -> None: + """Test parsing section pieces.""" + piece = '<<>> code' + result = parse_section_part(piece) + assert result == PendingPart(metadata={'purpose': 'code', 'pending': True}) + + +def test_parse_section_piece_invalid() -> None: + """Test parsing invalid section pieces.""" + piece = 'https://example.com/image.jpg' + with pytest.raises(ValueError): + parse_section_part(piece) + + +def test_parse_text_piece() -> None: + """Test parsing text pieces.""" + piece = 'Hello World' + result = parse_text_part(piece) + assert result == TextPart(text='Hello World')