package main
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRequestModifierPlugin(t *testing.T) {
tests := []struct {
name string
config RequestModifierConfig
inputRequest *schemas.BifrostRequest
expectedPrefix string
expectedSuffix string
}{
{
name: "adds prefix and suffix to user message",
config: RequestModifierConfig{
PrefixPrompt: "Please be concise:",
SuffixPrompt: "Respond in one sentence.",
},
inputRequest: &schemas.BifrostRequest{
Provider: schemas.OpenAI,
Model: "gpt-4o-mini",
Input: schemas.RequestInput{
ChatCompletionInput: &[]schemas.BifrostMessage{
{
Role: schemas.ModelChatMessageRoleUser,
Content: schemas.MessageContent{
ContentStr: stringPtr("What is AI?"),
},
},
},
},
},
expectedPrefix: "Please be concise:",
expectedSuffix: "Respond in one sentence.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
plugin := NewRequestModifierPlugin(tt.config)
ctx := context.Background()
result, shortCircuit, err := plugin.PreHook(&ctx, tt.inputRequest)
assert.NoError(t, err)
assert.Nil(t, shortCircuit)
assert.NotNil(t, result)
messages := *result.Input.ChatCompletionInput
require.Len(t, messages, 1)
content := *messages[0].Content.ContentStr
assert.Contains(t, content, tt.expectedPrefix)
assert.Contains(t, content, tt.expectedSuffix)
assert.Contains(t, content, "What is AI?")
})
}
}
func TestAuthenticationPlugin(t *testing.T) {
validKeys := map[string]string{
"test-key-1": "user-1",
"test-key-2": "user-2",
}
plugin := NewAuthenticationPlugin(validKeys)
tests := []struct {
name string
apiKey string
expectError bool
errorCode string
}{
{
name: "valid API key",
apiKey: "test-key-1",
expectError: false,
},
{
name: "invalid API key",
apiKey: "invalid-key",
expectError: true,
errorCode: "invalid_api_key",
},
{
name: "missing API key",
apiKey: "",
expectError: true,
errorCode: "missing_api_key",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.WithValue(context.Background(), "api_key", tt.apiKey)
req := &schemas.BifrostRequest{
Provider: schemas.OpenAI,
Model: "gpt-4o-mini",
}
result, shortCircuit, err := plugin.PreHook(&ctx, req)
assert.NoError(t, err) // Plugin errors are returned via shortCircuit
if tt.expectError {
assert.Nil(t, result)
assert.NotNil(t, shortCircuit)
assert.NotNil(t, shortCircuit.Error)
if tt.errorCode != "" {
assert.Equal(t, tt.errorCode, *shortCircuit.Error.Error.Code)
}
assert.NotNil(t, shortCircuit.AllowFallbacks)
assert.False(t, *shortCircuit.AllowFallbacks)
} else {
assert.NotNil(t, result)
assert.Nil(t, shortCircuit)
// Check that user context was added
userID := ctx.Value("user_id")
assert.Equal(t, "user-1", userID)
}
})
}
}