From 39ed2a8ebfb275fdaf1c5947e7936bb6b500674a Mon Sep 17 00:00:00 2001 From: kpom Date: Thu, 19 Mar 2026 12:32:09 -0500 Subject: [PATCH] Sonic goes fast --- drivers/sonic/README.md | 87 +++ drivers/sonic/batch.go | 205 ++++++ drivers/sonic/cypher.go | 152 +++++ drivers/sonic/eval.go | 629 +++++++++++++++++ drivers/sonic/execute.go | 1007 +++++++++++++++++++++++++++ drivers/sonic/integration_test.go | 435 ++++++++++++ drivers/sonic/pathfinding.go | 313 +++++++++ drivers/sonic/queries.go | 367 ++++++++++ drivers/sonic/sonic.go | 107 +++ drivers/sonic/sonic_test.go | 1060 +++++++++++++++++++++++++++++ drivers/sonic/transaction.go | 122 ++++ 11 files changed, 4484 insertions(+) create mode 100644 drivers/sonic/README.md create mode 100644 drivers/sonic/batch.go create mode 100644 drivers/sonic/cypher.go create mode 100644 drivers/sonic/eval.go create mode 100644 drivers/sonic/execute.go create mode 100644 drivers/sonic/integration_test.go create mode 100644 drivers/sonic/pathfinding.go create mode 100644 drivers/sonic/queries.go create mode 100644 drivers/sonic/sonic.go create mode 100644 drivers/sonic/sonic_test.go create mode 100644 drivers/sonic/transaction.go diff --git a/drivers/sonic/README.md b/drivers/sonic/README.md new file mode 100644 index 0000000..235bc7f --- /dev/null +++ b/drivers/sonic/README.md @@ -0,0 +1,87 @@ +# Sonic — In-Memory DAWGS Graph Driver + +Sonic is an in-memory graph database driver for DAWGS. It implements the same `graph.Database` interface as the `pg` driver, giving BloodHound users a zero-infrastructure option with no Postgres required. + +`sonic` because it's really fast — no network, no disk, no MVCC overhead. + +## Architecture + +All data lives in Go maps protected by a `sync.RWMutex`. Adjacency indexes (`outEdges`, `inEdges`) map node IDs to their edge IDs for O(1) neighbor lookup. IDs are assigned via `atomic.Uint64`. + +The driver registers itself as `"sonic"` via `dawgs.Register()` in `init()`, following the same pattern as the `pg` driver. + +## Files + +| File | Purpose | +|------|---------| +| `sonic.go` | `Database` struct, constructor, `graph.Database` interface | +| `transaction.go` | `graph.Transaction` — node/edge CRUD, Cypher dispatch | +| `batch.go` | `graph.Batch` — bulk CRUD with upsert support | +| `queries.go` | `graph.NodeQuery` / `graph.RelationshipQuery` — filtering, fetching, shortest paths | +| `eval.go` | Cypher AST filter evaluation, comparison operators, type coercion | +| `pathfinding.go` | BFS shortest-path algorithm with constraint extraction | +| `cypher.go` | Cypher AST walker — MATCH, WITH, RETURN, WHERE, variable-length paths | +| `execute.go` | `sonicResult` — result set iteration, scanning, value mapping | + +## What Works + +### Graph Operations (via `graph.*` interfaces) + +- **CRUD**: CreateNode, UpdateNode, DeleteNode, CreateRelationship, CreateRelationshipByIDs, UpdateRelationship, DeleteRelationship +- **Node queries**: Filter, Filterf, First, Count, Fetch, FetchIDs, FetchKinds, Delete, Update, Query +- **Relationship queries**: Filter, Filterf, First, Count, Fetch, FetchIDs, FetchKinds, FetchDirection, FetchTriples, FetchAllShortestPaths, Delete, Update, Query +- **Batch upserts**: UpdateNodeBy, UpdateRelationshipBy (identity-based match/create/update) +- **Schema**: AssertSchema, SetDefaultGraph, FetchKinds + +### Filter Evaluation + +The driver evaluates the Cypher AST that DAWGS query builders produce: + +- **Logical**: Conjunction (AND), Disjunction (OR), Negation (NOT), Parenthetical +- **Comparisons**: `=`, `!=`, `<`, `>`, `<=`, `>=`, `IN`, `CONTAINS`, `STARTS WITH`, `ENDS WITH`, `IS NULL`, `IS NOT NULL` +- **Kind matching**: node kinds, edge kinds, start/end node kinds +- **Functions**: `id()`, `type()`, `toLower()`, `toUpper()`, `labels()`, `keys()` +- **Property resolution**: node/edge properties, start/end node properties via `query.EdgeStartSymbol`/`query.EdgeEndSymbol` + +### Cypher Execution + +Raw Cypher strings are parsed and executed via an AST walker: + +- MATCH / OPTIONAL MATCH with node and relationship patterns +- WHERE clause filtering with full expression evaluation +- WITH (scope barriers, projection, aggregation aliases) +- RETURN (*, named projections) +- ORDER BY, LIMIT, SKIP, DISTINCT +- `allShortestPaths()` pattern +- Variable-length relationship patterns (`[*]`, `[*1..3]`) +- Multi-part queries (multiple MATCH/WITH chains) +- Parameter substitution + +### Pathfinding + +BFS shortest-path implementation that: +- Finds **all** equally-short paths between start and end nodes +- Respects edge kind constraints +- Supports multiple start/end nodes simultaneously +- Uses bidirectional parent tracking for path reconstruction + +## What's Not Supported + +- **Cypher write operations**: CREATE, DELETE, SET, REMOVE, MERGE return errors. Use the `graph.Transaction` or `graph.Batch` interfaces for writes. +- **UNWIND, quantifiers, filter expressions** in Cypher +- **Aggregation functions** (count, collect, sum, avg, min, max) — return nil stubs in Cypher evaluation +- **OrderBy, Offset, Limit** on `nodeQuery`/`relQuery` — accepted but no-op +- **Persistence** — data lives only in memory, lost on process exit (by design) + +## Constraints + +- **No persistence** — data is lost when the process exits. By design for the initial version. +- **Coarse locking** — `sync.RWMutex` protects the whole database, not individual operations. Fine for single-user BHE. +- **Non-deterministic ordering** — map iteration means query results may come back in different orders than Postgres. +- **Binding limit** — Cypher execution caps at 100,000 intermediate bindings. +- **Variable-length path depth** — capped at 50 hops with cycle prevention. + +## Tests + +- **Unit tests** (`sonic_test.go`): CRUD, property filters, shortest paths, Cypher queries (kind filtering, negation, multi-part, variable-length paths, anonymous nodes, concurrent access) +- **Integration tests** (`integration_test.go`): node/relationship operations, attack path finding, batch upserts, parallel fetches against a realistic graph topology diff --git a/drivers/sonic/batch.go b/drivers/sonic/batch.go new file mode 100644 index 0000000..162c68b --- /dev/null +++ b/drivers/sonic/batch.go @@ -0,0 +1,205 @@ +package sonic + +import ( + "context" + "fmt" + + "github.com/specterops/dawgs/graph" +) + +type batch struct { + db *Database + ctx context.Context +} + +func (b *batch) WithGraph(graphSchema graph.Graph) graph.Batch { + return b +} + +func (b *batch) CreateNode(node *graph.Node) error { + id := b.db.newID() + node.ID = id + + b.db.mu.Lock() + defer b.db.mu.Unlock() + + b.db.nodes[id] = node + return nil +} + +func (b *batch) DeleteNode(id graph.ID) error { + b.db.mu.Lock() + defer b.db.mu.Unlock() + + for _, edgeID := range b.db.outEdges[id] { + delete(b.db.edges, edgeID) + } + for _, edgeID := range b.db.inEdges[id] { + delete(b.db.edges, edgeID) + } + delete(b.db.outEdges, id) + delete(b.db.inEdges, id) + delete(b.db.nodes, id) + return nil +} + +func (b *batch) Nodes() graph.NodeQuery { + return &nodeQuery{db: b.db} +} + +func (b *batch) Relationships() graph.RelationshipQuery { + return &relQuery{db: b.db} +} + +func (b *batch) UpdateNodeBy(update graph.NodeUpdate) error { + b.db.mu.Lock() + defer b.db.mu.Unlock() + + // Try to find an existing node that matches the identity criteria + for _, existing := range b.db.nodes { + if !existing.Kinds.ContainsOneOf(update.IdentityKind) { + continue + } + + if matchesIdentity(existing.Properties, update.Node.Properties, update.IdentityProperties) { + // Update existing node: merge kinds and properties + for _, kind := range update.Node.Kinds { + if !existing.Kinds.ContainsOneOf(kind) { + existing.Kinds = append(existing.Kinds, kind) + } + } + if update.Node.Properties != nil { + for key, val := range update.Node.Properties.Map { + existing.Properties.Set(key, val) + } + } + return nil + } + } + + // No match — create new node (inline to avoid double-lock) + id := b.db.newID() + update.Node.ID = id + b.db.nodes[id] = update.Node + return nil +} + +func (b *batch) CreateRelationship(relationship *graph.Relationship) error { + id := b.db.newID() + relationship.ID = id + + b.db.mu.Lock() + defer b.db.mu.Unlock() + + b.db.edges[id] = relationship + b.db.outEdges[relationship.StartID] = append(b.db.outEdges[relationship.StartID], id) + b.db.inEdges[relationship.EndID] = append(b.db.inEdges[relationship.EndID], id) + return nil +} + +func (b *batch) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) error { + b.db.mu.Lock() + defer b.db.mu.Unlock() + + if _, ok := b.db.nodes[startNodeID]; !ok { + return fmt.Errorf("start node %d not found", startNodeID) + } + if _, ok := b.db.nodes[endNodeID]; !ok { + return fmt.Errorf("end node %d not found", endNodeID) + } + + id := b.db.newID() + rel := &graph.Relationship{ + ID: id, + StartID: startNodeID, + EndID: endNodeID, + Kind: kind, + Properties: properties, + } + + b.db.edges[id] = rel + b.db.outEdges[startNodeID] = append(b.db.outEdges[startNodeID], id) + b.db.inEdges[endNodeID] = append(b.db.inEdges[endNodeID], id) + return nil +} + +func (b *batch) DeleteRelationship(id graph.ID) error { + b.db.mu.Lock() + defer b.db.mu.Unlock() + + rel, ok := b.db.edges[id] + if !ok { + return fmt.Errorf("relationship %d not found", id) + } + + delete(b.db.edges, id) + + // Clean up adjacency indexes + b.db.outEdges[rel.StartID] = removeID(b.db.outEdges[rel.StartID], id) + b.db.inEdges[rel.EndID] = removeID(b.db.inEdges[rel.EndID], id) + return nil +} + +func (b *batch) UpdateRelationshipBy(update graph.RelationshipUpdate) error { + b.db.mu.Lock() + defer b.db.mu.Unlock() + + rel := update.Relationship + + // Try to find an existing relationship that matches + for _, existing := range b.db.edges { + if existing.Kind != rel.Kind { + continue + } + if existing.StartID != rel.StartID || existing.EndID != rel.EndID { + continue + } + if matchesIdentity(existing.Properties, rel.Properties, update.IdentityProperties) { + // Update existing relationship properties + if rel.Properties != nil { + for key, val := range rel.Properties.Map { + existing.Properties.Set(key, val) + } + } + return nil + } + } + + // No match — create new relationship (inline to avoid double-lock) + id := b.db.newID() + rel.ID = id + b.db.edges[id] = rel + b.db.outEdges[rel.StartID] = append(b.db.outEdges[rel.StartID], id) + b.db.inEdges[rel.EndID] = append(b.db.inEdges[rel.EndID], id) + return nil +} + +func (b *batch) Commit() error { + return nil +} + +func matchesIdentity(existing, candidate *graph.Properties, identityKeys []string) bool { + if len(identityKeys) == 0 { + return false + } + for _, key := range identityKeys { + existingVal := existing.Get(key).Any() + candidateVal := candidate.Get(key).Any() + if existingVal == nil || candidateVal == nil { + return false + } + if fmt.Sprint(existingVal) != fmt.Sprint(candidateVal) { + return false + } + } + return true +} + +func removeID(ids []graph.ID, target graph.ID) []graph.ID { + for i, id := range ids { + if id == target { + return append(ids[:i], ids[i+1:]...) + } + } + return ids +} diff --git a/drivers/sonic/cypher.go b/drivers/sonic/cypher.go new file mode 100644 index 0000000..e9528d1 --- /dev/null +++ b/drivers/sonic/cypher.go @@ -0,0 +1,152 @@ +package sonic + +import ( + "fmt" + + cypher "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/walk" + "github.com/specterops/dawgs/graph" +) + +// executeCypher takes a parsed Cypher RegularQuery and executes it against the in-memory database. +// It uses a walker-based executor that handles arbitrary Cypher clause structures. +func (db *Database) executeCypher(query *cypher.RegularQuery, parameters map[string]any) (*sonicResult, error) { + if query.SingleQuery == nil { + return emptyResult(), nil + } + + ex := newExecutor(db, parameters) + + if err := walk.Cypher(query, ex); err != nil { + return nil, err + } + + if ex.result == nil { + return emptyResult(), nil + } + + return ex.result, nil +} + +// resolveParameters walks the AST and replaces Parameter nodes with their values. +func resolveParameters(where *cypher.Where, parameters map[string]any) { + if parameters == nil || where == nil { + return + } + for i, expr := range where.Expressions { + where.Expressions[i] = resolveParamsInExpr(expr, parameters) + } +} + +func resolveParamsInExpr(expr cypher.Expression, params map[string]any) cypher.Expression { + switch e := expr.(type) { + case *cypher.Comparison: + for _, partial := range e.Partials { + if param, ok := partial.Right.(*cypher.Parameter); ok { + if val, exists := params[param.Symbol]; exists { + param.Value = val + } + } + } + case *cypher.Conjunction: + for i, sub := range e.Expressions { + e.Expressions[i] = resolveParamsInExpr(sub, params) + } + case *cypher.Disjunction: + for i, sub := range e.Expressions { + e.Expressions[i] = resolveParamsInExpr(sub, params) + } + case *cypher.Negation: + e.Expression = resolveParamsInExpr(e.Expression, params) + case *cypher.Parenthetical: + e.Expression = resolveParamsInExpr(e.Expression, params) + case *cypher.FunctionInvocation: + for i, arg := range e.Arguments { + e.Arguments[i] = resolveParamsInExpr(arg, params) + } + case *cypher.Parameter: + if val, exists := params[e.Symbol]; exists { + e.Value = val + } + } + return expr +} + +// --- sonicResult implements graph.Result --- + +type sonicResult struct { + rows [][]any + keys []string + cursor int + current []any + err error +} + +func emptyResult() *sonicResult { + return &sonicResult{} +} + +func (r *sonicResult) Next() bool { + if r.cursor >= len(r.rows) { + return false + } + r.current = r.rows[r.cursor] + r.cursor++ + return true +} + +func (r *sonicResult) Keys() []string { + return r.keys +} + +func (r *sonicResult) Values() []any { + return r.current +} + +func (r *sonicResult) Mapper() graph.ValueMapper { + return graph.NewValueMapper(sonicMapValue) +} + +func (r *sonicResult) Scan(targets ...any) error { + if r.current == nil { + return fmt.Errorf("sonic: no current row") + } + + mapper := r.Mapper() + for i, target := range targets { + if i >= len(r.current) { + break + } + mapper.Map(r.current[i], target) + } + return nil +} + +func (r *sonicResult) Error() error { + return r.err +} + +func (r *sonicResult) Close() { +} + +// sonicMapValue maps in-memory graph types directly to targets. +func sonicMapValue(value, target any) bool { + switch t := target.(type) { + case *graph.Node: + if n, ok := value.(*graph.Node); ok { + *t = *n + return true + } + case *graph.Relationship: + if r, ok := value.(*graph.Relationship); ok { + *t = *r + return true + } + case *graph.Path: + if p, ok := value.(*graph.Path); ok { + *t = *p + return true + } + } + return false +} diff --git a/drivers/sonic/eval.go b/drivers/sonic/eval.go new file mode 100644 index 0000000..4a857bb --- /dev/null +++ b/drivers/sonic/eval.go @@ -0,0 +1,629 @@ +package sonic + +import ( + "fmt" + "strings" + + cypher "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/query" +) + +// evalNodeCriteria evaluates a filter criteria against a node. +func evalNodeCriteria(db *Database, n *graph.Node, criteria graph.Criteria) bool { + return evalCriteria(db, n.ID, n.Kinds, n.Properties, 0, 0, nil, criteria) +} + +// evalRelCriteria evaluates a filter criteria against a relationship. +func evalRelCriteria(db *Database, r *graph.Relationship, criteria graph.Criteria) bool { + startNode := db.nodes[r.StartID] + endNode := db.nodes[r.EndID] + return evalCriteria(db, r.ID, nil, r.Properties, r.StartID, r.EndID, &evalContext{ + rel: r, + startNode: startNode, + endNode: endNode, + }, criteria) +} + +type evalContext struct { + rel *graph.Relationship + startNode *graph.Node + endNode *graph.Node +} + +func evalCriteria(db *Database, id graph.ID, kinds graph.Kinds, props *graph.Properties, startID, endID graph.ID, relCtx *evalContext, criteria graph.Criteria) bool { + switch c := criteria.(type) { + case *cypher.Conjunction: + for _, expr := range c.Expressions { + if !evalCriteria(db, id, kinds, props, startID, endID, relCtx, expr) { + return false + } + } + return true + + case *cypher.Disjunction: + for _, expr := range c.Expressions { + if evalCriteria(db, id, kinds, props, startID, endID, relCtx, expr) { + return true + } + } + return false + + case *cypher.Negation: + return !evalCriteria(db, id, kinds, props, startID, endID, relCtx, c.Expression) + + case *cypher.Parenthetical: + return evalCriteria(db, id, kinds, props, startID, endID, relCtx, c.Expression) + + case *cypher.Comparison: + return evalComparison(db, id, kinds, props, startID, endID, relCtx, c) + + case *cypher.KindMatcher: + return evalKindMatcher(id, kinds, relCtx, c) + + default: + return true + } +} + +func evalComparison(db *Database, id graph.ID, kinds graph.Kinds, props *graph.Properties, startID, endID graph.ID, relCtx *evalContext, cmp *cypher.Comparison) bool { + left := resolveValue(db, id, kinds, props, startID, endID, relCtx, cmp.Left) + + if cmp.Partials == nil || len(cmp.Partials) == 0 { + return false + } + + partial := cmp.Partials[0] + right := resolveValue(db, id, kinds, props, startID, endID, relCtx, partial.Right) + + switch partial.Operator { + case cypher.OperatorEquals: + return compareEquals(left, right) + case cypher.OperatorNotEquals: + return !compareEquals(left, right) + case cypher.OperatorIn: + return compareIn(left, right) + case cypher.OperatorContains: + return compareContains(left, right) + case cypher.OperatorStartsWith: + return compareStartsWith(left, right) + case cypher.OperatorEndsWith: + return compareEndsWith(left, right) + case cypher.OperatorGreaterThan: + return compareOrdered(left, right) > 0 + case cypher.OperatorGreaterThanOrEqualTo: + return compareOrdered(left, right) >= 0 + case cypher.OperatorLessThan: + return compareOrdered(left, right) < 0 + case cypher.OperatorLessThanOrEqualTo: + return compareOrdered(left, right) <= 0 + case cypher.OperatorIs: + return left == nil + case cypher.OperatorIsNot: + return left != nil + default: + return false + } +} + +func evalKindMatcher(id graph.ID, kinds graph.Kinds, relCtx *evalContext, km *cypher.KindMatcher) bool { + // Determine which kinds to match against based on the reference variable + var targetKinds graph.Kinds + + if v, ok := km.Reference.(*cypher.Variable); ok { + switch v.Symbol { + case query.EdgeStartSymbol: + if relCtx != nil && relCtx.startNode != nil { + targetKinds = relCtx.startNode.Kinds + } + case query.EdgeEndSymbol: + if relCtx != nil && relCtx.endNode != nil { + targetKinds = relCtx.endNode.Kinds + } + default: + targetKinds = kinds + } + } else { + targetKinds = kinds + } + + // For relationship kind matching + if relCtx != nil && relCtx.rel != nil { + if v, ok := km.Reference.(*cypher.Variable); ok && v.Symbol == query.EdgeSymbol { + for _, k := range km.Kinds { + if relCtx.rel.Kind == k { + return true + } + } + return false + } + } + + // Check if the target has all the requested kinds + for _, k := range km.Kinds { + found := false + for _, tk := range targetKinds { + if tk == k { + found = true + break + } + } + if !found { + return false + } + } + return true +} + +func resolveValue(db *Database, id graph.ID, kinds graph.Kinds, props *graph.Properties, startID, endID graph.ID, relCtx *evalContext, expr any) any { + switch e := expr.(type) { + case *cypher.Variable: + return nil + + case *cypher.FunctionInvocation: + if e.Name == "id" && len(e.Arguments) > 0 { + if v, ok := e.Arguments[0].(*cypher.Variable); ok { + switch v.Symbol { + case query.NodeSymbol: + return id + case query.EdgeSymbol: + return id + case query.EdgeStartSymbol: + return startID + case query.EdgeEndSymbol: + return endID + } + } + } + if e.Name == "toLower" && len(e.Arguments) > 0 { + inner := resolveValue(db, id, kinds, props, startID, endID, relCtx, e.Arguments[0]) + if s, ok := inner.(string); ok { + return strings.ToLower(s) + } + } + return nil + + case *cypher.PropertyLookup: + return resolveProperty(db, id, props, startID, endID, relCtx, e) + + case *cypher.Parameter: + return e.Value + + case *cypher.Literal: + if e.Null { + return nil + } + return e.Value + + default: + return expr + } +} + +func resolveProperty(db *Database, id graph.ID, props *graph.Properties, startID, endID graph.ID, relCtx *evalContext, lookup *cypher.PropertyLookup) any { + var targetProps *graph.Properties + + if v, ok := lookup.Atom.(*cypher.Variable); ok { + switch v.Symbol { + case query.NodeSymbol: + targetProps = props + case query.EdgeSymbol: + targetProps = props + case query.EdgeStartSymbol: + if relCtx != nil && relCtx.startNode != nil { + targetProps = relCtx.startNode.Properties + } + case query.EdgeEndSymbol: + if relCtx != nil && relCtx.endNode != nil { + targetProps = relCtx.endNode.Properties + } + } + } + + if targetProps == nil { + return nil + } + + return targetProps.Get(lookup.Symbol).Any() +} + +// --- Comparison helpers --- + +func compareEquals(left, right any) bool { + if left == nil && right == nil { + return true + } + if left == nil || right == nil { + return false + } + + // Handle graph.ID comparisons + leftID, leftIsID := toID(left) + rightID, rightIsID := toID(right) + if leftIsID && rightIsID { + return leftID == rightID + } + + return fmt.Sprint(left) == fmt.Sprint(right) +} + +func compareIn(left, right any) bool { + switch r := right.(type) { + case []graph.ID: + if lid, ok := toID(left); ok { + for _, id := range r { + if lid == id { + return true + } + } + } + return false + case []int64: + if lid, ok := toInt64(left); ok { + for _, v := range r { + if lid == v { + return true + } + } + } + return false + case []string: + ls := fmt.Sprint(left) + for _, v := range r { + if ls == v { + return true + } + } + return false + case []any: + for _, v := range r { + if compareEquals(left, v) { + return true + } + } + return false + default: + return false + } +} + +func compareContains(left, right any) bool { + ls, lok := left.(string) + rs, rok := right.(string) + if lok && rok { + return strings.Contains(ls, rs) + } + return false +} + +func compareStartsWith(left, right any) bool { + ls, lok := left.(string) + rs, rok := right.(string) + if lok && rok { + return strings.HasPrefix(ls, rs) + } + return false +} + +func compareEndsWith(left, right any) bool { + ls, lok := left.(string) + rs, rok := right.(string) + if lok && rok { + return strings.HasSuffix(ls, rs) + } + return false +} + +func compareOrdered(left, right any) int { + lf, lok := toFloat64(left) + rf, rok := toFloat64(right) + if lok && rok { + if lf < rf { + return -1 + } + if lf > rf { + return 1 + } + return 0 + } + return 0 +} + +// --- Binding-aware expression evaluation --- + +// evalBindingExpr evaluates a Cypher expression against a binding row. +// Returns true/false for boolean expressions (WHERE clauses). +func evalBindingExpr(db *Database, b binding, expr cypher.Expression) bool { + switch e := expr.(type) { + case *cypher.Conjunction: + for _, sub := range e.Expressions { + if !evalBindingExpr(db, b, sub) { + return false + } + } + return true + + case *cypher.Disjunction: + for _, sub := range e.Expressions { + if evalBindingExpr(db, b, sub) { + return true + } + } + return false + + case *cypher.Negation: + return !evalBindingExpr(db, b, e.Expression) + + case *cypher.Parenthetical: + return evalBindingExpr(db, b, e.Expression) + + case *cypher.Comparison: + return evalBindingComparison(db, b, e) + + case *cypher.KindMatcher: + return evalBindingKindMatcher(b, e) + + default: + return true + } +} + +// evalBindingComparison evaluates a comparison expression against a binding row. +func evalBindingComparison(db *Database, b binding, cmp *cypher.Comparison) bool { + left := resolveBindingValue(db, b, cmp.Left) + + if len(cmp.Partials) == 0 { + return false + } + + partial := cmp.Partials[0] + right := resolveBindingValue(db, b, partial.Right) + + switch partial.Operator { + case cypher.OperatorEquals: + return compareEquals(left, right) + case cypher.OperatorNotEquals: + return !compareEquals(left, right) + case cypher.OperatorIn: + return compareIn(left, right) + case cypher.OperatorContains: + return compareContains(left, right) + case cypher.OperatorStartsWith: + return compareStartsWith(left, right) + case cypher.OperatorEndsWith: + return compareEndsWith(left, right) + case cypher.OperatorGreaterThan: + return compareOrdered(left, right) > 0 + case cypher.OperatorGreaterThanOrEqualTo: + return compareOrdered(left, right) >= 0 + case cypher.OperatorLessThan: + return compareOrdered(left, right) < 0 + case cypher.OperatorLessThanOrEqualTo: + return compareOrdered(left, right) <= 0 + case cypher.OperatorIs: + return left == nil + case cypher.OperatorIsNot: + return left != nil + default: + return false + } +} + +// evalBindingKindMatcher checks if a bound entity matches the specified kinds. +func evalBindingKindMatcher(b binding, km *cypher.KindMatcher) bool { + v, ok := km.Reference.(*cypher.Variable) + if !ok { + return false + } + + entity, exists := b[v.Symbol] + if !exists || entity == nil { + return false + } + + var targetKinds graph.Kinds + switch e := entity.(type) { + case *graph.Node: + targetKinds = e.Kinds + case *graph.Relationship: + // For relationships, check if any matcher kind matches the rel kind + for _, k := range km.Kinds { + if e.Kind == k { + return true + } + } + return false + default: + return false + } + + // For nodes: all specified kinds must be present (AND semantics) + for _, k := range km.Kinds { + if !targetKinds.ContainsOneOf(k) { + return false + } + } + return true +} + +// resolveBindingValue resolves an expression to a concrete value using a binding row. +func resolveBindingValue(db *Database, b binding, expr any) any { + switch e := expr.(type) { + case *cypher.Variable: + return b[e.Symbol] + + case *cypher.PropertyLookup: + return resolveBindingProperty(db, b, e) + + case *cypher.FunctionInvocation: + return resolveBindingFunction(db, b, e) + + case *cypher.Parameter: + return e.Value + + case *cypher.Literal: + if e.Null { + return nil + } + return stripStringQuotes(e.Value) + + case *cypher.Parenthetical: + return resolveBindingValue(db, b, e.Expression) + + default: + return expr + } +} + +// resolveBindingProperty resolves a property lookup against a binding row. +func resolveBindingProperty(db *Database, b binding, lookup *cypher.PropertyLookup) any { + atom := resolveBindingValue(db, b, lookup.Atom) + if atom == nil { + return nil + } + + switch e := atom.(type) { + case *graph.Node: + if e.Properties == nil { + return nil + } + return e.Properties.Get(lookup.Symbol).Any() + case *graph.Relationship: + if e.Properties == nil { + return nil + } + return e.Properties.Get(lookup.Symbol).Any() + default: + return nil + } +} + +// resolveBindingFunction evaluates a function call against a binding row. +func resolveBindingFunction(db *Database, b binding, fn *cypher.FunctionInvocation) any { + switch fn.Name { + case "id": + if len(fn.Arguments) == 0 { + return nil + } + arg := resolveBindingValue(db, b, fn.Arguments[0]) + switch e := arg.(type) { + case *graph.Node: + return e.ID + case *graph.Relationship: + return e.ID + } + return nil + + case "type": + if len(fn.Arguments) == 0 { + return nil + } + arg := resolveBindingValue(db, b, fn.Arguments[0]) + if rel, ok := arg.(*graph.Relationship); ok { + return rel.Kind.String() + } + return nil + + case "toLower": + if len(fn.Arguments) == 0 { + return nil + } + arg := resolveBindingValue(db, b, fn.Arguments[0]) + if s, ok := arg.(string); ok { + return strings.ToLower(s) + } + return nil + + case "toUpper": + if len(fn.Arguments) == 0 { + return nil + } + arg := resolveBindingValue(db, b, fn.Arguments[0]) + if s, ok := arg.(string); ok { + return strings.ToUpper(s) + } + return nil + + case "count", "collect", "sum", "avg", "min", "max": + // Aggregation functions are not yet supported in this evaluator + return nil + + case "labels", "keys": + if len(fn.Arguments) == 0 { + return nil + } + arg := resolveBindingValue(db, b, fn.Arguments[0]) + if fn.Name == "labels" { + if node, ok := arg.(*graph.Node); ok { + return node.Kinds.Strings() + } + } + return nil + + default: + return nil + } +} + +// stripStringQuotes removes surrounding single or double quotes from a string value. +func stripStringQuotes(v any) any { + s, ok := v.(string) + if !ok { + return v + } + if len(s) >= 2 { + if (s[0] == '\'' && s[len(s)-1] == '\'') || (s[0] == '"' && s[len(s)-1] == '"') { + return s[1 : len(s)-1] + } + } + return v +} + +// --- Type coercion helpers --- + +func toID(v any) (graph.ID, bool) { + switch tv := v.(type) { + case graph.ID: + return tv, true + case int64: + return graph.ID(tv), true + case uint64: + return graph.ID(tv), true + case int: + return graph.ID(tv), true + default: + return 0, false + } +} + +func toInt64(v any) (int64, bool) { + switch tv := v.(type) { + case int64: + return tv, true + case int: + return int64(tv), true + case graph.ID: + return int64(tv), true + case float64: + return int64(tv), true + default: + return 0, false + } +} + +func toFloat64(v any) (float64, bool) { + switch tv := v.(type) { + case float64: + return tv, true + case float32: + return float64(tv), true + case int: + return float64(tv), true + case int64: + return float64(tv), true + case graph.ID: + return float64(tv), true + default: + return 0, false + } +} diff --git a/drivers/sonic/execute.go b/drivers/sonic/execute.go new file mode 100644 index 0000000..ceba8c4 --- /dev/null +++ b/drivers/sonic/execute.go @@ -0,0 +1,1007 @@ +package sonic + +import ( + "fmt" + "sort" + "strings" + + cypher "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/walk" + "github.com/specterops/dawgs/graph" +) + +const maxBindings = 100000 + +// binding represents a single row in the binding table: variable name → value. +type binding map[string]any + +// copyBinding creates a shallow copy of a binding row. +func copyBinding(b binding) binding { + c := make(binding, len(b)) + for k, v := range b { + c[k] = v + } + return c +} + +// matchState tracks state for a single MATCH clause. +type matchState struct { + snapshot []binding // pre-MATCH bindings for OPTIONAL MATCH fallback + optional bool + newVars []string // variables introduced by this MATCH +} + +// projItem holds a parsed projection item. +type projItem struct { + expr cypher.Expression + alias string +} + +// returnContext holds RETURN/WITH clause modifiers. +type returnContext struct { + limit int + skip int + distinct bool + orderBy []*cypher.SortItem +} + +// executor walks a Cypher AST and evaluates it against the in-memory database. +type executor struct { + walk.Visitor[cypher.SyntaxNode] + + db *Database + bindings []binding + params map[string]any + result *sonicResult + + // Clause-level state + matchStack []matchState + projItems []projItem + returnCtx *returnContext + inReturn bool + inWith bool + + // Track the current PatternPart for allShortestPaths detection + currentPatternPart *cypher.PatternPart + + // Track current match for resolveParameters + currentMatch *cypher.Match + + // Set when allShortestPaths has consumed the WHERE for path constraints + whereConsumedByPath bool +} + +func newExecutor(db *Database, params map[string]any) *executor { + return &executor{ + Visitor: walk.NewVisitor[cypher.SyntaxNode](), + db: db, + bindings: []binding{{}}, // seed with one empty binding row + params: params, + } +} + +// Enter is called on first visit to an AST node. +func (ex *executor) Enter(node cypher.SyntaxNode) { + switch n := node.(type) { + case *cypher.RegularQuery: + // top-level — no action + + case *cypher.SingleQuery: + // no action + + case *cypher.SinglePartQuery: + // no action + + case *cypher.MultiPartQuery: + // no action + + case *cypher.MultiPartQueryPart: + // no action + + case *cypher.ReadingClause: + // no action + + case *cypher.Match: + ex.currentMatch = n + // Resolve parameters in WHERE clause + if n.Where != nil { + resolveParameters(n.Where, ex.params) + } + ms := matchState{ + optional: n.Optional, + } + if n.Optional { + ms.snapshot = make([]binding, len(ex.bindings)) + for i, b := range ex.bindings { + ms.snapshot[i] = copyBinding(b) + } + } + ex.matchStack = append(ex.matchStack, ms) + + case *cypher.PatternPart: + ex.currentPatternPart = n + // If allShortestPaths, consume — we'll handle it entirely on Exit + if n.AllShortestPathsPattern { + ex.Consume() + } + + case *cypher.PatternElement: + // no action — walker descends to NodePattern/RelationshipPattern + + case *cypher.NodePattern: + // no action — expansion happens in Exit + + case *cypher.RelationshipPattern: + // no action — expansion happens in Exit + + case *cypher.Where: + // Consume — we don't want the walker to descend into expression children. + // We evaluate expression trees ourselves in Exit(Where). + ex.Consume() + + case *cypher.Return: + ex.inReturn = true + ex.projItems = nil + ex.returnCtx = &returnContext{} + + case *cypher.With: + ex.inWith = true + ex.projItems = nil + ex.returnCtx = &returnContext{} + + case *cypher.Projection: + // Consume — we handle projection items ourselves + ex.Consume() + ex.parseProjection(n) + + // --- Leaf / expression nodes: no-op in walker --- + case *cypher.Variable, *cypher.Literal, *cypher.Parameter, + *cypher.Comparison, *cypher.Conjunction, *cypher.Disjunction, + *cypher.Negation, *cypher.Parenthetical, *cypher.KindMatcher, + *cypher.FunctionInvocation, *cypher.PropertyLookup, + *cypher.PartialComparison, *cypher.ArithmeticExpression, + *cypher.PartialArithmeticExpression, *cypher.UnaryAddOrSubtractExpression, + *cypher.ProjectionItem, *cypher.Order, *cypher.SortItem, + *cypher.Skip, *cypher.Limit, *cypher.RangeQuantifier, + cypher.Operator, graph.Kinds, cypher.MapLiteral, *cypher.ListLiteral, + *cypher.ExclusiveDisjunction, *cypher.PatternPredicate: + // no-op + + // --- Mutation nodes: error --- + case *cypher.UpdatingClause: + ex.SetErrorf("sonic: write operations in Cypher not yet supported") + + case *cypher.Create: + ex.SetErrorf("sonic: write operations in Cypher not yet supported") + + case *cypher.Delete: + ex.SetErrorf("sonic: write operations in Cypher not yet supported") + + case *cypher.Set: + ex.SetErrorf("sonic: write operations in Cypher not yet supported") + + case *cypher.SetItem: + ex.SetErrorf("sonic: write operations in Cypher not yet supported") + + case *cypher.Remove: + ex.SetErrorf("sonic: write operations in Cypher not yet supported") + + case *cypher.RemoveItem: + ex.SetErrorf("sonic: write operations in Cypher not yet supported") + + case *cypher.Merge: + ex.SetErrorf("sonic: write operations in Cypher not yet supported") + + case *cypher.MergeAction: + ex.SetErrorf("sonic: write operations in Cypher not yet supported") + + // --- Unsupported constructs --- + case *cypher.Unwind: + ex.SetErrorf("sonic: UNWIND not yet supported") + + case *cypher.Quantifier: + ex.SetErrorf("sonic: quantifier expressions not yet supported") + + case *cypher.FilterExpression: + ex.SetErrorf("sonic: filter expressions not yet supported") + + case *cypher.IDInCollection: + ex.SetErrorf("sonic: ID IN collection not yet supported") + + case *cypher.MapItem: + // no-op (used inside Properties) + + case *cypher.Properties: + // no-op + + default: + ex.SetErrorf("sonic: unsupported cypher construct: %T", node) + } +} + +// Visit is called when returning to a node after processing a child. +func (ex *executor) Visit(node cypher.SyntaxNode) { + // no-op for all nodes +} + +// Exit is called after all children of a node have been processed. +func (ex *executor) Exit(node cypher.SyntaxNode) { + switch n := node.(type) { + case *cypher.NodePattern: + ex.expandNodePattern(n) + + case *cypher.RelationshipPattern: + ex.expandRelationshipPattern(n) + + case *cypher.PatternPart: + if n.AllShortestPathsPattern { + ex.handleAllShortestPaths(n) + } + ex.currentPatternPart = nil + + case *cypher.Where: + ex.filterBindings(n) + + case *cypher.Match: + ex.finalizeMatch() + ex.currentMatch = nil + + case *cypher.With: + ex.finalizeWith() + ex.inWith = false + + case *cypher.Return: + ex.finalizeReturn() + ex.inReturn = false + + case *cypher.SinglePartQuery: + ex.finalizeResult() + + case *cypher.MultiPartQueryPart: + // bindings carry forward — no special action needed + } +} + +// parseProjection extracts projection items and modifiers from a Projection node. +func (ex *executor) parseProjection(proj *cypher.Projection) { + if proj == nil { + return + } + + if proj.Distinct { + ex.returnCtx.distinct = true + } + + // Parse items + if proj.All { + // RETURN * — project all bound variables + // We'll handle this in finalizeReturn/finalizeWith + } else { + for _, item := range proj.Items { + if pi, ok := item.(*cypher.ProjectionItem); ok { + alias := "" + if pi.Alias != nil { + alias = pi.Alias.Symbol + } else if v, ok := pi.Expression.(*cypher.Variable); ok { + alias = v.Symbol + } else if pl, ok := pi.Expression.(*cypher.PropertyLookup); ok { + alias = pl.Symbol + } + ex.projItems = append(ex.projItems, projItem{ + expr: pi.Expression, + alias: alias, + }) + } + } + } + + // Parse ORDER BY + if proj.Order != nil { + ex.returnCtx.orderBy = proj.Order.Items + } + + // Parse SKIP + if proj.Skip != nil { + if lit, ok := proj.Skip.Value.(*cypher.Literal); ok { + ex.returnCtx.skip = toInt(lit.Value) + } + } + + // Parse LIMIT + if proj.Limit != nil { + if lit, ok := proj.Limit.Value.(*cypher.Literal); ok { + ex.returnCtx.limit = toInt(lit.Value) + } + } +} + +func toInt(v any) int { + switch tv := v.(type) { + case int: + return tv + case int64: + return int(tv) + case float64: + return int(tv) + default: + return 0 + } +} + +// anonNodeCounter is used to generate unique keys for anonymous node patterns. +var anonNodeCounter int + +// expandNodePattern expands bindings against matching nodes. +func (ex *executor) expandNodePattern(np *cypher.NodePattern) { + varName := "" + if np.Variable != nil { + varName = np.Variable.Symbol + } else { + // Anonymous node — use a synthetic binding key so findAnchorNode can find it + anonNodeCounter++ + varName = fmt.Sprintf("__anon_node_%d__", anonNodeCounter) + } + + ex.db.mu.RLock() + defer ex.db.mu.RUnlock() + + var expanded []binding + + for _, row := range ex.bindings { + // If the variable is already bound, just filter — don't expand + if varName != "" { + if existing, ok := row[varName]; ok { + if node, ok := existing.(*graph.Node); ok { + if ex.nodeMatchesPattern(node, np) { + expanded = append(expanded, row) + } + continue + } + } + } + + // If __next_node__ is set (from a preceding relationship expansion), + // bind from it instead of scanning all nodes. + if nextNode, ok := row["__next_node__"]; ok { + if node, ok := nextNode.(*graph.Node); ok { + if ex.nodeMatchesPattern(node, np) { + newRow := copyBinding(row) + delete(newRow, "__next_node__") + if varName != "" { + newRow[varName] = node + } + expanded = append(expanded, newRow) + } + continue + } + } + + // Expand against all nodes + for _, node := range ex.db.nodes { + if !ex.nodeMatchesPattern(node, np) { + continue + } + + newRow := copyBinding(row) + if varName != "" { + newRow[varName] = node + } + expanded = append(expanded, newRow) + + if len(expanded) > maxBindings { + ex.SetErrorf("sonic: binding count exceeded %d — query too broad", maxBindings) + return + } + } + } + + ex.bindings = expanded + + // Track new variable in match state + if varName != "" && len(ex.matchStack) > 0 { + ms := &ex.matchStack[len(ex.matchStack)-1] + ms.newVars = append(ms.newVars, varName) + } +} + +// nodeMatchesPattern checks if a node satisfies a NodePattern's constraints. +func (ex *executor) nodeMatchesPattern(node *graph.Node, np *cypher.NodePattern) bool { + if len(np.Kinds) > 0 { + // Node patterns use AND semantics — node must have ALL specified kinds + for _, k := range np.Kinds { + if !node.Kinds.ContainsOneOf(k) { + return false + } + } + } + return true +} + +// expandRelationshipPattern expands bindings by following edges. +func (ex *executor) expandRelationshipPattern(rp *cypher.RelationshipPattern) { + // Variable-length paths: (a)-[r*1..3]->(b) + if rp.Range != nil { + ex.expandVariableLengthPattern(rp) + return + } + + relVar := "" + if rp.Variable != nil { + relVar = rp.Variable.Symbol + } + + ex.db.mu.RLock() + defer ex.db.mu.RUnlock() + + var expanded []binding + + for _, row := range ex.bindings { + // Find the most recently bound node in this row to use as the anchor + anchorNode := ex.findAnchorNode(row) + if anchorNode == nil { + // No anchor — expand from all edges (first element in pattern was a rel?) + continue + } + + // Get candidate edges based on direction + var edgeIDs []graph.ID + switch rp.Direction { + case graph.DirectionOutbound: + edgeIDs = ex.db.outEdges[anchorNode.ID] + case graph.DirectionInbound: + edgeIDs = ex.db.inEdges[anchorNode.ID] + default: // DirectionBoth + edgeIDs = append(edgeIDs, ex.db.outEdges[anchorNode.ID]...) + edgeIDs = append(edgeIDs, ex.db.inEdges[anchorNode.ID]...) + } + + for _, edgeID := range edgeIDs { + edge := ex.db.edges[edgeID] + if edge == nil { + continue + } + + // Check kind constraints (disjunction for relationships) + if len(rp.Kinds) > 0 { + if !rp.Kinds.ContainsOneOf(edge.Kind) { + continue + } + } + + // Determine the other end node + var otherNode *graph.Node + switch rp.Direction { + case graph.DirectionOutbound: + otherNode = ex.db.nodes[edge.EndID] + case graph.DirectionInbound: + otherNode = ex.db.nodes[edge.StartID] + default: // DirectionBoth + if edge.StartID == anchorNode.ID { + otherNode = ex.db.nodes[edge.EndID] + } else { + otherNode = ex.db.nodes[edge.StartID] + } + } + + if otherNode == nil { + continue + } + + newRow := copyBinding(row) + if relVar != "" { + newRow[relVar] = edge + } + // The other end node will be bound by the next NodePattern + // Store it temporarily so the next NodePattern can pick it up + newRow["__next_node__"] = otherNode + expanded = append(expanded, newRow) + + if len(expanded) > maxBindings { + ex.SetErrorf("sonic: binding count exceeded %d — query too broad", maxBindings) + return + } + } + } + + ex.bindings = expanded + + if relVar != "" && len(ex.matchStack) > 0 { + ms := &ex.matchStack[len(ex.matchStack)-1] + ms.newVars = append(ms.newVars, relVar) + } +} + +const maxVarLengthDepth = 50 + +// expandVariableLengthPattern handles relationship patterns with a range like [*1..3]. +// It performs BFS from the anchor node, collecting all paths whose length falls +// within [minHops, maxHops]. Each valid path produces a new binding row with: +// - relVar → []*graph.Relationship (the edges traversed) +// - __next_node__ → the terminal node +func (ex *executor) expandVariableLengthPattern(rp *cypher.RelationshipPattern) { + minHops := int64(1) + maxHops := int64(maxVarLengthDepth) + + if rp.Range.StartIndex != nil { + minHops = *rp.Range.StartIndex + } + if rp.Range.EndIndex != nil { + maxHops = *rp.Range.EndIndex + } + if maxHops > maxVarLengthDepth { + maxHops = maxVarLengthDepth + } + if minHops < 0 { + minHops = 0 + } + + relVar := "" + if rp.Variable != nil { + relVar = rp.Variable.Symbol + } + + ex.db.mu.RLock() + defer ex.db.mu.RUnlock() + + var expanded []binding + + for _, row := range ex.bindings { + anchorNode := ex.findAnchorNode(row) + if anchorNode == nil { + continue + } + + // BFS with path tracking. Each entry is a (nodeID, path-of-edges) pair. + type bfsState struct { + nodeID graph.ID + edges []*graph.Relationship + } + + queue := []bfsState{{nodeID: anchorNode.ID}} + // visited tracks the set of nodes in the *current* path to prevent cycles. + // We rebuild this per-path, so we use the queue entries themselves. + + for len(queue) > 0 { + cur := queue[0] + queue = queue[1:] + + depth := int64(len(cur.edges)) + + // If within valid range, emit a binding + if depth >= minHops { + terminalNode := ex.db.nodes[cur.nodeID] + if terminalNode != nil { + newRow := copyBinding(row) + if relVar != "" { + // Bind as slice of relationships + edgeCopy := make([]*graph.Relationship, len(cur.edges)) + copy(edgeCopy, cur.edges) + newRow[relVar] = edgeCopy + } + newRow["__next_node__"] = terminalNode + expanded = append(expanded, newRow) + + if len(expanded) > maxBindings { + ex.SetErrorf("sonic: binding count exceeded %d — query too broad", maxBindings) + return + } + } + } + + // If at max depth, don't expand further + if depth >= maxHops { + continue + } + + // Collect nodes already in this path to prevent cycles + visited := make(map[graph.ID]struct{}, len(cur.edges)+1) + visited[anchorNode.ID] = struct{}{} + for _, e := range cur.edges { + visited[e.StartID] = struct{}{} + visited[e.EndID] = struct{}{} + } + + // Expand neighbors + var edgeIDs []graph.ID + switch rp.Direction { + case graph.DirectionOutbound: + edgeIDs = ex.db.outEdges[cur.nodeID] + case graph.DirectionInbound: + edgeIDs = ex.db.inEdges[cur.nodeID] + default: + edgeIDs = append(edgeIDs, ex.db.outEdges[cur.nodeID]...) + edgeIDs = append(edgeIDs, ex.db.inEdges[cur.nodeID]...) + } + + for _, edgeID := range edgeIDs { + edge := ex.db.edges[edgeID] + if edge == nil { + continue + } + + // Kind filter (disjunction) + if len(rp.Kinds) > 0 && !rp.Kinds.ContainsOneOf(edge.Kind) { + continue + } + + // Determine neighbor + var neighborID graph.ID + switch rp.Direction { + case graph.DirectionOutbound: + neighborID = edge.EndID + case graph.DirectionInbound: + neighborID = edge.StartID + default: + if edge.StartID == cur.nodeID { + neighborID = edge.EndID + } else { + neighborID = edge.StartID + } + } + + // Cycle check — skip if neighbor already in path + if _, inPath := visited[neighborID]; inPath { + continue + } + + newEdges := make([]*graph.Relationship, len(cur.edges)+1) + copy(newEdges, cur.edges) + newEdges[len(cur.edges)] = edge + + queue = append(queue, bfsState{ + nodeID: neighborID, + edges: newEdges, + }) + } + } + } + + ex.bindings = expanded + + if relVar != "" && len(ex.matchStack) > 0 { + ms := &ex.matchStack[len(ex.matchStack)-1] + ms.newVars = append(ms.newVars, relVar) + } +} + +// findAnchorNode returns the most recently bound node in a binding row. +// It looks for the __next_node__ temporary or the last bound *graph.Node. +func (ex *executor) findAnchorNode(row binding) *graph.Node { + // Check for __next_node__ left by a previous relationship expansion + if n, ok := row["__next_node__"]; ok { + if node, ok := n.(*graph.Node); ok { + return node + } + } + + // Find last bound node — iterate over matchStack newVars in reverse + if len(ex.matchStack) > 0 { + ms := &ex.matchStack[len(ex.matchStack)-1] + for i := len(ms.newVars) - 1; i >= 0; i-- { + if v, ok := row[ms.newVars[i]]; ok { + if node, ok := v.(*graph.Node); ok { + return node + } + } + } + } + + // Fallback: find any bound node (last one added) + var lastNode *graph.Node + for _, v := range row { + if node, ok := v.(*graph.Node); ok { + lastNode = node + } + } + return lastNode +} + +// filterBindings applies WHERE expressions to filter binding rows. +func (ex *executor) filterBindings(where *cypher.Where) { + if where == nil || len(where.Expressions) == 0 { + return + } + + // If allShortestPaths already consumed the WHERE for path constraints, skip + if ex.whereConsumedByPath { + ex.whereConsumedByPath = false + return + } + + var filtered []binding + for _, row := range ex.bindings { + matched := true + for _, expr := range where.Expressions { + if !evalBindingExpr(ex.db, row, expr) { + matched = false + break + } + } + if matched { + filtered = append(filtered, row) + } + } + ex.bindings = filtered +} + +// finalizeMatch pops the match state and handles OPTIONAL MATCH fallback. +func (ex *executor) finalizeMatch() { + if len(ex.matchStack) == 0 { + return + } + + ms := ex.matchStack[len(ex.matchStack)-1] + ex.matchStack = ex.matchStack[:len(ex.matchStack)-1] + + if ms.optional && len(ex.bindings) == 0 { + // OPTIONAL MATCH: restore snapshot with nil-filled new variables + restored := make([]binding, len(ms.snapshot)) + for i, b := range ms.snapshot { + restored[i] = copyBinding(b) + for _, v := range ms.newVars { + restored[i][v] = nil + } + } + ex.bindings = restored + } + + // Clean up __next_node__ temporaries + for _, row := range ex.bindings { + delete(row, "__next_node__") + } +} + +// finalizeWith projects bindings through a WITH clause (scope barrier). +func (ex *executor) finalizeWith() { + if ex.returnCtx == nil { + return + } + + // Apply ORDER BY before projection + if len(ex.returnCtx.orderBy) > 0 { + ex.sortBindings() + } + + // Project bindings + if len(ex.projItems) > 0 { + projected := make([]binding, 0, len(ex.bindings)) + for _, row := range ex.bindings { + newRow := make(binding) + for _, item := range ex.projItems { + val := resolveBindingValue(ex.db, row, item.expr) + newRow[item.alias] = val + } + projected = append(projected, newRow) + } + ex.bindings = projected + } + + // Apply DISTINCT + if ex.returnCtx.distinct { + ex.bindings = deduplicateBindings(ex.bindings, ex.projItemAliases()) + } + + // Apply SKIP + if ex.returnCtx.skip > 0 && ex.returnCtx.skip < len(ex.bindings) { + ex.bindings = ex.bindings[ex.returnCtx.skip:] + } else if ex.returnCtx.skip >= len(ex.bindings) { + ex.bindings = nil + } + + // Apply LIMIT + if ex.returnCtx.limit > 0 && len(ex.bindings) > ex.returnCtx.limit { + ex.bindings = ex.bindings[:ex.returnCtx.limit] + } + + ex.returnCtx = nil + ex.projItems = nil +} + +// finalizeReturn builds the sonicResult from current bindings. +func (ex *executor) finalizeReturn() { + if ex.returnCtx == nil { + return + } + + // Apply ORDER BY before projection + if len(ex.returnCtx.orderBy) > 0 { + ex.sortBindings() + } + + // Determine columns + var keys []string + if len(ex.projItems) > 0 { + for _, item := range ex.projItems { + keys = append(keys, item.alias) + } + } else { + // RETURN * — use all bound variables + keys = ex.allBoundVariables() + } + + // Apply DISTINCT + if ex.returnCtx.distinct { + ex.bindings = deduplicateBindings(ex.bindings, keys) + } + + // Apply SKIP + if ex.returnCtx.skip > 0 && ex.returnCtx.skip < len(ex.bindings) { + ex.bindings = ex.bindings[ex.returnCtx.skip:] + } else if ex.returnCtx.skip >= len(ex.bindings) { + ex.bindings = nil + } + + // Apply LIMIT + if ex.returnCtx.limit > 0 && len(ex.bindings) > ex.returnCtx.limit { + ex.bindings = ex.bindings[:ex.returnCtx.limit] + } + + // Build result rows + rows := make([][]any, 0, len(ex.bindings)) + for _, row := range ex.bindings { + vals := make([]any, len(keys)) + for i, k := range keys { + if len(ex.projItems) > 0 { + vals[i] = resolveBindingValue(ex.db, row, ex.projItems[i].expr) + } else { + vals[i] = row[k] + } + } + rows = append(rows, vals) + } + + ex.result = &sonicResult{rows: rows, keys: keys} + ex.returnCtx = nil + ex.projItems = nil +} + +// finalizeResult sets a default empty result if none was built. +func (ex *executor) finalizeResult() { + if ex.result == nil { + ex.result = emptyResult() + } +} + +// handleAllShortestPaths handles allShortestPaths pattern parts. +func (ex *executor) handleAllShortestPaths(pp *cypher.PatternPart) { + // Collect WHERE expressions from the current match + var filters []graph.Criteria + if ex.currentMatch != nil && ex.currentMatch.Where != nil { + for _, expr := range ex.currentMatch.Where.Expressions { + filters = append(filters, expr) + } + } + + pc := extractPathConstraints(filters) + + // Extract kind constraints from the relationship pattern + for _, elem := range pp.PatternElements { + if rp, ok := elem.AsRelationshipPattern(); ok && len(rp.Kinds) > 0 { + if pc.edgeKinds == nil { + pc.edgeKinds = make(map[graph.Kind]struct{}) + } + for _, k := range rp.Kinds { + pc.edgeKinds[k] = struct{}{} + } + } + } + + ex.db.mu.RLock() + paths := ex.db.bfsAllShortestPaths(pc) + ex.db.mu.RUnlock() + + // Bind paths to the pattern variable + pathVar := "" + if pp.Variable != nil { + pathVar = pp.Variable.Symbol + } + + var expanded []binding + for _, row := range ex.bindings { + for _, p := range paths { + pathCopy := p + newRow := copyBinding(row) + if pathVar != "" { + newRow[pathVar] = &pathCopy + } + expanded = append(expanded, newRow) + } + } + + ex.bindings = expanded + + if pathVar != "" && len(ex.matchStack) > 0 { + ms := &ex.matchStack[len(ex.matchStack)-1] + ms.newVars = append(ms.newVars, pathVar) + } + + // Mark WHERE as consumed — path constraints already extracted + ex.whereConsumedByPath = true +} + +// sortBindings sorts bindings by ORDER BY items. +func (ex *executor) sortBindings() { + if ex.returnCtx == nil || len(ex.returnCtx.orderBy) == 0 { + return + } + + sort.SliceStable(ex.bindings, func(i, j int) bool { + for _, si := range ex.returnCtx.orderBy { + vi := resolveBindingValue(ex.db, ex.bindings[i], si.Expression) + vj := resolveBindingValue(ex.db, ex.bindings[j], si.Expression) + + cmp := compareOrdered(vi, vj) + if cmp == 0 { + // Try string comparison as fallback + si := fmt.Sprint(vi) + sj := fmt.Sprint(vj) + if si == sj { + continue + } + if si < sj { + cmp = -1 + } else { + cmp = 1 + } + } + + if !si.Ascending { + cmp = -cmp + } + if cmp < 0 { + return true + } + if cmp > 0 { + return false + } + } + return false + }) +} + +// allBoundVariables returns all variable names from bindings (for RETURN *). +func (ex *executor) allBoundVariables() []string { + seen := make(map[string]struct{}) + var keys []string + for _, row := range ex.bindings { + for k := range row { + if strings.HasPrefix(k, "__") { + continue + } + if _, ok := seen[k]; !ok { + seen[k] = struct{}{} + keys = append(keys, k) + } + } + } + sort.Strings(keys) + return keys +} + +// projItemAliases returns the alias names from projItems. +func (ex *executor) projItemAliases() []string { + aliases := make([]string, len(ex.projItems)) + for i, item := range ex.projItems { + aliases[i] = item.alias + } + return aliases +} + +// deduplicateBindings removes duplicate binding rows based on the given keys. +func deduplicateBindings(bindings []binding, keys []string) []binding { + seen := make(map[string]struct{}) + var result []binding + + for _, row := range bindings { + var parts []string + for _, k := range keys { + parts = append(parts, fmt.Sprintf("%v", row[k])) + } + key := strings.Join(parts, "\x00") + if _, ok := seen[key]; !ok { + seen[key] = struct{}{} + result = append(result, row) + } + } + return result +} diff --git a/drivers/sonic/integration_test.go b/drivers/sonic/integration_test.go new file mode 100644 index 0000000..b1d0616 --- /dev/null +++ b/drivers/sonic/integration_test.go @@ -0,0 +1,435 @@ +package sonic_test + +import ( + "context" + "testing" + + "github.com/specterops/dawgs/drivers/sonic" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/ops" + "github.com/specterops/dawgs/query" +) + +// These tests exercise the sonic driver through the ops package — the same +// code paths BloodHound Enterprise uses for graph operations. + +var ( + User = graph.StringKind("User") + Group = graph.StringKind("Group") + Computer = graph.StringKind("Computer") + Domain = graph.StringKind("Domain") + MemberOf = graph.StringKind("MemberOf") + HasSession = graph.StringKind("HasSession") + AdminTo = graph.StringKind("AdminTo") + GenericAll = graph.StringKind("GenericAll") +) + +func setupTestGraph(t *testing.T) (*sonic.Database, map[string]graph.ID) { + t.Helper() + + db := sonic.NewDatabase() + ctx := context.Background() + ids := make(map[string]graph.ID) + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + type nodeSpec struct { + name string + kinds []graph.Kind + } + nodes := []nodeSpec{ + {"jsmith", []graph.Kind{User}}, + {"Domain Admins", []graph.Kind{Group}}, + {"Domain Users", []graph.Kind{Group}}, + {"IT Admins", []graph.Kind{Group}}, + {"DC01", []graph.Kind{Computer}}, + {"WS01", []graph.Kind{Computer}}, + {"corp.local", []graph.Kind{Domain}}, + } + + for _, n := range nodes { + node, err := tx.CreateNode(graph.AsProperties(map[string]any{ + "name": n.name, + }), n.kinds...) + if err != nil { + return err + } + ids[n.name] = node.ID + } + + edges := []struct { + from, to string + kind graph.Kind + }{ + {"jsmith", "Domain Users", MemberOf}, + {"jsmith", "IT Admins", MemberOf}, + {"IT Admins", "Domain Admins", MemberOf}, + {"Domain Admins", "corp.local", GenericAll}, + {"jsmith", "WS01", HasSession}, + {"Domain Admins", "DC01", AdminTo}, + {"IT Admins", "WS01", AdminTo}, + } + + for _, e := range edges { + if _, err := tx.CreateRelationshipByIDs(ids[e.from], ids[e.to], e.kind, graph.NewProperties()); err != nil { + return err + } + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + + return db, ids +} + +// TestOpsCountNodes tests ops.CountNodes against sonic. +func TestOpsCountNodes(t *testing.T) { + db, _ := setupTestGraph(t) + ctx := context.Background() + + total, err := ops.CountNodes(ctx, db) + if err != nil { + t.Fatal(err) + } + if total != 7 { + t.Errorf("expected 7 nodes, got %d", total) + } + + groups, err := ops.CountNodes(ctx, db, query.KindIn(query.Node(), Group)) + if err != nil { + t.Fatal(err) + } + if groups != 3 { + t.Errorf("expected 3 groups, got %d", groups) + } +} + +// TestOpsFetchNode tests ops.FetchNode (single node by ID). +func TestOpsFetchNode(t *testing.T) { + db, ids := setupTestGraph(t) + ctx := context.Background() + + err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { + node, err := ops.FetchNode(tx, ids["jsmith"]) + if err != nil { + return err + } + + name, _ := node.Properties.Get("name").String() + if name != "jsmith" { + t.Errorf("expected jsmith, got %s", name) + } + if !node.Kinds.ContainsOneOf(User) { + t.Error("expected node to have User kind") + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestOpsFetchNodeSet tests ops.FetchNodeSet with kind filtering. +func TestOpsFetchNodeSet(t *testing.T) { + db, _ := setupTestGraph(t) + ctx := context.Background() + + err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { + computers, err := ops.FetchNodeSet(tx.Nodes().Filter( + query.KindIn(query.Node(), Computer), + )) + if err != nil { + return err + } + if computers.Len() != 2 { + t.Errorf("expected 2 computers, got %d", computers.Len()) + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestOpsFetchRelationships tests fetching relationships with kind filter. +func TestOpsFetchRelationships(t *testing.T) { + db, _ := setupTestGraph(t) + ctx := context.Background() + + err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { + memberOfRels, err := ops.FetchRelationships( + tx.Relationships().Filter(query.KindIn(query.Relationship(), MemberOf)), + ) + if err != nil { + return err + } + if len(memberOfRels) != 3 { + t.Errorf("expected 3 MemberOf relationships, got %d", len(memberOfRels)) + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestOpsFetchNodeRelationships tests ops.FetchNodeRelationships (outbound/inbound from a node). +func TestOpsFetchNodeRelationships(t *testing.T) { + db, ids := setupTestGraph(t) + ctx := context.Background() + + err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { + jsmith, err := ops.FetchNode(tx, ids["jsmith"]) + if err != nil { + return err + } + + // jsmith has 3 outbound edges: MemberOf(Domain Users), MemberOf(IT Admins), HasSession(WS01) + outbound, err := ops.FetchNodeRelationships(tx, jsmith, graph.DirectionOutbound) + if err != nil { + return err + } + if len(outbound) != 3 { + t.Errorf("expected 3 outbound relationships from jsmith, got %d", len(outbound)) + } + + // jsmith has 0 inbound edges + inbound, err := ops.FetchNodeRelationships(tx, jsmith, graph.DirectionInbound) + if err != nil { + return err + } + if len(inbound) != 0 { + t.Errorf("expected 0 inbound relationships to jsmith, got %d", len(inbound)) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestOpsForEachStartEndNode tests ops.ForEachStartNode and ForEachEndNode. +func TestOpsForEachStartEndNode(t *testing.T) { + db, _ := setupTestGraph(t) + ctx := context.Background() + + err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { + // Get start nodes of all AdminTo edges + startNodes, err := ops.FetchStartNodes( + tx.Relationships().Filter(query.KindIn(query.Relationship(), AdminTo)), + ) + if err != nil { + return err + } + if startNodes.Len() != 2 { + t.Errorf("expected 2 AdminTo start nodes, got %d", startNodes.Len()) + } + + // Get end nodes of all AdminTo edges + endNodes, err := ops.FetchEndNodes( + tx.Relationships().Filter(query.KindIn(query.Relationship(), AdminTo)), + ) + if err != nil { + return err + } + if endNodes.Len() != 2 { + t.Errorf("expected 2 AdminTo end nodes, got %d", endNodes.Len()) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestOpsDeleteNodes tests ops.DeleteNodes. +func TestOpsDeleteNodes(t *testing.T) { + db, ids := setupTestGraph(t) + ctx := context.Background() + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + return ops.DeleteNodes(tx, ids["WS01"]) + }) + if err != nil { + t.Fatal(err) + } + + count, err := ops.CountNodes(ctx, db) + if err != nil { + t.Fatal(err) + } + if count != 6 { + t.Errorf("expected 6 nodes after delete, got %d", count) + } +} + +// TestOpsDeleteRelationships tests ops.DeleteRelationships. +func TestOpsDeleteRelationships(t *testing.T) { + db, ids := setupTestGraph(t) + ctx := context.Background() + + // Find the HasSession edge from jsmith -> WS01 + var hasSessionID graph.ID + err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { + return tx.Relationships().Filter( + query.And( + query.KindIn(query.Relationship(), HasSession), + query.Equals(query.StartID(), ids["jsmith"]), + ), + ).Fetch(func(cursor graph.Cursor[*graph.Relationship]) error { + for rel := range cursor.Chan() { + hasSessionID = rel.ID + } + return cursor.Error() + }) + }) + if err != nil { + t.Fatal(err) + } + + // Delete it + err = db.WriteTransaction(ctx, func(tx graph.Transaction) error { + return ops.DeleteRelationships(tx, hasSessionID) + }) + if err != nil { + t.Fatal(err) + } + + // Verify + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + count, err := tx.Relationships().Filter( + query.KindIn(query.Relationship(), HasSession), + ).Count() + if err != nil { + return err + } + if count != 0 { + t.Errorf("expected 0 HasSession edges after delete, got %d", count) + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestAttackPathFinding tests the core BHE use case: find attack paths through the graph. +func TestAttackPathFinding(t *testing.T) { + db, ids := setupTestGraph(t) + ctx := context.Background() + + // Attack path: jsmith -> IT Admins -> Domain Admins -> corp.local (via MemberOf, MemberOf, GenericAll) + err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { + return tx.Relationships(). + Filter(query.InIDs(query.StartID(), ids["jsmith"])). + Filter(query.InIDs(query.EndID(), ids["corp.local"])). + Filter(query.KindIn(query.Relationship(), MemberOf, GenericAll)). + FetchAllShortestPaths(func(cursor graph.Cursor[graph.Path]) error { + var paths []graph.Path + for path := range cursor.Chan() { + paths = append(paths, path) + } + if err := cursor.Error(); err != nil { + return err + } + + if len(paths) != 1 { + t.Fatalf("expected 1 attack path, got %d", len(paths)) + } + + path := paths[0] + if len(path.Nodes) != 4 { + t.Errorf("expected 4 nodes in path, got %d", len(path.Nodes)) + } + if len(path.Edges) != 3 { + t.Errorf("expected 3 edges in path, got %d", len(path.Edges)) + } + + // Verify path: jsmith -> IT Admins -> Domain Admins -> corp.local + if path.Root().ID != ids["jsmith"] { + t.Errorf("expected path to start at jsmith") + } + if path.Terminal().ID != ids["corp.local"] { + t.Errorf("expected path to end at corp.local") + } + + return nil + }) + }) + if err != nil { + t.Fatal(err) + } +} + +// TestBatchOperations tests batch create and upsert operations. +func TestBatchOperations(t *testing.T) { + db := sonic.NewDatabase() + ctx := context.Background() + + // Batch create + err := db.BatchOperation(ctx, func(batch graph.Batch) error { + node := &graph.Node{ + Kinds: graph.Kinds{User}, + Properties: graph.AsProperties(map[string]any{"name": "alice", "email": "alice@corp.local"}), + } + if err := batch.CreateNode(node); err != nil { + return err + } + + // Upsert same node — should update, not create duplicate + upsertNode := &graph.Node{ + Kinds: graph.Kinds{User}, + Properties: graph.AsProperties(map[string]any{"name": "alice", "email": "alice@newcorp.local"}), + } + return batch.UpdateNodeBy(graph.NodeUpdate{ + Node: upsertNode, + IdentityKind: User, + IdentityProperties: []string{"name"}, + }) + }) + if err != nil { + t.Fatal(err) + } + + // Should still be 1 node, with updated email + count, err := ops.CountNodes(ctx, db) + if err != nil { + t.Fatal(err) + } + if count != 1 { + t.Errorf("expected 1 node after upsert, got %d", count) + } + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + node, err := tx.Nodes().First() + if err != nil { + return err + } + email, _ := node.Properties.Get("email").String() + if email != "alice@newcorp.local" { + t.Errorf("expected updated email, got %s", email) + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestParallelFetchNodes tests the parallel fetch code path used by BHE. +func TestParallelFetchNodes(t *testing.T) { + db, _ := setupTestGraph(t) + ctx := context.Background() + + nodes, err := ops.ParallelFetchNodes(ctx, db, query.KindIn(query.Node(), User), 2) + if err != nil { + t.Fatal(err) + } + if nodes.Len() != 1 { + t.Errorf("expected 1 user from parallel fetch, got %d", nodes.Len()) + } +} diff --git a/drivers/sonic/pathfinding.go b/drivers/sonic/pathfinding.go new file mode 100644 index 0000000..556e5dc --- /dev/null +++ b/drivers/sonic/pathfinding.go @@ -0,0 +1,313 @@ +package sonic + +import ( + cypher "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/query" +) + +// pathConstraints holds the extracted constraints from query filters for pathfinding. +type pathConstraints struct { + startIDs map[graph.ID]struct{} + endIDs map[graph.ID]struct{} + edgeKinds map[graph.Kind]struct{} // nil means all kinds allowed +} + +// extractPathConstraints walks the filter criteria to pull out start node IDs, +// end node IDs, and allowed edge kinds. +func extractPathConstraints(filters []graph.Criteria) pathConstraints { + pc := pathConstraints{ + startIDs: make(map[graph.ID]struct{}), + endIDs: make(map[graph.ID]struct{}), + } + + for _, f := range filters { + extractFromCriteria(f, &pc) + } + + return pc +} + +func extractFromCriteria(criteria graph.Criteria, pc *pathConstraints) { + switch c := criteria.(type) { + case *cypher.Conjunction: + for _, expr := range c.Expressions { + extractFromCriteria(expr, pc) + } + + case *cypher.Parenthetical: + extractFromCriteria(c.Expression, pc) + + case *cypher.Comparison: + extractFromComparison(c, pc) + + case *cypher.KindMatcher: + extractFromKindMatcher(c, pc) + } +} + +func extractFromComparison(cmp *cypher.Comparison, pc *pathConstraints) { + if len(cmp.Partials) == 0 { + return + } + + partial := cmp.Partials[0] + + // We're looking for patterns like: id(s) IN [...] or id(e) IN [...] + if partial.Operator != cypher.OperatorIn && partial.Operator != cypher.OperatorEquals { + return + } + + // Check if the left side is id(s) or id(e) + funcInv, ok := cmp.Left.(*cypher.FunctionInvocation) + if !ok || funcInv.Name != "id" || len(funcInv.Arguments) == 0 { + return + } + + v, ok := funcInv.Arguments[0].(*cypher.Variable) + if !ok { + return + } + + var targetSet map[graph.ID]struct{} + switch v.Symbol { + case query.EdgeStartSymbol: + targetSet = pc.startIDs + case query.EdgeEndSymbol: + targetSet = pc.endIDs + default: + return + } + + // Extract the IDs from the right side + if partial.Operator == cypher.OperatorEquals { + if param, ok := partial.Right.(*cypher.Parameter); ok { + if id, ok := toID(param.Value); ok { + targetSet[id] = struct{}{} + } + } else if lit, ok := partial.Right.(*cypher.Literal); ok { + if id, ok := toID(lit.Value); ok { + targetSet[id] = struct{}{} + } + } + return + } + + // OperatorIn — right side is a Parameter whose Value is a slice of IDs + if param, ok := partial.Right.(*cypher.Parameter); ok { + extractIDs(param.Value, targetSet) + } else if lit, ok := partial.Right.(*cypher.Literal); ok { + extractIDs(lit.Value, targetSet) + } +} + +func extractIDs(val any, target map[graph.ID]struct{}) { + switch ids := val.(type) { + case []graph.ID: + for _, id := range ids { + target[id] = struct{}{} + } + case []int64: + for _, id := range ids { + target[graph.ID(id)] = struct{}{} + } + case []uint64: + for _, id := range ids { + target[graph.ID(id)] = struct{}{} + } + case []any: + for _, v := range ids { + if id, ok := toID(v); ok { + target[id] = struct{}{} + } + } + } +} + +func extractFromKindMatcher(km *cypher.KindMatcher, pc *pathConstraints) { + v, ok := km.Reference.(*cypher.Variable) + if !ok || v.Symbol != query.EdgeSymbol { + return + } + + if pc.edgeKinds == nil { + pc.edgeKinds = make(map[graph.Kind]struct{}) + } + for _, k := range km.Kinds { + pc.edgeKinds[k] = struct{}{} + } +} + +// bfsAllShortestPaths finds all shortest paths from any start node to any end node. +// It uses bidirectional-style BFS: standard BFS from all start nodes simultaneously, +// stopping at the depth where we first reach an end node. All paths at that depth are returned. +func (db *Database) bfsAllShortestPaths(pc pathConstraints) []graph.Path { + if len(pc.startIDs) == 0 || len(pc.endIDs) == 0 { + return nil + } + + type bfsEntry struct { + nodeID graph.ID + depth int + } + + // parents[nodeID] = list of (parent, edge) pairs that reach nodeID at shortest depth + parents := make(map[graph.ID][]parentInfo) + // shortest depth at which each node was discovered + depthOf := make(map[graph.ID]int) + + queue := make([]bfsEntry, 0, 256) + + // Seed BFS with all start nodes + for startID := range pc.startIDs { + if _, exists := db.nodes[startID]; !exists { + continue + } + queue = append(queue, bfsEntry{nodeID: startID, depth: 0}) + depthOf[startID] = 0 + } + + foundDepth := -1 // depth at which we first hit an end node + reachedEnds := make(map[graph.ID]struct{}) + + for len(queue) > 0 { + entry := queue[0] + queue = queue[1:] + + // If we already found paths and this entry is deeper, stop + if foundDepth >= 0 && entry.depth > foundDepth { + break + } + + // Expand outgoing edges + for _, edgeID := range db.outEdges[entry.nodeID] { + edge := db.edges[edgeID] + if edge == nil { + continue + } + + // Check edge kind constraint + if pc.edgeKinds != nil { + if _, allowed := pc.edgeKinds[edge.Kind]; !allowed { + continue + } + } + + neighborID := edge.EndID + neighborDepth := entry.depth + 1 + + // If we've already found paths at a shorter depth, skip deeper exploration + if foundDepth >= 0 && neighborDepth > foundDepth { + continue + } + + prevDepth, visited := depthOf[neighborID] + if visited && prevDepth < neighborDepth { + // Already reached at a shorter depth, skip + continue + } + + if !visited || prevDepth == neighborDepth { + // First visit or same depth — record this parent + if !visited { + depthOf[neighborID] = neighborDepth + queue = append(queue, bfsEntry{nodeID: neighborID, depth: neighborDepth}) + } + parents[neighborID] = append(parents[neighborID], parentInfo{ + parentNodeID: entry.nodeID, + edgeID: edgeID, + }) + + // Check if we reached an end node + if _, isEnd := pc.endIDs[neighborID]; isEnd { + foundDepth = neighborDepth + reachedEnds[neighborID] = struct{}{} + } + } + } + } + + if len(reachedEnds) == 0 { + return nil + } + + // Reconstruct all shortest paths by backtracking from end nodes + var paths []graph.Path + for endID := range reachedEnds { + paths = append(paths, db.reconstructPaths(endID, parents)...) + } + + return paths +} + +// reconstructPaths backtracks from endID through the parents map to build all shortest paths. +func (db *Database) reconstructPaths(endID graph.ID, parents map[graph.ID][]parentInfo) []graph.Path { + type partial struct { + nodeIDs []graph.ID + edgeIDs []graph.ID + } + + // Start with the end node and work backwards + current := []partial{{nodeIDs: []graph.ID{endID}}} + + for { + var next []partial + allDone := true + + for _, p := range current { + headNode := p.nodeIDs[len(p.nodeIDs)-1] + pInfos := parents[headNode] + + if len(pInfos) == 0 { + // Reached a start node — this partial is complete + next = append(next, p) + continue + } + + allDone = false + for _, pi := range pInfos { + newNodeIDs := make([]graph.ID, len(p.nodeIDs)+1) + copy(newNodeIDs, p.nodeIDs) + newNodeIDs[len(p.nodeIDs)] = pi.parentNodeID + + newEdgeIDs := make([]graph.ID, len(p.edgeIDs)+1) + copy(newEdgeIDs, p.edgeIDs) + newEdgeIDs[len(p.edgeIDs)] = pi.edgeID + + next = append(next, partial{nodeIDs: newNodeIDs, edgeIDs: newEdgeIDs}) + } + } + + current = next + if allDone { + break + } + } + + // Convert partials to graph.Path (reverse since we built them end-to-start) + paths := make([]graph.Path, 0, len(current)) + for _, p := range current { + path := graph.Path{ + Nodes: make([]*graph.Node, len(p.nodeIDs)), + Edges: make([]*graph.Relationship, len(p.edgeIDs)), + } + + // Reverse nodes (they were built end→start) + for i, nid := range p.nodeIDs { + path.Nodes[len(p.nodeIDs)-1-i] = db.nodes[nid] + } + // Reverse edges + for i, eid := range p.edgeIDs { + path.Edges[len(p.edgeIDs)-1-i] = db.edges[eid] + } + + paths = append(paths, path) + } + + return paths +} + +type parentInfo struct { + parentNodeID graph.ID + edgeID graph.ID +} diff --git a/drivers/sonic/queries.go b/drivers/sonic/queries.go new file mode 100644 index 0000000..ca5e0fc --- /dev/null +++ b/drivers/sonic/queries.go @@ -0,0 +1,367 @@ +package sonic + +import ( + "github.com/specterops/dawgs/graph" +) + +// --- NodeQuery --- + +type nodeQuery struct { + db *Database + filters []graph.Criteria +} + +func (q *nodeQuery) Filter(criteria graph.Criteria) graph.NodeQuery { + q.filters = append(q.filters, criteria) + return q +} + +func (q *nodeQuery) Filterf(delegate graph.CriteriaProvider) graph.NodeQuery { + return q.Filter(delegate()) +} + +func (q *nodeQuery) Query(delegate func(results graph.Result) error, finalCriteria ...graph.Criteria) error { + for _, c := range finalCriteria { + q.filters = append(q.filters, c) + } + nodes := q.collect() + rows := make([][]any, len(nodes)) + for i, n := range nodes { + rows[i] = []any{n} + } + return delegate(&sonicResult{rows: rows, keys: []string{"n"}}) +} + +func (q *nodeQuery) Delete() error { + nodes := q.collect() + + q.db.mu.Lock() + defer q.db.mu.Unlock() + + for _, n := range nodes { + // Delete attached edges + for _, edgeID := range q.db.outEdges[n.ID] { + delete(q.db.edges, edgeID) + } + for _, edgeID := range q.db.inEdges[n.ID] { + delete(q.db.edges, edgeID) + } + delete(q.db.outEdges, n.ID) + delete(q.db.inEdges, n.ID) + delete(q.db.nodes, n.ID) + } + return nil +} + +func (q *nodeQuery) Update(properties *graph.Properties) error { + nodes := q.collect() + + q.db.mu.Lock() + defer q.db.mu.Unlock() + + for _, n := range nodes { + for key, val := range properties.Map { + n.Properties.Set(key, val) + } + for key := range properties.Deleted { + n.Properties.Delete(key) + } + q.db.nodes[n.ID] = n + } + return nil +} + +func (q *nodeQuery) OrderBy(criteria ...graph.Criteria) graph.NodeQuery { + // TODO: implement ordering + return q +} + +func (q *nodeQuery) Offset(skip int) graph.NodeQuery { + // TODO: implement offset + return q +} + +func (q *nodeQuery) Limit(limit int) graph.NodeQuery { + // TODO: implement limit + return q +} + +func (q *nodeQuery) Count() (int64, error) { + return int64(len(q.collect())), nil +} + +func (q *nodeQuery) First() (*graph.Node, error) { + nodes := q.collect() + if len(nodes) == 0 { + return nil, graph.ErrNoResultsFound + } + return nodes[0], nil +} + +func (q *nodeQuery) Fetch(delegate func(cursor graph.Cursor[*graph.Node]) error, finalCriteria ...graph.Criteria) error { + nodes := q.collect() + ch := make(chan *graph.Node, len(nodes)) + for _, n := range nodes { + ch <- n + } + close(ch) + return delegate(&sliceCursor[*graph.Node]{ch: ch}) +} + +func (q *nodeQuery) FetchIDs(delegate func(cursor graph.Cursor[graph.ID]) error) error { + nodes := q.collect() + ch := make(chan graph.ID, len(nodes)) + for _, n := range nodes { + ch <- n.ID + } + close(ch) + return delegate(&sliceCursor[graph.ID]{ch: ch}) +} + +func (q *nodeQuery) FetchKinds(delegate func(cursor graph.Cursor[graph.KindsResult]) error) error { + nodes := q.collect() + ch := make(chan graph.KindsResult, len(nodes)) + for _, n := range nodes { + ch <- graph.KindsResult{ID: n.ID, Kinds: n.Kinds} + } + close(ch) + return delegate(&sliceCursor[graph.KindsResult]{ch: ch}) +} + +func (q *nodeQuery) collect() []*graph.Node { + q.db.mu.RLock() + defer q.db.mu.RUnlock() + + var results []*graph.Node + for _, n := range q.db.nodes { + if q.matchNode(n) { + results = append(results, n) + } + } + return results +} + +func (q *nodeQuery) matchNode(n *graph.Node) bool { + if len(q.filters) == 0 { + return true + } + for _, f := range q.filters { + if !evalNodeCriteria(q.db, n, f) { + return false + } + } + return true +} + +// --- RelationshipQuery --- + +type relQuery struct { + db *Database + filters []graph.Criteria +} + +func (q *relQuery) Filter(criteria graph.Criteria) graph.RelationshipQuery { + q.filters = append(q.filters, criteria) + return q +} + +func (q *relQuery) Filterf(delegate graph.CriteriaProvider) graph.RelationshipQuery { + return q.Filter(delegate()) +} + +func (q *relQuery) Update(properties *graph.Properties) error { + rels := q.collect() + + q.db.mu.Lock() + defer q.db.mu.Unlock() + + for _, r := range rels { + for key, val := range properties.Map { + r.Properties.Set(key, val) + } + for key := range properties.Deleted { + r.Properties.Delete(key) + } + q.db.edges[r.ID] = r + } + return nil +} + +func (q *relQuery) Delete() error { + rels := q.collect() + + q.db.mu.Lock() + defer q.db.mu.Unlock() + + for _, r := range rels { + delete(q.db.edges, r.ID) + q.db.outEdges[r.StartID] = removeID(q.db.outEdges[r.StartID], r.ID) + q.db.inEdges[r.EndID] = removeID(q.db.inEdges[r.EndID], r.ID) + } + return nil +} + +func (q *relQuery) OrderBy(criteria ...graph.Criteria) graph.RelationshipQuery { + return q +} + +func (q *relQuery) Offset(skip int) graph.RelationshipQuery { + return q +} + +func (q *relQuery) Limit(limit int) graph.RelationshipQuery { + return q +} + +func (q *relQuery) Count() (int64, error) { + return int64(len(q.collect())), nil +} + +func (q *relQuery) First() (*graph.Relationship, error) { + rels := q.collect() + if len(rels) == 0 { + return nil, graph.ErrNoResultsFound + } + return rels[0], nil +} + +func (q *relQuery) Query(delegate func(results graph.Result) error, finalCriteria ...graph.Criteria) error { + for _, c := range finalCriteria { + q.filters = append(q.filters, c) + } + + rels := q.collect() + + q.db.mu.RLock() + rows := make([][]any, 0, len(rels)) + for _, r := range rels { + startNode := q.db.nodes[r.StartID] + endNode := q.db.nodes[r.EndID] + rows = append(rows, []any{startNode, r, endNode}) + } + q.db.mu.RUnlock() + + return delegate(&sonicResult{rows: rows, keys: []string{"s", "r", "e"}}) +} + +func (q *relQuery) Fetch(delegate func(cursor graph.Cursor[*graph.Relationship]) error) error { + rels := q.collect() + ch := make(chan *graph.Relationship, len(rels)) + for _, r := range rels { + ch <- r + } + close(ch) + return delegate(&sliceCursor[*graph.Relationship]{ch: ch}) +} + +func (q *relQuery) FetchDirection(direction graph.Direction, delegate func(cursor graph.Cursor[graph.DirectionalResult]) error) error { + rels := q.collect() + ch := make(chan graph.DirectionalResult, len(rels)) + for _, r := range rels { + var node *graph.Node + switch direction { + case graph.DirectionOutbound: + node = q.db.nodes[r.EndID] + case graph.DirectionInbound: + node = q.db.nodes[r.StartID] + } + ch <- graph.DirectionalResult{ + Direction: direction, + Relationship: r, + Node: node, + } + } + close(ch) + return delegate(&sliceCursor[graph.DirectionalResult]{ch: ch}) +} + +func (q *relQuery) FetchIDs(delegate func(cursor graph.Cursor[graph.ID]) error) error { + rels := q.collect() + ch := make(chan graph.ID, len(rels)) + for _, r := range rels { + ch <- r.ID + } + close(ch) + return delegate(&sliceCursor[graph.ID]{ch: ch}) +} + +func (q *relQuery) FetchTriples(delegate func(cursor graph.Cursor[graph.RelationshipTripleResult]) error) error { + rels := q.collect() + ch := make(chan graph.RelationshipTripleResult, len(rels)) + for _, r := range rels { + ch <- graph.RelationshipTripleResult{ID: r.ID, StartID: r.StartID, EndID: r.EndID} + } + close(ch) + return delegate(&sliceCursor[graph.RelationshipTripleResult]{ch: ch}) +} + +func (q *relQuery) FetchAllShortestPaths(delegate func(cursor graph.Cursor[graph.Path]) error) error { + pc := extractPathConstraints(q.filters) + + q.db.mu.RLock() + paths := q.db.bfsAllShortestPaths(pc) + q.db.mu.RUnlock() + + ch := make(chan graph.Path, len(paths)) + for _, p := range paths { + ch <- p + } + close(ch) + return delegate(&sliceCursor[graph.Path]{ch: ch}) +} + +func (q *relQuery) FetchKinds(delegate func(cursor graph.Cursor[graph.RelationshipKindsResult]) error) error { + rels := q.collect() + ch := make(chan graph.RelationshipKindsResult, len(rels)) + for _, r := range rels { + ch <- graph.RelationshipKindsResult{ + RelationshipTripleResult: graph.RelationshipTripleResult{ID: r.ID, StartID: r.StartID, EndID: r.EndID}, + Kind: r.Kind, + } + } + close(ch) + return delegate(&sliceCursor[graph.RelationshipKindsResult]{ch: ch}) +} + +func (q *relQuery) collect() []*graph.Relationship { + q.db.mu.RLock() + defer q.db.mu.RUnlock() + + var results []*graph.Relationship + for _, r := range q.db.edges { + if q.matchRel(r) { + results = append(results, r) + } + } + return results +} + +func (q *relQuery) matchRel(r *graph.Relationship) bool { + if len(q.filters) == 0 { + return true + } + for _, f := range q.filters { + if !evalRelCriteria(q.db, r, f) { + return false + } + } + return true +} + +// --- Cursor --- + +type sliceCursor[T any] struct { + ch chan T + err error +} + +func (c *sliceCursor[T]) Chan() chan T { + return c.ch +} + +func (c *sliceCursor[T]) Error() error { + return c.err +} + +func (c *sliceCursor[T]) Close() { +} diff --git a/drivers/sonic/sonic.go b/drivers/sonic/sonic.go new file mode 100644 index 0000000..532db6c --- /dev/null +++ b/drivers/sonic/sonic.go @@ -0,0 +1,107 @@ +package sonic + +import ( + "context" + "sync" + "sync/atomic" + + "github.com/specterops/dawgs" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/util/size" +) + +const DriverName = "sonic" + +func init() { + dawgs.Register(DriverName, func(ctx context.Context, cfg dawgs.Config) (graph.Database, error) { + return NewDatabase(), nil + }) +} + +// Database is an in-memory graph database. +type Database struct { + mu sync.RWMutex + nextID atomic.Uint64 + nodes map[graph.ID]*graph.Node + edges map[graph.ID]*graph.Relationship + + // Adjacency indexes + outEdges map[graph.ID][]graph.ID // nodeID -> edgeIDs where node is start + inEdges map[graph.ID][]graph.ID // nodeID -> edgeIDs where node is end + + schema graph.Schema + defaultGraph graph.Graph + kinds graph.Kinds + + queryMemoryLimit size.Size + batchWriteSize int + writeFlushSize int +} + +func NewDatabase() *Database { + return &Database{ + nodes: make(map[graph.ID]*graph.Node), + edges: make(map[graph.ID]*graph.Relationship), + outEdges: make(map[graph.ID][]graph.ID), + inEdges: make(map[graph.ID][]graph.ID), + } +} + +func (db *Database) newID() graph.ID { + return graph.ID(db.nextID.Add(1)) +} + +func (db *Database) SetWriteFlushSize(interval int) { + db.writeFlushSize = interval +} + +func (db *Database) SetBatchWriteSize(interval int) { + db.batchWriteSize = interval +} + +func (db *Database) ReadTransaction(ctx context.Context, txDelegate graph.TransactionDelegate, options ...graph.TransactionOption) error { + return txDelegate(&transaction{db: db, ctx: ctx}) +} + +func (db *Database) WriteTransaction(ctx context.Context, txDelegate graph.TransactionDelegate, options ...graph.TransactionOption) error { + return txDelegate(&transaction{db: db, ctx: ctx}) +} + +func (db *Database) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate) error { + return batchDelegate(&batch{db: db, ctx: ctx}) +} + +func (db *Database) AssertSchema(ctx context.Context, dbSchema graph.Schema) error { + db.schema = dbSchema + db.defaultGraph = dbSchema.DefaultGraph + return nil +} + +func (db *Database) SetDefaultGraph(ctx context.Context, graphSchema graph.Graph) error { + db.defaultGraph = graphSchema + return nil +} + +func (db *Database) Run(ctx context.Context, query string, parameters map[string]any) error { + // No-op for in-memory driver — raw queries are not supported. + return nil +} + +func (db *Database) Close(ctx context.Context) error { + db.mu.Lock() + defer db.mu.Unlock() + + db.nodes = make(map[graph.ID]*graph.Node) + db.edges = make(map[graph.ID]*graph.Relationship) + db.outEdges = make(map[graph.ID][]graph.ID) + db.inEdges = make(map[graph.ID][]graph.ID) + return nil +} + +func (db *Database) FetchKinds(ctx context.Context) (graph.Kinds, error) { + return db.kinds, nil +} + +func (db *Database) RefreshKinds(ctx context.Context) error { + return nil +} diff --git a/drivers/sonic/sonic_test.go b/drivers/sonic/sonic_test.go new file mode 100644 index 0000000..5b6cedd --- /dev/null +++ b/drivers/sonic/sonic_test.go @@ -0,0 +1,1060 @@ +package sonic + +import ( + "context" + "fmt" + "testing" + + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/query" +) + +var ( + User = graph.StringKind("User") + Group = graph.StringKind("Group") + MemberOf = graph.StringKind("MemberOf") + HasSession = graph.StringKind("HasSession") + AdminTo = graph.StringKind("AdminTo") +) + +// TestBasicCRUD verifies node and edge creation, querying, and deletion. +func TestBasicCRUD(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + var nodeA, nodeB *graph.Node + + // Create nodes + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + var err error + nodeA, err = tx.CreateNode(graph.AsProperties(map[string]any{"name": "Alice"}), User) + if err != nil { + return err + } + nodeB, err = tx.CreateNode(graph.AsProperties(map[string]any{"name": "Bob"}), User, Group) + if err != nil { + return err + } + _, err = tx.CreateRelationshipByIDs(nodeA.ID, nodeB.ID, MemberOf, graph.NewProperties()) + return err + }) + if err != nil { + t.Fatal(err) + } + + // Query nodes + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + count, err := tx.Nodes().Count() + if err != nil { + return err + } + if count != 2 { + t.Errorf("expected 2 nodes, got %d", count) + } + + // Filter by kind + count, err = tx.Nodes().Filter(query.KindIn(query.Node(), Group)).Count() + if err != nil { + return err + } + if count != 1 { + t.Errorf("expected 1 Group node, got %d", count) + } + + // Query relationships + count, err = tx.Relationships().Count() + if err != nil { + return err + } + if count != 1 { + t.Errorf("expected 1 relationship, got %d", count) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestPropertyFilter verifies filtering nodes by property values. +func TestPropertyFilter(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + _, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Alice"}), User) + if err != nil { + return err + } + _, err = tx.CreateNode(graph.AsProperties(map[string]any{"name": "Bob"}), User) + return err + }) + if err != nil { + t.Fatal(err) + } + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + node, err := tx.Nodes().Filter(query.Equals(query.NodeProperty("name"), "Alice")).First() + if err != nil { + return err + } + name, _ := node.Properties.Get("name").String() + if name != "Alice" { + t.Errorf("expected Alice, got %s", name) + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestShortestPath verifies BFS shortest path finding. +func TestShortestPath(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + // Build: A -MemberOf-> B -MemberOf-> C -AdminTo-> D + // A -HasSession-> D (direct, shorter path — but different kind) + var nodeIDs [4]graph.ID + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + names := []string{"A", "B", "C", "D"} + for i, name := range names { + n, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": name}), User) + if err != nil { + return err + } + nodeIDs[i] = n.ID + } + + // Long path: A -> B -> C -> D (all MemberOf) + if _, err := tx.CreateRelationshipByIDs(nodeIDs[0], nodeIDs[1], MemberOf, graph.NewProperties()); err != nil { + return err + } + if _, err := tx.CreateRelationshipByIDs(nodeIDs[1], nodeIDs[2], MemberOf, graph.NewProperties()); err != nil { + return err + } + if _, err := tx.CreateRelationshipByIDs(nodeIDs[2], nodeIDs[3], MemberOf, graph.NewProperties()); err != nil { + return err + } + + // Short path: A -> D (HasSession) + _, err := tx.CreateRelationshipByIDs(nodeIDs[0], nodeIDs[3], HasSession, graph.NewProperties()) + return err + }) + if err != nil { + t.Fatal(err) + } + + // Find shortest paths from A to D with all edge kinds + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + return tx.Relationships(). + Filter(query.InIDs(query.StartID(), nodeIDs[0])). + Filter(query.InIDs(query.EndID(), nodeIDs[3])). + FetchAllShortestPaths(func(cursor graph.Cursor[graph.Path]) error { + var paths []graph.Path + for path := range cursor.Chan() { + paths = append(paths, path) + } + if err := cursor.Error(); err != nil { + return err + } + + if len(paths) != 1 { + t.Errorf("expected 1 shortest path, got %d", len(paths)) + return nil + } + + path := paths[0] + if len(path.Edges) != 1 { + t.Errorf("expected shortest path with 1 edge, got %d", len(path.Edges)) + } + if path.Edges[0].Kind != HasSession { + t.Errorf("expected HasSession edge, got %s", path.Edges[0].Kind) + } + if path.Root().ID != nodeIDs[0] { + t.Errorf("expected root node A, got %d", path.Root().ID) + } + if path.Terminal().ID != nodeIDs[3] { + t.Errorf("expected terminal node D, got %d", path.Terminal().ID) + } + + return nil + }) + }) + if err != nil { + t.Fatal(err) + } + + // Find shortest paths from A to D with MemberOf only — should be 3-hop path + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + return tx.Relationships(). + Filter(query.InIDs(query.StartID(), nodeIDs[0])). + Filter(query.InIDs(query.EndID(), nodeIDs[3])). + Filter(query.KindIn(query.Relationship(), MemberOf)). + FetchAllShortestPaths(func(cursor graph.Cursor[graph.Path]) error { + var paths []graph.Path + for path := range cursor.Chan() { + paths = append(paths, path) + } + if err := cursor.Error(); err != nil { + return err + } + + if len(paths) != 1 { + t.Errorf("expected 1 path, got %d", len(paths)) + return nil + } + + path := paths[0] + if len(path.Edges) != 3 { + t.Errorf("expected 3-hop path, got %d edges", len(path.Edges)) + } + + return nil + }) + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherQuery verifies raw Cypher string execution. +func TestCypherQuery(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + _, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Alice"}), User) + if err != nil { + return err + } + _, err = tx.CreateNode(graph.AsProperties(map[string]any{"name": "Bob"}), User) + if err != nil { + return err + } + _, err = tx.CreateNode(graph.AsProperties(map[string]any{"name": "Charlie"}), Group) + return err + }) + if err != nil { + t.Fatal(err) + } + + // Test: MATCH (n:User) RETURN n + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query("MATCH (n:User) RETURN n", nil) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + values := result.Values() + if len(values) != 1 { + t.Errorf("expected 1 value per row, got %d", len(values)) + } + if _, ok := values[0].(*graph.Node); !ok { + t.Errorf("expected *graph.Node, got %T", values[0]) + } + } + if count != 2 { + t.Errorf("expected 2 User nodes from cypher, got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } + + // Test: MATCH (n) RETURN n LIMIT 1 + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query("MATCH (n) RETURN n LIMIT 1", nil) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + } + if count != 1 { + t.Errorf("expected 1 node with LIMIT, got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherScan verifies that Scan maps Cypher results to graph types. +func TestCypherScan(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + _, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Alice"}), User) + return err + }) + if err != nil { + t.Fatal(err) + } + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query("MATCH (n:User) RETURN n", nil) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + if !result.Next() { + t.Fatal("expected at least one result") + } + + var node graph.Node + if err := result.Scan(&node); err != nil { + return err + } + + name, _ := node.Properties.Get("name").String() + if name != "Alice" { + t.Errorf("expected Alice, got %s", name) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestMultipleShortestPaths verifies that all equally-short paths are returned. +func TestMultipleShortestPaths(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + // Build a diamond: A -> B -> D, A -> C -> D (both 2-hop) + var nodeIDs [4]graph.ID + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + names := []string{"A", "B", "C", "D"} + for i, name := range names { + n, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": name}), User) + if err != nil { + return err + } + nodeIDs[i] = n.ID + } + + edges := [][2]int{{0, 1}, {0, 2}, {1, 3}, {2, 3}} + for _, e := range edges { + if _, err := tx.CreateRelationshipByIDs(nodeIDs[e[0]], nodeIDs[e[1]], MemberOf, graph.NewProperties()); err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + return tx.Relationships(). + Filter(query.InIDs(query.StartID(), nodeIDs[0])). + Filter(query.InIDs(query.EndID(), nodeIDs[3])). + FetchAllShortestPaths(func(cursor graph.Cursor[graph.Path]) error { + var paths []graph.Path + for path := range cursor.Chan() { + paths = append(paths, path) + } + if err := cursor.Error(); err != nil { + return err + } + + if len(paths) != 2 { + t.Errorf("expected 2 shortest paths (diamond), got %d", len(paths)) + } + + for _, path := range paths { + if len(path.Edges) != 2 { + t.Errorf("expected 2-hop path, got %d edges", len(path.Edges)) + } + } + + return nil + }) + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherNegatedKind tests the exact query BHE's explore page sends. +func TestCypherNegatedKind(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + _, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Alice"}), User) + if err != nil { + return err + } + _, err = tx.CreateNode(graph.AsProperties(map[string]any{"name": "migration"}), graph.StringKind("MigrationData")) + return err + }) + if err != nil { + t.Fatal(err) + } + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query("MATCH (A) WHERE NOT A:MigrationData RETURN A LIMIT 10", nil) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + values := result.Values() + if n, ok := values[0].(*graph.Node); ok { + name, _ := n.Properties.Get("name").String() + if name == "migration" { + t.Error("MigrationData node should have been excluded") + } + } + } + if count != 1 { + t.Errorf("expected 1 non-MigrationData node, got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } +} + +// TestConcurrentReadWrite verifies no deadlock when readers and writers run simultaneously. +func TestConcurrentReadWrite(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + for i := 0; i < 10; i++ { + if _, err := tx.CreateNode(graph.AsProperties(map[string]any{"i": i}), User); err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + go func() { + defer close(done) + _ = db.BatchOperation(ctx, func(batch graph.Batch) error { + for i := 10; i < 20; i++ { + if err := batch.CreateNode(&graph.Node{ + Kinds: graph.Kinds{User}, + Properties: graph.AsProperties(map[string]any{"i": i}), + }); err != nil { + return err + } + } + return nil + }) + }() + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + _, err := tx.Nodes().Count() + return err + }) + if err != nil { + t.Fatal(err) + } + + <-done +} + +// TestCypherRelationshipQuery verifies: MATCH (s)-[r:MemberOf]->(e) RETURN s, r, e +func TestCypherRelationshipQuery(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + a, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Alice"}), User) + if err != nil { + return err + } + b, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Bob"}), Group) + if err != nil { + return err + } + _, err = tx.CreateRelationshipByIDs(a.ID, b.ID, MemberOf, graph.NewProperties()) + if err != nil { + return err + } + // Add a HasSession edge that should NOT appear + _, err = tx.CreateRelationshipByIDs(a.ID, b.ID, HasSession, graph.NewProperties()) + return err + }) + if err != nil { + t.Fatal(err) + } + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query("MATCH (s)-[r:MemberOf]->(e) RETURN s, r, e", nil) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + values := result.Values() + if len(values) != 3 { + t.Errorf("expected 3 values per row, got %d", len(values)) + continue + } + if _, ok := values[0].(*graph.Node); !ok { + t.Errorf("expected *graph.Node for s, got %T", values[0]) + } + if rel, ok := values[1].(*graph.Relationship); ok { + if rel.Kind != MemberOf { + t.Errorf("expected MemberOf relationship, got %s", rel.Kind) + } + } else { + t.Errorf("expected *graph.Relationship for r, got %T", values[1]) + } + if _, ok := values[2].(*graph.Node); !ok { + t.Errorf("expected *graph.Node for e, got %T", values[2]) + } + } + if count != 1 { + t.Errorf("expected 1 MemberOf relationship, got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherInboundRelationship verifies: MATCH (a)<-[r]-(b) RETURN a, r, b +func TestCypherInboundRelationship(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + var aliceID graph.ID + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + alice, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Alice"}), User) + if err != nil { + return err + } + aliceID = alice.ID + bob, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Bob"}), User) + if err != nil { + return err + } + // Bob -MemberOf-> Alice + _, err = tx.CreateRelationshipByIDs(bob.ID, alice.ID, MemberOf, graph.NewProperties()) + return err + }) + if err != nil { + t.Fatal(err) + } + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query("MATCH (a:User)<-[r:MemberOf]-(b) RETURN a, r, b", nil) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + values := result.Values() + if n, ok := values[0].(*graph.Node); ok { + if n.ID != aliceID { + t.Errorf("expected Alice as 'a', got node %d", n.ID) + } + } + } + if count != 1 { + t.Errorf("expected 1 inbound relationship, got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherWithMultiPart verifies: MATCH (n) WITH n MATCH (n)-[r]->(m) RETURN n, r, m +func TestCypherWithMultiPart(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + a, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Alice"}), User) + if err != nil { + return err + } + b, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Bob"}), Group) + if err != nil { + return err + } + _, err = tx.CreateRelationshipByIDs(a.ID, b.ID, MemberOf, graph.NewProperties()) + return err + }) + if err != nil { + t.Fatal(err) + } + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query("MATCH (n:User) WITH n MATCH (n)-[r]->(m) RETURN n, r, m", nil) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + values := result.Values() + if len(values) != 3 { + t.Errorf("expected 3 values per row, got %d", len(values)) + } + } + if count != 1 { + t.Errorf("expected 1 result from multi-part query, got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherPropertyFilter verifies: MATCH (n:User) WHERE n.name = 'Alice' RETURN n +func TestCypherPropertyFilter(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + _, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Alice"}), User) + if err != nil { + return err + } + _, err = tx.CreateNode(graph.AsProperties(map[string]any{"name": "Bob"}), User) + return err + }) + if err != nil { + t.Fatal(err) + } + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query("MATCH (n:User) WHERE n.name = 'Alice' RETURN n", nil) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + values := result.Values() + if n, ok := values[0].(*graph.Node); ok { + name, _ := n.Properties.Get("name").String() + if name != "Alice" { + t.Errorf("expected Alice, got %s", name) + } + } + } + if count != 1 { + t.Errorf("expected 1 result, got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherUnsupportedReturnsError verifies that unsupported constructs return errors. +func TestCypherUnsupportedReturnsError(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + _, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Alice"}), User) + return err + }) + if err != nil { + t.Fatal(err) + } + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + // CREATE should return an error + result := tx.Query("CREATE (n:User {name: 'Charlie'})", nil) + if result.Error() == nil { + t.Error("expected error for CREATE, got nil") + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherShortestPathViaCypher verifies allShortestPaths via Cypher string. +func TestCypherShortestPathViaCypher(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + var nodeIDs [3]graph.ID + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + names := []string{"A", "B", "C"} + for i, name := range names { + n, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": name}), User) + if err != nil { + return err + } + nodeIDs[i] = n.ID + } + // A -> B -> C + if _, err := tx.CreateRelationshipByIDs(nodeIDs[0], nodeIDs[1], MemberOf, graph.NewProperties()); err != nil { + return err + } + _, err := tx.CreateRelationshipByIDs(nodeIDs[1], nodeIDs[2], MemberOf, graph.NewProperties()) + return err + }) + if err != nil { + t.Fatal(err) + } + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + q := fmt.Sprintf( + "MATCH p = allShortestPaths((s)-[*]->(e)) WHERE id(s) = %d AND id(e) = %d RETURN p", + nodeIDs[0], nodeIDs[2], + ) + result := tx.Query(q, nil) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + values := result.Values() + if p, ok := values[0].(*graph.Path); ok { + if len(p.Edges) != 2 { + t.Errorf("expected 2-hop path, got %d edges", len(p.Edges)) + } + } else { + t.Errorf("expected *graph.Path, got %T", values[0]) + } + } + if count != 1 { + t.Errorf("expected 1 shortest path, got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherVariableLengthPath verifies: MATCH (a)-[r*1..2]->(b) RETURN a, r, b +func TestCypherVariableLengthPath(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + // Build chain: A -MemberOf-> B -MemberOf-> C -MemberOf-> D + var nodeIDs [4]graph.ID + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + names := []string{"A", "B", "C", "D"} + for i, name := range names { + n, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": name}), User) + if err != nil { + return err + } + nodeIDs[i] = n.ID + } + for i := 0; i < 3; i++ { + if _, err := tx.CreateRelationshipByIDs(nodeIDs[i], nodeIDs[i+1], MemberOf, graph.NewProperties()); err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + // MATCH (a)-[r*1..2]->(b) WHERE id(a) = A_ID RETURN a, r, b + // Should find: A->B (1 hop), A->B->C (2 hops) + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query( + fmt.Sprintf("MATCH (a)-[r*1..2]->(b) WHERE id(a) = %d RETURN a, r, b", nodeIDs[0]), + nil, + ) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + values := result.Values() + if len(values) != 3 { + t.Errorf("expected 3 values, got %d", len(values)) + continue + } + // r should be a slice of relationships + if rels, ok := values[1].([]*graph.Relationship); ok { + if len(rels) < 1 || len(rels) > 2 { + t.Errorf("expected 1 or 2 relationships, got %d", len(rels)) + } + } else { + t.Errorf("expected []*graph.Relationship, got %T", values[1]) + } + } + if count != 2 { + t.Errorf("expected 2 paths (1-hop and 2-hop), got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherVariableLengthUnbounded verifies: MATCH (a)-[*]->(b) with no upper bound +func TestCypherVariableLengthUnbounded(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + // Build chain: A -> B -> C + var nodeIDs [3]graph.ID + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + names := []string{"A", "B", "C"} + for i, name := range names { + n, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": name}), User) + if err != nil { + return err + } + nodeIDs[i] = n.ID + } + for i := 0; i < 2; i++ { + if _, err := tx.CreateRelationshipByIDs(nodeIDs[i], nodeIDs[i+1], MemberOf, graph.NewProperties()); err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + // MATCH (a)-[*]->(b) WHERE id(a) = A_ID RETURN b + // Should find: A->B (1 hop), A->B->C (2 hops) + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query( + fmt.Sprintf("MATCH (a)-[*]->(b) WHERE id(a) = %d RETURN b", nodeIDs[0]), + nil, + ) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + } + if count != 2 { + t.Errorf("expected 2 reachable nodes, got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherVariableLengthNoCycles verifies that variable-length paths don't loop. +func TestCypherVariableLengthNoCycles(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + // Build cycle: A -> B -> A + var nodeIDs [2]graph.ID + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + a, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "A"}), User) + if err != nil { + return err + } + b, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "B"}), User) + if err != nil { + return err + } + nodeIDs[0] = a.ID + nodeIDs[1] = b.ID + if _, err := tx.CreateRelationshipByIDs(a.ID, b.ID, MemberOf, graph.NewProperties()); err != nil { + return err + } + _, err = tx.CreateRelationshipByIDs(b.ID, a.ID, MemberOf, graph.NewProperties()) + return err + }) + if err != nil { + t.Fatal(err) + } + + // A->B is depth 1. B->A would revisit the start node, so it's blocked by cycle detection. + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query( + fmt.Sprintf("MATCH (a)-[*1..10]->(b) WHERE id(a) = %d RETURN b", nodeIDs[0]), + nil, + ) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + } + if count != 1 { + t.Errorf("expected 1 path (cycle should be prevented), got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherVariableLengthWithKind verifies kind filtering on variable-length edges. +func TestCypherVariableLengthWithKind(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + // A -MemberOf-> B -HasSession-> C -MemberOf-> D + var nodeIDs [4]graph.ID + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + names := []string{"A", "B", "C", "D"} + for i, name := range names { + n, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": name}), User) + if err != nil { + return err + } + nodeIDs[i] = n.ID + } + if _, err := tx.CreateRelationshipByIDs(nodeIDs[0], nodeIDs[1], MemberOf, graph.NewProperties()); err != nil { + return err + } + if _, err := tx.CreateRelationshipByIDs(nodeIDs[1], nodeIDs[2], HasSession, graph.NewProperties()); err != nil { + return err + } + _, err := tx.CreateRelationshipByIDs(nodeIDs[2], nodeIDs[3], MemberOf, graph.NewProperties()) + return err + }) + if err != nil { + t.Fatal(err) + } + + // Only follow MemberOf edges — should reach B but NOT C or D + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query( + fmt.Sprintf("MATCH (a)-[:MemberOf*1..5]->(b) WHERE id(a) = %d RETURN b", nodeIDs[0]), + nil, + ) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + values := result.Values() + if n, ok := values[0].(*graph.Node); ok { + if n.ID == nodeIDs[2] || n.ID == nodeIDs[3] { + t.Errorf("should not reach node %d through MemberOf-only path", n.ID) + } + } + } + if count != 1 { + t.Errorf("expected 1 reachable node via MemberOf, got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } +} + +// TestCypherAnonymousNodes verifies: MATCH ()-[r]->() RETURN r +func TestCypherAnonymousNodes(t *testing.T) { + db := NewDatabase() + ctx := context.Background() + + err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + a, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Alice"}), User) + if err != nil { + return err + } + b, err := tx.CreateNode(graph.AsProperties(map[string]any{"name": "Bob"}), Group) + if err != nil { + return err + } + _, err = tx.CreateRelationshipByIDs(a.ID, b.ID, MemberOf, graph.NewProperties()) + return err + }) + if err != nil { + t.Fatal(err) + } + + err = db.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query("MATCH ()-[r]->() RETURN r LIMIT 5", nil) + if result.Error() != nil { + return result.Error() + } + defer result.Close() + + count := 0 + for result.Next() { + count++ + values := result.Values() + if _, ok := values[0].(*graph.Relationship); !ok { + t.Errorf("expected *graph.Relationship, got %T", values[0]) + } + } + if count != 1 { + t.Errorf("expected 1 relationship, got %d", count) + } + return result.Error() + }) + if err != nil { + t.Fatal(err) + } +} diff --git a/drivers/sonic/transaction.go b/drivers/sonic/transaction.go new file mode 100644 index 0000000..9c0a6f9 --- /dev/null +++ b/drivers/sonic/transaction.go @@ -0,0 +1,122 @@ +package sonic + +import ( + "context" + "fmt" + "log/slog" + + "github.com/specterops/dawgs/cypher/frontend" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/util/size" +) + +type transaction struct { + db *Database + ctx context.Context +} + +func (tx *transaction) WithGraph(graphSchema graph.Graph) graph.Transaction { + return tx +} + +func (tx *transaction) CreateNode(properties *graph.Properties, kinds ...graph.Kind) (*graph.Node, error) { + id := tx.db.newID() + + node := &graph.Node{ + ID: id, + Kinds: kinds, + Properties: properties, + } + + tx.db.mu.Lock() + defer tx.db.mu.Unlock() + + tx.db.nodes[id] = node + return node, nil +} + +func (tx *transaction) UpdateNode(node *graph.Node) error { + tx.db.mu.Lock() + defer tx.db.mu.Unlock() + + if _, ok := tx.db.nodes[node.ID]; !ok { + return fmt.Errorf("node %d not found", node.ID) + } + tx.db.nodes[node.ID] = node + return nil +} + +func (tx *transaction) Nodes() graph.NodeQuery { + return &nodeQuery{db: tx.db} +} + +func (tx *transaction) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) { + tx.db.mu.Lock() + defer tx.db.mu.Unlock() + + if _, ok := tx.db.nodes[startNodeID]; !ok { + return nil, fmt.Errorf("start node %d not found", startNodeID) + } + if _, ok := tx.db.nodes[endNodeID]; !ok { + return nil, fmt.Errorf("end node %d not found", endNodeID) + } + + id := tx.db.newID() + rel := &graph.Relationship{ + ID: id, + StartID: startNodeID, + EndID: endNodeID, + Kind: kind, + Properties: properties, + } + + tx.db.edges[id] = rel + tx.db.outEdges[startNodeID] = append(tx.db.outEdges[startNodeID], id) + tx.db.inEdges[endNodeID] = append(tx.db.inEdges[endNodeID], id) + return rel, nil +} + +func (tx *transaction) UpdateRelationship(relationship *graph.Relationship) error { + tx.db.mu.Lock() + defer tx.db.mu.Unlock() + + if _, ok := tx.db.edges[relationship.ID]; !ok { + return fmt.Errorf("relationship %d not found", relationship.ID) + } + tx.db.edges[relationship.ID] = relationship + return nil +} + +func (tx *transaction) Relationships() graph.RelationshipQuery { + return &relQuery{db: tx.db} +} + +func (tx *transaction) Raw(query string, parameters map[string]any) graph.Result { + // Raw SQL doesn't apply to sonic — treat it as a Cypher query + return tx.Query(query, parameters) +} + +func (tx *transaction) Query(queryStr string, parameters map[string]any) graph.Result { + parsedQuery, err := frontend.ParseCypher(frontend.NewContext(), queryStr) + if err != nil { + slog.Error("sonic: failed to parse cypher", slog.String("query", queryStr), slog.String("error", err.Error())) + return graph.NewErrorResult(fmt.Errorf("sonic: failed to parse cypher: %w", err)) + } + + result, err := tx.db.executeCypher(parsedQuery, parameters) + if err != nil { + slog.Error("sonic: cypher execution failed", slog.String("query", queryStr), slog.String("error", err.Error())) + return graph.NewErrorResult(err) + } + + slog.Info("sonic: cypher query executed", slog.String("query", queryStr), slog.Int("rows", len(result.rows)), slog.Any("keys", result.keys)) + return result +} + +func (tx *transaction) Commit() error { + return nil +} + +func (tx *transaction) GraphQueryMemoryLimit() size.Size { + return tx.db.queryMemoryLimit +}