Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 72 additions & 71 deletions pkg/sql/colexec/table_function/ivf_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
package table_function

import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"

"github.com/bytedance/sonic"
"github.com/matrixorigin/matrixone/pkg/catalog"
"github.com/matrixorigin/matrixone/pkg/common/moerr"
"github.com/matrixorigin/matrixone/pkg/container/batch"
Expand All @@ -35,7 +34,6 @@ import (
"github.com/matrixorigin/matrixone/pkg/vectorindex/sqlexec"
"github.com/matrixorigin/matrixone/pkg/vm"
"github.com/matrixorigin/matrixone/pkg/vm/process"
"golang.org/x/exp/rand"
)

const (
Expand All @@ -61,7 +59,6 @@ type ivfCreateState struct {
idxcfg vectorindex.IndexConfig
data32 [][]float32
data64 [][]float64
rand *rand.Rand
nsample uint
sample_ratio float64
offset int
Expand All @@ -85,7 +82,7 @@ func clustering[T types.RealNumbers](u *ivfCreateState, tf *TableFunction, proc

// NOTE: We use L2 distance to caculate centroid. Ivfflat metric just for searching.
var centers [][]T
if clusterer, err = elkans.NewKMeans[T](
if clusterer, err = elkans.NewKMeans(
data, int(u.idxcfg.Ivfflat.Lists),
int(u.tblcfg.KmeansMaxIteration),
defaultKmeansDeltaThreshold,
Expand All @@ -95,6 +92,7 @@ func clustering[T types.RealNumbers](u *ivfCreateState, tf *TableFunction, proc
int(nworker)); err != nil {
return err
}

anycenters, err := clusterer.Cluster(proc.Ctx)
if err != nil {
return err
Expand All @@ -108,7 +106,7 @@ func clustering[T types.RealNumbers](u *ivfCreateState, tf *TableFunction, proc
// insert into centroid table
values := make([]string, 0, len(centers))
for i, c := range centers {
s := types.ArrayToString[T](c)
s := types.ArrayToString(c)
values = append(values, fmt.Sprintf("(%d, %d, '%s')", version, i, s))
}

Expand Down Expand Up @@ -143,9 +141,9 @@ func (u *ivfCreateState) end(tf *TableFunction, proc *process.Process) error {
}

if u.data32 != nil {
return clustering[float32](u, tf, proc, u.data32)
return clustering(u, tf, proc, u.data32)
} else {
return clustering[float64](u, tf, proc, u.data64)
return clustering(u, tf, proc, u.data64)
}
}

Expand Down Expand Up @@ -194,7 +192,7 @@ func (u *ivfCreateState) start(tf *TableFunction, proc *process.Process, nthRow

if !u.inited {
if len(tf.Params) > 0 {
err = json.Unmarshal([]byte(tf.Params), &u.param)
err = sonic.Unmarshal([]byte(tf.Params), &u.param)
if err != nil {
return err
}
Expand Down Expand Up @@ -232,27 +230,11 @@ func (u *ivfCreateState) start(tf *TableFunction, proc *process.Process, nthRow
if len(cfgstr) == 0 {
return moerr.NewInternalError(proc.Ctx, "IndexTableConfig is empty")
}
err := json.Unmarshal([]byte(cfgstr), &u.tblcfg)
err := sonic.Unmarshal([]byte(cfgstr), &u.tblcfg)
if err != nil {
return err
}

// support both vecf32 and vecf64
f32aVec := tf.ctr.argVecs[1]
supported := false
for _, t := range ClusterCentersSupportTypes {
if f32aVec.GetType().Oid == t {
supported = true
break
}
}
if !supported {
return moerr.NewInvalidInput(proc.Ctx, "Second argument (vector must be a vecf32 or vecf64 type")
}
dimension := f32aVec.GetType().Width

// dimension
u.idxcfg.Ivfflat.Dimensions = uint(dimension)
u.idxcfg.Type = "ivfflat"

u.nsample = u.idxcfg.Ivfflat.Lists * 50
Expand All @@ -268,12 +250,6 @@ func (u *ivfCreateState) start(tf *TableFunction, proc *process.Process, nthRow
u.nsample = min_nsample
}

if f32aVec.GetType().Oid == types.T_array_float32 {
u.data32 = make([][]float32, 0, u.nsample)
} else {
u.data64 = make([][]float64, 0, u.nsample)
}

u.sample_ratio = train_percent
if u.tblcfg.DataSize > 0 {
u.sample_ratio = float64(u.nsample) / float64(u.tblcfg.DataSize)
Expand All @@ -282,51 +258,76 @@ func (u *ivfCreateState) start(tf *TableFunction, proc *process.Process, nthRow
}
}

u.rand = rand.New(rand.NewSource(uint64(time.Now().UnixMicro())))
u.batch = tf.createResultBatch()
u.inited = true
//os.Stderr.WriteString(fmt.Sprintf("nsample %d, train_percent %f, iter %d\n", u.nsample, train_percent, u.tblcfg.KmeansMaxIteration))
}

// reset slice
u.offset = 0
// run SQL
sql := fmt.Sprintf("SELECT `%s` FROM `%s`.`%s` WHERE `%s` IS NOT NULL AND RAND() < %f LIMIT %d",
u.tblcfg.KeyPart,
u.tblcfg.DbName,
u.tblcfg.SrcTable,
u.tblcfg.KeyPart,
u.sample_ratio,
u.nsample)

// cleanup the batch
u.batch.CleanOnlyData()

datasz := 0
if u.data32 != nil {
datasz = len(u.data32)
} else {
datasz = len(u.data64)
}
if uint(datasz) >= u.nsample {
// enough sample data
return nil
}
res, err := ivf_runSql(sqlexec.NewSqlProcess(proc), sql)
if err != nil {
return err
}
defer res.Close()

fpaVec := tf.ctr.argVecs[1]
if fpaVec.IsNull(uint64(nthRow)) {
return nil
}
if len(res.Batches) == 0 {
return nil
}

if u.sample_ratio < u.rand.Float64() {
// skip the sample
return nil
}
embedvec := res.Batches[0].Vecs[0]
supported := false
for _, t := range ClusterCentersSupportTypes {
if embedvec.GetType().Oid == t {
supported = true
break
}
}
if !supported {
return moerr.NewInvalidInput(proc.Ctx, "Second argument (vector must be a vecf32 or vecf64 type")
}

if fpaVec.GetType().Oid == types.T_array_float32 {
f32a := types.BytesToArray[float32](fpaVec.GetBytesAt(nthRow))
if uint(len(f32a)) != u.idxcfg.Ivfflat.Dimensions {
return moerr.NewInternalError(proc.Ctx, "vector dimension mismatch")
if embedvec.GetType().Oid == types.T_array_float32 {
u.data32 = make([][]float32, 0, u.nsample)
} else {
u.data64 = make([][]float64, 0, u.nsample)
}
u.data32 = append(u.data32, append(make([]float32, 0, len(f32a)), f32a...))
} else {
f64a := types.BytesToArray[float64](fpaVec.GetBytesAt(nthRow))
if uint(len(f64a)) != u.idxcfg.Ivfflat.Dimensions {
return moerr.NewInternalError(proc.Ctx, "vector dimension mismatch")

// dimension
dimension := embedvec.GetType().Width
u.idxcfg.Ivfflat.Dimensions = uint(dimension)
//elemsz := res.Batches[0].Vecs[0].GetType().GetArrayElementSize()

for _, bat := range res.Batches {
evec := bat.Vecs[0]
for i := 0; i < bat.RowCount(); i++ {
switch evec.GetType().Oid {
case types.T_array_float32:
f32a := types.BytesToArray[float32](evec.GetBytesAt(i))
if uint(len(f32a)) != u.idxcfg.Ivfflat.Dimensions {
return moerr.NewInternalError(proc.Ctx, "vector dimension mismatch")
}
u.data32 = append(u.data32, append(make([]float32, 0, len(f32a)), f32a...))
case types.T_array_float64:
f64a := types.BytesToArray[float64](evec.GetBytesAt(i))
if uint(len(f64a)) != u.idxcfg.Ivfflat.Dimensions {
return moerr.NewInternalError(proc.Ctx, "vector dimension mismatch")
}
u.data64 = append(u.data64, append(make([]float64, 0, len(f64a)), f64a...))
}
}
}
u.data64 = append(u.data64, append(make([]float64, 0, len(f64a)), f64a...))

// reset slice
u.offset = 0

u.batch = tf.createResultBatch()
u.inited = true
//os.Stderr.WriteString(fmt.Sprintf("nsample %d, train_percent %f, iter %d\n", u.nsample, train_percent, u.tblcfg.KmeansMaxIteration))
// cleanup the batch
u.batch.CleanOnlyData()
}

return nil
Expand Down
Loading
Loading